mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-01-04 22:20:49 -05:00
Convert additional databases to async/await part 2 (#8200)
This commit is contained in:
parent
bbb3c8641c
commit
da77520cd1
1
changelog.d/8200.misc
Normal file
1
changelog.d/8200.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Convert various parts of the codebase to async/await.
|
@ -12,7 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Optional
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from nacl.signing import SigningKey
|
from nacl.signing import SigningKey
|
||||||
@ -97,14 +97,14 @@ class EventBuilder(object):
|
|||||||
def is_state(self):
|
def is_state(self):
|
||||||
return self._state_key is not None
|
return self._state_key is not None
|
||||||
|
|
||||||
async def build(self, prev_event_ids):
|
async def build(self, prev_event_ids: List[str]) -> EventBase:
|
||||||
"""Transform into a fully signed and hashed event
|
"""Transform into a fully signed and hashed event
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prev_event_ids (list[str]): The event IDs to use as the prev events
|
prev_event_ids: The event IDs to use as the prev events
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
FrozenEvent
|
The signed and hashed event.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
state_ids = await self._state.get_current_state_ids(
|
state_ids = await self._state.get_current_state_ids(
|
||||||
@ -114,8 +114,13 @@ class EventBuilder(object):
|
|||||||
|
|
||||||
format_version = self.room_version.event_format
|
format_version = self.room_version.event_format
|
||||||
if format_version == EventFormatVersions.V1:
|
if format_version == EventFormatVersions.V1:
|
||||||
auth_events = await self._store.add_event_hashes(auth_ids)
|
# The types of auth/prev events changes between event versions.
|
||||||
prev_events = await self._store.add_event_hashes(prev_event_ids)
|
auth_events = await self._store.add_event_hashes(
|
||||||
|
auth_ids
|
||||||
|
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
|
||||||
|
prev_events = await self._store.add_event_hashes(
|
||||||
|
prev_event_ids
|
||||||
|
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
|
||||||
else:
|
else:
|
||||||
auth_events = auth_ids
|
auth_events = auth_ids
|
||||||
prev_events = prev_event_ids
|
prev_events = prev_event_ids
|
||||||
@ -138,7 +143,7 @@ class EventBuilder(object):
|
|||||||
"unsigned": self.unsigned,
|
"unsigned": self.unsigned,
|
||||||
"depth": depth,
|
"depth": depth,
|
||||||
"prev_state": [],
|
"prev_state": [],
|
||||||
}
|
} # type: Dict[str, Any]
|
||||||
|
|
||||||
if self.is_state():
|
if self.is_state():
|
||||||
event_dict["state_key"] = self._state_key
|
event_dict["state_key"] = self._state_key
|
||||||
|
@ -49,14 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
|||||||
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
|
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
|
||||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.types import (
|
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
|
||||||
Collection,
|
|
||||||
Requester,
|
|
||||||
RoomAlias,
|
|
||||||
StreamToken,
|
|
||||||
UserID,
|
|
||||||
create_requester,
|
|
||||||
)
|
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.frozenutils import frozendict_json_encoder
|
from synapse.util.frozenutils import frozendict_json_encoder
|
||||||
@ -446,7 +439,7 @@ class EventCreationHandler(object):
|
|||||||
event_dict: dict,
|
event_dict: dict,
|
||||||
token_id: Optional[str] = None,
|
token_id: Optional[str] = None,
|
||||||
txn_id: Optional[str] = None,
|
txn_id: Optional[str] = None,
|
||||||
prev_event_ids: Optional[Collection[str]] = None,
|
prev_event_ids: Optional[List[str]] = None,
|
||||||
require_consent: bool = True,
|
require_consent: bool = True,
|
||||||
) -> Tuple[EventBase, EventContext]:
|
) -> Tuple[EventBase, EventContext]:
|
||||||
"""
|
"""
|
||||||
@ -786,7 +779,7 @@ class EventCreationHandler(object):
|
|||||||
self,
|
self,
|
||||||
builder: EventBuilder,
|
builder: EventBuilder,
|
||||||
requester: Optional[Requester] = None,
|
requester: Optional[Requester] = None,
|
||||||
prev_event_ids: Optional[Collection[str]] = None,
|
prev_event_ids: Optional[List[str]] = None,
|
||||||
) -> Tuple[EventBase, EventContext]:
|
) -> Tuple[EventBase, EventContext]:
|
||||||
"""Create a new event for a local client
|
"""Create a new event for a local client
|
||||||
|
|
||||||
|
@ -38,15 +38,7 @@ from synapse.events.builder import create_local_event_from_event_dict
|
|||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.events.validator import EventValidator
|
from synapse.events.validator import EventValidator
|
||||||
from synapse.storage.roommember import RoomsForUser
|
from synapse.storage.roommember import RoomsForUser
|
||||||
from synapse.types import (
|
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
|
||||||
Collection,
|
|
||||||
JsonDict,
|
|
||||||
Requester,
|
|
||||||
RoomAlias,
|
|
||||||
RoomID,
|
|
||||||
StateMap,
|
|
||||||
UserID,
|
|
||||||
)
|
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.distributor import user_joined_room, user_left_room
|
from synapse.util.distributor import user_joined_room, user_left_room
|
||||||
|
|
||||||
@ -184,7 +176,7 @@ class RoomMemberHandler(object):
|
|||||||
target: UserID,
|
target: UserID,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
membership: str,
|
membership: str,
|
||||||
prev_event_ids: Collection[str],
|
prev_event_ids: List[str],
|
||||||
txn_id: Optional[str] = None,
|
txn_id: Optional[str] = None,
|
||||||
ratelimit: bool = True,
|
ratelimit: bool = True,
|
||||||
content: Optional[dict] = None,
|
content: Optional[dict] = None,
|
||||||
|
@ -396,7 +396,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
|
|||||||
self._batch_row_update[key] = (user_agent, device_id, now)
|
self._batch_row_update[key] = (user_agent, device_id, now)
|
||||||
|
|
||||||
@wrap_as_background_process("update_client_ips")
|
@wrap_as_background_process("update_client_ips")
|
||||||
def _update_client_ips_batch(self):
|
async def _update_client_ips_batch(self) -> None:
|
||||||
|
|
||||||
# If the DB pool has already terminated, don't try updating
|
# If the DB pool has already terminated, don't try updating
|
||||||
if not self.db_pool.is_running():
|
if not self.db_pool.is_running():
|
||||||
@ -405,7 +405,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
|
|||||||
to_update = self._batch_row_update
|
to_update = self._batch_row_update
|
||||||
self._batch_row_update = {}
|
self._batch_row_update = {}
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
|
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -159,9 +159,9 @@ class DirectoryStore(DirectoryWorkerStore):
|
|||||||
|
|
||||||
return room_id
|
return room_id
|
||||||
|
|
||||||
def update_aliases_for_room(
|
async def update_aliases_for_room(
|
||||||
self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
|
self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
|
||||||
):
|
) -> None:
|
||||||
"""Repoint all of the aliases for a given room, to a different room.
|
"""Repoint all of the aliases for a given room, to a different room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -189,6 +189,6 @@ class DirectoryStore(DirectoryWorkerStore):
|
|||||||
txn, self.get_aliases_for_room, (new_room_id,)
|
txn, self.get_aliases_for_room, (new_room_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"_update_aliases_for_room_txn", _update_aliases_for_room_txn
|
"_update_aliases_for_room_txn", _update_aliases_for_room_txn
|
||||||
)
|
)
|
||||||
|
@ -17,6 +17,7 @@ from canonicaljson import encode_canonical_json
|
|||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
|
from synapse.types import JsonDict
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
|
|
||||||
@ -40,7 +41,7 @@ class FilteringStore(SQLBaseStore):
|
|||||||
|
|
||||||
return db_to_json(def_json)
|
return db_to_json(def_json)
|
||||||
|
|
||||||
def add_user_filter(self, user_localpart, user_filter):
|
async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str:
|
||||||
def_json = encode_canonical_json(user_filter)
|
def_json = encode_canonical_json(user_filter)
|
||||||
|
|
||||||
# Need an atomic transaction to SELECT the maximal ID so far then
|
# Need an atomic transaction to SELECT the maximal ID so far then
|
||||||
@ -71,4 +72,4 @@ class FilteringStore(SQLBaseStore):
|
|||||||
|
|
||||||
return filter_id
|
return filter_id
|
||||||
|
|
||||||
return self.db_pool.runInteraction("add_user_filter", _do_txn)
|
return await self.db_pool.runInteraction("add_user_filter", _do_txn)
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
|
|
||||||
|
|
||||||
@ -15,7 +17,9 @@ class OpenIdStore(SQLBaseStore):
|
|||||||
desc="insert_open_id_token",
|
desc="insert_open_id_token",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_user_id_for_open_id_token(self, token, ts_now_ms):
|
async def get_user_id_for_open_id_token(
|
||||||
|
self, token: str, ts_now_ms: int
|
||||||
|
) -> Optional[str]:
|
||||||
def get_user_id_for_token_txn(txn):
|
def get_user_id_for_token_txn(txn):
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT user_id FROM open_id_tokens"
|
"SELECT user_id FROM open_id_tokens"
|
||||||
@ -30,6 +34,6 @@ class OpenIdStore(SQLBaseStore):
|
|||||||
else:
|
else:
|
||||||
return rows[0][0]
|
return rows[0][0]
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_user_id_for_token", get_user_id_for_token_txn
|
"get_user_id_for_token", get_user_id_for_token_txn
|
||||||
)
|
)
|
||||||
|
@ -138,7 +138,9 @@ class ProfileStore(ProfileWorkerStore):
|
|||||||
desc="delete_remote_profile_cache",
|
desc="delete_remote_profile_cache",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_remote_profile_cache_entries_that_expire(self, last_checked):
|
async def get_remote_profile_cache_entries_that_expire(
|
||||||
|
self, last_checked: int
|
||||||
|
) -> Dict[str, str]:
|
||||||
"""Get all users who haven't been checked since `last_checked`
|
"""Get all users who haven't been checked since `last_checked`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -153,7 +155,7 @@ class ProfileStore(ProfileWorkerStore):
|
|||||||
|
|
||||||
return self.db_pool.cursor_to_dict(txn)
|
return self.db_pool.cursor_to_dict(txn)
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_remote_profile_cache_entries_that_expire",
|
"get_remote_profile_cache_entries_that_expire",
|
||||||
_get_remote_profile_cache_entries_that_expire_txn,
|
_get_remote_profile_cache_entries_that_expire_txn,
|
||||||
)
|
)
|
||||||
|
@ -18,8 +18,6 @@ import abc
|
|||||||
import logging
|
import logging
|
||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.push.baserules import list_with_base_rules
|
from synapse.push.baserules import list_with_base_rules
|
||||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
@ -149,9 +147,11 @@ class PushRulesWorkerStore(
|
|||||||
)
|
)
|
||||||
return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
|
return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
|
||||||
|
|
||||||
def have_push_rules_changed_for_user(self, user_id, last_id):
|
async def have_push_rules_changed_for_user(
|
||||||
|
self, user_id: str, last_id: int
|
||||||
|
) -> bool:
|
||||||
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
|
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
|
||||||
return defer.succeed(False)
|
return False
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def have_push_rules_changed_txn(txn):
|
def have_push_rules_changed_txn(txn):
|
||||||
@ -163,7 +163,7 @@ class PushRulesWorkerStore(
|
|||||||
(count,) = txn.fetchone()
|
(count,) = txn.fetchone()
|
||||||
return bool(count)
|
return bool(count)
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"have_push_rules_changed", have_push_rules_changed_txn
|
"have_push_rules_changed", have_push_rules_changed_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -89,7 +89,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
allow_none=True,
|
allow_none=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_room_with_stats(self, room_id: str):
|
async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]:
|
||||||
"""Retrieve room with statistics.
|
"""Retrieve room with statistics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -121,7 +121,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
res["public"] = bool(res["public"])
|
res["public"] = bool(res["public"])
|
||||||
return res
|
return res
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_room_with_stats", get_room_with_stats_txn, room_id
|
"get_room_with_stats", get_room_with_stats_txn, room_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -133,13 +133,17 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
desc="get_public_room_ids",
|
desc="get_public_room_ids",
|
||||||
)
|
)
|
||||||
|
|
||||||
def count_public_rooms(self, network_tuple, ignore_non_federatable):
|
async def count_public_rooms(
|
||||||
|
self,
|
||||||
|
network_tuple: Optional[ThirdPartyInstanceID],
|
||||||
|
ignore_non_federatable: bool,
|
||||||
|
) -> int:
|
||||||
"""Counts the number of public rooms as tracked in the room_stats_current
|
"""Counts the number of public rooms as tracked in the room_stats_current
|
||||||
and room_stats_state table.
|
and room_stats_state table.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
network_tuple (ThirdPartyInstanceID|None)
|
network_tuple
|
||||||
ignore_non_federatable (bool): If true filters out non-federatable rooms
|
ignore_non_federatable: If true filters out non-federatable rooms
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _count_public_rooms_txn(txn):
|
def _count_public_rooms_txn(txn):
|
||||||
@ -183,7 +187,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
txn.execute(sql, query_args)
|
txn.execute(sql, query_args)
|
||||||
return txn.fetchone()[0]
|
return txn.fetchone()[0]
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"count_public_rooms", _count_public_rooms_txn
|
"count_public_rooms", _count_public_rooms_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -586,15 +590,14 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
return row
|
return row
|
||||||
|
|
||||||
def get_media_mxcs_in_room(self, room_id):
|
async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
|
||||||
"""Retrieves all the local and remote media MXC URIs in a given room
|
"""Retrieves all the local and remote media MXC URIs in a given room
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_id (str)
|
room_id
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The local and remote media as a lists of tuples where the key is
|
The local and remote media as a lists of the media IDs.
|
||||||
the hostname and the value is the media ID.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_media_mxcs_in_room_txn(txn):
|
def _get_media_mxcs_in_room_txn(txn):
|
||||||
@ -610,11 +613,13 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
return local_media_mxcs, remote_media_mxcs
|
return local_media_mxcs, remote_media_mxcs
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_media_ids_in_room", _get_media_mxcs_in_room_txn
|
"get_media_ids_in_room", _get_media_mxcs_in_room_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def quarantine_media_ids_in_room(self, room_id, quarantined_by):
|
async def quarantine_media_ids_in_room(
|
||||||
|
self, room_id: str, quarantined_by: str
|
||||||
|
) -> int:
|
||||||
"""For a room loops through all events with media and quarantines
|
"""For a room loops through all events with media and quarantines
|
||||||
the associated media
|
the associated media
|
||||||
"""
|
"""
|
||||||
@ -627,7 +632,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
txn, local_mxcs, remote_mxcs, quarantined_by
|
txn, local_mxcs, remote_mxcs, quarantined_by
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"quarantine_media_in_room", _quarantine_media_in_room_txn
|
"quarantine_media_in_room", _quarantine_media_in_room_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -690,9 +695,9 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
return local_media_mxcs, remote_media_mxcs
|
return local_media_mxcs, remote_media_mxcs
|
||||||
|
|
||||||
def quarantine_media_by_id(
|
async def quarantine_media_by_id(
|
||||||
self, server_name: str, media_id: str, quarantined_by: str,
|
self, server_name: str, media_id: str, quarantined_by: str,
|
||||||
):
|
) -> int:
|
||||||
"""quarantines a single local or remote media id
|
"""quarantines a single local or remote media id
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -711,11 +716,13 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
txn, local_mxcs, remote_mxcs, quarantined_by
|
txn, local_mxcs, remote_mxcs, quarantined_by
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"quarantine_media_by_user", _quarantine_media_by_id_txn
|
"quarantine_media_by_user", _quarantine_media_by_id_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str):
|
async def quarantine_media_ids_by_user(
|
||||||
|
self, user_id: str, quarantined_by: str
|
||||||
|
) -> int:
|
||||||
"""quarantines all local media associated with a single user
|
"""quarantines all local media associated with a single user
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -727,7 +734,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
|
local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
|
||||||
return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
|
return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"quarantine_media_by_user", _quarantine_media_by_user_txn
|
"quarantine_media_by_user", _quarantine_media_by_user_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1284,8 +1291,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||||||
)
|
)
|
||||||
self.hs.get_notifier().on_new_replication_data()
|
self.hs.get_notifier().on_new_replication_data()
|
||||||
|
|
||||||
def get_room_count(self):
|
async def get_room_count(self) -> int:
|
||||||
"""Retrieve a list of all rooms
|
"""Retrieve the total number of rooms.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def f(txn):
|
def f(txn):
|
||||||
@ -1294,7 +1301,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||||||
row = txn.fetchone()
|
row = txn.fetchone()
|
||||||
return row[0] or 0
|
return row[0] or 0
|
||||||
|
|
||||||
return self.db_pool.runInteraction("get_rooms", f)
|
return await self.db_pool.runInteraction("get_rooms", f)
|
||||||
|
|
||||||
async def add_event_report(
|
async def add_event_report(
|
||||||
self,
|
self,
|
||||||
|
@ -13,9 +13,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Dict, Iterable, List, Tuple
|
||||||
|
|
||||||
from unpaddedbase64 import encode_base64
|
from unpaddedbase64 import encode_base64
|
||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
|
from synapse.storage.types import Cursor
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
|
||||||
|
|
||||||
@ -29,16 +32,37 @@ class SignatureWorkerStore(SQLBaseStore):
|
|||||||
@cachedList(
|
@cachedList(
|
||||||
cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1
|
cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1
|
||||||
)
|
)
|
||||||
def get_event_reference_hashes(self, event_ids):
|
async def get_event_reference_hashes(
|
||||||
|
self, event_ids: Iterable[str]
|
||||||
|
) -> Dict[str, Dict[str, bytes]]:
|
||||||
|
"""Get all hashes for given events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_ids: The event IDs to get hashes for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A mapping of event ID to a mapping of algorithm to hash.
|
||||||
|
"""
|
||||||
|
|
||||||
def f(txn):
|
def f(txn):
|
||||||
return {
|
return {
|
||||||
event_id: self._get_event_reference_hashes_txn(txn, event_id)
|
event_id: self._get_event_reference_hashes_txn(txn, event_id)
|
||||||
for event_id in event_ids
|
for event_id in event_ids
|
||||||
}
|
}
|
||||||
|
|
||||||
return self.db_pool.runInteraction("get_event_reference_hashes", f)
|
return await self.db_pool.runInteraction("get_event_reference_hashes", f)
|
||||||
|
|
||||||
async def add_event_hashes(self, event_ids):
|
async def add_event_hashes(
|
||||||
|
self, event_ids: Iterable[str]
|
||||||
|
) -> List[Tuple[str, Dict[str, str]]]:
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_ids: The event IDs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash.
|
||||||
|
"""
|
||||||
hashes = await self.get_event_reference_hashes(event_ids)
|
hashes = await self.get_event_reference_hashes(event_ids)
|
||||||
hashes = {
|
hashes = {
|
||||||
e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
|
e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
|
||||||
@ -47,13 +71,15 @@ class SignatureWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
return list(hashes.items())
|
return list(hashes.items())
|
||||||
|
|
||||||
def _get_event_reference_hashes_txn(self, txn, event_id):
|
def _get_event_reference_hashes_txn(
|
||||||
|
self, txn: Cursor, event_id: str
|
||||||
|
) -> Dict[str, bytes]:
|
||||||
"""Get all the hashes for a given PDU.
|
"""Get all the hashes for a given PDU.
|
||||||
Args:
|
Args:
|
||||||
txn (cursor):
|
txn:
|
||||||
event_id (str): Id for the Event.
|
event_id: Id for the Event.
|
||||||
Returns:
|
Returns:
|
||||||
A dict[unicode, bytes] of algorithm -> hash.
|
A mapping of algorithm -> hash.
|
||||||
"""
|
"""
|
||||||
query = (
|
query = (
|
||||||
"SELECT algorithm, hash"
|
"SELECT algorithm, hash"
|
||||||
|
@ -290,7 +290,7 @@ class UIAuthWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
|
|
||||||
class UIAuthStore(UIAuthWorkerStore):
|
class UIAuthStore(UIAuthWorkerStore):
|
||||||
def delete_old_ui_auth_sessions(self, expiration_time: int):
|
async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
|
||||||
"""
|
"""
|
||||||
Remove sessions which were last used earlier than the expiration time.
|
Remove sessions which were last used earlier than the expiration time.
|
||||||
|
|
||||||
@ -299,7 +299,7 @@ class UIAuthStore(UIAuthWorkerStore):
|
|||||||
This is an epoch time in milliseconds.
|
This is an epoch time in milliseconds.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"delete_old_ui_auth_sessions",
|
"delete_old_ui_auth_sessions",
|
||||||
self._delete_old_ui_auth_sessions_txn,
|
self._delete_old_ui_auth_sessions_txn,
|
||||||
expiration_time,
|
expiration_time,
|
||||||
|
@ -66,7 +66,7 @@ class UserErasureWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
|
|
||||||
class UserErasureStore(UserErasureWorkerStore):
|
class UserErasureStore(UserErasureWorkerStore):
|
||||||
def mark_user_erased(self, user_id: str) -> None:
|
async def mark_user_erased(self, user_id: str) -> None:
|
||||||
"""Indicate that user_id wishes their message history to be erased.
|
"""Indicate that user_id wishes their message history to be erased.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -84,9 +84,9 @@ class UserErasureStore(UserErasureWorkerStore):
|
|||||||
|
|
||||||
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
|
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
|
||||||
|
|
||||||
return self.db_pool.runInteraction("mark_user_erased", f)
|
await self.db_pool.runInteraction("mark_user_erased", f)
|
||||||
|
|
||||||
def mark_user_not_erased(self, user_id: str) -> None:
|
async def mark_user_not_erased(self, user_id: str) -> None:
|
||||||
"""Indicate that user_id is no longer erased.
|
"""Indicate that user_id is no longer erased.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -106,4 +106,4 @@ class UserErasureStore(UserErasureWorkerStore):
|
|||||||
|
|
||||||
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
|
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
|
||||||
|
|
||||||
return self.db_pool.runInteraction("mark_user_not_erased", f)
|
await self.db_pool.runInteraction("mark_user_not_erased", f)
|
||||||
|
@ -13,14 +13,13 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import synapse.server
|
import synapse.server
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.types import Collection
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Utility functions for poking events into the storage of the server under test.
|
Utility functions for poking events into the storage of the server under test.
|
||||||
@ -58,7 +57,7 @@ async def inject_member_event(
|
|||||||
async def inject_event(
|
async def inject_event(
|
||||||
hs: synapse.server.HomeServer,
|
hs: synapse.server.HomeServer,
|
||||||
room_version: Optional[str] = None,
|
room_version: Optional[str] = None,
|
||||||
prev_event_ids: Optional[Collection[str]] = None,
|
prev_event_ids: Optional[List[str]] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> EventBase:
|
) -> EventBase:
|
||||||
"""Inject a generic event into a room
|
"""Inject a generic event into a room
|
||||||
@ -80,7 +79,7 @@ async def inject_event(
|
|||||||
async def create_event(
|
async def create_event(
|
||||||
hs: synapse.server.HomeServer,
|
hs: synapse.server.HomeServer,
|
||||||
room_version: Optional[str] = None,
|
room_version: Optional[str] = None,
|
||||||
prev_event_ids: Optional[Collection[str]] = None,
|
prev_event_ids: Optional[List[str]] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Tuple[EventBase, EventContext]:
|
) -> Tuple[EventBase, EventContext]:
|
||||||
if room_version is None:
|
if room_version is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user