mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-08-03 08:06:03 -04:00
Merge branch 'develop' of github.com:matrix-org/synapse into babolivier/new_push_rules
This commit is contained in:
commit
69158e554f
230 changed files with 5343 additions and 4040 deletions
|
@ -172,6 +172,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
|||
|
||||
self.get_latest_event_ids_in_room.invalidate((room_id,))
|
||||
|
||||
self.get_unread_message_count_for_user.invalidate_many((room_id,))
|
||||
self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
|
||||
|
||||
if not backfilled:
|
||||
|
|
|
@ -15,11 +15,10 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import Database
|
||||
|
@ -166,8 +165,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
|
||||
return {"notify_count": notify_count, "highlight_count": highlight_count}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering):
|
||||
async def get_push_action_users_in_range(
|
||||
self, min_stream_ordering, max_stream_ordering
|
||||
):
|
||||
def f(txn):
|
||||
sql = (
|
||||
"SELECT DISTINCT(user_id) FROM event_push_actions WHERE"
|
||||
|
@ -176,26 +176,28 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
|
||||
return [r[0] for r in txn]
|
||||
|
||||
ret = yield self.db.runInteraction("get_push_action_users_in_range", f)
|
||||
ret = await self.db.runInteraction("get_push_action_users_in_range", f)
|
||||
return ret
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_unread_push_actions_for_user_in_range_for_http(
|
||||
self, user_id, min_stream_ordering, max_stream_ordering, limit=20
|
||||
):
|
||||
async def get_unread_push_actions_for_user_in_range_for_http(
|
||||
self,
|
||||
user_id: str,
|
||||
min_stream_ordering: int,
|
||||
max_stream_ordering: int,
|
||||
limit: int = 20,
|
||||
) -> List[dict]:
|
||||
"""Get a list of the most recent unread push actions for a given user,
|
||||
within the given stream ordering range. Called by the httppusher.
|
||||
|
||||
Args:
|
||||
user_id (str): The user to fetch push actions for.
|
||||
min_stream_ordering(int): The exclusive lower bound on the
|
||||
user_id: The user to fetch push actions for.
|
||||
min_stream_ordering: The exclusive lower bound on the
|
||||
stream ordering of event push actions to fetch.
|
||||
max_stream_ordering(int): The inclusive upper bound on the
|
||||
max_stream_ordering: The inclusive upper bound on the
|
||||
stream ordering of event push actions to fetch.
|
||||
limit (int): The maximum number of rows to return.
|
||||
limit: The maximum number of rows to return.
|
||||
Returns:
|
||||
A promise which resolves to a list of dicts with the keys "event_id",
|
||||
"room_id", "stream_ordering", "actions".
|
||||
A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions".
|
||||
The list will be ordered by ascending stream_ordering.
|
||||
The list will have between 0~limit entries.
|
||||
"""
|
||||
|
@ -228,7 +230,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
txn.execute(sql, args)
|
||||
return txn.fetchall()
|
||||
|
||||
after_read_receipt = yield self.db.runInteraction(
|
||||
after_read_receipt = await self.db.runInteraction(
|
||||
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
|
||||
)
|
||||
|
||||
|
@ -256,7 +258,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
txn.execute(sql, args)
|
||||
return txn.fetchall()
|
||||
|
||||
no_read_receipt = yield self.db.runInteraction(
|
||||
no_read_receipt = await self.db.runInteraction(
|
||||
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
|
||||
)
|
||||
|
||||
|
@ -280,23 +282,25 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
# one of the subqueries may have hit the limit.
|
||||
return notifs[:limit]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_unread_push_actions_for_user_in_range_for_email(
|
||||
self, user_id, min_stream_ordering, max_stream_ordering, limit=20
|
||||
):
|
||||
async def get_unread_push_actions_for_user_in_range_for_email(
|
||||
self,
|
||||
user_id: str,
|
||||
min_stream_ordering: int,
|
||||
max_stream_ordering: int,
|
||||
limit: int = 20,
|
||||
) -> List[dict]:
|
||||
"""Get a list of the most recent unread push actions for a given user,
|
||||
within the given stream ordering range. Called by the emailpusher
|
||||
|
||||
Args:
|
||||
user_id (str): The user to fetch push actions for.
|
||||
min_stream_ordering(int): The exclusive lower bound on the
|
||||
user_id: The user to fetch push actions for.
|
||||
min_stream_ordering: The exclusive lower bound on the
|
||||
stream ordering of event push actions to fetch.
|
||||
max_stream_ordering(int): The inclusive upper bound on the
|
||||
max_stream_ordering: The inclusive upper bound on the
|
||||
stream ordering of event push actions to fetch.
|
||||
limit (int): The maximum number of rows to return.
|
||||
limit: The maximum number of rows to return.
|
||||
Returns:
|
||||
A promise which resolves to a list of dicts with the keys "event_id",
|
||||
"room_id", "stream_ordering", "actions", "received_ts".
|
||||
A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions", "received_ts".
|
||||
The list will be ordered by descending received_ts.
|
||||
The list will have between 0~limit entries.
|
||||
"""
|
||||
|
@ -328,7 +332,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
txn.execute(sql, args)
|
||||
return txn.fetchall()
|
||||
|
||||
after_read_receipt = yield self.db.runInteraction(
|
||||
after_read_receipt = await self.db.runInteraction(
|
||||
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
|
||||
)
|
||||
|
||||
|
@ -356,7 +360,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
txn.execute(sql, args)
|
||||
return txn.fetchall()
|
||||
|
||||
no_read_receipt = yield self.db.runInteraction(
|
||||
no_read_receipt = await self.db.runInteraction(
|
||||
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
|
||||
)
|
||||
|
||||
|
@ -411,7 +415,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
_get_if_maybe_push_in_range_for_user_txn,
|
||||
)
|
||||
|
||||
def add_push_actions_to_staging(self, event_id, user_id_actions):
|
||||
async def add_push_actions_to_staging(self, event_id, user_id_actions):
|
||||
"""Add the push actions for the event to the push action staging area.
|
||||
|
||||
Args:
|
||||
|
@ -457,21 +461,17 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
),
|
||||
)
|
||||
|
||||
return self.db.runInteraction(
|
||||
return await self.db.runInteraction(
|
||||
"add_push_actions_to_staging", _add_push_actions_to_staging_txn
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def remove_push_actions_from_staging(self, event_id):
|
||||
async def remove_push_actions_from_staging(self, event_id: str) -> None:
|
||||
"""Called if we failed to persist the event to ensure that stale push
|
||||
actions don't build up in the DB
|
||||
|
||||
Args:
|
||||
event_id (str)
|
||||
"""
|
||||
|
||||
try:
|
||||
res = yield self.db.simple_delete(
|
||||
res = await self.db.simple_delete(
|
||||
table="event_push_actions_staging",
|
||||
keyvalues={"event_id": event_id},
|
||||
desc="remove_push_actions_from_staging",
|
||||
|
@ -606,8 +606,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
|
||||
return range_end
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_time_of_last_push_action_before(self, stream_ordering):
|
||||
async def get_time_of_last_push_action_before(self, stream_ordering):
|
||||
def f(txn):
|
||||
sql = (
|
||||
"SELECT e.received_ts"
|
||||
|
@ -620,7 +619,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
txn.execute(sql, (stream_ordering,))
|
||||
return txn.fetchone()
|
||||
|
||||
result = yield self.db.runInteraction("get_time_of_last_push_action_before", f)
|
||||
result = await self.db.runInteraction("get_time_of_last_push_action_before", f)
|
||||
return result[0] if result else None
|
||||
|
||||
|
||||
|
@ -650,8 +649,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
|||
self._start_rotate_notifs, 30 * 60 * 1000
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_push_actions_for_user(
|
||||
async def get_push_actions_for_user(
|
||||
self, user_id, before=None, limit=50, only_highlight=False
|
||||
):
|
||||
def f(txn):
|
||||
|
@ -682,18 +680,17 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
|||
txn.execute(sql, args)
|
||||
return self.db.cursor_to_dict(txn)
|
||||
|
||||
push_actions = yield self.db.runInteraction("get_push_actions_for_user", f)
|
||||
push_actions = await self.db.runInteraction("get_push_actions_for_user", f)
|
||||
for pa in push_actions:
|
||||
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
|
||||
return push_actions
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_latest_push_action_stream_ordering(self):
|
||||
async def get_latest_push_action_stream_ordering(self):
|
||||
def f(txn):
|
||||
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
|
||||
return txn.fetchone()
|
||||
|
||||
result = yield self.db.runInteraction(
|
||||
result = await self.db.runInteraction(
|
||||
"get_latest_push_action_stream_ordering", f
|
||||
)
|
||||
return result[0] or 0
|
||||
|
@ -747,8 +744,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
|||
def _start_rotate_notifs(self):
|
||||
return run_as_background_process("rotate_notifs", self._rotate_notifs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _rotate_notifs(self):
|
||||
async def _rotate_notifs(self):
|
||||
if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
|
||||
return
|
||||
self._doing_notif_rotation = True
|
||||
|
@ -757,12 +753,12 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
|||
while True:
|
||||
logger.info("Rotating notifications")
|
||||
|
||||
caught_up = yield self.db.runInteraction(
|
||||
caught_up = await self.db.runInteraction(
|
||||
"_rotate_notifs", self._rotate_notifs_txn
|
||||
)
|
||||
if caught_up:
|
||||
break
|
||||
yield self.hs.get_clock().sleep(self._rotate_delay)
|
||||
await self.hs.get_clock().sleep(self._rotate_delay)
|
||||
finally:
|
||||
self._doing_notif_rotation = False
|
||||
|
||||
|
|
|
@ -53,6 +53,47 @@ event_counter = Counter(
|
|||
["type", "origin_type", "origin_entity"],
|
||||
)
|
||||
|
||||
STATE_EVENT_TYPES_TO_MARK_UNREAD = {
|
||||
EventTypes.Topic,
|
||||
EventTypes.Name,
|
||||
EventTypes.RoomAvatar,
|
||||
EventTypes.Tombstone,
|
||||
}
|
||||
|
||||
|
||||
def should_count_as_unread(event: EventBase, context: EventContext) -> bool:
|
||||
# Exclude rejected and soft-failed events.
|
||||
if context.rejected or event.internal_metadata.is_soft_failed():
|
||||
return False
|
||||
|
||||
# Exclude notices.
|
||||
if (
|
||||
not event.is_state()
|
||||
and event.type == EventTypes.Message
|
||||
and event.content.get("msgtype") == "m.notice"
|
||||
):
|
||||
return False
|
||||
|
||||
# Exclude edits.
|
||||
relates_to = event.content.get("m.relates_to", {})
|
||||
if relates_to.get("rel_type") == RelationTypes.REPLACE:
|
||||
return False
|
||||
|
||||
# Mark events that have a non-empty string body as unread.
|
||||
body = event.content.get("body")
|
||||
if isinstance(body, str) and body:
|
||||
return True
|
||||
|
||||
# Mark some state events as unread.
|
||||
if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD:
|
||||
return True
|
||||
|
||||
# Mark encrypted events as unread.
|
||||
if not event.is_state() and event.type == EventTypes.Encrypted:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def encode_json(json_object):
|
||||
"""
|
||||
|
@ -196,6 +237,10 @@ class PersistEventsStore:
|
|||
|
||||
event_counter.labels(event.type, origin_type, origin_entity).inc()
|
||||
|
||||
self.store.get_unread_message_count_for_user.invalidate_many(
|
||||
(event.room_id,),
|
||||
)
|
||||
|
||||
for room_id, new_state in current_state_for_room.items():
|
||||
self.store.get_current_state_ids.prefill((room_id,), new_state)
|
||||
|
||||
|
@ -817,8 +862,9 @@ class PersistEventsStore:
|
|||
"contains_url": (
|
||||
"url" in event.content and isinstance(event.content["url"], str)
|
||||
),
|
||||
"count_as_unread": should_count_as_unread(event, context),
|
||||
}
|
||||
for event, _ in events_and_contexts
|
||||
for event, context in events_and_contexts
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -41,9 +41,15 @@ from synapse.replication.tcp.streams import BackfillStream
|
|||
from synapse.replication.tcp.streams.events import EventsStream
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
|
||||
from synapse.util.caches.descriptors import (
|
||||
Cache,
|
||||
_CacheContext,
|
||||
cached,
|
||||
cachedInlineCallbacks,
|
||||
)
|
||||
from synapse.util.iterutils import batch_iter
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
|
@ -1358,6 +1364,84 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
|
||||
)
|
||||
|
||||
@cached(tree=True, cache_context=True)
|
||||
async def get_unread_message_count_for_user(
|
||||
self, room_id: str, user_id: str, cache_context: _CacheContext,
|
||||
) -> int:
|
||||
"""Retrieve the count of unread messages for the given room and user.
|
||||
|
||||
Args:
|
||||
room_id: The ID of the room to count unread messages in.
|
||||
user_id: The ID of the user to count unread messages for.
|
||||
|
||||
Returns:
|
||||
The number of unread messages for the given user in the given room.
|
||||
"""
|
||||
with Measure(self._clock, "get_unread_message_count_for_user"):
|
||||
last_read_event_id = await self.get_last_receipt_event_id_for_user(
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
receipt_type="m.read",
|
||||
on_invalidate=cache_context.invalidate,
|
||||
)
|
||||
|
||||
return await self.db.runInteraction(
|
||||
"get_unread_message_count_for_user",
|
||||
self._get_unread_message_count_for_user_txn,
|
||||
user_id,
|
||||
room_id,
|
||||
last_read_event_id,
|
||||
)
|
||||
|
||||
def _get_unread_message_count_for_user_txn(
|
||||
self,
|
||||
txn: Cursor,
|
||||
user_id: str,
|
||||
room_id: str,
|
||||
last_read_event_id: Optional[str],
|
||||
) -> int:
|
||||
if last_read_event_id:
|
||||
# Get the stream ordering for the last read event.
|
||||
stream_ordering = self.db.simple_select_one_onecol_txn(
|
||||
txn=txn,
|
||||
table="events",
|
||||
keyvalues={"room_id": room_id, "event_id": last_read_event_id},
|
||||
retcol="stream_ordering",
|
||||
)
|
||||
else:
|
||||
# If there's no read receipt for that room, it probably means the user hasn't
|
||||
# opened it yet, in which case use the stream ID of their join event.
|
||||
# We can't just set it to 0 otherwise messages from other local users from
|
||||
# before this user joined will be counted as well.
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT stream_ordering FROM local_current_membership
|
||||
LEFT JOIN events USING (event_id, room_id)
|
||||
WHERE membership = 'join'
|
||||
AND user_id = ?
|
||||
AND room_id = ?
|
||||
""",
|
||||
(user_id, room_id),
|
||||
)
|
||||
row = txn.fetchone()
|
||||
|
||||
if row is None:
|
||||
return 0
|
||||
|
||||
stream_ordering = row[0]
|
||||
|
||||
# Count the messages that qualify as unread after the stream ordering we've just
|
||||
# retrieved.
|
||||
sql = """
|
||||
SELECT COUNT(*) FROM events
|
||||
WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread
|
||||
"""
|
||||
|
||||
txn.execute(sql, (user_id, room_id, stream_ordering))
|
||||
row = txn.fetchone()
|
||||
|
||||
return row[0] if row else 0
|
||||
|
||||
|
||||
AllNewEventsResult = namedtuple(
|
||||
"AllNewEventsResult",
|
||||
|
|
|
@ -62,6 +62,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
|||
# event_json
|
||||
# event_push_actions
|
||||
# event_reference_hashes
|
||||
# event_relations
|
||||
# event_search
|
||||
# event_to_state_groups
|
||||
# events
|
||||
|
@ -209,6 +210,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
|||
"event_edges",
|
||||
"event_forward_extremities",
|
||||
"event_reference_hashes",
|
||||
"event_relations",
|
||||
"event_search",
|
||||
"rejections",
|
||||
):
|
||||
|
|
|
@ -284,7 +284,7 @@ class PushRulesWorkerStore(
|
|||
# To do this we set the state_group to a new object as object() != object()
|
||||
state_group = object()
|
||||
|
||||
current_state_ids = yield context.get_current_state_ids()
|
||||
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
|
||||
result = yield self._bulk_get_push_rules_for_room(
|
||||
event.room_id, state_group, current_state_ids, event=event
|
||||
)
|
||||
|
|
|
@ -23,8 +23,6 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
|
||||
from canonicaljson import json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.api.room_versions import RoomVersion, RoomVersions
|
||||
|
@ -32,7 +30,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json
|
|||
from synapse.storage.data_stores.main.search import SearchStore
|
||||
from synapse.storage.database import Database, LoggingTransaction
|
||||
from synapse.types import ThirdPartyInstanceID
|
||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -192,8 +190,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
|
||||
return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_largest_public_rooms(
|
||||
async def get_largest_public_rooms(
|
||||
self,
|
||||
network_tuple: Optional[ThirdPartyInstanceID],
|
||||
search_filter: Optional[dict],
|
||||
|
@ -330,10 +327,10 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
|
||||
return results
|
||||
|
||||
ret_val = yield self.db.runInteraction(
|
||||
ret_val = await self.db.runInteraction(
|
||||
"get_largest_public_rooms", _get_largest_public_rooms_txn
|
||||
)
|
||||
defer.returnValue(ret_val)
|
||||
return ret_val
|
||||
|
||||
@cached(max_entries=10000)
|
||||
def is_room_blocked(self, room_id):
|
||||
|
@ -509,8 +506,8 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
"get_rooms_paginate", _get_rooms_paginate_txn,
|
||||
)
|
||||
|
||||
@cachedInlineCallbacks(max_entries=10000)
|
||||
def get_ratelimit_for_user(self, user_id):
|
||||
@cached(max_entries=10000)
|
||||
async def get_ratelimit_for_user(self, user_id):
|
||||
"""Check if there are any overrides for ratelimiting for the given
|
||||
user
|
||||
|
||||
|
@ -522,7 +519,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
of RatelimitOverride are None or 0 then ratelimitng has been
|
||||
disabled for that user entirely.
|
||||
"""
|
||||
row = yield self.db.simple_select_one(
|
||||
row = await self.db.simple_select_one(
|
||||
table="ratelimit_override",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("messages_per_second", "burst_count"),
|
||||
|
@ -538,8 +535,8 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
else:
|
||||
return None
|
||||
|
||||
@cachedInlineCallbacks()
|
||||
def get_retention_policy_for_room(self, room_id):
|
||||
@cached()
|
||||
async def get_retention_policy_for_room(self, room_id):
|
||||
"""Get the retention policy for a given room.
|
||||
|
||||
If no retention policy has been found for this room, returns a policy defined
|
||||
|
@ -566,19 +563,17 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
|
||||
return self.db.cursor_to_dict(txn)
|
||||
|
||||
ret = yield self.db.runInteraction(
|
||||
ret = await self.db.runInteraction(
|
||||
"get_retention_policy_for_room", get_retention_policy_for_room_txn,
|
||||
)
|
||||
|
||||
# If we don't know this room ID, ret will be None, in this case return the default
|
||||
# policy.
|
||||
if not ret:
|
||||
defer.returnValue(
|
||||
{
|
||||
"min_lifetime": self.config.retention_default_min_lifetime,
|
||||
"max_lifetime": self.config.retention_default_max_lifetime,
|
||||
}
|
||||
)
|
||||
return {
|
||||
"min_lifetime": self.config.retention_default_min_lifetime,
|
||||
"max_lifetime": self.config.retention_default_max_lifetime,
|
||||
}
|
||||
|
||||
row = ret[0]
|
||||
|
||||
|
@ -592,7 +587,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
if row["max_lifetime"] is None:
|
||||
row["max_lifetime"] = self.config.retention_default_max_lifetime
|
||||
|
||||
defer.returnValue(row)
|
||||
return row
|
||||
|
||||
def get_media_mxcs_in_room(self, room_id):
|
||||
"""Retrieves all the local and remote media MXC URIs in a given room
|
||||
|
@ -881,8 +876,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
|
|||
self._background_add_rooms_room_version_column,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _background_insert_retention(self, progress, batch_size):
|
||||
async def _background_insert_retention(self, progress, batch_size):
|
||||
"""Retrieves a list of all rooms within a range and inserts an entry for each of
|
||||
them into the room_retention table.
|
||||
NULLs the property's columns if missing from the retention event in the room's
|
||||
|
@ -940,14 +934,14 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
|
|||
else:
|
||||
return False
|
||||
|
||||
end = yield self.db.runInteraction(
|
||||
end = await self.db.runInteraction(
|
||||
"insert_room_retention", _background_insert_retention_txn,
|
||||
)
|
||||
|
||||
if end:
|
||||
yield self.db.updates._end_background_update("insert_room_retention")
|
||||
await self.db.updates._end_background_update("insert_room_retention")
|
||||
|
||||
defer.returnValue(batch_size)
|
||||
return batch_size
|
||||
|
||||
async def _background_add_rooms_room_version_column(
|
||||
self, progress: dict, batch_size: int
|
||||
|
@ -1096,8 +1090,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
lock=False,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def store_room(
|
||||
async def store_room(
|
||||
self,
|
||||
room_id: str,
|
||||
room_creator_user_id: str,
|
||||
|
@ -1140,7 +1133,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
)
|
||||
|
||||
with self._public_room_id_gen.get_next() as next_id:
|
||||
yield self.db.runInteraction("store_room_txn", store_room_txn, next_id)
|
||||
await self.db.runInteraction("store_room_txn", store_room_txn, next_id)
|
||||
except Exception as e:
|
||||
logger.error("store_room with room_id=%s failed: %s", room_id, e)
|
||||
raise StoreError(500, "Problem creating room.")
|
||||
|
@ -1165,8 +1158,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
lock=False,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_room_is_public(self, room_id, is_public):
|
||||
async def set_room_is_public(self, room_id, is_public):
|
||||
def set_room_is_public_txn(txn, next_id):
|
||||
self.db.simple_update_one_txn(
|
||||
txn,
|
||||
|
@ -1206,13 +1198,12 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
)
|
||||
|
||||
with self._public_room_id_gen.get_next() as next_id:
|
||||
yield self.db.runInteraction(
|
||||
await self.db.runInteraction(
|
||||
"set_room_is_public", set_room_is_public_txn, next_id
|
||||
)
|
||||
self.hs.get_notifier().on_new_replication_data()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_room_is_public_appservice(
|
||||
async def set_room_is_public_appservice(
|
||||
self, room_id, appservice_id, network_id, is_public
|
||||
):
|
||||
"""Edit the appservice/network specific public room list.
|
||||
|
@ -1287,7 +1278,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
)
|
||||
|
||||
with self._public_room_id_gen.get_next() as next_id:
|
||||
yield self.db.runInteraction(
|
||||
await self.db.runInteraction(
|
||||
"set_room_is_public_appservice",
|
||||
set_room_is_public_appservice_txn,
|
||||
next_id,
|
||||
|
@ -1327,52 +1318,47 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
def get_current_public_room_stream_id(self):
|
||||
return self._public_room_id_gen.get_current_token()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def block_room(self, room_id, user_id):
|
||||
async def block_room(self, room_id: str, user_id: str) -> None:
|
||||
"""Marks the room as blocked. Can be called multiple times.
|
||||
|
||||
Args:
|
||||
room_id (str): Room to block
|
||||
user_id (str): Who blocked it
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
room_id: Room to block
|
||||
user_id: Who blocked it
|
||||
"""
|
||||
yield self.db.simple_upsert(
|
||||
await self.db.simple_upsert(
|
||||
table="blocked_rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
values={},
|
||||
insertion_values={"user_id": user_id},
|
||||
desc="block_room",
|
||||
)
|
||||
yield self.db.runInteraction(
|
||||
await self.db.runInteraction(
|
||||
"block_room_invalidation",
|
||||
self._invalidate_cache_and_stream,
|
||||
self.is_room_blocked,
|
||||
(room_id,),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_rooms_for_retention_period_in_range(
|
||||
self, min_ms, max_ms, include_null=False
|
||||
):
|
||||
async def get_rooms_for_retention_period_in_range(
|
||||
self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
|
||||
) -> Dict[str, dict]:
|
||||
"""Retrieves all of the rooms within the given retention range.
|
||||
|
||||
Optionally includes the rooms which don't have a retention policy.
|
||||
|
||||
Args:
|
||||
min_ms (int|None): Duration in milliseconds that define the lower limit of
|
||||
min_ms: Duration in milliseconds that define the lower limit of
|
||||
the range to handle (exclusive). If None, doesn't set a lower limit.
|
||||
max_ms (int|None): Duration in milliseconds that define the upper limit of
|
||||
max_ms: Duration in milliseconds that define the upper limit of
|
||||
the range to handle (inclusive). If None, doesn't set an upper limit.
|
||||
include_null (bool): Whether to include rooms which retention policy is NULL
|
||||
include_null: Whether to include rooms which retention policy is NULL
|
||||
in the returned set.
|
||||
|
||||
Returns:
|
||||
dict[str, dict]: The rooms within this range, along with their retention
|
||||
policy. The key is "room_id", and maps to a dict describing the retention
|
||||
policy associated with this room ID. The keys for this nested dict are
|
||||
"min_lifetime" (int|None), and "max_lifetime" (int|None).
|
||||
The rooms within this range, along with their retention
|
||||
policy. The key is "room_id", and maps to a dict describing the retention
|
||||
policy associated with this room ID. The keys for this nested dict are
|
||||
"min_lifetime" (int|None), and "max_lifetime" (int|None).
|
||||
"""
|
||||
|
||||
def get_rooms_for_retention_period_in_range_txn(txn):
|
||||
|
@ -1431,9 +1417,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
|
||||
return rooms_dict
|
||||
|
||||
rooms = yield self.db.runInteraction(
|
||||
rooms = await self.db.runInteraction(
|
||||
"get_rooms_for_retention_period_in_range",
|
||||
get_rooms_for_retention_period_in_range_txn,
|
||||
)
|
||||
|
||||
defer.returnValue(rooms)
|
||||
return rooms
|
||||
|
|
|
@ -497,7 +497,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
# To do this we set the state_group to a new object as object() != object()
|
||||
state_group = object()
|
||||
|
||||
current_state_ids = yield context.get_current_state_ids()
|
||||
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
|
||||
result = yield self._get_joined_users_from_context(
|
||||
event.room_id, state_group, current_state_ids, event=event, context=context
|
||||
)
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- Store a boolean value in the events table for whether the event should be counted in
|
||||
-- the unread_count property of sync responses.
|
||||
ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN;
|
|
@ -16,12 +16,12 @@
|
|||
import collections.abc
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
|
||||
from twisted.internet import defer
|
||||
from typing import Iterable, Optional, Set
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||
from synapse.events import EventBase
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
|
||||
|
@ -108,28 +108,27 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
create_event = await self.get_create_event_for_room(room_id)
|
||||
return create_event.content.get("room_version", "1")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_predecessor(self, room_id):
|
||||
async def get_room_predecessor(self, room_id: str) -> Optional[dict]:
|
||||
"""Get the predecessor of an upgraded room if it exists.
|
||||
Otherwise return None.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
room_id: The room ID.
|
||||
|
||||
Returns:
|
||||
Deferred[dict|None]: A dictionary containing the structure of the predecessor
|
||||
field from the room's create event. The structure is subject to other servers,
|
||||
but it is expected to be:
|
||||
* room_id (str): The room ID of the predecessor room
|
||||
* event_id (str): The ID of the tombstone event in the predecessor room
|
||||
A dictionary containing the structure of the predecessor
|
||||
field from the room's create event. The structure is subject to other servers,
|
||||
but it is expected to be:
|
||||
* room_id (str): The room ID of the predecessor room
|
||||
* event_id (str): The ID of the tombstone event in the predecessor room
|
||||
|
||||
None if a predecessor key is not found, or is not a dictionary.
|
||||
None if a predecessor key is not found, or is not a dictionary.
|
||||
|
||||
Raises:
|
||||
NotFoundError if the given room is unknown
|
||||
"""
|
||||
# Retrieve the room's create event
|
||||
create_event = yield self.get_create_event_for_room(room_id)
|
||||
create_event = await self.get_create_event_for_room(room_id)
|
||||
|
||||
# Retrieve the predecessor key of the create event
|
||||
predecessor = create_event.content.get("predecessor", None)
|
||||
|
@ -140,20 +139,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
return predecessor
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_create_event_for_room(self, room_id):
|
||||
async def get_create_event_for_room(self, room_id: str) -> EventBase:
|
||||
"""Get the create state event for a room.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
room_id: The room ID.
|
||||
|
||||
Returns:
|
||||
Deferred[EventBase]: The room creation event.
|
||||
The room creation event.
|
||||
|
||||
Raises:
|
||||
NotFoundError if the room is unknown
|
||||
"""
|
||||
state_ids = yield self.get_current_state_ids(room_id)
|
||||
state_ids = await self.get_current_state_ids(room_id)
|
||||
create_id = state_ids.get((EventTypes.Create, ""))
|
||||
|
||||
# If we can't find the create event, assume we've hit a dead end
|
||||
|
@ -161,7 +159,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
raise NotFoundError("Unknown room %s" % (room_id,))
|
||||
|
||||
# Retrieve the room's create event and return
|
||||
create_event = yield self.get_event(create_id)
|
||||
create_event = await self.get_event(create_id)
|
||||
return create_event
|
||||
|
||||
@cached(max_entries=100000, iterable=True)
|
||||
|
@ -237,18 +235,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_canonical_alias_for_room(self, room_id):
|
||||
async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
|
||||
"""Get canonical alias for room, if any
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
room_id: The room ID
|
||||
|
||||
Returns:
|
||||
Deferred[str|None]: The canonical alias, if any
|
||||
The canonical alias, if any
|
||||
"""
|
||||
|
||||
state = yield self.get_filtered_current_state_ids(
|
||||
state = await self.get_filtered_current_state_ids(
|
||||
room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
|
||||
)
|
||||
|
||||
|
@ -256,7 +253,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
if not event_id:
|
||||
return
|
||||
|
||||
event = yield self.get_event(event_id, allow_none=True)
|
||||
event = await self.get_event(event_id, allow_none=True)
|
||||
if not event:
|
||||
return
|
||||
|
||||
|
@ -292,19 +289,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
return {row["event_id"]: row["state_group"] for row in rows}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_referenced_state_groups(self, state_groups):
|
||||
async def get_referenced_state_groups(
|
||||
self, state_groups: Iterable[int]
|
||||
) -> Set[int]:
|
||||
"""Check if the state groups are referenced by events.
|
||||
|
||||
Args:
|
||||
state_groups (Iterable[int])
|
||||
state_groups
|
||||
|
||||
Returns:
|
||||
Deferred[set[int]]: The subset of state groups that are
|
||||
referenced.
|
||||
The subset of state groups that are referenced.
|
||||
"""
|
||||
|
||||
rows = yield self.db.simple_select_many_batch(
|
||||
rows = await self.db.simple_select_many_batch(
|
||||
table="event_to_state_groups",
|
||||
column="state_group",
|
||||
iterable=state_groups,
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
|
||||
import logging
|
||||
from itertools import chain
|
||||
from typing import Tuple
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import DeferredLock
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
|
@ -97,13 +97,12 @@ class StatsStore(StateDeltasStore):
|
|||
"""
|
||||
return (ts // self.stats_bucket_size) * self.stats_bucket_size
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _populate_stats_process_users(self, progress, batch_size):
|
||||
async def _populate_stats_process_users(self, progress, batch_size):
|
||||
"""
|
||||
This is a background update which regenerates statistics for users.
|
||||
"""
|
||||
if not self.stats_enabled:
|
||||
yield self.db.updates._end_background_update("populate_stats_process_users")
|
||||
await self.db.updates._end_background_update("populate_stats_process_users")
|
||||
return 1
|
||||
|
||||
last_user_id = progress.get("last_user_id", "")
|
||||
|
@ -118,20 +117,20 @@ class StatsStore(StateDeltasStore):
|
|||
txn.execute(sql, (last_user_id, batch_size))
|
||||
return [r for r, in txn]
|
||||
|
||||
users_to_work_on = yield self.db.runInteraction(
|
||||
users_to_work_on = await self.db.runInteraction(
|
||||
"_populate_stats_process_users", _get_next_batch
|
||||
)
|
||||
|
||||
# No more rooms -- complete the transaction.
|
||||
if not users_to_work_on:
|
||||
yield self.db.updates._end_background_update("populate_stats_process_users")
|
||||
await self.db.updates._end_background_update("populate_stats_process_users")
|
||||
return 1
|
||||
|
||||
for user_id in users_to_work_on:
|
||||
yield self._calculate_and_set_initial_state_for_user(user_id)
|
||||
await self._calculate_and_set_initial_state_for_user(user_id)
|
||||
progress["last_user_id"] = user_id
|
||||
|
||||
yield self.db.runInteraction(
|
||||
await self.db.runInteraction(
|
||||
"populate_stats_process_users",
|
||||
self.db.updates._background_update_progress_txn,
|
||||
"populate_stats_process_users",
|
||||
|
@ -140,13 +139,12 @@ class StatsStore(StateDeltasStore):
|
|||
|
||||
return len(users_to_work_on)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _populate_stats_process_rooms(self, progress, batch_size):
|
||||
async def _populate_stats_process_rooms(self, progress, batch_size):
|
||||
"""
|
||||
This is a background update which regenerates statistics for rooms.
|
||||
"""
|
||||
if not self.stats_enabled:
|
||||
yield self.db.updates._end_background_update("populate_stats_process_rooms")
|
||||
await self.db.updates._end_background_update("populate_stats_process_rooms")
|
||||
return 1
|
||||
|
||||
last_room_id = progress.get("last_room_id", "")
|
||||
|
@ -161,20 +159,20 @@ class StatsStore(StateDeltasStore):
|
|||
txn.execute(sql, (last_room_id, batch_size))
|
||||
return [r for r, in txn]
|
||||
|
||||
rooms_to_work_on = yield self.db.runInteraction(
|
||||
rooms_to_work_on = await self.db.runInteraction(
|
||||
"populate_stats_rooms_get_batch", _get_next_batch
|
||||
)
|
||||
|
||||
# No more rooms -- complete the transaction.
|
||||
if not rooms_to_work_on:
|
||||
yield self.db.updates._end_background_update("populate_stats_process_rooms")
|
||||
await self.db.updates._end_background_update("populate_stats_process_rooms")
|
||||
return 1
|
||||
|
||||
for room_id in rooms_to_work_on:
|
||||
yield self._calculate_and_set_initial_state_for_room(room_id)
|
||||
await self._calculate_and_set_initial_state_for_room(room_id)
|
||||
progress["last_room_id"] = room_id
|
||||
|
||||
yield self.db.runInteraction(
|
||||
await self.db.runInteraction(
|
||||
"_populate_stats_process_rooms",
|
||||
self.db.updates._background_update_progress_txn,
|
||||
"populate_stats_process_rooms",
|
||||
|
@ -696,16 +694,16 @@ class StatsStore(StateDeltasStore):
|
|||
|
||||
return room_deltas, user_deltas
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _calculate_and_set_initial_state_for_room(self, room_id):
|
||||
async def _calculate_and_set_initial_state_for_room(
|
||||
self, room_id: str
|
||||
) -> Tuple[dict, dict, int]:
|
||||
"""Calculate and insert an entry into room_stats_current.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
room_id: The room ID under calculation.
|
||||
|
||||
Returns:
|
||||
Deferred[tuple[dict, dict, int]]: A tuple of room state, membership
|
||||
counts and stream position.
|
||||
A tuple of room state, membership counts and stream position.
|
||||
"""
|
||||
|
||||
def _fetch_current_state_stats(txn):
|
||||
|
@ -767,11 +765,11 @@ class StatsStore(StateDeltasStore):
|
|||
current_state_events_count,
|
||||
users_in_room,
|
||||
pos,
|
||||
) = yield self.db.runInteraction(
|
||||
) = await self.db.runInteraction(
|
||||
"get_initial_state_for_room", _fetch_current_state_stats
|
||||
)
|
||||
|
||||
state_event_map = yield self.get_events(event_ids, get_prev_content=False)
|
||||
state_event_map = await self.get_events(event_ids, get_prev_content=False)
|
||||
|
||||
room_state = {
|
||||
"join_rules": None,
|
||||
|
@ -806,11 +804,11 @@ class StatsStore(StateDeltasStore):
|
|||
event.content.get("m.federate", True) is True
|
||||
)
|
||||
|
||||
yield self.update_room_state(room_id, room_state)
|
||||
await self.update_room_state(room_id, room_state)
|
||||
|
||||
local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)]
|
||||
|
||||
yield self.update_stats_delta(
|
||||
await self.update_stats_delta(
|
||||
ts=self.clock.time_msec(),
|
||||
stats_type="room",
|
||||
stats_id=room_id,
|
||||
|
@ -826,8 +824,7 @@ class StatsStore(StateDeltasStore):
|
|||
},
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _calculate_and_set_initial_state_for_user(self, user_id):
|
||||
async def _calculate_and_set_initial_state_for_user(self, user_id):
|
||||
def _calculate_and_set_initial_state_for_user_txn(txn):
|
||||
pos = self._get_max_stream_id_in_current_state_deltas_txn(txn)
|
||||
|
||||
|
@ -842,12 +839,12 @@ class StatsStore(StateDeltasStore):
|
|||
(count,) = txn.fetchone()
|
||||
return count, pos
|
||||
|
||||
joined_rooms, pos = yield self.db.runInteraction(
|
||||
joined_rooms, pos = await self.db.runInteraction(
|
||||
"calculate_and_set_initial_state_for_user",
|
||||
_calculate_and_set_initial_state_for_user_txn,
|
||||
)
|
||||
|
||||
yield self.update_stats_delta(
|
||||
await self.update_stats_delta(
|
||||
ts=self.clock.time_msec(),
|
||||
stats_type="user",
|
||||
stats_id=user_id,
|
||||
|
|
|
@ -255,7 +255,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
self._instance_name = hs.get_instance_name()
|
||||
self._send_federation = hs.should_send_federation()
|
||||
self._federation_shard_config = hs.config.federation.federation_shard_config
|
||||
self._federation_shard_config = hs.config.worker.federation_shard_config
|
||||
|
||||
# If we're a process that sends federation we may need to reset the
|
||||
# `federation_stream_position` table to match the current sharding
|
||||
|
|
|
@ -198,7 +198,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
room_id
|
||||
)
|
||||
|
||||
users_with_profile = yield state.get_current_users_in_room(room_id)
|
||||
users_with_profile = yield defer.ensureDeferred(
|
||||
state.get_current_users_in_room(room_id)
|
||||
)
|
||||
user_ids = set(users_with_profile)
|
||||
|
||||
# Update each user in the user directory.
|
||||
|
|
|
@ -70,11 +70,11 @@ class UserErasureWorkerStore(SQLBaseStore):
|
|||
|
||||
|
||||
class UserErasureStore(UserErasureWorkerStore):
|
||||
def mark_user_erased(self, user_id):
|
||||
def mark_user_erased(self, user_id: str) -> None:
|
||||
"""Indicate that user_id wishes their message history to be erased.
|
||||
|
||||
Args:
|
||||
user_id (str): full user_id to be erased
|
||||
user_id: full user_id to be erased
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
|
@ -89,3 +89,25 @@ class UserErasureStore(UserErasureWorkerStore):
|
|||
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
|
||||
|
||||
return self.db.runInteraction("mark_user_erased", f)
|
||||
|
||||
def mark_user_not_erased(self, user_id: str) -> None:
|
||||
"""Indicate that user_id is no longer erased.
|
||||
|
||||
Args:
|
||||
user_id: full user_id to be un-erased
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
# first check if they are already in the list
|
||||
txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,))
|
||||
if not txn.fetchone():
|
||||
return
|
||||
|
||||
# They are there, delete them.
|
||||
self.simple_delete_one_txn(
|
||||
txn, "erased_users", keyvalues={"user_id": user_id}
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
|
||||
|
||||
return self.db.runInteraction("mark_user_not_erased", f)
|
||||
|
|
|
@ -139,10 +139,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
"get_state_group_delta", _get_state_group_delta_txn
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_state_groups_from_groups(
|
||||
async def _get_state_groups_from_groups(
|
||||
self, groups: List[int], state_filter: StateFilter
|
||||
):
|
||||
) -> Dict[int, StateMap[str]]:
|
||||
"""Returns the state groups for a given set of groups from the
|
||||
database, filtering on types of state events.
|
||||
|
||||
|
@ -151,13 +150,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
state_filter: The state filter used to fetch state
|
||||
from the database.
|
||||
Returns:
|
||||
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
|
||||
Dict of state group to state map.
|
||||
"""
|
||||
results = {}
|
||||
|
||||
chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
|
||||
for chunk in chunks:
|
||||
res = yield self.db.runInteraction(
|
||||
res = await self.db.runInteraction(
|
||||
"_get_state_groups_from_groups",
|
||||
self._get_state_groups_from_groups_txn,
|
||||
chunk,
|
||||
|
@ -206,10 +205,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
|
||||
return state_filter.filter_state(state_dict_ids), not missing_types
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_state_for_groups(
|
||||
async def _get_state_for_groups(
|
||||
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
|
||||
):
|
||||
) -> Dict[int, StateMap[str]]:
|
||||
"""Gets the state at each of a list of state groups, optionally
|
||||
filtering by type/state_key
|
||||
|
||||
|
@ -219,7 +217,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
state_filter: The state filter used to fetch state
|
||||
from the database.
|
||||
Returns:
|
||||
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
|
||||
Dict of state group to state map.
|
||||
"""
|
||||
|
||||
member_filter, non_member_filter = state_filter.get_member_split()
|
||||
|
@ -228,14 +226,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
(
|
||||
non_member_state,
|
||||
incomplete_groups_nm,
|
||||
) = yield self._get_state_for_groups_using_cache(
|
||||
) = self._get_state_for_groups_using_cache(
|
||||
groups, self._state_group_cache, state_filter=non_member_filter
|
||||
)
|
||||
|
||||
(
|
||||
member_state,
|
||||
incomplete_groups_m,
|
||||
) = yield self._get_state_for_groups_using_cache(
|
||||
(member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache(
|
||||
groups, self._state_group_members_cache, state_filter=member_filter
|
||||
)
|
||||
|
||||
|
@ -256,7 +251,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
# Help the cache hit ratio by expanding the filter a bit
|
||||
db_state_filter = state_filter.return_expanded()
|
||||
|
||||
group_to_state_dict = yield self._get_state_groups_from_groups(
|
||||
group_to_state_dict = await self._get_state_groups_from_groups(
|
||||
list(incomplete_groups), state_filter=db_state_filter
|
||||
)
|
||||
|
||||
|
@ -576,19 +571,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
((sg,) for sg in state_groups_to_delete),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_previous_state_groups(self, state_groups):
|
||||
async def get_previous_state_groups(
|
||||
self, state_groups: Iterable[int]
|
||||
) -> Dict[int, int]:
|
||||
"""Fetch the previous groups of the given state groups.
|
||||
|
||||
Args:
|
||||
state_groups (Iterable[int])
|
||||
state_groups
|
||||
|
||||
Returns:
|
||||
Deferred[dict[int, int]]: mapping from state group to previous
|
||||
state group.
|
||||
A mapping from state group to previous state group.
|
||||
"""
|
||||
|
||||
rows = yield self.db.simple_select_many_batch(
|
||||
rows = await self.db.simple_select_many_batch(
|
||||
table="state_group_edges",
|
||||
column="prev_state_group",
|
||||
iterable=state_groups,
|
||||
|
|
|
@ -49,11 +49,11 @@ from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3E
|
|||
from synapse.storage.types import Connection, Cursor
|
||||
from synapse.types import Collection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# python 3 does not have a maximum int value
|
||||
MAX_TXN_ID = 2 ** 63 - 1
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
sql_logger = logging.getLogger("synapse.storage.SQL")
|
||||
transaction_logger = logging.getLogger("synapse.storage.txn")
|
||||
perf_logger = logging.getLogger("synapse.storage.TIME")
|
||||
|
@ -233,7 +233,7 @@ class LoggingTransaction:
|
|||
try:
|
||||
return func(sql, *args)
|
||||
except Exception as e:
|
||||
logger.debug("[SQL FAIL] {%s} %s", self.name, e)
|
||||
sql_logger.debug("[SQL FAIL] {%s} %s", self.name, e)
|
||||
raise
|
||||
finally:
|
||||
secs = time.time() - start
|
||||
|
@ -419,7 +419,7 @@ class Database(object):
|
|||
except self.engine.module.OperationalError as e:
|
||||
# This can happen if the database disappears mid
|
||||
# transaction.
|
||||
logger.warning(
|
||||
transaction_logger.warning(
|
||||
"[TXN OPERROR] {%s} %s %d/%d", name, e, i, N,
|
||||
)
|
||||
if i < N:
|
||||
|
@ -427,18 +427,20 @@ class Database(object):
|
|||
try:
|
||||
conn.rollback()
|
||||
except self.engine.module.Error as e1:
|
||||
logger.warning("[TXN EROLL] {%s} %s", name, e1)
|
||||
transaction_logger.warning("[TXN EROLL] {%s} %s", name, e1)
|
||||
continue
|
||||
raise
|
||||
except self.engine.module.DatabaseError as e:
|
||||
if self.engine.is_deadlock(e):
|
||||
logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
|
||||
transaction_logger.warning(
|
||||
"[TXN DEADLOCK] {%s} %d/%d", name, i, N
|
||||
)
|
||||
if i < N:
|
||||
i += 1
|
||||
try:
|
||||
conn.rollback()
|
||||
except self.engine.module.Error as e1:
|
||||
logger.warning(
|
||||
transaction_logger.warning(
|
||||
"[TXN EROLL] {%s} %s", name, e1,
|
||||
)
|
||||
continue
|
||||
|
@ -478,7 +480,7 @@ class Database(object):
|
|||
# [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
|
||||
cursor.close()
|
||||
except Exception as e:
|
||||
logger.debug("[TXN FAIL] {%s} %s", name, e)
|
||||
transaction_logger.debug("[TXN FAIL] {%s} %s", name, e)
|
||||
raise
|
||||
finally:
|
||||
end = monotonic_time()
|
||||
|
|
|
@ -25,11 +25,10 @@ from prometheus_client import Counter, Histogram
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.state import StateResolutionStore
|
||||
from synapse.storage.data_stores import DataStores
|
||||
from synapse.storage.data_stores.main.events import DeltaState
|
||||
from synapse.types import StateMap
|
||||
|
@ -193,12 +192,11 @@ class EventsPersistenceStorage(object):
|
|||
self._event_persist_queue = _EventPeristenceQueue()
|
||||
self._state_resolution_handler = hs.get_state_resolution_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def persist_events(
|
||||
async def persist_events(
|
||||
self,
|
||||
events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
backfilled: bool = False,
|
||||
):
|
||||
) -> int:
|
||||
"""
|
||||
Write events to the database
|
||||
Args:
|
||||
|
@ -208,7 +206,7 @@ class EventsPersistenceStorage(object):
|
|||
which might update the current state etc.
|
||||
|
||||
Returns:
|
||||
Deferred[int]: the stream ordering of the latest persisted event
|
||||
the stream ordering of the latest persisted event
|
||||
"""
|
||||
partitioned = {}
|
||||
for event, ctx in events_and_contexts:
|
||||
|
@ -224,22 +222,19 @@ class EventsPersistenceStorage(object):
|
|||
for room_id in partitioned:
|
||||
self._maybe_start_persisting(room_id)
|
||||
|
||||
yield make_deferred_yieldable(
|
||||
await make_deferred_yieldable(
|
||||
defer.gatherResults(deferreds, consumeErrors=True)
|
||||
)
|
||||
|
||||
max_persisted_id = yield self.main_store.get_current_events_token()
|
||||
return self.main_store.get_current_events_token()
|
||||
|
||||
return max_persisted_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def persist_event(
|
||||
self, event: FrozenEvent, context: EventContext, backfilled: bool = False
|
||||
):
|
||||
async def persist_event(
|
||||
self, event: EventBase, context: EventContext, backfilled: bool = False
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Returns:
|
||||
Deferred[Tuple[int, int]]: the stream ordering of ``event``,
|
||||
and the stream ordering of the latest persisted event
|
||||
The stream ordering of `event`, and the stream ordering of the
|
||||
latest persisted event
|
||||
"""
|
||||
deferred = self._event_persist_queue.add_to_queue(
|
||||
event.room_id, [(event, context)], backfilled=backfilled
|
||||
|
@ -247,9 +242,9 @@ class EventsPersistenceStorage(object):
|
|||
|
||||
self._maybe_start_persisting(event.room_id)
|
||||
|
||||
yield make_deferred_yieldable(deferred)
|
||||
await make_deferred_yieldable(deferred)
|
||||
|
||||
max_persisted_id = yield self.main_store.get_current_events_token()
|
||||
max_persisted_id = self.main_store.get_current_events_token()
|
||||
return (event.internal_metadata.stream_ordering, max_persisted_id)
|
||||
|
||||
def _maybe_start_persisting(self, room_id: str):
|
||||
|
@ -263,7 +258,7 @@ class EventsPersistenceStorage(object):
|
|||
|
||||
async def _persist_events(
|
||||
self,
|
||||
events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
backfilled: bool = False,
|
||||
):
|
||||
"""Calculates the change to current state and forward extremities, and
|
||||
|
@ -440,7 +435,7 @@ class EventsPersistenceStorage(object):
|
|||
async def _calculate_new_extremities(
|
||||
self,
|
||||
room_id: str,
|
||||
event_contexts: List[Tuple[FrozenEvent, EventContext]],
|
||||
event_contexts: List[Tuple[EventBase, EventContext]],
|
||||
latest_event_ids: List[str],
|
||||
):
|
||||
"""Calculates the new forward extremities for a room given events to
|
||||
|
@ -498,7 +493,7 @@ class EventsPersistenceStorage(object):
|
|||
async def _get_new_state_after_events(
|
||||
self,
|
||||
room_id: str,
|
||||
events_context: List[Tuple[FrozenEvent, EventContext]],
|
||||
events_context: List[Tuple[EventBase, EventContext]],
|
||||
old_latest_event_ids: Iterable[str],
|
||||
new_latest_event_ids: Iterable[str],
|
||||
) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
|
||||
|
@ -648,6 +643,10 @@ class EventsPersistenceStorage(object):
|
|||
room_version = await self.main_store.get_room_version_id(room_id)
|
||||
|
||||
logger.debug("calling resolve_state_groups from preserve_events")
|
||||
|
||||
# Avoid a circular import.
|
||||
from synapse.state import StateResolutionStore
|
||||
|
||||
res = await self._state_resolution_handler.resolve_state_groups(
|
||||
room_id,
|
||||
room_version,
|
||||
|
@ -680,7 +679,7 @@ class EventsPersistenceStorage(object):
|
|||
async def _is_server_still_joined(
|
||||
self,
|
||||
room_id: str,
|
||||
ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]],
|
||||
ev_ctx_rm: List[Tuple[EventBase, EventContext]],
|
||||
delta: DeltaState,
|
||||
current_state: Optional[StateMap[str]],
|
||||
potentially_left_users: Set[str],
|
||||
|
|
|
@ -15,8 +15,7 @@
|
|||
|
||||
import itertools
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
from typing import Set
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -28,49 +27,48 @@ class PurgeEventsStorage(object):
|
|||
def __init__(self, hs, stores):
|
||||
self.stores = stores
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def purge_room(self, room_id: str):
|
||||
async def purge_room(self, room_id: str):
|
||||
"""Deletes all record of a room
|
||||
"""
|
||||
|
||||
state_groups_to_delete = yield self.stores.main.purge_room(room_id)
|
||||
yield self.stores.state.purge_room_state(room_id, state_groups_to_delete)
|
||||
state_groups_to_delete = await self.stores.main.purge_room(room_id)
|
||||
await self.stores.state.purge_room_state(room_id, state_groups_to_delete)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def purge_history(self, room_id, token, delete_local_events):
|
||||
async def purge_history(
|
||||
self, room_id: str, token: str, delete_local_events: bool
|
||||
) -> None:
|
||||
"""Deletes room history before a certain point
|
||||
|
||||
Args:
|
||||
room_id (str):
|
||||
room_id: The room ID
|
||||
|
||||
token (str): A topological token to delete events before
|
||||
token: A topological token to delete events before
|
||||
|
||||
delete_local_events (bool):
|
||||
delete_local_events:
|
||||
if True, we will delete local events as well as remote ones
|
||||
(instead of just marking them as outliers and deleting their
|
||||
state groups).
|
||||
"""
|
||||
state_groups = yield self.stores.main.purge_history(
|
||||
state_groups = await self.stores.main.purge_history(
|
||||
room_id, token, delete_local_events
|
||||
)
|
||||
|
||||
logger.info("[purge] finding state groups that can be deleted")
|
||||
|
||||
sg_to_delete = yield self._find_unreferenced_groups(state_groups)
|
||||
sg_to_delete = await self._find_unreferenced_groups(state_groups)
|
||||
|
||||
yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
|
||||
await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _find_unreferenced_groups(self, state_groups):
|
||||
async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]:
|
||||
"""Used when purging history to figure out which state groups can be
|
||||
deleted.
|
||||
|
||||
Args:
|
||||
state_groups (set[int]): Set of state groups referenced by events
|
||||
state_groups: Set of state groups referenced by events
|
||||
that are going to be deleted.
|
||||
|
||||
Returns:
|
||||
Deferred[set[int]] The set of state groups that can be deleted.
|
||||
The set of state groups that can be deleted.
|
||||
"""
|
||||
# Graph of state group -> previous group
|
||||
graph = {}
|
||||
|
@ -93,7 +91,7 @@ class PurgeEventsStorage(object):
|
|||
current_search = set(itertools.islice(next_to_search, 100))
|
||||
next_to_search -= current_search
|
||||
|
||||
referenced = yield self.stores.main.get_referenced_state_groups(
|
||||
referenced = await self.stores.main.get_referenced_state_groups(
|
||||
current_search
|
||||
)
|
||||
referenced_groups |= referenced
|
||||
|
@ -102,7 +100,7 @@ class PurgeEventsStorage(object):
|
|||
# groups that are referenced.
|
||||
current_search -= referenced
|
||||
|
||||
edges = yield self.stores.state.get_previous_state_groups(current_search)
|
||||
edges = await self.stores.state.get_previous_state_groups(current_search)
|
||||
|
||||
prevs = set(edges.values())
|
||||
# We don't bother re-handling groups we've already seen
|
||||
|
|
|
@ -14,13 +14,12 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Iterable, List, TypeVar
|
||||
from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import StateMap
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -34,16 +33,16 @@ class StateFilter(object):
|
|||
"""A filter used when querying for state.
|
||||
|
||||
Attributes:
|
||||
types (dict[str, set[str]|None]): Map from type to set of state keys (or
|
||||
None). This specifies which state_keys for the given type to fetch
|
||||
from the DB. If None then all events with that type are fetched. If
|
||||
the set is empty then no events with that type are fetched.
|
||||
include_others (bool): Whether to fetch events with types that do not
|
||||
types: Map from type to set of state keys (or None). This specifies
|
||||
which state_keys for the given type to fetch from the DB. If None
|
||||
then all events with that type are fetched. If the set is empty
|
||||
then no events with that type are fetched.
|
||||
include_others: Whether to fetch events with types that do not
|
||||
appear in `types`.
|
||||
"""
|
||||
|
||||
types = attr.ib()
|
||||
include_others = attr.ib(default=False)
|
||||
types = attr.ib(type=Dict[str, Optional[Set[str]]])
|
||||
include_others = attr.ib(default=False, type=bool)
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
# If `include_others` is set we canonicalise the filter by removing
|
||||
|
@ -52,36 +51,35 @@ class StateFilter(object):
|
|||
self.types = {k: v for k, v in self.types.items() if v is not None}
|
||||
|
||||
@staticmethod
|
||||
def all():
|
||||
def all() -> "StateFilter":
|
||||
"""Creates a filter that fetches everything.
|
||||
|
||||
Returns:
|
||||
StateFilter
|
||||
The new state filter.
|
||||
"""
|
||||
return StateFilter(types={}, include_others=True)
|
||||
|
||||
@staticmethod
|
||||
def none():
|
||||
def none() -> "StateFilter":
|
||||
"""Creates a filter that fetches nothing.
|
||||
|
||||
Returns:
|
||||
StateFilter
|
||||
The new state filter.
|
||||
"""
|
||||
return StateFilter(types={}, include_others=False)
|
||||
|
||||
@staticmethod
|
||||
def from_types(types):
|
||||
def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
|
||||
"""Creates a filter that only fetches the given types
|
||||
|
||||
Args:
|
||||
types (Iterable[tuple[str, str|None]]): A list of type and state
|
||||
keys to fetch. A state_key of None fetches everything for
|
||||
that type
|
||||
types: A list of type and state keys to fetch. A state_key of None
|
||||
fetches everything for that type
|
||||
|
||||
Returns:
|
||||
StateFilter
|
||||
The new state filter.
|
||||
"""
|
||||
type_dict = {}
|
||||
type_dict = {} # type: Dict[str, Optional[Set[str]]]
|
||||
for typ, s in types:
|
||||
if typ in type_dict:
|
||||
if type_dict[typ] is None:
|
||||
|
@ -91,24 +89,24 @@ class StateFilter(object):
|
|||
type_dict[typ] = None
|
||||
continue
|
||||
|
||||
type_dict.setdefault(typ, set()).add(s)
|
||||
type_dict.setdefault(typ, set()).add(s) # type: ignore
|
||||
|
||||
return StateFilter(types=type_dict)
|
||||
|
||||
@staticmethod
|
||||
def from_lazy_load_member_list(members):
|
||||
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
|
||||
"""Creates a filter that returns all non-member events, plus the member
|
||||
events for the given users
|
||||
|
||||
Args:
|
||||
members (iterable[str]): Set of user IDs
|
||||
members: Set of user IDs
|
||||
|
||||
Returns:
|
||||
StateFilter
|
||||
The new state filter
|
||||
"""
|
||||
return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
|
||||
|
||||
def return_expanded(self):
|
||||
def return_expanded(self) -> "StateFilter":
|
||||
"""Creates a new StateFilter where type wild cards have been removed
|
||||
(except for memberships). The returned filter is a superset of the
|
||||
current one, i.e. anything that passes the current filter will pass
|
||||
|
@ -130,7 +128,7 @@ class StateFilter(object):
|
|||
return all non-member events
|
||||
|
||||
Returns:
|
||||
StateFilter
|
||||
The new state filter.
|
||||
"""
|
||||
|
||||
if self.is_full():
|
||||
|
@ -167,7 +165,7 @@ class StateFilter(object):
|
|||
include_others=True,
|
||||
)
|
||||
|
||||
def make_sql_filter_clause(self):
|
||||
def make_sql_filter_clause(self) -> Tuple[str, List[str]]:
|
||||
"""Converts the filter to an SQL clause.
|
||||
|
||||
For example:
|
||||
|
@ -179,13 +177,12 @@ class StateFilter(object):
|
|||
|
||||
|
||||
Returns:
|
||||
tuple[str, list]: The SQL string (may be empty) and arguments. An
|
||||
empty SQL string is returned when the filter matches everything
|
||||
(i.e. is "full").
|
||||
The SQL string (may be empty) and arguments. An empty SQL string is
|
||||
returned when the filter matches everything (i.e. is "full").
|
||||
"""
|
||||
|
||||
where_clause = ""
|
||||
where_args = []
|
||||
where_args = [] # type: List[str]
|
||||
|
||||
if self.is_full():
|
||||
return where_clause, where_args
|
||||
|
@ -221,7 +218,7 @@ class StateFilter(object):
|
|||
|
||||
return where_clause, where_args
|
||||
|
||||
def max_entries_returned(self):
|
||||
def max_entries_returned(self) -> Optional[int]:
|
||||
"""Returns the maximum number of entries this filter will return if
|
||||
known, otherwise returns None.
|
||||
|
||||
|
@ -260,33 +257,33 @@ class StateFilter(object):
|
|||
|
||||
return filtered_state
|
||||
|
||||
def is_full(self):
|
||||
def is_full(self) -> bool:
|
||||
"""Whether this filter fetches everything or not
|
||||
|
||||
Returns:
|
||||
bool
|
||||
True if the filter fetches everything.
|
||||
"""
|
||||
return self.include_others and not self.types
|
||||
|
||||
def has_wildcards(self):
|
||||
def has_wildcards(self) -> bool:
|
||||
"""Whether the filter includes wildcards or is attempting to fetch
|
||||
specific state.
|
||||
|
||||
Returns:
|
||||
bool
|
||||
True if the filter includes wildcards.
|
||||
"""
|
||||
|
||||
return self.include_others or any(
|
||||
state_keys is None for state_keys in self.types.values()
|
||||
)
|
||||
|
||||
def concrete_types(self):
|
||||
def concrete_types(self) -> List[Tuple[str, str]]:
|
||||
"""Returns a list of concrete type/state_keys (i.e. not None) that
|
||||
will be fetched. This will be a complete list if `has_wildcards`
|
||||
returns False, but otherwise will be a subset (or even empty).
|
||||
|
||||
Returns:
|
||||
list[tuple[str,str]]
|
||||
A list of type/state_keys tuples.
|
||||
"""
|
||||
return [
|
||||
(t, s)
|
||||
|
@ -295,7 +292,7 @@ class StateFilter(object):
|
|||
for s in state_keys
|
||||
]
|
||||
|
||||
def get_member_split(self):
|
||||
def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
|
||||
"""Return the filter split into two: one which assumes it's exclusively
|
||||
matching against member state, and one which assumes it's matching
|
||||
against non member state.
|
||||
|
@ -307,7 +304,7 @@ class StateFilter(object):
|
|||
state caches).
|
||||
|
||||
Returns:
|
||||
tuple[StateFilter, StateFilter]: The member and non member filters
|
||||
The member and non member filters
|
||||
"""
|
||||
|
||||
if EventTypes.Member in self.types:
|
||||
|
@ -340,6 +337,9 @@ class StateGroupStorage(object):
|
|||
"""Given a state group try to return a previous group and a delta between
|
||||
the old and the new.
|
||||
|
||||
Args:
|
||||
state_group: The state group used to retrieve state deltas.
|
||||
|
||||
Returns:
|
||||
Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
|
||||
(prev_group, delta_ids)
|
||||
|
@ -347,55 +347,59 @@ class StateGroupStorage(object):
|
|||
|
||||
return self.stores.state.get_state_group_delta(state_group)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_groups_ids(self, _room_id, event_ids):
|
||||
async def get_state_groups_ids(
|
||||
self, _room_id: str, event_ids: Iterable[str]
|
||||
) -> Dict[int, StateMap[str]]:
|
||||
"""Get the event IDs of all the state for the state groups for the given events
|
||||
|
||||
Args:
|
||||
_room_id (str): id of the room for these events
|
||||
event_ids (iterable[str]): ids of the events
|
||||
_room_id: id of the room for these events
|
||||
event_ids: ids of the events
|
||||
|
||||
Returns:
|
||||
Deferred[dict[int, StateMap[str]]]:
|
||||
dict of state_group_id -> (dict of (type, state_key) -> event id)
|
||||
dict of state_group_id -> (dict of (type, state_key) -> event id)
|
||||
"""
|
||||
if not event_ids:
|
||||
return {}
|
||||
|
||||
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
|
||||
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
|
||||
|
||||
groups = set(event_to_groups.values())
|
||||
group_to_state = yield self.stores.state._get_state_for_groups(groups)
|
||||
group_to_state = await self.stores.state._get_state_for_groups(groups)
|
||||
|
||||
return group_to_state
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_ids_for_group(self, state_group):
|
||||
async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
|
||||
"""Get the event IDs of all the state in the given state group
|
||||
|
||||
Args:
|
||||
state_group (int)
|
||||
state_group: A state group for which we want to get the state IDs.
|
||||
|
||||
Returns:
|
||||
Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
|
||||
Resolves to a map of (type, state_key) -> event_id
|
||||
"""
|
||||
group_to_state = yield self._get_state_for_groups((state_group,))
|
||||
group_to_state = await self._get_state_for_groups((state_group,))
|
||||
|
||||
return group_to_state[state_group]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_groups(self, room_id, event_ids):
|
||||
async def get_state_groups(
|
||||
self, room_id: str, event_ids: Iterable[str]
|
||||
) -> Dict[int, List[EventBase]]:
|
||||
""" Get the state groups for the given list of event_ids
|
||||
|
||||
Args:
|
||||
room_id: ID of the room for these events.
|
||||
event_ids: The event IDs to retrieve state for.
|
||||
|
||||
Returns:
|
||||
Deferred[dict[int, list[EventBase]]]:
|
||||
dict of state_group_id -> list of state events.
|
||||
dict of state_group_id -> list of state events.
|
||||
"""
|
||||
if not event_ids:
|
||||
return {}
|
||||
|
||||
group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
|
||||
group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
|
||||
|
||||
state_event_map = yield self.stores.main.get_events(
|
||||
state_event_map = await self.stores.main.get_events(
|
||||
[
|
||||
ev_id
|
||||
for group_ids in group_to_ids.values()
|
||||
|
@ -415,7 +419,7 @@ class StateGroupStorage(object):
|
|||
|
||||
def _get_state_groups_from_groups(
|
||||
self, groups: List[int], state_filter: StateFilter
|
||||
):
|
||||
) -> Awaitable[Dict[int, StateMap[str]]]:
|
||||
"""Returns the state groups for a given set of groups, filtering on
|
||||
types of state events.
|
||||
|
||||
|
@ -423,31 +427,34 @@ class StateGroupStorage(object):
|
|||
groups: list of state group IDs to query
|
||||
state_filter: The state filter used to fetch state
|
||||
from the database.
|
||||
|
||||
Returns:
|
||||
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
|
||||
Dict of state group to state map.
|
||||
"""
|
||||
|
||||
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
|
||||
async def get_state_for_events(
|
||||
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
|
||||
):
|
||||
"""Given a list of event_ids and type tuples, return a list of state
|
||||
dicts for each event.
|
||||
|
||||
Args:
|
||||
event_ids (list[string])
|
||||
state_filter (StateFilter): The state filter used to fetch state
|
||||
from the database.
|
||||
event_ids: The events to fetch the state of.
|
||||
state_filter: The state filter used to fetch state.
|
||||
|
||||
Returns:
|
||||
deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
|
||||
A dict of (event_id) -> (type, state_key) -> [state_events]
|
||||
"""
|
||||
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
|
||||
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
|
||||
|
||||
groups = set(event_to_groups.values())
|
||||
group_to_state = yield self.stores.state._get_state_for_groups(
|
||||
group_to_state = await self.stores.state._get_state_for_groups(
|
||||
groups, state_filter
|
||||
)
|
||||
|
||||
state_event_map = yield self.stores.main.get_events(
|
||||
state_event_map = await self.stores.main.get_events(
|
||||
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
|
||||
get_prev_content=False,
|
||||
)
|
||||
|
@ -463,24 +470,24 @@ class StateGroupStorage(object):
|
|||
|
||||
return {event: event_to_state[event] for event in event_ids}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
|
||||
async def get_state_ids_for_events(
|
||||
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
|
||||
):
|
||||
"""
|
||||
Get the state dicts corresponding to a list of events, containing the event_ids
|
||||
of the state events (as opposed to the events themselves)
|
||||
|
||||
Args:
|
||||
event_ids(list(str)): events whose state should be returned
|
||||
state_filter (StateFilter): The state filter used to fetch state
|
||||
from the database.
|
||||
event_ids: events whose state should be returned
|
||||
state_filter: The state filter used to fetch state from the database.
|
||||
|
||||
Returns:
|
||||
A deferred dict from event_id -> (type, state_key) -> event_id
|
||||
A dict from event_id -> (type, state_key) -> event_id
|
||||
"""
|
||||
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
|
||||
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
|
||||
|
||||
groups = set(event_to_groups.values())
|
||||
group_to_state = yield self.stores.state._get_state_for_groups(
|
||||
group_to_state = await self.stores.state._get_state_for_groups(
|
||||
groups, state_filter
|
||||
)
|
||||
|
||||
|
@ -491,67 +498,72 @@ class StateGroupStorage(object):
|
|||
|
||||
return {event: event_to_state[event] for event in event_ids}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
|
||||
async def get_state_for_event(
|
||||
self, event_id: str, state_filter: StateFilter = StateFilter.all()
|
||||
):
|
||||
"""
|
||||
Get the state dict corresponding to a particular event
|
||||
|
||||
Args:
|
||||
event_id(str): event whose state should be returned
|
||||
state_filter (StateFilter): The state filter used to fetch state
|
||||
from the database.
|
||||
event_id: event whose state should be returned
|
||||
state_filter: The state filter used to fetch state from the database.
|
||||
|
||||
Returns:
|
||||
A deferred dict from (type, state_key) -> state_event
|
||||
A dict from (type, state_key) -> state_event
|
||||
"""
|
||||
state_map = yield self.get_state_for_events([event_id], state_filter)
|
||||
state_map = await self.get_state_for_events([event_id], state_filter)
|
||||
return state_map[event_id]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
|
||||
async def get_state_ids_for_event(
|
||||
self, event_id: str, state_filter: StateFilter = StateFilter.all()
|
||||
):
|
||||
"""
|
||||
Get the state dict corresponding to a particular event
|
||||
|
||||
Args:
|
||||
event_id(str): event whose state should be returned
|
||||
state_filter (StateFilter): The state filter used to fetch state
|
||||
from the database.
|
||||
event_id: event whose state should be returned
|
||||
state_filter: The state filter used to fetch state from the database.
|
||||
|
||||
Returns:
|
||||
A deferred dict from (type, state_key) -> state_event
|
||||
"""
|
||||
state_map = yield self.get_state_ids_for_events([event_id], state_filter)
|
||||
state_map = await self.get_state_ids_for_events([event_id], state_filter)
|
||||
return state_map[event_id]
|
||||
|
||||
def _get_state_for_groups(
|
||||
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
|
||||
):
|
||||
) -> Awaitable[Dict[int, StateMap[str]]]:
|
||||
"""Gets the state at each of a list of state groups, optionally
|
||||
filtering by type/state_key
|
||||
|
||||
Args:
|
||||
groups (iterable[int]): list of state groups for which we want
|
||||
to get the state.
|
||||
state_filter (StateFilter): The state filter used to fetch state
|
||||
groups: list of state groups for which we want to get the state.
|
||||
state_filter: The state filter used to fetch state.
|
||||
from the database.
|
||||
|
||||
Returns:
|
||||
Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
|
||||
Dict of state group to state map.
|
||||
"""
|
||||
return self.stores.state._get_state_for_groups(groups, state_filter)
|
||||
|
||||
def store_state_group(
|
||||
self, event_id, room_id, prev_group, delta_ids, current_state_ids
|
||||
self,
|
||||
event_id: str,
|
||||
room_id: str,
|
||||
prev_group: Optional[int],
|
||||
delta_ids: Optional[dict],
|
||||
current_state_ids: dict,
|
||||
):
|
||||
"""Store a new set of state, returning a newly assigned state group.
|
||||
|
||||
Args:
|
||||
event_id (str): The event ID for which the state was calculated
|
||||
room_id (str)
|
||||
prev_group (int|None): A previous state group for the room, optional.
|
||||
delta_ids (dict|None): The delta between state at `prev_group` and
|
||||
event_id: The event ID for which the state was calculated.
|
||||
room_id: ID of the room for which the state was calculated.
|
||||
prev_group: A previous state group for the room, optional.
|
||||
delta_ids: The delta between state at `prev_group` and
|
||||
`current_state_ids`, if `prev_group` was given. Same format as
|
||||
`current_state_ids`.
|
||||
current_state_ids (dict): The state to store. Map of (type, state_key)
|
||||
current_state_ids: The state to store. Map of (type, state_key)
|
||||
to event_id.
|
||||
|
||||
Returns:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue