mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-08-15 19:30:12 -04:00
Merge remote-tracking branch 'upstream/release-v1.48'
This commit is contained in:
commit
9f4fa40b64
175 changed files with 6413 additions and 1993 deletions
|
@ -82,7 +82,7 @@ class BackgroundUpdater:
|
|||
process and autotuning the batch size.
|
||||
"""
|
||||
|
||||
MINIMUM_BACKGROUND_BATCH_SIZE = 100
|
||||
MINIMUM_BACKGROUND_BATCH_SIZE = 1
|
||||
DEFAULT_BACKGROUND_BATCH_SIZE = 100
|
||||
BACKGROUND_UPDATE_INTERVAL_MS = 1000
|
||||
BACKGROUND_UPDATE_DURATION_MS = 100
|
||||
|
@ -122,6 +122,8 @@ class BackgroundUpdater:
|
|||
|
||||
def start_doing_background_updates(self) -> None:
|
||||
if self.enabled:
|
||||
# if we start a new background update, not all updates are done.
|
||||
self._all_done = False
|
||||
run_as_background_process("background_updates", self.run_background_updates)
|
||||
|
||||
async def run_background_updates(self, sleep: bool = True) -> None:
|
||||
|
|
|
@ -188,7 +188,7 @@ class LoggingDatabaseConnection:
|
|||
|
||||
|
||||
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
|
||||
_CallbackListEntry = Tuple[Callable[..., None], Iterable[Any], Dict[str, Any]]
|
||||
_CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]]
|
||||
|
||||
|
||||
R = TypeVar("R")
|
||||
|
@ -235,7 +235,7 @@ class LoggingTransaction:
|
|||
self.after_callbacks = after_callbacks
|
||||
self.exception_callbacks = exception_callbacks
|
||||
|
||||
def call_after(self, callback: Callable[..., None], *args: Any, **kwargs: Any):
|
||||
def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any):
|
||||
"""Call the given callback on the main twisted thread after the
|
||||
transaction has finished. Used to invalidate the caches on the
|
||||
correct thread.
|
||||
|
@ -247,7 +247,7 @@ class LoggingTransaction:
|
|||
self.after_callbacks.append((callback, args, kwargs))
|
||||
|
||||
def call_on_exception(
|
||||
self, callback: Callable[..., None], *args: Any, **kwargs: Any
|
||||
self, callback: Callable[..., object], *args: Any, **kwargs: Any
|
||||
):
|
||||
# if self.exception_callbacks is None, that means that whatever constructed the
|
||||
# LoggingTransaction isn't expecting there to be any callbacks; assert that
|
||||
|
|
|
@ -123,9 +123,9 @@ class DataStore(
|
|||
RelationsStore,
|
||||
CensorEventsStore,
|
||||
UIAuthStore,
|
||||
EventForwardExtremitiesStore,
|
||||
CacheInvalidationWorkerStore,
|
||||
ServerMetricsStore,
|
||||
EventForwardExtremitiesStore,
|
||||
LockStore,
|
||||
SessionStore,
|
||||
):
|
||||
|
@ -154,6 +154,7 @@ class DataStore(
|
|||
db_conn, "local_group_updates", "stream_id"
|
||||
)
|
||||
|
||||
self._cache_id_gen: Optional[MultiWriterIdGenerator]
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
# We set the `writers` to an empty list here as we don't care about
|
||||
# missing updates over restarts, as we'll not have anything in our
|
||||
|
|
|
@ -412,16 +412,16 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
)
|
||||
|
||||
async def set_type_stream_id_for_appservice(
|
||||
self, service: ApplicationService, type: str, pos: Optional[int]
|
||||
self, service: ApplicationService, stream_type: str, pos: Optional[int]
|
||||
) -> None:
|
||||
if type not in ("read_receipt", "presence"):
|
||||
if stream_type not in ("read_receipt", "presence"):
|
||||
raise ValueError(
|
||||
"Expected type to be a valid application stream id type, got %s"
|
||||
% (type,)
|
||||
% (stream_type,)
|
||||
)
|
||||
|
||||
def set_type_stream_id_for_appservice_txn(txn):
|
||||
stream_id_type = "%s_stream_id" % type
|
||||
stream_id_type = "%s_stream_id" % stream_type
|
||||
txn.execute(
|
||||
"UPDATE application_services_state SET %s = ? WHERE as_id=?"
|
||||
% stream_id_type,
|
||||
|
|
|
@ -13,12 +13,12 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from synapse.events.utils import prune_event_dict
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.util import json_encoder
|
||||
|
@ -41,7 +41,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
|
|||
hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000)
|
||||
|
||||
@wrap_as_background_process("_censor_redactions")
|
||||
async def _censor_redactions(self):
|
||||
async def _censor_redactions(self) -> None:
|
||||
"""Censors all redactions older than the configured period that haven't
|
||||
been censored yet.
|
||||
|
||||
|
@ -105,7 +105,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
|
|||
and original_event.internal_metadata.is_redacted()
|
||||
):
|
||||
# Redaction was allowed
|
||||
pruned_json = json_encoder.encode(
|
||||
pruned_json: Optional[str] = json_encoder.encode(
|
||||
prune_event_dict(
|
||||
original_event.room_version, original_event.get_dict()
|
||||
)
|
||||
|
@ -116,7 +116,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
|
|||
|
||||
updates.append((redaction_id, event_id, pruned_json))
|
||||
|
||||
def _update_censor_txn(txn):
|
||||
def _update_censor_txn(txn: LoggingTransaction) -> None:
|
||||
for redaction_id, event_id, pruned_json in updates:
|
||||
if pruned_json:
|
||||
self._censor_event_txn(txn, event_id, pruned_json)
|
||||
|
@ -130,14 +130,16 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
|
|||
|
||||
await self.db_pool.runInteraction("_update_censor_txn", _update_censor_txn)
|
||||
|
||||
def _censor_event_txn(self, txn, event_id, pruned_json):
|
||||
def _censor_event_txn(
|
||||
self, txn: LoggingTransaction, event_id: str, pruned_json: str
|
||||
) -> None:
|
||||
"""Censor an event by replacing its JSON in the event_json table with the
|
||||
provided pruned JSON.
|
||||
|
||||
Args:
|
||||
txn (LoggingTransaction): The database transaction.
|
||||
event_id (str): The ID of the event to censor.
|
||||
pruned_json (str): The pruned JSON
|
||||
txn: The database transaction.
|
||||
event_id: The ID of the event to censor.
|
||||
pruned_json: The pruned JSON
|
||||
"""
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn,
|
||||
|
@ -157,7 +159,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
|
|||
# Try to retrieve the event's content from the database or the event cache.
|
||||
event = await self.get_event(event_id)
|
||||
|
||||
def delete_expired_event_txn(txn):
|
||||
def delete_expired_event_txn(txn: LoggingTransaction) -> None:
|
||||
# Delete the expiry timestamp associated with this event from the database.
|
||||
self._delete_event_expiry_txn(txn, event_id)
|
||||
|
||||
|
@ -194,14 +196,14 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
|
|||
"delete_expired_event", delete_expired_event_txn
|
||||
)
|
||||
|
||||
def _delete_event_expiry_txn(self, txn, event_id):
|
||||
def _delete_event_expiry_txn(self, txn: LoggingTransaction, event_id: str) -> None:
|
||||
"""Delete the expiry timestamp associated with an event ID without deleting the
|
||||
actual event.
|
||||
|
||||
Args:
|
||||
txn (LoggingTransaction): The transaction to use to perform the deletion.
|
||||
event_id (str): The event ID to delete the associated expiry timestamp of.
|
||||
txn: The transaction to use to perform the deletion.
|
||||
event_id: The event ID to delete the associated expiry timestamp of.
|
||||
"""
|
||||
return self.db_pool.simple_delete_txn(
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
|
||||
)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright 2016 OpenMarket Ltd
|
||||
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -19,9 +20,17 @@ from synapse.logging import issue9533_logger
|
|||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||
from synapse.replication.tcp.streams import ToDeviceStream
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
||||
from synapse.storage.util.id_generators import (
|
||||
AbstractStreamIdGenerator,
|
||||
MultiWriterIdGenerator,
|
||||
StreamIdGenerator,
|
||||
)
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
|
@ -34,14 +43,21 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class DeviceInboxWorkerStore(SQLBaseStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
# Map of (user_id, device_id) to the last stream_id that has been
|
||||
# deleted up to. This is so that we can no op deletions.
|
||||
self._last_device_delete_cache = ExpiringCache(
|
||||
self._last_device_delete_cache: ExpiringCache[
|
||||
Tuple[str, Optional[str]], int
|
||||
] = ExpiringCache(
|
||||
cache_name="last_device_delete_cache",
|
||||
clock=self._clock,
|
||||
max_len=10000,
|
||||
|
@ -53,14 +69,16 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
self._instance_name in hs.config.worker.writers.to_device
|
||||
)
|
||||
|
||||
self._device_inbox_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
stream_name="to_device",
|
||||
instance_name=self._instance_name,
|
||||
tables=[("device_inbox", "instance_name", "stream_id")],
|
||||
sequence_name="device_inbox_sequence",
|
||||
writers=hs.config.worker.writers.to_device,
|
||||
self._device_inbox_id_gen: AbstractStreamIdGenerator = (
|
||||
MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
stream_name="to_device",
|
||||
instance_name=self._instance_name,
|
||||
tables=[("device_inbox", "instance_name", "stream_id")],
|
||||
sequence_name="device_inbox_sequence",
|
||||
writers=hs.config.worker.writers.to_device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._can_write_to_device = True
|
||||
|
@ -101,6 +119,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == ToDeviceStream.NAME:
|
||||
# If replication is happening than postgres must be being used.
|
||||
assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator)
|
||||
self._device_inbox_id_gen.advance(instance_name, token)
|
||||
for row in rows:
|
||||
if row.entity.startswith("@"):
|
||||
|
@ -134,7 +154,10 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
limit: The maximum number of messages to retrieve.
|
||||
|
||||
Returns:
|
||||
A list of messages for the device and where in the stream the messages got to.
|
||||
A tuple containing:
|
||||
* A list of messages for the device.
|
||||
* The max stream token of these messages. There may be more to retrieve
|
||||
if the given limit was reached.
|
||||
"""
|
||||
has_changed = self._device_inbox_stream_cache.has_entity_changed(
|
||||
user_id, last_stream_id
|
||||
|
@ -153,12 +176,19 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
txn.execute(
|
||||
sql, (user_id, device_id, last_stream_id, current_stream_id, limit)
|
||||
)
|
||||
|
||||
messages = []
|
||||
stream_pos = current_stream_id
|
||||
|
||||
for row in txn:
|
||||
stream_pos = row[0]
|
||||
messages.append(db_to_json(row[1]))
|
||||
|
||||
# If the limit was not reached we know that there's no more data for this
|
||||
# user/device pair up to current_stream_id.
|
||||
if len(messages) < limit:
|
||||
stream_pos = current_stream_id
|
||||
|
||||
return messages, stream_pos
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
|
@ -210,11 +240,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
log_kv({"message": f"deleted {count} messages for device", "count": count})
|
||||
|
||||
# Update the cache, ensuring that we only ever increase the value
|
||||
last_deleted_stream_id = self._last_device_delete_cache.get(
|
||||
updated_last_deleted_stream_id = self._last_device_delete_cache.get(
|
||||
(user_id, device_id), 0
|
||||
)
|
||||
self._last_device_delete_cache[(user_id, device_id)] = max(
|
||||
last_deleted_stream_id, up_to_stream_id
|
||||
updated_last_deleted_stream_id, up_to_stream_id
|
||||
)
|
||||
|
||||
return count
|
||||
|
@ -260,13 +290,20 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
" LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (destination, last_stream_id, current_stream_id, limit))
|
||||
|
||||
messages = []
|
||||
stream_pos = current_stream_id
|
||||
|
||||
for row in txn:
|
||||
stream_pos = row[0]
|
||||
messages.append(db_to_json(row[1]))
|
||||
|
||||
# If the limit was not reached we know that there's no more data for this
|
||||
# user/device pair up to current_stream_id.
|
||||
if len(messages) < limit:
|
||||
log_kv({"message": "Set stream position to current position"})
|
||||
stream_pos = current_stream_id
|
||||
|
||||
return messages, stream_pos
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
|
@ -372,8 +409,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
"""Used to send messages from this server.
|
||||
|
||||
Args:
|
||||
local_messages_by_user_and_device:
|
||||
Dictionary of user_id to device_id to message.
|
||||
local_messages_by_user_then_device:
|
||||
Dictionary of recipient user_id to recipient device_id to message.
|
||||
remote_messages_by_destination:
|
||||
Dictionary of destination server_name to the EDU JSON to send.
|
||||
|
||||
|
@ -415,7 +452,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
async with self._device_inbox_id_gen.get_next() as stream_id:
|
||||
now_ms = self.clock.time_msec()
|
||||
now_ms = self._clock.time_msec()
|
||||
await self.db_pool.runInteraction(
|
||||
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
|
||||
)
|
||||
|
@ -466,7 +503,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
async with self._device_inbox_id_gen.get_next() as stream_id:
|
||||
now_ms = self.clock.time_msec()
|
||||
now_ms = self._clock.time_msec()
|
||||
await self.db_pool.runInteraction(
|
||||
"add_messages_from_remote_to_device_inbox",
|
||||
add_messages_txn,
|
||||
|
@ -562,6 +599,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
|
|||
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
|
||||
REMOVE_DELETED_DEVICES = "remove_deleted_devices_from_device_inbox"
|
||||
REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox"
|
||||
REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox"
|
||||
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
@ -577,14 +615,18 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
|
|||
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
self.REMOVE_DELETED_DEVICES,
|
||||
self._remove_deleted_devices_from_device_inbox,
|
||||
# Used to be a background update that deletes all device_inboxes for deleted
|
||||
# devices.
|
||||
self.db_pool.updates.register_noop_background_update(
|
||||
self.REMOVE_DELETED_DEVICES
|
||||
)
|
||||
# Used to be a background update that deletes all device_inboxes for hidden
|
||||
# devices.
|
||||
self.db_pool.updates.register_noop_background_update(self.REMOVE_HIDDEN_DEVICES)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
self.REMOVE_HIDDEN_DEVICES,
|
||||
self._remove_hidden_devices_from_device_inbox,
|
||||
self.REMOVE_DEAD_DEVICES_FROM_INBOX,
|
||||
self._remove_dead_devices_from_device_inbox,
|
||||
)
|
||||
|
||||
async def _background_drop_index_device_inbox(self, progress, batch_size):
|
||||
|
@ -599,171 +641,83 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
|
|||
|
||||
return 1
|
||||
|
||||
async def _remove_deleted_devices_from_device_inbox(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
async def _remove_dead_devices_from_device_inbox(
|
||||
self,
|
||||
progress: JsonDict,
|
||||
batch_size: int,
|
||||
) -> int:
|
||||
"""A background update that deletes all device_inboxes for deleted devices.
|
||||
|
||||
This should only need to be run once (when users upgrade to v1.47.0)
|
||||
"""A background update to remove devices that were either deleted or hidden from
|
||||
the device_inbox table.
|
||||
|
||||
Args:
|
||||
progress: JsonDict used to store progress of this background update
|
||||
batch_size: the maximum number of rows to retrieve in a single select query
|
||||
progress: The update's progress dict.
|
||||
batch_size: The batch size for this update.
|
||||
|
||||
Returns:
|
||||
The number of deleted rows
|
||||
The number of rows deleted.
|
||||
"""
|
||||
|
||||
def _remove_deleted_devices_from_device_inbox_txn(
|
||||
def _remove_dead_devices_from_device_inbox_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> int:
|
||||
"""stream_id is not unique
|
||||
we need to use an inclusive `stream_id >= ?` clause,
|
||||
since we might not have deleted all dead device messages for the stream_id
|
||||
returned from the previous query
|
||||
) -> Tuple[int, bool]:
|
||||
|
||||
Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
|
||||
to avoid problems of deleting a large number of rows all at once
|
||||
due to a single device having lots of device messages.
|
||||
"""
|
||||
if "max_stream_id" in progress:
|
||||
max_stream_id = progress["max_stream_id"]
|
||||
else:
|
||||
txn.execute("SELECT max(stream_id) FROM device_inbox")
|
||||
# There's a type mismatch here between how we want to type the row and
|
||||
# what fetchone says it returns, but we silence it because we know that
|
||||
# res can't be None.
|
||||
res: Tuple[Optional[int]] = txn.fetchone() # type: ignore[assignment]
|
||||
if res[0] is None:
|
||||
# this can only happen if the `device_inbox` table is empty, in which
|
||||
# case we have no work to do.
|
||||
return 0, True
|
||||
else:
|
||||
max_stream_id = res[0]
|
||||
|
||||
last_stream_id = progress.get("stream_id", 0)
|
||||
start = progress.get("stream_id", 0)
|
||||
stop = start + batch_size
|
||||
|
||||
# delete rows in `device_inbox` which do *not* correspond to a known,
|
||||
# unhidden device.
|
||||
sql = """
|
||||
SELECT device_id, user_id, stream_id
|
||||
FROM device_inbox
|
||||
DELETE FROM device_inbox
|
||||
WHERE
|
||||
stream_id >= ?
|
||||
AND (device_id, user_id) NOT IN (
|
||||
SELECT device_id, user_id FROM devices
|
||||
stream_id >= ? AND stream_id < ?
|
||||
AND NOT EXISTS (
|
||||
SELECT * FROM devices d
|
||||
WHERE
|
||||
d.device_id=device_inbox.device_id
|
||||
AND d.user_id=device_inbox.user_id
|
||||
AND NOT hidden
|
||||
)
|
||||
ORDER BY stream_id
|
||||
LIMIT ?
|
||||
"""
|
||||
"""
|
||||
|
||||
txn.execute(sql, (last_stream_id, batch_size))
|
||||
rows = txn.fetchall()
|
||||
txn.execute(sql, (start, stop))
|
||||
|
||||
num_deleted = 0
|
||||
for row in rows:
|
||||
num_deleted += self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
"device_inbox",
|
||||
{"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
|
||||
)
|
||||
|
||||
if rows:
|
||||
# send more than stream_id to progress
|
||||
# otherwise it can happen in large deployments that
|
||||
# no change of status is visible in the log file
|
||||
# it may be that the stream_id does not change in several runs
|
||||
self.db_pool.updates._background_update_progress_txn(
|
||||
txn,
|
||||
self.REMOVE_DELETED_DEVICES,
|
||||
{
|
||||
"device_id": rows[-1][0],
|
||||
"user_id": rows[-1][1],
|
||||
"stream_id": rows[-1][2],
|
||||
},
|
||||
)
|
||||
|
||||
return num_deleted
|
||||
|
||||
number_deleted = await self.db_pool.runInteraction(
|
||||
"_remove_deleted_devices_from_device_inbox",
|
||||
_remove_deleted_devices_from_device_inbox_txn,
|
||||
)
|
||||
|
||||
# The task is finished when no more lines are deleted.
|
||||
if not number_deleted:
|
||||
await self.db_pool.updates._end_background_update(
|
||||
self.REMOVE_DELETED_DEVICES
|
||||
self.db_pool.updates._background_update_progress_txn(
|
||||
txn,
|
||||
self.REMOVE_DEAD_DEVICES_FROM_INBOX,
|
||||
{
|
||||
"stream_id": stop,
|
||||
"max_stream_id": max_stream_id,
|
||||
},
|
||||
)
|
||||
|
||||
return number_deleted
|
||||
return stop > max_stream_id
|
||||
|
||||
async def _remove_hidden_devices_from_device_inbox(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
"""A background update that deletes all device_inboxes for hidden devices.
|
||||
|
||||
This should only need to be run once (when users upgrade to v1.47.0)
|
||||
|
||||
Args:
|
||||
progress: JsonDict used to store progress of this background update
|
||||
batch_size: the maximum number of rows to retrieve in a single select query
|
||||
|
||||
Returns:
|
||||
The number of deleted rows
|
||||
"""
|
||||
|
||||
def _remove_hidden_devices_from_device_inbox_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> int:
|
||||
"""stream_id is not unique
|
||||
we need to use an inclusive `stream_id >= ?` clause,
|
||||
since we might not have deleted all hidden device messages for the stream_id
|
||||
returned from the previous query
|
||||
|
||||
Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
|
||||
to avoid problems of deleting a large number of rows all at once
|
||||
due to a single device having lots of device messages.
|
||||
"""
|
||||
|
||||
last_stream_id = progress.get("stream_id", 0)
|
||||
|
||||
sql = """
|
||||
SELECT device_id, user_id, stream_id
|
||||
FROM device_inbox
|
||||
WHERE
|
||||
stream_id >= ?
|
||||
AND (device_id, user_id) IN (
|
||||
SELECT device_id, user_id FROM devices WHERE hidden = ?
|
||||
)
|
||||
ORDER BY stream_id
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
txn.execute(sql, (last_stream_id, True, batch_size))
|
||||
rows = txn.fetchall()
|
||||
|
||||
num_deleted = 0
|
||||
for row in rows:
|
||||
num_deleted += self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
"device_inbox",
|
||||
{"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
|
||||
)
|
||||
|
||||
if rows:
|
||||
# We don't just save the `stream_id` in progress as
|
||||
# otherwise it can happen in large deployments that
|
||||
# no change of status is visible in the log file, as
|
||||
# it may be that the stream_id does not change in several runs
|
||||
self.db_pool.updates._background_update_progress_txn(
|
||||
txn,
|
||||
self.REMOVE_HIDDEN_DEVICES,
|
||||
{
|
||||
"device_id": rows[-1][0],
|
||||
"user_id": rows[-1][1],
|
||||
"stream_id": rows[-1][2],
|
||||
},
|
||||
)
|
||||
|
||||
return num_deleted
|
||||
|
||||
number_deleted = await self.db_pool.runInteraction(
|
||||
"_remove_hidden_devices_from_device_inbox",
|
||||
_remove_hidden_devices_from_device_inbox_txn,
|
||||
finished = await self.db_pool.runInteraction(
|
||||
"_remove_devices_from_device_inbox_txn",
|
||||
_remove_dead_devices_from_device_inbox_txn,
|
||||
)
|
||||
|
||||
# The task is finished when no more lines are deleted.
|
||||
if not number_deleted:
|
||||
if finished:
|
||||
await self.db_pool.updates._end_background_update(
|
||||
self.REMOVE_HIDDEN_DEVICES
|
||||
self.REMOVE_DEAD_DEVICES_FROM_INBOX,
|
||||
)
|
||||
|
||||
return number_deleted
|
||||
return batch_size
|
||||
|
||||
|
||||
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
|
||||
|
|
|
@ -13,17 +13,18 @@
|
|||
# limitations under the License.
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import Iterable, List, Optional
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.storage.databases.main import CacheInvalidationWorkerStore
|
||||
from synapse.types import RoomAlias
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
|
||||
|
||||
|
||||
class DirectoryWorkerStore(SQLBaseStore):
|
||||
class DirectoryWorkerStore(CacheInvalidationWorkerStore):
|
||||
async def get_association_from_room_alias(
|
||||
self, room_alias: RoomAlias
|
||||
) -> Optional[RoomAliasMapping]:
|
||||
|
@ -91,7 +92,7 @@ class DirectoryWorkerStore(SQLBaseStore):
|
|||
creator: Optional user_id of creator.
|
||||
"""
|
||||
|
||||
def alias_txn(txn):
|
||||
def alias_txn(txn: LoggingTransaction) -> None:
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn,
|
||||
"room_aliases",
|
||||
|
@ -126,14 +127,16 @@ class DirectoryWorkerStore(SQLBaseStore):
|
|||
|
||||
|
||||
class DirectoryStore(DirectoryWorkerStore):
|
||||
async def delete_room_alias(self, room_alias: RoomAlias) -> str:
|
||||
async def delete_room_alias(self, room_alias: RoomAlias) -> Optional[str]:
|
||||
room_id = await self.db_pool.runInteraction(
|
||||
"delete_room_alias", self._delete_room_alias_txn, room_alias
|
||||
)
|
||||
|
||||
return room_id
|
||||
|
||||
def _delete_room_alias_txn(self, txn, room_alias: RoomAlias) -> str:
|
||||
def _delete_room_alias_txn(
|
||||
self, txn: LoggingTransaction, room_alias: RoomAlias
|
||||
) -> Optional[str]:
|
||||
txn.execute(
|
||||
"SELECT room_id FROM room_aliases WHERE room_alias = ?",
|
||||
(room_alias.to_string(),),
|
||||
|
@ -173,9 +176,9 @@ class DirectoryStore(DirectoryWorkerStore):
|
|||
If None, the creator will be left unchanged.
|
||||
"""
|
||||
|
||||
def _update_aliases_for_room_txn(txn):
|
||||
def _update_aliases_for_room_txn(txn: LoggingTransaction) -> None:
|
||||
update_creator_sql = ""
|
||||
sql_params = (new_room_id, old_room_id)
|
||||
sql_params: Tuple[str, ...] = (new_room_id, old_room_id)
|
||||
if creator:
|
||||
update_creator_sql = ", creator = ?"
|
||||
sql_params = (new_room_id, creator, old_room_id)
|
||||
|
|
|
@ -408,29 +408,58 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||
fallback_keys: the keys to set. This is a map from key ID (which is
|
||||
of the form "algorithm:id") to key data.
|
||||
"""
|
||||
await self.db_pool.runInteraction(
|
||||
"set_e2e_fallback_keys_txn",
|
||||
self._set_e2e_fallback_keys_txn,
|
||||
user_id,
|
||||
device_id,
|
||||
fallback_keys,
|
||||
)
|
||||
|
||||
await self.invalidate_cache_and_stream(
|
||||
"get_e2e_unused_fallback_key_types", (user_id, device_id)
|
||||
)
|
||||
|
||||
def _set_e2e_fallback_keys_txn(
|
||||
self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict
|
||||
) -> None:
|
||||
# fallback_keys will usually only have one item in it, so using a for
|
||||
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
|
||||
# FIXME: make sure that only one key per algorithm is uploaded
|
||||
for key_id, fallback_key in fallback_keys.items():
|
||||
algorithm, key_id = key_id.split(":", 1)
|
||||
await self.db_pool.simple_upsert(
|
||||
"e2e_fallback_keys_json",
|
||||
old_key_json = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="e2e_fallback_keys_json",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"algorithm": algorithm,
|
||||
},
|
||||
values={
|
||||
"key_id": key_id,
|
||||
"key_json": json_encoder.encode(fallback_key),
|
||||
"used": False,
|
||||
},
|
||||
desc="set_e2e_fallback_key",
|
||||
retcol="key_json",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
await self.invalidate_cache_and_stream(
|
||||
"get_e2e_unused_fallback_key_types", (user_id, device_id)
|
||||
)
|
||||
new_key_json = encode_canonical_json(fallback_key).decode("utf-8")
|
||||
|
||||
# If the uploaded key is the same as the current fallback key,
|
||||
# don't do anything. This prevents marking the key as unused if it
|
||||
# was already used.
|
||||
if old_key_json != new_key_json:
|
||||
self.db_pool.simple_upsert_txn(
|
||||
txn,
|
||||
table="e2e_fallback_keys_json",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"algorithm": algorithm,
|
||||
},
|
||||
values={
|
||||
"key_id": key_id,
|
||||
"key_json": json_encoder.encode(fallback_key),
|
||||
"used": False,
|
||||
},
|
||||
)
|
||||
|
||||
@cached(max_entries=10000)
|
||||
async def get_e2e_unused_fallback_key_types(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2018-2019 New Vector Ltd
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1641,8 +1641,8 @@ class PersistEventsStore:
|
|||
def _store_room_members_txn(self, txn, events, backfilled):
|
||||
"""Store a room member in the database."""
|
||||
|
||||
def str_or_none(val: Any) -> Optional[str]:
|
||||
return val if isinstance(val, str) else None
|
||||
def non_null_str_or_none(val: Any) -> Optional[str]:
|
||||
return val if isinstance(val, str) and "\u0000" not in val else None
|
||||
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
|
@ -1654,8 +1654,10 @@ class PersistEventsStore:
|
|||
"sender": event.user_id,
|
||||
"room_id": event.room_id,
|
||||
"membership": event.membership,
|
||||
"display_name": str_or_none(event.content.get("displayname")),
|
||||
"avatar_url": str_or_none(event.content.get("avatar_url")),
|
||||
"display_name": non_null_str_or_none(
|
||||
event.content.get("displayname")
|
||||
),
|
||||
"avatar_url": non_null_str_or_none(event.content.get("avatar_url")),
|
||||
}
|
||||
for event in events
|
||||
],
|
||||
|
@ -1694,34 +1696,33 @@ class PersistEventsStore:
|
|||
},
|
||||
)
|
||||
|
||||
def _handle_event_relations(self, txn, event):
|
||||
"""Handles inserting relation data during peristence of events
|
||||
def _handle_event_relations(
|
||||
self, txn: LoggingTransaction, event: EventBase
|
||||
) -> None:
|
||||
"""Handles inserting relation data during persistence of events
|
||||
|
||||
Args:
|
||||
txn
|
||||
event (EventBase)
|
||||
txn: The current database transaction.
|
||||
event: The event which might have relations.
|
||||
"""
|
||||
relation = event.content.get("m.relates_to")
|
||||
if not relation:
|
||||
# No relations
|
||||
return
|
||||
|
||||
# Relations must have a type and parent event ID.
|
||||
rel_type = relation.get("rel_type")
|
||||
if rel_type not in (
|
||||
RelationTypes.ANNOTATION,
|
||||
RelationTypes.REFERENCE,
|
||||
RelationTypes.REPLACE,
|
||||
RelationTypes.THREAD,
|
||||
):
|
||||
# Unknown relation type
|
||||
if not isinstance(rel_type, str):
|
||||
return
|
||||
|
||||
parent_id = relation.get("event_id")
|
||||
if not parent_id:
|
||||
# Invalid relation
|
||||
if not isinstance(parent_id, str):
|
||||
return
|
||||
|
||||
aggregation_key = relation.get("key")
|
||||
# Annotations have a key field.
|
||||
aggregation_key = None
|
||||
if rel_type == RelationTypes.ANNOTATION:
|
||||
aggregation_key = relation.get("key")
|
||||
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -171,8 +171,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
self._purged_chain_cover_index,
|
||||
)
|
||||
|
||||
# The event_thread_relation background update was replaced with the
|
||||
# event_arbitrary_relations one, which handles any relation to avoid
|
||||
# needed to potentially crawl the entire events table in the future.
|
||||
self.db_pool.updates.register_noop_background_update("event_thread_relation")
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
"event_thread_relation", self._event_thread_relation
|
||||
"event_arbitrary_relations",
|
||||
self._event_arbitrary_relations,
|
||||
)
|
||||
|
||||
################################################################################
|
||||
|
@ -1099,23 +1105,27 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
|
||||
return result
|
||||
|
||||
async def _event_thread_relation(self, progress: JsonDict, batch_size: int) -> int:
|
||||
"""Background update handler which will store thread relations for existing events."""
|
||||
async def _event_arbitrary_relations(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
"""Background update handler which will store previously unknown relations for existing events."""
|
||||
last_event_id = progress.get("last_event_id", "")
|
||||
|
||||
def _event_thread_relation_txn(txn: LoggingTransaction) -> int:
|
||||
def _event_arbitrary_relations_txn(txn: LoggingTransaction) -> int:
|
||||
# Fetch events and then filter based on whether the event has a
|
||||
# relation or not.
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT event_id, json FROM event_json
|
||||
LEFT JOIN event_relations USING (event_id)
|
||||
WHERE event_id > ? AND event_relations.event_id IS NULL
|
||||
WHERE event_id > ?
|
||||
ORDER BY event_id LIMIT ?
|
||||
""",
|
||||
(last_event_id, batch_size),
|
||||
)
|
||||
|
||||
results = list(txn)
|
||||
missing_thread_relations = []
|
||||
# (event_id, parent_id, rel_type) for each relation
|
||||
relations_to_insert: List[Tuple[str, str, str]] = []
|
||||
for (event_id, event_json_raw) in results:
|
||||
try:
|
||||
event_json = db_to_json(event_json_raw)
|
||||
|
@ -1127,48 +1137,70 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
)
|
||||
continue
|
||||
|
||||
# If there's no relation (or it is not a thread), skip!
|
||||
# If there's no relation, skip!
|
||||
relates_to = event_json["content"].get("m.relates_to")
|
||||
if not relates_to or not isinstance(relates_to, dict):
|
||||
continue
|
||||
if relates_to.get("rel_type") != RelationTypes.THREAD:
|
||||
|
||||
# If the relation type or parent event ID is not a string, skip it.
|
||||
#
|
||||
# Do not consider relation types that have existed for a long time,
|
||||
# since they will already be listed in the `event_relations` table.
|
||||
rel_type = relates_to.get("rel_type")
|
||||
if not isinstance(rel_type, str) or rel_type in (
|
||||
RelationTypes.ANNOTATION,
|
||||
RelationTypes.REFERENCE,
|
||||
RelationTypes.REPLACE,
|
||||
):
|
||||
continue
|
||||
|
||||
# Get the parent ID.
|
||||
parent_id = relates_to.get("event_id")
|
||||
if not isinstance(parent_id, str):
|
||||
continue
|
||||
|
||||
missing_thread_relations.append((event_id, parent_id))
|
||||
relations_to_insert.append((event_id, parent_id, rel_type))
|
||||
|
||||
# Insert the missing data.
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
txn=txn,
|
||||
table="event_relations",
|
||||
values=[
|
||||
{
|
||||
"event_id": event_id,
|
||||
"relates_to_Id": parent_id,
|
||||
"relation_type": RelationTypes.THREAD,
|
||||
}
|
||||
for event_id, parent_id in missing_thread_relations
|
||||
],
|
||||
)
|
||||
# Insert the missing data, note that we upsert here in case the event
|
||||
# has already been processed.
|
||||
if relations_to_insert:
|
||||
self.db_pool.simple_upsert_many_txn(
|
||||
txn=txn,
|
||||
table="event_relations",
|
||||
key_names=("event_id",),
|
||||
key_values=[(r[0],) for r in relations_to_insert],
|
||||
value_names=("relates_to_id", "relation_type"),
|
||||
value_values=[r[1:] for r in relations_to_insert],
|
||||
)
|
||||
|
||||
# Iterate the parent IDs and invalidate caches.
|
||||
for parent_id in {r[1] for r in relations_to_insert}:
|
||||
cache_tuple = (parent_id,)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_relations_for_event, cache_tuple
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_aggregation_groups_for_event, cache_tuple
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_thread_summary, cache_tuple
|
||||
)
|
||||
|
||||
if results:
|
||||
latest_event_id = results[-1][0]
|
||||
self.db_pool.updates._background_update_progress_txn(
|
||||
txn, "event_thread_relation", {"last_event_id": latest_event_id}
|
||||
txn, "event_arbitrary_relations", {"last_event_id": latest_event_id}
|
||||
)
|
||||
|
||||
return len(results)
|
||||
|
||||
num_rows = await self.db_pool.runInteraction(
|
||||
desc="event_thread_relation", func=_event_thread_relation_txn
|
||||
desc="event_arbitrary_relations", func=_event_arbitrary_relations_txn
|
||||
)
|
||||
|
||||
if not num_rows:
|
||||
await self.db_pool.updates._end_background_update("event_thread_relation")
|
||||
await self.db_pool.updates._end_background_update(
|
||||
"event_arbitrary_relations"
|
||||
)
|
||||
|
||||
return num_rows
|
||||
|
||||
|
|
|
@ -13,15 +13,20 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.storage.databases.main import CacheInvalidationWorkerStore
|
||||
from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventForwardExtremitiesStore(SQLBaseStore):
|
||||
class EventForwardExtremitiesStore(
|
||||
EventFederationWorkerStore,
|
||||
CacheInvalidationWorkerStore,
|
||||
):
|
||||
async def delete_forward_extremities_for_room(self, room_id: str) -> int:
|
||||
"""Delete any extra forward extremities for a room.
|
||||
|
||||
|
@ -31,7 +36,7 @@ class EventForwardExtremitiesStore(SQLBaseStore):
|
|||
Returns count deleted.
|
||||
"""
|
||||
|
||||
def delete_forward_extremities_for_room_txn(txn):
|
||||
def delete_forward_extremities_for_room_txn(txn: LoggingTransaction) -> int:
|
||||
# First we need to get the event_id to not delete
|
||||
sql = """
|
||||
SELECT event_id FROM event_forward_extremities
|
||||
|
@ -82,10 +87,14 @@ class EventForwardExtremitiesStore(SQLBaseStore):
|
|||
delete_forward_extremities_for_room_txn,
|
||||
)
|
||||
|
||||
async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]:
|
||||
async def get_forward_extremities_for_room(
|
||||
self, room_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get list of forward extremities for a room."""
|
||||
|
||||
def get_forward_extremities_for_room_txn(txn):
|
||||
def get_forward_extremities_for_room_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> List[Dict[str, Any]]:
|
||||
sql = """
|
||||
SELECT event_id, state_group, depth, received_ts
|
||||
FROM event_forward_extremities
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -18,6 +19,7 @@ from canonicaljson import encode_canonical_json
|
|||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
|
@ -49,7 +51,7 @@ class FilteringStore(SQLBaseStore):
|
|||
|
||||
# Need an atomic transaction to SELECT the maximal ID so far then
|
||||
# INSERT a new one
|
||||
def _do_txn(txn):
|
||||
def _do_txn(txn: LoggingTransaction) -> int:
|
||||
sql = (
|
||||
"SELECT filter_id FROM user_filters "
|
||||
"WHERE user_id = ? AND filter_json = ?"
|
||||
|
@ -61,7 +63,7 @@ class FilteringStore(SQLBaseStore):
|
|||
|
||||
sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
|
||||
txn.execute(sql, (user_localpart,))
|
||||
max_id = txn.fetchone()[0]
|
||||
max_id = txn.fetchone()[0] # type: ignore[index]
|
||||
if max_id is None:
|
||||
filter_id = 0
|
||||
else:
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
from types import TracebackType
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Type
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from twisted.internet.interfaces import IReactorCore
|
||||
|
@ -62,7 +62,9 @@ class LockStore(SQLBaseStore):
|
|||
|
||||
# A map from `(lock_name, lock_key)` to the token of any locks that we
|
||||
# think we currently hold.
|
||||
self._live_tokens: Dict[Tuple[str, str], Lock] = WeakValueDictionary()
|
||||
self._live_tokens: WeakValueDictionary[
|
||||
Tuple[str, str], Lock
|
||||
] = WeakValueDictionary()
|
||||
|
||||
# When we shut down we want to remove the locks. Technically this can
|
||||
# lead to a race, as we may drop the lock while we are still processing.
|
||||
|
|
|
@ -13,10 +13,25 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.types import JsonDict, UserID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
@ -46,7 +61,12 @@ class MediaSortOrder(Enum):
|
|||
|
||||
|
||||
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
|
@ -102,13 +122,15 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
|
|||
self._drop_media_index_without_method,
|
||||
)
|
||||
|
||||
async def _drop_media_index_without_method(self, progress, batch_size):
|
||||
async def _drop_media_index_without_method(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
"""background update handler which removes the old constraints.
|
||||
|
||||
Note that this is only run on postgres.
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
def f(txn: LoggingTransaction) -> None:
|
||||
txn.execute(
|
||||
"ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key"
|
||||
)
|
||||
|
@ -126,7 +148,12 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
|
|||
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||
"""Persistence for attachments and avatars"""
|
||||
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
self.server_name = hs.hostname
|
||||
|
||||
|
@ -174,7 +201,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
plus the total count of all the user's media
|
||||
"""
|
||||
|
||||
def get_local_media_by_user_paginate_txn(txn):
|
||||
def get_local_media_by_user_paginate_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
|
||||
# Set ordering
|
||||
order_by_column = MediaSortOrder(order_by).value
|
||||
|
@ -184,14 +213,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
else:
|
||||
order = "ASC"
|
||||
|
||||
args = [user_id]
|
||||
args: List[Union[str, int]] = [user_id]
|
||||
sql = """
|
||||
SELECT COUNT(*) as total_media
|
||||
FROM local_media_repository
|
||||
WHERE user_id = ?
|
||||
"""
|
||||
txn.execute(sql, args)
|
||||
count = txn.fetchone()[0]
|
||||
count = txn.fetchone()[0] # type: ignore[index]
|
||||
|
||||
sql = """
|
||||
SELECT
|
||||
|
@ -268,7 +297,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
)
|
||||
sql += sql_keep
|
||||
|
||||
def _get_local_media_before_txn(txn):
|
||||
def _get_local_media_before_txn(txn: LoggingTransaction) -> List[str]:
|
||||
txn.execute(sql, (before_ts, before_ts, size_gt))
|
||||
return [row[0] for row in txn]
|
||||
|
||||
|
@ -278,13 +307,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
|
||||
async def store_local_media(
|
||||
self,
|
||||
media_id,
|
||||
media_type,
|
||||
time_now_ms,
|
||||
upload_name,
|
||||
media_length,
|
||||
user_id,
|
||||
url_cache=None,
|
||||
media_id: str,
|
||||
media_type: str,
|
||||
time_now_ms: int,
|
||||
upload_name: Optional[str],
|
||||
media_length: int,
|
||||
user_id: UserID,
|
||||
url_cache: Optional[str] = None,
|
||||
) -> None:
|
||||
await self.db_pool.simple_insert(
|
||||
"local_media_repository",
|
||||
|
@ -315,7 +344,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
None if the URL isn't cached.
|
||||
"""
|
||||
|
||||
def get_url_cache_txn(txn):
|
||||
def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
|
||||
# get the most recently cached result (relative to the given ts)
|
||||
sql = (
|
||||
"SELECT response_code, etag, expires_ts, og, media_id, download_ts"
|
||||
|
@ -359,7 +388,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
|
||||
async def store_url_cache(
|
||||
self, url, response_code, etag, expires_ts, og, media_id, download_ts
|
||||
):
|
||||
) -> None:
|
||||
await self.db_pool.simple_insert(
|
||||
"local_media_repository_url_cache",
|
||||
{
|
||||
|
@ -390,13 +419,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
|
||||
async def store_local_thumbnail(
|
||||
self,
|
||||
media_id,
|
||||
thumbnail_width,
|
||||
thumbnail_height,
|
||||
thumbnail_type,
|
||||
thumbnail_method,
|
||||
thumbnail_length,
|
||||
):
|
||||
media_id: str,
|
||||
thumbnail_width: int,
|
||||
thumbnail_height: int,
|
||||
thumbnail_type: str,
|
||||
thumbnail_method: str,
|
||||
thumbnail_length: int,
|
||||
) -> None:
|
||||
await self.db_pool.simple_upsert(
|
||||
table="local_media_repository_thumbnails",
|
||||
keyvalues={
|
||||
|
@ -430,14 +459,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
|
||||
async def store_cached_remote_media(
|
||||
self,
|
||||
origin,
|
||||
media_id,
|
||||
media_type,
|
||||
media_length,
|
||||
time_now_ms,
|
||||
upload_name,
|
||||
filesystem_id,
|
||||
):
|
||||
origin: str,
|
||||
media_id: str,
|
||||
media_type: str,
|
||||
media_length: int,
|
||||
time_now_ms: int,
|
||||
upload_name: Optional[str],
|
||||
filesystem_id: str,
|
||||
) -> None:
|
||||
await self.db_pool.simple_insert(
|
||||
"remote_media_cache",
|
||||
{
|
||||
|
@ -458,7 +487,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
local_media: Iterable[str],
|
||||
remote_media: Iterable[Tuple[str, str]],
|
||||
time_ms: int,
|
||||
):
|
||||
) -> None:
|
||||
"""Updates the last access time of the given media
|
||||
|
||||
Args:
|
||||
|
@ -467,7 +496,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
time_ms: Current time in milliseconds
|
||||
"""
|
||||
|
||||
def update_cache_txn(txn):
|
||||
def update_cache_txn(txn: LoggingTransaction) -> None:
|
||||
sql = (
|
||||
"UPDATE remote_media_cache SET last_access_ts = ?"
|
||||
" WHERE media_origin = ? AND media_id = ?"
|
||||
|
@ -488,7 +517,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
|
||||
txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media))
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"update_cached_last_access_time", update_cache_txn
|
||||
)
|
||||
|
||||
|
@ -542,15 +571,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
|
||||
async def store_remote_media_thumbnail(
|
||||
self,
|
||||
origin,
|
||||
media_id,
|
||||
filesystem_id,
|
||||
thumbnail_width,
|
||||
thumbnail_height,
|
||||
thumbnail_type,
|
||||
thumbnail_method,
|
||||
thumbnail_length,
|
||||
):
|
||||
origin: str,
|
||||
media_id: str,
|
||||
filesystem_id: str,
|
||||
thumbnail_width: int,
|
||||
thumbnail_height: int,
|
||||
thumbnail_type: str,
|
||||
thumbnail_method: str,
|
||||
thumbnail_length: int,
|
||||
) -> None:
|
||||
await self.db_pool.simple_upsert(
|
||||
table="remote_media_cache_thumbnails",
|
||||
keyvalues={
|
||||
|
@ -566,7 +595,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
desc="store_remote_media_thumbnail",
|
||||
)
|
||||
|
||||
async def get_remote_media_before(self, before_ts):
|
||||
async def get_remote_media_before(self, before_ts: int) -> List[Dict[str, str]]:
|
||||
sql = (
|
||||
"SELECT media_origin, media_id, filesystem_id"
|
||||
" FROM remote_media_cache"
|
||||
|
@ -602,7 +631,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
" LIMIT 500"
|
||||
)
|
||||
|
||||
def _get_expired_url_cache_txn(txn):
|
||||
def _get_expired_url_cache_txn(txn: LoggingTransaction) -> List[str]:
|
||||
txn.execute(sql, (now_ts,))
|
||||
return [row[0] for row in txn]
|
||||
|
||||
|
@ -610,18 +639,16 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
"get_expired_url_cache", _get_expired_url_cache_txn
|
||||
)
|
||||
|
||||
async def delete_url_cache(self, media_ids):
|
||||
async def delete_url_cache(self, media_ids: Collection[str]) -> None:
|
||||
if len(media_ids) == 0:
|
||||
return
|
||||
|
||||
sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
|
||||
|
||||
def _delete_url_cache_txn(txn):
|
||||
def _delete_url_cache_txn(txn: LoggingTransaction) -> None:
|
||||
txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"delete_url_cache", _delete_url_cache_txn
|
||||
)
|
||||
await self.db_pool.runInteraction("delete_url_cache", _delete_url_cache_txn)
|
||||
|
||||
async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
|
||||
sql = (
|
||||
|
@ -631,7 +658,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
" LIMIT 500"
|
||||
)
|
||||
|
||||
def _get_url_cache_media_before_txn(txn):
|
||||
def _get_url_cache_media_before_txn(txn: LoggingTransaction) -> List[str]:
|
||||
txn.execute(sql, (before_ts,))
|
||||
return [row[0] for row in txn]
|
||||
|
||||
|
@ -639,11 +666,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
"get_url_cache_media_before", _get_url_cache_media_before_txn
|
||||
)
|
||||
|
||||
async def delete_url_cache_media(self, media_ids):
|
||||
async def delete_url_cache_media(self, media_ids: Collection[str]) -> None:
|
||||
if len(media_ids) == 0:
|
||||
return
|
||||
|
||||
def _delete_url_cache_media_txn(txn):
|
||||
def _delete_url_cache_media_txn(txn: LoggingTransaction) -> None:
|
||||
sql = "DELETE FROM local_media_repository WHERE media_id = ?"
|
||||
|
||||
txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
|
||||
|
@ -652,6 +679,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
|
||||
txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_url_cache_media", _delete_url_cache_media_txn
|
||||
)
|
||||
|
|
|
@ -1,6 +1,21 @@
|
|||
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
|
||||
|
||||
class OpenIdStore(SQLBaseStore):
|
||||
|
@ -20,7 +35,7 @@ class OpenIdStore(SQLBaseStore):
|
|||
async def get_user_id_for_open_id_token(
|
||||
self, token: str, ts_now_ms: int
|
||||
) -> Optional[str]:
|
||||
def get_user_id_for_token_txn(txn):
|
||||
def get_user_id_for_token_txn(txn: LoggingTransaction) -> Optional[str]:
|
||||
sql = (
|
||||
"SELECT user_id FROM open_id_tokens"
|
||||
" WHERE token = ? AND ? <= ts_valid_until_ms"
|
||||
|
|
|
@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional
|
|||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.storage.databases.main.roommember import ProfileInfo
|
||||
|
||||
|
||||
|
@ -104,7 +105,7 @@ class ProfileWorkerStore(SQLBaseStore):
|
|||
desc="update_remote_profile_cache",
|
||||
)
|
||||
|
||||
async def maybe_delete_remote_profile_cache(self, user_id):
|
||||
async def maybe_delete_remote_profile_cache(self, user_id: str) -> None:
|
||||
"""Check if we still care about the remote user's profile, and if we
|
||||
don't then remove their profile from the cache
|
||||
"""
|
||||
|
@ -116,9 +117,9 @@ class ProfileWorkerStore(SQLBaseStore):
|
|||
desc="delete_remote_profile_cache",
|
||||
)
|
||||
|
||||
async def is_subscribed_remote_profile_for_user(self, user_id):
|
||||
async def is_subscribed_remote_profile_for_user(self, user_id: str) -> bool:
|
||||
"""Check whether we are interested in a remote user's profile."""
|
||||
res = await self.db_pool.simple_select_one_onecol(
|
||||
res: Optional[str] = await self.db_pool.simple_select_one_onecol(
|
||||
table="group_users",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="user_id",
|
||||
|
@ -139,13 +140,16 @@ class ProfileWorkerStore(SQLBaseStore):
|
|||
|
||||
if res:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_remote_profile_cache_entries_that_expire(
|
||||
self, last_checked: int
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Get all users who haven't been checked since `last_checked`"""
|
||||
|
||||
def _get_remote_profile_cache_entries_that_expire_txn(txn):
|
||||
def _get_remote_profile_cache_entries_that_expire_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> List[Dict[str, str]]:
|
||||
sql = """
|
||||
SELECT user_id, displayname, avatar_url
|
||||
FROM remote_profile_cache
|
||||
|
|
|
@ -84,26 +84,26 @@ class TokenLookupResult:
|
|||
return self.user_id
|
||||
|
||||
|
||||
@attr.s(frozen=True, slots=True)
|
||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
||||
class RefreshTokenLookupResult:
|
||||
"""Result of looking up a refresh token."""
|
||||
|
||||
user_id = attr.ib(type=str)
|
||||
user_id: str
|
||||
"""The user this token belongs to."""
|
||||
|
||||
device_id = attr.ib(type=str)
|
||||
device_id: str
|
||||
"""The device associated with this refresh token."""
|
||||
|
||||
token_id = attr.ib(type=int)
|
||||
token_id: int
|
||||
"""The ID of this refresh token."""
|
||||
|
||||
next_token_id = attr.ib(type=Optional[int])
|
||||
next_token_id: Optional[int]
|
||||
"""The ID of the refresh token which replaced this one."""
|
||||
|
||||
has_next_refresh_token_been_refreshed = attr.ib(type=bool)
|
||||
has_next_refresh_token_been_refreshed: bool
|
||||
"""True if the next refresh token was used for another refresh."""
|
||||
|
||||
has_next_access_token_been_used = attr.ib(type=bool)
|
||||
has_next_access_token_been_used: bool
|
||||
"""True if the next access token was already used at least once."""
|
||||
|
||||
|
||||
|
@ -476,7 +476,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
shadow_banned: true iff the user is to be shadow-banned, false otherwise.
|
||||
"""
|
||||
|
||||
def set_shadow_banned_txn(txn):
|
||||
def set_shadow_banned_txn(txn: LoggingTransaction) -> None:
|
||||
user_id = user.to_string()
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn,
|
||||
|
@ -1198,8 +1198,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
expiration_ts = now_ms + self._account_validity_period
|
||||
|
||||
if use_delta:
|
||||
assert self._account_validity_startup_job_max_delta is not None
|
||||
expiration_ts = random.randrange(
|
||||
expiration_ts - self._account_validity_startup_job_max_delta,
|
||||
int(expiration_ts - self._account_validity_startup_job_max_delta),
|
||||
expiration_ts,
|
||||
)
|
||||
|
||||
|
@ -1728,11 +1729,11 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
|||
)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
"user_threepids_grandfather", self._bg_user_threepids_grandfather
|
||||
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
|
||||
self.db_pool.updates.register_noop_background_update(
|
||||
"user_threepids_grandfather"
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
|
@ -1805,35 +1806,6 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
|||
|
||||
return nb_processed
|
||||
|
||||
async def _bg_user_threepids_grandfather(self, progress, batch_size):
|
||||
"""We now track which identity servers a user binds their 3PID to, so
|
||||
we need to handle the case of existing bindings where we didn't track
|
||||
this.
|
||||
|
||||
We do this by grandfathering in existing user threepids assuming that
|
||||
they used one of the server configured trusted identity servers.
|
||||
"""
|
||||
id_servers = set(self.config.registration.trusted_third_party_id_servers)
|
||||
|
||||
def _bg_user_threepids_grandfather_txn(txn):
|
||||
sql = """
|
||||
INSERT INTO user_threepid_id_server
|
||||
(user_id, medium, address, id_server)
|
||||
SELECT user_id, medium, address, ?
|
||||
FROM user_threepids
|
||||
"""
|
||||
|
||||
txn.execute_batch(sql, [(id_server,) for id_server in id_servers])
|
||||
|
||||
if id_servers:
|
||||
await self.db_pool.runInteraction(
|
||||
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
|
||||
)
|
||||
|
||||
await self.db_pool.updates._end_background_update("user_threepids_grandfather")
|
||||
|
||||
return 1
|
||||
|
||||
async def set_user_deactivated_status(
|
||||
self, user_id: str, deactivated: bool
|
||||
) -> None:
|
||||
|
|
|
@ -20,7 +20,7 @@ import attr
|
|||
from synapse.api.constants import RelationTypes
|
||||
from synapse.events import EventBase
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
|
||||
from synapse.storage.databases.main.stream import generate_pagination_where_clause
|
||||
from synapse.storage.relations import (
|
||||
AggregationPaginationToken,
|
||||
|
@ -132,6 +132,69 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
"get_recent_references_for_event", _get_recent_references_for_event_txn
|
||||
)
|
||||
|
||||
async def event_includes_relation(self, event_id: str) -> bool:
|
||||
"""Check if the given event relates to another event.
|
||||
|
||||
An event has a relation if it has a valid m.relates_to with a rel_type
|
||||
and event_id in the content:
|
||||
|
||||
{
|
||||
"content": {
|
||||
"m.relates_to": {
|
||||
"rel_type": "m.replace",
|
||||
"event_id": "$other_event_id"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Args:
|
||||
event_id: The event to check.
|
||||
|
||||
Returns:
|
||||
True if the event includes a valid relation.
|
||||
"""
|
||||
|
||||
result = await self.db_pool.simple_select_one_onecol(
|
||||
table="event_relations",
|
||||
keyvalues={"event_id": event_id},
|
||||
retcol="event_id",
|
||||
allow_none=True,
|
||||
desc="event_includes_relation",
|
||||
)
|
||||
return result is not None
|
||||
|
||||
async def event_is_target_of_relation(self, parent_id: str) -> bool:
|
||||
"""Check if the given event is the target of another event's relation.
|
||||
|
||||
An event is the target of an event relation if it has a valid
|
||||
m.relates_to with a rel_type and event_id pointing to parent_id in the
|
||||
content:
|
||||
|
||||
{
|
||||
"content": {
|
||||
"m.relates_to": {
|
||||
"rel_type": "m.replace",
|
||||
"event_id": "$parent_id"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Args:
|
||||
parent_id: The event to check.
|
||||
|
||||
Returns:
|
||||
True if the event is the target of another event's relation.
|
||||
"""
|
||||
|
||||
result = await self.db_pool.simple_select_one_onecol(
|
||||
table="event_relations",
|
||||
keyvalues={"relates_to_id": parent_id},
|
||||
retcol="event_id",
|
||||
allow_none=True,
|
||||
desc="event_is_target_of_relation",
|
||||
)
|
||||
return result is not None
|
||||
|
||||
@cached(tree=True)
|
||||
async def get_aggregation_groups_for_event(
|
||||
self,
|
||||
|
@ -334,6 +397,62 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
|
||||
return count, latest_event
|
||||
|
||||
async def events_have_relations(
|
||||
self,
|
||||
parent_ids: List[str],
|
||||
relation_senders: Optional[List[str]],
|
||||
relation_types: Optional[List[str]],
|
||||
) -> List[str]:
|
||||
"""Check which events have a relationship from the given senders of the
|
||||
given types.
|
||||
|
||||
Args:
|
||||
parent_ids: The events being annotated
|
||||
relation_senders: The relation senders to check.
|
||||
relation_types: The relation types to check.
|
||||
|
||||
Returns:
|
||||
True if the event has at least one relationship from one of the given senders of the given type.
|
||||
"""
|
||||
# If no restrictions are given then the event has the required relations.
|
||||
if not relation_senders and not relation_types:
|
||||
return parent_ids
|
||||
|
||||
sql = """
|
||||
SELECT relates_to_id FROM event_relations
|
||||
INNER JOIN events USING (event_id)
|
||||
WHERE
|
||||
%s;
|
||||
"""
|
||||
|
||||
def _get_if_events_have_relations(txn) -> List[str]:
|
||||
clauses: List[str] = []
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "relates_to_id", parent_ids
|
||||
)
|
||||
clauses.append(clause)
|
||||
|
||||
if relation_senders:
|
||||
clause, temp_args = make_in_list_sql_clause(
|
||||
txn.database_engine, "sender", relation_senders
|
||||
)
|
||||
clauses.append(clause)
|
||||
args.extend(temp_args)
|
||||
if relation_types:
|
||||
clause, temp_args = make_in_list_sql_clause(
|
||||
txn.database_engine, "relation_type", relation_types
|
||||
)
|
||||
clauses.append(clause)
|
||||
args.extend(temp_args)
|
||||
|
||||
txn.execute(sql % " AND ".join(clauses), args)
|
||||
|
||||
return [row[0] for row in txn]
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_if_events_have_relations", _get_if_events_have_relations
|
||||
)
|
||||
|
||||
async def has_user_annotated_event(
|
||||
self, parent_id: str, event_type: str, aggregation_key: str, sender: str
|
||||
) -> bool:
|
||||
|
|
|
@ -397,6 +397,20 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
desc="is_room_blocked",
|
||||
)
|
||||
|
||||
async def room_is_blocked_by(self, room_id: str) -> Optional[str]:
|
||||
"""
|
||||
Function to retrieve user who has blocked the room.
|
||||
user_id is non-nullable
|
||||
It returns None if the room is not blocked.
|
||||
"""
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="blocked_rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcol="user_id",
|
||||
allow_none=True,
|
||||
desc="room_is_blocked_by",
|
||||
)
|
||||
|
||||
async def get_rooms_paginate(
|
||||
self,
|
||||
start: int,
|
||||
|
@ -1751,7 +1765,12 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
)
|
||||
|
||||
async def block_room(self, room_id: str, user_id: str) -> None:
|
||||
"""Marks the room as blocked. Can be called multiple times.
|
||||
"""Marks the room as blocked.
|
||||
|
||||
Can be called multiple times (though we'll only track the last user to
|
||||
block this room).
|
||||
|
||||
Can be called on a room unknown to this homeserver.
|
||||
|
||||
Args:
|
||||
room_id: Room to block
|
||||
|
@ -1770,3 +1789,21 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
self.is_room_blocked,
|
||||
(room_id,),
|
||||
)
|
||||
|
||||
async def unblock_room(self, room_id: str) -> None:
|
||||
"""Remove the room from blocking list.
|
||||
|
||||
Args:
|
||||
room_id: Room to unblock
|
||||
"""
|
||||
await self.db_pool.simple_delete(
|
||||
table="blocked_rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
desc="unblock_room",
|
||||
)
|
||||
await self.db_pool.runInteraction(
|
||||
"block_room_invalidation",
|
||||
self._invalidate_cache_and_stream,
|
||||
self.is_room_blocked,
|
||||
(room_id,),
|
||||
)
|
||||
|
|
|
@ -39,13 +39,11 @@ class RoomBatchStore(SQLBaseStore):
|
|||
|
||||
async def store_state_group_id_for_event_id(
|
||||
self, event_id: str, state_group_id: int
|
||||
) -> Optional[str]:
|
||||
{
|
||||
await self.db_pool.simple_upsert(
|
||||
table="event_to_state_groups",
|
||||
keyvalues={"event_id": event_id},
|
||||
values={"state_group": state_group_id, "event_id": event_id},
|
||||
# Unique constraint on event_id so we don't have to lock
|
||||
lock=False,
|
||||
)
|
||||
}
|
||||
) -> None:
|
||||
await self.db_pool.simple_upsert(
|
||||
table="event_to_state_groups",
|
||||
keyvalues={"event_id": event_id},
|
||||
values={"state_group": state_group_id, "event_id": event_id},
|
||||
# Unique constraint on event_id so we don't have to lock
|
||||
lock=False,
|
||||
)
|
||||
|
|
|
@ -63,12 +63,12 @@ class SignatureWorkerStore(SQLBaseStore):
|
|||
A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash.
|
||||
"""
|
||||
hashes = await self.get_event_reference_hashes(event_ids)
|
||||
hashes = {
|
||||
encoded_hashes = {
|
||||
e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
|
||||
for e_id, h in hashes.items()
|
||||
}
|
||||
|
||||
return list(hashes.items())
|
||||
return list(encoded_hashes.items())
|
||||
|
||||
def _get_event_reference_hashes_txn(
|
||||
self, txn: Cursor, event_id: str
|
||||
|
|
|
@ -16,11 +16,17 @@ import logging
|
|||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StateDeltasStore(SQLBaseStore):
|
||||
# This class must be mixed in with a child class which provides the following
|
||||
# attribute. TODO: can we get static analysis to enforce this?
|
||||
_curr_state_delta_stream_cache: StreamChangeCache
|
||||
|
||||
async def get_current_state_deltas(
|
||||
self, prev_stream_id: int, max_stream_id: int
|
||||
) -> Tuple[int, List[Dict[str, Any]]]:
|
||||
|
@ -60,7 +66,9 @@ class StateDeltasStore(SQLBaseStore):
|
|||
# max_stream_id.
|
||||
return max_stream_id, []
|
||||
|
||||
def get_current_state_deltas_txn(txn):
|
||||
def get_current_state_deltas_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[int, List[Dict[str, Any]]]:
|
||||
# First we calculate the max stream id that will give us less than
|
||||
# N results.
|
||||
# We arbitrarily limit to 100 stream_id entries to ensure we don't
|
||||
|
@ -106,7 +114,9 @@ class StateDeltasStore(SQLBaseStore):
|
|||
"get_current_state_deltas", get_current_state_deltas_txn
|
||||
)
|
||||
|
||||
def _get_max_stream_id_in_current_state_deltas_txn(self, txn):
|
||||
def _get_max_stream_id_in_current_state_deltas_txn(
|
||||
self, txn: LoggingTransaction
|
||||
) -> int:
|
||||
return self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="current_state_delta_stream",
|
||||
|
@ -114,7 +124,7 @@ class StateDeltasStore(SQLBaseStore):
|
|||
retcol="COALESCE(MAX(stream_id), -1)",
|
||||
)
|
||||
|
||||
async def get_max_stream_id_in_current_state_deltas(self):
|
||||
async def get_max_stream_id_in_current_state_deltas(self) -> int:
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_max_stream_id_in_current_state_deltas",
|
||||
self._get_max_stream_id_in_current_state_deltas_txn,
|
||||
|
|
|
@ -272,31 +272,37 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
|
|||
args = []
|
||||
|
||||
if event_filter.types:
|
||||
clauses.append("(%s)" % " OR ".join("type = ?" for _ in event_filter.types))
|
||||
clauses.append(
|
||||
"(%s)" % " OR ".join("event.type = ?" for _ in event_filter.types)
|
||||
)
|
||||
args.extend(event_filter.types)
|
||||
|
||||
for typ in event_filter.not_types:
|
||||
clauses.append("type != ?")
|
||||
clauses.append("event.type != ?")
|
||||
args.append(typ)
|
||||
|
||||
if event_filter.senders:
|
||||
clauses.append("(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders))
|
||||
clauses.append(
|
||||
"(%s)" % " OR ".join("event.sender = ?" for _ in event_filter.senders)
|
||||
)
|
||||
args.extend(event_filter.senders)
|
||||
|
||||
for sender in event_filter.not_senders:
|
||||
clauses.append("sender != ?")
|
||||
clauses.append("event.sender != ?")
|
||||
args.append(sender)
|
||||
|
||||
if event_filter.rooms:
|
||||
clauses.append("(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms))
|
||||
clauses.append(
|
||||
"(%s)" % " OR ".join("event.room_id = ?" for _ in event_filter.rooms)
|
||||
)
|
||||
args.extend(event_filter.rooms)
|
||||
|
||||
for room_id in event_filter.not_rooms:
|
||||
clauses.append("room_id != ?")
|
||||
clauses.append("event.room_id != ?")
|
||||
args.append(room_id)
|
||||
|
||||
if event_filter.contains_url:
|
||||
clauses.append("contains_url = ?")
|
||||
clauses.append("event.contains_url = ?")
|
||||
args.append(event_filter.contains_url)
|
||||
|
||||
# We're only applying the "labels" filter on the database query, because applying the
|
||||
|
@ -307,6 +313,23 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
|
|||
clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels))
|
||||
args.extend(event_filter.labels)
|
||||
|
||||
# Filter on relation_senders / relation types from the joined tables.
|
||||
if event_filter.relation_senders:
|
||||
clauses.append(
|
||||
"(%s)"
|
||||
% " OR ".join(
|
||||
"related_event.sender = ?" for _ in event_filter.relation_senders
|
||||
)
|
||||
)
|
||||
args.extend(event_filter.relation_senders)
|
||||
|
||||
if event_filter.relation_types:
|
||||
clauses.append(
|
||||
"(%s)"
|
||||
% " OR ".join("relation_type = ?" for _ in event_filter.relation_types)
|
||||
)
|
||||
args.extend(event_filter.relation_types)
|
||||
|
||||
return " AND ".join(clauses), args
|
||||
|
||||
|
||||
|
@ -1116,7 +1139,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
|||
|
||||
bounds = generate_pagination_where_clause(
|
||||
direction=direction,
|
||||
column_names=("topological_ordering", "stream_ordering"),
|
||||
column_names=("event.topological_ordering", "event.stream_ordering"),
|
||||
from_token=from_bound,
|
||||
to_token=to_bound,
|
||||
engine=self.database_engine,
|
||||
|
@ -1133,32 +1156,51 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
|||
|
||||
select_keywords = "SELECT"
|
||||
join_clause = ""
|
||||
# Using DISTINCT in this SELECT query is quite expensive, because it
|
||||
# requires the engine to sort on the entire (not limited) result set,
|
||||
# i.e. the entire events table. Only use it in scenarios that could result
|
||||
# in the same event ID occurring multiple times in the results.
|
||||
needs_distinct = False
|
||||
if event_filter and event_filter.labels:
|
||||
# If we're not filtering on a label, then joining on event_labels will
|
||||
# return as many row for a single event as the number of labels it has. To
|
||||
# avoid this, only join if we're filtering on at least one label.
|
||||
join_clause = """
|
||||
join_clause += """
|
||||
LEFT JOIN event_labels
|
||||
USING (event_id, room_id, topological_ordering)
|
||||
"""
|
||||
if len(event_filter.labels) > 1:
|
||||
# Using DISTINCT in this SELECT query is quite expensive, because it
|
||||
# requires the engine to sort on the entire (not limited) result set,
|
||||
# i.e. the entire events table. We only need to use it when we're
|
||||
# filtering on more than two labels, because that's the only scenario
|
||||
# in which we can possibly to get multiple times the same event ID in
|
||||
# the results.
|
||||
select_keywords += "DISTINCT"
|
||||
# Multiple labels could cause the same event to appear multiple times.
|
||||
needs_distinct = True
|
||||
|
||||
# If there is a filter on relation_senders and relation_types join to the
|
||||
# relations table.
|
||||
if event_filter and (
|
||||
event_filter.relation_senders or event_filter.relation_types
|
||||
):
|
||||
# Filtering by relations could cause the same event to appear multiple
|
||||
# times (since there's no limit on the number of relations to an event).
|
||||
needs_distinct = True
|
||||
join_clause += """
|
||||
LEFT JOIN event_relations AS relation ON (event.event_id = relation.relates_to_id)
|
||||
"""
|
||||
if event_filter.relation_senders:
|
||||
join_clause += """
|
||||
LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id)
|
||||
"""
|
||||
|
||||
if needs_distinct:
|
||||
select_keywords += " DISTINCT"
|
||||
|
||||
sql = """
|
||||
%(select_keywords)s
|
||||
event_id, instance_name,
|
||||
topological_ordering, stream_ordering
|
||||
FROM events
|
||||
event.event_id, event.instance_name,
|
||||
event.topological_ordering, event.stream_ordering
|
||||
FROM events AS event
|
||||
%(join_clause)s
|
||||
WHERE outlier = ? AND room_id = ? AND %(bounds)s
|
||||
ORDER BY topological_ordering %(order)s,
|
||||
stream_ordering %(order)s LIMIT ?
|
||||
WHERE event.outlier = ? AND event.room_id = ? AND %(bounds)s
|
||||
ORDER BY event.topological_ordering %(order)s,
|
||||
event.stream_ordering %(order)s LIMIT ?
|
||||
""" % {
|
||||
"select_keywords": select_keywords,
|
||||
"join_clause": join_clause,
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -14,9 +15,10 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Tuple, cast
|
||||
|
||||
from synapse.storage._base import db_to_json
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
|
@ -50,7 +52,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
|||
|
||||
async def get_all_updated_tags(
|
||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||
) -> Tuple[List[Tuple[int, Tuple[str, str, str]]], int, bool]:
|
||||
"""Get updates for tags replication stream.
|
||||
|
||||
Args:
|
||||
|
@ -75,7 +77,9 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
|||
if last_id == current_id:
|
||||
return [], current_id, False
|
||||
|
||||
def get_all_updated_tags_txn(txn):
|
||||
def get_all_updated_tags_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> List[Tuple[int, str, str]]:
|
||||
sql = (
|
||||
"SELECT stream_id, user_id, room_id"
|
||||
" FROM room_tags_revisions as r"
|
||||
|
@ -83,13 +87,16 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
|||
" ORDER BY stream_id ASC LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
return txn.fetchall()
|
||||
# mypy doesn't understand what the query is selecting.
|
||||
return cast(List[Tuple[int, str, str]], txn.fetchall())
|
||||
|
||||
tag_ids = await self.db_pool.runInteraction(
|
||||
"get_all_updated_tags", get_all_updated_tags_txn
|
||||
)
|
||||
|
||||
def get_tag_content(txn, tag_ids):
|
||||
def get_tag_content(
|
||||
txn: LoggingTransaction, tag_ids
|
||||
) -> List[Tuple[int, Tuple[str, str, str]]]:
|
||||
sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
|
||||
results = []
|
||||
for stream_id, user_id, room_id in tag_ids:
|
||||
|
@ -127,15 +134,15 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
|||
given version
|
||||
|
||||
Args:
|
||||
user_id(str): The user to get the tags for.
|
||||
stream_id(int): The earliest update to get for the user.
|
||||
user_id: The user to get the tags for.
|
||||
stream_id: The earliest update to get for the user.
|
||||
|
||||
Returns:
|
||||
A mapping from room_id strings to lists of tag strings for all the
|
||||
rooms that changed since the stream_id token.
|
||||
"""
|
||||
|
||||
def get_updated_tags_txn(txn):
|
||||
def get_updated_tags_txn(txn: LoggingTransaction) -> List[str]:
|
||||
sql = (
|
||||
"SELECT room_id from room_tags_revisions"
|
||||
" WHERE user_id = ? AND stream_id > ?"
|
||||
|
@ -200,7 +207,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
|||
|
||||
content_json = json_encoder.encode(content)
|
||||
|
||||
def add_tag_txn(txn, next_id):
|
||||
def add_tag_txn(txn: LoggingTransaction, next_id: int) -> None:
|
||||
self.db_pool.simple_upsert_txn(
|
||||
txn,
|
||||
table="room_tags",
|
||||
|
@ -224,7 +231,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
|||
"""
|
||||
assert self._can_write_to_account_data
|
||||
|
||||
def remove_tag_txn(txn, next_id):
|
||||
def remove_tag_txn(txn: LoggingTransaction, next_id: int) -> None:
|
||||
sql = (
|
||||
"DELETE FROM room_tags "
|
||||
" WHERE user_id = ? AND room_id = ? AND tag = ?"
|
||||
|
|
|
@ -14,11 +14,12 @@
|
|||
|
||||
from typing import Dict, Iterable
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.storage.databases.main import CacheInvalidationWorkerStore
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
|
||||
|
||||
class UserErasureWorkerStore(SQLBaseStore):
|
||||
class UserErasureWorkerStore(CacheInvalidationWorkerStore):
|
||||
@cached()
|
||||
async def is_user_erased(self, user_id: str) -> bool:
|
||||
"""
|
||||
|
@ -69,7 +70,7 @@ class UserErasureStore(UserErasureWorkerStore):
|
|||
user_id: full user_id to be erased
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
def f(txn: LoggingTransaction) -> None:
|
||||
# first check if they are already in the list
|
||||
txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,))
|
||||
if txn.fetchone():
|
||||
|
@ -89,7 +90,7 @@ class UserErasureStore(UserErasureWorkerStore):
|
|||
user_id: full user_id to be un-erased
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
def f(txn: LoggingTransaction) -> None:
|
||||
# first check if they are already in the list
|
||||
txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,))
|
||||
if not txn.fetchone():
|
||||
|
|
|
@ -45,10 +45,13 @@ Changes in SCHEMA_VERSION = 64:
|
|||
Changes in SCHEMA_VERSION = 65:
|
||||
- MSC2716: Remove unique event_id constraint from insertion_event_edges
|
||||
because an insertion event can have multiple edges.
|
||||
- Remove unused tables `user_stats_historical` and `room_stats_historical`.
|
||||
"""
|
||||
|
||||
|
||||
SCHEMA_COMPAT_VERSION = 60 # 60: "outlier" not in internal_metadata.
|
||||
SCHEMA_COMPAT_VERSION = (
|
||||
61 # 61: Remove unused tables `user_stats_historical` and `room_stats_historical`
|
||||
)
|
||||
"""Limit on how far the synapse codebase can be rolled back without breaking db compat
|
||||
|
||||
This value is stored in the database, and checked on startup. If the value in the
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
/* Copyright 2021 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- Remove unused tables room_stats_historical and user_stats_historical
|
||||
-- which have not been read or written since schema version 61.
|
||||
DROP TABLE IF EXISTS room_stats_historical;
|
||||
DROP TABLE IF EXISTS user_stats_historical;
|
|
@ -15,4 +15,4 @@
|
|||
|
||||
-- Check old events for thread relations.
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(6502, 'event_thread_relation', '{}');
|
||||
(6507, 'event_arbitrary_relations', '{}');
|
|
@ -0,0 +1,18 @@
|
|||
/* Copyright 2021 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- Background update to clear the inboxes of hidden and deleted devices.
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(6508, 'remove_dead_devices_from_device_inbox', '{}');
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -11,6 +12,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import heapq
|
||||
import logging
|
||||
import threading
|
||||
|
@ -87,7 +89,25 @@ def _load_current_id(
|
|||
return (max if step > 0 else min)(current_id, step)
|
||||
|
||||
|
||||
class StreamIdGenerator:
|
||||
class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def get_next(self) -> AsyncContextManager[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_current_token(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_current_token_for_writer(self, instance_name: str) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class StreamIdGenerator(AbstractStreamIdGenerator):
|
||||
"""Used to generate new stream ids when persisting events while keeping
|
||||
track of which transactions have been completed.
|
||||
|
||||
|
@ -209,7 +229,7 @@ class StreamIdGenerator:
|
|||
return self.get_current_token()
|
||||
|
||||
|
||||
class MultiWriterIdGenerator:
|
||||
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||
"""An ID generator that tracks a stream that can have multiple writers.
|
||||
|
||||
Uses a Postgres sequence to coordinate ID assignment, but positions of other
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue