Convert additional database code to async/await. (#8195)

This commit is contained in:
Patrick Cloke 2020-08-28 07:54:27 -04:00 committed by GitHub
parent d5e73cb6aa
commit 5c03134d0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 246 additions and 175 deletions

1
changelog.d/8195.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -14,11 +14,16 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re import re
from typing import TYPE_CHECKING
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.appservice.api import ApplicationServiceApi
from synapse.types import GroupID, get_domain_from_id from synapse.types import GroupID, get_domain_from_id
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,19 +40,19 @@ class AppServiceTransaction(object):
self.id = id self.id = id
self.events = events self.events = events
def send(self, as_api): async def send(self, as_api: ApplicationServiceApi) -> bool:
"""Sends this transaction using the provided AS API interface. """Sends this transaction using the provided AS API interface.
Args: Args:
as_api(ApplicationServiceApi): The API to use to send. as_api: The API to use to send.
Returns: Returns:
An Awaitable which resolves to True if the transaction was sent. True if the transaction was sent.
""" """
return as_api.push_bulk( return await as_api.push_bulk(
service=self.service, events=self.events, txn_id=self.id service=self.service, events=self.events, txn_id=self.id
) )
def complete(self, store): async def complete(self, store: "DataStore") -> None:
"""Completes this transaction as successful. """Completes this transaction as successful.
Marks this transaction ID on the application service and removes the Marks this transaction ID on the application service and removes the
@ -55,10 +60,8 @@ class AppServiceTransaction(object):
Args: Args:
store: The database store to operate on. store: The database store to operate on.
Returns:
A Deferred which resolves to True if the transaction was completed.
""" """
return store.complete_appservice_txn(service=self.service, txn_id=self.id) await store.complete_appservice_txn(service=self.service, txn_id=self.id)
class ApplicationService(object): class ApplicationService(object):

View File

@ -20,6 +20,7 @@ These actions are mostly only used by the :py:mod:`.replication` module.
""" """
import logging import logging
from typing import Optional, Tuple
from synapse.federation.units import Transaction from synapse.federation.units import Transaction
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
@ -36,19 +37,21 @@ class TransactionActions(object):
self.store = datastore self.store = datastore
@log_function @log_function
def have_responded(self, origin, transaction): async def have_responded(
self, origin: str, transaction: Transaction
) -> Optional[Tuple[int, JsonDict]]:
"""Have we already responded to a transaction with the same id and """Have we already responded to a transaction with the same id and
origin? origin?
Returns: Returns:
Deferred: Results in `None` if we have not previously responded to `None` if we have not previously responded to this transaction or a
this transaction or a 2-tuple of `(int, dict)` representing the 2-tuple of `(int, dict)` representing the response code and response body.
response code and response body.
""" """
if not transaction.transaction_id: transaction_id = transaction.transaction_id # type: ignore
if not transaction_id:
raise RuntimeError("Cannot persist a transaction with no transaction_id") raise RuntimeError("Cannot persist a transaction with no transaction_id")
return self.store.get_received_txn_response(transaction.transaction_id, origin) return await self.store.get_received_txn_response(transaction_id, origin)
@log_function @log_function
async def set_response( async def set_response(

View File

@ -1879,8 +1879,8 @@ class FederationHandler(BaseHandler):
else: else:
return None return None
def get_min_depth_for_context(self, context): async def get_min_depth_for_context(self, context):
return self.store.get_min_depth(context) return await self.store.get_min_depth(context)
async def _handle_new_event( async def _handle_new_event(
self, origin, event, state=None, auth_events=None, backfilled=False self, origin, event, state=None, auth_events=None, backfilled=False

View File

@ -172,7 +172,7 @@ class ApplicationServiceTransactionWorkerStore(
"application_services_state", {"as_id": service.id}, {"state": state} "application_services_state", {"as_id": service.id}, {"state": state}
) )
def create_appservice_txn(self, service, events): async def create_appservice_txn(self, service, events):
"""Atomically creates a new transaction for this application service """Atomically creates a new transaction for this application service
with the given list of events. with the given list of events.
@ -209,20 +209,17 @@ class ApplicationServiceTransactionWorkerStore(
) )
return AppServiceTransaction(service=service, id=new_txn_id, events=events) return AppServiceTransaction(service=service, id=new_txn_id, events=events)
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"create_appservice_txn", _create_appservice_txn "create_appservice_txn", _create_appservice_txn
) )
def complete_appservice_txn(self, txn_id, service): async def complete_appservice_txn(self, txn_id, service) -> None:
"""Completes an application service transaction. """Completes an application service transaction.
Args: Args:
txn_id(str): The transaction ID being completed. txn_id(str): The transaction ID being completed.
service(ApplicationService): The application service which was sent service(ApplicationService): The application service which was sent
this transaction. this transaction.
Returns:
A Deferred which resolves if this transaction was stored
successfully.
""" """
txn_id = int(txn_id) txn_id = int(txn_id)
@ -258,7 +255,7 @@ class ApplicationServiceTransactionWorkerStore(
{"txn_id": txn_id, "as_id": service.id}, {"txn_id": txn_id, "as_id": service.id},
) )
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"complete_appservice_txn", _complete_appservice_txn "complete_appservice_txn", _complete_appservice_txn
) )
@ -312,13 +309,13 @@ class ApplicationServiceTransactionWorkerStore(
else: else:
return int(last_txn_id[0]) # select 'last_txn' col return int(last_txn_id[0]) # select 'last_txn' col
def set_appservice_last_pos(self, pos): async def set_appservice_last_pos(self, pos) -> None:
def set_appservice_last_pos_txn(txn): def set_appservice_last_pos_txn(txn):
txn.execute( txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
) )
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"set_appservice_last_pos", set_appservice_last_pos_txn "set_appservice_last_pos", set_appservice_last_pos_txn
) )

View File

@ -190,15 +190,15 @@ class DeviceInboxWorkerStore(SQLBaseStore):
) )
@trace @trace
def delete_device_msgs_for_remote(self, destination, up_to_stream_id): async def delete_device_msgs_for_remote(
self, destination: str, up_to_stream_id: int
) -> None:
"""Used to delete messages when the remote destination acknowledges """Used to delete messages when the remote destination acknowledges
their receipt. their receipt.
Args: Args:
destination(str): The destination server_name destination: The destination server_name
up_to_stream_id(int): Where to delete messages up to. up_to_stream_id: Where to delete messages up to.
Returns:
A deferred that resolves when the messages have been deleted.
""" """
def delete_messages_for_remote_destination_txn(txn): def delete_messages_for_remote_destination_txn(txn):
@ -209,7 +209,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
) )
txn.execute(sql, (destination, up_to_stream_id)) txn.execute(sql, (destination, up_to_stream_id))
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
) )

View File

@ -151,7 +151,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return sessions return sessions
def get_e2e_room_keys_multi(self, user_id, version, room_keys): async def get_e2e_room_keys_multi(self, user_id, version, room_keys):
"""Get multiple room keys at a time. The difference between this function and """Get multiple room keys at a time. The difference between this function and
get_e2e_room_keys is that this function can be used to retrieve get_e2e_room_keys is that this function can be used to retrieve
multiple specific keys at a time, whereas get_e2e_room_keys is used for multiple specific keys at a time, whereas get_e2e_room_keys is used for
@ -166,10 +166,10 @@ class EndToEndRoomKeyStore(SQLBaseStore):
that we want to query that we want to query
Returns: Returns:
Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key dict[str, dict[str, dict]]: a map of room IDs to session IDs to room key
""" """
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_e2e_room_keys_multi", "get_e2e_room_keys_multi",
self._get_e2e_room_keys_multi_txn, self._get_e2e_room_keys_multi_txn,
user_id, user_id,
@ -283,7 +283,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
raise StoreError(404, "No current backup version") raise StoreError(404, "No current backup version")
return row[0] return row[0]
def get_e2e_room_keys_version_info(self, user_id, version=None): async def get_e2e_room_keys_version_info(self, user_id, version=None):
"""Get info metadata about a version of our room_keys backup. """Get info metadata about a version of our room_keys backup.
Args: Args:
@ -293,7 +293,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
Raises: Raises:
StoreError: with code 404 if there are no e2e_room_keys_versions present StoreError: with code 404 if there are no e2e_room_keys_versions present
Returns: Returns:
A deferred dict giving the info metadata for this backup version, with A dict giving the info metadata for this backup version, with
fields including: fields including:
version(str) version(str)
algorithm(str) algorithm(str)
@ -324,12 +324,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
result["etag"] = 0 result["etag"] = 0
return result return result
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
) )
@trace @trace
def create_e2e_room_keys_version(self, user_id, info): async def create_e2e_room_keys_version(self, user_id: str, info: dict) -> str:
"""Atomically creates a new version of this user's e2e_room_keys store """Atomically creates a new version of this user's e2e_room_keys store
with the given version info. with the given version info.
@ -338,7 +338,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
info(dict): the info about the backup version to be created info(dict): the info about the backup version to be created
Returns: Returns:
A deferred string for the newly created version ID The newly created version ID
""" """
def _create_e2e_room_keys_version_txn(txn): def _create_e2e_room_keys_version_txn(txn):
@ -365,7 +365,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return new_version return new_version
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
) )
@ -403,13 +403,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
) )
@trace @trace
def delete_e2e_room_keys_version(self, user_id, version=None): async def delete_e2e_room_keys_version(
self, user_id: str, version: Optional[str] = None
) -> None:
"""Delete a given backup version of the user's room keys. """Delete a given backup version of the user's room keys.
Doesn't delete their actual key data. Doesn't delete their actual key data.
Args: Args:
user_id(str): the user whose backup version we're deleting user_id: the user whose backup version we're deleting
version(str): Optional. the version ID of the backup version we're deleting version: Optional. the version ID of the backup version we're deleting
If missing, we delete the current backup version info. If missing, we delete the current backup version info.
Raises: Raises:
StoreError: with code 404 if there are no e2e_room_keys_versions present, StoreError: with code 404 if there are no e2e_room_keys_versions present,
@ -430,13 +432,13 @@ class EndToEndRoomKeyStore(SQLBaseStore):
keyvalues={"user_id": user_id, "version": this_version}, keyvalues={"user_id": user_id, "version": this_version},
) )
return self.db_pool.simple_update_one_txn( self.db_pool.simple_update_one_txn(
txn, txn,
table="e2e_room_keys_versions", table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version}, keyvalues={"user_id": user_id, "version": this_version},
updatevalues={"deleted": 1}, updatevalues={"deleted": 1},
) )
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn
) )

View File

@ -59,7 +59,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
include_given: include the given events in result include_given: include the given events in result
Returns: Returns:
list of event_ids An awaitable which resolve to a list of event_ids
""" """
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_auth_chain_ids", "get_auth_chain_ids",
@ -95,7 +95,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return list(results) return list(results)
def get_auth_chain_difference(self, state_sets: List[Set[str]]): async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]:
"""Given sets of state events figure out the auth chain difference (as """Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm). per state res v2 algorithm).
@ -104,10 +104,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
chain. chain.
Returns: Returns:
Deferred[Set[str]] The set of the difference in auth chains.
""" """
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_auth_chain_difference", "get_auth_chain_difference",
self._get_auth_chain_difference_txn, self._get_auth_chain_difference_txn,
state_sets, state_sets,
@ -252,8 +252,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# Return all events where not all sets can reach them. # Return all events where not all sets can reach them.
return {eid for eid, n in event_to_missing_sets.items() if n} return {eid for eid, n in event_to_missing_sets.items() if n}
def get_oldest_events_with_depth_in_room(self, room_id): async def get_oldest_events_with_depth_in_room(self, room_id):
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_oldest_events_with_depth_in_room", "get_oldest_events_with_depth_in_room",
self.get_oldest_events_with_depth_in_room_txn, self.get_oldest_events_with_depth_in_room_txn,
room_id, room_id,
@ -293,7 +293,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
else: else:
return max(row["depth"] for row in rows) return max(row["depth"] for row in rows)
def get_prev_events_for_room(self, room_id: str): async def get_prev_events_for_room(self, room_id: str) -> List[str]:
""" """
Gets a subset of the current forward extremities in the given room. Gets a subset of the current forward extremities in the given room.
@ -301,14 +301,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
events which refer to hundreds of prev_events. events which refer to hundreds of prev_events.
Args: Args:
room_id (str): room_id room_id: room_id
Returns: Returns:
Deferred[List[str]]: the event ids of the forward extremites The event ids of the forward extremities.
""" """
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
) )
@ -328,17 +328,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return [row[0] for row in txn] return [row[0] for row in txn]
def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter): async def get_rooms_with_many_extremities(
self, min_count: int, limit: int, room_id_filter: Iterable[str]
) -> List[str]:
"""Get the top rooms with at least N extremities. """Get the top rooms with at least N extremities.
Args: Args:
min_count (int): The minimum number of extremities min_count: The minimum number of extremities
limit (int): The maximum number of rooms to return. limit: The maximum number of rooms to return.
room_id_filter (iterable[str]): room_ids to exclude from the results room_id_filter: room_ids to exclude from the results
Returns: Returns:
Deferred[list]: At most `limit` room IDs that have at least At most `limit` room IDs that have at least `min_count` extremities,
`min_count` extremities, sorted by extremity count. sorted by extremity count.
""" """
def _get_rooms_with_many_extremities_txn(txn): def _get_rooms_with_many_extremities_txn(txn):
@ -363,7 +365,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(sql, query_args) txn.execute(sql, query_args)
return [room_id for room_id, in txn] return [room_id for room_id, in txn]
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn
) )
@ -376,10 +378,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
desc="get_latest_event_ids_in_room", desc="get_latest_event_ids_in_room",
) )
def get_min_depth(self, room_id): async def get_min_depth(self, room_id: str) -> int:
""" For hte given room, get the minimum depth we have seen for it. """For the given room, get the minimum depth we have seen for it.
""" """
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_min_depth", self._get_min_depth_interaction, room_id "get_min_depth", self._get_min_depth_interaction, room_id
) )
@ -394,7 +396,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return int(min_depth) if min_depth is not None else None return int(min_depth) if min_depth is not None else None
def get_forward_extremeties_for_room(self, room_id, stream_ordering): async def get_forward_extremeties_for_room(
self, room_id: str, stream_ordering: int
) -> List[str]:
"""For a given room_id and stream_ordering, return the forward """For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time". extremeties of the room at that point in "time".
@ -402,11 +406,11 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
stream_orderings from that point. stream_orderings from that point.
Args: Args:
room_id (str): room_id:
stream_ordering (int): stream_ordering:
Returns: Returns:
deferred, which resolves to a list of event_ids A list of event_ids
""" """
# We want to make the cache more effective, so we clamp to the last # We want to make the cache more effective, so we clamp to the last
# change before the given ordering. # change before the given ordering.
@ -422,10 +426,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
if last_change > self.stream_ordering_month_ago: if last_change > self.stream_ordering_month_ago:
stream_ordering = min(last_change, stream_ordering) stream_ordering = min(last_change, stream_ordering)
return self._get_forward_extremeties_for_room(room_id, stream_ordering) return await self._get_forward_extremeties_for_room(room_id, stream_ordering)
@cached(max_entries=5000, num_args=2) @cached(max_entries=5000, num_args=2)
def _get_forward_extremeties_for_room(self, room_id, stream_ordering): async def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
"""For a given room_id and stream_ordering, return the forward """For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time". extremeties of the room at that point in "time".
@ -450,19 +454,18 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(sql, (stream_ordering, room_id)) txn.execute(sql, (stream_ordering, room_id))
return [event_id for event_id, in txn] return [event_id for event_id, in txn]
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
) )
async def get_backfill_events(self, room_id, event_list, limit): async def get_backfill_events(self, room_id: str, event_list: list, limit: int):
"""Get a list of Events for a given topic that occurred before (and """Get a list of Events for a given topic that occurred before (and
including) the events in event_list. Return a list of max size `limit` including) the events in event_list. Return a list of max size `limit`
Args: Args:
txn room_id
room_id (str) event_list
event_list (list) limit
limit (int)
""" """
event_ids = await self.db_pool.runInteraction( event_ids = await self.db_pool.runInteraction(
"get_backfill_events", "get_backfill_events",
@ -631,8 +634,8 @@ class EventFederationStore(EventFederationWorkerStore):
_delete_old_forward_extrem_cache_txn, _delete_old_forward_extrem_cache_txn,
) )
def clean_room_for_join(self, room_id): async def clean_room_for_join(self, room_id):
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"clean_room_for_join", self._clean_room_for_join_txn, room_id "clean_room_for_join", self._clean_room_for_join_txn, room_id
) )

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple, Union
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
@ -70,7 +70,9 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_invited_users_in_group", desc="get_invited_users_in_group",
) )
def get_rooms_in_group(self, group_id: str, include_private: bool = False): async def get_rooms_in_group(
self, group_id: str, include_private: bool = False
) -> List[Dict[str, Union[str, bool]]]:
"""Retrieve the rooms that belong to a given group. Does not return rooms that """Retrieve the rooms that belong to a given group. Does not return rooms that
lack members. lack members.
@ -79,8 +81,7 @@ class GroupServerWorkerStore(SQLBaseStore):
include_private: Whether to return private rooms in results include_private: Whether to return private rooms in results
Returns: Returns:
Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the A list of dictionaries, each in the form of:
form of:
{ {
"room_id": "!a_room_id:example.com", # The ID of the room "room_id": "!a_room_id:example.com", # The ID of the room
@ -117,13 +118,13 @@ class GroupServerWorkerStore(SQLBaseStore):
for room_id, is_public in txn for room_id, is_public in txn
] ]
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_rooms_in_group", _get_rooms_in_group_txn "get_rooms_in_group", _get_rooms_in_group_txn
) )
def get_rooms_for_summary_by_category( async def get_rooms_for_summary_by_category(
self, group_id: str, include_private: bool = False, self, group_id: str, include_private: bool = False,
): ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
"""Get the rooms and categories that should be included in a summary request """Get the rooms and categories that should be included in a summary request
Args: Args:
@ -131,7 +132,7 @@ class GroupServerWorkerStore(SQLBaseStore):
include_private: Whether to return private rooms in results include_private: Whether to return private rooms in results
Returns: Returns:
Deferred[Tuple[List, Dict]]: A tuple containing: A tuple containing:
* A list of dictionaries with the keys: * A list of dictionaries with the keys:
* "room_id": str, the room ID * "room_id": str, the room ID
@ -207,7 +208,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return rooms, categories return rooms, categories
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_rooms_for_summary", _get_rooms_for_summary_txn "get_rooms_for_summary", _get_rooms_for_summary_txn
) )
@ -281,10 +282,11 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_local_groups_for_room", desc="get_local_groups_for_room",
) )
def get_users_for_summary_by_role(self, group_id, include_private=False): async def get_users_for_summary_by_role(self, group_id, include_private=False):
"""Get the users and roles that should be included in a summary request """Get the users and roles that should be included in a summary request
Returns ([users], [roles]) Returns:
([users], [roles])
""" """
def _get_users_for_summary_txn(txn): def _get_users_for_summary_txn(txn):
@ -338,7 +340,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return users, roles return users, roles
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_users_for_summary_by_role", _get_users_for_summary_txn "get_users_for_summary_by_role", _get_users_for_summary_txn
) )
@ -376,7 +378,7 @@ class GroupServerWorkerStore(SQLBaseStore):
allow_none=True, allow_none=True,
) )
def get_users_membership_info_in_group(self, group_id, user_id): async def get_users_membership_info_in_group(self, group_id, user_id):
"""Get a dict describing the membership of a user in a group. """Get a dict describing the membership of a user in a group.
Example if joined: Example if joined:
@ -387,7 +389,8 @@ class GroupServerWorkerStore(SQLBaseStore):
"is_privileged": False, "is_privileged": False,
} }
Returns an empty dict if the user is not join/invite/etc Returns:
An empty dict if the user is not join/invite/etc
""" """
def _get_users_membership_in_group_txn(txn): def _get_users_membership_in_group_txn(txn):
@ -419,7 +422,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return {} return {}
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_users_membership_info_in_group", _get_users_membership_in_group_txn "get_users_membership_info_in_group", _get_users_membership_in_group_txn
) )
@ -433,7 +436,7 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_publicised_groups_for_user", desc="get_publicised_groups_for_user",
) )
def get_attestations_need_renewals(self, valid_until_ms): async def get_attestations_need_renewals(self, valid_until_ms):
"""Get all attestations that need to be renewed until givent time """Get all attestations that need to be renewed until givent time
""" """
@ -445,7 +448,7 @@ class GroupServerWorkerStore(SQLBaseStore):
txn.execute(sql, (valid_until_ms,)) txn.execute(sql, (valid_until_ms,))
return self.db_pool.cursor_to_dict(txn) return self.db_pool.cursor_to_dict(txn)
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_attestations_need_renewals", _get_attestations_need_renewals_txn "get_attestations_need_renewals", _get_attestations_need_renewals_txn
) )
@ -475,7 +478,7 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_joined_groups", desc="get_joined_groups",
) )
def get_all_groups_for_user(self, user_id, now_token): async def get_all_groups_for_user(self, user_id, now_token):
def _get_all_groups_for_user_txn(txn): def _get_all_groups_for_user_txn(txn):
sql = """ sql = """
SELECT group_id, type, membership, u.content SELECT group_id, type, membership, u.content
@ -495,7 +498,7 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in txn for row in txn
] ]
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_all_groups_for_user", _get_all_groups_for_user_txn "get_all_groups_for_user", _get_all_groups_for_user_txn
) )
@ -600,8 +603,27 @@ class GroupServerStore(GroupServerWorkerStore):
desc="set_group_join_policy", desc="set_group_join_policy",
) )
def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): async def add_room_to_summary(
return self.db_pool.runInteraction( self,
group_id: str,
room_id: str,
category_id: str,
order: int,
is_public: Optional[bool],
) -> None:
"""Add (or update) room's entry in summary.
Args:
group_id
room_id
category_id: If not None then adds the category to the end of
the summary if its not already there.
order: If not None inserts the room at that position, e.g. an order
of 1 will put the room first. Otherwise, the room gets added to
the end.
is_public
"""
await self.db_pool.runInteraction(
"add_room_to_summary", "add_room_to_summary",
self._add_room_to_summary_txn, self._add_room_to_summary_txn,
group_id, group_id,
@ -612,18 +634,26 @@ class GroupServerStore(GroupServerWorkerStore):
) )
def _add_room_to_summary_txn( def _add_room_to_summary_txn(
self, txn, group_id, room_id, category_id, order, is_public self,
): txn,
group_id: str,
room_id: str,
category_id: str,
order: int,
is_public: Optional[bool],
) -> None:
"""Add (or update) room's entry in summary. """Add (or update) room's entry in summary.
Args: Args:
group_id (str) txn
room_id (str) group_id
category_id (str): If not None then adds the category to the end of room_id
the summary if its not already there. [Optional] category_id: If not None then adds the category to the end of
order (int): If not None inserts the room at that position, e.g. the summary if its not already there.
an order of 1 will put the room first. Otherwise, the room gets order: If not None inserts the room at that position, e.g. an order
added to the end. of 1 will put the room first. Otherwise, the room gets added to
the end.
is_public
""" """
room_in_group = self.db_pool.simple_select_one_onecol_txn( room_in_group = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
@ -818,8 +848,27 @@ class GroupServerStore(GroupServerWorkerStore):
desc="remove_group_role", desc="remove_group_role",
) )
def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): async def add_user_to_summary(
return self.db_pool.runInteraction( self,
group_id: str,
user_id: str,
role_id: str,
order: int,
is_public: Optional[bool],
) -> None:
"""Add (or update) user's entry in summary.
Args:
group_id
user_id
role_id: If not None then adds the role to the end of the summary if
its not already there.
order: If not None inserts the user at that position, e.g. an order
of 1 will put the user first. Otherwise, the user gets added to
the end.
is_public
"""
await self.db_pool.runInteraction(
"add_user_to_summary", "add_user_to_summary",
self._add_user_to_summary_txn, self._add_user_to_summary_txn,
group_id, group_id,
@ -830,18 +879,26 @@ class GroupServerStore(GroupServerWorkerStore):
) )
def _add_user_to_summary_txn( def _add_user_to_summary_txn(
self, txn, group_id, user_id, role_id, order, is_public self,
txn,
group_id: str,
user_id: str,
role_id: str,
order: int,
is_public: Optional[bool],
): ):
"""Add (or update) user's entry in summary. """Add (or update) user's entry in summary.
Args: Args:
group_id (str) txn
user_id (str) group_id
role_id (str): If not None then adds the role to the end of user_id
the summary if its not already there. [Optional] role_id: If not None then adds the role to the end of the summary if
order (int): If not None inserts the user at that position, e.g. its not already there.
an order of 1 will put the user first. Otherwise, the user gets order: If not None inserts the user at that position, e.g. an order
added to the end. of 1 will put the user first. Otherwise, the user gets added to
the end.
is_public
""" """
user_in_group = self.db_pool.simple_select_one_onecol_txn( user_in_group = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
@ -963,27 +1020,26 @@ class GroupServerStore(GroupServerWorkerStore):
desc="add_group_invite", desc="add_group_invite",
) )
def add_user_to_group( async def add_user_to_group(
self, self,
group_id, group_id: str,
user_id, user_id: str,
is_admin=False, is_admin: bool = False,
is_public=True, is_public: bool = True,
local_attestation=None, local_attestation: dict = None,
remote_attestation=None, remote_attestation: dict = None,
): ) -> None:
"""Add a user to the group server. """Add a user to the group server.
Args: Args:
group_id (str) group_id
user_id (str) user_id
is_admin (bool) is_admin
is_public (bool) is_public
local_attestation (dict): The attestation the GS created to give local_attestation: The attestation the GS created to give to the remote
to the remote server. Optional if the user and group are on the
same server
remote_attestation (dict): The attestation given to GS by remote
server. Optional if the user and group are on the same server server. Optional if the user and group are on the same server
remote_attestation: The attestation given to GS by remote server.
Optional if the user and group are on the same server
""" """
def _add_user_to_group_txn(txn): def _add_user_to_group_txn(txn):
@ -1026,9 +1082,9 @@ class GroupServerStore(GroupServerWorkerStore):
}, },
) )
return self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn) await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
def remove_user_from_group(self, group_id, user_id): async def remove_user_from_group(self, group_id: str, user_id: str) -> None:
def _remove_user_from_group_txn(txn): def _remove_user_from_group_txn(txn):
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
@ -1056,7 +1112,7 @@ class GroupServerStore(GroupServerWorkerStore):
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
) )
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"remove_user_from_group", _remove_user_from_group_txn "remove_user_from_group", _remove_user_from_group_txn
) )
@ -1079,7 +1135,7 @@ class GroupServerStore(GroupServerWorkerStore):
desc="update_room_in_group_visibility", desc="update_room_in_group_visibility",
) )
def remove_room_from_group(self, group_id, room_id): async def remove_room_from_group(self, group_id: str, room_id: str) -> None:
def _remove_room_from_group_txn(txn): def _remove_room_from_group_txn(txn):
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
@ -1093,7 +1149,7 @@ class GroupServerStore(GroupServerWorkerStore):
keyvalues={"group_id": group_id, "room_id": room_id}, keyvalues={"group_id": group_id, "room_id": room_id},
) )
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"remove_room_from_group", _remove_room_from_group_txn "remove_room_from_group", _remove_room_from_group_txn
) )
@ -1286,14 +1342,11 @@ class GroupServerStore(GroupServerWorkerStore):
def get_group_stream_token(self): def get_group_stream_token(self):
return self._group_updates_id_gen.get_current_token() return self._group_updates_id_gen.get_current_token()
def delete_group(self, group_id): async def delete_group(self, group_id: str) -> None:
"""Deletes a group fully from the database. """Deletes a group fully from the database.
Args: Args:
group_id (str) group_id: The group ID to delete.
Returns:
Deferred
""" """
def _delete_group_txn(txn): def _delete_group_txn(txn):
@ -1317,4 +1370,4 @@ class GroupServerStore(GroupServerWorkerStore):
txn, table=table, keyvalues={"group_id": group_id} txn, table=table, keyvalues={"group_id": group_id}
) )
return self.db_pool.runInteraction("delete_group", _delete_group_txn) await self.db_pool.runInteraction("delete_group", _delete_group_txn)

View File

@ -16,7 +16,7 @@
import itertools import itertools
import logging import logging
from typing import Iterable, Tuple from typing import Dict, Iterable, List, Optional, Tuple
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
@ -42,16 +42,17 @@ class KeyStore(SQLBaseStore):
@cachedList( @cachedList(
cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids" cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids"
) )
def get_server_verify_keys(self, server_name_and_key_ids): async def get_server_verify_keys(
self, server_name_and_key_ids: Iterable[Tuple[str, str]]
) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]:
""" """
Args: Args:
server_name_and_key_ids (iterable[Tuple[str, str]]): server_name_and_key_ids:
iterable of (server_name, key-id) tuples to fetch keys for iterable of (server_name, key-id) tuples to fetch keys for
Returns: Returns:
Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]: A map from (server_name, key_id) -> FetchKeyResult, or None if the
map from (server_name, key_id) -> FetchKeyResult, or None if the key is key is unknown
unknown
""" """
keys = {} keys = {}
@ -87,7 +88,7 @@ class KeyStore(SQLBaseStore):
_get_keys(txn, batch) _get_keys(txn, batch)
return keys return keys
return self.db_pool.runInteraction("get_server_verify_keys", _txn) return await self.db_pool.runInteraction("get_server_verify_keys", _txn)
async def store_server_verify_keys( async def store_server_verify_keys(
self, self,
@ -179,7 +180,9 @@ class KeyStore(SQLBaseStore):
desc="store_server_keys_json", desc="store_server_keys_json",
) )
def get_server_keys_json(self, server_keys): async def get_server_keys_json(
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]:
"""Retrive the key json for a list of server_keys and key ids. """Retrive the key json for a list of server_keys and key ids.
If no keys are found for a given server, key_id and source then If no keys are found for a given server, key_id and source then
that server, key_id, and source triplet entry will be an empty list. that server, key_id, and source triplet entry will be an empty list.
@ -188,8 +191,7 @@ class KeyStore(SQLBaseStore):
Args: Args:
server_keys (list): List of (server_name, key_id, source) triplets. server_keys (list): List of (server_name, key_id, source) triplets.
Returns: Returns:
Deferred[dict[Tuple[str, str, str|None], list[dict]]]: A mapping from (server_name, key_id, source) triplets to a list of dicts
Dict mapping (server_name, key_id, source) triplets to lists of dicts
""" """
def _get_server_keys_json_txn(txn): def _get_server_keys_json_txn(txn):
@ -215,6 +217,6 @@ class KeyStore(SQLBaseStore):
results[(server_name, key_id, from_server)] = rows results[(server_name, key_id, from_server)] = rows
return results return results
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_server_keys_json", _get_server_keys_json_txn "get_server_keys_json", _get_server_keys_json_txn
) )

View File

@ -15,6 +15,7 @@
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Optional, Tuple
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
@ -56,21 +57,23 @@ class TransactionStore(SQLBaseStore):
expiry_ms=5 * 60 * 1000, expiry_ms=5 * 60 * 1000,
) )
def get_received_txn_response(self, transaction_id, origin): async def get_received_txn_response(
self, transaction_id: str, origin: str
) -> Optional[Tuple[int, JsonDict]]:
"""For an incoming transaction from a given origin, check if we have """For an incoming transaction from a given origin, check if we have
already responded to it. If so, return the response code and response already responded to it. If so, return the response code and response
body (as a dict). body (as a dict).
Args: Args:
transaction_id (str) transaction_id
origin(str) origin
Returns: Returns:
tuple: None if we have not previously responded to None if we have not previously responded to this transaction or a
this transaction or a 2-tuple of (int, dict) 2-tuple of (int, dict)
""" """
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_received_txn_response", "get_received_txn_response",
self._get_received_txn_response, self._get_received_txn_response,
transaction_id, transaction_id,
@ -166,21 +169,25 @@ class TransactionStore(SQLBaseStore):
else: else:
return None return None
def set_destination_retry_timings( async def set_destination_retry_timings(
self, destination, failure_ts, retry_last_ts, retry_interval self,
): destination: str,
failure_ts: Optional[int],
retry_last_ts: int,
retry_interval: int,
) -> None:
"""Sets the current retry timings for a given destination. """Sets the current retry timings for a given destination.
Both timings should be zero if retrying is no longer occuring. Both timings should be zero if retrying is no longer occuring.
Args: Args:
destination (str) destination
failure_ts (int|None) - when the server started failing (ms since epoch) failure_ts: when the server started failing (ms since epoch)
retry_last_ts (int) - time of last retry attempt in unix epoch ms retry_last_ts: time of last retry attempt in unix epoch ms
retry_interval (int) - how long until next retry in ms retry_interval: how long until next retry in ms
""" """
self._destination_retry_cache.pop(destination, None) self._destination_retry_cache.pop(destination, None)
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"set_destination_retry_timings", "set_destination_retry_timings",
self._set_destination_retry_timings, self._set_destination_retry_timings,
destination, destination,
@ -256,13 +263,13 @@ class TransactionStore(SQLBaseStore):
"cleanup_transactions", self._cleanup_transactions "cleanup_transactions", self._cleanup_transactions
) )
def _cleanup_transactions(self): async def _cleanup_transactions(self) -> None:
now = self._clock.time_msec() now = self._clock.time_msec()
month_ago = now - 30 * 24 * 60 * 60 * 1000 month_ago = now - 30 * 24 * 60 * 60 * 1000
def _cleanup_transactions_txn(txn): def _cleanup_transactions_txn(txn):
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"_cleanup_transactions", _cleanup_transactions_txn "_cleanup_transactions", _cleanup_transactions_txn
) )