Move DB pool and helper functions into dedicated Database class

This commit is contained in:
Erik Johnston 2019-12-04 13:52:46 +00:00
parent ddbbfc9512
commit 756d4942f5
62 changed files with 2377 additions and 2295 deletions

View file

@ -45,7 +45,7 @@ class RegistrationWorkerStore(SQLBaseStore):
@cached()
def get_user_by_id(self, user_id):
return self.simple_select_one(
return self.db.simple_select_one(
table="users",
keyvalues={"name": user_id},
retcols=[
@ -94,7 +94,7 @@ class RegistrationWorkerStore(SQLBaseStore):
including the keys `name`, `is_guest`, `device_id`, `token_id`,
`valid_until_ms`.
"""
return self.runInteraction(
return self.db.runInteraction(
"get_user_by_access_token", self._query_for_auth, token
)
@ -109,7 +109,7 @@ class RegistrationWorkerStore(SQLBaseStore):
otherwise int representation of the timestamp (as a number of
milliseconds since epoch).
"""
res = yield self.simple_select_one_onecol(
res = yield self.db.simple_select_one_onecol(
table="account_validity",
keyvalues={"user_id": user_id},
retcol="expiration_ts_ms",
@ -137,7 +137,7 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
def set_account_validity_for_user_txn(txn):
self.simple_update_txn(
self.db.simple_update_txn(
txn=txn,
table="account_validity",
keyvalues={"user_id": user_id},
@ -151,7 +151,7 @@ class RegistrationWorkerStore(SQLBaseStore):
txn, self.get_expiration_ts_for_user, (user_id,)
)
yield self.runInteraction(
yield self.db.runInteraction(
"set_account_validity_for_user", set_account_validity_for_user_txn
)
@ -167,7 +167,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Raises:
StoreError: The provided token is already set for another user.
"""
yield self.simple_update_one(
yield self.db.simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={"renewal_token": renewal_token},
@ -184,7 +184,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
defer.Deferred[str]: The ID of the user to which the token belongs.
"""
res = yield self.simple_select_one_onecol(
res = yield self.db.simple_select_one_onecol(
table="account_validity",
keyvalues={"renewal_token": renewal_token},
retcol="user_id",
@ -203,7 +203,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
defer.Deferred[str]: The renewal token associated with this user ID.
"""
res = yield self.simple_select_one_onecol(
res = yield self.db.simple_select_one_onecol(
table="account_validity",
keyvalues={"user_id": user_id},
retcol="renewal_token",
@ -229,9 +229,9 @@ class RegistrationWorkerStore(SQLBaseStore):
)
values = [False, now_ms, renew_at]
txn.execute(sql, values)
return self.cursor_to_dict(txn)
return self.db.cursor_to_dict(txn)
res = yield self.runInteraction(
res = yield self.db.runInteraction(
"get_users_expiring_soon",
select_users_txn,
self.clock.time_msec(),
@ -250,7 +250,7 @@ class RegistrationWorkerStore(SQLBaseStore):
email_sent (bool): Flag which indicates whether a renewal email has been sent
to this user.
"""
yield self.simple_update_one(
yield self.db.simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={"email_sent": email_sent},
@ -265,7 +265,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Args:
user_id (str): ID of the user to remove from the account validity table.
"""
yield self.simple_delete_one(
yield self.db.simple_delete_one(
table="account_validity",
keyvalues={"user_id": user_id},
desc="delete_account_validity_for_user",
@ -281,7 +281,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns (bool):
true iff the user is a server admin, false otherwise.
"""
res = yield self.simple_select_one_onecol(
res = yield self.db.simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
retcol="admin",
@ -299,7 +299,7 @@ class RegistrationWorkerStore(SQLBaseStore):
admin (bool): true iff the user is to be a server admin,
false otherwise.
"""
return self.simple_update_one(
return self.db.simple_update_one(
table="users",
keyvalues={"name": user.to_string()},
updatevalues={"admin": 1 if admin else 0},
@ -316,7 +316,7 @@ class RegistrationWorkerStore(SQLBaseStore):
)
txn.execute(sql, (token,))
rows = self.cursor_to_dict(txn)
rows = self.db.cursor_to_dict(txn)
if rows:
return rows[0]
@ -332,7 +332,9 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred[bool]: True if user 'user_type' is null or empty string
"""
res = yield self.runInteraction("is_real_user", self.is_real_user_txn, user_id)
res = yield self.db.runInteraction(
"is_real_user", self.is_real_user_txn, user_id
)
return res
@cachedInlineCallbacks()
@ -345,13 +347,13 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred[bool]: True if user is of type UserTypes.SUPPORT
"""
res = yield self.runInteraction(
res = yield self.db.runInteraction(
"is_support_user", self.is_support_user_txn, user_id
)
return res
def is_real_user_txn(self, txn, user_id):
res = self.simple_select_one_onecol_txn(
res = self.db.simple_select_one_onecol_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
@ -361,7 +363,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return res is None
def is_support_user_txn(self, txn, user_id):
res = self.simple_select_one_onecol_txn(
res = self.db.simple_select_one_onecol_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
@ -380,7 +382,7 @@ class RegistrationWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,))
return dict(txn)
return self.runInteraction("get_users_by_id_case_insensitive", f)
return self.db.runInteraction("get_users_by_id_case_insensitive", f)
async def get_user_by_external_id(
self, auth_provider: str, external_id: str
@ -394,7 +396,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
str|None: the mxid of the user, or None if they are not known
"""
return await self.simple_select_one_onecol(
return await self.db.simple_select_one_onecol(
table="user_external_ids",
keyvalues={"auth_provider": auth_provider, "external_id": external_id},
retcol="user_id",
@ -408,12 +410,12 @@ class RegistrationWorkerStore(SQLBaseStore):
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users")
rows = self.cursor_to_dict(txn)
rows = self.db.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
ret = yield self.runInteraction("count_users", _count_users)
ret = yield self.db.runInteraction("count_users", _count_users)
return ret
def count_daily_user_type(self):
@ -445,7 +447,7 @@ class RegistrationWorkerStore(SQLBaseStore):
results[row[0]] = row[1]
return results
return self.runInteraction("count_daily_user_type", _count_daily_user_type)
return self.db.runInteraction("count_daily_user_type", _count_daily_user_type)
@defer.inlineCallbacks
def count_nonbridged_users(self):
@ -459,7 +461,7 @@ class RegistrationWorkerStore(SQLBaseStore):
(count,) = txn.fetchone()
return count
ret = yield self.runInteraction("count_users", _count_users)
ret = yield self.db.runInteraction("count_users", _count_users)
return ret
@defer.inlineCallbacks
@ -468,12 +470,12 @@ class RegistrationWorkerStore(SQLBaseStore):
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
rows = self.cursor_to_dict(txn)
rows = self.db.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
ret = yield self.runInteraction("count_real_users", _count_users)
ret = yield self.db.runInteraction("count_real_users", _count_users)
return ret
@defer.inlineCallbacks
@ -503,7 +505,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return (
(
yield self.runInteraction(
yield self.db.runInteraction(
"find_next_generated_user_id", _find_next_generated_user_id
)
)
@ -520,7 +522,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred[str|None]: user id or None if no user id/threepid mapping exists
"""
user_id = yield self.runInteraction(
user_id = yield self.db.runInteraction(
"get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
)
return user_id
@ -536,7 +538,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
str|None: user id or None if no user id/threepid mapping exists
"""
ret = self.simple_select_one_txn(
ret = self.db.simple_select_one_txn(
txn,
"user_threepids",
{"medium": medium, "address": address},
@ -549,7 +551,7 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
yield self.simple_upsert(
yield self.db.simple_upsert(
"user_threepids",
{"medium": medium, "address": address},
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
@ -557,7 +559,7 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def user_get_threepids(self, user_id):
ret = yield self.simple_select_list(
ret = yield self.db.simple_select_list(
"user_threepids",
{"user_id": user_id},
["medium", "address", "validated_at", "added_at"],
@ -566,7 +568,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return ret
def user_delete_threepid(self, user_id, medium, address):
return self.simple_delete(
return self.db.simple_delete(
"user_threepids",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
desc="user_delete_threepid",
@ -579,7 +581,7 @@ class RegistrationWorkerStore(SQLBaseStore):
user_id: The user id to delete all threepids of
"""
return self.simple_delete(
return self.db.simple_delete(
"user_threepids",
keyvalues={"user_id": user_id},
desc="user_delete_threepids",
@ -601,7 +603,7 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
# We need to use an upsert, in case they user had already bound the
# threepid
return self.simple_upsert(
return self.db.simple_upsert(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
@ -627,7 +629,7 @@ class RegistrationWorkerStore(SQLBaseStore):
medium (str): The medium of the threepid (e.g "email")
address (str): The address of the threepid (e.g "bob@example.com")
"""
return self.simple_select_list(
return self.db.simple_select_list(
table="user_threepid_id_server",
keyvalues={"user_id": user_id},
retcols=["medium", "address"],
@ -648,7 +650,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred
"""
return self.simple_delete(
return self.db.simple_delete(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
@ -671,7 +673,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred[list[str]]: Resolves to a list of identity servers
"""
return self.simple_select_onecol(
return self.db.simple_select_onecol(
table="user_threepid_id_server",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
retcol="id_server",
@ -689,7 +691,7 @@ class RegistrationWorkerStore(SQLBaseStore):
defer.Deferred(bool): The requested value.
"""
res = yield self.simple_select_one_onecol(
res = yield self.db.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="deactivated",
@ -756,13 +758,13 @@ class RegistrationWorkerStore(SQLBaseStore):
sql += " LIMIT 1"
txn.execute(sql, list(keyvalues.values()))
rows = self.cursor_to_dict(txn)
rows = self.db.cursor_to_dict(txn)
if not rows:
return None
return rows[0]
return self.runInteraction(
return self.db.runInteraction(
"get_threepid_validation_session", get_threepid_validation_session_txn
)
@ -776,18 +778,18 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
def delete_threepid_session_txn(txn):
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="threepid_validation_token",
keyvalues={"session_id": session_id},
)
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
)
return self.runInteraction(
return self.db.runInteraction(
"delete_threepid_session", delete_threepid_session_txn
)
@ -857,7 +859,7 @@ class RegistrationBackgroundUpdateStore(
(last_user, batch_size),
)
rows = self.cursor_to_dict(txn)
rows = self.db.cursor_to_dict(txn)
if not rows:
return True, 0
@ -880,7 +882,7 @@ class RegistrationBackgroundUpdateStore(
else:
return False, len(rows)
end, nb_processed = yield self.runInteraction(
end, nb_processed = yield self.db.runInteraction(
"users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
)
@ -911,7 +913,7 @@ class RegistrationBackgroundUpdateStore(
txn.executemany(sql, [(id_server,) for id_server in id_servers])
if id_servers:
yield self.runInteraction(
yield self.db.runInteraction(
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
)
@ -961,7 +963,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
next_id = self._access_tokens_id_gen.get_next()
yield self.simple_insert(
yield self.db.simple_insert(
"access_tokens",
{
"id": next_id,
@ -1003,7 +1005,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
Raises:
StoreError if the user_id could not be registered.
"""
return self.runInteraction(
return self.db.runInteraction(
"register_user",
self._register_user,
user_id,
@ -1037,7 +1039,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
# Ensure that the guest user actually exists
# ``allow_none=False`` makes this raise an exception
# if the row isn't in the database.
self.simple_select_one_txn(
self.db.simple_select_one_txn(
txn,
"users",
keyvalues={"name": user_id, "is_guest": 1},
@ -1045,7 +1047,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
allow_none=False,
)
self.simple_update_one_txn(
self.db.simple_update_one_txn(
txn,
"users",
keyvalues={"name": user_id, "is_guest": 1},
@ -1059,7 +1061,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
},
)
else:
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
"users",
values={
@ -1114,7 +1116,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
external_id: id on that system
user_id: complete mxid that it is mapped to
"""
return self.simple_insert(
return self.db.simple_insert(
table="user_external_ids",
values={
"auth_provider": auth_provider,
@ -1132,12 +1134,14 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
def user_set_password_hash_txn(txn):
self.simple_update_one_txn(
self.db.simple_update_one_txn(
txn, "users", {"name": user_id}, {"password_hash": password_hash}
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.runInteraction("user_set_password_hash", user_set_password_hash_txn)
return self.db.runInteraction(
"user_set_password_hash", user_set_password_hash_txn
)
def user_set_consent_version(self, user_id, consent_version):
"""Updates the user table to record privacy policy consent
@ -1152,7 +1156,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
def f(txn):
self.simple_update_one_txn(
self.db.simple_update_one_txn(
txn,
table="users",
keyvalues={"name": user_id},
@ -1160,7 +1164,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.runInteraction("user_set_consent_version", f)
return self.db.runInteraction("user_set_consent_version", f)
def user_set_consent_server_notice_sent(self, user_id, consent_version):
"""Updates the user table to record that we have sent the user a server
@ -1176,7 +1180,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
def f(txn):
self.simple_update_one_txn(
self.db.simple_update_one_txn(
txn,
table="users",
keyvalues={"name": user_id},
@ -1184,7 +1188,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.runInteraction("user_set_consent_server_notice_sent", f)
return self.db.runInteraction("user_set_consent_server_notice_sent", f)
def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
"""
@ -1230,11 +1234,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return tokens_and_devices
return self.runInteraction("user_delete_access_tokens", f)
return self.db.runInteraction("user_delete_access_tokens", f)
def delete_access_token(self, access_token):
def f(txn):
self.simple_delete_one_txn(
self.db.simple_delete_one_txn(
txn, table="access_tokens", keyvalues={"token": access_token}
)
@ -1242,11 +1246,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
txn, self.get_user_by_access_token, (access_token,)
)
return self.runInteraction("delete_access_token", f)
return self.db.runInteraction("delete_access_token", f)
@cachedInlineCallbacks()
def is_guest(self, user_id):
res = yield self.simple_select_one_onecol(
res = yield self.db.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="is_guest",
@ -1261,7 +1265,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
Adds a user to the table of users who need to be parted from all the rooms they're
in
"""
return self.simple_insert(
return self.db.simple_insert(
"users_pending_deactivation",
values={"user_id": user_id},
desc="add_user_pending_deactivation",
@ -1274,7 +1278,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
# XXX: This should be simple_delete_one but we failed to put a unique index on
# the table, so somehow duplicate entries have ended up in it.
return self.simple_delete(
return self.db.simple_delete(
"users_pending_deactivation",
keyvalues={"user_id": user_id},
desc="del_user_pending_deactivation",
@ -1285,7 +1289,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
Gets one user from the table of users waiting to be parted from all the rooms
they're in.
"""
return self.simple_select_one_onecol(
return self.db.simple_select_one_onecol(
"users_pending_deactivation",
keyvalues={},
retcol="user_id",
@ -1315,7 +1319,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
# Insert everything into a transaction in order to run atomically
def validate_threepid_session_txn(txn):
row = self.simple_select_one_txn(
row = self.db.simple_select_one_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
@ -1333,7 +1337,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
400, "This client_secret does not match the provided session_id"
)
row = self.simple_select_one_txn(
row = self.db.simple_select_one_txn(
txn,
table="threepid_validation_token",
keyvalues={"session_id": session_id, "token": token},
@ -1358,7 +1362,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
# Looks good. Validate the session
self.simple_update_txn(
self.db.simple_update_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
@ -1368,7 +1372,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return next_link
# Return next_link if it exists
return self.runInteraction(
return self.db.runInteraction(
"validate_threepid_session_txn", validate_threepid_session_txn
)
@ -1401,7 +1405,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
if validated_at:
insertion_values["validated_at"] = validated_at
return self.simple_upsert(
return self.db.simple_upsert(
table="threepid_validation_session",
keyvalues={"session_id": session_id},
values={"last_send_attempt": send_attempt},
@ -1439,7 +1443,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
def start_or_continue_validation_session_txn(txn):
# Create or update a validation session
self.simple_upsert_txn(
self.db.simple_upsert_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
@ -1452,7 +1456,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
# Create a new validation token with this session ID
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="threepid_validation_token",
values={
@ -1463,7 +1467,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
},
)
return self.runInteraction(
return self.db.runInteraction(
"start_or_continue_validation_session",
start_or_continue_validation_session_txn,
)
@ -1478,7 +1482,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
return txn.execute(sql, (ts,))
return self.runInteraction(
return self.db.runInteraction(
"cull_expired_threepid_validation_tokens",
cull_expired_threepid_validation_tokens_txn,
self.clock.time_msec(),
@ -1493,7 +1497,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
deactivated (bool): The value to set for `deactivated`.
"""
yield self.runInteraction(
yield self.db.runInteraction(
"set_user_deactivated_status",
self.set_user_deactivated_status_txn,
user_id,
@ -1501,7 +1505,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
self.simple_update_one_txn(
self.db.simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
@ -1529,14 +1533,14 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
txn.execute(sql, [])
res = self.cursor_to_dict(txn)
res = self.db.cursor_to_dict(txn)
if res:
for user in res:
self.set_expiration_date_for_user_txn(
txn, user["name"], use_delta=True
)
yield self.runInteraction(
yield self.db.runInteraction(
"get_users_with_no_expiration_date",
select_users_with_no_expiration_date_txn,
)
@ -1560,7 +1564,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
expiration_ts,
)
self.simple_upsert_txn(
self.db.simple_upsert_txn(
txn,
"account_validity",
keyvalues={"user_id": user_id},