Merge remote-tracking branch 'upstream/release-v1.48'

This commit is contained in:
Tulir Asokan 2021-11-25 18:33:37 +02:00
commit 9f4fa40b64
175 changed files with 6413 additions and 1993 deletions

View file

@ -82,7 +82,7 @@ class BackgroundUpdater:
process and autotuning the batch size.
"""
MINIMUM_BACKGROUND_BATCH_SIZE = 100
MINIMUM_BACKGROUND_BATCH_SIZE = 1
DEFAULT_BACKGROUND_BATCH_SIZE = 100
BACKGROUND_UPDATE_INTERVAL_MS = 1000
BACKGROUND_UPDATE_DURATION_MS = 100
@ -122,6 +122,8 @@ class BackgroundUpdater:
def start_doing_background_updates(self) -> None:
if self.enabled:
# if we start a new background update, not all updates are done.
self._all_done = False
run_as_background_process("background_updates", self.run_background_updates)
async def run_background_updates(self, sleep: bool = True) -> None:

View file

@ -188,7 +188,7 @@ class LoggingDatabaseConnection:
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
_CallbackListEntry = Tuple[Callable[..., None], Iterable[Any], Dict[str, Any]]
_CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]]
R = TypeVar("R")
@ -235,7 +235,7 @@ class LoggingTransaction:
self.after_callbacks = after_callbacks
self.exception_callbacks = exception_callbacks
def call_after(self, callback: Callable[..., None], *args: Any, **kwargs: Any):
def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any):
"""Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the
correct thread.
@ -247,7 +247,7 @@ class LoggingTransaction:
self.after_callbacks.append((callback, args, kwargs))
def call_on_exception(
self, callback: Callable[..., None], *args: Any, **kwargs: Any
self, callback: Callable[..., object], *args: Any, **kwargs: Any
):
# if self.exception_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that

View file

@ -123,9 +123,9 @@ class DataStore(
RelationsStore,
CensorEventsStore,
UIAuthStore,
EventForwardExtremitiesStore,
CacheInvalidationWorkerStore,
ServerMetricsStore,
EventForwardExtremitiesStore,
LockStore,
SessionStore,
):
@ -154,6 +154,7 @@ class DataStore(
db_conn, "local_group_updates", "stream_id"
)
self._cache_id_gen: Optional[MultiWriterIdGenerator]
if isinstance(self.database_engine, PostgresEngine):
# We set the `writers` to an empty list here as we don't care about
# missing updates over restarts, as we'll not have anything in our

View file

@ -412,16 +412,16 @@ class ApplicationServiceTransactionWorkerStore(
)
async def set_type_stream_id_for_appservice(
self, service: ApplicationService, type: str, pos: Optional[int]
self, service: ApplicationService, stream_type: str, pos: Optional[int]
) -> None:
if type not in ("read_receipt", "presence"):
if stream_type not in ("read_receipt", "presence"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (type,)
% (stream_type,)
)
def set_type_stream_id_for_appservice_txn(txn):
stream_id_type = "%s_stream_id" % type
stream_id_type = "%s_stream_id" % stream_type
txn.execute(
"UPDATE application_services_state SET %s = ? WHERE as_id=?"
% stream_id_type,

View file

@ -13,12 +13,12 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
from synapse.events.utils import prune_event_dict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.util import json_encoder
@ -41,7 +41,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000)
@wrap_as_background_process("_censor_redactions")
async def _censor_redactions(self):
async def _censor_redactions(self) -> None:
"""Censors all redactions older than the configured period that haven't
been censored yet.
@ -105,7 +105,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
and original_event.internal_metadata.is_redacted()
):
# Redaction was allowed
pruned_json = json_encoder.encode(
pruned_json: Optional[str] = json_encoder.encode(
prune_event_dict(
original_event.room_version, original_event.get_dict()
)
@ -116,7 +116,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
updates.append((redaction_id, event_id, pruned_json))
def _update_censor_txn(txn):
def _update_censor_txn(txn: LoggingTransaction) -> None:
for redaction_id, event_id, pruned_json in updates:
if pruned_json:
self._censor_event_txn(txn, event_id, pruned_json)
@ -130,14 +130,16 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
await self.db_pool.runInteraction("_update_censor_txn", _update_censor_txn)
def _censor_event_txn(self, txn, event_id, pruned_json):
def _censor_event_txn(
self, txn: LoggingTransaction, event_id: str, pruned_json: str
) -> None:
"""Censor an event by replacing its JSON in the event_json table with the
provided pruned JSON.
Args:
txn (LoggingTransaction): The database transaction.
event_id (str): The ID of the event to censor.
pruned_json (str): The pruned JSON
txn: The database transaction.
event_id: The ID of the event to censor.
pruned_json: The pruned JSON
"""
self.db_pool.simple_update_one_txn(
txn,
@ -157,7 +159,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
# Try to retrieve the event's content from the database or the event cache.
event = await self.get_event(event_id)
def delete_expired_event_txn(txn):
def delete_expired_event_txn(txn: LoggingTransaction) -> None:
# Delete the expiry timestamp associated with this event from the database.
self._delete_event_expiry_txn(txn, event_id)
@ -194,14 +196,14 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
"delete_expired_event", delete_expired_event_txn
)
def _delete_event_expiry_txn(self, txn, event_id):
def _delete_event_expiry_txn(self, txn: LoggingTransaction, event_id: str) -> None:
"""Delete the expiry timestamp associated with an event ID without deleting the
actual event.
Args:
txn (LoggingTransaction): The transaction to use to perform the deletion.
event_id (str): The event ID to delete the associated expiry timestamp of.
txn: The transaction to use to perform the deletion.
event_id: The event ID to delete the associated expiry timestamp of.
"""
return self.db_pool.simple_delete_txn(
self.db_pool.simple_delete_txn(
txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
)

View file

@ -1,4 +1,5 @@
# Copyright 2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -19,9 +20,17 @@ from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache
@ -34,14 +43,21 @@ logger = logging.getLogger(__name__)
class DeviceInboxWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
# Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions.
self._last_device_delete_cache = ExpiringCache(
self._last_device_delete_cache: ExpiringCache[
Tuple[str, Optional[str]], int
] = ExpiringCache(
cache_name="last_device_delete_cache",
clock=self._clock,
max_len=10000,
@ -53,14 +69,16 @@ class DeviceInboxWorkerStore(SQLBaseStore):
self._instance_name in hs.config.worker.writers.to_device
)
self._device_inbox_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
stream_name="to_device",
instance_name=self._instance_name,
tables=[("device_inbox", "instance_name", "stream_id")],
sequence_name="device_inbox_sequence",
writers=hs.config.worker.writers.to_device,
self._device_inbox_id_gen: AbstractStreamIdGenerator = (
MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
stream_name="to_device",
instance_name=self._instance_name,
tables=[("device_inbox", "instance_name", "stream_id")],
sequence_name="device_inbox_sequence",
writers=hs.config.worker.writers.to_device,
)
)
else:
self._can_write_to_device = True
@ -101,6 +119,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ToDeviceStream.NAME:
# If replication is happening than postgres must be being used.
assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator)
self._device_inbox_id_gen.advance(instance_name, token)
for row in rows:
if row.entity.startswith("@"):
@ -134,7 +154,10 @@ class DeviceInboxWorkerStore(SQLBaseStore):
limit: The maximum number of messages to retrieve.
Returns:
A list of messages for the device and where in the stream the messages got to.
A tuple containing:
* A list of messages for the device.
* The max stream token of these messages. There may be more to retrieve
if the given limit was reached.
"""
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_stream_id
@ -153,12 +176,19 @@ class DeviceInboxWorkerStore(SQLBaseStore):
txn.execute(
sql, (user_id, device_id, last_stream_id, current_stream_id, limit)
)
messages = []
stream_pos = current_stream_id
for row in txn:
stream_pos = row[0]
messages.append(db_to_json(row[1]))
# If the limit was not reached we know that there's no more data for this
# user/device pair up to current_stream_id.
if len(messages) < limit:
stream_pos = current_stream_id
return messages, stream_pos
return await self.db_pool.runInteraction(
@ -210,11 +240,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
log_kv({"message": f"deleted {count} messages for device", "count": count})
# Update the cache, ensuring that we only ever increase the value
last_deleted_stream_id = self._last_device_delete_cache.get(
updated_last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), 0
)
self._last_device_delete_cache[(user_id, device_id)] = max(
last_deleted_stream_id, up_to_stream_id
updated_last_deleted_stream_id, up_to_stream_id
)
return count
@ -260,13 +290,20 @@ class DeviceInboxWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (destination, last_stream_id, current_stream_id, limit))
messages = []
stream_pos = current_stream_id
for row in txn:
stream_pos = row[0]
messages.append(db_to_json(row[1]))
# If the limit was not reached we know that there's no more data for this
# user/device pair up to current_stream_id.
if len(messages) < limit:
log_kv({"message": "Set stream position to current position"})
stream_pos = current_stream_id
return messages, stream_pos
return await self.db_pool.runInteraction(
@ -372,8 +409,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
"""Used to send messages from this server.
Args:
local_messages_by_user_and_device:
Dictionary of user_id to device_id to message.
local_messages_by_user_then_device:
Dictionary of recipient user_id to recipient device_id to message.
remote_messages_by_destination:
Dictionary of destination server_name to the EDU JSON to send.
@ -415,7 +452,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
now_ms = self._clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
)
@ -466,7 +503,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
now_ms = self._clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox",
add_messages_txn,
@ -562,6 +599,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
REMOVE_DELETED_DEVICES = "remove_deleted_devices_from_device_inbox"
REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox"
REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox"
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
@ -577,14 +615,18 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
)
self.db_pool.updates.register_background_update_handler(
self.REMOVE_DELETED_DEVICES,
self._remove_deleted_devices_from_device_inbox,
# Used to be a background update that deletes all device_inboxes for deleted
# devices.
self.db_pool.updates.register_noop_background_update(
self.REMOVE_DELETED_DEVICES
)
# Used to be a background update that deletes all device_inboxes for hidden
# devices.
self.db_pool.updates.register_noop_background_update(self.REMOVE_HIDDEN_DEVICES)
self.db_pool.updates.register_background_update_handler(
self.REMOVE_HIDDEN_DEVICES,
self._remove_hidden_devices_from_device_inbox,
self.REMOVE_DEAD_DEVICES_FROM_INBOX,
self._remove_dead_devices_from_device_inbox,
)
async def _background_drop_index_device_inbox(self, progress, batch_size):
@ -599,171 +641,83 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
return 1
async def _remove_deleted_devices_from_device_inbox(
self, progress: JsonDict, batch_size: int
async def _remove_dead_devices_from_device_inbox(
self,
progress: JsonDict,
batch_size: int,
) -> int:
"""A background update that deletes all device_inboxes for deleted devices.
This should only need to be run once (when users upgrade to v1.47.0)
"""A background update to remove devices that were either deleted or hidden from
the device_inbox table.
Args:
progress: JsonDict used to store progress of this background update
batch_size: the maximum number of rows to retrieve in a single select query
progress: The update's progress dict.
batch_size: The batch size for this update.
Returns:
The number of deleted rows
The number of rows deleted.
"""
def _remove_deleted_devices_from_device_inbox_txn(
def _remove_dead_devices_from_device_inbox_txn(
txn: LoggingTransaction,
) -> int:
"""stream_id is not unique
we need to use an inclusive `stream_id >= ?` clause,
since we might not have deleted all dead device messages for the stream_id
returned from the previous query
) -> Tuple[int, bool]:
Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
to avoid problems of deleting a large number of rows all at once
due to a single device having lots of device messages.
"""
if "max_stream_id" in progress:
max_stream_id = progress["max_stream_id"]
else:
txn.execute("SELECT max(stream_id) FROM device_inbox")
# There's a type mismatch here between how we want to type the row and
# what fetchone says it returns, but we silence it because we know that
# res can't be None.
res: Tuple[Optional[int]] = txn.fetchone() # type: ignore[assignment]
if res[0] is None:
# this can only happen if the `device_inbox` table is empty, in which
# case we have no work to do.
return 0, True
else:
max_stream_id = res[0]
last_stream_id = progress.get("stream_id", 0)
start = progress.get("stream_id", 0)
stop = start + batch_size
# delete rows in `device_inbox` which do *not* correspond to a known,
# unhidden device.
sql = """
SELECT device_id, user_id, stream_id
FROM device_inbox
DELETE FROM device_inbox
WHERE
stream_id >= ?
AND (device_id, user_id) NOT IN (
SELECT device_id, user_id FROM devices
stream_id >= ? AND stream_id < ?
AND NOT EXISTS (
SELECT * FROM devices d
WHERE
d.device_id=device_inbox.device_id
AND d.user_id=device_inbox.user_id
AND NOT hidden
)
ORDER BY stream_id
LIMIT ?
"""
"""
txn.execute(sql, (last_stream_id, batch_size))
rows = txn.fetchall()
txn.execute(sql, (start, stop))
num_deleted = 0
for row in rows:
num_deleted += self.db_pool.simple_delete_txn(
txn,
"device_inbox",
{"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
)
if rows:
# send more than stream_id to progress
# otherwise it can happen in large deployments that
# no change of status is visible in the log file
# it may be that the stream_id does not change in several runs
self.db_pool.updates._background_update_progress_txn(
txn,
self.REMOVE_DELETED_DEVICES,
{
"device_id": rows[-1][0],
"user_id": rows[-1][1],
"stream_id": rows[-1][2],
},
)
return num_deleted
number_deleted = await self.db_pool.runInteraction(
"_remove_deleted_devices_from_device_inbox",
_remove_deleted_devices_from_device_inbox_txn,
)
# The task is finished when no more lines are deleted.
if not number_deleted:
await self.db_pool.updates._end_background_update(
self.REMOVE_DELETED_DEVICES
self.db_pool.updates._background_update_progress_txn(
txn,
self.REMOVE_DEAD_DEVICES_FROM_INBOX,
{
"stream_id": stop,
"max_stream_id": max_stream_id,
},
)
return number_deleted
return stop > max_stream_id
async def _remove_hidden_devices_from_device_inbox(
self, progress: JsonDict, batch_size: int
) -> int:
"""A background update that deletes all device_inboxes for hidden devices.
This should only need to be run once (when users upgrade to v1.47.0)
Args:
progress: JsonDict used to store progress of this background update
batch_size: the maximum number of rows to retrieve in a single select query
Returns:
The number of deleted rows
"""
def _remove_hidden_devices_from_device_inbox_txn(
txn: LoggingTransaction,
) -> int:
"""stream_id is not unique
we need to use an inclusive `stream_id >= ?` clause,
since we might not have deleted all hidden device messages for the stream_id
returned from the previous query
Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
to avoid problems of deleting a large number of rows all at once
due to a single device having lots of device messages.
"""
last_stream_id = progress.get("stream_id", 0)
sql = """
SELECT device_id, user_id, stream_id
FROM device_inbox
WHERE
stream_id >= ?
AND (device_id, user_id) IN (
SELECT device_id, user_id FROM devices WHERE hidden = ?
)
ORDER BY stream_id
LIMIT ?
"""
txn.execute(sql, (last_stream_id, True, batch_size))
rows = txn.fetchall()
num_deleted = 0
for row in rows:
num_deleted += self.db_pool.simple_delete_txn(
txn,
"device_inbox",
{"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
)
if rows:
# We don't just save the `stream_id` in progress as
# otherwise it can happen in large deployments that
# no change of status is visible in the log file, as
# it may be that the stream_id does not change in several runs
self.db_pool.updates._background_update_progress_txn(
txn,
self.REMOVE_HIDDEN_DEVICES,
{
"device_id": rows[-1][0],
"user_id": rows[-1][1],
"stream_id": rows[-1][2],
},
)
return num_deleted
number_deleted = await self.db_pool.runInteraction(
"_remove_hidden_devices_from_device_inbox",
_remove_hidden_devices_from_device_inbox_txn,
finished = await self.db_pool.runInteraction(
"_remove_devices_from_device_inbox_txn",
_remove_dead_devices_from_device_inbox_txn,
)
# The task is finished when no more lines are deleted.
if not number_deleted:
if finished:
await self.db_pool.updates._end_background_update(
self.REMOVE_HIDDEN_DEVICES
self.REMOVE_DEAD_DEVICES_FROM_INBOX,
)
return number_deleted
return batch_size
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):

View file

@ -13,17 +13,18 @@
# limitations under the License.
from collections import namedtuple
from typing import Iterable, List, Optional
from typing import Iterable, List, Optional, Tuple
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.types import RoomAlias
from synapse.util.caches.descriptors import cached
RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
class DirectoryWorkerStore(SQLBaseStore):
class DirectoryWorkerStore(CacheInvalidationWorkerStore):
async def get_association_from_room_alias(
self, room_alias: RoomAlias
) -> Optional[RoomAliasMapping]:
@ -91,7 +92,7 @@ class DirectoryWorkerStore(SQLBaseStore):
creator: Optional user_id of creator.
"""
def alias_txn(txn):
def alias_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_insert_txn(
txn,
"room_aliases",
@ -126,14 +127,16 @@ class DirectoryWorkerStore(SQLBaseStore):
class DirectoryStore(DirectoryWorkerStore):
async def delete_room_alias(self, room_alias: RoomAlias) -> str:
async def delete_room_alias(self, room_alias: RoomAlias) -> Optional[str]:
room_id = await self.db_pool.runInteraction(
"delete_room_alias", self._delete_room_alias_txn, room_alias
)
return room_id
def _delete_room_alias_txn(self, txn, room_alias: RoomAlias) -> str:
def _delete_room_alias_txn(
self, txn: LoggingTransaction, room_alias: RoomAlias
) -> Optional[str]:
txn.execute(
"SELECT room_id FROM room_aliases WHERE room_alias = ?",
(room_alias.to_string(),),
@ -173,9 +176,9 @@ class DirectoryStore(DirectoryWorkerStore):
If None, the creator will be left unchanged.
"""
def _update_aliases_for_room_txn(txn):
def _update_aliases_for_room_txn(txn: LoggingTransaction) -> None:
update_creator_sql = ""
sql_params = (new_room_id, old_room_id)
sql_params: Tuple[str, ...] = (new_room_id, old_room_id)
if creator:
update_creator_sql = ", creator = ?"
sql_params = (new_room_id, creator, old_room_id)

View file

@ -408,29 +408,58 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
fallback_keys: the keys to set. This is a map from key ID (which is
of the form "algorithm:id") to key data.
"""
await self.db_pool.runInteraction(
"set_e2e_fallback_keys_txn",
self._set_e2e_fallback_keys_txn,
user_id,
device_id,
fallback_keys,
)
await self.invalidate_cache_and_stream(
"get_e2e_unused_fallback_key_types", (user_id, device_id)
)
def _set_e2e_fallback_keys_txn(
self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict
) -> None:
# fallback_keys will usually only have one item in it, so using a for
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
# FIXME: make sure that only one key per algorithm is uploaded
for key_id, fallback_key in fallback_keys.items():
algorithm, key_id = key_id.split(":", 1)
await self.db_pool.simple_upsert(
"e2e_fallback_keys_json",
old_key_json = self.db_pool.simple_select_one_onecol_txn(
txn,
table="e2e_fallback_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
},
values={
"key_id": key_id,
"key_json": json_encoder.encode(fallback_key),
"used": False,
},
desc="set_e2e_fallback_key",
retcol="key_json",
allow_none=True,
)
await self.invalidate_cache_and_stream(
"get_e2e_unused_fallback_key_types", (user_id, device_id)
)
new_key_json = encode_canonical_json(fallback_key).decode("utf-8")
# If the uploaded key is the same as the current fallback key,
# don't do anything. This prevents marking the key as unused if it
# was already used.
if old_key_json != new_key_json:
self.db_pool.simple_upsert_txn(
txn,
table="e2e_fallback_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
},
values={
"key_id": key_id,
"key_json": json_encoder.encode(fallback_key),
"used": False,
},
)
@cached(max_entries=10000)
async def get_e2e_unused_fallback_key_types(

View file

@ -1,6 +1,6 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018-2019 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -1641,8 +1641,8 @@ class PersistEventsStore:
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database."""
def str_or_none(val: Any) -> Optional[str]:
return val if isinstance(val, str) else None
def non_null_str_or_none(val: Any) -> Optional[str]:
return val if isinstance(val, str) and "\u0000" not in val else None
self.db_pool.simple_insert_many_txn(
txn,
@ -1654,8 +1654,10 @@ class PersistEventsStore:
"sender": event.user_id,
"room_id": event.room_id,
"membership": event.membership,
"display_name": str_or_none(event.content.get("displayname")),
"avatar_url": str_or_none(event.content.get("avatar_url")),
"display_name": non_null_str_or_none(
event.content.get("displayname")
),
"avatar_url": non_null_str_or_none(event.content.get("avatar_url")),
}
for event in events
],
@ -1694,34 +1696,33 @@ class PersistEventsStore:
},
)
def _handle_event_relations(self, txn, event):
"""Handles inserting relation data during peristence of events
def _handle_event_relations(
self, txn: LoggingTransaction, event: EventBase
) -> None:
"""Handles inserting relation data during persistence of events
Args:
txn
event (EventBase)
txn: The current database transaction.
event: The event which might have relations.
"""
relation = event.content.get("m.relates_to")
if not relation:
# No relations
return
# Relations must have a type and parent event ID.
rel_type = relation.get("rel_type")
if rel_type not in (
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
RelationTypes.REPLACE,
RelationTypes.THREAD,
):
# Unknown relation type
if not isinstance(rel_type, str):
return
parent_id = relation.get("event_id")
if not parent_id:
# Invalid relation
if not isinstance(parent_id, str):
return
aggregation_key = relation.get("key")
# Annotations have a key field.
aggregation_key = None
if rel_type == RelationTypes.ANNOTATION:
aggregation_key = relation.get("key")
self.db_pool.simple_insert_txn(
txn,

View file

@ -1,4 +1,4 @@
# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -171,8 +171,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self._purged_chain_cover_index,
)
# The event_thread_relation background update was replaced with the
# event_arbitrary_relations one, which handles any relation to avoid
# needed to potentially crawl the entire events table in the future.
self.db_pool.updates.register_noop_background_update("event_thread_relation")
self.db_pool.updates.register_background_update_handler(
"event_thread_relation", self._event_thread_relation
"event_arbitrary_relations",
self._event_arbitrary_relations,
)
################################################################################
@ -1099,23 +1105,27 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return result
async def _event_thread_relation(self, progress: JsonDict, batch_size: int) -> int:
"""Background update handler which will store thread relations for existing events."""
async def _event_arbitrary_relations(
self, progress: JsonDict, batch_size: int
) -> int:
"""Background update handler which will store previously unknown relations for existing events."""
last_event_id = progress.get("last_event_id", "")
def _event_thread_relation_txn(txn: LoggingTransaction) -> int:
def _event_arbitrary_relations_txn(txn: LoggingTransaction) -> int:
# Fetch events and then filter based on whether the event has a
# relation or not.
txn.execute(
"""
SELECT event_id, json FROM event_json
LEFT JOIN event_relations USING (event_id)
WHERE event_id > ? AND event_relations.event_id IS NULL
WHERE event_id > ?
ORDER BY event_id LIMIT ?
""",
(last_event_id, batch_size),
)
results = list(txn)
missing_thread_relations = []
# (event_id, parent_id, rel_type) for each relation
relations_to_insert: List[Tuple[str, str, str]] = []
for (event_id, event_json_raw) in results:
try:
event_json = db_to_json(event_json_raw)
@ -1127,48 +1137,70 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
)
continue
# If there's no relation (or it is not a thread), skip!
# If there's no relation, skip!
relates_to = event_json["content"].get("m.relates_to")
if not relates_to or not isinstance(relates_to, dict):
continue
if relates_to.get("rel_type") != RelationTypes.THREAD:
# If the relation type or parent event ID is not a string, skip it.
#
# Do not consider relation types that have existed for a long time,
# since they will already be listed in the `event_relations` table.
rel_type = relates_to.get("rel_type")
if not isinstance(rel_type, str) or rel_type in (
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
RelationTypes.REPLACE,
):
continue
# Get the parent ID.
parent_id = relates_to.get("event_id")
if not isinstance(parent_id, str):
continue
missing_thread_relations.append((event_id, parent_id))
relations_to_insert.append((event_id, parent_id, rel_type))
# Insert the missing data.
self.db_pool.simple_insert_many_txn(
txn=txn,
table="event_relations",
values=[
{
"event_id": event_id,
"relates_to_Id": parent_id,
"relation_type": RelationTypes.THREAD,
}
for event_id, parent_id in missing_thread_relations
],
)
# Insert the missing data, note that we upsert here in case the event
# has already been processed.
if relations_to_insert:
self.db_pool.simple_upsert_many_txn(
txn=txn,
table="event_relations",
key_names=("event_id",),
key_values=[(r[0],) for r in relations_to_insert],
value_names=("relates_to_id", "relation_type"),
value_values=[r[1:] for r in relations_to_insert],
)
# Iterate the parent IDs and invalidate caches.
for parent_id in {r[1] for r in relations_to_insert}:
cache_tuple = (parent_id,)
self._invalidate_cache_and_stream(
txn, self.get_relations_for_event, cache_tuple
)
self._invalidate_cache_and_stream(
txn, self.get_aggregation_groups_for_event, cache_tuple
)
self._invalidate_cache_and_stream(
txn, self.get_thread_summary, cache_tuple
)
if results:
latest_event_id = results[-1][0]
self.db_pool.updates._background_update_progress_txn(
txn, "event_thread_relation", {"last_event_id": latest_event_id}
txn, "event_arbitrary_relations", {"last_event_id": latest_event_id}
)
return len(results)
num_rows = await self.db_pool.runInteraction(
desc="event_thread_relation", func=_event_thread_relation_txn
desc="event_arbitrary_relations", func=_event_arbitrary_relations_txn
)
if not num_rows:
await self.db_pool.updates._end_background_update("event_thread_relation")
await self.db_pool.updates._end_background_update(
"event_arbitrary_relations"
)
return num_rows

View file

@ -13,15 +13,20 @@
# limitations under the License.
import logging
from typing import Dict, List
from typing import Any, Dict, List
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
logger = logging.getLogger(__name__)
class EventForwardExtremitiesStore(SQLBaseStore):
class EventForwardExtremitiesStore(
EventFederationWorkerStore,
CacheInvalidationWorkerStore,
):
async def delete_forward_extremities_for_room(self, room_id: str) -> int:
"""Delete any extra forward extremities for a room.
@ -31,7 +36,7 @@ class EventForwardExtremitiesStore(SQLBaseStore):
Returns count deleted.
"""
def delete_forward_extremities_for_room_txn(txn):
def delete_forward_extremities_for_room_txn(txn: LoggingTransaction) -> int:
# First we need to get the event_id to not delete
sql = """
SELECT event_id FROM event_forward_extremities
@ -82,10 +87,14 @@ class EventForwardExtremitiesStore(SQLBaseStore):
delete_forward_extremities_for_room_txn,
)
async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]:
async def get_forward_extremities_for_room(
self, room_id: str
) -> List[Dict[str, Any]]:
"""Get list of forward extremities for a room."""
def get_forward_extremities_for_room_txn(txn):
def get_forward_extremities_for_room_txn(
txn: LoggingTransaction,
) -> List[Dict[str, Any]]:
sql = """
SELECT event_id, state_group, depth, received_ts
FROM event_forward_extremities

View file

@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -18,6 +19,7 @@ from canonicaljson import encode_canonical_json
from synapse.api.errors import Codes, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
@ -49,7 +51,7 @@ class FilteringStore(SQLBaseStore):
# Need an atomic transaction to SELECT the maximal ID so far then
# INSERT a new one
def _do_txn(txn):
def _do_txn(txn: LoggingTransaction) -> int:
sql = (
"SELECT filter_id FROM user_filters "
"WHERE user_id = ? AND filter_json = ?"
@ -61,7 +63,7 @@ class FilteringStore(SQLBaseStore):
sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
txn.execute(sql, (user_localpart,))
max_id = txn.fetchone()[0]
max_id = txn.fetchone()[0] # type: ignore[index]
if max_id is None:
filter_id = 0
else:

View file

@ -13,7 +13,7 @@
# limitations under the License.
import logging
from types import TracebackType
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type
from typing import TYPE_CHECKING, Optional, Tuple, Type
from weakref import WeakValueDictionary
from twisted.internet.interfaces import IReactorCore
@ -62,7 +62,9 @@ class LockStore(SQLBaseStore):
# A map from `(lock_name, lock_key)` to the token of any locks that we
# think we currently hold.
self._live_tokens: Dict[Tuple[str, str], Lock] = WeakValueDictionary()
self._live_tokens: WeakValueDictionary[
Tuple[str, str], Lock
] = WeakValueDictionary()
# When we shut down we want to remove the locks. Technically this can
# lead to a race, as we may drop the lock while we are still processing.

View file

@ -13,10 +13,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
)
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -46,7 +61,12 @@ class MediaSortOrder(Enum):
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@ -102,13 +122,15 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
self._drop_media_index_without_method,
)
async def _drop_media_index_without_method(self, progress, batch_size):
async def _drop_media_index_without_method(
self, progress: JsonDict, batch_size: int
) -> int:
"""background update handler which removes the old constraints.
Note that this is only run on postgres.
"""
def f(txn):
def f(txn: LoggingTransaction) -> None:
txn.execute(
"ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key"
)
@ -126,7 +148,12 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"""Persistence for attachments and avatars"""
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
@ -174,7 +201,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
plus the total count of all the user's media
"""
def get_local_media_by_user_paginate_txn(txn):
def get_local_media_by_user_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[Dict[str, Any]], int]:
# Set ordering
order_by_column = MediaSortOrder(order_by).value
@ -184,14 +213,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
else:
order = "ASC"
args = [user_id]
args: List[Union[str, int]] = [user_id]
sql = """
SELECT COUNT(*) as total_media
FROM local_media_repository
WHERE user_id = ?
"""
txn.execute(sql, args)
count = txn.fetchone()[0]
count = txn.fetchone()[0] # type: ignore[index]
sql = """
SELECT
@ -268,7 +297,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
sql += sql_keep
def _get_local_media_before_txn(txn):
def _get_local_media_before_txn(txn: LoggingTransaction) -> List[str]:
txn.execute(sql, (before_ts, before_ts, size_gt))
return [row[0] for row in txn]
@ -278,13 +307,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def store_local_media(
self,
media_id,
media_type,
time_now_ms,
upload_name,
media_length,
user_id,
url_cache=None,
media_id: str,
media_type: str,
time_now_ms: int,
upload_name: Optional[str],
media_length: int,
user_id: UserID,
url_cache: Optional[str] = None,
) -> None:
await self.db_pool.simple_insert(
"local_media_repository",
@ -315,7 +344,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
None if the URL isn't cached.
"""
def get_url_cache_txn(txn):
def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
# get the most recently cached result (relative to the given ts)
sql = (
"SELECT response_code, etag, expires_ts, og, media_id, download_ts"
@ -359,7 +388,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts
):
) -> None:
await self.db_pool.simple_insert(
"local_media_repository_url_cache",
{
@ -390,13 +419,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def store_local_thumbnail(
self,
media_id,
thumbnail_width,
thumbnail_height,
thumbnail_type,
thumbnail_method,
thumbnail_length,
):
media_id: str,
thumbnail_width: int,
thumbnail_height: int,
thumbnail_type: str,
thumbnail_method: str,
thumbnail_length: int,
) -> None:
await self.db_pool.simple_upsert(
table="local_media_repository_thumbnails",
keyvalues={
@ -430,14 +459,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def store_cached_remote_media(
self,
origin,
media_id,
media_type,
media_length,
time_now_ms,
upload_name,
filesystem_id,
):
origin: str,
media_id: str,
media_type: str,
media_length: int,
time_now_ms: int,
upload_name: Optional[str],
filesystem_id: str,
) -> None:
await self.db_pool.simple_insert(
"remote_media_cache",
{
@ -458,7 +487,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
local_media: Iterable[str],
remote_media: Iterable[Tuple[str, str]],
time_ms: int,
):
) -> None:
"""Updates the last access time of the given media
Args:
@ -467,7 +496,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
time_ms: Current time in milliseconds
"""
def update_cache_txn(txn):
def update_cache_txn(txn: LoggingTransaction) -> None:
sql = (
"UPDATE remote_media_cache SET last_access_ts = ?"
" WHERE media_origin = ? AND media_id = ?"
@ -488,7 +517,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media))
return await self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"update_cached_last_access_time", update_cache_txn
)
@ -542,15 +571,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def store_remote_media_thumbnail(
self,
origin,
media_id,
filesystem_id,
thumbnail_width,
thumbnail_height,
thumbnail_type,
thumbnail_method,
thumbnail_length,
):
origin: str,
media_id: str,
filesystem_id: str,
thumbnail_width: int,
thumbnail_height: int,
thumbnail_type: str,
thumbnail_method: str,
thumbnail_length: int,
) -> None:
await self.db_pool.simple_upsert(
table="remote_media_cache_thumbnails",
keyvalues={
@ -566,7 +595,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_remote_media_thumbnail",
)
async def get_remote_media_before(self, before_ts):
async def get_remote_media_before(self, before_ts: int) -> List[Dict[str, str]]:
sql = (
"SELECT media_origin, media_id, filesystem_id"
" FROM remote_media_cache"
@ -602,7 +631,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" LIMIT 500"
)
def _get_expired_url_cache_txn(txn):
def _get_expired_url_cache_txn(txn: LoggingTransaction) -> List[str]:
txn.execute(sql, (now_ts,))
return [row[0] for row in txn]
@ -610,18 +639,16 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"get_expired_url_cache", _get_expired_url_cache_txn
)
async def delete_url_cache(self, media_ids):
async def delete_url_cache(self, media_ids: Collection[str]) -> None:
if len(media_ids) == 0:
return
sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
def _delete_url_cache_txn(txn):
def _delete_url_cache_txn(txn: LoggingTransaction) -> None:
txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
return await self.db_pool.runInteraction(
"delete_url_cache", _delete_url_cache_txn
)
await self.db_pool.runInteraction("delete_url_cache", _delete_url_cache_txn)
async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
sql = (
@ -631,7 +658,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" LIMIT 500"
)
def _get_url_cache_media_before_txn(txn):
def _get_url_cache_media_before_txn(txn: LoggingTransaction) -> List[str]:
txn.execute(sql, (before_ts,))
return [row[0] for row in txn]
@ -639,11 +666,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"get_url_cache_media_before", _get_url_cache_media_before_txn
)
async def delete_url_cache_media(self, media_ids):
async def delete_url_cache_media(self, media_ids: Collection[str]) -> None:
if len(media_ids) == 0:
return
def _delete_url_cache_media_txn(txn):
def _delete_url_cache_media_txn(txn: LoggingTransaction) -> None:
sql = "DELETE FROM local_media_repository WHERE media_id = ?"
txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
@ -652,6 +679,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
return await self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn
)

View file

@ -1,6 +1,21 @@
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
class OpenIdStore(SQLBaseStore):
@ -20,7 +35,7 @@ class OpenIdStore(SQLBaseStore):
async def get_user_id_for_open_id_token(
self, token: str, ts_now_ms: int
) -> Optional[str]:
def get_user_id_for_token_txn(txn):
def get_user_id_for_token_txn(txn: LoggingTransaction) -> Optional[str]:
sql = (
"SELECT user_id FROM open_id_tokens"
" WHERE token = ? AND ? <= ts_valid_until_ms"

View file

@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.roommember import ProfileInfo
@ -104,7 +105,7 @@ class ProfileWorkerStore(SQLBaseStore):
desc="update_remote_profile_cache",
)
async def maybe_delete_remote_profile_cache(self, user_id):
async def maybe_delete_remote_profile_cache(self, user_id: str) -> None:
"""Check if we still care about the remote user's profile, and if we
don't then remove their profile from the cache
"""
@ -116,9 +117,9 @@ class ProfileWorkerStore(SQLBaseStore):
desc="delete_remote_profile_cache",
)
async def is_subscribed_remote_profile_for_user(self, user_id):
async def is_subscribed_remote_profile_for_user(self, user_id: str) -> bool:
"""Check whether we are interested in a remote user's profile."""
res = await self.db_pool.simple_select_one_onecol(
res: Optional[str] = await self.db_pool.simple_select_one_onecol(
table="group_users",
keyvalues={"user_id": user_id},
retcol="user_id",
@ -139,13 +140,16 @@ class ProfileWorkerStore(SQLBaseStore):
if res:
return True
return False
async def get_remote_profile_cache_entries_that_expire(
self, last_checked: int
) -> List[Dict[str, str]]:
"""Get all users who haven't been checked since `last_checked`"""
def _get_remote_profile_cache_entries_that_expire_txn(txn):
def _get_remote_profile_cache_entries_that_expire_txn(
txn: LoggingTransaction,
) -> List[Dict[str, str]]:
sql = """
SELECT user_id, displayname, avatar_url
FROM remote_profile_cache

View file

@ -84,26 +84,26 @@ class TokenLookupResult:
return self.user_id
@attr.s(frozen=True, slots=True)
@attr.s(auto_attribs=True, frozen=True, slots=True)
class RefreshTokenLookupResult:
"""Result of looking up a refresh token."""
user_id = attr.ib(type=str)
user_id: str
"""The user this token belongs to."""
device_id = attr.ib(type=str)
device_id: str
"""The device associated with this refresh token."""
token_id = attr.ib(type=int)
token_id: int
"""The ID of this refresh token."""
next_token_id = attr.ib(type=Optional[int])
next_token_id: Optional[int]
"""The ID of the refresh token which replaced this one."""
has_next_refresh_token_been_refreshed = attr.ib(type=bool)
has_next_refresh_token_been_refreshed: bool
"""True if the next refresh token was used for another refresh."""
has_next_access_token_been_used = attr.ib(type=bool)
has_next_access_token_been_used: bool
"""True if the next access token was already used at least once."""
@ -476,7 +476,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
shadow_banned: true iff the user is to be shadow-banned, false otherwise.
"""
def set_shadow_banned_txn(txn):
def set_shadow_banned_txn(txn: LoggingTransaction) -> None:
user_id = user.to_string()
self.db_pool.simple_update_one_txn(
txn,
@ -1198,8 +1198,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
expiration_ts = now_ms + self._account_validity_period
if use_delta:
assert self._account_validity_startup_job_max_delta is not None
expiration_ts = random.randrange(
expiration_ts - self._account_validity_startup_job_max_delta,
int(expiration_ts - self._account_validity_startup_job_max_delta),
expiration_ts,
)
@ -1728,11 +1729,11 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
)
self.db_pool.updates.register_background_update_handler(
"user_threepids_grandfather", self._bg_user_threepids_grandfather
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
)
self.db_pool.updates.register_background_update_handler(
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
self.db_pool.updates.register_noop_background_update(
"user_threepids_grandfather"
)
self.db_pool.updates.register_background_index_update(
@ -1805,35 +1806,6 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
return nb_processed
async def _bg_user_threepids_grandfather(self, progress, batch_size):
"""We now track which identity servers a user binds their 3PID to, so
we need to handle the case of existing bindings where we didn't track
this.
We do this by grandfathering in existing user threepids assuming that
they used one of the server configured trusted identity servers.
"""
id_servers = set(self.config.registration.trusted_third_party_id_servers)
def _bg_user_threepids_grandfather_txn(txn):
sql = """
INSERT INTO user_threepid_id_server
(user_id, medium, address, id_server)
SELECT user_id, medium, address, ?
FROM user_threepids
"""
txn.execute_batch(sql, [(id_server,) for id_server in id_servers])
if id_servers:
await self.db_pool.runInteraction(
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
)
await self.db_pool.updates._end_background_update("user_threepids_grandfather")
return 1
async def set_user_deactivated_status(
self, user_id: str, deactivated: bool
) -> None:

View file

@ -20,7 +20,7 @@ import attr
from synapse.api.constants import RelationTypes
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.relations import (
AggregationPaginationToken,
@ -132,6 +132,69 @@ class RelationsWorkerStore(SQLBaseStore):
"get_recent_references_for_event", _get_recent_references_for_event_txn
)
async def event_includes_relation(self, event_id: str) -> bool:
"""Check if the given event relates to another event.
An event has a relation if it has a valid m.relates_to with a rel_type
and event_id in the content:
{
"content": {
"m.relates_to": {
"rel_type": "m.replace",
"event_id": "$other_event_id"
}
}
}
Args:
event_id: The event to check.
Returns:
True if the event includes a valid relation.
"""
result = await self.db_pool.simple_select_one_onecol(
table="event_relations",
keyvalues={"event_id": event_id},
retcol="event_id",
allow_none=True,
desc="event_includes_relation",
)
return result is not None
async def event_is_target_of_relation(self, parent_id: str) -> bool:
"""Check if the given event is the target of another event's relation.
An event is the target of an event relation if it has a valid
m.relates_to with a rel_type and event_id pointing to parent_id in the
content:
{
"content": {
"m.relates_to": {
"rel_type": "m.replace",
"event_id": "$parent_id"
}
}
}
Args:
parent_id: The event to check.
Returns:
True if the event is the target of another event's relation.
"""
result = await self.db_pool.simple_select_one_onecol(
table="event_relations",
keyvalues={"relates_to_id": parent_id},
retcol="event_id",
allow_none=True,
desc="event_is_target_of_relation",
)
return result is not None
@cached(tree=True)
async def get_aggregation_groups_for_event(
self,
@ -334,6 +397,62 @@ class RelationsWorkerStore(SQLBaseStore):
return count, latest_event
async def events_have_relations(
self,
parent_ids: List[str],
relation_senders: Optional[List[str]],
relation_types: Optional[List[str]],
) -> List[str]:
"""Check which events have a relationship from the given senders of the
given types.
Args:
parent_ids: The events being annotated
relation_senders: The relation senders to check.
relation_types: The relation types to check.
Returns:
True if the event has at least one relationship from one of the given senders of the given type.
"""
# If no restrictions are given then the event has the required relations.
if not relation_senders and not relation_types:
return parent_ids
sql = """
SELECT relates_to_id FROM event_relations
INNER JOIN events USING (event_id)
WHERE
%s;
"""
def _get_if_events_have_relations(txn) -> List[str]:
clauses: List[str] = []
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", parent_ids
)
clauses.append(clause)
if relation_senders:
clause, temp_args = make_in_list_sql_clause(
txn.database_engine, "sender", relation_senders
)
clauses.append(clause)
args.extend(temp_args)
if relation_types:
clause, temp_args = make_in_list_sql_clause(
txn.database_engine, "relation_type", relation_types
)
clauses.append(clause)
args.extend(temp_args)
txn.execute(sql % " AND ".join(clauses), args)
return [row[0] for row in txn]
return await self.db_pool.runInteraction(
"get_if_events_have_relations", _get_if_events_have_relations
)
async def has_user_annotated_event(
self, parent_id: str, event_type: str, aggregation_key: str, sender: str
) -> bool:

View file

@ -397,6 +397,20 @@ class RoomWorkerStore(SQLBaseStore):
desc="is_room_blocked",
)
async def room_is_blocked_by(self, room_id: str) -> Optional[str]:
"""
Function to retrieve user who has blocked the room.
user_id is non-nullable
It returns None if the room is not blocked.
"""
return await self.db_pool.simple_select_one_onecol(
table="blocked_rooms",
keyvalues={"room_id": room_id},
retcol="user_id",
allow_none=True,
desc="room_is_blocked_by",
)
async def get_rooms_paginate(
self,
start: int,
@ -1751,7 +1765,12 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
async def block_room(self, room_id: str, user_id: str) -> None:
"""Marks the room as blocked. Can be called multiple times.
"""Marks the room as blocked.
Can be called multiple times (though we'll only track the last user to
block this room).
Can be called on a room unknown to this homeserver.
Args:
room_id: Room to block
@ -1770,3 +1789,21 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
self.is_room_blocked,
(room_id,),
)
async def unblock_room(self, room_id: str) -> None:
"""Remove the room from blocking list.
Args:
room_id: Room to unblock
"""
await self.db_pool.simple_delete(
table="blocked_rooms",
keyvalues={"room_id": room_id},
desc="unblock_room",
)
await self.db_pool.runInteraction(
"block_room_invalidation",
self._invalidate_cache_and_stream,
self.is_room_blocked,
(room_id,),
)

View file

@ -39,13 +39,11 @@ class RoomBatchStore(SQLBaseStore):
async def store_state_group_id_for_event_id(
self, event_id: str, state_group_id: int
) -> Optional[str]:
{
await self.db_pool.simple_upsert(
table="event_to_state_groups",
keyvalues={"event_id": event_id},
values={"state_group": state_group_id, "event_id": event_id},
# Unique constraint on event_id so we don't have to lock
lock=False,
)
}
) -> None:
await self.db_pool.simple_upsert(
table="event_to_state_groups",
keyvalues={"event_id": event_id},
values={"state_group": state_group_id, "event_id": event_id},
# Unique constraint on event_id so we don't have to lock
lock=False,
)

View file

@ -63,12 +63,12 @@ class SignatureWorkerStore(SQLBaseStore):
A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash.
"""
hashes = await self.get_event_reference_hashes(event_ids)
hashes = {
encoded_hashes = {
e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
for e_id, h in hashes.items()
}
return list(hashes.items())
return list(encoded_hashes.items())
def _get_event_reference_hashes_txn(
self, txn: Cursor, event_id: str

View file

@ -16,11 +16,17 @@ import logging
from typing import Any, Dict, List, Tuple
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
class StateDeltasStore(SQLBaseStore):
# This class must be mixed in with a child class which provides the following
# attribute. TODO: can we get static analysis to enforce this?
_curr_state_delta_stream_cache: StreamChangeCache
async def get_current_state_deltas(
self, prev_stream_id: int, max_stream_id: int
) -> Tuple[int, List[Dict[str, Any]]]:
@ -60,7 +66,9 @@ class StateDeltasStore(SQLBaseStore):
# max_stream_id.
return max_stream_id, []
def get_current_state_deltas_txn(txn):
def get_current_state_deltas_txn(
txn: LoggingTransaction,
) -> Tuple[int, List[Dict[str, Any]]]:
# First we calculate the max stream id that will give us less than
# N results.
# We arbitrarily limit to 100 stream_id entries to ensure we don't
@ -106,7 +114,9 @@ class StateDeltasStore(SQLBaseStore):
"get_current_state_deltas", get_current_state_deltas_txn
)
def _get_max_stream_id_in_current_state_deltas_txn(self, txn):
def _get_max_stream_id_in_current_state_deltas_txn(
self, txn: LoggingTransaction
) -> int:
return self.db_pool.simple_select_one_onecol_txn(
txn,
table="current_state_delta_stream",
@ -114,7 +124,7 @@ class StateDeltasStore(SQLBaseStore):
retcol="COALESCE(MAX(stream_id), -1)",
)
async def get_max_stream_id_in_current_state_deltas(self):
async def get_max_stream_id_in_current_state_deltas(self) -> int:
return await self.db_pool.runInteraction(
"get_max_stream_id_in_current_state_deltas",
self._get_max_stream_id_in_current_state_deltas_txn,

View file

@ -272,31 +272,37 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
args = []
if event_filter.types:
clauses.append("(%s)" % " OR ".join("type = ?" for _ in event_filter.types))
clauses.append(
"(%s)" % " OR ".join("event.type = ?" for _ in event_filter.types)
)
args.extend(event_filter.types)
for typ in event_filter.not_types:
clauses.append("type != ?")
clauses.append("event.type != ?")
args.append(typ)
if event_filter.senders:
clauses.append("(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders))
clauses.append(
"(%s)" % " OR ".join("event.sender = ?" for _ in event_filter.senders)
)
args.extend(event_filter.senders)
for sender in event_filter.not_senders:
clauses.append("sender != ?")
clauses.append("event.sender != ?")
args.append(sender)
if event_filter.rooms:
clauses.append("(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms))
clauses.append(
"(%s)" % " OR ".join("event.room_id = ?" for _ in event_filter.rooms)
)
args.extend(event_filter.rooms)
for room_id in event_filter.not_rooms:
clauses.append("room_id != ?")
clauses.append("event.room_id != ?")
args.append(room_id)
if event_filter.contains_url:
clauses.append("contains_url = ?")
clauses.append("event.contains_url = ?")
args.append(event_filter.contains_url)
# We're only applying the "labels" filter on the database query, because applying the
@ -307,6 +313,23 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels))
args.extend(event_filter.labels)
# Filter on relation_senders / relation types from the joined tables.
if event_filter.relation_senders:
clauses.append(
"(%s)"
% " OR ".join(
"related_event.sender = ?" for _ in event_filter.relation_senders
)
)
args.extend(event_filter.relation_senders)
if event_filter.relation_types:
clauses.append(
"(%s)"
% " OR ".join("relation_type = ?" for _ in event_filter.relation_types)
)
args.extend(event_filter.relation_types)
return " AND ".join(clauses), args
@ -1116,7 +1139,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
bounds = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
column_names=("event.topological_ordering", "event.stream_ordering"),
from_token=from_bound,
to_token=to_bound,
engine=self.database_engine,
@ -1133,32 +1156,51 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
select_keywords = "SELECT"
join_clause = ""
# Using DISTINCT in this SELECT query is quite expensive, because it
# requires the engine to sort on the entire (not limited) result set,
# i.e. the entire events table. Only use it in scenarios that could result
# in the same event ID occurring multiple times in the results.
needs_distinct = False
if event_filter and event_filter.labels:
# If we're not filtering on a label, then joining on event_labels will
# return as many row for a single event as the number of labels it has. To
# avoid this, only join if we're filtering on at least one label.
join_clause = """
join_clause += """
LEFT JOIN event_labels
USING (event_id, room_id, topological_ordering)
"""
if len(event_filter.labels) > 1:
# Using DISTINCT in this SELECT query is quite expensive, because it
# requires the engine to sort on the entire (not limited) result set,
# i.e. the entire events table. We only need to use it when we're
# filtering on more than two labels, because that's the only scenario
# in which we can possibly to get multiple times the same event ID in
# the results.
select_keywords += "DISTINCT"
# Multiple labels could cause the same event to appear multiple times.
needs_distinct = True
# If there is a filter on relation_senders and relation_types join to the
# relations table.
if event_filter and (
event_filter.relation_senders or event_filter.relation_types
):
# Filtering by relations could cause the same event to appear multiple
# times (since there's no limit on the number of relations to an event).
needs_distinct = True
join_clause += """
LEFT JOIN event_relations AS relation ON (event.event_id = relation.relates_to_id)
"""
if event_filter.relation_senders:
join_clause += """
LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id)
"""
if needs_distinct:
select_keywords += " DISTINCT"
sql = """
%(select_keywords)s
event_id, instance_name,
topological_ordering, stream_ordering
FROM events
event.event_id, event.instance_name,
event.topological_ordering, event.stream_ordering
FROM events AS event
%(join_clause)s
WHERE outlier = ? AND room_id = ? AND %(bounds)s
ORDER BY topological_ordering %(order)s,
stream_ordering %(order)s LIMIT ?
WHERE event.outlier = ? AND event.room_id = ? AND %(bounds)s
ORDER BY event.topological_ordering %(order)s,
event.stream_ordering %(order)s LIMIT ?
""" % {
"select_keywords": select_keywords,
"join_clause": join_clause,

View file

@ -1,5 +1,6 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,9 +15,10 @@
# limitations under the License.
import logging
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, cast
from synapse.storage._base import db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.types import JsonDict
from synapse.util import json_encoder
@ -50,7 +52,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
async def get_all_updated_tags(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
) -> Tuple[List[Tuple[int, Tuple[str, str, str]]], int, bool]:
"""Get updates for tags replication stream.
Args:
@ -75,7 +77,9 @@ class TagsWorkerStore(AccountDataWorkerStore):
if last_id == current_id:
return [], current_id, False
def get_all_updated_tags_txn(txn):
def get_all_updated_tags_txn(
txn: LoggingTransaction,
) -> List[Tuple[int, str, str]]:
sql = (
"SELECT stream_id, user_id, room_id"
" FROM room_tags_revisions as r"
@ -83,13 +87,16 @@ class TagsWorkerStore(AccountDataWorkerStore):
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
# mypy doesn't understand what the query is selecting.
return cast(List[Tuple[int, str, str]], txn.fetchall())
tag_ids = await self.db_pool.runInteraction(
"get_all_updated_tags", get_all_updated_tags_txn
)
def get_tag_content(txn, tag_ids):
def get_tag_content(
txn: LoggingTransaction, tag_ids
) -> List[Tuple[int, Tuple[str, str, str]]]:
sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
results = []
for stream_id, user_id, room_id in tag_ids:
@ -127,15 +134,15 @@ class TagsWorkerStore(AccountDataWorkerStore):
given version
Args:
user_id(str): The user to get the tags for.
stream_id(int): The earliest update to get for the user.
user_id: The user to get the tags for.
stream_id: The earliest update to get for the user.
Returns:
A mapping from room_id strings to lists of tag strings for all the
rooms that changed since the stream_id token.
"""
def get_updated_tags_txn(txn):
def get_updated_tags_txn(txn: LoggingTransaction) -> List[str]:
sql = (
"SELECT room_id from room_tags_revisions"
" WHERE user_id = ? AND stream_id > ?"
@ -200,7 +207,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
content_json = json_encoder.encode(content)
def add_tag_txn(txn, next_id):
def add_tag_txn(txn: LoggingTransaction, next_id: int) -> None:
self.db_pool.simple_upsert_txn(
txn,
table="room_tags",
@ -224,7 +231,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
"""
assert self._can_write_to_account_data
def remove_tag_txn(txn, next_id):
def remove_tag_txn(txn: LoggingTransaction, next_id: int) -> None:
sql = (
"DELETE FROM room_tags "
" WHERE user_id = ? AND room_id = ? AND tag = ?"

View file

@ -14,11 +14,12 @@
from typing import Dict, Iterable
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.util.caches.descriptors import cached, cachedList
class UserErasureWorkerStore(SQLBaseStore):
class UserErasureWorkerStore(CacheInvalidationWorkerStore):
@cached()
async def is_user_erased(self, user_id: str) -> bool:
"""
@ -69,7 +70,7 @@ class UserErasureStore(UserErasureWorkerStore):
user_id: full user_id to be erased
"""
def f(txn):
def f(txn: LoggingTransaction) -> None:
# first check if they are already in the list
txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,))
if txn.fetchone():
@ -89,7 +90,7 @@ class UserErasureStore(UserErasureWorkerStore):
user_id: full user_id to be un-erased
"""
def f(txn):
def f(txn: LoggingTransaction) -> None:
# first check if they are already in the list
txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,))
if not txn.fetchone():

View file

@ -45,10 +45,13 @@ Changes in SCHEMA_VERSION = 64:
Changes in SCHEMA_VERSION = 65:
- MSC2716: Remove unique event_id constraint from insertion_event_edges
because an insertion event can have multiple edges.
- Remove unused tables `user_stats_historical` and `room_stats_historical`.
"""
SCHEMA_COMPAT_VERSION = 60 # 60: "outlier" not in internal_metadata.
SCHEMA_COMPAT_VERSION = (
61 # 61: Remove unused tables `user_stats_historical` and `room_stats_historical`
)
"""Limit on how far the synapse codebase can be rolled back without breaking db compat
This value is stored in the database, and checked on startup. If the value in the

View file

@ -0,0 +1,19 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Remove unused tables room_stats_historical and user_stats_historical
-- which have not been read or written since schema version 61.
DROP TABLE IF EXISTS room_stats_historical;
DROP TABLE IF EXISTS user_stats_historical;

View file

@ -15,4 +15,4 @@
-- Check old events for thread relations.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(6502, 'event_thread_relation', '{}');
(6507, 'event_arbitrary_relations', '{}');

View file

@ -0,0 +1,18 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Background update to clear the inboxes of hidden and deleted devices.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(6508, 'remove_dead_devices_from_device_inbox', '{}');

View file

@ -1,4 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -11,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import heapq
import logging
import threading
@ -87,7 +89,25 @@ def _load_current_id(
return (max if step > 0 else min)(current_id, step)
class StreamIdGenerator:
class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
@abc.abstractmethod
def get_next(self) -> AsyncContextManager[int]:
raise NotImplementedError()
@abc.abstractmethod
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
raise NotImplementedError()
@abc.abstractmethod
def get_current_token(self) -> int:
raise NotImplementedError()
@abc.abstractmethod
def get_current_token_for_writer(self, instance_name: str) -> int:
raise NotImplementedError()
class StreamIdGenerator(AbstractStreamIdGenerator):
"""Used to generate new stream ids when persisting events while keeping
track of which transactions have been completed.
@ -209,7 +229,7 @@ class StreamIdGenerator:
return self.get_current_token()
class MultiWriterIdGenerator:
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
"""An ID generator that tracks a stream that can have multiple writers.
Uses a Postgres sequence to coordinate ID assignment, but positions of other