Convert additional databases to async/await part 3 (#8201)

This commit is contained in:
Patrick Cloke 2020-09-01 11:04:17 -04:00 committed by GitHub
parent 7d103a594e
commit 37db6252b7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 121 additions and 87 deletions

View file

@ -16,9 +16,7 @@
import abc
import logging
from typing import List, Optional, Tuple
from twisted.internet import defer
from typing import Dict, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
@ -58,14 +56,16 @@ class AccountDataWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cached()
def get_account_data_for_user(self, user_id):
async def get_account_data_for_user(
self, user_id: str
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
"""Get all the client account_data for a user.
Args:
user_id(str): The user to get the account_data for.
user_id: The user to get the account_data for.
Returns:
A deferred pair of a dict of global account_data and a dict
mapping from room_id string to per room account_data dicts.
A 2-tuple of a dict of global account_data and a dict mapping from
room_id string to per room account_data dicts.
"""
def get_account_data_for_user_txn(txn):
@ -94,7 +94,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return global_account_data, by_room
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_account_data_for_user", get_account_data_for_user_txn
)
@ -120,14 +120,16 @@ class AccountDataWorkerStore(SQLBaseStore):
return None
@cached(num_args=2)
def get_account_data_for_room(self, user_id, room_id):
async def get_account_data_for_room(
self, user_id: str, room_id: str
) -> Dict[str, JsonDict]:
"""Get all the client account_data for a user for a room.
Args:
user_id(str): The user to get the account_data for.
room_id(str): The room to get the account_data for.
user_id: The user to get the account_data for.
room_id: The room to get the account_data for.
Returns:
A deferred dict of the room account_data
A dict of the room account_data
"""
def get_account_data_for_room_txn(txn):
@ -142,21 +144,22 @@ class AccountDataWorkerStore(SQLBaseStore):
row["account_data_type"]: db_to_json(row["content"]) for row in rows
}
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_account_data_for_room", get_account_data_for_room_txn
)
@cached(num_args=3, max_entries=5000)
def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
async def get_account_data_for_room_and_type(
self, user_id: str, room_id: str, account_data_type: str
) -> Optional[JsonDict]:
"""Get the client account_data of given type for a user for a room.
Args:
user_id(str): The user to get the account_data for.
room_id(str): The room to get the account_data for.
account_data_type (str): The account data type to get.
user_id: The user to get the account_data for.
room_id: The room to get the account_data for.
account_data_type: The account data type to get.
Returns:
A deferred of the room account_data for that type, or None if
there isn't any set.
The room account_data for that type, or None if there isn't any set.
"""
def get_account_data_for_room_and_type_txn(txn):
@ -174,7 +177,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return db_to_json(content_json) if content_json else None
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
)
@ -238,12 +241,14 @@ class AccountDataWorkerStore(SQLBaseStore):
"get_updated_room_account_data", get_updated_room_account_data_txn
)
def get_updated_account_data_for_user(self, user_id, stream_id):
async def get_updated_account_data_for_user(
self, user_id: str, stream_id: int
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
"""Get all the client account_data for a that's changed for a user
Args:
user_id(str): The user to get the account_data for.
stream_id(int): The point in the stream since which to get updates
user_id: The user to get the account_data for.
stream_id: The point in the stream since which to get updates
Returns:
A deferred pair of a dict of global account_data and a dict
mapping from room_id string to per room account_data dicts.
@ -277,9 +282,9 @@ class AccountDataWorkerStore(SQLBaseStore):
user_id, int(stream_id)
)
if not changed:
return defer.succeed(({}, {}))
return ({}, {})
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
@ -416,7 +421,7 @@ class AccountDataStore(AccountDataWorkerStore):
return self._account_data_id_gen.get_current_token()
def _update_max_stream_id(self, next_id: int):
async def _update_max_stream_id(self, next_id: int) -> None:
"""Update the max stream_id
Args:
@ -435,4 +440,4 @@ class AccountDataStore(AccountDataWorkerStore):
)
txn.execute(update_max_id_sql, (next_id, next_id))
return self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
await self.db_pool.runInteraction("update_account_data_max_stream_id", _update)