mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-08-15 15:00:14 -04:00
Merge remote-tracking branch 'upstream/release-v1.57'
This commit is contained in:
commit
b2fa6ec9f6
248 changed files with 14616 additions and 8934 deletions
|
@ -241,9 +241,17 @@ class LoggingTransaction:
|
|||
self.exception_callbacks = exception_callbacks
|
||||
|
||||
def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any):
|
||||
"""Call the given callback on the main twisted thread after the
|
||||
transaction has finished. Used to invalidate the caches on the
|
||||
correct thread.
|
||||
"""Call the given callback on the main twisted thread after the transaction has
|
||||
finished.
|
||||
|
||||
Mostly used to invalidate the caches on the correct thread.
|
||||
|
||||
Note that transactions may be retried a few times if they encounter database
|
||||
errors such as serialization failures. Callbacks given to `call_after`
|
||||
will accumulate across transaction attempts and will _all_ be called once a
|
||||
transaction attempt succeeds, regardless of whether previous transaction
|
||||
attempts failed. Otherwise, if all transaction attempts fail, all
|
||||
`call_on_exception` callbacks will be run instead.
|
||||
"""
|
||||
# if self.after_callbacks is None, that means that whatever constructed the
|
||||
# LoggingTransaction isn't expecting there to be any callbacks; assert that
|
||||
|
@ -254,6 +262,15 @@ class LoggingTransaction:
|
|||
def call_on_exception(
|
||||
self, callback: Callable[..., object], *args: Any, **kwargs: Any
|
||||
):
|
||||
"""Call the given callback on the main twisted thread after the transaction has
|
||||
failed.
|
||||
|
||||
Note that transactions may be retried a few times if they encounter database
|
||||
errors such as serialization failures. Callbacks given to `call_on_exception`
|
||||
will accumulate across transaction attempts and will _all_ be called once the
|
||||
final transaction attempt fails. No `call_on_exception` callbacks will be run
|
||||
if any transaction attempt succeeds.
|
||||
"""
|
||||
# if self.exception_callbacks is None, that means that whatever constructed the
|
||||
# LoggingTransaction isn't expecting there to be any callbacks; assert that
|
||||
# is not the case.
|
||||
|
@ -1251,6 +1268,7 @@ class DatabasePool:
|
|||
value_names: Collection[str],
|
||||
value_values: Collection[Collection[Any]],
|
||||
desc: str,
|
||||
lock: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Upsert, many times.
|
||||
|
@ -1262,6 +1280,8 @@ class DatabasePool:
|
|||
value_names: The value column names
|
||||
value_values: A list of each row's value column values.
|
||||
Ignored if value_names is empty.
|
||||
lock: True to lock the table when doing the upsert. Unused if the database engine
|
||||
supports native upserts.
|
||||
"""
|
||||
|
||||
# We can autocommit if we are going to use native upserts
|
||||
|
@ -1269,7 +1289,7 @@ class DatabasePool:
|
|||
self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables
|
||||
)
|
||||
|
||||
return await self.runInteraction(
|
||||
await self.runInteraction(
|
||||
desc,
|
||||
self.simple_upsert_many_txn,
|
||||
table,
|
||||
|
@ -1277,6 +1297,7 @@ class DatabasePool:
|
|||
key_values,
|
||||
value_names,
|
||||
value_values,
|
||||
lock=lock,
|
||||
db_autocommit=autocommit,
|
||||
)
|
||||
|
||||
|
@ -1288,6 +1309,7 @@ class DatabasePool:
|
|||
key_values: Collection[Iterable[Any]],
|
||||
value_names: Collection[str],
|
||||
value_values: Iterable[Iterable[Any]],
|
||||
lock: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Upsert, many times.
|
||||
|
@ -1299,6 +1321,8 @@ class DatabasePool:
|
|||
value_names: The value column names
|
||||
value_values: A list of each row's value column values.
|
||||
Ignored if value_names is empty.
|
||||
lock: True to lock the table when doing the upsert. Unused if the database engine
|
||||
supports native upserts.
|
||||
"""
|
||||
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
|
||||
return self.simple_upsert_many_txn_native_upsert(
|
||||
|
@ -1306,7 +1330,7 @@ class DatabasePool:
|
|||
)
|
||||
else:
|
||||
return self.simple_upsert_many_txn_emulated(
|
||||
txn, table, key_names, key_values, value_names, value_values
|
||||
txn, table, key_names, key_values, value_names, value_values, lock=lock
|
||||
)
|
||||
|
||||
def simple_upsert_many_txn_emulated(
|
||||
|
@ -1317,6 +1341,7 @@ class DatabasePool:
|
|||
key_values: Collection[Iterable[Any]],
|
||||
value_names: Collection[str],
|
||||
value_values: Iterable[Iterable[Any]],
|
||||
lock: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Upsert, many times, but without native UPSERT support or batching.
|
||||
|
@ -1328,17 +1353,24 @@ class DatabasePool:
|
|||
value_names: The value column names
|
||||
value_values: A list of each row's value column values.
|
||||
Ignored if value_names is empty.
|
||||
lock: True to lock the table when doing the upsert.
|
||||
"""
|
||||
# No value columns, therefore make a blank list so that the following
|
||||
# zip() works correctly.
|
||||
if not value_names:
|
||||
value_values = [() for x in range(len(key_values))]
|
||||
|
||||
if lock:
|
||||
# Lock the table just once, to prevent it being done once per row.
|
||||
# Note that, according to Postgres' documentation, once obtained,
|
||||
# the lock is held for the remainder of the current transaction.
|
||||
self.engine.lock_table(txn, "user_ips")
|
||||
|
||||
for keyv, valv in zip(key_values, value_values):
|
||||
_keys = {x: y for x, y in zip(key_names, keyv)}
|
||||
_vals = {x: y for x, y in zip(value_names, valv)}
|
||||
|
||||
self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
|
||||
self.simple_upsert_txn_emulated(txn, table, _keys, _vals, lock=False)
|
||||
|
||||
def simple_upsert_many_txn_native_upsert(
|
||||
self,
|
||||
|
@ -1775,6 +1807,86 @@ class DatabasePool:
|
|||
|
||||
return txn.rowcount
|
||||
|
||||
async def simple_update_many(
|
||||
self,
|
||||
table: str,
|
||||
key_names: Collection[str],
|
||||
key_values: Collection[Iterable[Any]],
|
||||
value_names: Collection[str],
|
||||
value_values: Iterable[Iterable[Any]],
|
||||
desc: str,
|
||||
) -> None:
|
||||
"""
|
||||
Update, many times, using batching where possible.
|
||||
If the keys don't match anything, nothing will be updated.
|
||||
|
||||
Args:
|
||||
table: The table to update
|
||||
key_names: The key column names.
|
||||
key_values: A list of each row's key column values.
|
||||
value_names: The names of value columns to update.
|
||||
value_values: A list of each row's value column values.
|
||||
"""
|
||||
|
||||
await self.runInteraction(
|
||||
desc,
|
||||
self.simple_update_many_txn,
|
||||
table,
|
||||
key_names,
|
||||
key_values,
|
||||
value_names,
|
||||
value_values,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def simple_update_many_txn(
|
||||
txn: LoggingTransaction,
|
||||
table: str,
|
||||
key_names: Collection[str],
|
||||
key_values: Collection[Iterable[Any]],
|
||||
value_names: Collection[str],
|
||||
value_values: Collection[Iterable[Any]],
|
||||
) -> None:
|
||||
"""
|
||||
Update, many times, using batching where possible.
|
||||
If the keys don't match anything, nothing will be updated.
|
||||
|
||||
Args:
|
||||
table: The table to update
|
||||
key_names: The key column names.
|
||||
key_values: A list of each row's key column values.
|
||||
value_names: The names of value columns to update.
|
||||
value_values: A list of each row's value column values.
|
||||
"""
|
||||
|
||||
if len(value_values) != len(key_values):
|
||||
raise ValueError(
|
||||
f"{len(key_values)} key rows and {len(value_values)} value rows: should be the same number."
|
||||
)
|
||||
|
||||
# List of tuples of (value values, then key values)
|
||||
# (This matches the order needed for the query)
|
||||
args = [tuple(x) + tuple(y) for x, y in zip(value_values, key_values)]
|
||||
|
||||
for ks, vs in zip(key_values, value_values):
|
||||
args.append(tuple(vs) + tuple(ks))
|
||||
|
||||
# 'col1 = ?, col2 = ?, ...'
|
||||
set_clause = ", ".join(f"{n} = ?" for n in value_names)
|
||||
|
||||
if key_names:
|
||||
# 'WHERE col3 = ? AND col4 = ? AND col5 = ?'
|
||||
where_clause = "WHERE " + (" AND ".join(f"{n} = ?" for n in key_names))
|
||||
else:
|
||||
where_clause = ""
|
||||
|
||||
# UPDATE mytable SET col1 = ?, col2 = ? WHERE col3 = ? AND col4 = ?
|
||||
sql = f"""
|
||||
UPDATE {table} SET {set_clause} {where_clause}
|
||||
"""
|
||||
|
||||
txn.execute_batch(sql, args)
|
||||
|
||||
async def simple_update_one(
|
||||
self,
|
||||
table: str,
|
||||
|
@ -2013,29 +2125,40 @@ class DatabasePool:
|
|||
max_value: int,
|
||||
limit: int = 100000,
|
||||
) -> Tuple[Dict[Any, int], int]:
|
||||
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
|
||||
# It doesn't really matter how many we get, the StreamChangeCache will
|
||||
# do the right thing to ensure it respects the max size of cache.
|
||||
sql = (
|
||||
"SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
|
||||
" WHERE %(stream)s > ? - %(limit)s"
|
||||
" GROUP BY %(entity)s"
|
||||
) % {
|
||||
"table": table,
|
||||
"entity": entity_column,
|
||||
"stream": stream_column,
|
||||
"limit": limit,
|
||||
}
|
||||
"""Gets roughly the last N changes in the given stream table as a
|
||||
map from entity to the stream ID of the most recent change.
|
||||
|
||||
Also returns the minimum stream ID.
|
||||
"""
|
||||
|
||||
# This may return many rows for the same entity, but the `limit` is only
|
||||
# a suggestion so we don't care that much.
|
||||
#
|
||||
# Note: Some stream tables can have multiple rows with the same stream
|
||||
# ID. Instead of handling this with complicated SQL, we instead simply
|
||||
# add one to the returned minimum stream ID to ensure correctness.
|
||||
sql = f"""
|
||||
SELECT {entity_column}, {stream_column}
|
||||
FROM {table}
|
||||
ORDER BY {stream_column} DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
txn = db_conn.cursor(txn_name="get_cache_dict")
|
||||
txn.execute(sql, (int(max_value),))
|
||||
txn.execute(sql, (limit,))
|
||||
|
||||
cache = {row[0]: int(row[1]) for row in txn}
|
||||
# The rows come out in reverse stream ID order, so we want to keep the
|
||||
# stream ID of the first row for each entity.
|
||||
cache: Dict[Any, int] = {}
|
||||
for row in txn:
|
||||
cache.setdefault(row[0], int(row[1]))
|
||||
|
||||
txn.close()
|
||||
|
||||
if cache:
|
||||
min_val = min(cache.values())
|
||||
# We add one here as we don't know if we have all rows for the
|
||||
# minimum stream ID.
|
||||
min_val = min(cache.values()) + 1
|
||||
else:
|
||||
min_val = max_value
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ from .account_data import AccountDataStore
|
|||
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
|
||||
from .cache import CacheInvalidationWorkerStore
|
||||
from .censor_events import CensorEventsStore
|
||||
from .client_ips import ClientIpStore
|
||||
from .client_ips import ClientIpWorkerStore
|
||||
from .deviceinbox import DeviceInboxStore
|
||||
from .devices import DeviceStore
|
||||
from .directory import DirectoryStore
|
||||
|
@ -49,7 +49,7 @@ from .keys import KeyStore
|
|||
from .lock import LockStore
|
||||
from .media_repository import MediaRepositoryStore
|
||||
from .metrics import ServerMetricsStore
|
||||
from .monthly_active_users import MonthlyActiveUsersStore
|
||||
from .monthly_active_users import MonthlyActiveUsersWorkerStore
|
||||
from .openid import OpenIdStore
|
||||
from .presence import PresenceStore
|
||||
from .profile import ProfileStore
|
||||
|
@ -112,13 +112,13 @@ class DataStore(
|
|||
AccountDataStore,
|
||||
EventPushActionsStore,
|
||||
OpenIdStore,
|
||||
ClientIpStore,
|
||||
ClientIpWorkerStore,
|
||||
DeviceStore,
|
||||
DeviceInboxStore,
|
||||
UserDirectoryStore,
|
||||
GroupServerStore,
|
||||
UserErasureStore,
|
||||
MonthlyActiveUsersStore,
|
||||
MonthlyActiveUsersWorkerStore,
|
||||
StatsStore,
|
||||
RelationsStore,
|
||||
CensorEventsStore,
|
||||
|
@ -146,6 +146,7 @@ class DataStore(
|
|||
extra_tables=[
|
||||
("user_signature_stream", "stream_id"),
|
||||
("device_lists_outbound_pokes", "stream_id"),
|
||||
("device_lists_changes_in_room", "stream_id"),
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -182,17 +183,6 @@ class DataStore(
|
|||
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
device_list_max = self._device_list_id_gen.get_current_token()
|
||||
self._device_list_stream_cache = StreamChangeCache(
|
||||
"DeviceListStreamChangeCache", device_list_max
|
||||
)
|
||||
self._user_signature_stream_cache = StreamChangeCache(
|
||||
"UserSignatureStreamChangeCache", device_list_max
|
||||
)
|
||||
self._device_list_federation_stream_cache = StreamChangeCache(
|
||||
"DeviceListFederationStreamChangeCache", device_list_max
|
||||
)
|
||||
|
||||
events_max = self._stream_id_gen.get_current_token()
|
||||
curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
|
||||
db_conn,
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple
|
||||
|
||||
from synapse.appservice import (
|
||||
ApplicationService,
|
||||
|
@ -26,10 +26,16 @@ from synapse.appservice import (
|
|||
from synapse.config.appservice import load_appservices
|
||||
from synapse.events import EventBase
|
||||
from synapse.storage._base import db_to_json
|
||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||
from synapse.types import JsonDict
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.storage.util.sequence import build_sequence_generator
|
||||
from synapse.types import DeviceListUpdates, JsonDict
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import _CacheContext, cached
|
||||
|
||||
|
@ -72,9 +78,25 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
|
|||
)
|
||||
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
|
||||
|
||||
def get_max_as_txn_id(txn: Cursor) -> int:
|
||||
logger.warning("Falling back to slow query, you should port to postgres")
|
||||
txn.execute(
|
||||
"SELECT COALESCE(max(txn_id), 0) FROM application_services_txns"
|
||||
)
|
||||
return txn.fetchone()[0] # type: ignore
|
||||
|
||||
self._as_txn_seq_gen = build_sequence_generator(
|
||||
db_conn,
|
||||
database.engine,
|
||||
get_max_as_txn_id,
|
||||
"application_services_txn_id_seq",
|
||||
table="application_services_txns",
|
||||
id_column="txn_id",
|
||||
)
|
||||
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
def get_app_services(self):
|
||||
def get_app_services(self) -> List[ApplicationService]:
|
||||
return self.services_cache
|
||||
|
||||
def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
|
||||
|
@ -217,6 +239,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
to_device_messages: List[JsonDict],
|
||||
one_time_key_counts: TransactionOneTimeKeyCounts,
|
||||
unused_fallback_keys: TransactionUnusedFallbackKeys,
|
||||
device_list_summary: DeviceListUpdates,
|
||||
) -> AppServiceTransaction:
|
||||
"""Atomically creates a new transaction for this application service
|
||||
with the given list of events. Ephemeral events are NOT persisted to the
|
||||
|
@ -231,27 +254,14 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
appservice devices in the transaction.
|
||||
unused_fallback_keys: Lists of unused fallback keys for relevant
|
||||
appservice devices in the transaction.
|
||||
device_list_summary: The device list summary to include in the transaction.
|
||||
|
||||
Returns:
|
||||
A new transaction.
|
||||
"""
|
||||
|
||||
def _create_appservice_txn(txn):
|
||||
# work out new txn id (highest txn id for this service += 1)
|
||||
# The highest id may be the last one sent (in which case it is last_txn)
|
||||
# or it may be the highest in the txns list (which are waiting to be/are
|
||||
# being sent)
|
||||
last_txn_id = self._get_last_txn(txn, service.id)
|
||||
|
||||
txn.execute(
|
||||
"SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
|
||||
(service.id,),
|
||||
)
|
||||
highest_txn_id = txn.fetchone()[0]
|
||||
if highest_txn_id is None:
|
||||
highest_txn_id = 0
|
||||
|
||||
new_txn_id = max(highest_txn_id, last_txn_id) + 1
|
||||
def _create_appservice_txn(txn: LoggingTransaction) -> AppServiceTransaction:
|
||||
new_txn_id = self._as_txn_seq_gen.get_next_id_txn(txn)
|
||||
|
||||
# Insert new txn into txn table
|
||||
event_ids = json_encoder.encode([e.event_id for e in events])
|
||||
|
@ -268,6 +278,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
to_device_messages=to_device_messages,
|
||||
one_time_key_counts=one_time_key_counts,
|
||||
unused_fallback_keys=unused_fallback_keys,
|
||||
device_list_summary=device_list_summary,
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
|
@ -283,25 +294,8 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
txn_id: The transaction ID being completed.
|
||||
service: The application service which was sent this transaction.
|
||||
"""
|
||||
txn_id = int(txn_id)
|
||||
|
||||
def _complete_appservice_txn(txn):
|
||||
# Debugging query: Make sure the txn being completed is EXACTLY +1 from
|
||||
# what was there before. If it isn't, we've got problems (e.g. the AS
|
||||
# has probably missed some events), so whine loudly but still continue,
|
||||
# since it shouldn't fail completion of the transaction.
|
||||
last_txn_id = self._get_last_txn(txn, service.id)
|
||||
if (last_txn_id + 1) != txn_id:
|
||||
logger.error(
|
||||
"appservice: Completing a transaction which has an ID > 1 from "
|
||||
"the last ID sent to this AS. We've either dropped events or "
|
||||
"sent it to the AS out of order. FIX ME. last_txn=%s "
|
||||
"completing_txn=%s service_id=%s",
|
||||
last_txn_id,
|
||||
txn_id,
|
||||
service.id,
|
||||
)
|
||||
|
||||
def _complete_appservice_txn(txn: LoggingTransaction) -> None:
|
||||
# Set current txn_id for AS to 'txn_id'
|
||||
self.db_pool.simple_upsert_txn(
|
||||
txn,
|
||||
|
@ -332,7 +326,9 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
An AppServiceTransaction or None.
|
||||
"""
|
||||
|
||||
def _get_oldest_unsent_txn(txn):
|
||||
def _get_oldest_unsent_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
# Monotonically increasing txn ids, so just select the smallest
|
||||
# one in the txns table (we delete them when they are sent)
|
||||
txn.execute(
|
||||
|
@ -359,8 +355,8 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
|
||||
events = await self.get_events_as_list(event_ids)
|
||||
|
||||
# TODO: to-device messages, one-time key counts and unused fallback keys
|
||||
# are not yet populated for catch-up transactions.
|
||||
# TODO: to-device messages, one-time key counts, device list summaries and unused
|
||||
# fallback keys are not yet populated for catch-up transactions.
|
||||
# We likely want to populate those for reliability.
|
||||
return AppServiceTransaction(
|
||||
service=service,
|
||||
|
@ -370,21 +366,11 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
to_device_messages=[],
|
||||
one_time_key_counts={},
|
||||
unused_fallback_keys={},
|
||||
device_list_summary=DeviceListUpdates(),
|
||||
)
|
||||
|
||||
def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
|
||||
txn.execute(
|
||||
"SELECT last_txn FROM application_services_state WHERE as_id=?",
|
||||
(service_id,),
|
||||
)
|
||||
last_txn_id = txn.fetchone()
|
||||
if last_txn_id is None or last_txn_id[0] is None: # no row exists
|
||||
return 0
|
||||
else:
|
||||
return int(last_txn_id[0]) # select 'last_txn' col
|
||||
|
||||
async def set_appservice_last_pos(self, pos: int) -> None:
|
||||
def set_appservice_last_pos_txn(txn):
|
||||
def set_appservice_last_pos_txn(txn: LoggingTransaction) -> None:
|
||||
txn.execute(
|
||||
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
|
||||
)
|
||||
|
@ -398,7 +384,9 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
) -> Tuple[int, List[EventBase]]:
|
||||
"""Get all new events for an appservice"""
|
||||
|
||||
def get_new_events_for_appservice_txn(txn):
|
||||
def get_new_events_for_appservice_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[int, List[str]]:
|
||||
sql = (
|
||||
"SELECT e.stream_ordering, e.event_id"
|
||||
" FROM events AS e"
|
||||
|
@ -430,13 +418,13 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
async def get_type_stream_id_for_appservice(
|
||||
self, service: ApplicationService, type: str
|
||||
) -> int:
|
||||
if type not in ("read_receipt", "presence", "to_device"):
|
||||
if type not in ("read_receipt", "presence", "to_device", "device_list"):
|
||||
raise ValueError(
|
||||
"Expected type to be a valid application stream id type, got %s"
|
||||
% (type,)
|
||||
)
|
||||
|
||||
def get_type_stream_id_for_appservice_txn(txn):
|
||||
def get_type_stream_id_for_appservice_txn(txn: LoggingTransaction) -> int:
|
||||
stream_id_type = "%s_stream_id" % type
|
||||
txn.execute(
|
||||
# We do NOT want to escape `stream_id_type`.
|
||||
|
@ -446,7 +434,8 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
)
|
||||
last_stream_id = txn.fetchone()
|
||||
if last_stream_id is None or last_stream_id[0] is None: # no row exists
|
||||
return 0
|
||||
# Stream tokens always start from 1, to avoid foot guns around `0` being falsey.
|
||||
return 1
|
||||
else:
|
||||
return int(last_stream_id[0])
|
||||
|
||||
|
@ -457,13 +446,13 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
async def set_appservice_stream_type_pos(
|
||||
self, service: ApplicationService, stream_type: str, pos: Optional[int]
|
||||
) -> None:
|
||||
if stream_type not in ("read_receipt", "presence", "to_device"):
|
||||
if stream_type not in ("read_receipt", "presence", "to_device", "device_list"):
|
||||
raise ValueError(
|
||||
"Expected type to be a valid application stream id type, got %s"
|
||||
% (stream_type,)
|
||||
)
|
||||
|
||||
def set_appservice_stream_type_pos_txn(txn):
|
||||
def set_appservice_stream_type_pos_txn(txn: LoggingTransaction) -> None:
|
||||
stream_id_type = "%s_stream_id" % stream_type
|
||||
txn.execute(
|
||||
"UPDATE application_services_state SET %s = ? WHERE as_id=?"
|
||||
|
|
|
@ -25,7 +25,9 @@ from synapse.storage.database import (
|
|||
LoggingTransaction,
|
||||
make_tuple_comparison_clause,
|
||||
)
|
||||
from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore
|
||||
from synapse.storage.databases.main.monthly_active_users import (
|
||||
MonthlyActiveUsersWorkerStore,
|
||||
)
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
|
||||
|
@ -397,7 +399,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
|
|||
return updated
|
||||
|
||||
|
||||
class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
|
||||
class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorkerStore):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
|
@ -406,11 +408,40 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
|
|||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
if hs.config.redis.redis_enabled:
|
||||
# If we're using Redis, we can shift this update process off to
|
||||
# the background worker
|
||||
self._update_on_this_worker = hs.config.worker.run_background_tasks
|
||||
else:
|
||||
# If we're NOT using Redis, this must be handled by the master
|
||||
self._update_on_this_worker = hs.get_instance_name() == "master"
|
||||
|
||||
self.user_ips_max_age = hs.config.server.user_ips_max_age
|
||||
|
||||
# (user_id, access_token, ip,) -> last_seen
|
||||
self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
|
||||
cache_name="client_ip_last_seen", max_size=50000
|
||||
)
|
||||
|
||||
if hs.config.worker.run_background_tasks and self.user_ips_max_age:
|
||||
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
|
||||
|
||||
if self._update_on_this_worker:
|
||||
# This is the designated worker that can write to the client IP
|
||||
# tables.
|
||||
|
||||
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
|
||||
self._batch_row_update: Dict[
|
||||
Tuple[str, str, str], Tuple[str, Optional[str], int]
|
||||
] = {}
|
||||
|
||||
self._client_ip_looper = self._clock.looping_call(
|
||||
self._update_client_ips_batch, 5 * 1000
|
||||
)
|
||||
self.hs.get_reactor().addSystemEventTrigger(
|
||||
"before", "shutdown", self._update_client_ips_batch
|
||||
)
|
||||
|
||||
@wrap_as_background_process("prune_old_user_ips")
|
||||
async def _prune_old_user_ips(self) -> None:
|
||||
"""Removes entries in user IPs older than the configured period."""
|
||||
|
@ -456,7 +487,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
|
|||
"_prune_old_user_ips", _prune_old_user_ips_txn
|
||||
)
|
||||
|
||||
async def get_last_client_ip_by_device(
|
||||
async def _get_last_client_ip_by_device_from_database(
|
||||
self, user_id: str, device_id: Optional[str]
|
||||
) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]:
|
||||
"""For each device_id listed, give the user_ip it was last seen on.
|
||||
|
@ -487,7 +518,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
|
|||
|
||||
return {(d["user_id"], d["device_id"]): d for d in res}
|
||||
|
||||
async def get_user_ip_and_agents(
|
||||
async def _get_user_ip_and_agents_from_database(
|
||||
self, user: UserID, since_ts: int = 0
|
||||
) -> List[LastConnectionInfo]:
|
||||
"""Fetch the IPs and user agents for a user since the given timestamp.
|
||||
|
@ -539,34 +570,6 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
|
|||
for access_token, ip, user_agent, last_seen in rows
|
||||
]
|
||||
|
||||
|
||||
class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
|
||||
# (user_id, access_token, ip,) -> last_seen
|
||||
self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
|
||||
cache_name="client_ip_last_seen", max_size=50000
|
||||
)
|
||||
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
|
||||
self._batch_row_update: Dict[
|
||||
Tuple[str, str, str], Tuple[str, Optional[str], int]
|
||||
] = {}
|
||||
|
||||
self._client_ip_looper = self._clock.looping_call(
|
||||
self._update_client_ips_batch, 5 * 1000
|
||||
)
|
||||
self.hs.get_reactor().addSystemEventTrigger(
|
||||
"before", "shutdown", self._update_client_ips_batch
|
||||
)
|
||||
|
||||
async def insert_client_ip(
|
||||
self,
|
||||
user_id: str,
|
||||
|
@ -584,17 +587,27 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
|
|||
last_seen = self.client_ip_last_seen.get(key)
|
||||
except KeyError:
|
||||
last_seen = None
|
||||
await self.populate_monthly_active_users(user_id)
|
||||
|
||||
# Rate-limited inserts
|
||||
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
|
||||
return
|
||||
|
||||
self.client_ip_last_seen.set(key, now)
|
||||
|
||||
self._batch_row_update[key] = (user_agent, device_id, now)
|
||||
if self._update_on_this_worker:
|
||||
await self.populate_monthly_active_users(user_id)
|
||||
self._batch_row_update[key] = (user_agent, device_id, now)
|
||||
else:
|
||||
# We are not the designated writer-worker, so stream over replication
|
||||
self.hs.get_replication_command_handler().send_user_ip(
|
||||
user_id, access_token, ip, user_agent, device_id, now
|
||||
)
|
||||
|
||||
@wrap_as_background_process("update_client_ips")
|
||||
async def _update_client_ips_batch(self) -> None:
|
||||
assert (
|
||||
self._update_on_this_worker
|
||||
), "This worker is not designated to update client IPs"
|
||||
|
||||
# If the DB pool has already terminated, don't try updating
|
||||
if not self.db_pool.is_running():
|
||||
|
@ -603,51 +616,57 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
|
|||
to_update = self._batch_row_update
|
||||
self._batch_row_update = {}
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
|
||||
)
|
||||
if to_update:
|
||||
await self.db_pool.runInteraction(
|
||||
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
|
||||
)
|
||||
|
||||
def _update_client_ips_batch_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]],
|
||||
) -> None:
|
||||
if "user_ips" in self.db_pool._unsafe_to_upsert_tables or (
|
||||
not self.database_engine.can_native_upsert
|
||||
):
|
||||
self.database_engine.lock_table(txn, "user_ips")
|
||||
assert (
|
||||
self._update_on_this_worker
|
||||
), "This worker is not designated to update client IPs"
|
||||
|
||||
# Keys and values for the `user_ips` upsert.
|
||||
user_ips_keys = []
|
||||
user_ips_values = []
|
||||
|
||||
# Keys and values for the `devices` update.
|
||||
devices_keys = []
|
||||
devices_values = []
|
||||
|
||||
for entry in to_update.items():
|
||||
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
|
||||
|
||||
self.db_pool.simple_upsert_txn(
|
||||
txn,
|
||||
table="user_ips",
|
||||
keyvalues={"user_id": user_id, "access_token": access_token, "ip": ip},
|
||||
values={
|
||||
"user_agent": user_agent,
|
||||
"device_id": device_id,
|
||||
"last_seen": last_seen,
|
||||
},
|
||||
lock=False,
|
||||
)
|
||||
user_ips_keys.append((user_id, access_token, ip))
|
||||
user_ips_values.append((user_agent, device_id, last_seen))
|
||||
|
||||
# Technically an access token might not be associated with
|
||||
# a device so we need to check.
|
||||
if device_id:
|
||||
# this is always an update rather than an upsert: the row should
|
||||
# already exist, and if it doesn't, that may be because it has been
|
||||
# deleted, and we don't want to re-create it.
|
||||
self.db_pool.simple_update_txn(
|
||||
txn,
|
||||
table="devices",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
updatevalues={
|
||||
"user_agent": user_agent,
|
||||
"last_seen": last_seen,
|
||||
"ip": ip,
|
||||
},
|
||||
)
|
||||
devices_keys.append((user_id, device_id))
|
||||
devices_values.append((user_agent, last_seen, ip))
|
||||
|
||||
self.db_pool.simple_upsert_many_txn(
|
||||
txn,
|
||||
table="user_ips",
|
||||
key_names=("user_id", "access_token", "ip"),
|
||||
key_values=user_ips_keys,
|
||||
value_names=("user_agent", "device_id", "last_seen"),
|
||||
value_values=user_ips_values,
|
||||
)
|
||||
|
||||
if devices_values:
|
||||
self.db_pool.simple_update_many_txn(
|
||||
txn,
|
||||
table="devices",
|
||||
key_names=("user_id", "device_id"),
|
||||
key_values=devices_keys,
|
||||
value_names=("user_agent", "last_seen", "ip"),
|
||||
value_values=devices_values,
|
||||
)
|
||||
|
||||
async def get_last_client_ip_by_device(
|
||||
self, user_id: str, device_id: Optional[str]
|
||||
|
@ -662,7 +681,12 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
|
|||
A dictionary mapping a tuple of (user_id, device_id) to dicts, with
|
||||
keys giving the column names from the devices table.
|
||||
"""
|
||||
ret = await super().get_last_client_ip_by_device(user_id, device_id)
|
||||
ret = await self._get_last_client_ip_by_device_from_database(user_id, device_id)
|
||||
|
||||
if not self._update_on_this_worker:
|
||||
# Only the writing-worker has additional in-memory data to enhance
|
||||
# the result
|
||||
return ret
|
||||
|
||||
# Update what is retrieved from the database with data which is pending
|
||||
# insertion, as if it has already been stored in the database.
|
||||
|
@ -707,9 +731,16 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
|
|||
Only the latest user agent for each access token and IP address combination
|
||||
is available.
|
||||
"""
|
||||
rows_from_db = await self._get_user_ip_and_agents_from_database(user, since_ts)
|
||||
|
||||
if not self._update_on_this_worker:
|
||||
# Only the writing-worker has additional in-memory data to enhance
|
||||
# the result
|
||||
return rows_from_db
|
||||
|
||||
results: Dict[Tuple[str, str], LastConnectionInfo] = {
|
||||
(connection["access_token"], connection["ip"]): connection
|
||||
for connection in await super().get_user_ip_and_agents(user, since_ts)
|
||||
for connection in rows_from_db
|
||||
}
|
||||
|
||||
# Overlay data that is pending insertion on top of the results from the
|
||||
|
|
|
@ -46,6 +46,7 @@ from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
|
|||
from synapse.util import json_decoder, json_encoder
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
from synapse.util.iterutils import batch_iter
|
||||
from synapse.util.stringutils import shortstr
|
||||
|
||||
|
@ -71,6 +72,55 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
device_list_max = self._device_list_id_gen.get_current_token()
|
||||
device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
|
||||
db_conn,
|
||||
"device_lists_stream",
|
||||
entity_column="user_id",
|
||||
stream_column="stream_id",
|
||||
max_value=device_list_max,
|
||||
limit=10000,
|
||||
)
|
||||
self._device_list_stream_cache = StreamChangeCache(
|
||||
"DeviceListStreamChangeCache",
|
||||
min_device_list_id,
|
||||
prefilled_cache=device_list_prefill,
|
||||
)
|
||||
|
||||
(
|
||||
user_signature_stream_prefill,
|
||||
user_signature_stream_list_id,
|
||||
) = self.db_pool.get_cache_dict(
|
||||
db_conn,
|
||||
"user_signature_stream",
|
||||
entity_column="from_user_id",
|
||||
stream_column="stream_id",
|
||||
max_value=device_list_max,
|
||||
limit=1000,
|
||||
)
|
||||
self._user_signature_stream_cache = StreamChangeCache(
|
||||
"UserSignatureStreamChangeCache",
|
||||
user_signature_stream_list_id,
|
||||
prefilled_cache=user_signature_stream_prefill,
|
||||
)
|
||||
|
||||
(
|
||||
device_list_federation_prefill,
|
||||
device_list_federation_list_id,
|
||||
) = self.db_pool.get_cache_dict(
|
||||
db_conn,
|
||||
"device_lists_outbound_pokes",
|
||||
entity_column="destination",
|
||||
stream_column="stream_id",
|
||||
max_value=device_list_max,
|
||||
limit=10000,
|
||||
)
|
||||
self._device_list_federation_stream_cache = StreamChangeCache(
|
||||
"DeviceListFederationStreamChangeCache",
|
||||
device_list_federation_list_id,
|
||||
prefilled_cache=device_list_federation_prefill,
|
||||
)
|
||||
|
||||
if hs.config.worker.run_background_tasks:
|
||||
self._clock.looping_call(
|
||||
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
|
||||
|
@ -681,42 +731,64 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
return self._device_list_stream_cache.get_all_entities_changed(from_key)
|
||||
|
||||
async def get_users_whose_devices_changed(
|
||||
self, from_key: int, user_ids: Iterable[str]
|
||||
self,
|
||||
from_key: int,
|
||||
user_ids: Optional[Iterable[str]] = None,
|
||||
to_key: Optional[int] = None,
|
||||
) -> Set[str]:
|
||||
"""Get set of users whose devices have changed since `from_key` that
|
||||
are in the given list of user_ids.
|
||||
|
||||
Args:
|
||||
from_key: The device lists stream token
|
||||
user_ids: The user IDs to query for devices.
|
||||
from_key: The minimum device lists stream token to query device list changes for,
|
||||
exclusive.
|
||||
user_ids: If provided, only check if these users have changed their device lists.
|
||||
Otherwise changes from all users are returned.
|
||||
to_key: The maximum device lists stream token to query device list changes for,
|
||||
inclusive.
|
||||
|
||||
Returns:
|
||||
The set of user_ids whose devices have changed since `from_key`
|
||||
The set of user_ids whose devices have changed since `from_key` (exclusive)
|
||||
until `to_key` (inclusive).
|
||||
"""
|
||||
|
||||
# Get set of users who *may* have changed. Users not in the returned
|
||||
# list have definitely not changed.
|
||||
to_check = self._device_list_stream_cache.get_entities_changed(
|
||||
user_ids, from_key
|
||||
)
|
||||
if user_ids is None:
|
||||
# Get set of all users that have had device list changes since 'from_key'
|
||||
user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
|
||||
from_key
|
||||
)
|
||||
else:
|
||||
# The same as above, but filter results to only those users in 'user_ids'
|
||||
user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
|
||||
user_ids, from_key
|
||||
)
|
||||
|
||||
if not to_check:
|
||||
if not user_ids_to_check:
|
||||
return set()
|
||||
|
||||
def _get_users_whose_devices_changed_txn(txn):
|
||||
changes = set()
|
||||
|
||||
sql = """
|
||||
stream_id_where_clause = "stream_id > ?"
|
||||
sql_args = [from_key]
|
||||
|
||||
if to_key:
|
||||
stream_id_where_clause += " AND stream_id <= ?"
|
||||
sql_args.append(to_key)
|
||||
|
||||
sql = f"""
|
||||
SELECT DISTINCT user_id FROM device_lists_stream
|
||||
WHERE stream_id > ?
|
||||
WHERE {stream_id_where_clause}
|
||||
AND
|
||||
"""
|
||||
|
||||
for chunk in batch_iter(to_check, 100):
|
||||
# Query device changes with a batch of users at a time
|
||||
for chunk in batch_iter(user_ids_to_check, 100):
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "user_id", chunk
|
||||
)
|
||||
txn.execute(sql + clause, (from_key,) + tuple(args))
|
||||
txn.execute(sql + clause, sql_args + args)
|
||||
changes.update(user_id for user_id, in txn)
|
||||
|
||||
return changes
|
||||
|
@ -788,6 +860,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
|
||||
) AS e
|
||||
WHERE ? < stream_id AND stream_id <= ?
|
||||
ORDER BY stream_id ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
|
@ -1506,7 +1579,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
)
|
||||
|
||||
async def add_device_change_to_streams(
|
||||
self, user_id: str, device_ids: Collection[str], hosts: Collection[str]
|
||||
self,
|
||||
user_id: str,
|
||||
device_ids: Collection[str],
|
||||
hosts: Optional[Collection[str]],
|
||||
room_ids: Collection[str],
|
||||
) -> Optional[int]:
|
||||
"""Persist that a user's devices have been updated, and which hosts
|
||||
(if any) should be poked.
|
||||
|
@ -1515,7 +1592,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
user_id: The ID of the user whose device changed.
|
||||
device_ids: The IDs of any changed devices. If empty, this function will
|
||||
return None.
|
||||
hosts: The remote destinations that should be notified of the change.
|
||||
hosts: The remote destinations that should be notified of the change. If
|
||||
None then the set of hosts have *not* been calculated, and will be
|
||||
calculated later by a background task.
|
||||
room_ids: The rooms that the user is in
|
||||
|
||||
Returns:
|
||||
The maximum stream ID of device list updates that were added to the database, or
|
||||
|
@ -1524,34 +1604,62 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
if not device_ids:
|
||||
return None
|
||||
|
||||
async with self._device_list_id_gen.get_next_mult(
|
||||
len(device_ids)
|
||||
) as stream_ids:
|
||||
await self.db_pool.runInteraction(
|
||||
"add_device_change_to_stream",
|
||||
self._add_device_change_to_stream_txn,
|
||||
context = get_active_span_text_map()
|
||||
|
||||
def add_device_changes_txn(
|
||||
txn, stream_ids_for_device_change, stream_ids_for_outbound_pokes
|
||||
):
|
||||
self._add_device_change_to_stream_txn(
|
||||
txn,
|
||||
user_id,
|
||||
device_ids,
|
||||
stream_ids,
|
||||
stream_ids_for_device_change,
|
||||
)
|
||||
|
||||
if not hosts:
|
||||
return stream_ids[-1]
|
||||
self._add_device_outbound_room_poke_txn(
|
||||
txn,
|
||||
user_id,
|
||||
device_ids,
|
||||
room_ids,
|
||||
stream_ids_for_device_change,
|
||||
context,
|
||||
hosts_have_been_calculated=hosts is not None,
|
||||
)
|
||||
|
||||
context = get_active_span_text_map()
|
||||
async with self._device_list_id_gen.get_next_mult(
|
||||
len(hosts) * len(device_ids)
|
||||
) as stream_ids:
|
||||
await self.db_pool.runInteraction(
|
||||
"add_device_outbound_poke_to_stream",
|
||||
self._add_device_outbound_poke_to_stream_txn,
|
||||
# If the set of hosts to send to has not been calculated yet (and so
|
||||
# `hosts` is None) or there are no `hosts` to send to, then skip
|
||||
# trying to persist them to the DB.
|
||||
if not hosts:
|
||||
return
|
||||
|
||||
self._add_device_outbound_poke_to_stream_txn(
|
||||
txn,
|
||||
user_id,
|
||||
device_ids,
|
||||
hosts,
|
||||
stream_ids,
|
||||
stream_ids_for_outbound_pokes,
|
||||
context,
|
||||
)
|
||||
|
||||
# `device_lists_stream` wants a stream ID per device update.
|
||||
num_stream_ids = len(device_ids)
|
||||
|
||||
if hosts:
|
||||
# `device_lists_outbound_pokes` wants a different stream ID for
|
||||
# each row, which is a row per host per device update.
|
||||
num_stream_ids += len(hosts) * len(device_ids)
|
||||
|
||||
async with self._device_list_id_gen.get_next_mult(num_stream_ids) as stream_ids:
|
||||
stream_ids_for_device_change = stream_ids[: len(device_ids)]
|
||||
stream_ids_for_outbound_pokes = stream_ids[len(device_ids) :]
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"add_device_change_to_stream",
|
||||
add_device_changes_txn,
|
||||
stream_ids_for_device_change,
|
||||
stream_ids_for_outbound_pokes,
|
||||
)
|
||||
|
||||
return stream_ids[-1]
|
||||
|
||||
def _add_device_change_to_stream_txn(
|
||||
|
@ -1595,7 +1703,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
user_id: str,
|
||||
device_ids: Iterable[str],
|
||||
hosts: Collection[str],
|
||||
stream_ids: List[str],
|
||||
stream_ids: List[int],
|
||||
context: Dict[str, str],
|
||||
) -> None:
|
||||
for host in hosts:
|
||||
|
@ -1606,8 +1714,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
)
|
||||
|
||||
now = self._clock.time_msec()
|
||||
next_stream_id = iter(stream_ids)
|
||||
stream_id_iterator = iter(stream_ids)
|
||||
|
||||
encoded_context = json_encoder.encode(context)
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
table="device_lists_outbound_pokes",
|
||||
|
@ -1623,16 +1732,146 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
values=[
|
||||
(
|
||||
destination,
|
||||
next(next_stream_id),
|
||||
next(stream_id_iterator),
|
||||
user_id,
|
||||
device_id,
|
||||
False,
|
||||
now,
|
||||
json_encoder.encode(context)
|
||||
if whitelisted_homeserver(destination)
|
||||
else "{}",
|
||||
encoded_context if whitelisted_homeserver(destination) else "{}",
|
||||
)
|
||||
for destination in hosts
|
||||
for device_id in device_ids
|
||||
],
|
||||
)
|
||||
|
||||
def _add_device_outbound_room_poke_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
user_id: str,
|
||||
device_ids: Iterable[str],
|
||||
room_ids: Collection[str],
|
||||
stream_ids: List[str],
|
||||
context: Dict[str, str],
|
||||
hosts_have_been_calculated: bool,
|
||||
) -> None:
|
||||
"""Record the user in the room has updated their device.
|
||||
|
||||
Args:
|
||||
hosts_have_been_calculated: True if `device_lists_outbound_pokes`
|
||||
has been updated already with the updates.
|
||||
"""
|
||||
|
||||
# We only need to convert to outbound pokes if they are our user.
|
||||
converted_to_destinations = (
|
||||
hosts_have_been_calculated or not self.hs.is_mine_id(user_id)
|
||||
)
|
||||
|
||||
encoded_context = json_encoder.encode(context)
|
||||
|
||||
# The `device_lists_changes_in_room.stream_id` column matches the
|
||||
# corresponding `stream_id` of the update in the `device_lists_stream`
|
||||
# table, i.e. all rows persisted for the same device update will have
|
||||
# the same `stream_id` (but different room IDs).
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
table="device_lists_changes_in_room",
|
||||
keys=(
|
||||
"user_id",
|
||||
"device_id",
|
||||
"room_id",
|
||||
"stream_id",
|
||||
"converted_to_destinations",
|
||||
"opentracing_context",
|
||||
),
|
||||
values=[
|
||||
(
|
||||
user_id,
|
||||
device_id,
|
||||
room_id,
|
||||
stream_id,
|
||||
converted_to_destinations,
|
||||
encoded_context,
|
||||
)
|
||||
for room_id in room_ids
|
||||
for device_id, stream_id in zip(device_ids, stream_ids)
|
||||
],
|
||||
)
|
||||
|
||||
async def get_uncoverted_outbound_room_pokes(
|
||||
self, limit: int = 10
|
||||
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
|
||||
"""Get device list changes by room that have not yet been handled and
|
||||
written to `device_lists_outbound_pokes`.
|
||||
|
||||
Returns:
|
||||
A list of user ID, device ID, room ID, stream ID and optional opentracing context.
|
||||
"""
|
||||
|
||||
sql = """
|
||||
SELECT user_id, device_id, room_id, stream_id, opentracing_context
|
||||
FROM device_lists_changes_in_room
|
||||
WHERE NOT converted_to_destinations
|
||||
ORDER BY stream_id
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
def get_uncoverted_outbound_room_pokes_txn(txn):
|
||||
txn.execute(sql, (limit,))
|
||||
return txn.fetchall()
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_uncoverted_outbound_room_pokes", get_uncoverted_outbound_room_pokes_txn
|
||||
)
|
||||
|
||||
async def add_device_list_outbound_pokes(
|
||||
self,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
room_id: str,
|
||||
stream_id: int,
|
||||
hosts: Collection[str],
|
||||
context: Optional[Dict[str, str]],
|
||||
) -> None:
|
||||
"""Queue the device update to be sent to the given set of hosts,
|
||||
calculated from the room ID.
|
||||
|
||||
Marks the associated row in `device_lists_changes_in_room` as handled.
|
||||
"""
|
||||
|
||||
def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]):
|
||||
if hosts:
|
||||
self._add_device_outbound_poke_to_stream_txn(
|
||||
txn,
|
||||
user_id=user_id,
|
||||
device_ids=[device_id],
|
||||
hosts=hosts,
|
||||
stream_ids=stream_ids,
|
||||
context=context,
|
||||
)
|
||||
|
||||
self.db_pool.simple_update_txn(
|
||||
txn,
|
||||
table="device_lists_changes_in_room",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"stream_id": stream_id,
|
||||
"room_id": room_id,
|
||||
},
|
||||
updatevalues={"converted_to_destinations": True},
|
||||
)
|
||||
|
||||
if not hosts:
|
||||
# If there are no hosts then we don't try and generate stream IDs.
|
||||
return await self.db_pool.runInteraction(
|
||||
"add_device_list_outbound_pokes",
|
||||
add_device_list_outbound_pokes_txn,
|
||||
[],
|
||||
)
|
||||
|
||||
async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
|
||||
return await self.db_pool.runInteraction(
|
||||
"add_device_list_outbound_pokes",
|
||||
add_device_list_outbound_pokes_txn,
|
||||
stream_ids,
|
||||
)
|
||||
|
|
|
@ -197,12 +197,10 @@ class PersistEventsStore:
|
|||
)
|
||||
persist_event_counter.inc(len(events_and_contexts))
|
||||
|
||||
if stream < 0:
|
||||
# backfilled events have negative stream orderings, so we don't
|
||||
# want to set the event_persisted_position to that.
|
||||
synapse.metrics.event_persisted_position.set(
|
||||
events_and_contexts[-1][0].internal_metadata.stream_ordering
|
||||
)
|
||||
if not use_negative_stream_ordering:
|
||||
# we don't want to set the event_persisted_position to a negative
|
||||
# stream_ordering.
|
||||
synapse.metrics.event_persisted_position.set(stream)
|
||||
|
||||
for event, context in events_and_contexts:
|
||||
if context.app_service:
|
||||
|
|
|
@ -22,7 +22,6 @@ from typing import (
|
|||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
|
@ -1339,10 +1338,9 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
return results
|
||||
|
||||
@cached(max_entries=100000, tree=True)
|
||||
async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
|
||||
# this only exists for the benefit of the @cachedList descriptor on
|
||||
# _have_seen_events_dict
|
||||
raise NotImplementedError()
|
||||
async def have_seen_event(self, room_id: str, event_id: str) -> bool:
|
||||
res = await self._have_seen_events_dict(((room_id, event_id),))
|
||||
return res[(room_id, event_id)]
|
||||
|
||||
def _get_current_state_event_counts_txn(
|
||||
self, txn: LoggingTransaction, room_id: str
|
||||
|
|
|
@ -15,7 +15,6 @@ import logging
|
|||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
|
||||
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
|
@ -36,7 +35,7 @@ logger = logging.getLogger(__name__)
|
|||
LAST_SEEN_GRANULARITY = 60 * 60 * 1000
|
||||
|
||||
|
||||
class MonthlyActiveUsersWorkerStore(SQLBaseStore):
|
||||
class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
|
@ -47,9 +46,30 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
|
|||
self._clock = hs.get_clock()
|
||||
self.hs = hs
|
||||
|
||||
if hs.config.redis.redis_enabled:
|
||||
# If we're using Redis, we can shift this update process off to
|
||||
# the background worker
|
||||
self._update_on_this_worker = hs.config.worker.run_background_tasks
|
||||
else:
|
||||
# If we're NOT using Redis, this must be handled by the master
|
||||
self._update_on_this_worker = hs.get_instance_name() == "master"
|
||||
|
||||
self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau
|
||||
self._max_mau_value = hs.config.server.max_mau_value
|
||||
|
||||
self._mau_stats_only = hs.config.server.mau_stats_only
|
||||
|
||||
if self._update_on_this_worker:
|
||||
# Do not add more reserved users than the total allowable number
|
||||
self.db_pool.new_transaction(
|
||||
db_conn,
|
||||
"initialise_mau_threepids",
|
||||
[],
|
||||
[],
|
||||
self._initialise_reserved_users,
|
||||
hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
|
||||
)
|
||||
|
||||
@cached(num_args=0)
|
||||
async def get_monthly_active_count(self) -> int:
|
||||
"""Generates current count of monthly active users
|
||||
|
@ -222,28 +242,6 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
|
|||
"reap_monthly_active_users", _reap_users, reserved_users
|
||||
)
|
||||
|
||||
|
||||
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._mau_stats_only = hs.config.server.mau_stats_only
|
||||
|
||||
# Do not add more reserved users than the total allowable number
|
||||
self.db_pool.new_transaction(
|
||||
db_conn,
|
||||
"initialise_mau_threepids",
|
||||
[],
|
||||
[],
|
||||
self._initialise_reserved_users,
|
||||
hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
|
||||
)
|
||||
|
||||
def _initialise_reserved_users(
|
||||
self, txn: LoggingTransaction, threepids: List[dict]
|
||||
) -> None:
|
||||
|
@ -254,6 +252,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
|
|||
txn:
|
||||
threepids: List of threepid dicts to reserve
|
||||
"""
|
||||
assert (
|
||||
self._update_on_this_worker
|
||||
), "This worker is not designated to update MAUs"
|
||||
|
||||
# XXX what is this function trying to achieve? It upserts into
|
||||
# monthly_active_users for each *registered* reserved mau user, but why?
|
||||
|
@ -287,6 +288,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
|
|||
Args:
|
||||
user_id: user to add/update
|
||||
"""
|
||||
assert (
|
||||
self._update_on_this_worker
|
||||
), "This worker is not designated to update MAUs"
|
||||
|
||||
# Support user never to be included in MAU stats. Note I can't easily call this
|
||||
# from upsert_monthly_active_user_txn because then I need a _txn form of
|
||||
# is_support_user which is complicated because I want to cache the result.
|
||||
|
@ -322,6 +327,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
|
|||
txn (cursor):
|
||||
user_id (str): user to add/update
|
||||
"""
|
||||
assert (
|
||||
self._update_on_this_worker
|
||||
), "This worker is not designated to update MAUs"
|
||||
|
||||
# Am consciously deciding to lock the table on the basis that is ought
|
||||
# never be a big table and alternative approaches (batching multiple
|
||||
|
@ -349,6 +357,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
|
|||
Args:
|
||||
user_id(str): the user_id to query
|
||||
"""
|
||||
assert (
|
||||
self._update_on_this_worker
|
||||
), "This worker is not designated to update MAUs"
|
||||
|
||||
if self._limit_usage_by_mau or self._mau_stats_only:
|
||||
# Trial users and guests should not be included as part of MAU group
|
||||
is_guest = await self.is_guest(user_id) # type: ignore[attr-defined]
|
||||
|
|
|
@ -98,8 +98,19 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
max_receipts_stream_id = self.get_max_receipt_stream_id()
|
||||
receipts_stream_prefill, min_receipts_stream_id = self.db_pool.get_cache_dict(
|
||||
db_conn,
|
||||
"receipts_linearized",
|
||||
entity_column="room_id",
|
||||
stream_column="stream_id",
|
||||
max_value=max_receipts_stream_id,
|
||||
limit=10000,
|
||||
)
|
||||
self._receipts_stream_cache = StreamChangeCache(
|
||||
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
|
||||
"ReceiptsRoomChangeCache",
|
||||
min_receipts_stream_id,
|
||||
prefilled_cache=receipts_stream_prefill,
|
||||
)
|
||||
|
||||
def get_max_receipt_stream_id(self) -> int:
|
||||
|
|
|
@ -34,7 +34,7 @@ from synapse.storage.databases.main.stats import StatsStore
|
|||
from synapse.storage.types import Cursor
|
||||
from synapse.storage.util.id_generators import IdGenerator
|
||||
from synapse.storage.util.sequence import build_sequence_generator
|
||||
from synapse.types import UserID, UserInfo
|
||||
from synapse.types import JsonDict, UserID, UserInfo
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -79,7 +79,7 @@ class TokenLookupResult:
|
|||
|
||||
# Make the token owner default to the user ID, which is the common case.
|
||||
@token_owner.default
|
||||
def _default_token_owner(self):
|
||||
def _default_token_owner(self) -> str:
|
||||
return self.user_id
|
||||
|
||||
|
||||
|
@ -299,7 +299,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
the account.
|
||||
"""
|
||||
|
||||
def set_account_validity_for_user_txn(txn):
|
||||
def set_account_validity_for_user_txn(txn: LoggingTransaction) -> None:
|
||||
self.db_pool.simple_update_txn(
|
||||
txn=txn,
|
||||
table="account_validity",
|
||||
|
@ -385,23 +385,25 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
desc="get_renewal_token_for_user",
|
||||
)
|
||||
|
||||
async def get_users_expiring_soon(self) -> List[Dict[str, Any]]:
|
||||
async def get_users_expiring_soon(self) -> List[Tuple[str, int]]:
|
||||
"""Selects users whose account will expire in the [now, now + renew_at] time
|
||||
window (see configuration for account_validity for information on what renew_at
|
||||
refers to).
|
||||
|
||||
Returns:
|
||||
A list of dictionaries, each with a user ID and expiration time (in milliseconds).
|
||||
A list of tuples, each with a user ID and expiration time (in milliseconds).
|
||||
"""
|
||||
|
||||
def select_users_txn(txn, now_ms, renew_at):
|
||||
def select_users_txn(
|
||||
txn: LoggingTransaction, now_ms: int, renew_at: int
|
||||
) -> List[Tuple[str, int]]:
|
||||
sql = (
|
||||
"SELECT user_id, expiration_ts_ms FROM account_validity"
|
||||
" WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?"
|
||||
)
|
||||
values = [False, now_ms, renew_at]
|
||||
txn.execute(sql, values)
|
||||
return self.db_pool.cursor_to_dict(txn)
|
||||
return cast(List[Tuple[str, int]], txn.fetchall())
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_users_expiring_soon",
|
||||
|
@ -466,7 +468,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
admin: true iff the user is to be a server admin, false otherwise.
|
||||
"""
|
||||
|
||||
def set_server_admin_txn(txn):
|
||||
def set_server_admin_txn(txn: LoggingTransaction) -> None:
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0}
|
||||
)
|
||||
|
@ -515,7 +517,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
user_type: type of the user or None for a user without a type.
|
||||
"""
|
||||
|
||||
def set_user_type_txn(txn):
|
||||
def set_user_type_txn(txn: LoggingTransaction) -> None:
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn, "users", {"name": user.to_string()}, {"user_type": user_type}
|
||||
)
|
||||
|
@ -525,7 +527,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
|
||||
await self.db_pool.runInteraction("set_user_type", set_user_type_txn)
|
||||
|
||||
def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
|
||||
def _query_for_auth(
|
||||
self, txn: LoggingTransaction, token: str
|
||||
) -> Optional[TokenLookupResult]:
|
||||
sql = """
|
||||
SELECT users.name as user_id,
|
||||
users.is_guest,
|
||||
|
@ -582,7 +586,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
"is_support_user", self.is_support_user_txn, user_id
|
||||
)
|
||||
|
||||
def is_real_user_txn(self, txn, user_id):
|
||||
def is_real_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool:
|
||||
res = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn=txn,
|
||||
table="users",
|
||||
|
@ -592,7 +596,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
)
|
||||
return res is None
|
||||
|
||||
def is_support_user_txn(self, txn, user_id):
|
||||
def is_support_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool:
|
||||
res = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn=txn,
|
||||
table="users",
|
||||
|
@ -609,10 +613,11 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
A mapping of user_id -> password_hash.
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
def f(txn: LoggingTransaction) -> Dict[str, str]:
|
||||
sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)"
|
||||
txn.execute(sql, (user_id,))
|
||||
return dict(txn)
|
||||
result = cast(List[Tuple[str, str]], txn.fetchall())
|
||||
return dict(result)
|
||||
|
||||
return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
|
||||
|
||||
|
@ -734,7 +739,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
|
||||
def _replace_user_external_id_txn(
|
||||
txn: LoggingTransaction,
|
||||
):
|
||||
) -> None:
|
||||
_remove_user_external_ids_txn(txn, user_id)
|
||||
|
||||
for auth_provider, external_id in record_external_ids:
|
||||
|
@ -790,10 +795,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
)
|
||||
return [(r["auth_provider"], r["external_id"]) for r in res]
|
||||
|
||||
async def count_all_users(self):
|
||||
async def count_all_users(self) -> int:
|
||||
"""Counts all users registered on the homeserver."""
|
||||
|
||||
def _count_users(txn):
|
||||
def _count_users(txn: LoggingTransaction) -> int:
|
||||
txn.execute("SELECT COUNT(*) AS users FROM users")
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
if rows:
|
||||
|
@ -810,7 +815,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
who registered on the homeserver in the past 24 hours
|
||||
"""
|
||||
|
||||
def _count_daily_user_type(txn):
|
||||
def _count_daily_user_type(txn: LoggingTransaction) -> Dict[str, int]:
|
||||
yesterday = int(self._clock.time()) - (60 * 60 * 24)
|
||||
|
||||
sql = """
|
||||
|
@ -835,23 +840,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
"count_daily_user_type", _count_daily_user_type
|
||||
)
|
||||
|
||||
async def count_nonbridged_users(self):
|
||||
def _count_users(txn):
|
||||
async def count_nonbridged_users(self) -> int:
|
||||
def _count_users(txn: LoggingTransaction) -> int:
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT COUNT(*) FROM users
|
||||
WHERE appservice_id IS NULL
|
||||
"""
|
||||
)
|
||||
(count,) = txn.fetchone()
|
||||
(count,) = cast(Tuple[int], txn.fetchone())
|
||||
return count
|
||||
|
||||
return await self.db_pool.runInteraction("count_users", _count_users)
|
||||
|
||||
async def count_real_users(self):
|
||||
async def count_real_users(self) -> int:
|
||||
"""Counts all users without a special user_type registered on the homeserver."""
|
||||
|
||||
def _count_users(txn):
|
||||
def _count_users(txn: LoggingTransaction) -> int:
|
||||
txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
if rows:
|
||||
|
@ -888,7 +893,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
return user_id
|
||||
|
||||
def get_user_id_by_threepid_txn(
|
||||
self, txn, medium: str, address: str
|
||||
self, txn: LoggingTransaction, medium: str, address: str
|
||||
) -> Optional[str]:
|
||||
"""Returns user id from threepid
|
||||
|
||||
|
@ -925,7 +930,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
|
||||
)
|
||||
|
||||
async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]:
|
||||
async def user_get_threepids(self, user_id: str) -> List[Dict[str, Any]]:
|
||||
return await self.db_pool.simple_select_list(
|
||||
"user_threepids",
|
||||
{"user_id": user_id},
|
||||
|
@ -957,7 +962,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
|
||||
async def add_user_bound_threepid(
|
||||
self, user_id: str, medium: str, address: str, id_server: str
|
||||
):
|
||||
) -> None:
|
||||
"""The server proxied a bind request to the given identity server on
|
||||
behalf of the given user. We need to remember this in case the user
|
||||
asks us to unbind the threepid.
|
||||
|
@ -1116,7 +1121,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
|
||||
assert address or sid
|
||||
|
||||
def get_threepid_validation_session_txn(txn):
|
||||
def get_threepid_validation_session_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
sql = """
|
||||
SELECT address, session_id, medium, client_secret,
|
||||
last_send_attempt, validated_at
|
||||
|
@ -1150,7 +1157,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
session_id: The ID of the session to delete
|
||||
"""
|
||||
|
||||
def delete_threepid_session_txn(txn):
|
||||
def delete_threepid_session_txn(txn: LoggingTransaction) -> None:
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
table="threepid_validation_token",
|
||||
|
@ -1170,7 +1177,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
async def cull_expired_threepid_validation_tokens(self) -> None:
|
||||
"""Remove threepid validation tokens with expiry dates that have passed"""
|
||||
|
||||
def cull_expired_threepid_validation_tokens_txn(txn, ts):
|
||||
def cull_expired_threepid_validation_tokens_txn(
|
||||
txn: LoggingTransaction, ts: int
|
||||
) -> None:
|
||||
sql = """
|
||||
DELETE FROM threepid_validation_token WHERE
|
||||
expires < ?
|
||||
|
@ -1184,13 +1193,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
)
|
||||
|
||||
@wrap_as_background_process("account_validity_set_expiration_dates")
|
||||
async def _set_expiration_date_when_missing(self):
|
||||
async def _set_expiration_date_when_missing(self) -> None:
|
||||
"""
|
||||
Retrieves the list of registered users that don't have an expiration date, and
|
||||
adds an expiration date for each of them.
|
||||
"""
|
||||
|
||||
def select_users_with_no_expiration_date_txn(txn):
|
||||
def select_users_with_no_expiration_date_txn(txn: LoggingTransaction) -> None:
|
||||
"""Retrieves the list of registered users with no expiration date from the
|
||||
database, filtering out deactivated users.
|
||||
"""
|
||||
|
@ -1213,7 +1222,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
select_users_with_no_expiration_date_txn,
|
||||
)
|
||||
|
||||
def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
|
||||
def set_expiration_date_for_user_txn(
|
||||
self, txn: LoggingTransaction, user_id: str, use_delta: bool = False
|
||||
) -> None:
|
||||
"""Sets an expiration date to the account with the given user ID.
|
||||
|
||||
Args:
|
||||
|
@ -1344,7 +1355,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
token: The registration token pending use
|
||||
"""
|
||||
|
||||
def _set_registration_token_pending_txn(txn):
|
||||
def _set_registration_token_pending_txn(txn: LoggingTransaction) -> None:
|
||||
pending = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
|
@ -1358,7 +1369,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
updatevalues={"pending": pending + 1},
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"set_registration_token_pending", _set_registration_token_pending_txn
|
||||
)
|
||||
|
||||
|
@ -1372,7 +1383,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
token: The registration token to be 'used'
|
||||
"""
|
||||
|
||||
def _use_registration_token_txn(txn):
|
||||
def _use_registration_token_txn(txn: LoggingTransaction) -> None:
|
||||
# Normally, res is Optional[Dict[str, Any]].
|
||||
# Override type because the return type is only optional if
|
||||
# allow_none is True, and we don't want mypy throwing errors
|
||||
|
@ -1398,7 +1409,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
},
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"use_registration_token", _use_registration_token_txn
|
||||
)
|
||||
|
||||
|
@ -1416,7 +1427,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
A list of dicts, each containing details of a token.
|
||||
"""
|
||||
|
||||
def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]):
|
||||
def select_registration_tokens_txn(
|
||||
txn: LoggingTransaction, now: int, valid: Optional[bool]
|
||||
) -> List[Dict[str, Any]]:
|
||||
if valid is None:
|
||||
# Return all tokens regardless of validity
|
||||
txn.execute("SELECT * FROM registration_tokens")
|
||||
|
@ -1523,7 +1536,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
Whether the row was inserted or not.
|
||||
"""
|
||||
|
||||
def _create_registration_token_txn(txn):
|
||||
def _create_registration_token_txn(txn: LoggingTransaction) -> bool:
|
||||
row = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
|
@ -1570,7 +1583,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
A dict with all info about the token, or None if token doesn't exist.
|
||||
"""
|
||||
|
||||
def _update_registration_token_txn(txn):
|
||||
def _update_registration_token_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
try:
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn,
|
||||
|
@ -1651,7 +1666,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
) -> Optional[RefreshTokenLookupResult]:
|
||||
"""Lookup a refresh token with hints about its validity."""
|
||||
|
||||
def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]:
|
||||
def _lookup_refresh_token_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Optional[RefreshTokenLookupResult]:
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT
|
||||
|
@ -1745,6 +1762,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
"replace_refresh_token", _replace_refresh_token_txn
|
||||
)
|
||||
|
||||
@cached()
|
||||
async def is_guest(self, user_id: str) -> bool:
|
||||
res = await self.db_pool.simple_select_one_onecol(
|
||||
table="users",
|
||||
keyvalues={"name": user_id},
|
||||
retcol="is_guest",
|
||||
allow_none=True,
|
||||
desc="is_guest",
|
||||
)
|
||||
|
||||
return res if res else False
|
||||
|
||||
|
||||
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
||||
def __init__(
|
||||
|
@ -1795,14 +1824,18 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
|||
unique=False,
|
||||
)
|
||||
|
||||
async def _background_update_set_deactivated_flag(self, progress, batch_size):
|
||||
async def _background_update_set_deactivated_flag(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
|
||||
for each of them.
|
||||
"""
|
||||
|
||||
last_user = progress.get("user_id", "")
|
||||
|
||||
def _background_update_set_deactivated_flag_txn(txn):
|
||||
def _background_update_set_deactivated_flag_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[bool, int]:
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT
|
||||
|
@ -1874,7 +1907,9 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
|||
deactivated,
|
||||
)
|
||||
|
||||
def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool):
|
||||
def set_user_deactivated_status_txn(
|
||||
self, txn: LoggingTransaction, user_id: str, deactivated: bool
|
||||
) -> None:
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn=txn,
|
||||
table="users",
|
||||
|
@ -1887,18 +1922,6 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
|||
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
||||
txn.call_after(self.is_guest.invalidate, (user_id,))
|
||||
|
||||
@cached()
|
||||
async def is_guest(self, user_id: str) -> bool:
|
||||
res = await self.db_pool.simple_select_one_onecol(
|
||||
table="users",
|
||||
keyvalues={"name": user_id},
|
||||
retcol="is_guest",
|
||||
allow_none=True,
|
||||
desc="is_guest",
|
||||
)
|
||||
|
||||
return res if res else False
|
||||
|
||||
|
||||
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||
def __init__(
|
||||
|
@ -2005,7 +2028,9 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
|
||||
return next_id
|
||||
|
||||
def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
|
||||
def _set_device_for_access_token_txn(
|
||||
self, txn: LoggingTransaction, token: str, device_id: str
|
||||
) -> str:
|
||||
old_device_id = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn, "access_tokens", {"token": token}, "device_id"
|
||||
)
|
||||
|
@ -2084,7 +2109,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
|
||||
def _register_user(
|
||||
self,
|
||||
txn,
|
||||
txn: LoggingTransaction,
|
||||
user_id: str,
|
||||
password_hash: Optional[str],
|
||||
was_guest: bool,
|
||||
|
@ -2094,7 +2119,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
admin: bool,
|
||||
user_type: Optional[str],
|
||||
shadow_banned: bool,
|
||||
):
|
||||
) -> None:
|
||||
user_id_obj = UserID.from_string(user_id)
|
||||
|
||||
now = int(self._clock.time())
|
||||
|
@ -2181,7 +2206,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
pointless. Use flush_user separately.
|
||||
"""
|
||||
|
||||
def user_set_password_hash_txn(txn):
|
||||
def user_set_password_hash_txn(txn: LoggingTransaction) -> None:
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn, "users", {"name": user_id}, {"password_hash": password_hash}
|
||||
)
|
||||
|
@ -2204,7 +2229,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
StoreError(404) if user not found
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
def f(txn: LoggingTransaction) -> None:
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn,
|
||||
table="users",
|
||||
|
@ -2229,7 +2254,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
StoreError(404) if user not found
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
def f(txn: LoggingTransaction) -> None:
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn,
|
||||
table="users",
|
||||
|
@ -2259,7 +2284,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
A tuple of (token, token id, device id) for each of the deleted tokens
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
def f(txn: LoggingTransaction) -> List[Tuple[str, int, Optional[str]]]:
|
||||
keyvalues = {"user_id": user_id}
|
||||
if device_id is not None:
|
||||
keyvalues["device_id"] = device_id
|
||||
|
@ -2301,7 +2326,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
return await self.db_pool.runInteraction("user_delete_access_tokens", f)
|
||||
|
||||
async def delete_access_token(self, access_token: str) -> None:
|
||||
def f(txn):
|
||||
def f(txn: LoggingTransaction) -> None:
|
||||
self.db_pool.simple_delete_one_txn(
|
||||
txn, table="access_tokens", keyvalues={"token": access_token}
|
||||
)
|
||||
|
@ -2313,7 +2338,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
await self.db_pool.runInteraction("delete_access_token", f)
|
||||
|
||||
async def delete_refresh_token(self, refresh_token: str) -> None:
|
||||
def f(txn):
|
||||
def f(txn: LoggingTransaction) -> None:
|
||||
self.db_pool.simple_delete_one_txn(
|
||||
txn, table="refresh_tokens", keyvalues={"token": refresh_token}
|
||||
)
|
||||
|
@ -2353,7 +2378,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
"""
|
||||
|
||||
# Insert everything into a transaction in order to run atomically
|
||||
def validate_threepid_session_txn(txn):
|
||||
def validate_threepid_session_txn(txn: LoggingTransaction) -> Optional[str]:
|
||||
row = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
table="threepid_validation_session",
|
||||
|
@ -2450,7 +2475,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
longer be valid
|
||||
"""
|
||||
|
||||
def start_or_continue_validation_session_txn(txn):
|
||||
def start_or_continue_validation_session_txn(txn: LoggingTransaction) -> None:
|
||||
# Create or update a validation session
|
||||
self.db_pool.simple_upsert_txn(
|
||||
txn,
|
||||
|
|
|
@ -17,6 +17,7 @@ from typing import (
|
|||
TYPE_CHECKING,
|
||||
Collection,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
|
@ -39,8 +40,7 @@ from synapse.storage.database import (
|
|||
)
|
||||
from synapse.storage.databases.main.stream import generate_pagination_where_clause
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
|
||||
from synapse.types import RoomStreamToken, StreamToken
|
||||
from synapse.types import JsonDict, RoomStreamToken, StreamToken
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -49,6 +49,19 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class _RelatedEvent:
|
||||
"""
|
||||
Contains enough information about a related event in order to properly filter
|
||||
events from ignored users.
|
||||
"""
|
||||
|
||||
# The event ID of the related event.
|
||||
event_id: str
|
||||
# The sender of the related event.
|
||||
sender: str
|
||||
|
||||
|
||||
class RelationsWorkerStore(SQLBaseStore):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -73,7 +86,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
direction: str = "b",
|
||||
from_token: Optional[StreamToken] = None,
|
||||
to_token: Optional[StreamToken] = None,
|
||||
) -> PaginationChunk:
|
||||
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
|
||||
"""Get a list of relations for an event, ordered by topological ordering.
|
||||
|
||||
Args:
|
||||
|
@ -90,8 +103,10 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
to_token: Fetch rows up to the given token, or up to the end if None.
|
||||
|
||||
Returns:
|
||||
List of event IDs that match relations requested. The rows are of
|
||||
the form `{"event_id": "..."}`.
|
||||
A tuple of:
|
||||
A list of related event IDs & their senders.
|
||||
|
||||
The next stream token, if one exists.
|
||||
"""
|
||||
# We don't use `event_id`, it's there so that we can cache based on
|
||||
# it. The `event_id` must match the `event.event_id`.
|
||||
|
@ -132,7 +147,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
order = "ASC"
|
||||
|
||||
sql = """
|
||||
SELECT event_id, relation_type, topological_ordering, stream_ordering
|
||||
SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
|
||||
FROM event_relations
|
||||
INNER JOIN events USING (event_id)
|
||||
WHERE %s
|
||||
|
@ -146,7 +161,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
|
||||
def _get_recent_references_for_event_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> PaginationChunk:
|
||||
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
|
||||
txn.execute(sql, where_args + [limit + 1])
|
||||
|
||||
last_topo_id = None
|
||||
|
@ -156,9 +171,9 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
# Do not include edits for redacted events as they leak event
|
||||
# content.
|
||||
if not is_redacted or row[1] != RelationTypes.REPLACE:
|
||||
events.append({"event_id": row[0]})
|
||||
last_topo_id = row[2]
|
||||
last_stream_id = row[3]
|
||||
events.append(_RelatedEvent(row[0], row[2]))
|
||||
last_topo_id = row[3]
|
||||
last_stream_id = row[4]
|
||||
|
||||
# If there are more events, generate the next pagination key.
|
||||
next_token = None
|
||||
|
@ -179,9 +194,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
groups_key=0,
|
||||
)
|
||||
|
||||
return PaginationChunk(
|
||||
chunk=list(events[:limit]), next_batch=next_token, prev_batch=from_token
|
||||
)
|
||||
return events[:limit], next_token
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_recent_references_for_event", _get_recent_references_for_event_txn
|
||||
|
@ -252,15 +265,8 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
|
||||
@cached(tree=True)
|
||||
async def get_aggregation_groups_for_event(
|
||||
self,
|
||||
event_id: str,
|
||||
room_id: str,
|
||||
event_type: Optional[str] = None,
|
||||
limit: int = 5,
|
||||
direction: str = "b",
|
||||
from_token: Optional[AggregationPaginationToken] = None,
|
||||
to_token: Optional[AggregationPaginationToken] = None,
|
||||
) -> PaginationChunk:
|
||||
self, event_id: str, room_id: str, limit: int = 5
|
||||
) -> List[JsonDict]:
|
||||
"""Get a list of annotations on the event, grouped by event type and
|
||||
aggregation key, sorted by count.
|
||||
|
||||
|
@ -270,82 +276,96 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
Args:
|
||||
event_id: Fetch events that relate to this event ID.
|
||||
room_id: The room the event belongs to.
|
||||
event_type: Only fetch events with this event type, if given.
|
||||
limit: Only fetch the `limit` groups.
|
||||
direction: Whether to fetch the highest count first (`"b"`) or
|
||||
the lowest count first (`"f"`).
|
||||
from_token: Fetch rows from the given token, or from the start if None.
|
||||
to_token: Fetch rows up to the given token, or up to the end if None.
|
||||
|
||||
Returns:
|
||||
List of groups of annotations that match. Each row is a dict with
|
||||
`type`, `key` and `count` fields.
|
||||
"""
|
||||
|
||||
where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"]
|
||||
where_args: List[Union[str, int]] = [
|
||||
args = [
|
||||
event_id,
|
||||
room_id,
|
||||
RelationTypes.ANNOTATION,
|
||||
limit,
|
||||
]
|
||||
|
||||
sql = """
|
||||
SELECT type, aggregation_key, COUNT(DISTINCT sender)
|
||||
FROM event_relations
|
||||
INNER JOIN events USING (event_id)
|
||||
WHERE relates_to_id = ? AND room_id = ? AND relation_type = ?
|
||||
GROUP BY relation_type, type, aggregation_key
|
||||
ORDER BY COUNT(*) DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
def _get_aggregation_groups_for_event_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> List[JsonDict]:
|
||||
txn.execute(sql, args)
|
||||
|
||||
return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn]
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
|
||||
)
|
||||
|
||||
async def get_aggregation_groups_for_users(
|
||||
self,
|
||||
event_id: str,
|
||||
room_id: str,
|
||||
limit: int,
|
||||
users: FrozenSet[str] = frozenset(),
|
||||
) -> Dict[Tuple[str, str], int]:
|
||||
"""Fetch the partial aggregations for an event for specific users.
|
||||
|
||||
This is used, in conjunction with get_aggregation_groups_for_event, to
|
||||
remove information from the results for ignored users.
|
||||
|
||||
Args:
|
||||
event_id: Fetch events that relate to this event ID.
|
||||
room_id: The room the event belongs to.
|
||||
limit: Only fetch the `limit` groups.
|
||||
users: The users to fetch information for.
|
||||
|
||||
Returns:
|
||||
A map of (event type, aggregation key) to a count of users.
|
||||
"""
|
||||
|
||||
if not users:
|
||||
return {}
|
||||
|
||||
args: List[Union[str, int]] = [
|
||||
event_id,
|
||||
room_id,
|
||||
RelationTypes.ANNOTATION,
|
||||
]
|
||||
|
||||
if event_type:
|
||||
where_clause.append("type = ?")
|
||||
where_args.append(event_type)
|
||||
|
||||
having_clause = generate_pagination_where_clause(
|
||||
direction=direction,
|
||||
column_names=("COUNT(*)", "MAX(stream_ordering)"),
|
||||
from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
|
||||
to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
|
||||
engine=self.database_engine,
|
||||
users_sql, users_args = make_in_list_sql_clause(
|
||||
self.database_engine, "sender", users
|
||||
)
|
||||
args.extend(users_args)
|
||||
|
||||
if direction == "b":
|
||||
order = "DESC"
|
||||
else:
|
||||
order = "ASC"
|
||||
|
||||
if having_clause:
|
||||
having_clause = "HAVING " + having_clause
|
||||
else:
|
||||
having_clause = ""
|
||||
|
||||
sql = """
|
||||
SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering)
|
||||
sql = f"""
|
||||
SELECT type, aggregation_key, COUNT(DISTINCT sender)
|
||||
FROM event_relations
|
||||
INNER JOIN events USING (event_id)
|
||||
WHERE {where_clause}
|
||||
WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql}
|
||||
GROUP BY relation_type, type, aggregation_key
|
||||
{having_clause}
|
||||
ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
|
||||
ORDER BY COUNT(*) DESC
|
||||
LIMIT ?
|
||||
""".format(
|
||||
where_clause=" AND ".join(where_clause),
|
||||
order=order,
|
||||
having_clause=having_clause,
|
||||
)
|
||||
"""
|
||||
|
||||
def _get_aggregation_groups_for_event_txn(
|
||||
def _get_aggregation_groups_for_users_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> PaginationChunk:
|
||||
txn.execute(sql, where_args + [limit + 1])
|
||||
) -> Dict[Tuple[str, str], int]:
|
||||
txn.execute(sql, args + [limit])
|
||||
|
||||
next_batch = None
|
||||
events = []
|
||||
for row in txn:
|
||||
events.append({"type": row[0], "key": row[1], "count": row[2]})
|
||||
next_batch = AggregationPaginationToken(row[2], row[3])
|
||||
|
||||
if len(events) <= limit:
|
||||
next_batch = None
|
||||
|
||||
return PaginationChunk(
|
||||
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
|
||||
)
|
||||
return {(row[0], row[1]): row[2] for row in txn}
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
|
||||
"get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn
|
||||
)
|
||||
|
||||
@cached()
|
||||
|
@ -574,6 +594,67 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
|
||||
return summaries
|
||||
|
||||
async def get_threaded_messages_per_user(
|
||||
self,
|
||||
event_ids: Collection[str],
|
||||
users: FrozenSet[str] = frozenset(),
|
||||
) -> Dict[Tuple[str, str], int]:
|
||||
"""Get the number of threaded replies for a set of users.
|
||||
|
||||
This is used, in conjunction with get_thread_summaries, to calculate an
|
||||
accurate count of the replies to a thread by subtracting ignored users.
|
||||
|
||||
Args:
|
||||
event_ids: The events to check for threaded replies.
|
||||
users: The user to calculate the count of their replies.
|
||||
|
||||
Returns:
|
||||
A map of the (event_id, sender) to the count of their replies.
|
||||
"""
|
||||
if not users:
|
||||
return {}
|
||||
|
||||
# Fetch the number of threaded replies.
|
||||
sql = """
|
||||
SELECT parent.event_id, child.sender, COUNT(child.event_id) FROM events AS child
|
||||
INNER JOIN event_relations USING (event_id)
|
||||
INNER JOIN events AS parent ON
|
||||
parent.event_id = relates_to_id
|
||||
AND parent.room_id = child.room_id
|
||||
WHERE
|
||||
%s
|
||||
AND %s
|
||||
AND %s
|
||||
GROUP BY parent.event_id, child.sender
|
||||
"""
|
||||
|
||||
def _get_threaded_messages_per_user_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Dict[Tuple[str, str], int]:
|
||||
users_sql, users_args = make_in_list_sql_clause(
|
||||
self.database_engine, "child.sender", users
|
||||
)
|
||||
events_clause, events_args = make_in_list_sql_clause(
|
||||
txn.database_engine, "relates_to_id", event_ids
|
||||
)
|
||||
|
||||
if self._msc3440_enabled:
|
||||
relations_clause = "(relation_type = ? OR relation_type = ?)"
|
||||
relations_args = [RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD]
|
||||
else:
|
||||
relations_clause = "relation_type = ?"
|
||||
relations_args = [RelationTypes.THREAD]
|
||||
|
||||
txn.execute(
|
||||
sql % (users_sql, events_clause, relations_clause),
|
||||
users_args + events_args + relations_args,
|
||||
)
|
||||
return {(row[0], row[1]): row[2] for row in txn}
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_threaded_messages_per_user", _get_threaded_messages_per_user_txn
|
||||
)
|
||||
|
||||
@cached()
|
||||
def get_thread_participated(self, event_id: str, user_id: str) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
@ -661,7 +742,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
%s;
|
||||
"""
|
||||
|
||||
def _get_if_events_have_relations(txn) -> List[str]:
|
||||
def _get_if_events_have_relations(txn: LoggingTransaction) -> List[str]:
|
||||
clauses: List[str] = []
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "relates_to_id", parent_ids
|
||||
|
|
|
@ -361,7 +361,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
return None
|
||||
|
||||
async def get_rooms_for_local_user_where_membership_is(
|
||||
self, user_id: str, membership_list: Collection[str]
|
||||
self,
|
||||
user_id: str,
|
||||
membership_list: Collection[str],
|
||||
excluded_rooms: Optional[List[str]] = None,
|
||||
) -> List[RoomsForUser]:
|
||||
"""Get all the rooms for this *local* user where the membership for this user
|
||||
matches one in the membership list.
|
||||
|
@ -372,6 +375,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
user_id: The user ID.
|
||||
membership_list: A list of synapse.api.constants.Membership
|
||||
values which the user must be in.
|
||||
excluded_rooms: A list of rooms to ignore.
|
||||
|
||||
Returns:
|
||||
The RoomsForUser that the user matches the membership types.
|
||||
|
@ -386,12 +390,19 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
membership_list,
|
||||
)
|
||||
|
||||
# Now we filter out forgotten rooms
|
||||
forgotten_rooms = await self.get_forgotten_rooms_for_user(user_id)
|
||||
return [room for room in rooms if room.room_id not in forgotten_rooms]
|
||||
# Now we filter out forgotten and excluded rooms
|
||||
rooms_to_exclude: Set[str] = await self.get_forgotten_rooms_for_user(user_id)
|
||||
|
||||
if excluded_rooms is not None:
|
||||
rooms_to_exclude.update(set(excluded_rooms))
|
||||
|
||||
return [room for room in rooms if room.room_id not in rooms_to_exclude]
|
||||
|
||||
def _get_rooms_for_local_user_where_membership_is_txn(
|
||||
self, txn, user_id: str, membership_list: List[str]
|
||||
self,
|
||||
txn,
|
||||
user_id: str,
|
||||
membership_list: List[str],
|
||||
) -> List[RoomsForUser]:
|
||||
# Paranoia check.
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
|
@ -877,7 +888,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
return frozenset(cache.hosts_to_joined_users)
|
||||
|
||||
# Since we'll mutate the cache we need to lock.
|
||||
with (await self._joined_host_linearizer.queue(room_id)):
|
||||
async with self._joined_host_linearizer.queue(room_id):
|
||||
if state_entry.state_group == cache.state_group:
|
||||
# Same state group, so nothing to do. We've already checked for
|
||||
# this above, but the cache may have changed while waiting on
|
||||
|
|
|
@ -26,7 +26,7 @@ from synapse.util.caches.descriptors import cached, cachedList
|
|||
|
||||
class SignatureWorkerStore(EventsWorkerStore):
|
||||
@cached()
|
||||
def get_event_reference_hash(self, event_id):
|
||||
def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]]:
|
||||
# This is a dummy function to allow get_event_reference_hashes
|
||||
# to use its cache
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -12,9 +12,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections.abc
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
|
||||
|
||||
from frozendict import frozendict
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
|
||||
|
@ -29,7 +30,7 @@ from synapse.storage.database import (
|
|||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import JsonDict, StateMap
|
||||
from synapse.types import JsonDict, JsonMapping, StateMap
|
||||
from synapse.util.caches import intern_string
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
|
||||
|
@ -132,7 +133,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
return room_version
|
||||
|
||||
async def get_room_predecessor(self, room_id: str) -> Optional[dict]:
|
||||
async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]:
|
||||
"""Get the predecessor of an upgraded room if it exists.
|
||||
Otherwise return None.
|
||||
|
||||
|
@ -158,9 +159,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
predecessor = create_event.content.get("predecessor", None)
|
||||
|
||||
# Ensure the key is a dictionary
|
||||
if not isinstance(predecessor, collections.abc.Mapping):
|
||||
if not isinstance(predecessor, (dict, frozendict)):
|
||||
return None
|
||||
|
||||
# The keys must be strings since the data is JSON.
|
||||
return predecessor
|
||||
|
||||
async def get_create_event_for_room(self, room_id: str) -> EventBase:
|
||||
|
@ -202,7 +204,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
The current state of the room.
|
||||
"""
|
||||
|
||||
def _get_current_state_ids_txn(txn):
|
||||
def _get_current_state_ids_txn(txn: LoggingTransaction) -> StateMap[str]:
|
||||
txn.execute(
|
||||
"""SELECT type, state_key, event_id FROM current_state_events
|
||||
WHERE room_id = ?
|
||||
|
@ -306,8 +308,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
list_name="event_ids",
|
||||
num_args=1,
|
||||
)
|
||||
async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict:
|
||||
"""Returns mapping event_id -> state_group"""
|
||||
async def _get_state_group_for_events(
|
||||
self, event_ids: Collection[str]
|
||||
) -> Dict[str, int]:
|
||||
"""Returns mapping event_id -> state_group.
|
||||
|
||||
Raises:
|
||||
RuntimeError if the state is unknown at any of the given events
|
||||
"""
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="event_to_state_groups",
|
||||
column="event_id",
|
||||
|
@ -317,7 +325,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
desc="_get_state_group_for_events",
|
||||
)
|
||||
|
||||
return {row["event_id"]: row["state_group"] for row in rows}
|
||||
res = {row["event_id"]: row["state_group"] for row in rows}
|
||||
for e in event_ids:
|
||||
if e not in res:
|
||||
raise RuntimeError("No state group for unknown or outlier event %s" % e)
|
||||
return res
|
||||
|
||||
async def get_referenced_state_groups(
|
||||
self, state_groups: Iterable[int]
|
||||
|
@ -521,7 +533,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
|
|||
)
|
||||
|
||||
for user_id in potentially_left_users - joined_users:
|
||||
await self.mark_remote_user_device_list_as_unsubscribed(user_id)
|
||||
await self.mark_remote_user_device_list_as_unsubscribed(user_id) # type: ignore[attr-defined]
|
||||
|
||||
return batch_size
|
||||
|
||||
|
|
|
@ -36,7 +36,17 @@ what sort order was used:
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Collection,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
|
||||
import attr
|
||||
from frozendict import frozendict
|
||||
|
@ -585,7 +595,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
return ret, key
|
||||
|
||||
async def get_membership_changes_for_user(
|
||||
self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
|
||||
self,
|
||||
user_id: str,
|
||||
from_key: RoomStreamToken,
|
||||
to_key: RoomStreamToken,
|
||||
excluded_rooms: Optional[List[str]] = None,
|
||||
) -> List[EventBase]:
|
||||
"""Fetch membership events for a given user.
|
||||
|
||||
|
@ -610,23 +624,29 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
min_from_id = from_key.stream
|
||||
max_to_id = to_key.get_max_stream_pos()
|
||||
|
||||
args: List[Any] = [user_id, min_from_id, max_to_id]
|
||||
|
||||
ignore_room_clause = ""
|
||||
if excluded_rooms is not None and len(excluded_rooms) > 0:
|
||||
ignore_room_clause = "AND e.room_id NOT IN (%s)" % ",".join(
|
||||
"?" for _ in excluded_rooms
|
||||
)
|
||||
args = args + excluded_rooms
|
||||
|
||||
sql = """
|
||||
SELECT m.event_id, instance_name, topological_ordering, stream_ordering
|
||||
FROM events AS e, room_memberships AS m
|
||||
WHERE e.event_id = m.event_id
|
||||
AND m.user_id = ?
|
||||
AND e.stream_ordering > ? AND e.stream_ordering <= ?
|
||||
%s
|
||||
ORDER BY e.stream_ordering ASC
|
||||
"""
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
user_id,
|
||||
min_from_id,
|
||||
max_to_id,
|
||||
),
|
||||
""" % (
|
||||
ignore_room_clause,
|
||||
)
|
||||
|
||||
txn.execute(sql, args)
|
||||
|
||||
rows = [
|
||||
_EventDictReturn(event_id, None, stream_ordering)
|
||||
for event_id, instance_name, topological_ordering, stream_ordering in txn
|
||||
|
@ -722,7 +742,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
A tuple of (stream ordering, topological ordering, event_id)
|
||||
"""
|
||||
|
||||
def _f(txn):
|
||||
def _f(txn: LoggingTransaction) -> Optional[Tuple[int, int, str]]:
|
||||
sql = (
|
||||
"SELECT stream_ordering, topological_ordering, event_id"
|
||||
" FROM events"
|
||||
|
@ -732,27 +752,29 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
" LIMIT 1"
|
||||
)
|
||||
txn.execute(sql, (room_id, stream_ordering))
|
||||
return txn.fetchone()
|
||||
return cast(Optional[Tuple[int, int, str]], txn.fetchone())
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_room_event_before_stream_ordering", _f
|
||||
)
|
||||
|
||||
async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str:
|
||||
"""Returns the current token for rooms stream.
|
||||
async def get_current_room_stream_token_for_room_id(
|
||||
self, room_id: Optional[str] = None
|
||||
) -> RoomStreamToken:
|
||||
"""Returns the current position of the rooms stream.
|
||||
|
||||
By default, it returns the current global stream token. Specifying a
|
||||
`room_id` causes it to return the current room specific topological
|
||||
token.
|
||||
By default, it returns a live token with the current global stream
|
||||
token. Specifying a `room_id` causes it to return a historic token with
|
||||
the room specific topological token.
|
||||
"""
|
||||
token = self.get_room_max_stream_ordering()
|
||||
stream_ordering = self.get_room_max_stream_ordering()
|
||||
if room_id is None:
|
||||
return "s%d" % (token,)
|
||||
return RoomStreamToken(None, stream_ordering)
|
||||
else:
|
||||
topo = await self.db_pool.runInteraction(
|
||||
"_get_max_topological_txn", self._get_max_topological_txn, room_id
|
||||
)
|
||||
return "t%d-%d" % (topo, token)
|
||||
return RoomStreamToken(topo, stream_ordering)
|
||||
|
||||
def get_stream_id_for_event_txn(
|
||||
self,
|
||||
|
@ -827,7 +849,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
@staticmethod
|
||||
def _set_before_and_after(
|
||||
events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True
|
||||
):
|
||||
) -> None:
|
||||
"""Inserts ordering information to events' internal metadata from
|
||||
the DB rows.
|
||||
|
||||
|
@ -973,7 +995,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
the `current_id`).
|
||||
"""
|
||||
|
||||
def get_all_new_events_stream_txn(txn):
|
||||
def get_all_new_events_stream_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[int, List[str]]:
|
||||
sql = (
|
||||
"SELECT e.stream_ordering, e.event_id"
|
||||
" FROM events AS e"
|
||||
|
@ -1319,7 +1343,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
async def get_id_for_instance(self, instance_name: str) -> int:
|
||||
"""Get a unique, immutable ID that corresponds to the given Synapse worker instance."""
|
||||
|
||||
def _get_id_for_instance_txn(txn):
|
||||
def _get_id_for_instance_txn(txn: LoggingTransaction) -> int:
|
||||
instance_id = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="instance_map",
|
||||
|
|
|
@ -97,7 +97,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
|||
)
|
||||
|
||||
def get_tag_content(
|
||||
txn: LoggingTransaction, tag_ids
|
||||
txn: LoggingTransaction, tag_ids: List[Tuple[int, str, str]]
|
||||
) -> List[Tuple[int, Tuple[str, str, str]]]:
|
||||
sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
|
||||
results = []
|
||||
|
@ -251,7 +251,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
|||
return self._account_data_id_gen.get_current_token()
|
||||
|
||||
def _update_revision_txn(
|
||||
self, txn, user_id: str, room_id: str, next_id: int
|
||||
self, txn: LoggingTransaction, user_id: str, room_id: str, next_id: int
|
||||
) -> None:
|
||||
"""Update the latest revision of the tags for the given user and room.
|
||||
|
||||
|
|
|
@ -1,84 +0,0 @@
|
|||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class PaginationChunk:
|
||||
"""Returned by relation pagination APIs.
|
||||
|
||||
Attributes:
|
||||
chunk: The rows returned by pagination
|
||||
next_batch: Token to fetch next set of results with, if
|
||||
None then there are no more results.
|
||||
prev_batch: Token to fetch previous set of results with, if
|
||||
None then there are no previous results.
|
||||
"""
|
||||
|
||||
chunk: List[JsonDict]
|
||||
next_batch: Optional[Any] = None
|
||||
prev_batch: Optional[Any] = None
|
||||
|
||||
async def to_dict(self, store: "DataStore") -> Dict[str, Any]:
|
||||
d = {"chunk": self.chunk}
|
||||
|
||||
if self.next_batch:
|
||||
d["next_batch"] = await self.next_batch.to_string(store)
|
||||
|
||||
if self.prev_batch:
|
||||
d["prev_batch"] = await self.prev_batch.to_string(store)
|
||||
|
||||
return d
|
||||
|
||||
|
||||
@attr.s(frozen=True, slots=True, auto_attribs=True)
|
||||
class AggregationPaginationToken:
|
||||
"""Pagination token for relation aggregation pagination API.
|
||||
|
||||
As the results are order by count and then MAX(stream_ordering) of the
|
||||
aggregation groups, we can just use them as our pagination token.
|
||||
|
||||
Attributes:
|
||||
count: The count of relations in the boundary group.
|
||||
stream: The MAX stream ordering in the boundary group.
|
||||
"""
|
||||
|
||||
count: int
|
||||
stream: int
|
||||
|
||||
@staticmethod
|
||||
def from_string(string: str) -> "AggregationPaginationToken":
|
||||
try:
|
||||
c, s = string.split("-")
|
||||
return AggregationPaginationToken(int(c), int(s))
|
||||
except ValueError:
|
||||
raise SynapseError(400, "Invalid aggregation pagination token")
|
||||
|
||||
async def to_string(self, store: "DataStore") -> str:
|
||||
return "%d-%d" % (self.count, self.stream)
|
||||
|
||||
def as_tuple(self) -> Tuple[Any, ...]:
|
||||
return attr.astuple(self)
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
SCHEMA_VERSION = 68 # remember to update the list below when updating
|
||||
SCHEMA_VERSION = 69 # remember to update the list below when updating
|
||||
"""Represents the expectations made by the codebase about the database schema
|
||||
|
||||
This should be incremented whenever the codebase changes its requirements on the
|
||||
|
@ -58,6 +58,10 @@ Changes in SCHEMA_VERSION = 68:
|
|||
- event_reference_hashes is no longer read.
|
||||
- `events` has `state_key` and `rejection_reason` columns, which are populated for
|
||||
new events.
|
||||
|
||||
Changes in SCHEMA_VERSION = 69:
|
||||
- We now write to `device_lists_changes_in_room` table.
|
||||
- Use sequence to generate future `application_services_txns.txn_id`s
|
||||
"""
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
/* Copyright 2022 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- Add a column to track what device list changes stream id that this application
|
||||
-- service has been caught up to.
|
||||
|
||||
-- We explicitly don't set this field as "NOT NULL", as having NULL as a possible
|
||||
-- state is useful for determining if we've ever sent traffic for a stream type
|
||||
-- to an appservice. See https://github.com/matrix-org/synapse/issues/10836 for
|
||||
-- one way this can be used.
|
||||
ALTER TABLE application_services_state ADD COLUMN device_list_stream_id BIGINT;
|
44
synapse/storage/schema/main/delta/69/01as_txn_seq.py
Normal file
44
synapse/storage/schema/main/delta/69/01as_txn_seq.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2022 Beeper
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
Adds a postgres SEQUENCE for generating application service transaction IDs.
|
||||
"""
|
||||
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
|
||||
|
||||
def run_create(cur, database_engine, *args, **kwargs):
|
||||
if isinstance(database_engine, PostgresEngine):
|
||||
# If we already have some AS TXNs we want to start from the current
|
||||
# maximum value. There are two potential places this is stored - the
|
||||
# actual TXNs themselves *and* the AS state table. At time of migration
|
||||
# it is possible the TXNs table is empty so we must include the AS state
|
||||
# last_txn as a potential option, and pick the maximum.
|
||||
|
||||
cur.execute("SELECT COALESCE(max(txn_id), 0) FROM application_services_txns")
|
||||
row = cur.fetchone()
|
||||
txn_max = row[0]
|
||||
|
||||
cur.execute("SELECT COALESCE(max(last_txn), 0) FROM application_services_state")
|
||||
row = cur.fetchone()
|
||||
last_txn_max = row[0]
|
||||
|
||||
start_val = max(last_txn_max, txn_max) + 1
|
||||
|
||||
cur.execute(
|
||||
"CREATE SEQUENCE application_services_txn_id_seq START WITH %s",
|
||||
(start_val,),
|
||||
)
|
|
@ -0,0 +1,38 @@
|
|||
/* Copyright 2022 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE device_lists_changes_in_room (
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
|
||||
-- This initially matches `device_lists_stream.stream_id`. Note that we
|
||||
-- delete older values from `device_lists_stream`, so we can't use a foreign
|
||||
-- constraint here.
|
||||
--
|
||||
-- The table will contain rows with the same `stream_id` but different
|
||||
-- `room_id`, as for each device update we store a row per room the user is
|
||||
-- joined to. Therefore `(stream_id, room_id)` gives a unique index.
|
||||
stream_id BIGINT NOT NULL,
|
||||
|
||||
-- We have a background process which goes through this table and converts
|
||||
-- entries into rows in `device_lists_outbound_pokes`. Once we have processed
|
||||
-- a row, we mark it as such by setting `converted_to_destinations=TRUE`.
|
||||
converted_to_destinations BOOLEAN NOT NULL,
|
||||
opentracing_context TEXT
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX device_lists_changes_in_stream_id ON device_lists_changes_in_room(stream_id, room_id);
|
||||
CREATE INDEX device_lists_changes_in_stream_id_unconverted ON device_lists_changes_in_room(stream_id) WHERE NOT converted_to_destinations;
|
|
@ -571,6 +571,10 @@ class StateGroupStorage:
|
|||
|
||||
Returns:
|
||||
dict of state_group_id -> (dict of (type, state_key) -> event id)
|
||||
|
||||
Raises:
|
||||
RuntimeError if we don't have a state group for one or more of the events
|
||||
(ie they are outliers or unknown)
|
||||
"""
|
||||
if not event_ids:
|
||||
return {}
|
||||
|
@ -659,6 +663,10 @@ class StateGroupStorage:
|
|||
|
||||
Returns:
|
||||
A dict of (event_id) -> (type, state_key) -> [state_events]
|
||||
|
||||
Raises:
|
||||
RuntimeError if we don't have a state group for one or more of the events
|
||||
(ie they are outliers or unknown)
|
||||
"""
|
||||
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
|
||||
|
||||
|
@ -696,6 +704,10 @@ class StateGroupStorage:
|
|||
|
||||
Returns:
|
||||
A dict from event_id -> (type, state_key) -> event_id
|
||||
|
||||
Raises:
|
||||
RuntimeError if we don't have a state group for one or more of the events
|
||||
(ie they are outliers or unknown)
|
||||
"""
|
||||
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
|
||||
|
||||
|
@ -723,6 +735,10 @@ class StateGroupStorage:
|
|||
|
||||
Returns:
|
||||
A dict from (type, state_key) -> state_event
|
||||
|
||||
Raises:
|
||||
RuntimeError if we don't have a state group for the event (ie it is an
|
||||
outlier or is unknown)
|
||||
"""
|
||||
state_map = await self.get_state_for_events(
|
||||
[event_id], state_filter or StateFilter.all()
|
||||
|
@ -741,6 +757,10 @@ class StateGroupStorage:
|
|||
|
||||
Returns:
|
||||
A dict from (type, state_key) -> state_event_id
|
||||
|
||||
Raises:
|
||||
RuntimeError if we don't have a state group for the event (ie it is an
|
||||
outlier or is unknown)
|
||||
"""
|
||||
state_map = await self.get_state_ids_for_events(
|
||||
[event_id], state_filter or StateFilter.all()
|
||||
|
|
|
@ -45,6 +45,7 @@ class Cursor(Protocol):
|
|||
Sequence[
|
||||
# Note that this is an approximate typing based on sqlite3 and other
|
||||
# drivers, and may not be entirely accurate.
|
||||
# FWIW, the DBAPI 2 spec is: https://peps.python.org/pep-0249/#description
|
||||
Tuple[
|
||||
str,
|
||||
Optional[Any],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue