synapse-product/synapse/storage/_base.py
Daniel Wagner-Hall cecbd636e9 /tokenrefresh POST endpoint
This allows refresh tokens to be exchanged for (access_token,
refresh_token).

It also starts issuing them on login, though no clients currently
interpret them.
2015-08-20 16:21:35 +01:00

920 lines
31 KiB
Python

# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.api.errors import StoreError
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache
import synapse.metrics
from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer
from collections import namedtuple
import sys
import time
import threading
logger = logging.getLogger(__name__)
sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
perf_logger = logging.getLogger("synapse.storage.TIME")
metrics = synapse.metrics.get_metrics_for("synapse.storage")
sql_scheduling_timer = metrics.register_distribution("schedule_time")
sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
method."""
__slots__ = ["txn", "name", "database_engine", "after_callbacks"]
def __init__(self, txn, name, database_engine, after_callbacks):
object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine)
object.__setattr__(self, "after_callbacks", after_callbacks)
def call_after(self, callback, *args):
"""Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the
correct thread.
"""
self.after_callbacks.append((callback, args))
def __getattr__(self, name):
return getattr(self.txn, name)
def __setattr__(self, name, value):
setattr(self.txn, name, value)
def execute(self, sql, *args):
self._do_execute(self.txn.execute, sql, *args)
def executemany(self, sql, *args):
self._do_execute(self.txn.executemany, sql, *args)
def _do_execute(self, func, sql, *args):
# TODO(paul): Maybe use 'info' and 'debug' for values?
sql_logger.debug("[SQL] {%s} %s", self.name, sql)
sql = self.database_engine.convert_param_style(sql)
if args:
try:
sql_logger.debug(
"[SQL values] {%s} %r",
self.name, args[0]
)
except:
# Don't let logging failures stop SQL from working
pass
start = time.time() * 1000
try:
return func(
sql, *args
)
except Exception as e:
logger.debug("[SQL FAIL] {%s} %s", self.name, e)
raise
finally:
msecs = (time.time() * 1000) - start
sql_logger.debug("[SQL time] {%s} %f", self.name, msecs)
sql_query_timer.inc_by(msecs, sql.split()[0])
class PerformanceCounters(object):
def __init__(self):
self.current_counters = {}
self.previous_counters = {}
def update(self, key, start_time, end_time=None):
if end_time is None:
end_time = time.time() * 1000
duration = end_time - start_time
count, cum_time = self.current_counters.get(key, (0, 0))
count += 1
cum_time += duration
self.current_counters[key] = (count, cum_time)
return end_time
def interval(self, interval_duration, limit=3):
counters = []
for name, (count, cum_time) in self.current_counters.items():
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
counters.append((
(cum_time - prev_time) / interval_duration,
count - prev_count,
name
))
self.previous_counters = dict(self.current_counters)
counters.sort(reverse=True)
top_n_counters = ", ".join(
"%s(%d): %.3f%%" % (name, count, 100 * ratio)
for ratio, count, name in counters[:limit]
)
return top_n_counters
class SQLBaseStore(object):
_TXN_ID = 0
def __init__(self, hs):
self.hs = hs
self._db_pool = hs.get_db_pool()
self._clock = hs.get_clock()
self._previous_txn_total_time = 0
self._current_txn_total_time = 0
self._previous_loop_ts = 0
# TODO(paul): These can eventually be removed once the metrics code
# is running in mainline, and we have some nice monitoring frontends
# to watch it
self._txn_perf_counters = PerformanceCounters()
self._get_event_counters = PerformanceCounters()
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
max_entries=hs.config.event_cache_size)
self._state_group_cache = DictionaryCache("*stateGroupCache*", 100000)
self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
self._event_fetch_ongoing = 0
self._pending_ds = []
self.database_engine = hs.database_engine
self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self)
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
def loop():
curr = self._current_txn_total_time
prev = self._previous_txn_total_time
self._previous_txn_total_time = curr
time_now = self._clock.time_msec()
time_then = self._previous_loop_ts
self._previous_loop_ts = time_now
ratio = (curr - prev)/(time_now - time_then)
top_three_counters = self._txn_perf_counters.interval(
time_now - time_then, limit=3
)
top_3_event_counters = self._get_event_counters.interval(
time_now - time_then, limit=3
)
perf_logger.info(
"Total database time: %.3f%% {%s} {%s}",
ratio * 100, top_three_counters, top_3_event_counters
)
self._clock.looping_call(loop, 10000)
def _new_transaction(self, conn, desc, after_callbacks, func, *args, **kwargs):
start = time.time() * 1000
txn_id = self._TXN_ID
# We don't really need these to be unique, so lets stop it from
# growing really large.
self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
name = "%s-%x" % (desc, txn_id, )
transaction_logger.debug("[TXN START] {%s}", name)
try:
i = 0
N = 5
while True:
try:
txn = conn.cursor()
txn = LoggingTransaction(
txn, name, self.database_engine, after_callbacks
)
r = func(txn, *args, **kwargs)
conn.commit()
return r
except self.database_engine.module.OperationalError as e:
# This can happen if the database disappears mid
# transaction.
logger.warn(
"[TXN OPERROR] {%s} %s %d/%d",
name, e, i, N
)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
raise
except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e):
logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
raise
except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
finally:
end = time.time() * 1000
duration = end - start
transaction_logger.debug("[TXN END] {%s} %f", name, duration)
self._current_txn_total_time += duration
self._txn_perf_counters.update(desc, start, end)
sql_txn_timer.inc_by(duration, desc)
@defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool."""
current_context = LoggingContext.current_context()
start_time = time.time() * 1000
after_callbacks = []
def inner_func(conn, *args, **kwargs):
with LoggingContext("runInteraction") as context:
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
conn.reconnect()
current_context.copy_to(context)
return self._new_transaction(
conn, desc, after_callbacks, func, *args, **kwargs
)
result = yield preserve_context_over_fn(
self._db_pool.runWithConnection,
inner_func, *args, **kwargs
)
for after_callback, after_args in after_callbacks:
after_callback(*after_args)
defer.returnValue(result)
@defer.inlineCallbacks
def runWithConnection(self, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool."""
current_context = LoggingContext.current_context()
start_time = time.time() * 1000
def inner_func(conn, *args, **kwargs):
with LoggingContext("runWithConnection") as context:
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
conn.reconnect()
current_context.copy_to(context)
return func(conn, *args, **kwargs)
result = yield preserve_context_over_fn(
self._db_pool.runWithConnection,
inner_func, *args, **kwargs
)
defer.returnValue(result)
def cursor_to_dict(self, cursor):
"""Converts a SQL cursor into an list of dicts.
Args:
cursor : The DBAPI cursor which has executed a query.
Returns:
A list of dicts where the key is the column header.
"""
col_headers = list(column[0] for column in cursor.description)
results = list(
dict(zip(col_headers, row)) for row in cursor.fetchall()
)
return results
def _execute(self, desc, decoder, query, *args):
"""Runs a single query for a result set.
Args:
decoder - The function which can resolve the cursor results to
something meaningful.
query - The query string to execute
*args - Query args.
Returns:
The result of decoder(results)
"""
def interaction(txn):
txn.execute(query, args)
if decoder:
return decoder(txn)
else:
return txn.fetchall()
return self.runInteraction(desc, interaction)
def _execute_and_decode(self, desc, query, *args):
return self._execute(desc, self.cursor_to_dict, query, *args)
# "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns.
@defer.inlineCallbacks
def _simple_insert(self, table, values, or_ignore=False,
desc="_simple_insert"):
"""Executes an INSERT query on the named table.
Args:
table : string giving the table name
values : dict of new column names and values for them
"""
try:
yield self.runInteraction(
desc,
self._simple_insert_txn, table, values,
)
except self.database_engine.module.IntegrityError:
# We have to do or_ignore flag at this layer, since we can't reuse
# a cursor after we receive an error from the db.
if not or_ignore:
raise
@log_function
def _simple_insert_txn(self, txn, table, values):
keys, vals = zip(*values.items())
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
table,
", ".join(k for k in keys),
", ".join("?" for _ in keys)
)
txn.execute(sql, vals)
def _simple_insert_many_txn(self, txn, table, values):
if not values:
return
# This is a *slight* abomination to get a list of tuples of key names
# and a list of tuples of value names.
#
# i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
# => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
#
# The sort is to ensure that we don't rely on dictionary iteration
# order.
keys, vals = zip(*[
zip(
*(sorted(i.items(), key=lambda kv: kv[0]))
)
for i in values
if i
])
for k in keys:
if k != keys[0]:
raise RuntimeError(
"All items must have the same keys"
)
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
table,
", ".join(k for k in keys[0]),
", ".join("?" for _ in keys[0])
)
txn.executemany(sql, vals)
def _simple_upsert(self, table, keyvalues, values,
insertion_values={}, desc="_simple_upsert", lock=True):
"""
Args:
table (str): The table to upsert into
keyvalues (dict): The unique key tables and their new values
values (dict): The nonunique columns and their new values
insertion_values (dict): key/values to use when inserting
Returns: A deferred
"""
return self.runInteraction(
desc,
self._simple_upsert_txn, table, keyvalues, values, insertion_values,
lock
)
def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
lock=True):
# We need to lock the table :(, unless we're *really* careful
if lock:
self.database_engine.lock_table(txn, table)
# Try to update
sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in values),
" AND ".join("%s = ?" % (k,) for k in keyvalues)
)
sqlargs = values.values() + keyvalues.values()
logger.debug(
"[SQL] %s Args=%s",
sql, sqlargs,
)
txn.execute(sql, sqlargs)
if txn.rowcount == 0:
# We didn't update and rows so insert a new one
allvalues = {}
allvalues.update(keyvalues)
allvalues.update(values)
allvalues.update(insertion_values)
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues)
)
logger.debug(
"[SQL] %s Args=%s",
sql, keyvalues.values(),
)
txn.execute(sql, allvalues.values())
def _simple_select_one(self, table, keyvalues, retcols,
allow_none=False, desc="_simple_select_one"):
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it.
Args:
table : string giving the table name
keyvalues : dict of column names and values to select the row with
retcols : list of strings giving the names of the columns to return
allow_none : If true, return None instead of failing if the SELECT
statement returns no rows
"""
return self.runInteraction(
desc,
self._simple_select_one_txn,
table, keyvalues, retcols, allow_none,
)
def _simple_select_one_onecol(self, table, keyvalues, retcol,
allow_none=False,
desc="_simple_select_one_onecol"):
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it."
Args:
table : string giving the table name
keyvalues : dict of column names and values to select the row with
retcol : string giving the name of the column to return
"""
return self.runInteraction(
desc,
self._simple_select_one_onecol_txn,
table, keyvalues, retcol, allow_none=allow_none,
)
def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol,
allow_none=False):
ret = self._simple_select_onecol_txn(
txn,
table=table,
keyvalues=keyvalues,
retcol=retcol,
)
if ret:
return ret[0]
else:
if allow_none:
return None
else:
raise StoreError(404, "No row found")
def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
sql = (
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
) % {
"retcol": retcol,
"table": table,
"where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
}
txn.execute(sql, keyvalues.values())
return [r[0] for r in txn.fetchall()]
def _simple_select_onecol(self, table, keyvalues, retcol,
desc="_simple_select_onecol"):
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
Args:
table (str): table name
keyvalues (dict): column names and values to select the rows with
retcol (str): column whos value we wish to retrieve.
Returns:
Deferred: Results in a list
"""
return self.runInteraction(
desc,
self._simple_select_onecol_txn,
table, keyvalues, retcol
)
def _simple_select_list(self, table, keyvalues, retcols,
desc="_simple_select_list"):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
table : string giving the table name
keyvalues : dict of column names and values to select the rows with,
or None to not apply a WHERE clause.
retcols : list of strings giving the names of the columns to return
"""
return self.runInteraction(
desc,
self._simple_select_list_txn,
table, keyvalues, retcols
)
def _simple_select_list_txn(self, txn, table, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
txn : Transaction object
table : string giving the table name
keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return
"""
if keyvalues:
sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
txn.execute(sql, keyvalues.values())
else:
sql = "SELECT %s FROM %s" % (
", ".join(retcols),
table
)
txn.execute(sql)
return self.cursor_to_dict(txn)
def _simple_update_one(self, table, keyvalues, updatevalues,
desc="_simple_update_one"):
"""Executes an UPDATE query on the named table, setting new values for
columns in a row matching the key values.
Args:
table : string giving the table name
keyvalues : dict of column names and values to select the row with
updatevalues : dict giving column names and values to update
retcols : optional list of column names to return
If present, retcols gives a list of column names on which to perform
a SELECT statement *before* performing the UPDATE statement. The values
of these will be returned in a dict.
These are performed within the same transaction, allowing an atomic
get-and-set. This can be used to implement compare-and-set by putting
the update column in the 'keyvalues' dict as well.
"""
return self.runInteraction(
desc,
self._simple_update_one_txn,
table, keyvalues, updatevalues,
)
def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues):
update_sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in updatevalues),
" AND ".join("%s = ?" % (k,) for k in keyvalues)
)
txn.execute(
update_sql,
updatevalues.values() + keyvalues.values()
)
if txn.rowcount == 0:
raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched")
def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
allow_none=False):
select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
" AND ".join("%s = ?" % (k,) for k in keyvalues)
)
txn.execute(select_sql, keyvalues.values())
row = txn.fetchone()
if not row:
if allow_none:
return None
raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched")
return dict(zip(retcols, row))
def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
retcols=None, allow_none=False,
desc="_simple_selectupdate_one"):
""" Combined SELECT then UPDATE."""
def func(txn):
ret = None
if retcols:
ret = self._simple_select_one_txn(
txn,
table=table,
keyvalues=keyvalues,
retcols=retcols,
allow_none=allow_none,
)
if updatevalues:
self._simple_update_one_txn(
txn,
table=table,
keyvalues=keyvalues,
updatevalues=updatevalues,
)
# if txn.rowcount == 0:
# raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched")
return ret
return self.runInteraction(desc, func)
def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"):
"""Executes a DELETE query on the named table, expecting to delete a
single row.
Args:
table : string giving the table name
keyvalues : dict of column names and values to select the row with
"""
sql = "DELETE FROM %s WHERE %s" % (
table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
def func(txn):
txn.execute(sql, keyvalues.values())
if txn.rowcount == 0:
raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "more than one row matched")
return self.runInteraction(desc, func)
def _simple_delete(self, table, keyvalues, desc="_simple_delete"):
"""Executes a DELETE query on the named table.
Args:
table : string giving the table name
keyvalues : dict of column names and values to select the row with
"""
return self.runInteraction(desc, self._simple_delete_txn)
def _simple_delete_txn(self, txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % (
table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
return txn.execute(sql, keyvalues.values())
def _simple_max_id(self, table):
"""Executes a SELECT query on the named table, expecting to return the
max value for the column "id".
Args:
table : string giving the table name
"""
sql = "SELECT MAX(id) AS id FROM %s" % table
def func(txn):
txn.execute(sql)
max_id = self.cursor_to_dict(txn)[0]["id"]
if max_id is None:
return 0
return max_id
return self.runInteraction("_simple_max_id", func)
def get_next_stream_id(self):
with self._next_stream_id_lock:
i = self._next_stream_id
self._next_stream_id += 1
return i
class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying
something went wrong.
"""
pass
class Table(object):
""" A base class used to store information about a particular table.
"""
table_name = None
""" str: The name of the table """
fields = None
""" list: The field names """
EntryType = None
""" Type: A tuple type used to decode the results """
_select_where_clause = "SELECT %s FROM %s WHERE %s"
_select_clause = "SELECT %s FROM %s"
_insert_clause = "REPLACE INTO %s (%s) VALUES (%s)"
@classmethod
def select_statement(cls, where_clause=None):
"""
Args:
where_clause (str): The WHERE clause to use.
Returns:
str: An SQL statement to select rows from the table with the given
WHERE clause.
"""
if where_clause:
return cls._select_where_clause % (
", ".join(cls.fields),
cls.table_name,
where_clause
)
else:
return cls._select_clause % (
", ".join(cls.fields),
cls.table_name,
)
@classmethod
def insert_statement(cls):
return cls._insert_clause % (
cls.table_name,
", ".join(cls.fields),
", ".join(["?"] * len(cls.fields)),
)
@classmethod
def decode_single_result(cls, results):
""" Given an iterable of tuples, return a single instance of
`EntryType` or None if the iterable is empty
Args:
results (list): The results list to convert to `EntryType`
Returns:
EntryType: An instance of `EntryType`
"""
results = list(results)
if results:
return cls.EntryType(*results[0])
else:
return None
@classmethod
def decode_results(cls, results):
""" Given an iterable of tuples, return a list of `EntryType`
Args:
results (list): The results list to convert to `EntryType`
Returns:
list: A list of `EntryType`
"""
return [cls.EntryType(*row) for row in results]
@classmethod
def get_fields_string(cls, prefix=None):
if prefix:
to_join = ("%s.%s" % (prefix, f) for f in cls.fields)
else:
to_join = cls.fields
return ", ".join(to_join)
class JoinHelper(object):
""" Used to help do joins on tables by looking at the tables' fields and
creating a list of unique fields to use with SELECTs and a namedtuple
to dump the results into.
Attributes:
tables (list): List of `Table` classes
EntryType (type)
"""
def __init__(self, *tables):
self.tables = tables
res = []
for table in self.tables:
res += [f for f in table.fields if f not in res]
self.EntryType = namedtuple("JoinHelperEntry", res)
def get_fields(self, **prefixes):
"""Get a string representing a list of fields for use in SELECT
statements with the given prefixes applied to each.
For example::
JoinHelper(PdusTable, StateTable).get_fields(
PdusTable="pdus",
StateTable="state"
)
"""
res = []
for field in self.EntryType._fields:
for table in self.tables:
if field in table.fields:
res.append("%s.%s" % (prefixes[table.__name__], field))
break
return ", ".join(res)
def decode_results(self, rows):
return [self.EntryType(*row) for row in rows]