Merge remote-tracking branch 'upstream/release-v1.72'

This commit is contained in:
Tulir Asokan 2022-11-16 17:51:05 +02:00
commit 108a80e9b0
113 changed files with 2633 additions and 1258 deletions

View file

@ -50,7 +50,7 @@ class SQLBaseStore(metaclass=ABCMeta):
self.external_cached_functions: Dict[str, CachedFunction] = {}
def process_replication_rows(
def process_replication_rows( # noqa: B027 (no-op by design)
self,
stream_name: str,
instance_name: str,

View file

@ -716,8 +716,6 @@ class EventsPersistenceStorageController:
)
if not is_still_joined:
logger.info("Server no longer in room %s", room_id)
latest_event_ids = set()
current_state = {}
delta.no_longer_in_room = True
state_delta_for_room[room_id] = delta

View file

@ -26,9 +26,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
from .account_data import AccountDataStore
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
@ -138,41 +136,8 @@ class DataStore(
self._clock = hs.get_clock()
self.database_engine = database.engine
self._device_list_id_gen = StreamIdGenerator(
db_conn,
"device_lists_stream",
"stream_id",
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
("device_lists_changes_in_room", "stream_id"),
],
)
super().__init__(database, db_conn, hs)
events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
db_conn,
"current_state_delta_stream",
entity_column="room_id",
stream_column="stream_id",
max_value=events_max, # As we share the stream id with events token
limit=1000,
)
self._curr_state_delta_stream_cache = StreamChangeCache(
"_curr_state_delta_stream_cache",
min_curr_state_delta_id,
prefilled_cache=curr_state_delta_prefill,
)
self._stream_order_on_start = self.get_room_max_stream_ordering()
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
def get_device_stream_token(self) -> int:
# TODO: shouldn't this be moved to `DeviceWorkerStore`?
return self._device_list_id_gen.get_current_token()
async def get_users(self) -> List[JsonDict]:
"""Function to retrieve a list of users in users table.

View file

@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
from typing import (
TYPE_CHECKING,
@ -39,6 +38,8 @@ from synapse.logging.opentracing import (
whitelisted_homeserver,
)
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@ -49,6 +50,11 @@ from synapse.storage.database import (
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
AbstractStreamIdTracker,
StreamIdGenerator,
)
from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@ -80,9 +86,32 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
):
super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
"device_lists_stream",
"stream_id",
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
("device_lists_changes_in_room", "stream_id"),
],
)
else:
self._device_list_id_gen = SlavedIdTracker(
db_conn,
"device_lists_stream",
"stream_id",
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
("device_lists_changes_in_room", "stream_id"),
],
)
# Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a
# StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker).
device_list_max = self._device_list_id_gen.get_current_token() # type: ignore[attr-defined]
device_list_max = self._device_list_id_gen.get_current_token()
device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
db_conn,
"device_lists_stream",
@ -136,6 +165,39 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
)
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
self._invalidate_caches_for_devices(token, rows)
elif stream_name == UserSignatureStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
def _invalidate_caches_for_devices(
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
) -> None:
for row in rows:
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
if row.entity.startswith("@"):
self._device_list_stream_cache.entity_has_changed(row.entity, token)
self.get_cached_devices_for_user.invalidate((row.entity,))
self._get_cached_user_device.invalidate((row.entity,))
self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
else:
self._device_list_federation_stream_cache.entity_has_changed(
row.entity, token
)
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
"""Retrieve number of all devices of given users.
Only returns number of devices that are not marked as hidden.
@ -677,11 +739,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
},
)
@abc.abstractmethod
def get_device_stream_token(self) -> int:
"""Get the current stream id from the _device_list_id_gen"""
...
@trace
@cancellable
async def get_user_devices_from_cache(
@ -1481,6 +1538,10 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Because we have write access, this will be a StreamIdGenerator
# (see DeviceWorkerStore.__init__)
_device_list_id_gen: AbstractStreamIdGenerator
def __init__(
self,
database: DatabasePool,
@ -1805,7 +1866,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
context,
)
async with self._device_list_id_gen.get_next_mult( # type: ignore[attr-defined]
async with self._device_list_id_gen.get_next_mult(
len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
@ -2044,7 +2105,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
[],
)
async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: # type: ignore[attr-defined]
async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
return await self.db_pool.runInteraction(
"add_device_list_outbound_pokes",
add_device_list_outbound_pokes_txn,
@ -2058,7 +2119,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
updates during partial joins.
"""
async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
async with self._device_list_id_gen.get_next() as stream_id:
await self.db_pool.simple_upsert(
table="device_lists_remote_pending",
keyvalues={

View file

@ -355,9 +355,9 @@ class PersistEventsStore:
txn: LoggingTransaction,
*,
events_and_contexts: List[Tuple[EventBase, EventContext]],
inhibit_local_membership_updates: bool = False,
state_delta_for_room: Optional[Dict[str, DeltaState]] = None,
new_forward_extremities: Optional[Dict[str, Set[str]]] = None,
inhibit_local_membership_updates: bool,
state_delta_for_room: Dict[str, DeltaState],
new_forward_extremities: Dict[str, Set[str]],
) -> None:
"""Insert some number of room events into the necessary database tables.
@ -384,9 +384,6 @@ class PersistEventsStore:
PartialStateConflictError: if attempting to persist a partial state event in
a room that has been un-partial stated.
"""
state_delta_for_room = state_delta_for_room or {}
new_forward_extremities = new_forward_extremities or {}
all_events_and_contexts = events_and_contexts
min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering

View file

@ -1435,16 +1435,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
),
)
endpoint = None
row = txn.fetchone()
if row:
endpoint = row[0]
else:
# if the query didn't return a row, we must be almost done. We just
# need to go up to the recorded max_stream_ordering.
endpoint = max_stream_ordering_inclusive
where_clause = "stream_ordering > ?"
args = [min_stream_ordering_exclusive]
if endpoint:
where_clause += " AND stream_ordering <= ?"
args.append(endpoint)
where_clause = "stream_ordering > ? AND stream_ordering <= ?"
args = [min_stream_ordering_exclusive, endpoint]
# now do the updates.
txn.execute(
@ -1458,13 +1458,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
)
logger.info(
"populated new `events` columns up to %s/%i: updated %i rows",
"populated new `events` columns up to %i/%i: updated %i rows",
endpoint,
max_stream_ordering_inclusive,
txn.rowcount,
)
if endpoint is None:
if endpoint >= max_stream_ordering_inclusive:
# we're done
return True

View file

@ -81,6 +81,7 @@ from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import AsyncLruCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@ -233,6 +234,21 @@ class EventsWorkerStore(SQLBaseStore):
db_conn, "events", "stream_ordering", step=-1
)
events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
db_conn,
"current_state_delta_stream",
entity_column="room_id",
stream_column="stream_id",
max_value=events_max, # As we share the stream id with events token
limit=1000,
)
self._curr_state_delta_stream_cache: StreamChangeCache = StreamChangeCache(
"_curr_state_delta_stream_cache",
min_curr_state_delta_id,
prefilled_cache=curr_state_delta_prefill,
)
if hs.config.worker.run_background_tasks:
# We periodically clean out old transaction ID mappings
self._clock.looping_call(
@ -2219,7 +2235,15 @@ class EventsWorkerStore(SQLBaseStore):
return result is not None
async def get_partial_state_events_batch(self, room_id: str) -> List[str]:
"""Get a list of events in the given room that have partial state"""
"""
Get a list of events in the given room that:
- have partial state; and
- are ready to be resynced (because they have no prev_events that are
partial-stated)
See the docstring on `_get_partial_state_events_batch_txn` for more
information.
"""
return await self.db_pool.runInteraction(
"get_partial_state_events_batch",
self._get_partial_state_events_batch_txn,

View file

@ -24,7 +24,7 @@ from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
class FilteringStore(SQLBaseStore):
class FilteringWorkerStore(SQLBaseStore):
@cached(num_args=2)
async def get_user_filter(
self, user_localpart: str, filter_id: Union[int, str]
@ -46,6 +46,8 @@ class FilteringStore(SQLBaseStore):
return db_to_json(def_json)
class FilteringStore(FilteringWorkerStore):
async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int:
def_json = encode_canonical_json(user_filter)

View file

@ -12,13 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
Iterable,
List,
Mapping,
Optional,
@ -31,6 +31,7 @@ from typing import (
from synapse.api.errors import StoreError
from synapse.config.homeserver import ExperimentalConfig
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@ -90,8 +91,6 @@ def _load_rules(
return filtered_rules
# The ABCMeta metaclass ensures that it cannot be instantiated without
# the abstract methods being implemented.
class PushRulesWorkerStore(
ApplicationServiceWorkerStore,
PusherWorkerStore,
@ -99,7 +98,6 @@ class PushRulesWorkerStore(
ReceiptsWorkerStore,
EventsWorkerStore,
SQLBaseStore,
metaclass=abc.ABCMeta,
):
"""This is an abstract base class where subclasses must implement
`get_max_push_rules_stream_id` which can be called in the initializer.
@ -136,14 +134,23 @@ class PushRulesWorkerStore(
prefilled_cache=push_rules_prefill,
)
@abc.abstractmethod
def get_max_push_rules_stream_id(self) -> int:
"""Get the position of the push rules stream.
Returns:
int
"""
raise NotImplementedError()
return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(instance_name, token)
for row in rows:
self.get_push_rules_for_user.invalidate((row.user_id,))
self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
@cached(max_entries=5000)
async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:

View file

@ -27,13 +27,19 @@ from typing import (
)
from synapse.push import PusherConfig, ThrottleParams
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import PushersStream
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
AbstractStreamIdTracker,
StreamIdGenerator,
)
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@ -52,9 +58,21 @@ class PusherWorkerStore(SQLBaseStore):
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self._pushers_id_gen = StreamIdGenerator(
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
if hs.config.worker.worker_app is None:
self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
"pushers",
"id",
extra_tables=[("deleted_pushers", "stream_id")],
)
else:
self._pushers_id_gen = SlavedIdTracker(
db_conn,
"pushers",
"id",
extra_tables=[("deleted_pushers", "stream_id")],
)
self.db_pool.updates.register_background_update_handler(
"remove_deactivated_pushers",
@ -96,6 +114,16 @@ class PusherWorkerStore(SQLBaseStore):
yield PusherConfig(**r)
def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == PushersStream.NAME:
self._pushers_id_gen.advance(instance_name, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
async def get_pushers_by_app_id_and_pushkey(
self, app_id: str, pushkey: str
) -> Iterator[PusherConfig]:
@ -545,8 +573,9 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
# Because we have write access, this will be a StreamIdGenerator
# (see PusherWorkerStore.__init__)
_pushers_id_gen: AbstractStreamIdGenerator
async def add_pusher(
self,

View file

@ -113,24 +113,6 @@ class ReceiptsWorkerStore(SQLBaseStore):
prefilled_cache=receipts_stream_prefill,
)
self.db_pool.updates.register_background_index_update(
"receipts_linearized_unique_index",
index_name="receipts_linearized_unique_index",
table="receipts_linearized",
columns=["room_id", "receipt_type", "user_id"],
where_clause="thread_id IS NULL",
unique=True,
)
self.db_pool.updates.register_background_index_update(
"receipts_graph_unique_index",
index_name="receipts_graph_unique_index",
table="receipts_graph",
columns=["room_id", "receipt_type", "user_id"],
where_clause="thread_id IS NULL",
unique=True,
)
def get_max_receipt_stream_id(self) -> int:
"""Get the current max stream ID for receipts stream"""
return self._receipts_id_gen.get_current_token()
@ -702,9 +684,6 @@ class ReceiptsWorkerStore(SQLBaseStore):
"data": json_encoder.encode(data),
},
where_clause=where_clause,
# receipts_linearized has a unique constraint on
# (user_id, room_id, receipt_type), so no need to lock
lock=False,
)
return rx_ts
@ -862,14 +841,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
"data": json_encoder.encode(data),
},
where_clause=where_clause,
# receipts_graph has a unique constraint on
# (user_id, room_id, receipt_type), so no need to lock
lock=False,
)
class ReceiptsBackgroundUpdateStore(SQLBaseStore):
POPULATE_RECEIPT_EVENT_STREAM_ORDERING = "populate_event_stream_ordering"
RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME = "receipts_linearized_unique_index"
RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME = "receipts_graph_unique_index"
def __init__(
self,
@ -883,6 +861,14 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore):
self.POPULATE_RECEIPT_EVENT_STREAM_ORDERING,
self._populate_receipt_event_stream_ordering,
)
self.db_pool.updates.register_background_update_handler(
self.RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME,
self._background_receipts_linearized_unique_index,
)
self.db_pool.updates.register_background_update_handler(
self.RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME,
self._background_receipts_graph_unique_index,
)
async def _populate_receipt_event_stream_ordering(
self, progress: JsonDict, batch_size: int
@ -938,6 +924,143 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore):
return batch_size
async def _create_receipts_index(self, index_name: str, table: str) -> None:
"""Adds a unique index on `(room_id, receipt_type, user_id)` to the given
receipts table, for non-thread receipts."""
def _create_index(conn: LoggingDatabaseConnection) -> None:
conn.rollback()
# we have to set autocommit, because postgres refuses to
# CREATE INDEX CONCURRENTLY without it.
if isinstance(self.database_engine, PostgresEngine):
conn.set_session(autocommit=True)
try:
c = conn.cursor()
# Now that the duplicates are gone, we can create the index.
concurrently = (
"CONCURRENTLY"
if isinstance(self.database_engine, PostgresEngine)
else ""
)
sql = f"""
CREATE UNIQUE INDEX {concurrently} {index_name}
ON {table}(room_id, receipt_type, user_id)
WHERE thread_id IS NULL
"""
c.execute(sql)
finally:
if isinstance(self.database_engine, PostgresEngine):
conn.set_session(autocommit=False)
await self.db_pool.runWithConnection(_create_index)
async def _background_receipts_linearized_unique_index(
self, progress: dict, batch_size: int
) -> int:
"""Removes duplicate receipts and adds a unique index on
`(room_id, receipt_type, user_id)` to `receipts_linearized`, for non-thread
receipts."""
def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None:
# Identify any duplicate receipts arising from
# https://github.com/matrix-org/synapse/issues/14406.
# We expect the following query to use the per-thread receipt index and take
# less than a minute.
sql = """
SELECT MAX(stream_id), room_id, receipt_type, user_id
FROM receipts_linearized
WHERE thread_id IS NULL
GROUP BY room_id, receipt_type, user_id
HAVING COUNT(*) > 1
"""
txn.execute(sql)
duplicate_keys = cast(List[Tuple[int, str, str, str]], list(txn))
# Then remove duplicate receipts, keeping the one with the highest
# `stream_id`. There should only be a single receipt with any given
# `stream_id`.
for max_stream_id, room_id, receipt_type, user_id in duplicate_keys:
sql = """
DELETE FROM receipts_linearized
WHERE
room_id = ? AND
receipt_type = ? AND
user_id = ? AND
thread_id IS NULL AND
stream_id < ?
"""
txn.execute(sql, (room_id, receipt_type, user_id, max_stream_id))
await self.db_pool.runInteraction(
self.RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME,
_remote_duplicate_receipts_txn,
)
await self._create_receipts_index(
"receipts_linearized_unique_index",
"receipts_linearized",
)
await self.db_pool.updates._end_background_update(
self.RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME
)
return 1
async def _background_receipts_graph_unique_index(
self, progress: dict, batch_size: int
) -> int:
"""Removes duplicate receipts and adds a unique index on
`(room_id, receipt_type, user_id)` to `receipts_graph`, for non-thread
receipts."""
def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None:
# Identify any duplicate receipts arising from
# https://github.com/matrix-org/synapse/issues/14406.
# We expect the following query to use the per-thread receipt index and take
# less than a minute.
sql = """
SELECT room_id, receipt_type, user_id FROM receipts_graph
WHERE thread_id IS NULL
GROUP BY room_id, receipt_type, user_id
HAVING COUNT(*) > 1
"""
txn.execute(sql)
duplicate_keys = cast(List[Tuple[str, str, str]], list(txn))
# Then remove all duplicate receipts.
# We could be clever and try to keep the latest receipt out of every set of
# duplicates, but it's far simpler to remove them all.
for room_id, receipt_type, user_id in duplicate_keys:
sql = """
DELETE FROM receipts_graph
WHERE
room_id = ? AND
receipt_type = ? AND
user_id = ? AND
thread_id IS NULL
"""
txn.execute(sql, (room_id, receipt_type, user_id))
await self.db_pool.runInteraction(
self.RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME,
_remote_duplicate_receipts_txn,
)
await self._create_receipts_index(
"receipts_graph_unique_index",
"receipts_graph",
)
await self.db_pool.updates._end_background_update(
self.RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME
)
return 1
class ReceiptsStore(ReceiptsWorkerStore, ReceiptsBackgroundUpdateStore):
pass

View file

@ -295,6 +295,42 @@ class RelationsWorkerStore(SQLBaseStore):
"get_recent_references_for_event", _get_recent_references_for_event_txn
)
async def get_all_relations_for_event_with_types(
self,
event_id: str,
relation_types: List[str],
) -> List[str]:
"""Get the event IDs of all events that have a relation to the given event with
one of the given relation types.
Args:
event_id: The event for which to look for related events.
relation_types: The types of relations to look for.
Returns:
A list of the IDs of the events that relate to the given event with one of
the given relation types.
"""
def get_all_relation_ids_for_event_with_types_txn(
txn: LoggingTransaction,
) -> List[str]:
rows = self.db_pool.simple_select_many_txn(
txn=txn,
table="event_relations",
column="relation_type",
iterable=relation_types,
keyvalues={"relates_to_id": event_id},
retcols=["event_id"],
)
return [row["event_id"] for row in rows]
return await self.db_pool.runInteraction(
desc="get_all_relation_ids_for_event_with_types",
func=get_all_relation_ids_for_event_with_types_txn,
)
async def event_includes_relation(self, event_id: str) -> bool:
"""Check if the given event relates to another event.

View file

@ -1517,6 +1517,36 @@ class RoomMemberStore(
await self.db_pool.runInteraction("forget_membership", f)
def extract_heroes_from_room_summary(
details: Mapping[str, MemberSummary], me: str
) -> List[str]:
"""Determine the users that represent a room, from the perspective of the `me` user.
The rules which say which users we select are specified in the "Room Summary"
section of
https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv3sync
Returns a list (possibly empty) of heroes' mxids.
"""
empty_ms = MemberSummary([], 0)
joined_user_ids = [
r[0] for r in details.get(Membership.JOIN, empty_ms).members if r[0] != me
]
invited_user_ids = [
r[0] for r in details.get(Membership.INVITE, empty_ms).members if r[0] != me
]
gone_user_ids = [
r[0] for r in details.get(Membership.LEAVE, empty_ms).members if r[0] != me
] + [r[0] for r in details.get(Membership.BAN, empty_ms).members if r[0] != me]
# FIXME: order by stream ordering rather than as returned by SQL
if joined_user_ids or invited_user_ids:
return sorted(joined_user_ids + invited_user_ids)[0:5]
else:
return sorted(gone_user_ids)[0:5]
@attr.s(slots=True, auto_attribs=True)
class _JoinedHostsCache:
"""The cached data used by the `_get_joined_hosts_cache`."""

View file

@ -463,18 +463,17 @@ class SearchStore(SearchBackgroundUpdateStore):
if isinstance(self.database_engine, PostgresEngine):
search_query = search_term
tsquery_func = self.database_engine.tsquery_func
sql = f"""
SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) AS rank,
sql = """
SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) AS rank,
room_id, event_id
FROM event_search
WHERE vector @@ {tsquery_func}('english', ?)
WHERE vector @@ websearch_to_tsquery('english', ?)
"""
args = [search_query, search_query] + args
count_sql = f"""
count_sql = """
SELECT room_id, count(*) as count FROM event_search
WHERE vector @@ {tsquery_func}('english', ?)
WHERE vector @@ websearch_to_tsquery('english', ?)
"""
count_args = [search_query] + count_args
elif isinstance(self.database_engine, Sqlite3Engine):
@ -523,9 +522,7 @@ class SearchStore(SearchBackgroundUpdateStore):
highlights = None
if isinstance(self.database_engine, PostgresEngine):
highlights = await self._find_highlights_in_postgres(
search_query, events, tsquery_func
)
highlights = await self._find_highlights_in_postgres(search_query, events)
count_sql += " GROUP BY room_id"
@ -604,18 +601,17 @@ class SearchStore(SearchBackgroundUpdateStore):
if isinstance(self.database_engine, PostgresEngine):
search_query = search_term
tsquery_func = self.database_engine.tsquery_func
sql = f"""
SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) as rank,
sql = """
SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) as rank,
origin_server_ts, stream_ordering, room_id, event_id
FROM event_search
WHERE vector @@ {tsquery_func}('english', ?) AND
WHERE vector @@ websearch_to_tsquery('english', ?) AND
"""
args = [search_query, search_query] + args
count_sql = f"""
count_sql = """
SELECT room_id, count(*) as count FROM event_search
WHERE vector @@ {tsquery_func}('english', ?) AND
WHERE vector @@ websearch_to_tsquery('english', ?) AND
"""
count_args = [search_query] + count_args
elif isinstance(self.database_engine, Sqlite3Engine):
@ -686,9 +682,7 @@ class SearchStore(SearchBackgroundUpdateStore):
highlights = None
if isinstance(self.database_engine, PostgresEngine):
highlights = await self._find_highlights_in_postgres(
search_query, events, tsquery_func
)
highlights = await self._find_highlights_in_postgres(search_query, events)
count_sql += " GROUP BY room_id"
@ -714,7 +708,7 @@ class SearchStore(SearchBackgroundUpdateStore):
}
async def _find_highlights_in_postgres(
self, search_query: str, events: List[EventBase], tsquery_func: str
self, search_query: str, events: List[EventBase]
) -> Set[str]:
"""Given a list of events and a search term, return a list of words
that match from the content of the event.
@ -725,7 +719,6 @@ class SearchStore(SearchBackgroundUpdateStore):
Args:
search_query
events: A list of events
tsquery_func: The tsquery_* function to use when making queries
Returns:
A set of strings.
@ -758,13 +751,16 @@ class SearchStore(SearchBackgroundUpdateStore):
while stop_sel in value:
stop_sel += ">"
query = f"SELECT ts_headline(?, {tsquery_func}('english', ?), %s)" % (
_to_postgres_options(
{
"StartSel": start_sel,
"StopSel": stop_sel,
"MaxFragments": "50",
}
query = (
"SELECT ts_headline(?, websearch_to_tsquery('english', ?), %s)"
% (
_to_postgres_options(
{
"StartSel": start_sel,
"StopSel": stop_sel,
"MaxFragments": "50",
}
)
)
)
txn.execute(query, (value, search_query))

View file

@ -415,6 +415,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
self._stream_order_on_start = self.get_room_max_stream_ordering()
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
def get_room_max_stream_ordering(self) -> int:
"""Get the stream_ordering of regular events that we have committed up to

View file

@ -81,8 +81,8 @@ class PostgresEngine(
allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
# Are we on a supported PostgreSQL version?
if not allow_outdated_version and self._version < 100000:
raise RuntimeError("Synapse requires PostgreSQL 10 or above.")
if not allow_outdated_version and self._version < 110000:
raise RuntimeError("Synapse requires PostgreSQL 11 or above.")
with db_conn.cursor() as txn:
txn.execute("SHOW SERVER_ENCODING")
@ -170,22 +170,6 @@ class PostgresEngine(
"""Do we support the `RETURNING` clause in insert/update/delete?"""
return True
@property
def tsquery_func(self) -> str:
"""
Selects a tsquery_* func to use.
Ref: https://www.postgresql.org/docs/current/textsearch-controls.html
Returns:
The function name.
"""
# Postgres 11 added support for websearch_to_tsquery.
assert self._version is not None
if self._version >= 110000:
return "websearch_to_tsquery"
return "plainto_tsquery"
def is_deadlock(self, error: Exception) -> bool:
if isinstance(error, psycopg2.DatabaseError):
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html

View file

@ -0,0 +1,33 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- By default the postgres statistics collector massively underestimates the
-- number of distinct rooms in `event_search`, which can cause postgres to use
-- table scans for queries for multiple rooms.
--
-- To work around this we can manually tell postgres the number of distinct rooms
-- by setting `n_distinct` (a negative value here is the number of distinct values
-- divided by the number of rows, so -0.01 means on average there are 100 rows per
-- distinct value). We don't need a particularly accurate number here, as a) we just
-- want it to always use index scans and b) our estimate is going to be better than the
-- one made by the statistics collector.
ALTER TABLE event_search ALTER COLUMN room_id SET (n_distinct = -0.01);
-- Ideally we'd do an `ANALYZE event_search (room_id)` here so that
-- the above gets picked up immediately, but that can take a bit of time so we
-- rely on the autovacuum eventually getting run and doing that in the
-- background for us.