Convert some of the general database methods to async (#8100)

This commit is contained in:
Patrick Cloke 2020-08-17 12:18:01 -04:00 committed by GitHub
parent e04e465b4d
commit 050e20e7ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 69 additions and 59 deletions

1
changelog.d/8100.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -332,8 +332,7 @@ class DatabasePool(object):
""" """
return self._db_pool.running return self._db_pool.running
@defer.inlineCallbacks async def _check_safe_to_upsert(self):
def _check_safe_to_upsert(self):
""" """
Is it safe to use native UPSERT? Is it safe to use native UPSERT?
@ -342,7 +341,7 @@ class DatabasePool(object):
If the background updates have not completed, wait 15 sec and check again. If the background updates have not completed, wait 15 sec and check again.
""" """
updates = yield self.simple_select_list( updates = await self.simple_select_list(
"background_updates", "background_updates",
keyvalues=None, keyvalues=None,
retcols=["update_name"], retcols=["update_name"],
@ -614,8 +613,7 @@ class DatabasePool(object):
# "Simple" SQL API methods that operate on a single table with no JOINs, # "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns. # no complex WHERE clauses, just a dict of values for columns.
@defer.inlineCallbacks async def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
"""Executes an INSERT query on the named table. """Executes an INSERT query on the named table.
Args: Args:
@ -631,7 +629,7 @@ class DatabasePool(object):
`or_ignore` is True `or_ignore` is True
""" """
try: try:
yield self.runInteraction(desc, self.simple_insert_txn, table, values) await self.runInteraction(desc, self.simple_insert_txn, table, values)
except self.engine.module.IntegrityError: except self.engine.module.IntegrityError:
# We have to do or_ignore flag at this layer, since we can't reuse # We have to do or_ignore flag at this layer, since we can't reuse
# a cursor after we receive an error from the db. # a cursor after we receive an error from the db.
@ -684,8 +682,7 @@ class DatabasePool(object):
txn.executemany(sql, vals) txn.executemany(sql, vals)
@defer.inlineCallbacks async def simple_upsert(
def simple_upsert(
self, self,
table, table,
keyvalues, keyvalues,
@ -714,14 +711,14 @@ class DatabasePool(object):
inserting inserting
lock (bool): True to lock the table when doing the upsert. lock (bool): True to lock the table when doing the upsert.
Returns: Returns:
Deferred(None or bool): Native upserts always return None. Emulated None or bool: Native upserts always return None. Emulated
upserts return True if a new entry was created, False if an existing upserts return True if a new entry was created, False if an existing
one was updated. one was updated.
""" """
attempts = 0 attempts = 0
while True: while True:
try: try:
result = yield self.runInteraction( return await self.runInteraction(
desc, desc,
self.simple_upsert_txn, self.simple_upsert_txn,
table, table,
@ -730,7 +727,6 @@ class DatabasePool(object):
insertion_values, insertion_values,
lock=lock, lock=lock,
) )
return result
except self.engine.module.IntegrityError as e: except self.engine.module.IntegrityError as e:
attempts += 1 attempts += 1
if attempts >= 5: if attempts >= 5:
@ -1121,8 +1117,7 @@ class DatabasePool(object):
return cls.cursor_to_dict(txn) return cls.cursor_to_dict(txn)
@defer.inlineCallbacks async def simple_select_many_batch(
def simple_select_many_batch(
self, self,
table, table,
column, column,
@ -1156,7 +1151,7 @@ class DatabasePool(object):
it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size) it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
] ]
for chunk in chunks: for chunk in chunks:
rows = yield self.runInteraction( rows = await self.runInteraction(
desc, desc,
self.simple_select_many_txn, self.simple_select_many_txn,
table, table,

View File

@ -169,7 +169,7 @@ class ApplicationServiceTransactionWorkerStore(
service(ApplicationService): The service whose state to set. service(ApplicationService): The service whose state to set.
state(ApplicationServiceState): The connectivity state to apply. state(ApplicationServiceState): The connectivity state to apply.
Returns: Returns:
A Deferred which resolves when the state was set successfully. An Awaitable which resolves when the state was set successfully.
""" """
return self.db_pool.simple_upsert( return self.db_pool.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state} "application_services_state", {"as_id": service.id}, {"state": state}

View File

@ -847,7 +847,8 @@ class EventsWorkerStore(SQLBaseStore):
"""Given a list of event ids, check if we have already processed and """Given a list of event ids, check if we have already processed and
stored them as non outliers. stored them as non outliers.
""" """
rows = yield self.db_pool.simple_select_many_batch( rows = yield defer.ensureDeferred(
self.db_pool.simple_select_many_batch(
table="events", table="events",
retcols=("event_id",), retcols=("event_id",),
column="event_id", column="event_id",
@ -855,6 +856,7 @@ class EventsWorkerStore(SQLBaseStore):
keyvalues={"outlier": False}, keyvalues={"outlier": False},
desc="have_events_in_timeline", desc="have_events_in_timeline",
) )
)
return {r["event_id"] for r in rows} return {r["event_id"] for r in rows}

View File

@ -17,9 +17,7 @@
import logging import logging
import re import re
from typing import Dict, List, Optional from typing import Awaitable, Dict, List, Optional
from twisted.internet.defer import Deferred
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@ -563,7 +561,7 @@ class RegistrationWorkerStore(SQLBaseStore):
id_server (str) id_server (str)
Returns: Returns:
Deferred Awaitable
""" """
# We need to use an upsert, in case they user had already bound the # We need to use an upsert, in case they user had already bound the
# threepid # threepid
@ -1084,7 +1082,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
def record_user_external_id( def record_user_external_id(
self, auth_provider: str, external_id: str, user_id: str self, auth_provider: str, external_id: str, user_id: str
) -> Deferred: ) -> Awaitable:
"""Record a mapping from an external user id to a mxid """Record a mapping from an external user id to a mxid
Args: Args:

View File

@ -767,13 +767,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return set(room_ids) return set(room_ids)
def get_membership_from_event_ids( async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str] self, member_event_ids: Iterable[str]
) -> List[dict]: ) -> List[dict]:
"""Get user_id and membership of a set of event IDs. """Get user_id and membership of a set of event IDs.
""" """
return self.db_pool.simple_select_many_batch( return await self.db_pool.simple_select_many_batch(
table="room_memberships", table="room_memberships",
column="event_id", column="event_id",
iterable=member_event_ids, iterable=member_event_ids,

View File

@ -64,7 +64,7 @@ class ProfileTestCase(unittest.TestCase):
self.bob = UserID.from_string("@4567:test") self.bob = UserID.from_string("@4567:test")
self.alice = UserID.from_string("@alice:remote") self.alice = UserID.from_string("@alice:remote")
yield self.store.create_profile(self.frank.localpart) yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart))
self.handler = hs.get_profile_handler() self.handler = hs.get_profile_handler()
self.hs = hs self.hs = hs
@ -157,7 +157,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_incoming_fed_query(self): def test_incoming_fed_query(self):
yield self.store.create_profile("caroline") yield defer.ensureDeferred(self.store.create_profile("caroline"))
yield self.store.set_profile_displayname("caroline", "Caroline") yield self.store.set_profile_displayname("caroline", "Caroline")
response = yield defer.ensureDeferred( response = yield defer.ensureDeferred(

View File

@ -156,7 +156,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
([], 0) ([], 0)
) )
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed( self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable(
None None
) )

View File

@ -207,7 +207,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_appservices_state_down(self): def test_set_appservices_state_down(self):
service = Mock(id=self.as_list[1]["id"]) service = Mock(id=self.as_list[1]["id"])
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) yield defer.ensureDeferred(
self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
)
rows = yield self.db_pool.runQuery( rows = yield self.db_pool.runQuery(
self.engine.convert_param_style( self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?" "SELECT as_id FROM application_services_state WHERE state=?"
@ -219,9 +221,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_appservices_state_multiple_up(self): def test_set_appservices_state_multiple_up(self):
service = Mock(id=self.as_list[1]["id"]) service = Mock(id=self.as_list[1]["id"])
yield self.store.set_appservice_state(service, ApplicationServiceState.UP) yield defer.ensureDeferred(
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) self.store.set_appservice_state(service, ApplicationServiceState.UP)
yield self.store.set_appservice_state(service, ApplicationServiceState.UP) )
yield defer.ensureDeferred(
self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
)
yield defer.ensureDeferred(
self.store.set_appservice_state(service, ApplicationServiceState.UP)
)
rows = yield self.db_pool.runQuery( rows = yield self.db_pool.runQuery(
self.engine.convert_param_style( self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?" "SELECT as_id FROM application_services_state WHERE state=?"

View File

@ -66,9 +66,11 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_1col(self): def test_insert_1col(self):
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1
yield self.datastore.db_pool.simple_insert( yield defer.ensureDeferred(
self.datastore.db_pool.simple_insert(
table="tablename", values={"columname": "Value"} table="tablename", values={"columname": "Value"}
) )
)
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_with(
"INSERT INTO tablename (columname) VALUES(?)", ("Value",) "INSERT INTO tablename (columname) VALUES(?)", ("Value",)
@ -78,11 +80,13 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_3cols(self): def test_insert_3cols(self):
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1
yield self.datastore.db_pool.simple_insert( yield defer.ensureDeferred(
self.datastore.db_pool.simple_insert(
table="tablename", table="tablename",
# Use OrderedDict() so we can assert on the SQL generated # Use OrderedDict() so we can assert on the SQL generated
values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
) )
)
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_with(
"INSERT INTO tablename (colA, colB, colC) VALUES(?, ?, ?)", (1, 2, 3) "INSERT INTO tablename (colA, colB, colC) VALUES(?, ?, ?)", (1, 2, 3)

View File

@ -142,7 +142,8 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self): def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts): def add_event(so, ts):
return self.store.db_pool.simple_insert( return defer.ensureDeferred(
self.store.db_pool.simple_insert(
"events", "events",
{ {
"stream_ordering": so, "stream_ordering": so,
@ -157,6 +158,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
"depth": 0, "depth": 0,
}, },
) )
)
# start with the base case where there are no events in the table # start with the base case where there are no events in the table
r = yield self.store.find_first_stream_ordering_after_ts(11) r = yield self.store.find_first_stream_ordering_after_ts(11)

View File

@ -35,7 +35,7 @@ class DataStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_users_paginate(self): def test_get_users_paginate(self):
yield self.store.register_user(self.user.to_string(), "pass") yield self.store.register_user(self.user.to_string(), "pass")
yield self.store.create_profile(self.user.localpart) yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
yield self.store.set_profile_displayname(self.user.localpart, self.displayname) yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
users, total = yield self.store.get_users_paginate( users, total = yield self.store.get_users_paginate(

View File

@ -33,7 +33,7 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_displayname(self): def test_displayname(self):
yield self.store.create_profile(self.u_frank.localpart) yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank") yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
@ -43,7 +43,7 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_avatar_url(self): def test_avatar_url(self):
yield self.store.create_profile(self.u_frank.localpart) yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
yield self.store.set_profile_avatar_url( yield self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here" self.u_frank.localpart, "http://my.site/here"