mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-01-17 23:07:09 -05:00
Convert the roommember database to async/await. (#8070)
This commit is contained in:
parent
5ecc8b5825
commit
fbe930dad2
1
changelog.d/8070.misc
Normal file
1
changelog.d/8070.misc
Normal file
@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
@ -58,7 +58,6 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||
"""
|
||||
for host in {get_domain_from_id(u) for u in members_changed}:
|
||||
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
|
||||
self._attempt_to_invalidate_cache("was_host_joined", (room_id, host))
|
||||
|
||||
self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
|
||||
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
|
||||
|
@ -256,81 +256,6 @@ class PushRulesWorkerStore(
|
||||
):
|
||||
yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def bulk_get_push_rules_for_room(self, event, context):
|
||||
state_group = context.state_group
|
||||
if not state_group:
|
||||
# If state_group is None it means it has yet to be assigned a
|
||||
# state group, i.e. we need to make sure that calls with a state_group
|
||||
# of None don't hit previous cached calls with a None state_group.
|
||||
# To do this we set the state_group to a new object as object() != object()
|
||||
state_group = object()
|
||||
|
||||
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
|
||||
result = yield self._bulk_get_push_rules_for_room(
|
||||
event.room_id, state_group, current_state_ids, event=event
|
||||
)
|
||||
return result
|
||||
|
||||
@cachedInlineCallbacks(num_args=2, cache_context=True)
|
||||
def _bulk_get_push_rules_for_room(
|
||||
self, room_id, state_group, current_state_ids, cache_context, event=None
|
||||
):
|
||||
# We don't use `state_group`, its there so that we can cache based
|
||||
# on it. However, its important that its never None, since two current_state's
|
||||
# with a state_group of None are likely to be different.
|
||||
# See bulk_get_push_rules_for_room for how we work around this.
|
||||
assert state_group is not None
|
||||
|
||||
# We also will want to generate notifs for other people in the room so
|
||||
# their unread countss are correct in the event stream, but to avoid
|
||||
# generating them for bot / AS users etc, we only do so for people who've
|
||||
# sent a read receipt into the room.
|
||||
|
||||
users_in_room = yield self._get_joined_users_from_context(
|
||||
room_id,
|
||||
state_group,
|
||||
current_state_ids,
|
||||
on_invalidate=cache_context.invalidate,
|
||||
event=event,
|
||||
)
|
||||
|
||||
# We ignore app service users for now. This is so that we don't fill
|
||||
# up the `get_if_users_have_pushers` cache with AS entries that we
|
||||
# know don't have pushers, nor even read receipts.
|
||||
local_users_in_room = {
|
||||
u
|
||||
for u in users_in_room
|
||||
if self.hs.is_mine_id(u)
|
||||
and not self.get_if_app_services_interested_in_user(u)
|
||||
}
|
||||
|
||||
# users in the room who have pushers need to get push rules run because
|
||||
# that's how their pushers work
|
||||
if_users_with_pushers = yield self.get_if_users_have_pushers(
|
||||
local_users_in_room, on_invalidate=cache_context.invalidate
|
||||
)
|
||||
user_ids = {
|
||||
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
|
||||
}
|
||||
|
||||
users_with_receipts = yield self.get_users_with_read_receipts_in_room(
|
||||
room_id, on_invalidate=cache_context.invalidate
|
||||
)
|
||||
|
||||
# any users with pushers must be ours: they have pushers
|
||||
for uid in users_with_receipts:
|
||||
if uid in local_users_in_room:
|
||||
user_ids.add(uid)
|
||||
|
||||
rules_by_user = yield self.bulk_get_push_rules(
|
||||
user_ids, on_invalidate=cache_context.invalidate
|
||||
)
|
||||
|
||||
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
|
||||
|
||||
return rules_by_user
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="get_push_rules_enabled_for_user",
|
||||
list_name="user_ids",
|
||||
|
@ -15,11 +15,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Iterable, List, Set
|
||||
from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.metrics import LaterGauge
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage._base import (
|
||||
@ -40,9 +42,12 @@ from synapse.storage.roommember import (
|
||||
from synapse.types import Collection, get_domain_from_id
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches import intern_string
|
||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
|
||||
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.state import _StateCacheEntry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -150,12 +155,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
)
|
||||
|
||||
@cached(max_entries=100000, iterable=True)
|
||||
def get_users_in_room(self, room_id):
|
||||
def get_users_in_room(self, room_id: str):
|
||||
return self.db_pool.runInteraction(
|
||||
"get_users_in_room", self.get_users_in_room_txn, room_id
|
||||
)
|
||||
|
||||
def get_users_in_room_txn(self, txn, room_id):
|
||||
def get_users_in_room_txn(self, txn, room_id: str) -> List[str]:
|
||||
# If we can assume current_state_events.membership is up to date
|
||||
# then we can avoid a join, which is a Very Good Thing given how
|
||||
# frequently this function gets called.
|
||||
@ -178,11 +183,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
return [r[0] for r in txn]
|
||||
|
||||
@cached(max_entries=100000)
|
||||
def get_room_summary(self, room_id):
|
||||
def get_room_summary(self, room_id: str):
|
||||
""" Get the details of a room roughly suitable for use by the room
|
||||
summary extension to /sync. Useful when lazy loading room members.
|
||||
Args:
|
||||
room_id (str): The room ID to query
|
||||
room_id: The room ID to query
|
||||
Returns:
|
||||
Deferred[dict[str, MemberSummary]:
|
||||
dict of membership states, pointing to a MemberSummary named tuple.
|
||||
@ -261,78 +266,59 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
|
||||
return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn)
|
||||
|
||||
def _get_user_counts_in_room_txn(self, txn, room_id):
|
||||
"""
|
||||
Get the user count in a room by membership.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
membership (Membership)
|
||||
|
||||
Returns:
|
||||
Deferred[int]
|
||||
"""
|
||||
sql = """
|
||||
SELECT m.membership, count(*) FROM room_memberships as m
|
||||
INNER JOIN current_state_events as c USING(event_id)
|
||||
WHERE c.type = 'm.room.member' AND c.room_id = ?
|
||||
GROUP BY m.membership
|
||||
"""
|
||||
|
||||
txn.execute(sql, (room_id,))
|
||||
return {row[0]: row[1] for row in txn}
|
||||
|
||||
@cached()
|
||||
def get_invited_rooms_for_local_user(self, user_id):
|
||||
""" Get all the rooms the *local* user is invited to
|
||||
def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]:
|
||||
"""Get all the rooms the *local* user is invited to.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
user_id: The user ID.
|
||||
|
||||
Returns:
|
||||
A deferred list of RoomsForUser.
|
||||
A awaitable list of RoomsForUser.
|
||||
"""
|
||||
|
||||
return self.get_rooms_for_local_user_where_membership_is(
|
||||
user_id, [Membership.INVITE]
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_invite_for_local_user_in_room(self, user_id, room_id):
|
||||
"""Gets the invite for the given *local* user and room
|
||||
async def get_invite_for_local_user_in_room(
|
||||
self, user_id: str, room_id: str
|
||||
) -> Optional[RoomsForUser]:
|
||||
"""Gets the invite for the given *local* user and room.
|
||||
|
||||
Args:
|
||||
user_id (str)
|
||||
room_id (str)
|
||||
user_id: The user ID to find the invite of.
|
||||
room_id: The room to user was invited to.
|
||||
|
||||
Returns:
|
||||
Deferred: Resolves to either a RoomsForUser or None if no invite was
|
||||
found.
|
||||
Either a RoomsForUser or None if no invite was found.
|
||||
"""
|
||||
invites = yield self.get_invited_rooms_for_local_user(user_id)
|
||||
invites = await self.get_invited_rooms_for_local_user(user_id)
|
||||
for invite in invites:
|
||||
if invite.room_id == room_id:
|
||||
return invite
|
||||
return None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list):
|
||||
""" Get all the rooms for this *local* user where the membership for this user
|
||||
async def get_rooms_for_local_user_where_membership_is(
|
||||
self, user_id: str, membership_list: List[str]
|
||||
) -> Optional[List[RoomsForUser]]:
|
||||
"""Get all the rooms for this *local* user where the membership for this user
|
||||
matches one in the membership list.
|
||||
|
||||
Filters out forgotten rooms.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
membership_list (list): A list of synapse.api.constants.Membership
|
||||
values which the user must be in.
|
||||
user_id: The user ID.
|
||||
membership_list: A list of synapse.api.constants.Membership
|
||||
values which the user must be in.
|
||||
|
||||
Returns:
|
||||
Deferred[list[RoomsForUser]]
|
||||
The RoomsForUser that the user matches the membership types.
|
||||
"""
|
||||
if not membership_list:
|
||||
return defer.succeed(None)
|
||||
return None
|
||||
|
||||
rooms = yield self.db_pool.runInteraction(
|
||||
rooms = await self.db_pool.runInteraction(
|
||||
"get_rooms_for_local_user_where_membership_is",
|
||||
self._get_rooms_for_local_user_where_membership_is_txn,
|
||||
user_id,
|
||||
@ -340,12 +326,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
)
|
||||
|
||||
# Now we filter out forgotten rooms
|
||||
forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id)
|
||||
forgotten_rooms = await self.get_forgotten_rooms_for_user(user_id)
|
||||
return [room for room in rooms if room.room_id not in forgotten_rooms]
|
||||
|
||||
def _get_rooms_for_local_user_where_membership_is_txn(
|
||||
self, txn, user_id, membership_list
|
||||
):
|
||||
self, txn, user_id: str, membership_list: List[str]
|
||||
) -> List[RoomsForUser]:
|
||||
# Paranoia check.
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
raise Exception(
|
||||
@ -374,14 +360,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
return results
|
||||
|
||||
@cached(max_entries=500000, iterable=True)
|
||||
def get_rooms_for_user_with_stream_ordering(self, user_id):
|
||||
def get_rooms_for_user_with_stream_ordering(self, user_id: str):
|
||||
"""Returns a set of room_ids the user is currently joined to.
|
||||
|
||||
If a remote user only returns rooms this server is currently
|
||||
participating in.
|
||||
|
||||
Args:
|
||||
user_id (str)
|
||||
user_id
|
||||
|
||||
Returns:
|
||||
Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
|
||||
@ -394,7 +380,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
user_id,
|
||||
)
|
||||
|
||||
def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id):
|
||||
def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str):
|
||||
# We use `current_state_events` here and not `local_current_membership`
|
||||
# as a) this gets called with remote users and b) this only gets called
|
||||
# for rooms the server is participating in.
|
||||
@ -458,37 +444,39 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
_get_users_server_still_shares_room_with_txn,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_rooms_for_user(self, user_id, on_invalidate=None):
|
||||
async def get_rooms_for_user(self, user_id: str, on_invalidate=None):
|
||||
"""Returns a set of room_ids the user is currently joined to.
|
||||
|
||||
If a remote user only returns rooms this server is currently
|
||||
participating in.
|
||||
"""
|
||||
rooms = yield self.get_rooms_for_user_with_stream_ordering(
|
||||
rooms = await self.get_rooms_for_user_with_stream_ordering(
|
||||
user_id, on_invalidate=on_invalidate
|
||||
)
|
||||
return frozenset(r.room_id for r in rooms)
|
||||
|
||||
@cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
|
||||
def get_users_who_share_room_with_user(self, user_id, cache_context):
|
||||
@cached(max_entries=500000, cache_context=True, iterable=True)
|
||||
async def get_users_who_share_room_with_user(
|
||||
self, user_id: str, cache_context: _CacheContext
|
||||
) -> Set[str]:
|
||||
"""Returns the set of users who share a room with `user_id`
|
||||
"""
|
||||
room_ids = yield self.get_rooms_for_user(
|
||||
room_ids = await self.get_rooms_for_user(
|
||||
user_id, on_invalidate=cache_context.invalidate
|
||||
)
|
||||
|
||||
user_who_share_room = set()
|
||||
for room_id in room_ids:
|
||||
user_ids = yield self.get_users_in_room(
|
||||
user_ids = await self.get_users_in_room(
|
||||
room_id, on_invalidate=cache_context.invalidate
|
||||
)
|
||||
user_who_share_room.update(user_ids)
|
||||
|
||||
return user_who_share_room
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_joined_users_from_context(self, event, context):
|
||||
async def get_joined_users_from_context(
|
||||
self, event: EventBase, context: EventContext
|
||||
):
|
||||
state_group = context.state_group
|
||||
if not state_group:
|
||||
# If state_group is None it means it has yet to be assigned a
|
||||
@ -497,14 +485,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
# To do this we set the state_group to a new object as object() != object()
|
||||
state_group = object()
|
||||
|
||||
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
|
||||
result = yield self._get_joined_users_from_context(
|
||||
current_state_ids = await context.get_current_state_ids()
|
||||
return await self._get_joined_users_from_context(
|
||||
event.room_id, state_group, current_state_ids, event=event, context=context
|
||||
)
|
||||
return result
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_joined_users_from_state(self, room_id, state_entry):
|
||||
async def get_joined_users_from_state(self, room_id, state_entry):
|
||||
state_group = state_entry.state_group
|
||||
if not state_group:
|
||||
# If state_group is None it means it has yet to be assigned a
|
||||
@ -514,16 +500,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
state_group = object()
|
||||
|
||||
with Measure(self._clock, "get_joined_users_from_state"):
|
||||
return (
|
||||
yield self._get_joined_users_from_context(
|
||||
room_id, state_group, state_entry.state, context=state_entry
|
||||
)
|
||||
return await self._get_joined_users_from_context(
|
||||
room_id, state_group, state_entry.state, context=state_entry
|
||||
)
|
||||
|
||||
@cachedInlineCallbacks(
|
||||
num_args=2, cache_context=True, iterable=True, max_entries=100000
|
||||
)
|
||||
def _get_joined_users_from_context(
|
||||
@cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
|
||||
async def _get_joined_users_from_context(
|
||||
self,
|
||||
room_id,
|
||||
state_group,
|
||||
@ -535,7 +517,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
# We don't use `state_group`, it's there so that we can cache based
|
||||
# on it. However, it's important that it's never None, since two current_states
|
||||
# with a state_group of None are likely to be different.
|
||||
# See bulk_get_push_rules_for_room for how we work around this.
|
||||
assert state_group is not None
|
||||
|
||||
users_in_room = {}
|
||||
@ -588,7 +569,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
missing_member_event_ids.append(event_id)
|
||||
|
||||
if missing_member_event_ids:
|
||||
event_to_memberships = yield self._get_joined_profiles_from_event_ids(
|
||||
event_to_memberships = await self._get_joined_profiles_from_event_ids(
|
||||
missing_member_event_ids
|
||||
)
|
||||
users_in_room.update((row for row in event_to_memberships.values() if row))
|
||||
@ -612,12 +593,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
list_name="event_ids",
|
||||
inlineCallbacks=True,
|
||||
)
|
||||
def _get_joined_profiles_from_event_ids(self, event_ids):
|
||||
def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
|
||||
"""For given set of member event_ids check if they point to a join
|
||||
event and if so return the associated user and profile info.
|
||||
|
||||
Args:
|
||||
event_ids (Iterable[str]): The member event IDs to lookup
|
||||
event_ids: The member event IDs to lookup
|
||||
|
||||
Returns:
|
||||
Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
|
||||
@ -644,8 +625,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
for row in rows
|
||||
}
|
||||
|
||||
@cachedInlineCallbacks(max_entries=10000)
|
||||
def is_host_joined(self, room_id, host):
|
||||
@cached(max_entries=10000)
|
||||
async def is_host_joined(self, room_id: str, host: str) -> bool:
|
||||
if "%" in host or "_" in host:
|
||||
raise Exception("Invalid host name")
|
||||
|
||||
@ -664,7 +645,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
# the returned user actually has the correct domain.
|
||||
like_clause = "%:" + host
|
||||
|
||||
rows = yield self.db_pool.execute(
|
||||
rows = await self.db_pool.execute(
|
||||
"is_host_joined", None, sql, room_id, like_clause
|
||||
)
|
||||
|
||||
@ -678,50 +659,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
|
||||
return True
|
||||
|
||||
@cachedInlineCallbacks()
|
||||
def was_host_joined(self, room_id, host):
|
||||
"""Check whether the server is or ever was in the room.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
host (str)
|
||||
|
||||
Returns:
|
||||
Deferred: Resolves to True if the host is/was in the room, otherwise
|
||||
False.
|
||||
"""
|
||||
if "%" in host or "_" in host:
|
||||
raise Exception("Invalid host name")
|
||||
|
||||
sql = """
|
||||
SELECT user_id FROM room_memberships
|
||||
WHERE room_id = ?
|
||||
AND user_id LIKE ?
|
||||
AND membership = 'join'
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
# We do need to be careful to ensure that host doesn't have any wild cards
|
||||
# in it, but we checked above for known ones and we'll check below that
|
||||
# the returned user actually has the correct domain.
|
||||
like_clause = "%:" + host
|
||||
|
||||
rows = yield self.db_pool.execute(
|
||||
"was_host_joined", None, sql, room_id, like_clause
|
||||
)
|
||||
|
||||
if not rows:
|
||||
return False
|
||||
|
||||
user_id = rows[0][0]
|
||||
if get_domain_from_id(user_id) != host:
|
||||
# This can only happen if the host name has something funky in it
|
||||
raise Exception("Invalid host name")
|
||||
|
||||
return True
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_joined_hosts(self, room_id, state_entry):
|
||||
async def get_joined_hosts(self, room_id: str, state_entry):
|
||||
state_group = state_entry.state_group
|
||||
if not state_group:
|
||||
# If state_group is None it means it has yet to be assigned a
|
||||
@ -731,32 +669,28 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
state_group = object()
|
||||
|
||||
with Measure(self._clock, "get_joined_hosts"):
|
||||
return (
|
||||
yield self._get_joined_hosts(
|
||||
room_id, state_group, state_entry.state, state_entry=state_entry
|
||||
)
|
||||
return await self._get_joined_hosts(
|
||||
room_id, state_group, state_entry.state, state_entry=state_entry
|
||||
)
|
||||
|
||||
@cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
|
||||
# @defer.inlineCallbacks
|
||||
def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
|
||||
@cached(num_args=2, max_entries=10000, iterable=True)
|
||||
async def _get_joined_hosts(
|
||||
self, room_id, state_group, current_state_ids, state_entry
|
||||
):
|
||||
# We don't use `state_group`, its there so that we can cache based
|
||||
# on it. However, its important that its never None, since two current_state's
|
||||
# with a state_group of None are likely to be different.
|
||||
# See bulk_get_push_rules_for_room for how we work around this.
|
||||
assert state_group is not None
|
||||
|
||||
cache = yield self._get_joined_hosts_cache(room_id)
|
||||
joined_hosts = yield cache.get_destinations(state_entry)
|
||||
|
||||
return joined_hosts
|
||||
cache = await self._get_joined_hosts_cache(room_id)
|
||||
return await cache.get_destinations(state_entry)
|
||||
|
||||
@cached(max_entries=10000)
|
||||
def _get_joined_hosts_cache(self, room_id):
|
||||
def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache":
|
||||
return _JoinedHostsCache(self, room_id)
|
||||
|
||||
@cachedInlineCallbacks(num_args=2)
|
||||
def did_forget(self, user_id, room_id):
|
||||
@cached(num_args=2)
|
||||
async def did_forget(self, user_id: str, room_id: str) -> bool:
|
||||
"""Returns whether user_id has elected to discard history for room_id.
|
||||
|
||||
Returns False if they have since re-joined."""
|
||||
@ -778,15 +712,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
rows = txn.fetchall()
|
||||
return rows[0][0]
|
||||
|
||||
count = yield self.db_pool.runInteraction("did_forget_membership", f)
|
||||
count = await self.db_pool.runInteraction("did_forget_membership", f)
|
||||
return count == 0
|
||||
|
||||
@cached()
|
||||
def get_forgotten_rooms_for_user(self, user_id):
|
||||
def get_forgotten_rooms_for_user(self, user_id: str):
|
||||
"""Gets all rooms the user has forgotten.
|
||||
|
||||
Args:
|
||||
user_id (str)
|
||||
user_id
|
||||
|
||||
Returns:
|
||||
Deferred[set[str]]
|
||||
@ -819,18 +753,17 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_rooms_user_has_been_in(self, user_id):
|
||||
async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]:
|
||||
"""Get all rooms that the user has ever been in.
|
||||
|
||||
Args:
|
||||
user_id (str)
|
||||
user_id: The user ID to get the rooms of.
|
||||
|
||||
Returns:
|
||||
Deferred[set[str]]: Set of room IDs.
|
||||
Set of room IDs.
|
||||
"""
|
||||
|
||||
room_ids = yield self.db_pool.simple_select_onecol(
|
||||
room_ids = await self.db_pool.simple_select_onecol(
|
||||
table="room_memberships",
|
||||
keyvalues={"membership": Membership.JOIN, "user_id": user_id},
|
||||
retcol="room_id",
|
||||
@ -905,8 +838,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
||||
where_clause="forgotten = 1",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _background_add_membership_profile(self, progress, batch_size):
|
||||
async def _background_add_membership_profile(self, progress, batch_size):
|
||||
target_min_stream_id = progress.get(
|
||||
"target_min_stream_id_inclusive", self._min_stream_order_on_start
|
||||
)
|
||||
@ -971,19 +903,18 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
||||
|
||||
return len(rows)
|
||||
|
||||
result = yield self.db_pool.runInteraction(
|
||||
result = await self.db_pool.runInteraction(
|
||||
_MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
|
||||
)
|
||||
|
||||
if not result:
|
||||
yield self.db_pool.updates._end_background_update(
|
||||
await self.db_pool.updates._end_background_update(
|
||||
_MEMBERSHIP_PROFILE_UPDATE_NAME
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _background_current_state_membership(self, progress, batch_size):
|
||||
async def _background_current_state_membership(self, progress, batch_size):
|
||||
"""Update the new membership column on current_state_events.
|
||||
|
||||
This works by iterating over all rooms in alphebetical order.
|
||||
@ -1029,14 +960,14 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
||||
# string, which will compare before all room IDs correctly.
|
||||
last_processed_room = progress.get("last_processed_room", "")
|
||||
|
||||
row_count, finished = yield self.db_pool.runInteraction(
|
||||
row_count, finished = await self.db_pool.runInteraction(
|
||||
"_background_current_state_membership_update",
|
||||
_background_current_state_membership_txn,
|
||||
last_processed_room,
|
||||
)
|
||||
|
||||
if finished:
|
||||
yield self.db_pool.updates._end_background_update(
|
||||
await self.db_pool.updates._end_background_update(
|
||||
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME
|
||||
)
|
||||
|
||||
@ -1047,7 +978,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super(RoomMemberStore, self).__init__(database, db_conn, hs)
|
||||
|
||||
def forget(self, user_id, room_id):
|
||||
def forget(self, user_id: str, room_id: str):
|
||||
"""Indicate that user_id wishes to discard history for room_id."""
|
||||
|
||||
def f(txn):
|
||||
@ -1088,17 +1019,19 @@ class _JoinedHostsCache(object):
|
||||
|
||||
self._len = 0
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_destinations(self, state_entry):
|
||||
async def get_destinations(self, state_entry: "_StateCacheEntry") -> Set[str]:
|
||||
"""Get set of destinations for a state entry
|
||||
|
||||
Args:
|
||||
state_entry(synapse.state._StateCacheEntry)
|
||||
state_entry
|
||||
|
||||
Returns:
|
||||
The destinations as a set.
|
||||
"""
|
||||
if state_entry.state_group == self.state_group:
|
||||
return frozenset(self.hosts_to_joined_users)
|
||||
|
||||
with (yield self.linearizer.queue(())):
|
||||
with (await self.linearizer.queue(())):
|
||||
if state_entry.state_group == self.state_group:
|
||||
pass
|
||||
elif state_entry.prev_group == self.state_group:
|
||||
@ -1110,7 +1043,7 @@ class _JoinedHostsCache(object):
|
||||
user_id = state_key
|
||||
known_joins = self.hosts_to_joined_users.setdefault(host, set())
|
||||
|
||||
event = yield self.store.get_event(event_id)
|
||||
event = await self.store.get_event(event_id)
|
||||
if event.membership == Membership.JOIN:
|
||||
known_joins.add(user_id)
|
||||
else:
|
||||
@ -1119,7 +1052,7 @@ class _JoinedHostsCache(object):
|
||||
if not known_joins:
|
||||
self.hosts_to_joined_users.pop(host, None)
|
||||
else:
|
||||
joined_users = yield self.store.get_joined_users_from_state(
|
||||
joined_users = await self.store.get_joined_users_from_state(
|
||||
self.room_id, state_entry
|
||||
)
|
||||
|
||||
|
@ -1,3 +1,18 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
|
||||
@ -10,6 +25,7 @@ from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver
|
||||
from tests.test_utils import make_awaitable
|
||||
|
||||
|
||||
class MessageAcceptTests(unittest.HomeserverTestCase):
|
||||
@ -173,7 +189,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
|
||||
# Register a mock on the store so that the incoming update doesn't fail because
|
||||
# we don't share a room with the user.
|
||||
store = self.homeserver.get_datastore()
|
||||
store.get_rooms_for_user = Mock(return_value=succeed(["!someroom:test"]))
|
||||
store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"]))
|
||||
|
||||
# Manually inject a fake device list update. We need this update to include at
|
||||
# least one prev_id so that the user's device list will need to be retried.
|
||||
|
Loading…
Reference in New Issue
Block a user