Merge remote-tracking branch 'upstream/release-v1.73'

This commit is contained in:
Tulir Asokan 2022-11-29 15:51:33 +02:00
commit bb26f5f0a9
167 changed files with 3234 additions and 1676 deletions

View file

@ -204,9 +204,8 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
process to to so, calling the per_item_callback for each item.
Args:
room_id (str):
task (_EventPersistQueueTask): A _PersistEventsTask or
_UpdateCurrentStateTask to process.
room_id:
task: A _PersistEventsTask or _UpdateCurrentStateTask to process.
Returns:
the result returned by the `_per_item_callback` passed to

View file

@ -569,15 +569,15 @@ class DatabasePool:
retcols=["update_name"],
desc="check_background_updates",
)
updates = [x["update_name"] for x in updates]
background_update_names = [x["update_name"] for x in updates]
for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
if update_name not in updates:
if update_name not in background_update_names:
logger.debug("Now safe to upsert in %s", table)
self._unsafe_to_upsert_tables.discard(table)
# If there's any updates still running, reschedule to run.
if updates:
if background_update_names:
self._clock.call_later(
15.0,
run_as_background_process,
@ -1129,7 +1129,6 @@ class DatabasePool:
values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None,
desc: str = "simple_upsert",
lock: bool = True,
) -> bool:
"""Insert a row with values + insertion_values; on conflict, update with values.
@ -1154,21 +1153,12 @@ class DatabasePool:
requiring that a unique index exist on the column names used to detect a
conflict (i.e. `keyvalues.keys()`).
If there is no such index, we can "emulate" an upsert with a SELECT followed
by either an INSERT or an UPDATE. This is unsafe: we cannot make the same
atomicity guarantees that a native upsert can and are very vulnerable to races
and crashes. Therefore if we wish to upsert without an appropriate unique index,
we must either:
1. Acquire a table-level lock before the emulated upsert (`lock=True`), or
2. VERY CAREFULLY ensure that we are the only thread and worker which will be
writing to this table, in which case we can proceed without a lock
(`lock=False`).
Generally speaking, you should use `lock=True`. If the table in question has a
unique index[*], this class will use a native upsert (which is atomic and so can
ignore the `lock` argument). Otherwise this class will use an emulated upsert,
in which case we want the safer option unless we been VERY CAREFUL.
If there is no such index yet[*], we can "emulate" an upsert with a SELECT
followed by either an INSERT or an UPDATE. This is unsafe unless *all* upserters
run at the SERIALIZABLE isolation level: we cannot make the same atomicity
guarantees that a native upsert can and are very vulnerable to races and
crashes. Therefore to upsert without an appropriate unique index, we acquire a
table-level lock before the emulated upsert.
[*]: Some tables have unique indices added to them in the background. Those
tables `T` are keys in the dictionary UNIQUE_INDEX_BACKGROUND_UPDATES,
@ -1189,7 +1179,6 @@ class DatabasePool:
values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting
desc: description of the transaction, for logging and metrics
lock: True to lock the table when doing the upsert.
Returns:
Returns True if a row was inserted or updated (i.e. if `values` is
not empty then this always returns True)
@ -1209,7 +1198,6 @@ class DatabasePool:
keyvalues,
values,
insertion_values,
lock=lock,
db_autocommit=autocommit,
)
except self.engine.module.IntegrityError as e:
@ -1232,7 +1220,6 @@ class DatabasePool:
values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None,
where_clause: Optional[str] = None,
lock: bool = True,
) -> bool:
"""
Pick the UPSERT method which works best on the platform. Either the
@ -1245,8 +1232,6 @@ class DatabasePool:
values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting
where_clause: An index predicate to apply to the upsert.
lock: True to lock the table when doing the upsert. Unused when performing
a native upsert.
Returns:
Returns True if a row was inserted or updated (i.e. if `values` is
not empty then this always returns True)
@ -1270,7 +1255,6 @@ class DatabasePool:
values,
insertion_values=insertion_values,
where_clause=where_clause,
lock=lock,
)
def simple_upsert_txn_emulated(
@ -1291,14 +1275,15 @@ class DatabasePool:
insertion_values: additional key/values to use only when inserting
where_clause: An index predicate to apply to the upsert.
lock: True to lock the table when doing the upsert.
Must not be False unless the table has already been locked.
Returns:
Returns True if a row was inserted or updated (i.e. if `values` is
not empty then this always returns True)
"""
insertion_values = insertion_values or {}
# We need to lock the table :(, unless we're *really* careful
if lock:
# We need to lock the table :(
self.engine.lock_table(txn, table)
def _getwhere(key: str) -> str:
@ -1406,7 +1391,6 @@ class DatabasePool:
value_names: Collection[str],
value_values: Collection[Collection[Any]],
desc: str,
lock: bool = True,
) -> None:
"""
Upsert, many times.
@ -1418,8 +1402,6 @@ class DatabasePool:
value_names: The value column names
value_values: A list of each row's value column values.
Ignored if value_names is empty.
lock: True to lock the table when doing the upsert. Unused when performing
a native upsert.
"""
# We can autocommit if it safe to upsert
@ -1433,7 +1415,6 @@ class DatabasePool:
key_values,
value_names,
value_values,
lock=lock,
db_autocommit=autocommit,
)
@ -1445,7 +1426,6 @@ class DatabasePool:
key_values: Collection[Iterable[Any]],
value_names: Collection[str],
value_values: Iterable[Iterable[Any]],
lock: bool = True,
) -> None:
"""
Upsert, many times.
@ -1457,8 +1437,6 @@ class DatabasePool:
value_names: The value column names
value_values: A list of each row's value column values.
Ignored if value_names is empty.
lock: True to lock the table when doing the upsert. Unused when performing
a native upsert.
"""
if table not in self._unsafe_to_upsert_tables:
return self.simple_upsert_many_txn_native_upsert(
@ -1466,7 +1444,12 @@ class DatabasePool:
)
else:
return self.simple_upsert_many_txn_emulated(
txn, table, key_names, key_values, value_names, value_values, lock=lock
txn,
table,
key_names,
key_values,
value_names,
value_values,
)
def simple_upsert_many_txn_emulated(
@ -1477,7 +1460,6 @@ class DatabasePool:
key_values: Collection[Iterable[Any]],
value_names: Collection[str],
value_values: Iterable[Iterable[Any]],
lock: bool = True,
) -> None:
"""
Upsert, many times, but without native UPSERT support or batching.
@ -1489,18 +1471,16 @@ class DatabasePool:
value_names: The value column names
value_values: A list of each row's value column values.
Ignored if value_names is empty.
lock: True to lock the table when doing the upsert.
"""
# No value columns, therefore make a blank list so that the following
# zip() works correctly.
if not value_names:
value_values = [() for x in range(len(key_values))]
if lock:
# Lock the table just once, to prevent it being done once per row.
# Note that, according to Postgres' documentation, once obtained,
# the lock is held for the remainder of the current transaction.
self.engine.lock_table(txn, "user_ips")
# Lock the table just once, to prevent it being done once per row.
# Note that, according to Postgres' documentation, once obtained,
# the lock is held for the remainder of the current transaction.
self.engine.lock_table(txn, "user_ips")
for keyv, valv in zip(key_values, value_values):
_keys = {x: y for x, y in zip(key_names, keyv)}
@ -2075,13 +2055,14 @@ class DatabasePool:
retcols: Collection[str],
allow_none: bool = False,
) -> Optional[Dict[str, Any]]:
select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
" AND ".join("%s = ?" % (k,) for k in keyvalues),
)
select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
if keyvalues:
select_sql += " WHERE %s" % (" AND ".join("%s = ?" % k for k in keyvalues),)
txn.execute(select_sql, list(keyvalues.values()))
else:
txn.execute(select_sql)
txn.execute(select_sql, list(keyvalues.values()))
row = txn.fetchone()
if not row:

View file

@ -27,7 +27,6 @@ from typing import (
)
from synapse.api.constants import AccountDataTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
from synapse.storage._base import db_to_json
from synapse.storage.database import (
@ -68,12 +67,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
# to write account data. A value of `True` implies that `_account_data_id_gen`
# is an `AbstractStreamIdGenerator` and not just a tracker.
self._account_data_id_gen: AbstractStreamIdTracker
self._can_write_to_account_data = (
self._instance_name in hs.config.worker.writers.account_data
)
if isinstance(database.engine, PostgresEngine):
self._can_write_to_account_data = (
self._instance_name in hs.config.worker.writers.account_data
)
self._account_data_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
@ -95,21 +93,13 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
# updated over replication. (Multiple writers are not supported for
# SQLite).
if self._instance_name in hs.config.worker.writers.account_data:
self._can_write_to_account_data = True
self._account_data_id_gen = StreamIdGenerator(
db_conn,
"room_account_data",
"stream_id",
extra_tables=[("room_tags_revisions", "stream_id")],
)
else:
self._account_data_id_gen = SlavedIdTracker(
db_conn,
"room_account_data",
"stream_id",
extra_tables=[("room_tags_revisions", "stream_id")],
)
self._account_data_id_gen = StreamIdGenerator(
db_conn,
"room_account_data",
"stream_id",
extra_tables=[("room_tags_revisions", "stream_id")],
is_writer=self._instance_name in hs.config.worker.writers.account_data,
)
account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache(
@ -459,9 +449,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
content_json = json_encoder.encode(content)
async with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict.
await self.db_pool.simple_upsert(
desc="add_room_account_data",
table="room_account_data",
@ -471,7 +458,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
"account_data_type": account_data_type,
},
values={"stream_id": next_id, "content": content_json},
lock=False,
)
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
@ -527,15 +513,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
) -> None:
content_json = json_encoder.encode(content)
# no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict.
self.db_pool.simple_upsert_txn(
txn,
table="account_data",
keyvalues={"user_id": user_id, "account_data_type": account_data_type},
values={"stream_id": next_id, "content": content_json},
lock=False,
)
# Ignored users get denormalized into a separate table as an optimisation.

View file

@ -20,7 +20,7 @@ from synapse.appservice import (
ApplicationService,
ApplicationServiceState,
AppServiceTransaction,
TransactionOneTimeKeyCounts,
TransactionOneTimeKeysCount,
TransactionUnusedFallbackKeys,
)
from synapse.config.appservice import load_appservices
@ -260,7 +260,7 @@ class ApplicationServiceTransactionWorkerStore(
events: List[EventBase],
ephemeral: List[JsonDict],
to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts,
one_time_keys_count: TransactionOneTimeKeysCount,
unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceListUpdates,
) -> AppServiceTransaction:
@ -273,7 +273,7 @@ class ApplicationServiceTransactionWorkerStore(
events: A list of persistent events to put in the transaction.
ephemeral: A list of ephemeral events to put in the transaction.
to_device_messages: A list of to-device messages to put in the transaction.
one_time_key_counts: Counts of remaining one-time keys for relevant
one_time_keys_count: Counts of remaining one-time keys for relevant
appservice devices in the transaction.
unused_fallback_keys: Lists of unused fallback keys for relevant
appservice devices in the transaction.
@ -299,7 +299,7 @@ class ApplicationServiceTransactionWorkerStore(
events=events,
ephemeral=ephemeral,
to_device_messages=to_device_messages,
one_time_key_counts=one_time_key_counts,
one_time_keys_count=one_time_keys_count,
unused_fallback_keys=unused_fallback_keys,
device_list_summary=device_list_summary,
)
@ -379,7 +379,7 @@ class ApplicationServiceTransactionWorkerStore(
events=events,
ephemeral=[],
to_device_messages=[],
one_time_key_counts={},
one_time_keys_count={},
unused_fallback_keys={},
device_list_summary=DeviceListUpdates(),
)
@ -451,8 +451,6 @@ class ApplicationServiceTransactionWorkerStore(
table="application_services_state",
keyvalues={"as_id": service.id},
values={f"{stream_type}_stream_id": pos},
# no need to lock when emulating upsert: as_id is a unique key
lock=False,
desc="set_appservice_stream_type_pos",
)

View file

@ -259,6 +259,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if relates_to:
self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,))
self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,))
self._attempt_to_invalidate_cache(
"get_aggregation_groups_for_event", (relates_to,)
)

View file

@ -38,7 +38,6 @@ from synapse.logging.opentracing import (
whitelisted_homeserver,
)
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
@ -86,28 +85,19 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
):
super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
"device_lists_stream",
"stream_id",
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
("device_lists_changes_in_room", "stream_id"),
],
)
else:
self._device_list_id_gen = SlavedIdTracker(
db_conn,
"device_lists_stream",
"stream_id",
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
("device_lists_changes_in_room", "stream_id"),
],
)
# In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process.
self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
"device_lists_stream",
"stream_id",
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
("device_lists_changes_in_room", "stream_id"),
],
is_writer=hs.config.worker.worker_app is None,
)
# Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a
# StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker).
@ -535,7 +525,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
limit: Maximum number of device updates to return
Returns:
List: List of device update tuples:
List of device update tuples:
- user_id
- device_id
- stream_id
@ -1451,6 +1441,13 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
self._remove_duplicate_outbound_pokes,
)
self.db_pool.updates.register_background_index_update(
"device_lists_changes_in_room_by_room_index",
index_name="device_lists_changes_in_room_by_room_idx",
table="device_lists_changes_in_room",
columns=["room_id", "stream_id"],
)
async def _drop_device_list_streams_non_unique_indexes(
self, progress: JsonDict, batch_size: int
) -> int:
@ -1536,6 +1533,71 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
return rows
async def check_too_many_devices_for_user(self, user_id: str) -> Collection[str]:
"""Check if the user has a lot of devices, and if so return the set of
devices we can prune.
This does *not* return hidden devices or devices with E2E keys.
"""
num_devices = await self.db_pool.simple_select_one_onecol(
table="devices",
keyvalues={"user_id": user_id, "hidden": False},
retcol="COALESCE(COUNT(*), 0)",
desc="count_devices",
)
# We let users have up to ten devices without pruning.
if num_devices <= 10:
return ()
# We prune everything older than N days.
max_last_seen = self._clock.time_msec() - 14 * 24 * 60 * 60 * 1000
if num_devices > 50:
# If the user has more than 50 devices, then we chose a last seen
# that ensures we keep at most 50 devices.
sql = """
SELECT last_seen FROM devices
LEFT JOIN e2e_device_keys_json USING (user_id, device_id)
WHERE
user_id = ?
AND NOT hidden
AND last_seen IS NOT NULL
AND key_json IS NULL
ORDER BY last_seen DESC
LIMIT 1
OFFSET 50
"""
rows = await self.db_pool.execute(
"check_too_many_devices_for_user_last_seen", None, sql, (user_id,)
)
if rows:
max_last_seen = max(rows[0][0], max_last_seen)
# Now fetch the devices to delete.
sql = """
SELECT DISTINCT device_id FROM devices
LEFT JOIN e2e_device_keys_json USING (user_id, device_id)
WHERE
user_id = ?
AND NOT hidden
AND last_seen < ?
AND key_json IS NULL
"""
def check_too_many_devices_for_user_txn(
txn: LoggingTransaction,
) -> Collection[str]:
txn.execute(sql, (user_id, max_last_seen))
return {device_id for device_id, in txn}
return await self.db_pool.runInteraction(
"check_too_many_devices_for_user",
check_too_many_devices_for_user_txn,
)
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Because we have write access, this will be a StreamIdGenerator
@ -1594,6 +1656,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
values={},
insertion_values={
"display_name": initial_device_display_name,
"last_seen": self._clock.time_msec(),
"hidden": False,
},
desc="store_device",
@ -1639,7 +1702,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
raise StoreError(500, "Problem storing device.")
async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
async def delete_devices(self, user_id: str, device_ids: Collection[str]) -> None:
"""Deletes several devices.
Args:
@ -1747,9 +1810,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
values={"content": json_encoder.encode(content)},
# we don't need to lock, because we assume we are the only thread
# updating this user's devices.
lock=False,
)
txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id))
@ -1763,9 +1823,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
values={"stream_id": stream_id},
# again, we can assume we are the only thread updating this user's
# extremity.
lock=False,
)
async def update_remote_device_list_cache(
@ -1818,9 +1875,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
values={"stream_id": stream_id},
# we don't need to lock, because we can assume we are the only thread
# updating this user's extremity.
lock=False,
)
async def add_device_change_to_streams(
@ -2018,27 +2072,48 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
async def get_uncoverted_outbound_room_pokes(
self, limit: int = 10
self, start_stream_id: int, start_room_id: str, limit: int = 10
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
"""Get device list changes by room that have not yet been handled and
written to `device_lists_outbound_pokes`.
Args:
start_stream_id: Together with `start_room_id`, indicates the position after
which to return device list changes.
start_room_id: Together with `start_stream_id`, indicates the position after
which to return device list changes.
limit: The maximum number of device list changes to return.
Returns:
A list of user ID, device ID, room ID, stream ID and optional opentracing context.
A list of user ID, device ID, room ID, stream ID and optional opentracing
context, in order of ascending (stream ID, room ID).
"""
sql = """
SELECT user_id, device_id, room_id, stream_id, opentracing_context
FROM device_lists_changes_in_room
WHERE NOT converted_to_destinations
ORDER BY stream_id
WHERE
(stream_id, room_id) > (?, ?) AND
stream_id <= ? AND
NOT converted_to_destinations
ORDER BY stream_id ASC, room_id ASC
LIMIT ?
"""
def get_uncoverted_outbound_room_pokes_txn(
txn: LoggingTransaction,
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
txn.execute(sql, (limit,))
txn.execute(
sql,
(
start_stream_id,
start_room_id,
# Avoid returning rows if there may be uncommitted device list
# changes with smaller stream IDs.
self._device_list_id_gen.get_current_token(),
limit,
),
)
return [
(
@ -2060,49 +2135,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str,
device_id: str,
room_id: str,
stream_id: Optional[int],
hosts: Collection[str],
context: Optional[Dict[str, str]],
) -> None:
"""Queue the device update to be sent to the given set of hosts,
calculated from the room ID.
Marks the associated row in `device_lists_changes_in_room` as handled,
if `stream_id` is provided.
"""
if not hosts:
return
def add_device_list_outbound_pokes_txn(
txn: LoggingTransaction, stream_ids: List[int]
) -> None:
if hosts:
self._add_device_outbound_poke_to_stream_txn(
txn,
user_id=user_id,
device_id=device_id,
hosts=hosts,
stream_ids=stream_ids,
context=context,
)
if stream_id:
self.db_pool.simple_update_txn(
txn,
table="device_lists_changes_in_room",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"stream_id": stream_id,
"room_id": room_id,
},
updatevalues={"converted_to_destinations": True},
)
if not hosts:
# If there are no hosts then we don't try and generate stream IDs.
return await self.db_pool.runInteraction(
"add_device_list_outbound_pokes",
add_device_list_outbound_pokes_txn,
[],
self._add_device_outbound_poke_to_stream_txn(
txn,
user_id=user_id,
device_id=device_id,
hosts=hosts,
stream_ids=stream_ids,
context=context,
)
async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
@ -2166,3 +2217,37 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"get_pending_remote_device_list_updates_for_room",
get_pending_remote_device_list_updates_for_room_txn,
)
async def get_device_change_last_converted_pos(self) -> Tuple[int, str]:
"""
Get the position of the last row in `device_list_changes_in_room` that has been
converted to `device_lists_outbound_pokes`.
Rows with a strictly greater position where `converted_to_destinations` is
`FALSE` have not been converted.
"""
row = await self.db_pool.simple_select_one(
table="device_lists_changes_converted_stream_position",
keyvalues={},
retcols=["stream_id", "room_id"],
desc="get_device_change_last_converted_pos",
)
return row["stream_id"], row["room_id"]
async def set_device_change_last_converted_pos(
self,
stream_id: int,
room_id: str,
) -> None:
"""
Set the position of the last row in `device_list_changes_in_room` that has been
converted to `device_lists_outbound_pokes`.
"""
await self.db_pool.simple_update_one(
table="device_lists_changes_converted_stream_position",
keyvalues={},
updatevalues={"stream_id": stream_id, "room_id": room_id},
desc="set_device_change_last_converted_pos",
)

View file

@ -391,10 +391,10 @@ class EndToEndRoomKeyStore(SQLBaseStore):
Returns:
A dict giving the info metadata for this backup version, with
fields including:
version(str)
algorithm(str)
auth_data(object): opaque dict supplied by the client
etag(int): tag of the keys in the backup
version (str)
algorithm (str)
auth_data (object): opaque dict supplied by the client
etag (int): tag of the keys in the backup
"""
def _get_e2e_room_keys_version_info_txn(txn: LoggingTransaction) -> JsonDict:

View file

@ -33,7 +33,7 @@ from typing_extensions import Literal
from synapse.api.constants import DeviceKeyAlgorithms
from synapse.appservice import (
TransactionOneTimeKeyCounts,
TransactionOneTimeKeysCount,
TransactionUnusedFallbackKeys,
)
from synapse.logging.opentracing import log_kv, set_tag, trace
@ -412,10 +412,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
"""Retrieve a number of one-time keys for a user
Args:
user_id(str): id of user to get keys for
device_id(str): id of device to get keys for
key_ids(list[str]): list of key ids (excluding algorithm) to
retrieve
user_id: id of user to get keys for
device_id: id of device to get keys for
key_ids: list of key ids (excluding algorithm) to retrieve
Returns:
A map from (algorithm, key_id) to json string for key
@ -515,7 +514,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
async def count_bulk_e2e_one_time_keys_for_as(
self, user_ids: Collection[str]
) -> TransactionOneTimeKeyCounts:
) -> TransactionOneTimeKeysCount:
"""
Counts, in bulk, the one-time keys for all the users specified.
Intended to be used by application services for populating OTK counts in
@ -529,7 +528,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
def _count_bulk_e2e_one_time_keys_txn(
txn: LoggingTransaction,
) -> TransactionOneTimeKeyCounts:
) -> TransactionOneTimeKeysCount:
user_in_where_clause, user_parameters = make_in_list_sql_clause(
self.database_engine, "user_id", user_ids
)
@ -542,7 +541,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
"""
txn.execute(sql, user_parameters)
result: TransactionOneTimeKeyCounts = {}
result: TransactionOneTimeKeysCount = {}
for user_id, device_id, algorithm, count in txn:
# We deliberately construct empty dictionaries for

View file

@ -1686,7 +1686,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
},
insertion_values={},
desc="insert_insertion_extremity",
lock=False,
)
async def insert_received_event_to_staging(

View file

@ -1279,9 +1279,10 @@ class PersistEventsStore:
Pick the earliest non-outlier if there is one, else the earliest one.
Args:
events_and_contexts (list[(EventBase, EventContext)]):
events_and_contexts:
Returns:
list[(EventBase, EventContext)]: filtered list
filtered list
"""
new_events_and_contexts: OrderedDict[
str, Tuple[EventBase, EventContext]
@ -1307,9 +1308,8 @@ class PersistEventsStore:
"""Update min_depth for each room
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
events_and_contexts (list[(EventBase, EventContext)]): events
we are persisting
txn: db connection
events_and_contexts: events we are persisting
"""
depth_updates: Dict[str, int] = {}
for event, context in events_and_contexts:
@ -1580,13 +1580,11 @@ class PersistEventsStore:
"""Update all the miscellaneous tables for new events
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
events_and_contexts (list[(EventBase, EventContext)]): events
we are persisting
all_events_and_contexts (list[(EventBase, EventContext)]): all
events that we were going to persist. This includes events
we've already persisted, etc, that wouldn't appear in
events_and_context.
txn: db connection
events_and_contexts: events we are persisting
all_events_and_contexts: all events that we were going to persist.
This includes events we've already persisted, etc, that wouldn't
appear in events_and_context.
inhibit_local_membership_updates: Stop the local_current_membership
from being updated by these events. This should be set to True
for backfilled events because backfilled events in the past do
@ -2051,6 +2049,10 @@ class PersistEventsStore:
self.store._invalidate_cache_and_stream(
txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,)
)
if rel_type == RelationTypes.REFERENCE:
self.store._invalidate_cache_and_stream(
txn, self.store.get_references_for_event, (redacted_relates_to,)
)
if rel_type == RelationTypes.REPLACE:
self.store._invalidate_cache_and_stream(
txn, self.store.get_applicable_edit, (redacted_relates_to,)

View file

@ -59,7 +59,6 @@ from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@ -213,26 +212,20 @@ class EventsWorkerStore(SQLBaseStore):
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
# updated over replication. (Multiple writers are not supported for
# SQLite).
if hs.get_instance_name() in hs.config.worker.writers.events:
self._stream_id_gen = StreamIdGenerator(
db_conn,
"events",
"stream_ordering",
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
"events",
"stream_ordering",
step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
)
else:
self._stream_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering"
)
self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1
)
self._stream_id_gen = StreamIdGenerator(
db_conn,
"events",
"stream_ordering",
is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
"events",
"stream_ordering",
step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
)
events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
@ -1589,7 +1582,7 @@ class EventsWorkerStore(SQLBaseStore):
room_id: The room ID to query.
Returns:
dict[str:float] of complexity version to complexity.
Map of complexity version to complexity.
"""
state_events = await self.get_current_state_event_counts(room_id)

View file

@ -217,7 +217,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
def _reap_users(txn: LoggingTransaction, reserved_users: List[str]) -> None:
"""
Args:
reserved_users (tuple): reserved users to preserve
reserved_users: reserved users to preserve
"""
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
@ -370,8 +370,8 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
should not appear in the MAU stats).
Args:
txn (cursor):
user_id (str): user to add/update
txn:
user_id: user to add/update
"""
assert (
self._update_on_this_worker
@ -401,7 +401,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
add the user to the monthly active tables
Args:
user_id(str): the user_id to query
user_id: the user_id to query
"""
assert (
self._update_on_this_worker

View file

@ -30,7 +30,6 @@ from typing import (
from synapse.api.errors import StoreError
from synapse.config.homeserver import ExperimentalConfig
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
@ -85,7 +84,10 @@ def _load_rules(
push_rules = PushRules(ruleslist)
filtered_rules = FilteredPushRules(
push_rules, enabled_map, msc3664_enabled=experimental_config.msc3664_enabled
push_rules,
enabled_map,
msc3664_enabled=experimental_config.msc3664_enabled,
msc1767_enabled=experimental_config.msc1767_enabled,
)
return filtered_rules
@ -111,14 +113,14 @@ class PushRulesWorkerStore(
):
super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn, "push_rules_stream", "stream_id"
)
else:
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id"
)
# In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process.
self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
"push_rules_stream",
"stream_id",
is_writer=hs.config.worker.worker_app is None,
)
push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
db_conn,

View file

@ -27,7 +27,6 @@ from typing import (
)
from synapse.push import PusherConfig, ThrottleParams
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import PushersStream
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
@ -59,20 +58,15 @@ class PusherWorkerStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
"pushers",
"id",
extra_tables=[("deleted_pushers", "stream_id")],
)
else:
self._pushers_id_gen = SlavedIdTracker(
db_conn,
"pushers",
"id",
extra_tables=[("deleted_pushers", "stream_id")],
)
# In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process.
self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
"pushers",
"id",
extra_tables=[("deleted_pushers", "stream_id")],
is_writer=hs.config.worker.worker_app is None,
)
self.db_pool.updates.register_background_update_handler(
"remove_deactivated_pushers",
@ -331,14 +325,11 @@ class PusherWorkerStore(SQLBaseStore):
async def set_throttle_params(
self, pusher_id: str, room_id: str, params: ThrottleParams
) -> None:
# no need to lock because `pusher_throttle` has a primary key on
# (pusher, room_id) so simple_upsert will retry
await self.db_pool.simple_upsert(
"pusher_throttle",
{"pusher": pusher_id, "room_id": room_id},
{"last_sent_ts": params.last_sent_ts, "throttle_ms": params.throttle_ms},
desc="set_throttle_params",
lock=False,
)
async def _remove_deactivated_pushers(self, progress: dict, batch_size: int) -> int:
@ -595,8 +586,6 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
device_id: Optional[str] = None,
) -> None:
async with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so simple_upsert will retry
await self.db_pool.simple_upsert(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
@ -615,7 +604,6 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
"device_id": device_id,
},
desc="add_pusher",
lock=False,
)
user_has_pusher = self.get_if_user_has_pusher.cache.get_immediate(

View file

@ -27,7 +27,6 @@ from typing import (
)
from synapse.api.constants import EduTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
@ -61,6 +60,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
hs: "HomeServer",
):
self._instance_name = hs.get_instance_name()
# In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process.
self._receipts_id_gen: AbstractStreamIdTracker
if isinstance(database.engine, PostgresEngine):
@ -87,14 +89,12 @@ class ReceiptsWorkerStore(SQLBaseStore):
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
# updated over replication. (Multiple writers are not supported for
# SQLite).
if hs.get_instance_name() in hs.config.worker.writers.receipts:
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
)
else:
self._receipts_id_gen = SlavedIdTracker(
db_conn, "receipts_linearized", "stream_id"
)
self._receipts_id_gen = StreamIdGenerator(
db_conn,
"receipts_linearized",
"stream_id",
is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts,
)
super().__init__(database, db_conn, hs)

View file

@ -953,7 +953,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""Returns user id from threepid
Args:
txn (cursor):
txn:
medium: threepid medium e.g. email
address: threepid address e.g. me@example.com
@ -1283,8 +1283,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""Sets an expiration date to the account with the given user ID.
Args:
user_id (str): User ID to set an expiration date for.
use_delta (bool): If set to False, the expiration date for the user will be
user_id: User ID to set an expiration date for.
use_delta: If set to False, the expiration date for the user will be
now + validity period. If set to True, this expiration date will be a
random value in the [now + period - d ; now + period] range, d being a
delta equal to 10% of the validity period.

View file

@ -20,6 +20,7 @@ from typing import (
FrozenSet,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
@ -81,8 +82,6 @@ class _RelatedEvent:
event_id: str
# The sender of the related event.
sender: str
topological_ordering: Optional[int]
stream_ordering: int
class RelationsWorkerStore(SQLBaseStore):
@ -245,13 +244,17 @@ class RelationsWorkerStore(SQLBaseStore):
txn.execute(sql, where_args + [limit + 1])
events = []
for event_id, relation_type, sender, topo_ordering, stream_ordering in txn:
topo_orderings: List[int] = []
stream_orderings: List[int] = []
for event_id, relation_type, sender, topo_ordering, stream_ordering in cast(
List[Tuple[str, str, str, int, int]], txn
):
# Do not include edits for redacted events as they leak event
# content.
if not is_redacted or relation_type != RelationTypes.REPLACE:
events.append(
_RelatedEvent(event_id, sender, topo_ordering, stream_ordering)
)
events.append(_RelatedEvent(event_id, sender))
topo_orderings.append(topo_ordering)
stream_orderings.append(stream_ordering)
# If there are more events, generate the next pagination key from the
# last event returned.
@ -260,9 +263,11 @@ class RelationsWorkerStore(SQLBaseStore):
# Instead of using the last row (which tells us there is more
# data), use the last row to be returned.
events = events[:limit]
topo_orderings = topo_orderings[:limit]
stream_orderings = stream_orderings[:limit]
topo = events[-1].topological_ordering
token = events[-1].stream_ordering
topo = topo_orderings[-1]
token = stream_orderings[-1]
if direction == "b":
# Tokens are positions between events.
# This token points *after* the last event in the chunk.
@ -394,111 +399,195 @@ class RelationsWorkerStore(SQLBaseStore):
)
return result is not None
@cached(tree=True)
async def get_aggregation_groups_for_event(
self, event_id: str, room_id: str, limit: int = 5
) -> List[JsonDict]:
"""Get a list of annotations on the event, grouped by event type and
@cached()
async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]:
raise NotImplementedError()
@cachedList(
cached_method_name="get_aggregation_groups_for_event", list_name="event_ids"
)
async def get_aggregation_groups_for_events(
self, event_ids: Collection[str]
) -> Mapping[str, Optional[List[JsonDict]]]:
"""Get a list of annotations on the given events, grouped by event type and
aggregation key, sorted by count.
This is used e.g. to get the what and how many reactions have happend
on an event.
Args:
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
limit: Only fetch the `limit` groups.
event_ids: Fetch events that relate to these event IDs.
Returns:
List of groups of annotations that match. Each row is a dict with
`type`, `key` and `count` fields.
A map of event IDs to a list of groups of annotations that match.
Each entry is a dict with `type`, `key` and `count` fields.
"""
# The number of entries to return per event ID.
limit = 5
clause, args = make_in_list_sql_clause(
self.database_engine, "relates_to_id", event_ids
)
args.append(RelationTypes.ANNOTATION)
sql = f"""
SELECT
relates_to_id,
annotation.type,
aggregation_key,
COUNT(DISTINCT annotation.sender)
FROM events AS annotation
INNER JOIN event_relations USING (event_id)
INNER JOIN events AS parent ON
parent.event_id = relates_to_id
AND parent.room_id = annotation.room_id
WHERE
{clause}
AND relation_type = ?
GROUP BY relates_to_id, annotation.type, aggregation_key
ORDER BY relates_to_id, COUNT(*) DESC
"""
args = [
event_id,
room_id,
RelationTypes.ANNOTATION,
limit,
]
sql = """
SELECT type, aggregation_key, COUNT(DISTINCT sender)
FROM event_relations
INNER JOIN events USING (event_id)
WHERE relates_to_id = ? AND room_id = ? AND relation_type = ?
GROUP BY relation_type, type, aggregation_key
ORDER BY COUNT(*) DESC
LIMIT ?
"""
def _get_aggregation_groups_for_event_txn(
def _get_aggregation_groups_for_events_txn(
txn: LoggingTransaction,
) -> List[JsonDict]:
) -> Mapping[str, List[JsonDict]]:
txn.execute(sql, args)
return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn]
result: Dict[str, List[JsonDict]] = {}
for event_id, type, key, count in cast(
List[Tuple[str, str, str, int]], txn
):
event_results = result.setdefault(event_id, [])
# Limit the number of results per event ID.
if len(event_results) == limit:
continue
event_results.append({"type": type, "key": key, "count": count})
return result
return await self.db_pool.runInteraction(
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
"get_aggregation_groups_for_events", _get_aggregation_groups_for_events_txn
)
async def get_aggregation_groups_for_users(
self,
event_id: str,
room_id: str,
limit: int,
users: FrozenSet[str] = frozenset(),
) -> Dict[Tuple[str, str], int]:
self, event_ids: Collection[str], users: FrozenSet[str]
) -> Dict[str, Dict[Tuple[str, str], int]]:
"""Fetch the partial aggregations for an event for specific users.
This is used, in conjunction with get_aggregation_groups_for_event, to
remove information from the results for ignored users.
Args:
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
limit: Only fetch the `limit` groups.
event_ids: Fetch events that relate to these event IDs.
users: The users to fetch information for.
Returns:
A map of (event type, aggregation key) to a count of users.
A map of event ID to a map of (event type, aggregation key) to a
count of users.
"""
if not users:
return {}
args: List[Union[str, int]] = [
event_id,
room_id,
RelationTypes.ANNOTATION,
]
events_sql, args = make_in_list_sql_clause(
self.database_engine, "relates_to_id", event_ids
)
users_sql, users_args = make_in_list_sql_clause(
self.database_engine, "sender", users
self.database_engine, "annotation.sender", users
)
args.extend(users_args)
args.append(RelationTypes.ANNOTATION)
sql = f"""
SELECT type, aggregation_key, COUNT(DISTINCT sender)
FROM event_relations
INNER JOIN events USING (event_id)
WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql}
GROUP BY relation_type, type, aggregation_key
ORDER BY COUNT(*) DESC
LIMIT ?
SELECT
relates_to_id,
annotation.type,
aggregation_key,
COUNT(DISTINCT annotation.sender)
FROM events AS annotation
INNER JOIN event_relations USING (event_id)
INNER JOIN events AS parent ON
parent.event_id = relates_to_id
AND parent.room_id = annotation.room_id
WHERE {events_sql} AND {users_sql} AND relation_type = ?
GROUP BY relates_to_id, annotation.type, aggregation_key
ORDER BY relates_to_id, COUNT(*) DESC
"""
def _get_aggregation_groups_for_users_txn(
txn: LoggingTransaction,
) -> Dict[Tuple[str, str], int]:
txn.execute(sql, args + [limit])
) -> Dict[str, Dict[Tuple[str, str], int]]:
txn.execute(sql, args)
return {(row[0], row[1]): row[2] for row in txn}
result: Dict[str, Dict[Tuple[str, str], int]] = {}
for event_id, type, key, count in cast(
List[Tuple[str, str, str, int]], txn
):
result.setdefault(event_id, {})[(type, key)] = count
return result
return await self.db_pool.runInteraction(
"get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn
)
@cached()
async def get_references_for_event(self, event_id: str) -> List[JsonDict]:
raise NotImplementedError()
@cachedList(cached_method_name="get_references_for_event", list_name="event_ids")
async def get_references_for_events(
self, event_ids: Collection[str]
) -> Mapping[str, Optional[List[_RelatedEvent]]]:
"""Get a list of references to the given events.
Args:
event_ids: Fetch events that relate to these event IDs.
Returns:
A map of event IDs to a list of related event IDs (and their senders).
"""
clause, args = make_in_list_sql_clause(
self.database_engine, "relates_to_id", event_ids
)
args.append(RelationTypes.REFERENCE)
sql = f"""
SELECT relates_to_id, ref.event_id, ref.sender
FROM events AS ref
INNER JOIN event_relations USING (event_id)
INNER JOIN events AS parent ON
parent.event_id = relates_to_id
AND parent.room_id = ref.room_id
WHERE
{clause}
AND relation_type = ?
ORDER BY ref.topological_ordering, ref.stream_ordering
"""
def _get_references_for_events_txn(
txn: LoggingTransaction,
) -> Mapping[str, List[_RelatedEvent]]:
txn.execute(sql, args)
result: Dict[str, List[_RelatedEvent]] = {}
for relates_to_id, event_id, sender in cast(
List[Tuple[str, str, str]], txn
):
result.setdefault(relates_to_id, []).append(
_RelatedEvent(event_id, sender)
)
return result
return await self.db_pool.runInteraction(
"_get_references_for_events_txn", _get_references_for_events_txn
)
@cached()
def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
raise NotImplementedError()

View file

@ -912,7 +912,11 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
event_json = db_to_json(content_json)
content = event_json["content"]
content_url = content.get("url")
thumbnail_url = content.get("info", {}).get("thumbnail_url")
info = content.get("info")
if isinstance(info, dict):
thumbnail_url = info.get("thumbnail_url")
else:
thumbnail_url = None
for url in (content_url, thumbnail_url):
if not url:
@ -1843,9 +1847,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
"creator": room_creator,
"has_auth_chain_index": has_auth_chain_index,
},
# rooms has a unique constraint on room_id, so no need to lock when doing an
# emulated upsert.
lock=False,
)
async def store_partial_state_room(
@ -1966,9 +1967,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
"creator": "",
"has_auth_chain_index": has_auth_chain_index,
},
# rooms has a unique constraint on room_id, so no need to lock when doing an
# emulated upsert.
lock=False,
)
async def set_room_is_public(self, room_id: str, is_public: bool) -> None:
@ -2057,7 +2055,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
Args:
report_id: ID of reported event in database
Returns:
event_report: json list of information from event report
JSON dict of information from an event report or None if the
report does not exist.
"""
def _get_event_report_txn(
@ -2130,8 +2129,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
user_id: search for user_id. Ignored if user_id is None
room_id: search for room_id. Ignored if room_id is None
Returns:
event_reports: json list of event reports
count: total number of event reports matching the filter criteria
Tuple of:
json list of event reports
total number of event reports matching the filter criteria
"""
def _get_event_reports_paginate_txn(

View file

@ -44,6 +44,4 @@ class RoomBatchStore(SQLBaseStore):
table="event_to_state_groups",
keyvalues={"event_id": event_id},
values={"state_group": state_group_id, "event_id": event_id},
# Unique constraint on event_id so we don't have to lock
lock=False,
)

View file

@ -185,9 +185,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
- who should be in the user_directory.
Args:
progress (dict)
batch_size (int): Maximum number of state events to process
per cycle.
progress
batch_size: Maximum number of state events to process per cycle.
Returns:
number of events processed.
@ -482,7 +481,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
table="user_directory",
keyvalues={"user_id": user_id},
values={"display_name": display_name, "avatar_url": avatar_url},
lock=False, # We're only inserter
)
if isinstance(self.database_engine, PostgresEngine):
@ -512,7 +510,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
table="user_directory_search",
keyvalues={"user_id": user_id},
values={"value": value},
lock=False, # We're only inserter
)
else:
# This should be unreachable.
@ -708,10 +705,10 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
Returns the rooms that a user is in.
Args:
user_id(str): Must be a local user
user_id: Must be a local user
Returns:
list: user_id
List of room IDs
"""
rows = await self.db_pool.simple_select_onecol(
table="users_who_share_private_rooms",

View file

@ -93,13 +93,6 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups}
where_clause, where_args = state_filter.make_sql_filter_clause()
# Unless the filter clause is empty, we're going to append it after an
# existing where clause
if where_clause:
where_clause = " AND (%s)" % (where_clause,)
if isinstance(self.database_engine, PostgresEngine):
# Temporarily disable sequential scans in this transaction. This is
# a temporary hack until we can add the right indices in
@ -110,31 +103,91 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
# against `state_groups_state` to fetch the latest state.
# It assumes that previous state groups are always numerically
# lesser.
# The PARTITION is used to get the event_id in the greatest state
# group for the given type, state_key.
# This may return multiple rows per (type, state_key), but last_value
# should be the same.
sql = """
WITH RECURSIVE state(state_group) AS (
WITH RECURSIVE sgs(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
SELECT prev_state_group FROM state_group_edges e, sgs s
WHERE s.state_group = e.state_group
)
SELECT DISTINCT ON (type, state_key)
type, state_key, event_id
FROM state_groups_state
WHERE state_group IN (
SELECT state_group FROM state
) %s
ORDER BY type, state_key, state_group DESC
%s
"""
overall_select_query_args: List[Union[int, str]] = []
# This is an optimization to create a select clause per-condition. This
# makes the query planner a lot smarter on what rows should pull out in the
# first place and we end up with something that takes 10x less time to get a
# result.
use_condition_optimization = (
not state_filter.include_others and not state_filter.is_full()
)
state_filter_condition_combos: List[Tuple[str, Optional[str]]] = []
# We don't need to caclculate this list if we're not using the condition
# optimization
if use_condition_optimization:
for etype, state_keys in state_filter.types.items():
if state_keys is None:
state_filter_condition_combos.append((etype, None))
else:
for state_key in state_keys:
state_filter_condition_combos.append((etype, state_key))
# And here is the optimization itself. We don't want to do the optimization
# if there are too many individual conditions. 10 is an arbitrary number
# with no testing behind it but we do know that we specifically made this
# optimization for when we grab the necessary state out for
# `filter_events_for_client` which just uses 2 conditions
# (`EventTypes.RoomHistoryVisibility` and `EventTypes.Member`).
if use_condition_optimization and len(state_filter_condition_combos) < 10:
select_clause_list: List[str] = []
for etype, skey in state_filter_condition_combos:
if skey is None:
where_clause = "(type = ?)"
overall_select_query_args.extend([etype])
else:
where_clause = "(type = ? AND state_key = ?)"
overall_select_query_args.extend([etype, skey])
select_clause_list.append(
f"""
(
SELECT DISTINCT ON (type, state_key)
type, state_key, event_id
FROM state_groups_state
INNER JOIN sgs USING (state_group)
WHERE {where_clause}
ORDER BY type, state_key, state_group DESC
)
"""
)
overall_select_clause = " UNION ".join(select_clause_list)
else:
where_clause, where_args = state_filter.make_sql_filter_clause()
# Unless the filter clause is empty, we're going to append it after an
# existing where clause
if where_clause:
where_clause = " AND (%s)" % (where_clause,)
overall_select_query_args.extend(where_args)
overall_select_clause = f"""
SELECT DISTINCT ON (type, state_key)
type, state_key, event_id
FROM state_groups_state
WHERE state_group IN (
SELECT state_group FROM sgs
) {where_clause}
ORDER BY type, state_key, state_group DESC
"""
for group in groups:
args: List[Union[int, str]] = [group]
args.extend(where_args)
args.extend(overall_select_query_args)
txn.execute(sql % (where_clause,), args)
txn.execute(sql % (overall_select_clause,), args)
for row in txn:
typ, state_key, event_id = row
key = (intern_string(typ), intern_string(state_key))
@ -142,6 +195,12 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
else:
max_entries_returned = state_filter.max_entries_returned()
where_clause, where_args = state_filter.make_sql_filter_clause()
# Unless the filter clause is empty, we're going to append it after an
# existing where clause
if where_clause:
where_clause = " AND (%s)" % (where_clause,)
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups:

View file

@ -0,0 +1,53 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Prior to this schema delta, we tracked the set of unconverted rows in
-- `device_lists_changes_in_room` using the `converted_to_destinations` flag. When rows
-- were converted to `device_lists_outbound_pokes`, the `converted_to_destinations` flag
-- would be set.
--
-- After this schema delta, the `converted_to_destinations` is still populated like
-- before, but the set of unconverted rows is determined by the `stream_id` in the new
-- `device_lists_changes_converted_stream_position` table.
--
-- If rolled back, Synapse will re-send all device list changes that happened since the
-- schema delta.
CREATE TABLE IF NOT EXISTS device_lists_changes_converted_stream_position(
Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
-- The (stream id, room id) of the last row in `device_lists_changes_in_room` that
-- has been converted to `device_lists_outbound_pokes`. Rows with a strictly larger
-- (stream id, room id) where `converted_to_destinations` is `FALSE` have not been
-- converted.
stream_id BIGINT NOT NULL,
-- `room_id` may be an empty string, which compares less than all valid room IDs.
room_id TEXT NOT NULL,
CHECK (Lock='X')
);
INSERT INTO device_lists_changes_converted_stream_position (stream_id, room_id) VALUES (
(
SELECT COALESCE(
-- The last converted stream id is the smallest unconverted stream id minus
-- one.
MIN(stream_id) - 1,
-- If there is no unconverted stream id, the last converted stream id is the
-- largest stream id.
-- Otherwise, pick 1, since stream ids start at 2.
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_in_room)
) FROM device_lists_changes_in_room WHERE NOT converted_to_destinations
),
''
);

View file

@ -0,0 +1,20 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Adds an index on `device_lists_changes_in_room (room_id, stream_id)`, which
-- speeds up `/sync` queries.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(7313, 'device_lists_changes_in_room_by_room_index', '{}');

View file

@ -186,11 +186,13 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
column: str,
extra_tables: Iterable[Tuple[str, str]] = (),
step: int = 1,
is_writer: bool = True,
) -> None:
assert step != 0
self._lock = threading.Lock()
self._step: int = step
self._current: int = _load_current_id(db_conn, table, column, step)
self._is_writer = is_writer
for table, column in extra_tables:
self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step)
@ -204,9 +206,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
def advance(self, instance_name: str, new_id: int) -> None:
# `StreamIdGenerator` should only be used when there is a single writer,
# so replication should never happen.
raise Exception("Replication is not supported by StreamIdGenerator")
# Advance should never be called on a writer instance, only over replication
if self._is_writer:
raise Exception("Replication is not supported by writer StreamIdGenerator")
self._current = (max if self._step > 0 else min)(self._current, new_id)
def get_next(self) -> AsyncContextManager[int]:
with self._lock:
@ -249,6 +253,9 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return _AsyncCtxManagerWrapper(manager())
def get_current_token(self) -> int:
if not self._is_writer:
return self._current
with self._lock:
if self._unfinished_ids:
return next(iter(self._unfinished_ids)) - self._step