Convert additional databases to async/await part 3 (#8201)

This commit is contained in:
Patrick Cloke 2020-09-01 11:04:17 -04:00 committed by GitHub
parent 7d103a594e
commit 37db6252b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 121 additions and 87 deletions

1
changelog.d/8201.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -433,7 +433,7 @@ class BackgroundUpdater(object):
"background_updates", keyvalues={"update_name": update_name} "background_updates", keyvalues={"update_name": update_name}
) )
def _background_update_progress(self, update_name: str, progress: dict): async def _background_update_progress(self, update_name: str, progress: dict):
"""Update the progress of a background update """Update the progress of a background update
Args: Args:
@ -441,7 +441,7 @@ class BackgroundUpdater(object):
progress: The progress of the update. progress: The progress of the update.
""" """
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"background_update_progress", "background_update_progress",
self._background_update_progress_txn, self._background_update_progress_txn,
update_name, update_name,

View File

@ -16,9 +16,7 @@
import abc import abc
import logging import logging
from typing import List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
@ -58,14 +56,16 @@ class AccountDataWorkerStore(SQLBaseStore):
raise NotImplementedError() raise NotImplementedError()
@cached() @cached()
def get_account_data_for_user(self, user_id): async def get_account_data_for_user(
self, user_id: str
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
"""Get all the client account_data for a user. """Get all the client account_data for a user.
Args: Args:
user_id(str): The user to get the account_data for. user_id: The user to get the account_data for.
Returns: Returns:
A deferred pair of a dict of global account_data and a dict A 2-tuple of a dict of global account_data and a dict mapping from
mapping from room_id string to per room account_data dicts. room_id string to per room account_data dicts.
""" """
def get_account_data_for_user_txn(txn): def get_account_data_for_user_txn(txn):
@ -94,7 +94,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return global_account_data, by_room return global_account_data, by_room
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_account_data_for_user", get_account_data_for_user_txn "get_account_data_for_user", get_account_data_for_user_txn
) )
@ -120,14 +120,16 @@ class AccountDataWorkerStore(SQLBaseStore):
return None return None
@cached(num_args=2) @cached(num_args=2)
def get_account_data_for_room(self, user_id, room_id): async def get_account_data_for_room(
self, user_id: str, room_id: str
) -> Dict[str, JsonDict]:
"""Get all the client account_data for a user for a room. """Get all the client account_data for a user for a room.
Args: Args:
user_id(str): The user to get the account_data for. user_id: The user to get the account_data for.
room_id(str): The room to get the account_data for. room_id: The room to get the account_data for.
Returns: Returns:
A deferred dict of the room account_data A dict of the room account_data
""" """
def get_account_data_for_room_txn(txn): def get_account_data_for_room_txn(txn):
@ -142,21 +144,22 @@ class AccountDataWorkerStore(SQLBaseStore):
row["account_data_type"]: db_to_json(row["content"]) for row in rows row["account_data_type"]: db_to_json(row["content"]) for row in rows
} }
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_account_data_for_room", get_account_data_for_room_txn "get_account_data_for_room", get_account_data_for_room_txn
) )
@cached(num_args=3, max_entries=5000) @cached(num_args=3, max_entries=5000)
def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type): async def get_account_data_for_room_and_type(
self, user_id: str, room_id: str, account_data_type: str
) -> Optional[JsonDict]:
"""Get the client account_data of given type for a user for a room. """Get the client account_data of given type for a user for a room.
Args: Args:
user_id(str): The user to get the account_data for. user_id: The user to get the account_data for.
room_id(str): The room to get the account_data for. room_id: The room to get the account_data for.
account_data_type (str): The account data type to get. account_data_type: The account data type to get.
Returns: Returns:
A deferred of the room account_data for that type, or None if The room account_data for that type, or None if there isn't any set.
there isn't any set.
""" """
def get_account_data_for_room_and_type_txn(txn): def get_account_data_for_room_and_type_txn(txn):
@ -174,7 +177,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return db_to_json(content_json) if content_json else None return db_to_json(content_json) if content_json else None
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
) )
@ -238,12 +241,14 @@ class AccountDataWorkerStore(SQLBaseStore):
"get_updated_room_account_data", get_updated_room_account_data_txn "get_updated_room_account_data", get_updated_room_account_data_txn
) )
def get_updated_account_data_for_user(self, user_id, stream_id): async def get_updated_account_data_for_user(
self, user_id: str, stream_id: int
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
"""Get all the client account_data for a that's changed for a user """Get all the client account_data for a that's changed for a user
Args: Args:
user_id(str): The user to get the account_data for. user_id: The user to get the account_data for.
stream_id(int): The point in the stream since which to get updates stream_id: The point in the stream since which to get updates
Returns: Returns:
A deferred pair of a dict of global account_data and a dict A deferred pair of a dict of global account_data and a dict
mapping from room_id string to per room account_data dicts. mapping from room_id string to per room account_data dicts.
@ -277,9 +282,9 @@ class AccountDataWorkerStore(SQLBaseStore):
user_id, int(stream_id) user_id, int(stream_id)
) )
if not changed: if not changed:
return defer.succeed(({}, {})) return ({}, {})
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
) )
@ -416,7 +421,7 @@ class AccountDataStore(AccountDataWorkerStore):
return self._account_data_id_gen.get_current_token() return self._account_data_id_gen.get_current_token()
def _update_max_stream_id(self, next_id: int): async def _update_max_stream_id(self, next_id: int) -> None:
"""Update the max stream_id """Update the max stream_id
Args: Args:
@ -435,4 +440,4 @@ class AccountDataStore(AccountDataWorkerStore):
) )
txn.execute(update_max_id_sql, (next_id, next_id)) txn.execute(update_max_id_sql, (next_id, next_id))
return self.db_pool.runInteraction("update_account_data_max_stream_id", _update) await self.db_pool.runInteraction("update_account_data_max_stream_id", _update)

View File

@ -34,13 +34,15 @@ if TYPE_CHECKING:
class EndToEndKeyWorkerStore(SQLBaseStore): class EndToEndKeyWorkerStore(SQLBaseStore):
def get_e2e_device_keys_for_federation_query(self, user_id: str): async def get_e2e_device_keys_for_federation_query(
self, user_id: str
) -> Tuple[int, List[JsonDict]]:
"""Get all devices (with any device keys) for a user """Get all devices (with any device keys) for a user
Returns: Returns:
Deferred which resolves to (stream_id, devices) (stream_id, devices)
""" """
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_e2e_device_keys_for_federation_query", "get_e2e_device_keys_for_federation_query",
self._get_e2e_device_keys_for_federation_query_txn, self._get_e2e_device_keys_for_federation_query_txn,
user_id, user_id,
@ -292,10 +294,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
) )
@cached(max_entries=10000) @cached(max_entries=10000)
def count_e2e_one_time_keys(self, user_id, device_id): async def count_e2e_one_time_keys(
self, user_id: str, device_id: str
) -> Dict[str, int]:
""" Count the number of one time keys the server has for a device """ Count the number of one time keys the server has for a device
Returns: Returns:
Dict mapping from algorithm to number of keys for that algorithm. A mapping from algorithm to number of keys for that algorithm.
""" """
def _count_e2e_one_time_keys(txn): def _count_e2e_one_time_keys(txn):
@ -310,7 +314,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
result[algorithm] = key_count result[algorithm] = key_count
return result return result
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"count_e2e_one_time_keys", _count_e2e_one_time_keys "count_e2e_one_time_keys", _count_e2e_one_time_keys
) )
@ -348,7 +352,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
list_name="user_ids", list_name="user_ids",
num_args=1, num_args=1,
) )
def _get_bare_e2e_cross_signing_keys_bulk( async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: List[str] self, user_ids: List[str]
) -> Dict[str, Dict[str, dict]]: ) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this """Returns the cross-signing keys for a set of users. The output of this
@ -356,16 +360,15 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
the signatures for the calling user need to be fetched. the signatures for the calling user need to be fetched.
Args: Args:
user_ids (list[str]): the users whose keys are being requested user_ids: the users whose keys are being requested
Returns: Returns:
dict[str, dict[str, dict]]: mapping from user ID to key type to key A mapping from user ID to key type to key data. If a user's cross-signing
data. If a user's cross-signing keys were not found, either keys were not found, either their user ID will not be in the dict, or
their user ID will not be in the dict, or their user ID will map their user ID will map to None.
to None.
""" """
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_bare_e2e_cross_signing_keys_bulk", "get_bare_e2e_cross_signing_keys_bulk",
self._get_bare_e2e_cross_signing_keys_bulk_txn, self._get_bare_e2e_cross_signing_keys_bulk_txn,
user_ids, user_ids,
@ -588,7 +591,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): async def set_e2e_device_keys(
self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
) -> bool:
"""Stores device keys for a device. Returns whether there was a change """Stores device keys for a device. Returns whether there was a change
or the keys were already in the database. or the keys were already in the database.
""" """
@ -624,12 +629,21 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
log_kv({"message": "Device keys stored."}) log_kv({"message": "Device keys stored."})
return True return True
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn "set_e2e_device_keys", _set_e2e_device_keys_txn
) )
def claim_e2e_one_time_keys(self, query_list): async def claim_e2e_one_time_keys(
"""Take a list of one time keys out of the database""" self, query_list: Iterable[Tuple[str, str, str]]
) -> Dict[str, Dict[str, Dict[str, bytes]]]:
"""Take a list of one time keys out of the database.
Args:
query_list: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
"""
@trace @trace
def _claim_e2e_one_time_keys(txn): def _claim_e2e_one_time_keys(txn):
@ -665,11 +679,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
) )
return result return result
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
) )
def delete_e2e_keys_by_device(self, user_id, device_id): async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
def delete_e2e_keys_by_device_txn(txn): def delete_e2e_keys_by_device_txn(txn):
log_kv( log_kv(
{ {
@ -692,7 +706,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
txn, self.count_e2e_one_time_keys, (user_id, device_id) txn, self.count_e2e_one_time_keys, (user_id, device_id)
) )
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
) )

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, List, Optional from typing import Any, Dict, Iterable, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
@ -93,7 +93,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="mark_local_media_as_safe", desc="mark_local_media_as_safe",
) )
def get_url_cache(self, url, ts): async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
"""Get the media_id and ts for a cached URL as of the given timestamp """Get the media_id and ts for a cached URL as of the given timestamp
Returns: Returns:
None if the URL isn't cached. None if the URL isn't cached.
@ -139,7 +139,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
) )
) )
return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
async def store_url_cache( async def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts self, url, response_code, etag, expires_ts, og, media_id, download_ts
@ -237,12 +237,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_cached_remote_media", desc="store_cached_remote_media",
) )
def update_cached_last_access_time(self, local_media, remote_media, time_ms): async def update_cached_last_access_time(
self,
local_media: Iterable[str],
remote_media: Iterable[Tuple[str, str]],
time_ms: int,
):
"""Updates the last access time of the given media """Updates the last access time of the given media
Args: Args:
local_media (iterable[str]): Set of media_ids local_media: Set of media_ids
remote_media (iterable[(str, str)]): Set of (server_name, media_id) remote_media: Set of (server_name, media_id)
time_ms: Current time in milliseconds time_ms: Current time in milliseconds
""" """
@ -267,7 +272,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media)) txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"update_cached_last_access_time", update_cache_txn "update_cached_last_access_time", update_cache_txn
) )
@ -325,7 +330,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts "get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
) )
def delete_remote_media(self, media_origin, media_id): async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
def delete_remote_media_txn(txn): def delete_remote_media_txn(txn):
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
@ -338,11 +343,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
keyvalues={"media_origin": media_origin, "media_id": media_id}, keyvalues={"media_origin": media_origin, "media_id": media_id},
) )
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"delete_remote_media", delete_remote_media_txn "delete_remote_media", delete_remote_media_txn
) )
def get_expired_url_cache(self, now_ts): async def get_expired_url_cache(self, now_ts: int) -> List[str]:
sql = ( sql = (
"SELECT media_id FROM local_media_repository_url_cache" "SELECT media_id FROM local_media_repository_url_cache"
" WHERE expires_ts < ?" " WHERE expires_ts < ?"
@ -354,7 +359,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute(sql, (now_ts,)) txn.execute(sql, (now_ts,))
return [row[0] for row in txn] return [row[0] for row in txn]
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_expired_url_cache", _get_expired_url_cache_txn "get_expired_url_cache", _get_expired_url_cache_txn
) )
@ -371,7 +376,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"delete_url_cache", _delete_url_cache_txn "delete_url_cache", _delete_url_cache_txn
) )
def get_url_cache_media_before(self, before_ts): async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
sql = ( sql = (
"SELECT media_id FROM local_media_repository" "SELECT media_id FROM local_media_repository"
" WHERE created_ts < ? AND url_cache IS NOT NULL" " WHERE created_ts < ? AND url_cache IS NOT NULL"
@ -383,7 +388,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute(sql, (before_ts,)) txn.execute(sql, (before_ts,))
return [row[0] for row in txn] return [row[0] for row in txn]
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_url_cache_media_before", _get_url_cache_media_before_txn "get_url_cache_media_before", _get_url_cache_media_before_txn
) )

View File

@ -16,9 +16,10 @@
import logging import logging
import re import re
from collections import namedtuple from collections import namedtuple
from typing import List, Optional from typing import List, Optional, Set
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
@ -620,7 +621,9 @@ class SearchStore(SearchBackgroundUpdateStore):
"count": count, "count": count,
} }
def _find_highlights_in_postgres(self, search_query, events): async def _find_highlights_in_postgres(
self, search_query: str, events: List[EventBase]
) -> Set[str]:
"""Given a list of events and a search term, return a list of words """Given a list of events and a search term, return a list of words
that match from the content of the event. that match from the content of the event.
@ -628,11 +631,11 @@ class SearchStore(SearchBackgroundUpdateStore):
highlight the matching parts. highlight the matching parts.
Args: Args:
search_query (str) search_query
events (list): A list of events events: A list of events
Returns: Returns:
deferred : A set of strings. A set of strings.
""" """
def f(txn): def f(txn):
@ -685,7 +688,7 @@ class SearchStore(SearchBackgroundUpdateStore):
return highlight_words return highlight_words
return self.db_pool.runInteraction("_find_highlights", f) return await self.db_pool.runInteraction("_find_highlights", f)
def _to_postgres_options(options_dict): def _to_postgres_options(options_dict):

View File

@ -15,7 +15,7 @@
import logging import logging
import re import re
from typing import Any, Dict, Optional from typing import Any, Dict, Iterable, Optional, Tuple
from synapse.api.constants import EventTypes, JoinRules from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
@ -365,7 +365,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return False return False
def update_profile_in_user_dir(self, user_id, display_name, avatar_url): async def update_profile_in_user_dir(
self, user_id: str, display_name: str, avatar_url: str
) -> None:
""" """
Update or add a user's profile in the user directory. Update or add a user's profile in the user directory.
""" """
@ -458,17 +460,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"update_profile_in_user_dir", _update_profile_in_user_dir_txn "update_profile_in_user_dir", _update_profile_in_user_dir_txn
) )
def add_users_who_share_private_room(self, room_id, user_id_tuples): async def add_users_who_share_private_room(
self, room_id: str, user_id_tuples: Iterable[Tuple[str, str]]
) -> None:
"""Insert entries into the users_who_share_private_rooms table. The first """Insert entries into the users_who_share_private_rooms table. The first
user should be a local user. user should be a local user.
Args: Args:
room_id (str) room_id
user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs. user_id_tuples: iterable of 2-tuple of user IDs.
""" """
def _add_users_who_share_room_txn(txn): def _add_users_who_share_room_txn(txn):
@ -484,17 +488,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
value_values=None, value_values=None,
) )
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"add_users_who_share_room", _add_users_who_share_room_txn "add_users_who_share_room", _add_users_who_share_room_txn
) )
def add_users_in_public_rooms(self, room_id, user_ids): async def add_users_in_public_rooms(
self, room_id: str, user_ids: Iterable[str]
) -> None:
"""Insert entries into the users_who_share_private_rooms table. The first """Insert entries into the users_who_share_private_rooms table. The first
user should be a local user. user should be a local user.
Args: Args:
room_id (str) room_id
user_ids (list[str]) user_ids
""" """
def _add_users_in_public_rooms_txn(txn): def _add_users_in_public_rooms_txn(txn):
@ -508,11 +514,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
value_values=None, value_values=None,
) )
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"add_users_in_public_rooms", _add_users_in_public_rooms_txn "add_users_in_public_rooms", _add_users_in_public_rooms_txn
) )
def delete_all_from_user_dir(self): async def delete_all_from_user_dir(self) -> None:
"""Delete the entire user directory """Delete the entire user directory
""" """
@ -523,7 +529,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
txn.execute("DELETE FROM users_who_share_private_rooms") txn.execute("DELETE FROM users_who_share_private_rooms")
txn.call_after(self.get_user_in_directory.invalidate_all) txn.call_after(self.get_user_in_directory.invalidate_all)
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"delete_all_from_user_dir", _delete_all_from_user_dir_txn "delete_all_from_user_dir", _delete_all_from_user_dir_txn
) )
@ -555,7 +561,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(UserDirectoryStore, self).__init__(database, db_conn, hs) super(UserDirectoryStore, self).__init__(database, db_conn, hs)
def remove_from_user_dir(self, user_id): async def remove_from_user_dir(self, user_id: str) -> None:
def _remove_from_user_dir_txn(txn): def _remove_from_user_dir_txn(txn):
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, table="user_directory", keyvalues={"user_id": user_id} txn, table="user_directory", keyvalues={"user_id": user_id}
@ -578,7 +584,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
) )
txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"remove_from_user_dir", _remove_from_user_dir_txn "remove_from_user_dir", _remove_from_user_dir_txn
) )
@ -605,14 +611,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
return user_ids return user_ids
def remove_user_who_share_room(self, user_id, room_id): async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None:
""" """
Deletes entries in the users_who_share_*_rooms table. The first Deletes entries in the users_who_share_*_rooms table. The first
user should be a local user. user should be a local user.
Args: Args:
user_id (str) user_id
room_id (str) room_id
""" """
def _remove_user_who_share_room_txn(txn): def _remove_user_who_share_room_txn(txn):
@ -632,7 +638,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
keyvalues={"user_id": user_id, "room_id": room_id}, keyvalues={"user_id": user_id, "room_id": room_id},
) )
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"remove_user_who_share_room", _remove_user_who_share_room_txn "remove_user_who_share_room", _remove_user_who_share_room_txn
) )