mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-10-01 08:25:44 -04:00
Convert appservice, group server, profile and more databases to async (#8066)
This commit is contained in:
parent
9d1e4942ab
commit
a3a59bab7b
1
changelog.d/8066.misc
Normal file
1
changelog.d/8066.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Convert various parts of the codebase to async/await.
|
@ -18,8 +18,6 @@ import re
|
|||||||
|
|
||||||
from canonicaljson import json
|
from canonicaljson import json
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.appservice import AppServiceTransaction
|
from synapse.appservice import AppServiceTransaction
|
||||||
from synapse.config.appservice import load_appservices
|
from synapse.config.appservice import load_appservices
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
@ -124,17 +122,15 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
|
|||||||
class ApplicationServiceTransactionWorkerStore(
|
class ApplicationServiceTransactionWorkerStore(
|
||||||
ApplicationServiceWorkerStore, EventsWorkerStore
|
ApplicationServiceWorkerStore, EventsWorkerStore
|
||||||
):
|
):
|
||||||
@defer.inlineCallbacks
|
async def get_appservices_by_state(self, state):
|
||||||
def get_appservices_by_state(self, state):
|
|
||||||
"""Get a list of application services based on their state.
|
"""Get a list of application services based on their state.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state(ApplicationServiceState): The state to filter on.
|
state(ApplicationServiceState): The state to filter on.
|
||||||
Returns:
|
Returns:
|
||||||
A Deferred which resolves to a list of ApplicationServices, which
|
A list of ApplicationServices, which may be empty.
|
||||||
may be empty.
|
|
||||||
"""
|
"""
|
||||||
results = yield self.db_pool.simple_select_list(
|
results = await self.db_pool.simple_select_list(
|
||||||
"application_services_state", {"state": state}, ["as_id"]
|
"application_services_state", {"state": state}, ["as_id"]
|
||||||
)
|
)
|
||||||
# NB: This assumes this class is linked with ApplicationServiceStore
|
# NB: This assumes this class is linked with ApplicationServiceStore
|
||||||
@ -147,16 +143,15 @@ class ApplicationServiceTransactionWorkerStore(
|
|||||||
services.append(service)
|
services.append(service)
|
||||||
return services
|
return services
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_appservice_state(self, service):
|
||||||
def get_appservice_state(self, service):
|
|
||||||
"""Get the application service state.
|
"""Get the application service state.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service(ApplicationService): The service whose state to set.
|
service(ApplicationService): The service whose state to set.
|
||||||
Returns:
|
Returns:
|
||||||
A Deferred which resolves to ApplicationServiceState.
|
An ApplicationServiceState.
|
||||||
"""
|
"""
|
||||||
result = yield self.db_pool.simple_select_one(
|
result = await self.db_pool.simple_select_one(
|
||||||
"application_services_state",
|
"application_services_state",
|
||||||
{"as_id": service.id},
|
{"as_id": service.id},
|
||||||
["state"],
|
["state"],
|
||||||
@ -270,16 +265,14 @@ class ApplicationServiceTransactionWorkerStore(
|
|||||||
"complete_appservice_txn", _complete_appservice_txn
|
"complete_appservice_txn", _complete_appservice_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_oldest_unsent_txn(self, service):
|
||||||
def get_oldest_unsent_txn(self, service):
|
|
||||||
"""Get the oldest transaction which has not been sent for this
|
"""Get the oldest transaction which has not been sent for this
|
||||||
service.
|
service.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service(ApplicationService): The app service to get the oldest txn.
|
service(ApplicationService): The app service to get the oldest txn.
|
||||||
Returns:
|
Returns:
|
||||||
A Deferred which resolves to an AppServiceTransaction or
|
An AppServiceTransaction or None.
|
||||||
None.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_oldest_unsent_txn(txn):
|
def _get_oldest_unsent_txn(txn):
|
||||||
@ -298,7 +291,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||||||
|
|
||||||
return entry
|
return entry
|
||||||
|
|
||||||
entry = yield self.db_pool.runInteraction(
|
entry = await self.db_pool.runInteraction(
|
||||||
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
|
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -307,7 +300,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||||||
|
|
||||||
event_ids = db_to_json(entry["event_ids"])
|
event_ids = db_to_json(entry["event_ids"])
|
||||||
|
|
||||||
events = yield self.get_events_as_list(event_ids)
|
events = await self.get_events_as_list(event_ids)
|
||||||
|
|
||||||
return AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
|
return AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
|
||||||
|
|
||||||
@ -332,8 +325,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||||||
"set_appservice_last_pos", set_appservice_last_pos_txn
|
"set_appservice_last_pos", set_appservice_last_pos_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_new_events_for_appservice(self, current_id, limit):
|
||||||
def get_new_events_for_appservice(self, current_id, limit):
|
|
||||||
"""Get all new evnets"""
|
"""Get all new evnets"""
|
||||||
|
|
||||||
def get_new_events_for_appservice_txn(txn):
|
def get_new_events_for_appservice_txn(txn):
|
||||||
@ -357,11 +349,11 @@ class ApplicationServiceTransactionWorkerStore(
|
|||||||
|
|
||||||
return upper_bound, [row[1] for row in rows]
|
return upper_bound, [row[1] for row in rows]
|
||||||
|
|
||||||
upper_bound, event_ids = yield self.db_pool.runInteraction(
|
upper_bound, event_ids = await self.db_pool.runInteraction(
|
||||||
"get_new_events_for_appservice", get_new_events_for_appservice_txn
|
"get_new_events_for_appservice", get_new_events_for_appservice_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
events = yield self.get_events_as_list(event_ids)
|
events = await self.get_events_as_list(event_ids)
|
||||||
|
|
||||||
return upper_bound, events
|
return upper_bound, events
|
||||||
|
|
||||||
|
@ -17,12 +17,12 @@ from canonicaljson import encode_canonical_json
|
|||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
|
|
||||||
class FilteringStore(SQLBaseStore):
|
class FilteringStore(SQLBaseStore):
|
||||||
@cachedInlineCallbacks(num_args=2)
|
@cached(num_args=2)
|
||||||
def get_user_filter(self, user_localpart, filter_id):
|
async def get_user_filter(self, user_localpart, filter_id):
|
||||||
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
|
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
|
||||||
# with a coherent error message rather than 500 M_UNKNOWN.
|
# with a coherent error message rather than 500 M_UNKNOWN.
|
||||||
try:
|
try:
|
||||||
@ -30,7 +30,7 @@ class FilteringStore(SQLBaseStore):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
|
raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
|
||||||
|
|
||||||
def_json = yield self.db_pool.simple_select_one_onecol(
|
def_json = await self.db_pool.simple_select_one_onecol(
|
||||||
table="user_filters",
|
table="user_filters",
|
||||||
keyvalues={"user_id": user_localpart, "filter_id": filter_id},
|
keyvalues={"user_id": user_localpart, "filter_id": filter_id},
|
||||||
retcol="filter_json",
|
retcol="filter_json",
|
||||||
|
@ -14,12 +14,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import List, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
|
|
||||||
# The category ID for the "default" category. We don't store as null in the
|
# The category ID for the "default" category. We don't store as null in the
|
||||||
@ -210,9 +209,8 @@ class GroupServerWorkerStore(SQLBaseStore):
|
|||||||
"get_rooms_for_summary", _get_rooms_for_summary_txn
|
"get_rooms_for_summary", _get_rooms_for_summary_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_group_categories(self, group_id):
|
||||||
def get_group_categories(self, group_id):
|
rows = await self.db_pool.simple_select_list(
|
||||||
rows = yield self.db_pool.simple_select_list(
|
|
||||||
table="group_room_categories",
|
table="group_room_categories",
|
||||||
keyvalues={"group_id": group_id},
|
keyvalues={"group_id": group_id},
|
||||||
retcols=("category_id", "is_public", "profile"),
|
retcols=("category_id", "is_public", "profile"),
|
||||||
@ -227,9 +225,8 @@ class GroupServerWorkerStore(SQLBaseStore):
|
|||||||
for row in rows
|
for row in rows
|
||||||
}
|
}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_group_category(self, group_id, category_id):
|
||||||
def get_group_category(self, group_id, category_id):
|
category = await self.db_pool.simple_select_one(
|
||||||
category = yield self.db_pool.simple_select_one(
|
|
||||||
table="group_room_categories",
|
table="group_room_categories",
|
||||||
keyvalues={"group_id": group_id, "category_id": category_id},
|
keyvalues={"group_id": group_id, "category_id": category_id},
|
||||||
retcols=("is_public", "profile"),
|
retcols=("is_public", "profile"),
|
||||||
@ -240,9 +237,8 @@ class GroupServerWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
return category
|
return category
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_group_roles(self, group_id):
|
||||||
def get_group_roles(self, group_id):
|
rows = await self.db_pool.simple_select_list(
|
||||||
rows = yield self.db_pool.simple_select_list(
|
|
||||||
table="group_roles",
|
table="group_roles",
|
||||||
keyvalues={"group_id": group_id},
|
keyvalues={"group_id": group_id},
|
||||||
retcols=("role_id", "is_public", "profile"),
|
retcols=("role_id", "is_public", "profile"),
|
||||||
@ -257,9 +253,8 @@ class GroupServerWorkerStore(SQLBaseStore):
|
|||||||
for row in rows
|
for row in rows
|
||||||
}
|
}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_group_role(self, group_id, role_id):
|
||||||
def get_group_role(self, group_id, role_id):
|
role = await self.db_pool.simple_select_one(
|
||||||
role = yield self.db_pool.simple_select_one(
|
|
||||||
table="group_roles",
|
table="group_roles",
|
||||||
keyvalues={"group_id": group_id, "role_id": role_id},
|
keyvalues={"group_id": group_id, "role_id": role_id},
|
||||||
retcols=("is_public", "profile"),
|
retcols=("is_public", "profile"),
|
||||||
@ -448,12 +443,11 @@ class GroupServerWorkerStore(SQLBaseStore):
|
|||||||
"get_attestations_need_renewals", _get_attestations_need_renewals_txn
|
"get_attestations_need_renewals", _get_attestations_need_renewals_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_remote_attestation(self, group_id, user_id):
|
||||||
def get_remote_attestation(self, group_id, user_id):
|
|
||||||
"""Get the attestation that proves the remote agrees that the user is
|
"""Get the attestation that proves the remote agrees that the user is
|
||||||
in the group.
|
in the group.
|
||||||
"""
|
"""
|
||||||
row = yield self.db_pool.simple_select_one(
|
row = await self.db_pool.simple_select_one(
|
||||||
table="group_attestations_remote",
|
table="group_attestations_remote",
|
||||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||||
retcols=("valid_until_ms", "attestation_json"),
|
retcols=("valid_until_ms", "attestation_json"),
|
||||||
@ -499,13 +493,13 @@ class GroupServerWorkerStore(SQLBaseStore):
|
|||||||
"get_all_groups_for_user", _get_all_groups_for_user_txn
|
"get_all_groups_for_user", _get_all_groups_for_user_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_groups_changes_for_user(self, user_id, from_token, to_token):
|
async def get_groups_changes_for_user(self, user_id, from_token, to_token):
|
||||||
from_token = int(from_token)
|
from_token = int(from_token)
|
||||||
has_changed = self._group_updates_stream_cache.has_entity_changed(
|
has_changed = self._group_updates_stream_cache.has_entity_changed(
|
||||||
user_id, from_token
|
user_id, from_token
|
||||||
)
|
)
|
||||||
if not has_changed:
|
if not has_changed:
|
||||||
return defer.succeed([])
|
return []
|
||||||
|
|
||||||
def _get_groups_changes_for_user_txn(txn):
|
def _get_groups_changes_for_user_txn(txn):
|
||||||
sql = """
|
sql = """
|
||||||
@ -525,7 +519,7 @@ class GroupServerWorkerStore(SQLBaseStore):
|
|||||||
for group_id, membership, gtype, content_json in txn
|
for group_id, membership, gtype, content_json in txn
|
||||||
]
|
]
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_groups_changes_for_user", _get_groups_changes_for_user_txn
|
"get_groups_changes_for_user", _get_groups_changes_for_user_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1087,31 +1081,31 @@ class GroupServerStore(GroupServerWorkerStore):
|
|||||||
desc="update_group_publicity",
|
desc="update_group_publicity",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def register_user_group_membership(
|
||||||
def register_user_group_membership(
|
|
||||||
self,
|
self,
|
||||||
group_id,
|
group_id: str,
|
||||||
user_id,
|
user_id: str,
|
||||||
membership,
|
membership: str,
|
||||||
is_admin=False,
|
is_admin: bool = False,
|
||||||
content={},
|
content: JsonDict = {},
|
||||||
local_attestation=None,
|
local_attestation: Optional[dict] = None,
|
||||||
remote_attestation=None,
|
remote_attestation: Optional[dict] = None,
|
||||||
is_publicised=False,
|
is_publicised: bool = False,
|
||||||
):
|
) -> int:
|
||||||
"""Registers that a local user is a member of a (local or remote) group.
|
"""Registers that a local user is a member of a (local or remote) group.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_id (str)
|
group_id: The group the member is being added to.
|
||||||
user_id (str)
|
user_id: THe user ID to add to the group.
|
||||||
membership (str)
|
membership: The type of group membership.
|
||||||
is_admin (bool)
|
is_admin: Whether the user should be added as a group admin.
|
||||||
content (dict): Content of the membership, e.g. includes the inviter
|
content: Content of the membership, e.g. includes the inviter
|
||||||
if the user has been invited.
|
if the user has been invited.
|
||||||
local_attestation (dict): If remote group then store the fact that we
|
local_attestation: If remote group then store the fact that we
|
||||||
have given out an attestation, else None.
|
have given out an attestation, else None.
|
||||||
remote_attestation (dict): If remote group then store the remote
|
remote_attestation: If remote group then store the remote
|
||||||
attestation from the group, else None.
|
attestation from the group, else None.
|
||||||
|
is_publicised: Whether this should be publicised.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _register_user_group_membership_txn(txn, next_id):
|
def _register_user_group_membership_txn(txn, next_id):
|
||||||
@ -1188,18 +1182,17 @@ class GroupServerStore(GroupServerWorkerStore):
|
|||||||
return next_id
|
return next_id
|
||||||
|
|
||||||
with self._group_updates_id_gen.get_next() as next_id:
|
with self._group_updates_id_gen.get_next() as next_id:
|
||||||
res = yield self.db_pool.runInteraction(
|
res = await self.db_pool.runInteraction(
|
||||||
"register_user_group_membership",
|
"register_user_group_membership",
|
||||||
_register_user_group_membership_txn,
|
_register_user_group_membership_txn,
|
||||||
next_id,
|
next_id,
|
||||||
)
|
)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def create_group(
|
||||||
def create_group(
|
|
||||||
self, group_id, user_id, name, avatar_url, short_description, long_description
|
self, group_id, user_id, name, avatar_url, short_description, long_description
|
||||||
):
|
) -> None:
|
||||||
yield self.db_pool.simple_insert(
|
await self.db_pool.simple_insert(
|
||||||
table="groups",
|
table="groups",
|
||||||
values={
|
values={
|
||||||
"group_id": group_id,
|
"group_id": group_id,
|
||||||
@ -1212,9 +1205,8 @@ class GroupServerStore(GroupServerWorkerStore):
|
|||||||
desc="create_group",
|
desc="create_group",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def update_group_profile(self, group_id, profile):
|
||||||
def update_group_profile(self, group_id, profile):
|
await self.db_pool.simple_update_one(
|
||||||
yield self.db_pool.simple_update_one(
|
|
||||||
table="groups",
|
table="groups",
|
||||||
keyvalues={"group_id": group_id},
|
keyvalues={"group_id": group_id},
|
||||||
updatevalues=profile,
|
updatevalues=profile,
|
||||||
|
@ -15,8 +15,6 @@
|
|||||||
|
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
|
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
|
||||||
from synapse.storage.presence import UserPresenceState
|
from synapse.storage.presence import UserPresenceState
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
@ -24,14 +22,13 @@ from synapse.util.iterutils import batch_iter
|
|||||||
|
|
||||||
|
|
||||||
class PresenceStore(SQLBaseStore):
|
class PresenceStore(SQLBaseStore):
|
||||||
@defer.inlineCallbacks
|
async def update_presence(self, presence_states):
|
||||||
def update_presence(self, presence_states):
|
|
||||||
stream_ordering_manager = self._presence_id_gen.get_next_mult(
|
stream_ordering_manager = self._presence_id_gen.get_next_mult(
|
||||||
len(presence_states)
|
len(presence_states)
|
||||||
)
|
)
|
||||||
|
|
||||||
with stream_ordering_manager as stream_orderings:
|
with stream_ordering_manager as stream_orderings:
|
||||||
yield self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"update_presence",
|
"update_presence",
|
||||||
self._update_presence_txn,
|
self._update_presence_txn,
|
||||||
stream_orderings,
|
stream_orderings,
|
||||||
|
@ -13,18 +13,15 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.databases.main.roommember import ProfileInfo
|
from synapse.storage.databases.main.roommember import ProfileInfo
|
||||||
|
|
||||||
|
|
||||||
class ProfileWorkerStore(SQLBaseStore):
|
class ProfileWorkerStore(SQLBaseStore):
|
||||||
@defer.inlineCallbacks
|
async def get_profileinfo(self, user_localpart):
|
||||||
def get_profileinfo(self, user_localpart):
|
|
||||||
try:
|
try:
|
||||||
profile = yield self.db_pool.simple_select_one(
|
profile = await self.db_pool.simple_select_one(
|
||||||
table="profiles",
|
table="profiles",
|
||||||
keyvalues={"user_id": user_localpart},
|
keyvalues={"user_id": user_localpart},
|
||||||
retcols=("displayname", "avatar_url"),
|
retcols=("displayname", "avatar_url"),
|
||||||
@ -118,14 +115,13 @@ class ProfileStore(ProfileWorkerStore):
|
|||||||
desc="update_remote_profile_cache",
|
desc="update_remote_profile_cache",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def maybe_delete_remote_profile_cache(self, user_id):
|
||||||
def maybe_delete_remote_profile_cache(self, user_id):
|
|
||||||
"""Check if we still care about the remote user's profile, and if we
|
"""Check if we still care about the remote user's profile, and if we
|
||||||
don't then remove their profile from the cache
|
don't then remove their profile from the cache
|
||||||
"""
|
"""
|
||||||
subscribed = yield self.is_subscribed_remote_profile_for_user(user_id)
|
subscribed = await self.is_subscribed_remote_profile_for_user(user_id)
|
||||||
if not subscribed:
|
if not subscribed:
|
||||||
yield self.db_pool.simple_delete(
|
await self.db_pool.simple_delete(
|
||||||
table="remote_profile_cache",
|
table="remote_profile_cache",
|
||||||
keyvalues={"user_id": user_id},
|
keyvalues={"user_id": user_id},
|
||||||
desc="delete_remote_profile_cache",
|
desc="delete_remote_profile_cache",
|
||||||
@ -151,11 +147,10 @@ class ProfileStore(ProfileWorkerStore):
|
|||||||
_get_remote_profile_cache_entries_that_expire_txn,
|
_get_remote_profile_cache_entries_that_expire_txn,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def is_subscribed_remote_profile_for_user(self, user_id):
|
||||||
def is_subscribed_remote_profile_for_user(self, user_id):
|
|
||||||
"""Check whether we are interested in a remote user's profile.
|
"""Check whether we are interested in a remote user's profile.
|
||||||
"""
|
"""
|
||||||
res = yield self.db_pool.simple_select_one_onecol(
|
res = await self.db_pool.simple_select_one_onecol(
|
||||||
table="group_users",
|
table="group_users",
|
||||||
keyvalues={"user_id": user_id},
|
keyvalues={"user_id": user_id},
|
||||||
retcol="user_id",
|
retcol="user_id",
|
||||||
@ -166,7 +161,7 @@ class ProfileStore(ProfileWorkerStore):
|
|||||||
if res:
|
if res:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
res = yield self.db_pool.simple_select_one_onecol(
|
res = await self.db_pool.simple_select_one_onecol(
|
||||||
table="group_invites",
|
table="group_invites",
|
||||||
keyvalues={"user_id": user_id},
|
keyvalues={"user_id": user_id},
|
||||||
retcol="user_id",
|
retcol="user_id",
|
||||||
|
@ -14,10 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from synapse.api.constants import RelationTypes
|
from synapse.api.constants import RelationTypes
|
||||||
|
from synapse.events import EventBase
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.databases.main.stream import generate_pagination_where_clause
|
from synapse.storage.databases.main.stream import generate_pagination_where_clause
|
||||||
from synapse.storage.relations import (
|
from synapse.storage.relations import (
|
||||||
@ -25,7 +27,7 @@ from synapse.storage.relations import (
|
|||||||
PaginationChunk,
|
PaginationChunk,
|
||||||
RelationPaginationToken,
|
RelationPaginationToken,
|
||||||
)
|
)
|
||||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -227,18 +229,18 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||||||
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
|
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@cachedInlineCallbacks()
|
@cached()
|
||||||
def get_applicable_edit(self, event_id):
|
async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
|
||||||
"""Get the most recent edit (if any) that has happened for the given
|
"""Get the most recent edit (if any) that has happened for the given
|
||||||
event.
|
event.
|
||||||
|
|
||||||
Correctly handles checking whether edits were allowed to happen.
|
Correctly handles checking whether edits were allowed to happen.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event_id (str): The original event ID
|
event_id: The original event ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[EventBase|None]: Returns the most recent edit, if any.
|
The most recent edit, if any.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# We only allow edits for `m.room.message` events that have the same sender
|
# We only allow edits for `m.room.message` events that have the same sender
|
||||||
@ -268,15 +270,14 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||||||
if row:
|
if row:
|
||||||
return row[0]
|
return row[0]
|
||||||
|
|
||||||
edit_id = yield self.db_pool.runInteraction(
|
edit_id = await self.db_pool.runInteraction(
|
||||||
"get_applicable_edit", _get_applicable_edit_txn
|
"get_applicable_edit", _get_applicable_edit_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
if not edit_id:
|
if not edit_id:
|
||||||
return
|
return None
|
||||||
|
|
||||||
edit_event = yield self.get_event(edit_id, allow_none=True)
|
return await self.get_event(edit_id, allow_none=True)
|
||||||
return edit_event
|
|
||||||
|
|
||||||
def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
|
def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
|
||||||
"""Check if a user has already annotated an event with the same key
|
"""Check if a user has already annotated an event with the same key
|
||||||
|
@ -18,8 +18,6 @@ from collections import namedtuple
|
|||||||
|
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
@ -126,8 +124,7 @@ class TransactionStore(SQLBaseStore):
|
|||||||
desc="set_received_txn_response",
|
desc="set_received_txn_response",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_destination_retry_timings(self, destination):
|
||||||
def get_destination_retry_timings(self, destination):
|
|
||||||
"""Gets the current retry timings (if any) for a given destination.
|
"""Gets the current retry timings (if any) for a given destination.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -142,7 +139,7 @@ class TransactionStore(SQLBaseStore):
|
|||||||
if result is not SENTINEL:
|
if result is not SENTINEL:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
result = yield self.db_pool.runInteraction(
|
result = await self.db_pool.runInteraction(
|
||||||
"get_destination_retry_timings",
|
"get_destination_retry_timings",
|
||||||
self._get_destination_retry_timings,
|
self._get_destination_retry_timings,
|
||||||
destination,
|
destination,
|
||||||
|
@ -178,14 +178,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_appservice_state_none(self):
|
def test_get_appservice_state_none(self):
|
||||||
service = Mock(id="999")
|
service = Mock(id="999")
|
||||||
state = yield self.store.get_appservice_state(service)
|
state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
|
||||||
self.assertEquals(None, state)
|
self.assertEquals(None, state)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_appservice_state_up(self):
|
def test_get_appservice_state_up(self):
|
||||||
yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
|
yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
|
||||||
service = Mock(id=self.as_list[0]["id"])
|
service = Mock(id=self.as_list[0]["id"])
|
||||||
state = yield self.store.get_appservice_state(service)
|
state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
|
||||||
self.assertEquals(ApplicationServiceState.UP, state)
|
self.assertEquals(ApplicationServiceState.UP, state)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -194,13 +194,13 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
|||||||
yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN)
|
yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN)
|
||||||
yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
|
yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
|
||||||
service = Mock(id=self.as_list[1]["id"])
|
service = Mock(id=self.as_list[1]["id"])
|
||||||
state = yield self.store.get_appservice_state(service)
|
state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
|
||||||
self.assertEquals(ApplicationServiceState.DOWN, state)
|
self.assertEquals(ApplicationServiceState.DOWN, state)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_appservices_by_state_none(self):
|
def test_get_appservices_by_state_none(self):
|
||||||
services = yield self.store.get_appservices_by_state(
|
services = yield defer.ensureDeferred(
|
||||||
ApplicationServiceState.DOWN
|
self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
|
||||||
)
|
)
|
||||||
self.assertEquals(0, len(services))
|
self.assertEquals(0, len(services))
|
||||||
|
|
||||||
@ -339,7 +339,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
|||||||
def test_get_oldest_unsent_txn_none(self):
|
def test_get_oldest_unsent_txn_none(self):
|
||||||
service = Mock(id=self.as_list[0]["id"])
|
service = Mock(id=self.as_list[0]["id"])
|
||||||
|
|
||||||
txn = yield self.store.get_oldest_unsent_txn(service)
|
txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service))
|
||||||
self.assertEquals(None, txn)
|
self.assertEquals(None, txn)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -349,14 +349,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
|||||||
other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
|
other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
|
||||||
|
|
||||||
# we aren't testing store._base stuff here, so mock this out
|
# we aren't testing store._base stuff here, so mock this out
|
||||||
self.store.get_events_as_list = Mock(return_value=events)
|
self.store.get_events_as_list = Mock(return_value=defer.succeed(events))
|
||||||
|
|
||||||
yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
|
yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
|
||||||
yield self._insert_txn(service.id, 10, events)
|
yield self._insert_txn(service.id, 10, events)
|
||||||
yield self._insert_txn(service.id, 11, other_events)
|
yield self._insert_txn(service.id, 11, other_events)
|
||||||
yield self._insert_txn(service.id, 12, other_events)
|
yield self._insert_txn(service.id, 12, other_events)
|
||||||
|
|
||||||
txn = yield self.store.get_oldest_unsent_txn(service)
|
txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service))
|
||||||
self.assertEquals(service, txn.service)
|
self.assertEquals(service, txn.service)
|
||||||
self.assertEquals(10, txn.id)
|
self.assertEquals(10, txn.id)
|
||||||
self.assertEquals(events, txn.events)
|
self.assertEquals(events, txn.events)
|
||||||
@ -366,8 +366,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
|||||||
yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN)
|
yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN)
|
||||||
yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP)
|
yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP)
|
||||||
|
|
||||||
services = yield self.store.get_appservices_by_state(
|
services = yield defer.ensureDeferred(
|
||||||
ApplicationServiceState.DOWN
|
self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
|
||||||
)
|
)
|
||||||
self.assertEquals(1, len(services))
|
self.assertEquals(1, len(services))
|
||||||
self.assertEquals(self.as_list[0]["id"], services[0].id)
|
self.assertEquals(self.as_list[0]["id"], services[0].id)
|
||||||
@ -379,8 +379,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
|||||||
yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
|
yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
|
||||||
yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP)
|
yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP)
|
||||||
|
|
||||||
services = yield self.store.get_appservices_by_state(
|
services = yield defer.ensureDeferred(
|
||||||
ApplicationServiceState.DOWN
|
self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
|
||||||
)
|
)
|
||||||
self.assertEquals(2, len(services))
|
self.assertEquals(2, len(services))
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
|
Loading…
Reference in New Issue
Block a user