Pass Database into the data store

This commit is contained in:
Erik Johnston 2019-12-06 13:40:02 +00:00
parent d64bb32a73
commit d537be1ebd
5 changed files with 24 additions and 28 deletions

View File

@ -238,8 +238,7 @@ class HomeServer(object):
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
with self.get_db_conn() as conn: with self.get_db_conn() as conn:
datastore = self.DATASTORE_CLASS(conn, self) self.datastores = DataStores(self.DATASTORE_CLASS, conn, self)
self.datastores = DataStores(datastore, conn, self)
conn.commit() conn.commit()
self.start_time = int(self.get_clock().time()) self.start_time = int(self.get_clock().time())
logger.info("Finished setting up.") logger.info("Finished setting up.")

View File

@ -41,7 +41,7 @@ class SQLBaseStore(object):
self.hs = hs self.hs = hs
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.database_engine = hs.database_engine self.database_engine = hs.database_engine
self.db = Database(hs) # In future this will be passed in self.db = database
self.rand = random.SystemRandom() self.rand = random.SystemRandom()
def _invalidate_state_caches(self, room_id, members_changed): def _invalidate_state_caches(self, room_id, members_changed):

View File

@ -379,7 +379,7 @@ class BackgroundUpdater(object):
logger.debug("[SQL] %s", sql) logger.debug("[SQL] %s", sql)
c.execute(sql) c.execute(sql)
if isinstance(self.db.database_engine, engines.PostgresEngine): if isinstance(self.db.engine, engines.PostgresEngine):
runner = create_index_psql runner = create_index_psql
elif psql_only: elif psql_only:
runner = None runner = None

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.database import Database
class DataStores(object): class DataStores(object):
"""The various data stores. """The various data stores.
@ -20,7 +22,8 @@ class DataStores(object):
These are low level interfaces to physical databases. These are low level interfaces to physical databases.
""" """
def __init__(self, main_store, db_conn, hs): def __init__(self, main_store_class, db_conn, hs):
# Note we pass in the main store here as workers use a different main # Note we pass in the main store here as workers use a different main
# store. # store.
self.main = main_store database = Database(hs)
self.main = main_store_class(database, db_conn, hs)

View File

@ -234,7 +234,7 @@ class Database(object):
# to watch it # to watch it
self._txn_perf_counters = PerformanceCounters() self._txn_perf_counters = PerformanceCounters()
self.database_engine = hs.database_engine self.engine = hs.database_engine
# A set of tables that are not safe to use native upserts in. # A set of tables that are not safe to use native upserts in.
self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
@ -242,10 +242,10 @@ class Database(object):
# We add the user_directory_search table to the blacklist on SQLite # We add the user_directory_search table to the blacklist on SQLite
# because the existing search table does not have an index, making it # because the existing search table does not have an index, making it
# unsafe to use native upserts. # unsafe to use native upserts.
if isinstance(self.database_engine, Sqlite3Engine): if isinstance(self.engine, Sqlite3Engine):
self._unsafe_to_upsert_tables.add("user_directory_search") self._unsafe_to_upsert_tables.add("user_directory_search")
if self.database_engine.can_native_upsert: if self.engine.can_native_upsert:
# Check ASAP (and then later, every 1s) to see if we have finished # Check ASAP (and then later, every 1s) to see if we have finished
# background updates of tables that aren't safe to update. # background updates of tables that aren't safe to update.
self._clock.call_later( self._clock.call_later(
@ -331,7 +331,7 @@ class Database(object):
cursor = LoggingTransaction( cursor = LoggingTransaction(
conn.cursor(), conn.cursor(),
name, name,
self.database_engine, self.engine,
after_callbacks, after_callbacks,
exception_callbacks, exception_callbacks,
) )
@ -339,7 +339,7 @@ class Database(object):
r = func(cursor, *args, **kwargs) r = func(cursor, *args, **kwargs)
conn.commit() conn.commit()
return r return r
except self.database_engine.module.OperationalError as e: except self.engine.module.OperationalError as e:
# This can happen if the database disappears mid # This can happen if the database disappears mid
# transaction. # transaction.
logger.warning( logger.warning(
@ -353,20 +353,20 @@ class Database(object):
i += 1 i += 1
try: try:
conn.rollback() conn.rollback()
except self.database_engine.module.Error as e1: except self.engine.module.Error as e1:
logger.warning( logger.warning(
"[TXN EROLL] {%s} %s", name, exception_to_unicode(e1) "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
) )
continue continue
raise raise
except self.database_engine.module.DatabaseError as e: except self.engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e): if self.engine.is_deadlock(e):
logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N) logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
if i < N: if i < N:
i += 1 i += 1
try: try:
conn.rollback() conn.rollback()
except self.database_engine.module.Error as e1: except self.engine.module.Error as e1:
logger.warning( logger.warning(
"[TXN EROLL] {%s} %s", "[TXN EROLL] {%s} %s",
name, name,
@ -494,7 +494,7 @@ class Database(object):
sql_scheduling_timer.observe(sched_duration_sec) sql_scheduling_timer.observe(sched_duration_sec)
context.add_database_scheduled(sched_duration_sec) context.add_database_scheduled(sched_duration_sec)
if self.database_engine.is_connection_closed(conn): if self.engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection") logger.debug("Reconnecting closed database connection")
conn.reconnect() conn.reconnect()
@ -561,7 +561,7 @@ class Database(object):
""" """
try: try:
yield self.runInteraction(desc, self.simple_insert_txn, table, values) yield self.runInteraction(desc, self.simple_insert_txn, table, values)
except self.database_engine.module.IntegrityError: except self.engine.module.IntegrityError:
# We have to do or_ignore flag at this layer, since we can't reuse # We have to do or_ignore flag at this layer, since we can't reuse
# a cursor after we receive an error from the db. # a cursor after we receive an error from the db.
if not or_ignore: if not or_ignore:
@ -660,7 +660,7 @@ class Database(object):
lock=lock, lock=lock,
) )
return result return result
except self.database_engine.module.IntegrityError as e: except self.engine.module.IntegrityError as e:
attempts += 1 attempts += 1
if attempts >= 5: if attempts >= 5:
# don't retry forever, because things other than races # don't retry forever, because things other than races
@ -692,10 +692,7 @@ class Database(object):
upserts return True if a new entry was created, False if an existing upserts return True if a new entry was created, False if an existing
one was updated. one was updated.
""" """
if ( if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
self.database_engine.can_native_upsert
and table not in self._unsafe_to_upsert_tables
):
return self.simple_upsert_txn_native_upsert( return self.simple_upsert_txn_native_upsert(
txn, table, keyvalues, values, insertion_values=insertion_values txn, table, keyvalues, values, insertion_values=insertion_values
) )
@ -726,7 +723,7 @@ class Database(object):
""" """
# We need to lock the table :(, unless we're *really* careful # We need to lock the table :(, unless we're *really* careful
if lock: if lock:
self.database_engine.lock_table(txn, table) self.engine.lock_table(txn, table)
def _getwhere(key): def _getwhere(key):
# If the value we're passing in is None (aka NULL), we need to use # If the value we're passing in is None (aka NULL), we need to use
@ -828,10 +825,7 @@ class Database(object):
Returns: Returns:
None None
""" """
if ( if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
self.database_engine.can_native_upsert
and table not in self._unsafe_to_upsert_tables
):
return self.simple_upsert_many_txn_native_upsert( return self.simple_upsert_many_txn_native_upsert(
txn, table, key_names, key_values, value_names, value_values txn, table, key_names, key_values, value_names, value_values
) )
@ -1301,7 +1295,7 @@ class Database(object):
"limit": limit, "limit": limit,
} }
sql = self.database_engine.convert_param_style(sql) sql = self.engine.convert_param_style(sql)
txn = db_conn.cursor() txn = db_conn.cursor()
txn.execute(sql, (int(max_value),)) txn.execute(sql, (int(max_value),))