mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-02 11:26:09 -04:00
Run black on the rest of the storage module (#4996)
This commit is contained in:
parent
3039d61baf
commit
7efd1d87c2
42 changed files with 2129 additions and 2453 deletions
|
@ -41,7 +41,7 @@ try:
|
|||
MAX_TXN_ID = sys.maxint - 1
|
||||
except AttributeError:
|
||||
# python 3 does not have a maximum int value
|
||||
MAX_TXN_ID = 2**63 - 1
|
||||
MAX_TXN_ID = 2 ** 63 - 1
|
||||
|
||||
sql_logger = logging.getLogger("synapse.storage.SQL")
|
||||
transaction_logger = logging.getLogger("synapse.storage.txn")
|
||||
|
@ -76,12 +76,18 @@ class LoggingTransaction(object):
|
|||
"""An object that almost-transparently proxies for the 'txn' object
|
||||
passed to the constructor. Adds logging and metrics to the .execute()
|
||||
method."""
|
||||
|
||||
__slots__ = [
|
||||
"txn", "name", "database_engine", "after_callbacks", "exception_callbacks",
|
||||
"txn",
|
||||
"name",
|
||||
"database_engine",
|
||||
"after_callbacks",
|
||||
"exception_callbacks",
|
||||
]
|
||||
|
||||
def __init__(self, txn, name, database_engine, after_callbacks,
|
||||
exception_callbacks):
|
||||
def __init__(
|
||||
self, txn, name, database_engine, after_callbacks, exception_callbacks
|
||||
):
|
||||
object.__setattr__(self, "txn", txn)
|
||||
object.__setattr__(self, "name", name)
|
||||
object.__setattr__(self, "database_engine", database_engine)
|
||||
|
@ -110,6 +116,7 @@ class LoggingTransaction(object):
|
|||
def execute_batch(self, sql, args):
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
from psycopg2.extras import execute_batch
|
||||
|
||||
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
|
||||
else:
|
||||
for val in args:
|
||||
|
@ -134,10 +141,7 @@ class LoggingTransaction(object):
|
|||
sql = self.database_engine.convert_param_style(sql)
|
||||
if args:
|
||||
try:
|
||||
sql_logger.debug(
|
||||
"[SQL values] {%s} %r",
|
||||
self.name, args[0]
|
||||
)
|
||||
sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
|
||||
except Exception:
|
||||
# Don't let logging failures stop SQL from working
|
||||
pass
|
||||
|
@ -145,9 +149,7 @@ class LoggingTransaction(object):
|
|||
start = time.time()
|
||||
|
||||
try:
|
||||
return func(
|
||||
sql, *args
|
||||
)
|
||||
return func(sql, *args)
|
||||
except Exception as e:
|
||||
logger.debug("[SQL FAIL] {%s} %s", self.name, e)
|
||||
raise
|
||||
|
@ -176,11 +178,9 @@ class PerformanceCounters(object):
|
|||
counters = []
|
||||
for name, (count, cum_time) in iteritems(self.current_counters):
|
||||
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
|
||||
counters.append((
|
||||
(cum_time - prev_time) / interval_duration,
|
||||
count - prev_count,
|
||||
name
|
||||
))
|
||||
counters.append(
|
||||
((cum_time - prev_time) / interval_duration, count - prev_count, name)
|
||||
)
|
||||
|
||||
self.previous_counters = dict(self.current_counters)
|
||||
|
||||
|
@ -212,8 +212,9 @@ class SQLBaseStore(object):
|
|||
self._txn_perf_counters = PerformanceCounters()
|
||||
self._get_event_counters = PerformanceCounters()
|
||||
|
||||
self._get_event_cache = Cache("*getEvent*", keylen=3,
|
||||
max_entries=hs.config.event_cache_size)
|
||||
self._get_event_cache = Cache(
|
||||
"*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
|
||||
)
|
||||
|
||||
self._event_fetch_lock = threading.Condition()
|
||||
self._event_fetch_list = []
|
||||
|
@ -239,7 +240,7 @@ class SQLBaseStore(object):
|
|||
0.0,
|
||||
run_as_background_process,
|
||||
"upsert_safety_check",
|
||||
self._check_safe_to_upsert
|
||||
self._check_safe_to_upsert,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -271,7 +272,7 @@ class SQLBaseStore(object):
|
|||
15.0,
|
||||
run_as_background_process,
|
||||
"upsert_safety_check",
|
||||
self._check_safe_to_upsert
|
||||
self._check_safe_to_upsert,
|
||||
)
|
||||
|
||||
def start_profiling(self):
|
||||
|
@ -298,13 +299,16 @@ class SQLBaseStore(object):
|
|||
|
||||
perf_logger.info(
|
||||
"Total database time: %.3f%% {%s} {%s}",
|
||||
ratio * 100, top_three_counters, top_3_event_counters
|
||||
ratio * 100,
|
||||
top_three_counters,
|
||||
top_3_event_counters,
|
||||
)
|
||||
|
||||
self._clock.looping_call(loop, 10000)
|
||||
|
||||
def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks,
|
||||
func, *args, **kwargs):
|
||||
def _new_transaction(
|
||||
self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
|
||||
):
|
||||
start = time.time()
|
||||
txn_id = self._TXN_ID
|
||||
|
||||
|
@ -312,7 +316,7 @@ class SQLBaseStore(object):
|
|||
# growing really large.
|
||||
self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
|
||||
|
||||
name = "%s-%x" % (desc, txn_id, )
|
||||
name = "%s-%x" % (desc, txn_id)
|
||||
|
||||
transaction_logger.debug("[TXN START] {%s}", name)
|
||||
|
||||
|
@ -323,7 +327,10 @@ class SQLBaseStore(object):
|
|||
try:
|
||||
txn = conn.cursor()
|
||||
txn = LoggingTransaction(
|
||||
txn, name, self.database_engine, after_callbacks,
|
||||
txn,
|
||||
name,
|
||||
self.database_engine,
|
||||
after_callbacks,
|
||||
exception_callbacks,
|
||||
)
|
||||
r = func(txn, *args, **kwargs)
|
||||
|
@ -334,7 +341,10 @@ class SQLBaseStore(object):
|
|||
# transaction.
|
||||
logger.warning(
|
||||
"[TXN OPERROR] {%s} %s %d/%d",
|
||||
name, exception_to_unicode(e), i, N
|
||||
name,
|
||||
exception_to_unicode(e),
|
||||
i,
|
||||
N,
|
||||
)
|
||||
if i < N:
|
||||
i += 1
|
||||
|
@ -342,8 +352,7 @@ class SQLBaseStore(object):
|
|||
conn.rollback()
|
||||
except self.database_engine.module.Error as e1:
|
||||
logger.warning(
|
||||
"[TXN EROLL] {%s} %s",
|
||||
name, exception_to_unicode(e1),
|
||||
"[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
|
||||
)
|
||||
continue
|
||||
raise
|
||||
|
@ -357,7 +366,8 @@ class SQLBaseStore(object):
|
|||
except self.database_engine.module.Error as e1:
|
||||
logger.warning(
|
||||
"[TXN EROLL] {%s} %s",
|
||||
name, exception_to_unicode(e1),
|
||||
name,
|
||||
exception_to_unicode(e1),
|
||||
)
|
||||
continue
|
||||
raise
|
||||
|
@ -396,16 +406,17 @@ class SQLBaseStore(object):
|
|||
exception_callbacks = []
|
||||
|
||||
if LoggingContext.current_context() == LoggingContext.sentinel:
|
||||
logger.warn(
|
||||
"Starting db txn '%s' from sentinel context",
|
||||
desc,
|
||||
)
|
||||
logger.warn("Starting db txn '%s' from sentinel context", desc)
|
||||
|
||||
try:
|
||||
result = yield self.runWithConnection(
|
||||
self._new_transaction,
|
||||
desc, after_callbacks, exception_callbacks, func,
|
||||
*args, **kwargs
|
||||
desc,
|
||||
after_callbacks,
|
||||
exception_callbacks,
|
||||
func,
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
for after_callback, after_args, after_kwargs in after_callbacks:
|
||||
|
@ -434,7 +445,7 @@ class SQLBaseStore(object):
|
|||
parent_context = LoggingContext.current_context()
|
||||
if parent_context == LoggingContext.sentinel:
|
||||
logger.warn(
|
||||
"Starting db connection from sentinel context: metrics will be lost",
|
||||
"Starting db connection from sentinel context: metrics will be lost"
|
||||
)
|
||||
parent_context = None
|
||||
|
||||
|
@ -453,9 +464,7 @@ class SQLBaseStore(object):
|
|||
return func(conn, *args, **kwargs)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
result = yield self._db_pool.runWithConnection(
|
||||
inner_func, *args, **kwargs
|
||||
)
|
||||
result = yield self._db_pool.runWithConnection(inner_func, *args, **kwargs)
|
||||
|
||||
defer.returnValue(result)
|
||||
|
||||
|
@ -469,9 +478,7 @@ class SQLBaseStore(object):
|
|||
A list of dicts where the key is the column header.
|
||||
"""
|
||||
col_headers = list(intern(str(column[0])) for column in cursor.description)
|
||||
results = list(
|
||||
dict(zip(col_headers, row)) for row in cursor
|
||||
)
|
||||
results = list(dict(zip(col_headers, row)) for row in cursor)
|
||||
return results
|
||||
|
||||
def _execute(self, desc, decoder, query, *args):
|
||||
|
@ -485,6 +492,7 @@ class SQLBaseStore(object):
|
|||
Returns:
|
||||
The result of decoder(results)
|
||||
"""
|
||||
|
||||
def interaction(txn):
|
||||
txn.execute(query, args)
|
||||
if decoder:
|
||||
|
@ -498,8 +506,7 @@ class SQLBaseStore(object):
|
|||
# no complex WHERE clauses, just a dict of values for columns.
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _simple_insert(self, table, values, or_ignore=False,
|
||||
desc="_simple_insert"):
|
||||
def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"):
|
||||
"""Executes an INSERT query on the named table.
|
||||
|
||||
Args:
|
||||
|
@ -511,10 +518,7 @@ class SQLBaseStore(object):
|
|||
`or_ignore` is True
|
||||
"""
|
||||
try:
|
||||
yield self.runInteraction(
|
||||
desc,
|
||||
self._simple_insert_txn, table, values,
|
||||
)
|
||||
yield self.runInteraction(desc, self._simple_insert_txn, table, values)
|
||||
except self.database_engine.module.IntegrityError:
|
||||
# We have to do or_ignore flag at this layer, since we can't reuse
|
||||
# a cursor after we receive an error from the db.
|
||||
|
@ -530,15 +534,13 @@ class SQLBaseStore(object):
|
|||
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
|
||||
table,
|
||||
", ".join(k for k in keys),
|
||||
", ".join("?" for _ in keys)
|
||||
", ".join("?" for _ in keys),
|
||||
)
|
||||
|
||||
txn.execute(sql, vals)
|
||||
|
||||
def _simple_insert_many(self, table, values, desc):
|
||||
return self.runInteraction(
|
||||
desc, self._simple_insert_many_txn, table, values
|
||||
)
|
||||
return self.runInteraction(desc, self._simple_insert_many_txn, table, values)
|
||||
|
||||
@staticmethod
|
||||
def _simple_insert_many_txn(txn, table, values):
|
||||
|
@ -553,24 +555,18 @@ class SQLBaseStore(object):
|
|||
#
|
||||
# The sort is to ensure that we don't rely on dictionary iteration
|
||||
# order.
|
||||
keys, vals = zip(*[
|
||||
zip(
|
||||
*(sorted(i.items(), key=lambda kv: kv[0]))
|
||||
)
|
||||
for i in values
|
||||
if i
|
||||
])
|
||||
keys, vals = zip(
|
||||
*[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
|
||||
)
|
||||
|
||||
for k in keys:
|
||||
if k != keys[0]:
|
||||
raise RuntimeError(
|
||||
"All items must have the same keys"
|
||||
)
|
||||
raise RuntimeError("All items must have the same keys")
|
||||
|
||||
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
|
||||
table,
|
||||
", ".join(k for k in keys[0]),
|
||||
", ".join("?" for _ in keys[0])
|
||||
", ".join("?" for _ in keys[0]),
|
||||
)
|
||||
|
||||
txn.executemany(sql, vals)
|
||||
|
@ -583,7 +579,7 @@ class SQLBaseStore(object):
|
|||
values,
|
||||
insertion_values={},
|
||||
desc="_simple_upsert",
|
||||
lock=True
|
||||
lock=True,
|
||||
):
|
||||
"""
|
||||
|
||||
|
@ -635,13 +631,7 @@ class SQLBaseStore(object):
|
|||
)
|
||||
|
||||
def _simple_upsert_txn(
|
||||
self,
|
||||
txn,
|
||||
table,
|
||||
keyvalues,
|
||||
values,
|
||||
insertion_values={},
|
||||
lock=True,
|
||||
self, txn, table, keyvalues, values, insertion_values={}, lock=True
|
||||
):
|
||||
"""
|
||||
Pick the UPSERT method which works best on the platform. Either the
|
||||
|
@ -665,11 +655,7 @@ class SQLBaseStore(object):
|
|||
and table not in self._unsafe_to_upsert_tables
|
||||
):
|
||||
return self._simple_upsert_txn_native_upsert(
|
||||
txn,
|
||||
table,
|
||||
keyvalues,
|
||||
values,
|
||||
insertion_values=insertion_values,
|
||||
txn, table, keyvalues, values, insertion_values=insertion_values
|
||||
)
|
||||
else:
|
||||
return self._simple_upsert_txn_emulated(
|
||||
|
@ -714,7 +700,7 @@ class SQLBaseStore(object):
|
|||
# SELECT instead to see if it exists.
|
||||
sql = "SELECT 1 FROM %s WHERE %s" % (
|
||||
table,
|
||||
" AND ".join(_getwhere(k) for k in keyvalues)
|
||||
" AND ".join(_getwhere(k) for k in keyvalues),
|
||||
)
|
||||
sqlargs = list(keyvalues.values())
|
||||
txn.execute(sql, sqlargs)
|
||||
|
@ -726,7 +712,7 @@ class SQLBaseStore(object):
|
|||
sql = "UPDATE %s SET %s WHERE %s" % (
|
||||
table,
|
||||
", ".join("%s = ?" % (k,) for k in values),
|
||||
" AND ".join(_getwhere(k) for k in keyvalues)
|
||||
" AND ".join(_getwhere(k) for k in keyvalues),
|
||||
)
|
||||
sqlargs = list(values.values()) + list(keyvalues.values())
|
||||
|
||||
|
@ -773,19 +759,14 @@ class SQLBaseStore(object):
|
|||
latter = "NOTHING"
|
||||
else:
|
||||
allvalues.update(values)
|
||||
latter = (
|
||||
"UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
|
||||
)
|
||||
latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
|
||||
|
||||
sql = (
|
||||
"INSERT INTO %s (%s) VALUES (%s) "
|
||||
"ON CONFLICT (%s) DO %s"
|
||||
) % (
|
||||
sql = ("INSERT INTO %s (%s) VALUES (%s) " "ON CONFLICT (%s) DO %s") % (
|
||||
table,
|
||||
", ".join(k for k in allvalues),
|
||||
", ".join("?" for _ in allvalues),
|
||||
", ".join(k for k in keyvalues),
|
||||
latter
|
||||
latter,
|
||||
)
|
||||
txn.execute(sql, list(allvalues.values()))
|
||||
|
||||
|
@ -870,8 +851,8 @@ class SQLBaseStore(object):
|
|||
latter = "NOTHING"
|
||||
value_values = [() for x in range(len(key_values))]
|
||||
else:
|
||||
latter = (
|
||||
"UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in value_names)
|
||||
latter = "UPDATE SET " + ", ".join(
|
||||
k + "=EXCLUDED." + k for k in value_names
|
||||
)
|
||||
|
||||
sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
|
||||
|
@ -889,8 +870,9 @@ class SQLBaseStore(object):
|
|||
|
||||
return txn.execute_batch(sql, args)
|
||||
|
||||
def _simple_select_one(self, table, keyvalues, retcols,
|
||||
allow_none=False, desc="_simple_select_one"):
|
||||
def _simple_select_one(
|
||||
self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"
|
||||
):
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
return a single row, returning multiple columns from it.
|
||||
|
||||
|
@ -903,14 +885,17 @@ class SQLBaseStore(object):
|
|||
statement returns no rows
|
||||
"""
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_select_one_txn,
|
||||
table, keyvalues, retcols, allow_none,
|
||||
desc, self._simple_select_one_txn, table, keyvalues, retcols, allow_none
|
||||
)
|
||||
|
||||
def _simple_select_one_onecol(self, table, keyvalues, retcol,
|
||||
allow_none=False,
|
||||
desc="_simple_select_one_onecol"):
|
||||
def _simple_select_one_onecol(
|
||||
self,
|
||||
table,
|
||||
keyvalues,
|
||||
retcol,
|
||||
allow_none=False,
|
||||
desc="_simple_select_one_onecol",
|
||||
):
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
return a single row, returning a single column from it.
|
||||
|
||||
|
@ -922,17 +907,18 @@ class SQLBaseStore(object):
|
|||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_select_one_onecol_txn,
|
||||
table, keyvalues, retcol, allow_none=allow_none,
|
||||
table,
|
||||
keyvalues,
|
||||
retcol,
|
||||
allow_none=allow_none,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol,
|
||||
allow_none=False):
|
||||
def _simple_select_one_onecol_txn(
|
||||
cls, txn, table, keyvalues, retcol, allow_none=False
|
||||
):
|
||||
ret = cls._simple_select_onecol_txn(
|
||||
txn,
|
||||
table=table,
|
||||
keyvalues=keyvalues,
|
||||
retcol=retcol,
|
||||
txn, table=table, keyvalues=keyvalues, retcol=retcol
|
||||
)
|
||||
|
||||
if ret:
|
||||
|
@ -945,12 +931,7 @@ class SQLBaseStore(object):
|
|||
|
||||
@staticmethod
|
||||
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
|
||||
sql = (
|
||||
"SELECT %(retcol)s FROM %(table)s"
|
||||
) % {
|
||||
"retcol": retcol,
|
||||
"table": table,
|
||||
}
|
||||
sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
|
||||
|
||||
if keyvalues:
|
||||
sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
|
||||
|
@ -960,8 +941,9 @@ class SQLBaseStore(object):
|
|||
|
||||
return [r[0] for r in txn]
|
||||
|
||||
def _simple_select_onecol(self, table, keyvalues, retcol,
|
||||
desc="_simple_select_onecol"):
|
||||
def _simple_select_onecol(
|
||||
self, table, keyvalues, retcol, desc="_simple_select_onecol"
|
||||
):
|
||||
"""Executes a SELECT query on the named table, which returns a list
|
||||
comprising of the values of the named column from the selected rows.
|
||||
|
||||
|
@ -974,13 +956,12 @@ class SQLBaseStore(object):
|
|||
Deferred: Results in a list
|
||||
"""
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_select_onecol_txn,
|
||||
table, keyvalues, retcol
|
||||
desc, self._simple_select_onecol_txn, table, keyvalues, retcol
|
||||
)
|
||||
|
||||
def _simple_select_list(self, table, keyvalues, retcols,
|
||||
desc="_simple_select_list"):
|
||||
def _simple_select_list(
|
||||
self, table, keyvalues, retcols, desc="_simple_select_list"
|
||||
):
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
|
||||
|
@ -994,9 +975,7 @@ class SQLBaseStore(object):
|
|||
defer.Deferred: resolves to list[dict[str, Any]]
|
||||
"""
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_select_list_txn,
|
||||
table, keyvalues, retcols
|
||||
desc, self._simple_select_list_txn, table, keyvalues, retcols
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -1016,22 +995,26 @@ class SQLBaseStore(object):
|
|||
sql = "SELECT %s FROM %s WHERE %s" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues),
|
||||
)
|
||||
txn.execute(sql, list(keyvalues.values()))
|
||||
else:
|
||||
sql = "SELECT %s FROM %s" % (
|
||||
", ".join(retcols),
|
||||
table
|
||||
)
|
||||
sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
|
||||
txn.execute(sql)
|
||||
|
||||
return cls.cursor_to_dict(txn)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _simple_select_many_batch(self, table, column, iterable, retcols,
|
||||
keyvalues={}, desc="_simple_select_many_batch",
|
||||
batch_size=100):
|
||||
def _simple_select_many_batch(
|
||||
self,
|
||||
table,
|
||||
column,
|
||||
iterable,
|
||||
retcols,
|
||||
keyvalues={},
|
||||
desc="_simple_select_many_batch",
|
||||
batch_size=100,
|
||||
):
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
|
||||
|
@ -1053,14 +1036,17 @@ class SQLBaseStore(object):
|
|||
it_list = list(iterable)
|
||||
|
||||
chunks = [
|
||||
it_list[i:i + batch_size]
|
||||
for i in range(0, len(it_list), batch_size)
|
||||
it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
|
||||
]
|
||||
for chunk in chunks:
|
||||
rows = yield self.runInteraction(
|
||||
desc,
|
||||
self._simple_select_many_txn,
|
||||
table, column, chunk, keyvalues, retcols
|
||||
table,
|
||||
column,
|
||||
chunk,
|
||||
keyvalues,
|
||||
retcols,
|
||||
)
|
||||
|
||||
results.extend(rows)
|
||||
|
@ -1089,9 +1075,7 @@ class SQLBaseStore(object):
|
|||
|
||||
clauses = []
|
||||
values = []
|
||||
clauses.append(
|
||||
"%s IN (%s)" % (column, ",".join("?" for _ in iterable))
|
||||
)
|
||||
clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
|
||||
values.extend(iterable)
|
||||
|
||||
for key, value in iteritems(keyvalues):
|
||||
|
@ -1099,19 +1083,14 @@ class SQLBaseStore(object):
|
|||
values.append(value)
|
||||
|
||||
if clauses:
|
||||
sql = "%s WHERE %s" % (
|
||||
sql,
|
||||
" AND ".join(clauses),
|
||||
)
|
||||
sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
|
||||
|
||||
txn.execute(sql, values)
|
||||
return cls.cursor_to_dict(txn)
|
||||
|
||||
def _simple_update(self, table, keyvalues, updatevalues, desc):
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_update_txn,
|
||||
table, keyvalues, updatevalues,
|
||||
desc, self._simple_update_txn, table, keyvalues, updatevalues
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -1127,15 +1106,13 @@ class SQLBaseStore(object):
|
|||
where,
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
update_sql,
|
||||
list(updatevalues.values()) + list(keyvalues.values())
|
||||
)
|
||||
txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values()))
|
||||
|
||||
return txn.rowcount
|
||||
|
||||
def _simple_update_one(self, table, keyvalues, updatevalues,
|
||||
desc="_simple_update_one"):
|
||||
def _simple_update_one(
|
||||
self, table, keyvalues, updatevalues, desc="_simple_update_one"
|
||||
):
|
||||
"""Executes an UPDATE query on the named table, setting new values for
|
||||
columns in a row matching the key values.
|
||||
|
||||
|
@ -1154,9 +1131,7 @@ class SQLBaseStore(object):
|
|||
the update column in the 'keyvalues' dict as well.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_update_one_txn,
|
||||
table, keyvalues, updatevalues,
|
||||
desc, self._simple_update_one_txn, table, keyvalues, updatevalues
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -1169,12 +1144,11 @@ class SQLBaseStore(object):
|
|||
raise StoreError(500, "More than one row matched (%s)" % (table,))
|
||||
|
||||
@staticmethod
|
||||
def _simple_select_one_txn(txn, table, keyvalues, retcols,
|
||||
allow_none=False):
|
||||
def _simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
|
||||
select_sql = "SELECT %s FROM %s WHERE %s" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues)
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues),
|
||||
)
|
||||
|
||||
txn.execute(select_sql, list(keyvalues.values()))
|
||||
|
@ -1197,9 +1171,7 @@ class SQLBaseStore(object):
|
|||
table : string giving the table name
|
||||
keyvalues : dict of column names and values to select the row with
|
||||
"""
|
||||
return self.runInteraction(
|
||||
desc, self._simple_delete_one_txn, table, keyvalues
|
||||
)
|
||||
return self.runInteraction(desc, self._simple_delete_one_txn, table, keyvalues)
|
||||
|
||||
@staticmethod
|
||||
def _simple_delete_one_txn(txn, table, keyvalues):
|
||||
|
@ -1212,7 +1184,7 @@ class SQLBaseStore(object):
|
|||
"""
|
||||
sql = "DELETE FROM %s WHERE %s" % (
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues),
|
||||
)
|
||||
|
||||
txn.execute(sql, list(keyvalues.values()))
|
||||
|
@ -1222,15 +1194,13 @@ class SQLBaseStore(object):
|
|||
raise StoreError(500, "More than one row matched (%s)" % (table,))
|
||||
|
||||
def _simple_delete(self, table, keyvalues, desc):
|
||||
return self.runInteraction(
|
||||
desc, self._simple_delete_txn, table, keyvalues
|
||||
)
|
||||
return self.runInteraction(desc, self._simple_delete_txn, table, keyvalues)
|
||||
|
||||
@staticmethod
|
||||
def _simple_delete_txn(txn, table, keyvalues):
|
||||
sql = "DELETE FROM %s WHERE %s" % (
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues),
|
||||
)
|
||||
|
||||
return txn.execute(sql, list(keyvalues.values()))
|
||||
|
@ -1260,9 +1230,7 @@ class SQLBaseStore(object):
|
|||
|
||||
clauses = []
|
||||
values = []
|
||||
clauses.append(
|
||||
"%s IN (%s)" % (column, ",".join("?" for _ in iterable))
|
||||
)
|
||||
clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
|
||||
values.extend(iterable)
|
||||
|
||||
for key, value in iteritems(keyvalues):
|
||||
|
@ -1270,14 +1238,12 @@ class SQLBaseStore(object):
|
|||
values.append(value)
|
||||
|
||||
if clauses:
|
||||
sql = "%s WHERE %s" % (
|
||||
sql,
|
||||
" AND ".join(clauses),
|
||||
)
|
||||
sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
|
||||
return txn.execute(sql, values)
|
||||
|
||||
def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
|
||||
max_value, limit=100000):
|
||||
def _get_cache_dict(
|
||||
self, db_conn, table, entity_column, stream_column, max_value, limit=100000
|
||||
):
|
||||
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
|
||||
# It doesn't really matter how many we get, the StreamChangeCache will
|
||||
# do the right thing to ensure it respects the max size of cache.
|
||||
|
@ -1297,10 +1263,7 @@ class SQLBaseStore(object):
|
|||
txn = db_conn.cursor()
|
||||
txn.execute(sql, (int(max_value),))
|
||||
|
||||
cache = {
|
||||
row[0]: int(row[1])
|
||||
for row in txn
|
||||
}
|
||||
cache = {row[0]: int(row[1]) for row in txn}
|
||||
|
||||
txn.close()
|
||||
|
||||
|
@ -1342,9 +1305,7 @@ class SQLBaseStore(object):
|
|||
# be safe.
|
||||
for chunk in batch_iter(members_changed, 50):
|
||||
keys = itertools.chain([room_id], chunk)
|
||||
self._send_invalidation_to_replication(
|
||||
txn, _CURRENT_STATE_CACHE_NAME, keys,
|
||||
)
|
||||
self._send_invalidation_to_replication(txn, _CURRENT_STATE_CACHE_NAME, keys)
|
||||
|
||||
def _invalidate_state_caches(self, room_id, members_changed):
|
||||
"""Invalidates caches that are based on the current state, but does
|
||||
|
@ -1356,22 +1317,12 @@ class SQLBaseStore(object):
|
|||
changed
|
||||
"""
|
||||
for host in set(get_domain_from_id(u) for u in members_changed):
|
||||
self._attempt_to_invalidate_cache(
|
||||
"is_host_joined", (room_id, host,),
|
||||
)
|
||||
self._attempt_to_invalidate_cache(
|
||||
"was_host_joined", (room_id, host,),
|
||||
)
|
||||
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
|
||||
self._attempt_to_invalidate_cache("was_host_joined", (room_id, host))
|
||||
|
||||
self._attempt_to_invalidate_cache(
|
||||
"get_users_in_room", (room_id,),
|
||||
)
|
||||
self._attempt_to_invalidate_cache(
|
||||
"get_room_summary", (room_id,),
|
||||
)
|
||||
self._attempt_to_invalidate_cache(
|
||||
"get_current_state_ids", (room_id,),
|
||||
)
|
||||
self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
|
||||
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
|
||||
self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))
|
||||
|
||||
def _attempt_to_invalidate_cache(self, cache_name, key):
|
||||
"""Attempts to invalidate the cache of the given name, ignoring if the
|
||||
|
@ -1419,7 +1370,7 @@ class SQLBaseStore(object):
|
|||
"cache_func": cache_name,
|
||||
"keys": list(keys),
|
||||
"invalidation_ts": self.clock.time_msec(),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def get_all_updated_caches(self, last_id, current_id, limit):
|
||||
|
@ -1435,11 +1386,10 @@ class SQLBaseStore(object):
|
|||
" FROM cache_invalidation_stream"
|
||||
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (last_id, limit,))
|
||||
txn.execute(sql, (last_id, limit))
|
||||
return txn.fetchall()
|
||||
return self.runInteraction(
|
||||
"get_all_updated_caches", get_all_updated_caches_txn
|
||||
)
|
||||
|
||||
return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn)
|
||||
|
||||
def get_cache_stream_token(self):
|
||||
if self._cache_id_gen:
|
||||
|
@ -1447,8 +1397,9 @@ class SQLBaseStore(object):
|
|||
else:
|
||||
return 0
|
||||
|
||||
def _simple_select_list_paginate(self, table, keyvalues, pagevalues, retcols,
|
||||
desc="_simple_select_list_paginate"):
|
||||
def _simple_select_list_paginate(
|
||||
self, table, keyvalues, pagevalues, retcols, desc="_simple_select_list_paginate"
|
||||
):
|
||||
"""Executes a SELECT query on the named table with start and limit,
|
||||
of row numbers, which may return zero or number of rows from start to limit,
|
||||
returning the result as a list of dicts.
|
||||
|
@ -1468,11 +1419,16 @@ class SQLBaseStore(object):
|
|||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_select_list_paginate_txn,
|
||||
table, keyvalues, pagevalues, retcols
|
||||
table,
|
||||
keyvalues,
|
||||
pagevalues,
|
||||
retcols,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _simple_select_list_paginate_txn(cls, txn, table, keyvalues, pagevalues, retcols):
|
||||
def _simple_select_list_paginate_txn(
|
||||
cls, txn, table, keyvalues, pagevalues, retcols
|
||||
):
|
||||
"""Executes a SELECT query on the named table with start and limit,
|
||||
of row numbers, which may return zero or number of rows from start to limit,
|
||||
returning the result as a list of dicts.
|
||||
|
@ -1497,22 +1453,23 @@ class SQLBaseStore(object):
|
|||
", ".join(retcols),
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues),
|
||||
" ? ASC LIMIT ? OFFSET ?"
|
||||
" ? ASC LIMIT ? OFFSET ?",
|
||||
)
|
||||
txn.execute(sql, list(keyvalues.values()) + list(pagevalues))
|
||||
else:
|
||||
sql = "SELECT %s FROM %s ORDER BY %s" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
" ? ASC LIMIT ? OFFSET ?"
|
||||
" ? ASC LIMIT ? OFFSET ?",
|
||||
)
|
||||
txn.execute(sql, pagevalues)
|
||||
|
||||
return cls.cursor_to_dict(txn)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_list_paginate(self, table, keyvalues, pagevalues, retcols,
|
||||
desc="get_user_list_paginate"):
|
||||
def get_user_list_paginate(
|
||||
self, table, keyvalues, pagevalues, retcols, desc="get_user_list_paginate"
|
||||
):
|
||||
"""Get a list of users from start row to a limit number of rows. This will
|
||||
return a json object with users and total number of users in users list.
|
||||
|
||||
|
@ -1532,16 +1489,13 @@ class SQLBaseStore(object):
|
|||
users = yield self.runInteraction(
|
||||
desc,
|
||||
self._simple_select_list_paginate_txn,
|
||||
table, keyvalues, pagevalues, retcols
|
||||
table,
|
||||
keyvalues,
|
||||
pagevalues,
|
||||
retcols,
|
||||
)
|
||||
count = yield self.runInteraction(
|
||||
desc,
|
||||
self.get_user_count_txn
|
||||
)
|
||||
retval = {
|
||||
"users": users,
|
||||
"total": count
|
||||
}
|
||||
count = yield self.runInteraction(desc, self.get_user_count_txn)
|
||||
retval = {"users": users, "total": count}
|
||||
defer.returnValue(retval)
|
||||
|
||||
def get_user_count_txn(self, txn):
|
||||
|
@ -1556,8 +1510,9 @@ class SQLBaseStore(object):
|
|||
txn.execute(sql_count)
|
||||
return txn.fetchone()[0]
|
||||
|
||||
def _simple_search_list(self, table, term, col, retcols,
|
||||
desc="_simple_search_list"):
|
||||
def _simple_search_list(
|
||||
self, table, term, col, retcols, desc="_simple_search_list"
|
||||
):
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
|
||||
|
@ -1572,9 +1527,7 @@ class SQLBaseStore(object):
|
|||
"""
|
||||
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_search_list_txn,
|
||||
table, term, col, retcols
|
||||
desc, self._simple_search_list_txn, table, term, col, retcols
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -1593,11 +1546,7 @@ class SQLBaseStore(object):
|
|||
defer.Deferred: resolves to list[dict[str, Any]] or None
|
||||
"""
|
||||
if term:
|
||||
sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
col
|
||||
)
|
||||
sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
|
||||
termvalues = ["%%" + term + "%%"]
|
||||
txn.execute(sql, termvalues)
|
||||
else:
|
||||
|
@ -1618,6 +1567,7 @@ class _RollbackButIsFineException(Exception):
|
|||
""" This exception is used to rollback a transaction without implying
|
||||
something went wrong.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue