Merge remote-tracking branch 'upstream/release-v1.57'

This commit is contained in:
Tulir Asokan 2022-04-21 13:53:47 +03:00
commit b2fa6ec9f6
248 changed files with 14616 additions and 8934 deletions

View file

@ -33,7 +33,7 @@ from .account_data import AccountDataStore
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
from .cache import CacheInvalidationWorkerStore
from .censor_events import CensorEventsStore
from .client_ips import ClientIpStore
from .client_ips import ClientIpWorkerStore
from .deviceinbox import DeviceInboxStore
from .devices import DeviceStore
from .directory import DirectoryStore
@ -49,7 +49,7 @@ from .keys import KeyStore
from .lock import LockStore
from .media_repository import MediaRepositoryStore
from .metrics import ServerMetricsStore
from .monthly_active_users import MonthlyActiveUsersStore
from .monthly_active_users import MonthlyActiveUsersWorkerStore
from .openid import OpenIdStore
from .presence import PresenceStore
from .profile import ProfileStore
@ -112,13 +112,13 @@ class DataStore(
AccountDataStore,
EventPushActionsStore,
OpenIdStore,
ClientIpStore,
ClientIpWorkerStore,
DeviceStore,
DeviceInboxStore,
UserDirectoryStore,
GroupServerStore,
UserErasureStore,
MonthlyActiveUsersStore,
MonthlyActiveUsersWorkerStore,
StatsStore,
RelationsStore,
CensorEventsStore,
@ -146,6 +146,7 @@ class DataStore(
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
("device_lists_changes_in_room", "stream_id"),
],
)
@ -182,17 +183,6 @@ class DataStore(
super().__init__(database, db_conn, hs)
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
"DeviceListStreamChangeCache", device_list_max
)
self._user_signature_stream_cache = StreamChangeCache(
"UserSignatureStreamChangeCache", device_list_max
)
self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache", device_list_max
)
events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
db_conn,

View file

@ -14,7 +14,7 @@
# limitations under the License.
import logging
import re
from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple
from synapse.appservice import (
ApplicationService,
@ -26,10 +26,16 @@ from synapse.appservice import (
from synapse.config.appservice import load_appservices
from synapse.events import EventBase
from synapse.storage._base import db_to_json
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.types import JsonDict
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import DeviceListUpdates, JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import _CacheContext, cached
@ -72,9 +78,25 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
)
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
def get_max_as_txn_id(txn: Cursor) -> int:
logger.warning("Falling back to slow query, you should port to postgres")
txn.execute(
"SELECT COALESCE(max(txn_id), 0) FROM application_services_txns"
)
return txn.fetchone()[0] # type: ignore
self._as_txn_seq_gen = build_sequence_generator(
db_conn,
database.engine,
get_max_as_txn_id,
"application_services_txn_id_seq",
table="application_services_txns",
id_column="txn_id",
)
super().__init__(database, db_conn, hs)
def get_app_services(self):
def get_app_services(self) -> List[ApplicationService]:
return self.services_cache
def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
@ -217,6 +239,7 @@ class ApplicationServiceTransactionWorkerStore(
to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceListUpdates,
) -> AppServiceTransaction:
"""Atomically creates a new transaction for this application service
with the given list of events. Ephemeral events are NOT persisted to the
@ -231,27 +254,14 @@ class ApplicationServiceTransactionWorkerStore(
appservice devices in the transaction.
unused_fallback_keys: Lists of unused fallback keys for relevant
appservice devices in the transaction.
device_list_summary: The device list summary to include in the transaction.
Returns:
A new transaction.
"""
def _create_appservice_txn(txn):
# work out new txn id (highest txn id for this service += 1)
# The highest id may be the last one sent (in which case it is last_txn)
# or it may be the highest in the txns list (which are waiting to be/are
# being sent)
last_txn_id = self._get_last_txn(txn, service.id)
txn.execute(
"SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
(service.id,),
)
highest_txn_id = txn.fetchone()[0]
if highest_txn_id is None:
highest_txn_id = 0
new_txn_id = max(highest_txn_id, last_txn_id) + 1
def _create_appservice_txn(txn: LoggingTransaction) -> AppServiceTransaction:
new_txn_id = self._as_txn_seq_gen.get_next_id_txn(txn)
# Insert new txn into txn table
event_ids = json_encoder.encode([e.event_id for e in events])
@ -268,6 +278,7 @@ class ApplicationServiceTransactionWorkerStore(
to_device_messages=to_device_messages,
one_time_key_counts=one_time_key_counts,
unused_fallback_keys=unused_fallback_keys,
device_list_summary=device_list_summary,
)
return await self.db_pool.runInteraction(
@ -283,25 +294,8 @@ class ApplicationServiceTransactionWorkerStore(
txn_id: The transaction ID being completed.
service: The application service which was sent this transaction.
"""
txn_id = int(txn_id)
def _complete_appservice_txn(txn):
# Debugging query: Make sure the txn being completed is EXACTLY +1 from
# what was there before. If it isn't, we've got problems (e.g. the AS
# has probably missed some events), so whine loudly but still continue,
# since it shouldn't fail completion of the transaction.
last_txn_id = self._get_last_txn(txn, service.id)
if (last_txn_id + 1) != txn_id:
logger.error(
"appservice: Completing a transaction which has an ID > 1 from "
"the last ID sent to this AS. We've either dropped events or "
"sent it to the AS out of order. FIX ME. last_txn=%s "
"completing_txn=%s service_id=%s",
last_txn_id,
txn_id,
service.id,
)
def _complete_appservice_txn(txn: LoggingTransaction) -> None:
# Set current txn_id for AS to 'txn_id'
self.db_pool.simple_upsert_txn(
txn,
@ -332,7 +326,9 @@ class ApplicationServiceTransactionWorkerStore(
An AppServiceTransaction or None.
"""
def _get_oldest_unsent_txn(txn):
def _get_oldest_unsent_txn(
txn: LoggingTransaction,
) -> Optional[Dict[str, Any]]:
# Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent)
txn.execute(
@ -359,8 +355,8 @@ class ApplicationServiceTransactionWorkerStore(
events = await self.get_events_as_list(event_ids)
# TODO: to-device messages, one-time key counts and unused fallback keys
# are not yet populated for catch-up transactions.
# TODO: to-device messages, one-time key counts, device list summaries and unused
# fallback keys are not yet populated for catch-up transactions.
# We likely want to populate those for reliability.
return AppServiceTransaction(
service=service,
@ -370,21 +366,11 @@ class ApplicationServiceTransactionWorkerStore(
to_device_messages=[],
one_time_key_counts={},
unused_fallback_keys={},
device_list_summary=DeviceListUpdates(),
)
def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
txn.execute(
"SELECT last_txn FROM application_services_state WHERE as_id=?",
(service_id,),
)
last_txn_id = txn.fetchone()
if last_txn_id is None or last_txn_id[0] is None: # no row exists
return 0
else:
return int(last_txn_id[0]) # select 'last_txn' col
async def set_appservice_last_pos(self, pos: int) -> None:
def set_appservice_last_pos_txn(txn):
def set_appservice_last_pos_txn(txn: LoggingTransaction) -> None:
txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
)
@ -398,7 +384,9 @@ class ApplicationServiceTransactionWorkerStore(
) -> Tuple[int, List[EventBase]]:
"""Get all new events for an appservice"""
def get_new_events_for_appservice_txn(txn):
def get_new_events_for_appservice_txn(
txn: LoggingTransaction,
) -> Tuple[int, List[str]]:
sql = (
"SELECT e.stream_ordering, e.event_id"
" FROM events AS e"
@ -430,13 +418,13 @@ class ApplicationServiceTransactionWorkerStore(
async def get_type_stream_id_for_appservice(
self, service: ApplicationService, type: str
) -> int:
if type not in ("read_receipt", "presence", "to_device"):
if type not in ("read_receipt", "presence", "to_device", "device_list"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (type,)
)
def get_type_stream_id_for_appservice_txn(txn):
def get_type_stream_id_for_appservice_txn(txn: LoggingTransaction) -> int:
stream_id_type = "%s_stream_id" % type
txn.execute(
# We do NOT want to escape `stream_id_type`.
@ -446,7 +434,8 @@ class ApplicationServiceTransactionWorkerStore(
)
last_stream_id = txn.fetchone()
if last_stream_id is None or last_stream_id[0] is None: # no row exists
return 0
# Stream tokens always start from 1, to avoid foot guns around `0` being falsey.
return 1
else:
return int(last_stream_id[0])
@ -457,13 +446,13 @@ class ApplicationServiceTransactionWorkerStore(
async def set_appservice_stream_type_pos(
self, service: ApplicationService, stream_type: str, pos: Optional[int]
) -> None:
if stream_type not in ("read_receipt", "presence", "to_device"):
if stream_type not in ("read_receipt", "presence", "to_device", "device_list"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (stream_type,)
)
def set_appservice_stream_type_pos_txn(txn):
def set_appservice_stream_type_pos_txn(txn: LoggingTransaction) -> None:
stream_id_type = "%s_stream_id" % stream_type
txn.execute(
"UPDATE application_services_state SET %s = ? WHERE as_id=?"

View file

@ -25,7 +25,9 @@ from synapse.storage.database import (
LoggingTransaction,
make_tuple_comparison_clause,
)
from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore
from synapse.storage.databases.main.monthly_active_users import (
MonthlyActiveUsersWorkerStore,
)
from synapse.types import JsonDict, UserID
from synapse.util.caches.lrucache import LruCache
@ -397,7 +399,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return updated
class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorkerStore):
def __init__(
self,
database: DatabasePool,
@ -406,11 +408,40 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
):
super().__init__(database, db_conn, hs)
if hs.config.redis.redis_enabled:
# If we're using Redis, we can shift this update process off to
# the background worker
self._update_on_this_worker = hs.config.worker.run_background_tasks
else:
# If we're NOT using Redis, this must be handled by the master
self._update_on_this_worker = hs.get_instance_name() == "master"
self.user_ips_max_age = hs.config.server.user_ips_max_age
# (user_id, access_token, ip,) -> last_seen
self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
cache_name="client_ip_last_seen", max_size=50000
)
if hs.config.worker.run_background_tasks and self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
if self._update_on_this_worker:
# This is the designated worker that can write to the client IP
# tables.
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
self._batch_row_update: Dict[
Tuple[str, str, str], Tuple[str, Optional[str], int]
] = {}
self._client_ip_looper = self._clock.looping_call(
self._update_client_ips_batch, 5 * 1000
)
self.hs.get_reactor().addSystemEventTrigger(
"before", "shutdown", self._update_client_ips_batch
)
@wrap_as_background_process("prune_old_user_ips")
async def _prune_old_user_ips(self) -> None:
"""Removes entries in user IPs older than the configured period."""
@ -456,7 +487,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
"_prune_old_user_ips", _prune_old_user_ips_txn
)
async def get_last_client_ip_by_device(
async def _get_last_client_ip_by_device_from_database(
self, user_id: str, device_id: Optional[str]
) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]:
"""For each device_id listed, give the user_ip it was last seen on.
@ -487,7 +518,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
return {(d["user_id"], d["device_id"]): d for d in res}
async def get_user_ip_and_agents(
async def _get_user_ip_and_agents_from_database(
self, user: UserID, since_ts: int = 0
) -> List[LastConnectionInfo]:
"""Fetch the IPs and user agents for a user since the given timestamp.
@ -539,34 +570,6 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
for access_token, ip, user_agent, last_seen in rows
]
class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
# (user_id, access_token, ip,) -> last_seen
self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
cache_name="client_ip_last_seen", max_size=50000
)
super().__init__(database, db_conn, hs)
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
self._batch_row_update: Dict[
Tuple[str, str, str], Tuple[str, Optional[str], int]
] = {}
self._client_ip_looper = self._clock.looping_call(
self._update_client_ips_batch, 5 * 1000
)
self.hs.get_reactor().addSystemEventTrigger(
"before", "shutdown", self._update_client_ips_batch
)
async def insert_client_ip(
self,
user_id: str,
@ -584,17 +587,27 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
last_seen = self.client_ip_last_seen.get(key)
except KeyError:
last_seen = None
await self.populate_monthly_active_users(user_id)
# Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return
self.client_ip_last_seen.set(key, now)
self._batch_row_update[key] = (user_agent, device_id, now)
if self._update_on_this_worker:
await self.populate_monthly_active_users(user_id)
self._batch_row_update[key] = (user_agent, device_id, now)
else:
# We are not the designated writer-worker, so stream over replication
self.hs.get_replication_command_handler().send_user_ip(
user_id, access_token, ip, user_agent, device_id, now
)
@wrap_as_background_process("update_client_ips")
async def _update_client_ips_batch(self) -> None:
assert (
self._update_on_this_worker
), "This worker is not designated to update client IPs"
# If the DB pool has already terminated, don't try updating
if not self.db_pool.is_running():
@ -603,51 +616,57 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
to_update = self._batch_row_update
self._batch_row_update = {}
await self.db_pool.runInteraction(
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
)
if to_update:
await self.db_pool.runInteraction(
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
)
def _update_client_ips_batch_txn(
self,
txn: LoggingTransaction,
to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]],
) -> None:
if "user_ips" in self.db_pool._unsafe_to_upsert_tables or (
not self.database_engine.can_native_upsert
):
self.database_engine.lock_table(txn, "user_ips")
assert (
self._update_on_this_worker
), "This worker is not designated to update client IPs"
# Keys and values for the `user_ips` upsert.
user_ips_keys = []
user_ips_values = []
# Keys and values for the `devices` update.
devices_keys = []
devices_values = []
for entry in to_update.items():
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
self.db_pool.simple_upsert_txn(
txn,
table="user_ips",
keyvalues={"user_id": user_id, "access_token": access_token, "ip": ip},
values={
"user_agent": user_agent,
"device_id": device_id,
"last_seen": last_seen,
},
lock=False,
)
user_ips_keys.append((user_id, access_token, ip))
user_ips_values.append((user_agent, device_id, last_seen))
# Technically an access token might not be associated with
# a device so we need to check.
if device_id:
# this is always an update rather than an upsert: the row should
# already exist, and if it doesn't, that may be because it has been
# deleted, and we don't want to re-create it.
self.db_pool.simple_update_txn(
txn,
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
updatevalues={
"user_agent": user_agent,
"last_seen": last_seen,
"ip": ip,
},
)
devices_keys.append((user_id, device_id))
devices_values.append((user_agent, last_seen, ip))
self.db_pool.simple_upsert_many_txn(
txn,
table="user_ips",
key_names=("user_id", "access_token", "ip"),
key_values=user_ips_keys,
value_names=("user_agent", "device_id", "last_seen"),
value_values=user_ips_values,
)
if devices_values:
self.db_pool.simple_update_many_txn(
txn,
table="devices",
key_names=("user_id", "device_id"),
key_values=devices_keys,
value_names=("user_agent", "last_seen", "ip"),
value_values=devices_values,
)
async def get_last_client_ip_by_device(
self, user_id: str, device_id: Optional[str]
@ -662,7 +681,12 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
A dictionary mapping a tuple of (user_id, device_id) to dicts, with
keys giving the column names from the devices table.
"""
ret = await super().get_last_client_ip_by_device(user_id, device_id)
ret = await self._get_last_client_ip_by_device_from_database(user_id, device_id)
if not self._update_on_this_worker:
# Only the writing-worker has additional in-memory data to enhance
# the result
return ret
# Update what is retrieved from the database with data which is pending
# insertion, as if it has already been stored in the database.
@ -707,9 +731,16 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
Only the latest user agent for each access token and IP address combination
is available.
"""
rows_from_db = await self._get_user_ip_and_agents_from_database(user, since_ts)
if not self._update_on_this_worker:
# Only the writing-worker has additional in-memory data to enhance
# the result
return rows_from_db
results: Dict[Tuple[str, str], LastConnectionInfo] = {
(connection["access_token"], connection["ip"]): connection
for connection in await super().get_user_ip_and_agents(user, since_ts)
for connection in rows_from_db
}
# Overlay data that is pending insertion on top of the results from the

View file

@ -46,6 +46,7 @@ from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@ -71,6 +72,55 @@ class DeviceWorkerStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
device_list_max = self._device_list_id_gen.get_current_token()
device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
db_conn,
"device_lists_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=device_list_max,
limit=10000,
)
self._device_list_stream_cache = StreamChangeCache(
"DeviceListStreamChangeCache",
min_device_list_id,
prefilled_cache=device_list_prefill,
)
(
user_signature_stream_prefill,
user_signature_stream_list_id,
) = self.db_pool.get_cache_dict(
db_conn,
"user_signature_stream",
entity_column="from_user_id",
stream_column="stream_id",
max_value=device_list_max,
limit=1000,
)
self._user_signature_stream_cache = StreamChangeCache(
"UserSignatureStreamChangeCache",
user_signature_stream_list_id,
prefilled_cache=user_signature_stream_prefill,
)
(
device_list_federation_prefill,
device_list_federation_list_id,
) = self.db_pool.get_cache_dict(
db_conn,
"device_lists_outbound_pokes",
entity_column="destination",
stream_column="stream_id",
max_value=device_list_max,
limit=10000,
)
self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache",
device_list_federation_list_id,
prefilled_cache=device_list_federation_prefill,
)
if hs.config.worker.run_background_tasks:
self._clock.looping_call(
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
@ -681,42 +731,64 @@ class DeviceWorkerStore(SQLBaseStore):
return self._device_list_stream_cache.get_all_entities_changed(from_key)
async def get_users_whose_devices_changed(
self, from_key: int, user_ids: Iterable[str]
self,
from_key: int,
user_ids: Optional[Iterable[str]] = None,
to_key: Optional[int] = None,
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
are in the given list of user_ids.
Args:
from_key: The device lists stream token
user_ids: The user IDs to query for devices.
from_key: The minimum device lists stream token to query device list changes for,
exclusive.
user_ids: If provided, only check if these users have changed their device lists.
Otherwise changes from all users are returned.
to_key: The maximum device lists stream token to query device list changes for,
inclusive.
Returns:
The set of user_ids whose devices have changed since `from_key`
The set of user_ids whose devices have changed since `from_key` (exclusive)
until `to_key` (inclusive).
"""
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
to_check = self._device_list_stream_cache.get_entities_changed(
user_ids, from_key
)
if user_ids is None:
# Get set of all users that have had device list changes since 'from_key'
user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
from_key
)
else:
# The same as above, but filter results to only those users in 'user_ids'
user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
user_ids, from_key
)
if not to_check:
if not user_ids_to_check:
return set()
def _get_users_whose_devices_changed_txn(txn):
changes = set()
sql = """
stream_id_where_clause = "stream_id > ?"
sql_args = [from_key]
if to_key:
stream_id_where_clause += " AND stream_id <= ?"
sql_args.append(to_key)
sql = f"""
SELECT DISTINCT user_id FROM device_lists_stream
WHERE stream_id > ?
WHERE {stream_id_where_clause}
AND
"""
for chunk in batch_iter(to_check, 100):
# Query device changes with a batch of users at a time
for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
)
txn.execute(sql + clause, (from_key,) + tuple(args))
txn.execute(sql + clause, sql_args + args)
changes.update(user_id for user_id, in txn)
return changes
@ -788,6 +860,7 @@ class DeviceWorkerStore(SQLBaseStore):
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
) AS e
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
@ -1506,7 +1579,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
async def add_device_change_to_streams(
self, user_id: str, device_ids: Collection[str], hosts: Collection[str]
self,
user_id: str,
device_ids: Collection[str],
hosts: Optional[Collection[str]],
room_ids: Collection[str],
) -> Optional[int]:
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
@ -1515,7 +1592,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: The ID of the user whose device changed.
device_ids: The IDs of any changed devices. If empty, this function will
return None.
hosts: The remote destinations that should be notified of the change.
hosts: The remote destinations that should be notified of the change. If
None then the set of hosts have *not* been calculated, and will be
calculated later by a background task.
room_ids: The rooms that the user is in
Returns:
The maximum stream ID of device list updates that were added to the database, or
@ -1524,34 +1604,62 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not device_ids:
return None
async with self._device_list_id_gen.get_next_mult(
len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
"add_device_change_to_stream",
self._add_device_change_to_stream_txn,
context = get_active_span_text_map()
def add_device_changes_txn(
txn, stream_ids_for_device_change, stream_ids_for_outbound_pokes
):
self._add_device_change_to_stream_txn(
txn,
user_id,
device_ids,
stream_ids,
stream_ids_for_device_change,
)
if not hosts:
return stream_ids[-1]
self._add_device_outbound_room_poke_txn(
txn,
user_id,
device_ids,
room_ids,
stream_ids_for_device_change,
context,
hosts_have_been_calculated=hosts is not None,
)
context = get_active_span_text_map()
async with self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
"add_device_outbound_poke_to_stream",
self._add_device_outbound_poke_to_stream_txn,
# If the set of hosts to send to has not been calculated yet (and so
# `hosts` is None) or there are no `hosts` to send to, then skip
# trying to persist them to the DB.
if not hosts:
return
self._add_device_outbound_poke_to_stream_txn(
txn,
user_id,
device_ids,
hosts,
stream_ids,
stream_ids_for_outbound_pokes,
context,
)
# `device_lists_stream` wants a stream ID per device update.
num_stream_ids = len(device_ids)
if hosts:
# `device_lists_outbound_pokes` wants a different stream ID for
# each row, which is a row per host per device update.
num_stream_ids += len(hosts) * len(device_ids)
async with self._device_list_id_gen.get_next_mult(num_stream_ids) as stream_ids:
stream_ids_for_device_change = stream_ids[: len(device_ids)]
stream_ids_for_outbound_pokes = stream_ids[len(device_ids) :]
await self.db_pool.runInteraction(
"add_device_change_to_stream",
add_device_changes_txn,
stream_ids_for_device_change,
stream_ids_for_outbound_pokes,
)
return stream_ids[-1]
def _add_device_change_to_stream_txn(
@ -1595,7 +1703,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str,
device_ids: Iterable[str],
hosts: Collection[str],
stream_ids: List[str],
stream_ids: List[int],
context: Dict[str, str],
) -> None:
for host in hosts:
@ -1606,8 +1714,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
now = self._clock.time_msec()
next_stream_id = iter(stream_ids)
stream_id_iterator = iter(stream_ids)
encoded_context = json_encoder.encode(context)
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
@ -1623,16 +1732,146 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
values=[
(
destination,
next(next_stream_id),
next(stream_id_iterator),
user_id,
device_id,
False,
now,
json_encoder.encode(context)
if whitelisted_homeserver(destination)
else "{}",
encoded_context if whitelisted_homeserver(destination) else "{}",
)
for destination in hosts
for device_id in device_ids
],
)
def _add_device_outbound_room_poke_txn(
self,
txn: LoggingTransaction,
user_id: str,
device_ids: Iterable[str],
room_ids: Collection[str],
stream_ids: List[str],
context: Dict[str, str],
hosts_have_been_calculated: bool,
) -> None:
"""Record the user in the room has updated their device.
Args:
hosts_have_been_calculated: True if `device_lists_outbound_pokes`
has been updated already with the updates.
"""
# We only need to convert to outbound pokes if they are our user.
converted_to_destinations = (
hosts_have_been_calculated or not self.hs.is_mine_id(user_id)
)
encoded_context = json_encoder.encode(context)
# The `device_lists_changes_in_room.stream_id` column matches the
# corresponding `stream_id` of the update in the `device_lists_stream`
# table, i.e. all rows persisted for the same device update will have
# the same `stream_id` (but different room IDs).
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_changes_in_room",
keys=(
"user_id",
"device_id",
"room_id",
"stream_id",
"converted_to_destinations",
"opentracing_context",
),
values=[
(
user_id,
device_id,
room_id,
stream_id,
converted_to_destinations,
encoded_context,
)
for room_id in room_ids
for device_id, stream_id in zip(device_ids, stream_ids)
],
)
async def get_uncoverted_outbound_room_pokes(
self, limit: int = 10
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
"""Get device list changes by room that have not yet been handled and
written to `device_lists_outbound_pokes`.
Returns:
A list of user ID, device ID, room ID, stream ID and optional opentracing context.
"""
sql = """
SELECT user_id, device_id, room_id, stream_id, opentracing_context
FROM device_lists_changes_in_room
WHERE NOT converted_to_destinations
ORDER BY stream_id
LIMIT ?
"""
def get_uncoverted_outbound_room_pokes_txn(txn):
txn.execute(sql, (limit,))
return txn.fetchall()
return await self.db_pool.runInteraction(
"get_uncoverted_outbound_room_pokes", get_uncoverted_outbound_room_pokes_txn
)
async def add_device_list_outbound_pokes(
self,
user_id: str,
device_id: str,
room_id: str,
stream_id: int,
hosts: Collection[str],
context: Optional[Dict[str, str]],
) -> None:
"""Queue the device update to be sent to the given set of hosts,
calculated from the room ID.
Marks the associated row in `device_lists_changes_in_room` as handled.
"""
def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]):
if hosts:
self._add_device_outbound_poke_to_stream_txn(
txn,
user_id=user_id,
device_ids=[device_id],
hosts=hosts,
stream_ids=stream_ids,
context=context,
)
self.db_pool.simple_update_txn(
txn,
table="device_lists_changes_in_room",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"stream_id": stream_id,
"room_id": room_id,
},
updatevalues={"converted_to_destinations": True},
)
if not hosts:
# If there are no hosts then we don't try and generate stream IDs.
return await self.db_pool.runInteraction(
"add_device_list_outbound_pokes",
add_device_list_outbound_pokes_txn,
[],
)
async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
return await self.db_pool.runInteraction(
"add_device_list_outbound_pokes",
add_device_list_outbound_pokes_txn,
stream_ids,
)

View file

@ -197,12 +197,10 @@ class PersistEventsStore:
)
persist_event_counter.inc(len(events_and_contexts))
if stream < 0:
# backfilled events have negative stream orderings, so we don't
# want to set the event_persisted_position to that.
synapse.metrics.event_persisted_position.set(
events_and_contexts[-1][0].internal_metadata.stream_ordering
)
if not use_negative_stream_ordering:
# we don't want to set the event_persisted_position to a negative
# stream_ordering.
synapse.metrics.event_persisted_position.set(stream)
for event, context in events_and_contexts:
if context.app_service:

View file

@ -22,7 +22,6 @@ from typing import (
Dict,
Iterable,
List,
NoReturn,
Optional,
Set,
Tuple,
@ -1339,10 +1338,9 @@ class EventsWorkerStore(SQLBaseStore):
return results
@cached(max_entries=100000, tree=True)
async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
# this only exists for the benefit of the @cachedList descriptor on
# _have_seen_events_dict
raise NotImplementedError()
async def have_seen_event(self, room_id: str, event_id: str) -> bool:
res = await self._have_seen_events_dict(((room_id, event_id),))
return res[(room_id, event_id)]
def _get_current_state_event_counts_txn(
self, txn: LoggingTransaction, room_id: str

View file

@ -15,7 +15,6 @@ import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
@ -36,7 +35,7 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 60 * 60 * 1000
class MonthlyActiveUsersWorkerStore(SQLBaseStore):
class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
def __init__(
self,
database: DatabasePool,
@ -47,9 +46,30 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
self._clock = hs.get_clock()
self.hs = hs
if hs.config.redis.redis_enabled:
# If we're using Redis, we can shift this update process off to
# the background worker
self._update_on_this_worker = hs.config.worker.run_background_tasks
else:
# If we're NOT using Redis, this must be handled by the master
self._update_on_this_worker = hs.get_instance_name() == "master"
self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau
self._max_mau_value = hs.config.server.max_mau_value
self._mau_stats_only = hs.config.server.mau_stats_only
if self._update_on_this_worker:
# Do not add more reserved users than the total allowable number
self.db_pool.new_transaction(
db_conn,
"initialise_mau_threepids",
[],
[],
self._initialise_reserved_users,
hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
)
@cached(num_args=0)
async def get_monthly_active_count(self) -> int:
"""Generates current count of monthly active users
@ -222,28 +242,6 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"reap_monthly_active_users", _reap_users, reserved_users
)
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self._mau_stats_only = hs.config.server.mau_stats_only
# Do not add more reserved users than the total allowable number
self.db_pool.new_transaction(
db_conn,
"initialise_mau_threepids",
[],
[],
self._initialise_reserved_users,
hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
)
def _initialise_reserved_users(
self, txn: LoggingTransaction, threepids: List[dict]
) -> None:
@ -254,6 +252,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
txn:
threepids: List of threepid dicts to reserve
"""
assert (
self._update_on_this_worker
), "This worker is not designated to update MAUs"
# XXX what is this function trying to achieve? It upserts into
# monthly_active_users for each *registered* reserved mau user, but why?
@ -287,6 +288,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
Args:
user_id: user to add/update
"""
assert (
self._update_on_this_worker
), "This worker is not designated to update MAUs"
# Support user never to be included in MAU stats. Note I can't easily call this
# from upsert_monthly_active_user_txn because then I need a _txn form of
# is_support_user which is complicated because I want to cache the result.
@ -322,6 +327,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
txn (cursor):
user_id (str): user to add/update
"""
assert (
self._update_on_this_worker
), "This worker is not designated to update MAUs"
# Am consciously deciding to lock the table on the basis that is ought
# never be a big table and alternative approaches (batching multiple
@ -349,6 +357,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
Args:
user_id(str): the user_id to query
"""
assert (
self._update_on_this_worker
), "This worker is not designated to update MAUs"
if self._limit_usage_by_mau or self._mau_stats_only:
# Trial users and guests should not be included as part of MAU group
is_guest = await self.is_guest(user_id) # type: ignore[attr-defined]

View file

@ -98,8 +98,19 @@ class ReceiptsWorkerStore(SQLBaseStore):
super().__init__(database, db_conn, hs)
max_receipts_stream_id = self.get_max_receipt_stream_id()
receipts_stream_prefill, min_receipts_stream_id = self.db_pool.get_cache_dict(
db_conn,
"receipts_linearized",
entity_column="room_id",
stream_column="stream_id",
max_value=max_receipts_stream_id,
limit=10000,
)
self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
"ReceiptsRoomChangeCache",
min_receipts_stream_id,
prefilled_cache=receipts_stream_prefill,
)
def get_max_receipt_stream_id(self) -> int:

View file

@ -34,7 +34,7 @@ from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import IdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID, UserInfo
from synapse.types import JsonDict, UserID, UserInfo
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
@ -79,7 +79,7 @@ class TokenLookupResult:
# Make the token owner default to the user ID, which is the common case.
@token_owner.default
def _default_token_owner(self):
def _default_token_owner(self) -> str:
return self.user_id
@ -299,7 +299,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
the account.
"""
def set_account_validity_for_user_txn(txn):
def set_account_validity_for_user_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_txn(
txn=txn,
table="account_validity",
@ -385,23 +385,25 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="get_renewal_token_for_user",
)
async def get_users_expiring_soon(self) -> List[Dict[str, Any]]:
async def get_users_expiring_soon(self) -> List[Tuple[str, int]]:
"""Selects users whose account will expire in the [now, now + renew_at] time
window (see configuration for account_validity for information on what renew_at
refers to).
Returns:
A list of dictionaries, each with a user ID and expiration time (in milliseconds).
A list of tuples, each with a user ID and expiration time (in milliseconds).
"""
def select_users_txn(txn, now_ms, renew_at):
def select_users_txn(
txn: LoggingTransaction, now_ms: int, renew_at: int
) -> List[Tuple[str, int]]:
sql = (
"SELECT user_id, expiration_ts_ms FROM account_validity"
" WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?"
)
values = [False, now_ms, renew_at]
txn.execute(sql, values)
return self.db_pool.cursor_to_dict(txn)
return cast(List[Tuple[str, int]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_users_expiring_soon",
@ -466,7 +468,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
admin: true iff the user is to be a server admin, false otherwise.
"""
def set_server_admin_txn(txn):
def set_server_admin_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_one_txn(
txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0}
)
@ -515,7 +517,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
user_type: type of the user or None for a user without a type.
"""
def set_user_type_txn(txn):
def set_user_type_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_one_txn(
txn, "users", {"name": user.to_string()}, {"user_type": user_type}
)
@ -525,7 +527,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("set_user_type", set_user_type_txn)
def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
def _query_for_auth(
self, txn: LoggingTransaction, token: str
) -> Optional[TokenLookupResult]:
sql = """
SELECT users.name as user_id,
users.is_guest,
@ -582,7 +586,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"is_support_user", self.is_support_user_txn, user_id
)
def is_real_user_txn(self, txn, user_id):
def is_real_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool:
res = self.db_pool.simple_select_one_onecol_txn(
txn=txn,
table="users",
@ -592,7 +596,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
return res is None
def is_support_user_txn(self, txn, user_id):
def is_support_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool:
res = self.db_pool.simple_select_one_onecol_txn(
txn=txn,
table="users",
@ -609,10 +613,11 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
A mapping of user_id -> password_hash.
"""
def f(txn):
def f(txn: LoggingTransaction) -> Dict[str, str]:
sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)"
txn.execute(sql, (user_id,))
return dict(txn)
result = cast(List[Tuple[str, str]], txn.fetchall())
return dict(result)
return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
@ -734,7 +739,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def _replace_user_external_id_txn(
txn: LoggingTransaction,
):
) -> None:
_remove_user_external_ids_txn(txn, user_id)
for auth_provider, external_id in record_external_ids:
@ -790,10 +795,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
return [(r["auth_provider"], r["external_id"]) for r in res]
async def count_all_users(self):
async def count_all_users(self) -> int:
"""Counts all users registered on the homeserver."""
def _count_users(txn):
def _count_users(txn: LoggingTransaction) -> int:
txn.execute("SELECT COUNT(*) AS users FROM users")
rows = self.db_pool.cursor_to_dict(txn)
if rows:
@ -810,7 +815,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
who registered on the homeserver in the past 24 hours
"""
def _count_daily_user_type(txn):
def _count_daily_user_type(txn: LoggingTransaction) -> Dict[str, int]:
yesterday = int(self._clock.time()) - (60 * 60 * 24)
sql = """
@ -835,23 +840,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"count_daily_user_type", _count_daily_user_type
)
async def count_nonbridged_users(self):
def _count_users(txn):
async def count_nonbridged_users(self) -> int:
def _count_users(txn: LoggingTransaction) -> int:
txn.execute(
"""
SELECT COUNT(*) FROM users
WHERE appservice_id IS NULL
"""
)
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction("count_users", _count_users)
async def count_real_users(self):
async def count_real_users(self) -> int:
"""Counts all users without a special user_type registered on the homeserver."""
def _count_users(txn):
def _count_users(txn: LoggingTransaction) -> int:
txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
rows = self.db_pool.cursor_to_dict(txn)
if rows:
@ -888,7 +893,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return user_id
def get_user_id_by_threepid_txn(
self, txn, medium: str, address: str
self, txn: LoggingTransaction, medium: str, address: str
) -> Optional[str]:
"""Returns user id from threepid
@ -925,7 +930,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
)
async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]:
async def user_get_threepids(self, user_id: str) -> List[Dict[str, Any]]:
return await self.db_pool.simple_select_list(
"user_threepids",
{"user_id": user_id},
@ -957,7 +962,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
async def add_user_bound_threepid(
self, user_id: str, medium: str, address: str, id_server: str
):
) -> None:
"""The server proxied a bind request to the given identity server on
behalf of the given user. We need to remember this in case the user
asks us to unbind the threepid.
@ -1116,7 +1121,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
assert address or sid
def get_threepid_validation_session_txn(txn):
def get_threepid_validation_session_txn(
txn: LoggingTransaction,
) -> Optional[Dict[str, Any]]:
sql = """
SELECT address, session_id, medium, client_secret,
last_send_attempt, validated_at
@ -1150,7 +1157,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
session_id: The ID of the session to delete
"""
def delete_threepid_session_txn(txn):
def delete_threepid_session_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn,
table="threepid_validation_token",
@ -1170,7 +1177,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
async def cull_expired_threepid_validation_tokens(self) -> None:
"""Remove threepid validation tokens with expiry dates that have passed"""
def cull_expired_threepid_validation_tokens_txn(txn, ts):
def cull_expired_threepid_validation_tokens_txn(
txn: LoggingTransaction, ts: int
) -> None:
sql = """
DELETE FROM threepid_validation_token WHERE
expires < ?
@ -1184,13 +1193,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
@wrap_as_background_process("account_validity_set_expiration_dates")
async def _set_expiration_date_when_missing(self):
async def _set_expiration_date_when_missing(self) -> None:
"""
Retrieves the list of registered users that don't have an expiration date, and
adds an expiration date for each of them.
"""
def select_users_with_no_expiration_date_txn(txn):
def select_users_with_no_expiration_date_txn(txn: LoggingTransaction) -> None:
"""Retrieves the list of registered users with no expiration date from the
database, filtering out deactivated users.
"""
@ -1213,7 +1222,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
select_users_with_no_expiration_date_txn,
)
def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
def set_expiration_date_for_user_txn(
self, txn: LoggingTransaction, user_id: str, use_delta: bool = False
) -> None:
"""Sets an expiration date to the account with the given user ID.
Args:
@ -1344,7 +1355,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
token: The registration token pending use
"""
def _set_registration_token_pending_txn(txn):
def _set_registration_token_pending_txn(txn: LoggingTransaction) -> None:
pending = self.db_pool.simple_select_one_onecol_txn(
txn,
"registration_tokens",
@ -1358,7 +1369,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
updatevalues={"pending": pending + 1},
)
return await self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"set_registration_token_pending", _set_registration_token_pending_txn
)
@ -1372,7 +1383,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
token: The registration token to be 'used'
"""
def _use_registration_token_txn(txn):
def _use_registration_token_txn(txn: LoggingTransaction) -> None:
# Normally, res is Optional[Dict[str, Any]].
# Override type because the return type is only optional if
# allow_none is True, and we don't want mypy throwing errors
@ -1398,7 +1409,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
},
)
return await self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"use_registration_token", _use_registration_token_txn
)
@ -1416,7 +1427,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
A list of dicts, each containing details of a token.
"""
def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]):
def select_registration_tokens_txn(
txn: LoggingTransaction, now: int, valid: Optional[bool]
) -> List[Dict[str, Any]]:
if valid is None:
# Return all tokens regardless of validity
txn.execute("SELECT * FROM registration_tokens")
@ -1523,7 +1536,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Whether the row was inserted or not.
"""
def _create_registration_token_txn(txn):
def _create_registration_token_txn(txn: LoggingTransaction) -> bool:
row = self.db_pool.simple_select_one_txn(
txn,
"registration_tokens",
@ -1570,7 +1583,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
A dict with all info about the token, or None if token doesn't exist.
"""
def _update_registration_token_txn(txn):
def _update_registration_token_txn(
txn: LoggingTransaction,
) -> Optional[Dict[str, Any]]:
try:
self.db_pool.simple_update_one_txn(
txn,
@ -1651,7 +1666,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
) -> Optional[RefreshTokenLookupResult]:
"""Lookup a refresh token with hints about its validity."""
def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]:
def _lookup_refresh_token_txn(
txn: LoggingTransaction,
) -> Optional[RefreshTokenLookupResult]:
txn.execute(
"""
SELECT
@ -1745,6 +1762,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"replace_refresh_token", _replace_refresh_token_txn
)
@cached()
async def is_guest(self, user_id: str) -> bool:
res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="is_guest",
allow_none=True,
desc="is_guest",
)
return res if res else False
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(
@ -1795,14 +1824,18 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
unique=False,
)
async def _background_update_set_deactivated_flag(self, progress, batch_size):
async def _background_update_set_deactivated_flag(
self, progress: JsonDict, batch_size: int
) -> int:
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
for each of them.
"""
last_user = progress.get("user_id", "")
def _background_update_set_deactivated_flag_txn(txn):
def _background_update_set_deactivated_flag_txn(
txn: LoggingTransaction,
) -> Tuple[bool, int]:
txn.execute(
"""
SELECT
@ -1874,7 +1907,9 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
deactivated,
)
def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool):
def set_user_deactivated_status_txn(
self, txn: LoggingTransaction, user_id: str, deactivated: bool
) -> None:
self.db_pool.simple_update_one_txn(
txn=txn,
table="users",
@ -1887,18 +1922,6 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
@cached()
async def is_guest(self, user_id: str) -> bool:
res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="is_guest",
allow_none=True,
desc="is_guest",
)
return res if res else False
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
def __init__(
@ -2005,7 +2028,9 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
return next_id
def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
def _set_device_for_access_token_txn(
self, txn: LoggingTransaction, token: str, device_id: str
) -> str:
old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn, "access_tokens", {"token": token}, "device_id"
)
@ -2084,7 +2109,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
def _register_user(
self,
txn,
txn: LoggingTransaction,
user_id: str,
password_hash: Optional[str],
was_guest: bool,
@ -2094,7 +2119,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
admin: bool,
user_type: Optional[str],
shadow_banned: bool,
):
) -> None:
user_id_obj = UserID.from_string(user_id)
now = int(self._clock.time())
@ -2181,7 +2206,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
pointless. Use flush_user separately.
"""
def user_set_password_hash_txn(txn):
def user_set_password_hash_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_one_txn(
txn, "users", {"name": user_id}, {"password_hash": password_hash}
)
@ -2204,7 +2229,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
StoreError(404) if user not found
"""
def f(txn):
def f(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_one_txn(
txn,
table="users",
@ -2229,7 +2254,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
StoreError(404) if user not found
"""
def f(txn):
def f(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_one_txn(
txn,
table="users",
@ -2259,7 +2284,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
A tuple of (token, token id, device id) for each of the deleted tokens
"""
def f(txn):
def f(txn: LoggingTransaction) -> List[Tuple[str, int, Optional[str]]]:
keyvalues = {"user_id": user_id}
if device_id is not None:
keyvalues["device_id"] = device_id
@ -2301,7 +2326,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
return await self.db_pool.runInteraction("user_delete_access_tokens", f)
async def delete_access_token(self, access_token: str) -> None:
def f(txn):
def f(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_one_txn(
txn, table="access_tokens", keyvalues={"token": access_token}
)
@ -2313,7 +2338,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
await self.db_pool.runInteraction("delete_access_token", f)
async def delete_refresh_token(self, refresh_token: str) -> None:
def f(txn):
def f(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_one_txn(
txn, table="refresh_tokens", keyvalues={"token": refresh_token}
)
@ -2353,7 +2378,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"""
# Insert everything into a transaction in order to run atomically
def validate_threepid_session_txn(txn):
def validate_threepid_session_txn(txn: LoggingTransaction) -> Optional[str]:
row = self.db_pool.simple_select_one_txn(
txn,
table="threepid_validation_session",
@ -2450,7 +2475,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
longer be valid
"""
def start_or_continue_validation_session_txn(txn):
def start_or_continue_validation_session_txn(txn: LoggingTransaction) -> None:
# Create or update a validation session
self.db_pool.simple_upsert_txn(
txn,

View file

@ -17,6 +17,7 @@ from typing import (
TYPE_CHECKING,
Collection,
Dict,
FrozenSet,
Iterable,
List,
Optional,
@ -39,8 +40,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine
from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
from synapse.types import RoomStreamToken, StreamToken
from synapse.types import JsonDict, RoomStreamToken, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
@ -49,6 +49,19 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _RelatedEvent:
"""
Contains enough information about a related event in order to properly filter
events from ignored users.
"""
# The event ID of the related event.
event_id: str
# The sender of the related event.
sender: str
class RelationsWorkerStore(SQLBaseStore):
def __init__(
self,
@ -73,7 +86,7 @@ class RelationsWorkerStore(SQLBaseStore):
direction: str = "b",
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
) -> PaginationChunk:
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
"""Get a list of relations for an event, ordered by topological ordering.
Args:
@ -90,8 +103,10 @@ class RelationsWorkerStore(SQLBaseStore):
to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
List of event IDs that match relations requested. The rows are of
the form `{"event_id": "..."}`.
A tuple of:
A list of related event IDs & their senders.
The next stream token, if one exists.
"""
# We don't use `event_id`, it's there so that we can cache based on
# it. The `event_id` must match the `event.event_id`.
@ -132,7 +147,7 @@ class RelationsWorkerStore(SQLBaseStore):
order = "ASC"
sql = """
SELECT event_id, relation_type, topological_ordering, stream_ordering
SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
FROM event_relations
INNER JOIN events USING (event_id)
WHERE %s
@ -146,7 +161,7 @@ class RelationsWorkerStore(SQLBaseStore):
def _get_recent_references_for_event_txn(
txn: LoggingTransaction,
) -> PaginationChunk:
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
txn.execute(sql, where_args + [limit + 1])
last_topo_id = None
@ -156,9 +171,9 @@ class RelationsWorkerStore(SQLBaseStore):
# Do not include edits for redacted events as they leak event
# content.
if not is_redacted or row[1] != RelationTypes.REPLACE:
events.append({"event_id": row[0]})
last_topo_id = row[2]
last_stream_id = row[3]
events.append(_RelatedEvent(row[0], row[2]))
last_topo_id = row[3]
last_stream_id = row[4]
# If there are more events, generate the next pagination key.
next_token = None
@ -179,9 +194,7 @@ class RelationsWorkerStore(SQLBaseStore):
groups_key=0,
)
return PaginationChunk(
chunk=list(events[:limit]), next_batch=next_token, prev_batch=from_token
)
return events[:limit], next_token
return await self.db_pool.runInteraction(
"get_recent_references_for_event", _get_recent_references_for_event_txn
@ -252,15 +265,8 @@ class RelationsWorkerStore(SQLBaseStore):
@cached(tree=True)
async def get_aggregation_groups_for_event(
self,
event_id: str,
room_id: str,
event_type: Optional[str] = None,
limit: int = 5,
direction: str = "b",
from_token: Optional[AggregationPaginationToken] = None,
to_token: Optional[AggregationPaginationToken] = None,
) -> PaginationChunk:
self, event_id: str, room_id: str, limit: int = 5
) -> List[JsonDict]:
"""Get a list of annotations on the event, grouped by event type and
aggregation key, sorted by count.
@ -270,82 +276,96 @@ class RelationsWorkerStore(SQLBaseStore):
Args:
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
event_type: Only fetch events with this event type, if given.
limit: Only fetch the `limit` groups.
direction: Whether to fetch the highest count first (`"b"`) or
the lowest count first (`"f"`).
from_token: Fetch rows from the given token, or from the start if None.
to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
List of groups of annotations that match. Each row is a dict with
`type`, `key` and `count` fields.
"""
where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"]
where_args: List[Union[str, int]] = [
args = [
event_id,
room_id,
RelationTypes.ANNOTATION,
limit,
]
sql = """
SELECT type, aggregation_key, COUNT(DISTINCT sender)
FROM event_relations
INNER JOIN events USING (event_id)
WHERE relates_to_id = ? AND room_id = ? AND relation_type = ?
GROUP BY relation_type, type, aggregation_key
ORDER BY COUNT(*) DESC
LIMIT ?
"""
def _get_aggregation_groups_for_event_txn(
txn: LoggingTransaction,
) -> List[JsonDict]:
txn.execute(sql, args)
return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn]
return await self.db_pool.runInteraction(
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
)
async def get_aggregation_groups_for_users(
self,
event_id: str,
room_id: str,
limit: int,
users: FrozenSet[str] = frozenset(),
) -> Dict[Tuple[str, str], int]:
"""Fetch the partial aggregations for an event for specific users.
This is used, in conjunction with get_aggregation_groups_for_event, to
remove information from the results for ignored users.
Args:
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
limit: Only fetch the `limit` groups.
users: The users to fetch information for.
Returns:
A map of (event type, aggregation key) to a count of users.
"""
if not users:
return {}
args: List[Union[str, int]] = [
event_id,
room_id,
RelationTypes.ANNOTATION,
]
if event_type:
where_clause.append("type = ?")
where_args.append(event_type)
having_clause = generate_pagination_where_clause(
direction=direction,
column_names=("COUNT(*)", "MAX(stream_ordering)"),
from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
engine=self.database_engine,
users_sql, users_args = make_in_list_sql_clause(
self.database_engine, "sender", users
)
args.extend(users_args)
if direction == "b":
order = "DESC"
else:
order = "ASC"
if having_clause:
having_clause = "HAVING " + having_clause
else:
having_clause = ""
sql = """
SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering)
sql = f"""
SELECT type, aggregation_key, COUNT(DISTINCT sender)
FROM event_relations
INNER JOIN events USING (event_id)
WHERE {where_clause}
WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql}
GROUP BY relation_type, type, aggregation_key
{having_clause}
ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
ORDER BY COUNT(*) DESC
LIMIT ?
""".format(
where_clause=" AND ".join(where_clause),
order=order,
having_clause=having_clause,
)
"""
def _get_aggregation_groups_for_event_txn(
def _get_aggregation_groups_for_users_txn(
txn: LoggingTransaction,
) -> PaginationChunk:
txn.execute(sql, where_args + [limit + 1])
) -> Dict[Tuple[str, str], int]:
txn.execute(sql, args + [limit])
next_batch = None
events = []
for row in txn:
events.append({"type": row[0], "key": row[1], "count": row[2]})
next_batch = AggregationPaginationToken(row[2], row[3])
if len(events) <= limit:
next_batch = None
return PaginationChunk(
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
)
return {(row[0], row[1]): row[2] for row in txn}
return await self.db_pool.runInteraction(
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
"get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn
)
@cached()
@ -574,6 +594,67 @@ class RelationsWorkerStore(SQLBaseStore):
return summaries
async def get_threaded_messages_per_user(
self,
event_ids: Collection[str],
users: FrozenSet[str] = frozenset(),
) -> Dict[Tuple[str, str], int]:
"""Get the number of threaded replies for a set of users.
This is used, in conjunction with get_thread_summaries, to calculate an
accurate count of the replies to a thread by subtracting ignored users.
Args:
event_ids: The events to check for threaded replies.
users: The user to calculate the count of their replies.
Returns:
A map of the (event_id, sender) to the count of their replies.
"""
if not users:
return {}
# Fetch the number of threaded replies.
sql = """
SELECT parent.event_id, child.sender, COUNT(child.event_id) FROM events AS child
INNER JOIN event_relations USING (event_id)
INNER JOIN events AS parent ON
parent.event_id = relates_to_id
AND parent.room_id = child.room_id
WHERE
%s
AND %s
AND %s
GROUP BY parent.event_id, child.sender
"""
def _get_threaded_messages_per_user_txn(
txn: LoggingTransaction,
) -> Dict[Tuple[str, str], int]:
users_sql, users_args = make_in_list_sql_clause(
self.database_engine, "child.sender", users
)
events_clause, events_args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", event_ids
)
if self._msc3440_enabled:
relations_clause = "(relation_type = ? OR relation_type = ?)"
relations_args = [RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD]
else:
relations_clause = "relation_type = ?"
relations_args = [RelationTypes.THREAD]
txn.execute(
sql % (users_sql, events_clause, relations_clause),
users_args + events_args + relations_args,
)
return {(row[0], row[1]): row[2] for row in txn}
return await self.db_pool.runInteraction(
"get_threaded_messages_per_user", _get_threaded_messages_per_user_txn
)
@cached()
def get_thread_participated(self, event_id: str, user_id: str) -> bool:
raise NotImplementedError()
@ -661,7 +742,7 @@ class RelationsWorkerStore(SQLBaseStore):
%s;
"""
def _get_if_events_have_relations(txn) -> List[str]:
def _get_if_events_have_relations(txn: LoggingTransaction) -> List[str]:
clauses: List[str] = []
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", parent_ids

View file

@ -361,7 +361,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return None
async def get_rooms_for_local_user_where_membership_is(
self, user_id: str, membership_list: Collection[str]
self,
user_id: str,
membership_list: Collection[str],
excluded_rooms: Optional[List[str]] = None,
) -> List[RoomsForUser]:
"""Get all the rooms for this *local* user where the membership for this user
matches one in the membership list.
@ -372,6 +375,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
user_id: The user ID.
membership_list: A list of synapse.api.constants.Membership
values which the user must be in.
excluded_rooms: A list of rooms to ignore.
Returns:
The RoomsForUser that the user matches the membership types.
@ -386,12 +390,19 @@ class RoomMemberWorkerStore(EventsWorkerStore):
membership_list,
)
# Now we filter out forgotten rooms
forgotten_rooms = await self.get_forgotten_rooms_for_user(user_id)
return [room for room in rooms if room.room_id not in forgotten_rooms]
# Now we filter out forgotten and excluded rooms
rooms_to_exclude: Set[str] = await self.get_forgotten_rooms_for_user(user_id)
if excluded_rooms is not None:
rooms_to_exclude.update(set(excluded_rooms))
return [room for room in rooms if room.room_id not in rooms_to_exclude]
def _get_rooms_for_local_user_where_membership_is_txn(
self, txn, user_id: str, membership_list: List[str]
self,
txn,
user_id: str,
membership_list: List[str],
) -> List[RoomsForUser]:
# Paranoia check.
if not self.hs.is_mine_id(user_id):
@ -877,7 +888,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return frozenset(cache.hosts_to_joined_users)
# Since we'll mutate the cache we need to lock.
with (await self._joined_host_linearizer.queue(room_id)):
async with self._joined_host_linearizer.queue(room_id):
if state_entry.state_group == cache.state_group:
# Same state group, so nothing to do. We've already checked for
# this above, but the cache may have changed while waiting on

View file

@ -26,7 +26,7 @@ from synapse.util.caches.descriptors import cached, cachedList
class SignatureWorkerStore(EventsWorkerStore):
@cached()
def get_event_reference_hash(self, event_id):
def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]]:
# This is a dummy function to allow get_event_reference_hashes
# to use its cache
raise NotImplementedError()

View file

@ -12,9 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc
import logging
from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple
from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
from frozendict import frozendict
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
@ -29,7 +30,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter
from synapse.types import JsonDict, StateMap
from synapse.types import JsonDict, JsonMapping, StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
@ -132,7 +133,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return room_version
async def get_room_predecessor(self, room_id: str) -> Optional[dict]:
async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]:
"""Get the predecessor of an upgraded room if it exists.
Otherwise return None.
@ -158,9 +159,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
predecessor = create_event.content.get("predecessor", None)
# Ensure the key is a dictionary
if not isinstance(predecessor, collections.abc.Mapping):
if not isinstance(predecessor, (dict, frozendict)):
return None
# The keys must be strings since the data is JSON.
return predecessor
async def get_create_event_for_room(self, room_id: str) -> EventBase:
@ -202,7 +204,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
The current state of the room.
"""
def _get_current_state_ids_txn(txn):
def _get_current_state_ids_txn(txn: LoggingTransaction) -> StateMap[str]:
txn.execute(
"""SELECT type, state_key, event_id FROM current_state_events
WHERE room_id = ?
@ -306,8 +308,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
list_name="event_ids",
num_args=1,
)
async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict:
"""Returns mapping event_id -> state_group"""
async def _get_state_group_for_events(
self, event_ids: Collection[str]
) -> Dict[str, int]:
"""Returns mapping event_id -> state_group.
Raises:
RuntimeError if the state is unknown at any of the given events
"""
rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
@ -317,7 +325,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="_get_state_group_for_events",
)
return {row["event_id"]: row["state_group"] for row in rows}
res = {row["event_id"]: row["state_group"] for row in rows}
for e in event_ids:
if e not in res:
raise RuntimeError("No state group for unknown or outlier event %s" % e)
return res
async def get_referenced_state_groups(
self, state_groups: Iterable[int]
@ -521,7 +533,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
)
for user_id in potentially_left_users - joined_users:
await self.mark_remote_user_device_list_as_unsubscribed(user_id)
await self.mark_remote_user_device_list_as_unsubscribed(user_id) # type: ignore[attr-defined]
return batch_size

View file

@ -36,7 +36,17 @@ what sort order was used:
"""
import logging
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
List,
Optional,
Set,
Tuple,
cast,
)
import attr
from frozendict import frozendict
@ -585,7 +595,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
async def get_membership_changes_for_user(
self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
self,
user_id: str,
from_key: RoomStreamToken,
to_key: RoomStreamToken,
excluded_rooms: Optional[List[str]] = None,
) -> List[EventBase]:
"""Fetch membership events for a given user.
@ -610,23 +624,29 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
min_from_id = from_key.stream
max_to_id = to_key.get_max_stream_pos()
args: List[Any] = [user_id, min_from_id, max_to_id]
ignore_room_clause = ""
if excluded_rooms is not None and len(excluded_rooms) > 0:
ignore_room_clause = "AND e.room_id NOT IN (%s)" % ",".join(
"?" for _ in excluded_rooms
)
args = args + excluded_rooms
sql = """
SELECT m.event_id, instance_name, topological_ordering, stream_ordering
FROM events AS e, room_memberships AS m
WHERE e.event_id = m.event_id
AND m.user_id = ?
AND e.stream_ordering > ? AND e.stream_ordering <= ?
%s
ORDER BY e.stream_ordering ASC
"""
txn.execute(
sql,
(
user_id,
min_from_id,
max_to_id,
),
""" % (
ignore_room_clause,
)
txn.execute(sql, args)
rows = [
_EventDictReturn(event_id, None, stream_ordering)
for event_id, instance_name, topological_ordering, stream_ordering in txn
@ -722,7 +742,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
A tuple of (stream ordering, topological ordering, event_id)
"""
def _f(txn):
def _f(txn: LoggingTransaction) -> Optional[Tuple[int, int, str]]:
sql = (
"SELECT stream_ordering, topological_ordering, event_id"
" FROM events"
@ -732,27 +752,29 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
" LIMIT 1"
)
txn.execute(sql, (room_id, stream_ordering))
return txn.fetchone()
return cast(Optional[Tuple[int, int, str]], txn.fetchone())
return await self.db_pool.runInteraction(
"get_room_event_before_stream_ordering", _f
)
async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str:
"""Returns the current token for rooms stream.
async def get_current_room_stream_token_for_room_id(
self, room_id: Optional[str] = None
) -> RoomStreamToken:
"""Returns the current position of the rooms stream.
By default, it returns the current global stream token. Specifying a
`room_id` causes it to return the current room specific topological
token.
By default, it returns a live token with the current global stream
token. Specifying a `room_id` causes it to return a historic token with
the room specific topological token.
"""
token = self.get_room_max_stream_ordering()
stream_ordering = self.get_room_max_stream_ordering()
if room_id is None:
return "s%d" % (token,)
return RoomStreamToken(None, stream_ordering)
else:
topo = await self.db_pool.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn, room_id
)
return "t%d-%d" % (topo, token)
return RoomStreamToken(topo, stream_ordering)
def get_stream_id_for_event_txn(
self,
@ -827,7 +849,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
@staticmethod
def _set_before_and_after(
events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True
):
) -> None:
"""Inserts ordering information to events' internal metadata from
the DB rows.
@ -973,7 +995,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
the `current_id`).
"""
def get_all_new_events_stream_txn(txn):
def get_all_new_events_stream_txn(
txn: LoggingTransaction,
) -> Tuple[int, List[str]]:
sql = (
"SELECT e.stream_ordering, e.event_id"
" FROM events AS e"
@ -1319,7 +1343,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_id_for_instance(self, instance_name: str) -> int:
"""Get a unique, immutable ID that corresponds to the given Synapse worker instance."""
def _get_id_for_instance_txn(txn):
def _get_id_for_instance_txn(txn: LoggingTransaction) -> int:
instance_id = self.db_pool.simple_select_one_onecol_txn(
txn,
table="instance_map",

View file

@ -97,7 +97,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
)
def get_tag_content(
txn: LoggingTransaction, tag_ids
txn: LoggingTransaction, tag_ids: List[Tuple[int, str, str]]
) -> List[Tuple[int, Tuple[str, str, str]]]:
sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
results = []
@ -251,7 +251,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
return self._account_data_id_gen.get_current_token()
def _update_revision_txn(
self, txn, user_id: str, room_id: str, next_id: int
self, txn: LoggingTransaction, user_id: str, room_id: str, next_id: int
) -> None:
"""Update the latest revision of the tags for the given user and room.