Merge branch 'develop' of github.com:matrix-org/synapse into babolivier/new_push_rules

This commit is contained in:
Brendan Abolivier 2020-07-30 19:00:29 +01:00
commit 69158e554f
230 changed files with 5343 additions and 4040 deletions

View file

@ -172,6 +172,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_latest_event_ids_in_room.invalidate((room_id,))
self.get_unread_message_count_for_user.invalidate_many((room_id,))
self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
if not backfilled:

View file

@ -15,11 +15,10 @@
# limitations under the License.
import logging
from typing import List
from canonicaljson import json
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
from synapse.storage.database import Database
@ -166,8 +165,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return {"notify_count": notify_count, "highlight_count": highlight_count}
@defer.inlineCallbacks
def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering):
async def get_push_action_users_in_range(
self, min_stream_ordering, max_stream_ordering
):
def f(txn):
sql = (
"SELECT DISTINCT(user_id) FROM event_push_actions WHERE"
@ -176,26 +176,28 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn]
ret = yield self.db.runInteraction("get_push_action_users_in_range", f)
ret = await self.db.runInteraction("get_push_action_users_in_range", f)
return ret
@defer.inlineCallbacks
def get_unread_push_actions_for_user_in_range_for_http(
self, user_id, min_stream_ordering, max_stream_ordering, limit=20
):
async def get_unread_push_actions_for_user_in_range_for_http(
self,
user_id: str,
min_stream_ordering: int,
max_stream_ordering: int,
limit: int = 20,
) -> List[dict]:
"""Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the httppusher.
Args:
user_id (str): The user to fetch push actions for.
min_stream_ordering(int): The exclusive lower bound on the
user_id: The user to fetch push actions for.
min_stream_ordering: The exclusive lower bound on the
stream ordering of event push actions to fetch.
max_stream_ordering(int): The inclusive upper bound on the
max_stream_ordering: The inclusive upper bound on the
stream ordering of event push actions to fetch.
limit (int): The maximum number of rows to return.
limit: The maximum number of rows to return.
Returns:
A promise which resolves to a list of dicts with the keys "event_id",
"room_id", "stream_ordering", "actions".
A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions".
The list will be ordered by ascending stream_ordering.
The list will have between 0~limit entries.
"""
@ -228,7 +230,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
after_read_receipt = yield self.db.runInteraction(
after_read_receipt = await self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
)
@ -256,7 +258,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
no_read_receipt = yield self.db.runInteraction(
no_read_receipt = await self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
)
@ -280,23 +282,25 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# one of the subqueries may have hit the limit.
return notifs[:limit]
@defer.inlineCallbacks
def get_unread_push_actions_for_user_in_range_for_email(
self, user_id, min_stream_ordering, max_stream_ordering, limit=20
):
async def get_unread_push_actions_for_user_in_range_for_email(
self,
user_id: str,
min_stream_ordering: int,
max_stream_ordering: int,
limit: int = 20,
) -> List[dict]:
"""Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the emailpusher
Args:
user_id (str): The user to fetch push actions for.
min_stream_ordering(int): The exclusive lower bound on the
user_id: The user to fetch push actions for.
min_stream_ordering: The exclusive lower bound on the
stream ordering of event push actions to fetch.
max_stream_ordering(int): The inclusive upper bound on the
max_stream_ordering: The inclusive upper bound on the
stream ordering of event push actions to fetch.
limit (int): The maximum number of rows to return.
limit: The maximum number of rows to return.
Returns:
A promise which resolves to a list of dicts with the keys "event_id",
"room_id", "stream_ordering", "actions", "received_ts".
A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions", "received_ts".
The list will be ordered by descending received_ts.
The list will have between 0~limit entries.
"""
@ -328,7 +332,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
after_read_receipt = yield self.db.runInteraction(
after_read_receipt = await self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
)
@ -356,7 +360,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
no_read_receipt = yield self.db.runInteraction(
no_read_receipt = await self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
)
@ -411,7 +415,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
_get_if_maybe_push_in_range_for_user_txn,
)
def add_push_actions_to_staging(self, event_id, user_id_actions):
async def add_push_actions_to_staging(self, event_id, user_id_actions):
"""Add the push actions for the event to the push action staging area.
Args:
@ -457,21 +461,17 @@ class EventPushActionsWorkerStore(SQLBaseStore):
),
)
return self.db.runInteraction(
return await self.db.runInteraction(
"add_push_actions_to_staging", _add_push_actions_to_staging_txn
)
@defer.inlineCallbacks
def remove_push_actions_from_staging(self, event_id):
async def remove_push_actions_from_staging(self, event_id: str) -> None:
"""Called if we failed to persist the event to ensure that stale push
actions don't build up in the DB
Args:
event_id (str)
"""
try:
res = yield self.db.simple_delete(
res = await self.db.simple_delete(
table="event_push_actions_staging",
keyvalues={"event_id": event_id},
desc="remove_push_actions_from_staging",
@ -606,8 +606,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return range_end
@defer.inlineCallbacks
def get_time_of_last_push_action_before(self, stream_ordering):
async def get_time_of_last_push_action_before(self, stream_ordering):
def f(txn):
sql = (
"SELECT e.received_ts"
@ -620,7 +619,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (stream_ordering,))
return txn.fetchone()
result = yield self.db.runInteraction("get_time_of_last_push_action_before", f)
result = await self.db.runInteraction("get_time_of_last_push_action_before", f)
return result[0] if result else None
@ -650,8 +649,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
self._start_rotate_notifs, 30 * 60 * 1000
)
@defer.inlineCallbacks
def get_push_actions_for_user(
async def get_push_actions_for_user(
self, user_id, before=None, limit=50, only_highlight=False
):
def f(txn):
@ -682,18 +680,17 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
txn.execute(sql, args)
return self.db.cursor_to_dict(txn)
push_actions = yield self.db.runInteraction("get_push_actions_for_user", f)
push_actions = await self.db.runInteraction("get_push_actions_for_user", f)
for pa in push_actions:
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
return push_actions
@defer.inlineCallbacks
def get_latest_push_action_stream_ordering(self):
async def get_latest_push_action_stream_ordering(self):
def f(txn):
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
return txn.fetchone()
result = yield self.db.runInteraction(
result = await self.db.runInteraction(
"get_latest_push_action_stream_ordering", f
)
return result[0] or 0
@ -747,8 +744,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
def _start_rotate_notifs(self):
return run_as_background_process("rotate_notifs", self._rotate_notifs)
@defer.inlineCallbacks
def _rotate_notifs(self):
async def _rotate_notifs(self):
if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
return
self._doing_notif_rotation = True
@ -757,12 +753,12 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
while True:
logger.info("Rotating notifications")
caught_up = yield self.db.runInteraction(
caught_up = await self.db.runInteraction(
"_rotate_notifs", self._rotate_notifs_txn
)
if caught_up:
break
yield self.hs.get_clock().sleep(self._rotate_delay)
await self.hs.get_clock().sleep(self._rotate_delay)
finally:
self._doing_notif_rotation = False

View file

@ -53,6 +53,47 @@ event_counter = Counter(
["type", "origin_type", "origin_entity"],
)
STATE_EVENT_TYPES_TO_MARK_UNREAD = {
EventTypes.Topic,
EventTypes.Name,
EventTypes.RoomAvatar,
EventTypes.Tombstone,
}
def should_count_as_unread(event: EventBase, context: EventContext) -> bool:
# Exclude rejected and soft-failed events.
if context.rejected or event.internal_metadata.is_soft_failed():
return False
# Exclude notices.
if (
not event.is_state()
and event.type == EventTypes.Message
and event.content.get("msgtype") == "m.notice"
):
return False
# Exclude edits.
relates_to = event.content.get("m.relates_to", {})
if relates_to.get("rel_type") == RelationTypes.REPLACE:
return False
# Mark events that have a non-empty string body as unread.
body = event.content.get("body")
if isinstance(body, str) and body:
return True
# Mark some state events as unread.
if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD:
return True
# Mark encrypted events as unread.
if not event.is_state() and event.type == EventTypes.Encrypted:
return True
return False
def encode_json(json_object):
"""
@ -196,6 +237,10 @@ class PersistEventsStore:
event_counter.labels(event.type, origin_type, origin_entity).inc()
self.store.get_unread_message_count_for_user.invalidate_many(
(event.room_id,),
)
for room_id, new_state in current_state_for_room.items():
self.store.get_current_state_ids.prefill((room_id,), new_state)
@ -817,8 +862,9 @@ class PersistEventsStore:
"contains_url": (
"url" in event.content and isinstance(event.content["url"], str)
),
"count_as_unread": should_count_as_unread(event, context),
}
for event, _ in events_and_contexts
for event, context in events_and_contexts
],
)

View file

@ -41,9 +41,15 @@ from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id
from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
from synapse.util.caches.descriptors import (
Cache,
_CacheContext,
cached,
cachedInlineCallbacks,
)
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@ -1358,6 +1364,84 @@ class EventsWorkerStore(SQLBaseStore):
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)
@cached(tree=True, cache_context=True)
async def get_unread_message_count_for_user(
self, room_id: str, user_id: str, cache_context: _CacheContext,
) -> int:
"""Retrieve the count of unread messages for the given room and user.
Args:
room_id: The ID of the room to count unread messages in.
user_id: The ID of the user to count unread messages for.
Returns:
The number of unread messages for the given user in the given room.
"""
with Measure(self._clock, "get_unread_message_count_for_user"):
last_read_event_id = await self.get_last_receipt_event_id_for_user(
user_id=user_id,
room_id=room_id,
receipt_type="m.read",
on_invalidate=cache_context.invalidate,
)
return await self.db.runInteraction(
"get_unread_message_count_for_user",
self._get_unread_message_count_for_user_txn,
user_id,
room_id,
last_read_event_id,
)
def _get_unread_message_count_for_user_txn(
self,
txn: Cursor,
user_id: str,
room_id: str,
last_read_event_id: Optional[str],
) -> int:
if last_read_event_id:
# Get the stream ordering for the last read event.
stream_ordering = self.db.simple_select_one_onecol_txn(
txn=txn,
table="events",
keyvalues={"room_id": room_id, "event_id": last_read_event_id},
retcol="stream_ordering",
)
else:
# If there's no read receipt for that room, it probably means the user hasn't
# opened it yet, in which case use the stream ID of their join event.
# We can't just set it to 0 otherwise messages from other local users from
# before this user joined will be counted as well.
txn.execute(
"""
SELECT stream_ordering FROM local_current_membership
LEFT JOIN events USING (event_id, room_id)
WHERE membership = 'join'
AND user_id = ?
AND room_id = ?
""",
(user_id, room_id),
)
row = txn.fetchone()
if row is None:
return 0
stream_ordering = row[0]
# Count the messages that qualify as unread after the stream ordering we've just
# retrieved.
sql = """
SELECT COUNT(*) FROM events
WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread
"""
txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone()
return row[0] if row else 0
AllNewEventsResult = namedtuple(
"AllNewEventsResult",

View file

@ -62,6 +62,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
# event_json
# event_push_actions
# event_reference_hashes
# event_relations
# event_search
# event_to_state_groups
# events
@ -209,6 +210,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"event_edges",
"event_forward_extremities",
"event_reference_hashes",
"event_relations",
"event_search",
"rejections",
):

View file

@ -284,7 +284,7 @@ class PushRulesWorkerStore(
# To do this we set the state_group to a new object as object() != object()
state_group = object()
current_state_ids = yield context.get_current_state_ids()
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
result = yield self._bulk_get_push_rules_for_room(
event.room_id, state_group, current_state_ids, event=event
)

View file

@ -23,8 +23,6 @@ from typing import Any, Dict, List, Optional, Tuple
from canonicaljson import json
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
@ -32,7 +30,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.data_stores.main.search import SearchStore
from synapse.storage.database import Database, LoggingTransaction
from synapse.types import ThirdPartyInstanceID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@ -192,8 +190,7 @@ class RoomWorkerStore(SQLBaseStore):
return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn)
@defer.inlineCallbacks
def get_largest_public_rooms(
async def get_largest_public_rooms(
self,
network_tuple: Optional[ThirdPartyInstanceID],
search_filter: Optional[dict],
@ -330,10 +327,10 @@ class RoomWorkerStore(SQLBaseStore):
return results
ret_val = yield self.db.runInteraction(
ret_val = await self.db.runInteraction(
"get_largest_public_rooms", _get_largest_public_rooms_txn
)
defer.returnValue(ret_val)
return ret_val
@cached(max_entries=10000)
def is_room_blocked(self, room_id):
@ -509,8 +506,8 @@ class RoomWorkerStore(SQLBaseStore):
"get_rooms_paginate", _get_rooms_paginate_txn,
)
@cachedInlineCallbacks(max_entries=10000)
def get_ratelimit_for_user(self, user_id):
@cached(max_entries=10000)
async def get_ratelimit_for_user(self, user_id):
"""Check if there are any overrides for ratelimiting for the given
user
@ -522,7 +519,7 @@ class RoomWorkerStore(SQLBaseStore):
of RatelimitOverride are None or 0 then ratelimitng has been
disabled for that user entirely.
"""
row = yield self.db.simple_select_one(
row = await self.db.simple_select_one(
table="ratelimit_override",
keyvalues={"user_id": user_id},
retcols=("messages_per_second", "burst_count"),
@ -538,8 +535,8 @@ class RoomWorkerStore(SQLBaseStore):
else:
return None
@cachedInlineCallbacks()
def get_retention_policy_for_room(self, room_id):
@cached()
async def get_retention_policy_for_room(self, room_id):
"""Get the retention policy for a given room.
If no retention policy has been found for this room, returns a policy defined
@ -566,19 +563,17 @@ class RoomWorkerStore(SQLBaseStore):
return self.db.cursor_to_dict(txn)
ret = yield self.db.runInteraction(
ret = await self.db.runInteraction(
"get_retention_policy_for_room", get_retention_policy_for_room_txn,
)
# If we don't know this room ID, ret will be None, in this case return the default
# policy.
if not ret:
defer.returnValue(
{
"min_lifetime": self.config.retention_default_min_lifetime,
"max_lifetime": self.config.retention_default_max_lifetime,
}
)
return {
"min_lifetime": self.config.retention_default_min_lifetime,
"max_lifetime": self.config.retention_default_max_lifetime,
}
row = ret[0]
@ -592,7 +587,7 @@ class RoomWorkerStore(SQLBaseStore):
if row["max_lifetime"] is None:
row["max_lifetime"] = self.config.retention_default_max_lifetime
defer.returnValue(row)
return row
def get_media_mxcs_in_room(self, room_id):
"""Retrieves all the local and remote media MXC URIs in a given room
@ -881,8 +876,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
self._background_add_rooms_room_version_column,
)
@defer.inlineCallbacks
def _background_insert_retention(self, progress, batch_size):
async def _background_insert_retention(self, progress, batch_size):
"""Retrieves a list of all rooms within a range and inserts an entry for each of
them into the room_retention table.
NULLs the property's columns if missing from the retention event in the room's
@ -940,14 +934,14 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
else:
return False
end = yield self.db.runInteraction(
end = await self.db.runInteraction(
"insert_room_retention", _background_insert_retention_txn,
)
if end:
yield self.db.updates._end_background_update("insert_room_retention")
await self.db.updates._end_background_update("insert_room_retention")
defer.returnValue(batch_size)
return batch_size
async def _background_add_rooms_room_version_column(
self, progress: dict, batch_size: int
@ -1096,8 +1090,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
lock=False,
)
@defer.inlineCallbacks
def store_room(
async def store_room(
self,
room_id: str,
room_creator_user_id: str,
@ -1140,7 +1133,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
with self._public_room_id_gen.get_next() as next_id:
yield self.db.runInteraction("store_room_txn", store_room_txn, next_id)
await self.db.runInteraction("store_room_txn", store_room_txn, next_id)
except Exception as e:
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
@ -1165,8 +1158,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
lock=False,
)
@defer.inlineCallbacks
def set_room_is_public(self, room_id, is_public):
async def set_room_is_public(self, room_id, is_public):
def set_room_is_public_txn(txn, next_id):
self.db.simple_update_one_txn(
txn,
@ -1206,13 +1198,12 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
with self._public_room_id_gen.get_next() as next_id:
yield self.db.runInteraction(
await self.db.runInteraction(
"set_room_is_public", set_room_is_public_txn, next_id
)
self.hs.get_notifier().on_new_replication_data()
@defer.inlineCallbacks
def set_room_is_public_appservice(
async def set_room_is_public_appservice(
self, room_id, appservice_id, network_id, is_public
):
"""Edit the appservice/network specific public room list.
@ -1287,7 +1278,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
with self._public_room_id_gen.get_next() as next_id:
yield self.db.runInteraction(
await self.db.runInteraction(
"set_room_is_public_appservice",
set_room_is_public_appservice_txn,
next_id,
@ -1327,52 +1318,47 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
@defer.inlineCallbacks
def block_room(self, room_id, user_id):
async def block_room(self, room_id: str, user_id: str) -> None:
"""Marks the room as blocked. Can be called multiple times.
Args:
room_id (str): Room to block
user_id (str): Who blocked it
Returns:
Deferred
room_id: Room to block
user_id: Who blocked it
"""
yield self.db.simple_upsert(
await self.db.simple_upsert(
table="blocked_rooms",
keyvalues={"room_id": room_id},
values={},
insertion_values={"user_id": user_id},
desc="block_room",
)
yield self.db.runInteraction(
await self.db.runInteraction(
"block_room_invalidation",
self._invalidate_cache_and_stream,
self.is_room_blocked,
(room_id,),
)
@defer.inlineCallbacks
def get_rooms_for_retention_period_in_range(
self, min_ms, max_ms, include_null=False
):
async def get_rooms_for_retention_period_in_range(
self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
) -> Dict[str, dict]:
"""Retrieves all of the rooms within the given retention range.
Optionally includes the rooms which don't have a retention policy.
Args:
min_ms (int|None): Duration in milliseconds that define the lower limit of
min_ms: Duration in milliseconds that define the lower limit of
the range to handle (exclusive). If None, doesn't set a lower limit.
max_ms (int|None): Duration in milliseconds that define the upper limit of
max_ms: Duration in milliseconds that define the upper limit of
the range to handle (inclusive). If None, doesn't set an upper limit.
include_null (bool): Whether to include rooms which retention policy is NULL
include_null: Whether to include rooms which retention policy is NULL
in the returned set.
Returns:
dict[str, dict]: The rooms within this range, along with their retention
policy. The key is "room_id", and maps to a dict describing the retention
policy associated with this room ID. The keys for this nested dict are
"min_lifetime" (int|None), and "max_lifetime" (int|None).
The rooms within this range, along with their retention
policy. The key is "room_id", and maps to a dict describing the retention
policy associated with this room ID. The keys for this nested dict are
"min_lifetime" (int|None), and "max_lifetime" (int|None).
"""
def get_rooms_for_retention_period_in_range_txn(txn):
@ -1431,9 +1417,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
return rooms_dict
rooms = yield self.db.runInteraction(
rooms = await self.db.runInteraction(
"get_rooms_for_retention_period_in_range",
get_rooms_for_retention_period_in_range_txn,
)
defer.returnValue(rooms)
return rooms

View file

@ -497,7 +497,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
current_state_ids = yield context.get_current_state_ids()
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
result = yield self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids, event=event, context=context
)

View file

@ -0,0 +1,18 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Store a boolean value in the events table for whether the event should be counted in
-- the unread_count property of sync responses.
ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN;

View file

@ -16,12 +16,12 @@
import collections.abc
import logging
from collections import namedtuple
from twisted.internet import defer
from typing import Iterable, Optional, Set
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
@ -108,28 +108,27 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
create_event = await self.get_create_event_for_room(room_id)
return create_event.content.get("room_version", "1")
@defer.inlineCallbacks
def get_room_predecessor(self, room_id):
async def get_room_predecessor(self, room_id: str) -> Optional[dict]:
"""Get the predecessor of an upgraded room if it exists.
Otherwise return None.
Args:
room_id (str)
room_id: The room ID.
Returns:
Deferred[dict|None]: A dictionary containing the structure of the predecessor
field from the room's create event. The structure is subject to other servers,
but it is expected to be:
* room_id (str): The room ID of the predecessor room
* event_id (str): The ID of the tombstone event in the predecessor room
A dictionary containing the structure of the predecessor
field from the room's create event. The structure is subject to other servers,
but it is expected to be:
* room_id (str): The room ID of the predecessor room
* event_id (str): The ID of the tombstone event in the predecessor room
None if a predecessor key is not found, or is not a dictionary.
None if a predecessor key is not found, or is not a dictionary.
Raises:
NotFoundError if the given room is unknown
"""
# Retrieve the room's create event
create_event = yield self.get_create_event_for_room(room_id)
create_event = await self.get_create_event_for_room(room_id)
# Retrieve the predecessor key of the create event
predecessor = create_event.content.get("predecessor", None)
@ -140,20 +139,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return predecessor
@defer.inlineCallbacks
def get_create_event_for_room(self, room_id):
async def get_create_event_for_room(self, room_id: str) -> EventBase:
"""Get the create state event for a room.
Args:
room_id (str)
room_id: The room ID.
Returns:
Deferred[EventBase]: The room creation event.
The room creation event.
Raises:
NotFoundError if the room is unknown
"""
state_ids = yield self.get_current_state_ids(room_id)
state_ids = await self.get_current_state_ids(room_id)
create_id = state_ids.get((EventTypes.Create, ""))
# If we can't find the create event, assume we've hit a dead end
@ -161,7 +159,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
raise NotFoundError("Unknown room %s" % (room_id,))
# Retrieve the room's create event and return
create_event = yield self.get_event(create_id)
create_event = await self.get_event(create_id)
return create_event
@cached(max_entries=100000, iterable=True)
@ -237,18 +235,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
)
@defer.inlineCallbacks
def get_canonical_alias_for_room(self, room_id):
async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
"""Get canonical alias for room, if any
Args:
room_id (str)
room_id: The room ID
Returns:
Deferred[str|None]: The canonical alias, if any
The canonical alias, if any
"""
state = yield self.get_filtered_current_state_ids(
state = await self.get_filtered_current_state_ids(
room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
)
@ -256,7 +253,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
if not event_id:
return
event = yield self.get_event(event_id, allow_none=True)
event = await self.get_event(event_id, allow_none=True)
if not event:
return
@ -292,19 +289,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return {row["event_id"]: row["state_group"] for row in rows}
@defer.inlineCallbacks
def get_referenced_state_groups(self, state_groups):
async def get_referenced_state_groups(
self, state_groups: Iterable[int]
) -> Set[int]:
"""Check if the state groups are referenced by events.
Args:
state_groups (Iterable[int])
state_groups
Returns:
Deferred[set[int]]: The subset of state groups that are
referenced.
The subset of state groups that are referenced.
"""
rows = yield self.db.simple_select_many_batch(
rows = await self.db.simple_select_many_batch(
table="event_to_state_groups",
column="state_group",
iterable=state_groups,

View file

@ -16,8 +16,8 @@
import logging
from itertools import chain
from typing import Tuple
from twisted.internet import defer
from twisted.internet.defer import DeferredLock
from synapse.api.constants import EventTypes, Membership
@ -97,13 +97,12 @@ class StatsStore(StateDeltasStore):
"""
return (ts // self.stats_bucket_size) * self.stats_bucket_size
@defer.inlineCallbacks
def _populate_stats_process_users(self, progress, batch_size):
async def _populate_stats_process_users(self, progress, batch_size):
"""
This is a background update which regenerates statistics for users.
"""
if not self.stats_enabled:
yield self.db.updates._end_background_update("populate_stats_process_users")
await self.db.updates._end_background_update("populate_stats_process_users")
return 1
last_user_id = progress.get("last_user_id", "")
@ -118,20 +117,20 @@ class StatsStore(StateDeltasStore):
txn.execute(sql, (last_user_id, batch_size))
return [r for r, in txn]
users_to_work_on = yield self.db.runInteraction(
users_to_work_on = await self.db.runInteraction(
"_populate_stats_process_users", _get_next_batch
)
# No more rooms -- complete the transaction.
if not users_to_work_on:
yield self.db.updates._end_background_update("populate_stats_process_users")
await self.db.updates._end_background_update("populate_stats_process_users")
return 1
for user_id in users_to_work_on:
yield self._calculate_and_set_initial_state_for_user(user_id)
await self._calculate_and_set_initial_state_for_user(user_id)
progress["last_user_id"] = user_id
yield self.db.runInteraction(
await self.db.runInteraction(
"populate_stats_process_users",
self.db.updates._background_update_progress_txn,
"populate_stats_process_users",
@ -140,13 +139,12 @@ class StatsStore(StateDeltasStore):
return len(users_to_work_on)
@defer.inlineCallbacks
def _populate_stats_process_rooms(self, progress, batch_size):
async def _populate_stats_process_rooms(self, progress, batch_size):
"""
This is a background update which regenerates statistics for rooms.
"""
if not self.stats_enabled:
yield self.db.updates._end_background_update("populate_stats_process_rooms")
await self.db.updates._end_background_update("populate_stats_process_rooms")
return 1
last_room_id = progress.get("last_room_id", "")
@ -161,20 +159,20 @@ class StatsStore(StateDeltasStore):
txn.execute(sql, (last_room_id, batch_size))
return [r for r, in txn]
rooms_to_work_on = yield self.db.runInteraction(
rooms_to_work_on = await self.db.runInteraction(
"populate_stats_rooms_get_batch", _get_next_batch
)
# No more rooms -- complete the transaction.
if not rooms_to_work_on:
yield self.db.updates._end_background_update("populate_stats_process_rooms")
await self.db.updates._end_background_update("populate_stats_process_rooms")
return 1
for room_id in rooms_to_work_on:
yield self._calculate_and_set_initial_state_for_room(room_id)
await self._calculate_and_set_initial_state_for_room(room_id)
progress["last_room_id"] = room_id
yield self.db.runInteraction(
await self.db.runInteraction(
"_populate_stats_process_rooms",
self.db.updates._background_update_progress_txn,
"populate_stats_process_rooms",
@ -696,16 +694,16 @@ class StatsStore(StateDeltasStore):
return room_deltas, user_deltas
@defer.inlineCallbacks
def _calculate_and_set_initial_state_for_room(self, room_id):
async def _calculate_and_set_initial_state_for_room(
self, room_id: str
) -> Tuple[dict, dict, int]:
"""Calculate and insert an entry into room_stats_current.
Args:
room_id (str)
room_id: The room ID under calculation.
Returns:
Deferred[tuple[dict, dict, int]]: A tuple of room state, membership
counts and stream position.
A tuple of room state, membership counts and stream position.
"""
def _fetch_current_state_stats(txn):
@ -767,11 +765,11 @@ class StatsStore(StateDeltasStore):
current_state_events_count,
users_in_room,
pos,
) = yield self.db.runInteraction(
) = await self.db.runInteraction(
"get_initial_state_for_room", _fetch_current_state_stats
)
state_event_map = yield self.get_events(event_ids, get_prev_content=False)
state_event_map = await self.get_events(event_ids, get_prev_content=False)
room_state = {
"join_rules": None,
@ -806,11 +804,11 @@ class StatsStore(StateDeltasStore):
event.content.get("m.federate", True) is True
)
yield self.update_room_state(room_id, room_state)
await self.update_room_state(room_id, room_state)
local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)]
yield self.update_stats_delta(
await self.update_stats_delta(
ts=self.clock.time_msec(),
stats_type="room",
stats_id=room_id,
@ -826,8 +824,7 @@ class StatsStore(StateDeltasStore):
},
)
@defer.inlineCallbacks
def _calculate_and_set_initial_state_for_user(self, user_id):
async def _calculate_and_set_initial_state_for_user(self, user_id):
def _calculate_and_set_initial_state_for_user_txn(txn):
pos = self._get_max_stream_id_in_current_state_deltas_txn(txn)
@ -842,12 +839,12 @@ class StatsStore(StateDeltasStore):
(count,) = txn.fetchone()
return count, pos
joined_rooms, pos = yield self.db.runInteraction(
joined_rooms, pos = await self.db.runInteraction(
"calculate_and_set_initial_state_for_user",
_calculate_and_set_initial_state_for_user_txn,
)
yield self.update_stats_delta(
await self.update_stats_delta(
ts=self.clock.time_msec(),
stats_type="user",
stats_id=user_id,

View file

@ -255,7 +255,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self._instance_name = hs.get_instance_name()
self._send_federation = hs.should_send_federation()
self._federation_shard_config = hs.config.federation.federation_shard_config
self._federation_shard_config = hs.config.worker.federation_shard_config
# If we're a process that sends federation we may need to reset the
# `federation_stream_position` table to match the current sharding

View file

@ -198,7 +198,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
room_id
)
users_with_profile = yield state.get_current_users_in_room(room_id)
users_with_profile = yield defer.ensureDeferred(
state.get_current_users_in_room(room_id)
)
user_ids = set(users_with_profile)
# Update each user in the user directory.

View file

@ -70,11 +70,11 @@ class UserErasureWorkerStore(SQLBaseStore):
class UserErasureStore(UserErasureWorkerStore):
def mark_user_erased(self, user_id):
def mark_user_erased(self, user_id: str) -> None:
"""Indicate that user_id wishes their message history to be erased.
Args:
user_id (str): full user_id to be erased
user_id: full user_id to be erased
"""
def f(txn):
@ -89,3 +89,25 @@ class UserErasureStore(UserErasureWorkerStore):
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
return self.db.runInteraction("mark_user_erased", f)
def mark_user_not_erased(self, user_id: str) -> None:
"""Indicate that user_id is no longer erased.
Args:
user_id: full user_id to be un-erased
"""
def f(txn):
# first check if they are already in the list
txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,))
if not txn.fetchone():
return
# They are there, delete them.
self.simple_delete_one_txn(
txn, "erased_users", keyvalues={"user_id": user_id}
)
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
return self.db.runInteraction("mark_user_not_erased", f)

View file

@ -139,10 +139,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"get_state_group_delta", _get_state_group_delta_txn
)
@defer.inlineCallbacks
def _get_state_groups_from_groups(
async def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter
):
) -> Dict[int, StateMap[str]]:
"""Returns the state groups for a given set of groups from the
database, filtering on types of state events.
@ -151,13 +150,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_filter: The state filter used to fetch state
from the database.
Returns:
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
Dict of state group to state map.
"""
results = {}
chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
for chunk in chunks:
res = yield self.db.runInteraction(
res = await self.db.runInteraction(
"_get_state_groups_from_groups",
self._get_state_groups_from_groups_txn,
chunk,
@ -206,10 +205,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_filter.filter_state(state_dict_ids), not missing_types
@defer.inlineCallbacks
def _get_state_for_groups(
async def _get_state_for_groups(
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
):
) -> Dict[int, StateMap[str]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
@ -219,7 +217,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_filter: The state filter used to fetch state
from the database.
Returns:
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
Dict of state group to state map.
"""
member_filter, non_member_filter = state_filter.get_member_split()
@ -228,14 +226,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
(
non_member_state,
incomplete_groups_nm,
) = yield self._get_state_for_groups_using_cache(
) = self._get_state_for_groups_using_cache(
groups, self._state_group_cache, state_filter=non_member_filter
)
(
member_state,
incomplete_groups_m,
) = yield self._get_state_for_groups_using_cache(
(member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache(
groups, self._state_group_members_cache, state_filter=member_filter
)
@ -256,7 +251,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# Help the cache hit ratio by expanding the filter a bit
db_state_filter = state_filter.return_expanded()
group_to_state_dict = yield self._get_state_groups_from_groups(
group_to_state_dict = await self._get_state_groups_from_groups(
list(incomplete_groups), state_filter=db_state_filter
)
@ -576,19 +571,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
((sg,) for sg in state_groups_to_delete),
)
@defer.inlineCallbacks
def get_previous_state_groups(self, state_groups):
async def get_previous_state_groups(
self, state_groups: Iterable[int]
) -> Dict[int, int]:
"""Fetch the previous groups of the given state groups.
Args:
state_groups (Iterable[int])
state_groups
Returns:
Deferred[dict[int, int]]: mapping from state group to previous
state group.
A mapping from state group to previous state group.
"""
rows = yield self.db.simple_select_many_batch(
rows = await self.db.simple_select_many_batch(
table="state_group_edges",
column="prev_state_group",
iterable=state_groups,

View file

@ -49,11 +49,11 @@ from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3E
from synapse.storage.types import Connection, Cursor
from synapse.types import Collection
logger = logging.getLogger(__name__)
# python 3 does not have a maximum int value
MAX_TXN_ID = 2 ** 63 - 1
logger = logging.getLogger(__name__)
sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
perf_logger = logging.getLogger("synapse.storage.TIME")
@ -233,7 +233,7 @@ class LoggingTransaction:
try:
return func(sql, *args)
except Exception as e:
logger.debug("[SQL FAIL] {%s} %s", self.name, e)
sql_logger.debug("[SQL FAIL] {%s} %s", self.name, e)
raise
finally:
secs = time.time() - start
@ -419,7 +419,7 @@ class Database(object):
except self.engine.module.OperationalError as e:
# This can happen if the database disappears mid
# transaction.
logger.warning(
transaction_logger.warning(
"[TXN OPERROR] {%s} %s %d/%d", name, e, i, N,
)
if i < N:
@ -427,18 +427,20 @@ class Database(object):
try:
conn.rollback()
except self.engine.module.Error as e1:
logger.warning("[TXN EROLL] {%s} %s", name, e1)
transaction_logger.warning("[TXN EROLL] {%s} %s", name, e1)
continue
raise
except self.engine.module.DatabaseError as e:
if self.engine.is_deadlock(e):
logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
transaction_logger.warning(
"[TXN DEADLOCK] {%s} %d/%d", name, i, N
)
if i < N:
i += 1
try:
conn.rollback()
except self.engine.module.Error as e1:
logger.warning(
transaction_logger.warning(
"[TXN EROLL] {%s} %s", name, e1,
)
continue
@ -478,7 +480,7 @@ class Database(object):
# [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
cursor.close()
except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", name, e)
transaction_logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
finally:
end = monotonic_time()

View file

@ -25,11 +25,10 @@ from prometheus_client import Counter, Histogram
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events import FrozenEvent
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateResolutionStore
from synapse.storage.data_stores import DataStores
from synapse.storage.data_stores.main.events import DeltaState
from synapse.types import StateMap
@ -193,12 +192,11 @@ class EventsPersistenceStorage(object):
self._event_persist_queue = _EventPeristenceQueue()
self._state_resolution_handler = hs.get_state_resolution_handler()
@defer.inlineCallbacks
def persist_events(
async def persist_events(
self,
events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False,
):
) -> int:
"""
Write events to the database
Args:
@ -208,7 +206,7 @@ class EventsPersistenceStorage(object):
which might update the current state etc.
Returns:
Deferred[int]: the stream ordering of the latest persisted event
the stream ordering of the latest persisted event
"""
partitioned = {}
for event, ctx in events_and_contexts:
@ -224,22 +222,19 @@ class EventsPersistenceStorage(object):
for room_id in partitioned:
self._maybe_start_persisting(room_id)
yield make_deferred_yieldable(
await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
max_persisted_id = yield self.main_store.get_current_events_token()
return self.main_store.get_current_events_token()
return max_persisted_id
@defer.inlineCallbacks
def persist_event(
self, event: FrozenEvent, context: EventContext, backfilled: bool = False
):
async def persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
) -> Tuple[int, int]:
"""
Returns:
Deferred[Tuple[int, int]]: the stream ordering of ``event``,
and the stream ordering of the latest persisted event
The stream ordering of `event`, and the stream ordering of the
latest persisted event
"""
deferred = self._event_persist_queue.add_to_queue(
event.room_id, [(event, context)], backfilled=backfilled
@ -247,9 +242,9 @@ class EventsPersistenceStorage(object):
self._maybe_start_persisting(event.room_id)
yield make_deferred_yieldable(deferred)
await make_deferred_yieldable(deferred)
max_persisted_id = yield self.main_store.get_current_events_token()
max_persisted_id = self.main_store.get_current_events_token()
return (event.internal_metadata.stream_ordering, max_persisted_id)
def _maybe_start_persisting(self, room_id: str):
@ -263,7 +258,7 @@ class EventsPersistenceStorage(object):
async def _persist_events(
self,
events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False,
):
"""Calculates the change to current state and forward extremities, and
@ -440,7 +435,7 @@ class EventsPersistenceStorage(object):
async def _calculate_new_extremities(
self,
room_id: str,
event_contexts: List[Tuple[FrozenEvent, EventContext]],
event_contexts: List[Tuple[EventBase, EventContext]],
latest_event_ids: List[str],
):
"""Calculates the new forward extremities for a room given events to
@ -498,7 +493,7 @@ class EventsPersistenceStorage(object):
async def _get_new_state_after_events(
self,
room_id: str,
events_context: List[Tuple[FrozenEvent, EventContext]],
events_context: List[Tuple[EventBase, EventContext]],
old_latest_event_ids: Iterable[str],
new_latest_event_ids: Iterable[str],
) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
@ -648,6 +643,10 @@ class EventsPersistenceStorage(object):
room_version = await self.main_store.get_room_version_id(room_id)
logger.debug("calling resolve_state_groups from preserve_events")
# Avoid a circular import.
from synapse.state import StateResolutionStore
res = await self._state_resolution_handler.resolve_state_groups(
room_id,
room_version,
@ -680,7 +679,7 @@ class EventsPersistenceStorage(object):
async def _is_server_still_joined(
self,
room_id: str,
ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]],
ev_ctx_rm: List[Tuple[EventBase, EventContext]],
delta: DeltaState,
current_state: Optional[StateMap[str]],
potentially_left_users: Set[str],

View file

@ -15,8 +15,7 @@
import itertools
import logging
from twisted.internet import defer
from typing import Set
logger = logging.getLogger(__name__)
@ -28,49 +27,48 @@ class PurgeEventsStorage(object):
def __init__(self, hs, stores):
self.stores = stores
@defer.inlineCallbacks
def purge_room(self, room_id: str):
async def purge_room(self, room_id: str):
"""Deletes all record of a room
"""
state_groups_to_delete = yield self.stores.main.purge_room(room_id)
yield self.stores.state.purge_room_state(room_id, state_groups_to_delete)
state_groups_to_delete = await self.stores.main.purge_room(room_id)
await self.stores.state.purge_room_state(room_id, state_groups_to_delete)
@defer.inlineCallbacks
def purge_history(self, room_id, token, delete_local_events):
async def purge_history(
self, room_id: str, token: str, delete_local_events: bool
) -> None:
"""Deletes room history before a certain point
Args:
room_id (str):
room_id: The room ID
token (str): A topological token to delete events before
token: A topological token to delete events before
delete_local_events (bool):
delete_local_events:
if True, we will delete local events as well as remote ones
(instead of just marking them as outliers and deleting their
state groups).
"""
state_groups = yield self.stores.main.purge_history(
state_groups = await self.stores.main.purge_history(
room_id, token, delete_local_events
)
logger.info("[purge] finding state groups that can be deleted")
sg_to_delete = yield self._find_unreferenced_groups(state_groups)
sg_to_delete = await self._find_unreferenced_groups(state_groups)
yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
@defer.inlineCallbacks
def _find_unreferenced_groups(self, state_groups):
async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]:
"""Used when purging history to figure out which state groups can be
deleted.
Args:
state_groups (set[int]): Set of state groups referenced by events
state_groups: Set of state groups referenced by events
that are going to be deleted.
Returns:
Deferred[set[int]] The set of state groups that can be deleted.
The set of state groups that can be deleted.
"""
# Graph of state group -> previous group
graph = {}
@ -93,7 +91,7 @@ class PurgeEventsStorage(object):
current_search = set(itertools.islice(next_to_search, 100))
next_to_search -= current_search
referenced = yield self.stores.main.get_referenced_state_groups(
referenced = await self.stores.main.get_referenced_state_groups(
current_search
)
referenced_groups |= referenced
@ -102,7 +100,7 @@ class PurgeEventsStorage(object):
# groups that are referenced.
current_search -= referenced
edges = yield self.stores.state.get_previous_state_groups(current_search)
edges = await self.stores.state.get_previous_state_groups(current_search)
prevs = set(edges.values())
# We don't bother re-handling groups we've already seen

View file

@ -14,13 +14,12 @@
# limitations under the License.
import logging
from typing import Iterable, List, TypeVar
from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar
import attr
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import StateMap
logger = logging.getLogger(__name__)
@ -34,16 +33,16 @@ class StateFilter(object):
"""A filter used when querying for state.
Attributes:
types (dict[str, set[str]|None]): Map from type to set of state keys (or
None). This specifies which state_keys for the given type to fetch
from the DB. If None then all events with that type are fetched. If
the set is empty then no events with that type are fetched.
include_others (bool): Whether to fetch events with types that do not
types: Map from type to set of state keys (or None). This specifies
which state_keys for the given type to fetch from the DB. If None
then all events with that type are fetched. If the set is empty
then no events with that type are fetched.
include_others: Whether to fetch events with types that do not
appear in `types`.
"""
types = attr.ib()
include_others = attr.ib(default=False)
types = attr.ib(type=Dict[str, Optional[Set[str]]])
include_others = attr.ib(default=False, type=bool)
def __attrs_post_init__(self):
# If `include_others` is set we canonicalise the filter by removing
@ -52,36 +51,35 @@ class StateFilter(object):
self.types = {k: v for k, v in self.types.items() if v is not None}
@staticmethod
def all():
def all() -> "StateFilter":
"""Creates a filter that fetches everything.
Returns:
StateFilter
The new state filter.
"""
return StateFilter(types={}, include_others=True)
@staticmethod
def none():
def none() -> "StateFilter":
"""Creates a filter that fetches nothing.
Returns:
StateFilter
The new state filter.
"""
return StateFilter(types={}, include_others=False)
@staticmethod
def from_types(types):
def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
"""Creates a filter that only fetches the given types
Args:
types (Iterable[tuple[str, str|None]]): A list of type and state
keys to fetch. A state_key of None fetches everything for
that type
types: A list of type and state keys to fetch. A state_key of None
fetches everything for that type
Returns:
StateFilter
The new state filter.
"""
type_dict = {}
type_dict = {} # type: Dict[str, Optional[Set[str]]]
for typ, s in types:
if typ in type_dict:
if type_dict[typ] is None:
@ -91,24 +89,24 @@ class StateFilter(object):
type_dict[typ] = None
continue
type_dict.setdefault(typ, set()).add(s)
type_dict.setdefault(typ, set()).add(s) # type: ignore
return StateFilter(types=type_dict)
@staticmethod
def from_lazy_load_member_list(members):
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
"""Creates a filter that returns all non-member events, plus the member
events for the given users
Args:
members (iterable[str]): Set of user IDs
members: Set of user IDs
Returns:
StateFilter
The new state filter
"""
return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
def return_expanded(self):
def return_expanded(self) -> "StateFilter":
"""Creates a new StateFilter where type wild cards have been removed
(except for memberships). The returned filter is a superset of the
current one, i.e. anything that passes the current filter will pass
@ -130,7 +128,7 @@ class StateFilter(object):
return all non-member events
Returns:
StateFilter
The new state filter.
"""
if self.is_full():
@ -167,7 +165,7 @@ class StateFilter(object):
include_others=True,
)
def make_sql_filter_clause(self):
def make_sql_filter_clause(self) -> Tuple[str, List[str]]:
"""Converts the filter to an SQL clause.
For example:
@ -179,13 +177,12 @@ class StateFilter(object):
Returns:
tuple[str, list]: The SQL string (may be empty) and arguments. An
empty SQL string is returned when the filter matches everything
(i.e. is "full").
The SQL string (may be empty) and arguments. An empty SQL string is
returned when the filter matches everything (i.e. is "full").
"""
where_clause = ""
where_args = []
where_args = [] # type: List[str]
if self.is_full():
return where_clause, where_args
@ -221,7 +218,7 @@ class StateFilter(object):
return where_clause, where_args
def max_entries_returned(self):
def max_entries_returned(self) -> Optional[int]:
"""Returns the maximum number of entries this filter will return if
known, otherwise returns None.
@ -260,33 +257,33 @@ class StateFilter(object):
return filtered_state
def is_full(self):
def is_full(self) -> bool:
"""Whether this filter fetches everything or not
Returns:
bool
True if the filter fetches everything.
"""
return self.include_others and not self.types
def has_wildcards(self):
def has_wildcards(self) -> bool:
"""Whether the filter includes wildcards or is attempting to fetch
specific state.
Returns:
bool
True if the filter includes wildcards.
"""
return self.include_others or any(
state_keys is None for state_keys in self.types.values()
)
def concrete_types(self):
def concrete_types(self) -> List[Tuple[str, str]]:
"""Returns a list of concrete type/state_keys (i.e. not None) that
will be fetched. This will be a complete list if `has_wildcards`
returns False, but otherwise will be a subset (or even empty).
Returns:
list[tuple[str,str]]
A list of type/state_keys tuples.
"""
return [
(t, s)
@ -295,7 +292,7 @@ class StateFilter(object):
for s in state_keys
]
def get_member_split(self):
def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
"""Return the filter split into two: one which assumes it's exclusively
matching against member state, and one which assumes it's matching
against non member state.
@ -307,7 +304,7 @@ class StateFilter(object):
state caches).
Returns:
tuple[StateFilter, StateFilter]: The member and non member filters
The member and non member filters
"""
if EventTypes.Member in self.types:
@ -340,6 +337,9 @@ class StateGroupStorage(object):
"""Given a state group try to return a previous group and a delta between
the old and the new.
Args:
state_group: The state group used to retrieve state deltas.
Returns:
Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
(prev_group, delta_ids)
@ -347,55 +347,59 @@ class StateGroupStorage(object):
return self.stores.state.get_state_group_delta(state_group)
@defer.inlineCallbacks
def get_state_groups_ids(self, _room_id, event_ids):
async def get_state_groups_ids(
self, _room_id: str, event_ids: Iterable[str]
) -> Dict[int, StateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events
Args:
_room_id (str): id of the room for these events
event_ids (iterable[str]): ids of the events
_room_id: id of the room for these events
event_ids: ids of the events
Returns:
Deferred[dict[int, StateMap[str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
if not event_ids:
return {}
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups(groups)
group_to_state = await self.stores.state._get_state_for_groups(groups)
return group_to_state
@defer.inlineCallbacks
def get_state_ids_for_group(self, state_group):
async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
"""Get the event IDs of all the state in the given state group
Args:
state_group (int)
state_group: A state group for which we want to get the state IDs.
Returns:
Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
Resolves to a map of (type, state_key) -> event_id
"""
group_to_state = yield self._get_state_for_groups((state_group,))
group_to_state = await self._get_state_for_groups((state_group,))
return group_to_state[state_group]
@defer.inlineCallbacks
def get_state_groups(self, room_id, event_ids):
async def get_state_groups(
self, room_id: str, event_ids: Iterable[str]
) -> Dict[int, List[EventBase]]:
""" Get the state groups for the given list of event_ids
Args:
room_id: ID of the room for these events.
event_ids: The event IDs to retrieve state for.
Returns:
Deferred[dict[int, list[EventBase]]]:
dict of state_group_id -> list of state events.
dict of state_group_id -> list of state events.
"""
if not event_ids:
return {}
group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
state_event_map = yield self.stores.main.get_events(
state_event_map = await self.stores.main.get_events(
[
ev_id
for group_ids in group_to_ids.values()
@ -415,7 +419,7 @@ class StateGroupStorage(object):
def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter
):
) -> Awaitable[Dict[int, StateMap[str]]]:
"""Returns the state groups for a given set of groups, filtering on
types of state events.
@ -423,31 +427,34 @@ class StateGroupStorage(object):
groups: list of state group IDs to query
state_filter: The state filter used to fetch state
from the database.
Returns:
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
Dict of state group to state map.
"""
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
@defer.inlineCallbacks
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
async def get_state_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
):
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
Args:
event_ids (list[string])
state_filter (StateFilter): The state filter used to fetch state
from the database.
event_ids: The events to fetch the state of.
state_filter: The state filter used to fetch state.
Returns:
deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
A dict of (event_id) -> (type, state_key) -> [state_events]
"""
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups(
group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter
)
state_event_map = yield self.stores.main.get_events(
state_event_map = await self.stores.main.get_events(
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
get_prev_content=False,
)
@ -463,24 +470,24 @@ class StateGroupStorage(object):
return {event: event_to_state[event] for event in event_ids}
@defer.inlineCallbacks
def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
async def get_state_ids_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
):
"""
Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves)
Args:
event_ids(list(str)): events whose state should be returned
state_filter (StateFilter): The state filter used to fetch state
from the database.
event_ids: events whose state should be returned
state_filter: The state filter used to fetch state from the database.
Returns:
A deferred dict from event_id -> (type, state_key) -> event_id
A dict from event_id -> (type, state_key) -> event_id
"""
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups(
group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter
)
@ -491,67 +498,72 @@ class StateGroupStorage(object):
return {event: event_to_state[event] for event in event_ids}
@defer.inlineCallbacks
def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
async def get_state_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all()
):
"""
Get the state dict corresponding to a particular event
Args:
event_id(str): event whose state should be returned
state_filter (StateFilter): The state filter used to fetch state
from the database.
event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database.
Returns:
A deferred dict from (type, state_key) -> state_event
A dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_for_events([event_id], state_filter)
state_map = await self.get_state_for_events([event_id], state_filter)
return state_map[event_id]
@defer.inlineCallbacks
def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
async def get_state_ids_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all()
):
"""
Get the state dict corresponding to a particular event
Args:
event_id(str): event whose state should be returned
state_filter (StateFilter): The state filter used to fetch state
from the database.
event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database.
Returns:
A deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_ids_for_events([event_id], state_filter)
state_map = await self.get_state_ids_for_events([event_id], state_filter)
return state_map[event_id]
def _get_state_for_groups(
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
):
) -> Awaitable[Dict[int, StateMap[str]]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
Args:
groups (iterable[int]): list of state groups for which we want
to get the state.
state_filter (StateFilter): The state filter used to fetch state
groups: list of state groups for which we want to get the state.
state_filter: The state filter used to fetch state.
from the database.
Returns:
Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
Dict of state group to state map.
"""
return self.stores.state._get_state_for_groups(groups, state_filter)
def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids
self,
event_id: str,
room_id: str,
prev_group: Optional[int],
delta_ids: Optional[dict],
current_state_ids: dict,
):
"""Store a new set of state, returning a newly assigned state group.
Args:
event_id (str): The event ID for which the state was calculated
room_id (str)
prev_group (int|None): A previous state group for the room, optional.
delta_ids (dict|None): The delta between state at `prev_group` and
event_id: The event ID for which the state was calculated.
room_id: ID of the room for which the state was calculated.
prev_group: A previous state group for the room, optional.
delta_ids: The delta between state at `prev_group` and
`current_state_ids`, if `prev_group` was given. Same format as
`current_state_ids`.
current_state_ids (dict): The state to store. Map of (type, state_key)
current_state_ids: The state to store. Map of (type, state_key)
to event_id.
Returns: