Merge remote-tracking branch 'upstream/release-v1.35'

This commit is contained in:
Tulir Asokan 2021-05-25 13:13:39 +03:00
commit 4740b83c39
98 changed files with 6440 additions and 2105 deletions

View file

@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from abc import ABCMeta
from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union
@ -44,7 +43,6 @@ class SQLBaseStore(metaclass=ABCMeta):
self._clock = hs.get_clock()
self.database_engine = database.engine
self.db_pool = database
self.rand = random.SystemRandom()
def process_replication_rows(
self,

View file

@ -67,7 +67,7 @@ from .state import StateStore
from .stats import StatsStore
from .stream import StreamStore
from .tags import TagsStore
from .transactions import TransactionStore
from .transactions import TransactionWorkerStore
from .ui_auth import UIAuthStore
from .user_directory import UserDirectoryStore
from .user_erasure_store import UserErasureStore
@ -83,7 +83,7 @@ class DataStore(
StreamStore,
ProfileStore,
PresenceStore,
TransactionStore,
TransactionWorkerStore,
DirectoryStore,
KeyStore,
StateStore,

View file

@ -436,7 +436,7 @@ class ClientIpStore(ClientIpWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
self.client_ip_last_seen = LruCache(
cache_name="client_ip_last_seen", keylen=4, max_size=50000
cache_name="client_ip_last_seen", max_size=50000
)
super().__init__(database, db_conn, hs)

View file

@ -665,7 +665,7 @@ class DeviceWorkerStore(SQLBaseStore):
cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids",
)
async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]):
rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
@ -1053,7 +1053,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
self.device_id_exists_cache = LruCache(
cache_name="device_id_exists", keylen=2, max_size=10000
cache_name="device_id_exists", max_size=10000
)
async def store_device(

View file

@ -473,7 +473,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
num_args=1,
)
async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: List[str]
self, user_ids: Iterable[str]
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
@ -497,7 +497,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
def _get_bare_e2e_cross_signing_keys_bulk_txn(
self,
txn: Connection,
user_ids: List[str],
user_ids: Iterable[str],
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if

View file

@ -157,7 +157,6 @@ class EventsWorkerStore(SQLBaseStore):
self._get_event_cache = LruCache(
cache_name="*getEvent*",
keylen=3,
max_size=hs.config.caches.event_cache_size,
)

View file

@ -55,7 +55,7 @@ class KeyStore(SQLBaseStore):
"""
keys = {}
def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str]]) -> None:
def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None:
"""Processes a batch of keys to fetch, and adds the result to `keys`."""
# batch_iter always returns tuples so it's safe to do len(batch)

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict, List, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
from synapse.api.presence import PresenceState, UserPresenceState
from synapse.replication.tcp.streams import PresenceStream
@ -57,6 +57,7 @@ class PresenceStore(SQLBaseStore):
db_conn, "presence_stream", "stream_id"
)
self.hs = hs
self._presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict(
@ -96,6 +97,15 @@ class PresenceStore(SQLBaseStore):
)
txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
# Delete old rows to stop database from getting really big
sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "
for states in batch_iter(presence_states, 50):
clause, args = make_in_list_sql_clause(
self.database_engine, "user_id", [s.user_id for s in states]
)
txn.execute(sql + clause, [stream_id] + list(args))
# Actually insert new rows
self.db_pool.simple_insert_many_txn(
txn,
@ -116,15 +126,6 @@ class PresenceStore(SQLBaseStore):
],
)
# Delete old rows to stop database from getting really big
sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "
for states in batch_iter(presence_states, 50):
clause, args = make_in_list_sql_clause(
self.database_engine, "user_id", [s.user_id for s in states]
)
txn.execute(sql + clause, [stream_id] + list(args))
async def get_all_presence_updates(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, list]], int, bool]:
@ -210,6 +211,61 @@ class PresenceStore(SQLBaseStore):
return {row["user_id"]: UserPresenceState(**row) for row in rows}
async def should_user_receive_full_presence_with_token(
self,
user_id: str,
from_token: int,
) -> bool:
"""Check whether the given user should receive full presence using the stream token
they're updating from.
Args:
user_id: The ID of the user to check.
from_token: The stream token included in their /sync token.
Returns:
True if the user should have full presence sent to them, False otherwise.
"""
def _should_user_receive_full_presence_with_token_txn(txn):
sql = """
SELECT 1 FROM users_to_send_full_presence_to
WHERE user_id = ?
AND presence_stream_id >= ?
"""
txn.execute(sql, (user_id, from_token))
return bool(txn.fetchone())
return await self.db_pool.runInteraction(
"should_user_receive_full_presence_with_token",
_should_user_receive_full_presence_with_token_txn,
)
async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]):
"""Adds to the list of users who should receive a full snapshot of presence
upon their next sync.
Args:
user_ids: An iterable of user IDs.
"""
# Add user entries to the table, updating the presence_stream_id column if the user already
# exists in the table.
await self.db_pool.simple_upsert_many(
table="users_to_send_full_presence_to",
key_names=("user_id",),
key_values=[(user_id,) for user_id in user_ids],
value_names=("presence_stream_id",),
# We save the current presence stream ID token along with the user ID entry so
# that when a user /sync's, even if they syncing multiple times across separate
# devices at different times, each device will receive full presence once - when
# the presence stream ID in their sync token is less than the one in the table
# for their user ID.
value_values=(
(self._presence_id_gen.get_current_token(),) for _ in user_ids
),
desc="add_users_to_send_full_presence_to",
)
async def get_presence_for_all_users(
self,
include_offline: bool = True,

View file

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
@ -997,7 +998,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
expiration_ts = now_ms + self._account_validity_period
if use_delta:
expiration_ts = self.rand.randrange(
expiration_ts = random.randrange(
expiration_ts - self._account_validity_startup_job_max_delta,
expiration_ts,
)

View file

@ -16,13 +16,15 @@ import logging
from collections import namedtuple
from typing import Iterable, List, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage._base import db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.types import JsonDict
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.descriptors import cached
db_binary_type = memoryview
@ -38,10 +40,23 @@ _UpdateTransactionRow = namedtuple(
"_TransactionRow", ("response_code", "response_json")
)
SENTINEL = object()
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DestinationRetryTimings:
"""The current destination retry timing info for a remote server."""
# The first time we tried and failed to reach the remote server, in ms.
failure_ts: int
# The last time we tried and failed to reach the remote server, in ms.
retry_last_ts: int
# How long since the last time we tried to reach the remote server before
# trying again, in ms.
retry_interval: int
class TransactionWorkerStore(SQLBaseStore):
class TransactionWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@ -60,19 +75,6 @@ class TransactionWorkerStore(SQLBaseStore):
"_cleanup_transactions", _cleanup_transactions_txn
)
class TransactionStore(TransactionWorkerStore):
"""A collection of queries for handling PDUs."""
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
self._destination_retry_cache = ExpiringCache(
cache_name="get_destination_retry_timings",
clock=self._clock,
expiry_ms=5 * 60 * 1000,
)
async def get_received_txn_response(
self, transaction_id: str, origin: str
) -> Optional[Tuple[int, JsonDict]]:
@ -145,7 +147,11 @@ class TransactionStore(TransactionWorkerStore):
desc="set_received_txn_response",
)
async def get_destination_retry_timings(self, destination):
@cached(max_entries=10000)
async def get_destination_retry_timings(
self,
destination: str,
) -> Optional[DestinationRetryTimings]:
"""Gets the current retry timings (if any) for a given destination.
Args:
@ -156,34 +162,29 @@ class TransactionStore(TransactionWorkerStore):
Otherwise a dict for the retry scheme
"""
result = self._destination_retry_cache.get(destination, SENTINEL)
if result is not SENTINEL:
return result
result = await self.db_pool.runInteraction(
"get_destination_retry_timings",
self._get_destination_retry_timings,
destination,
)
# We don't hugely care about race conditions between getting and
# invalidating the cache, since we time out fairly quickly anyway.
self._destination_retry_cache[destination] = result
return result
def _get_destination_retry_timings(self, txn, destination):
def _get_destination_retry_timings(
self, txn, destination: str
) -> Optional[DestinationRetryTimings]:
result = self.db_pool.simple_select_one_txn(
txn,
table="destinations",
keyvalues={"destination": destination},
retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"),
retcols=("failure_ts", "retry_last_ts", "retry_interval"),
allow_none=True,
)
# check we have a row and retry_last_ts is not null or zero
# (retry_last_ts can't be negative)
if result and result["retry_last_ts"]:
return result
return DestinationRetryTimings(**result)
else:
return None
@ -204,7 +205,6 @@ class TransactionStore(TransactionWorkerStore):
retry_interval: how long until next retry in ms
"""
self._destination_retry_cache.pop(destination, None)
if self.database_engine.can_native_upsert:
return await self.db_pool.runInteraction(
"set_destination_retry_timings",
@ -252,6 +252,10 @@ class TransactionStore(TransactionWorkerStore):
txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
self._invalidate_cache_and_stream(
txn, self.get_destination_retry_timings, (destination,)
)
def _set_destination_retry_timings_emulated(
self, txn, destination, failure_ts, retry_last_ts, retry_interval
):
@ -295,6 +299,10 @@ class TransactionStore(TransactionWorkerStore):
},
)
self._invalidate_cache_and_stream(
txn, self.get_destination_retry_timings, (destination,)
)
async def store_destination_rooms_entries(
self,
destinations: Iterable[str],

View file

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Iterable
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList
@ -37,21 +39,16 @@ class UserErasureWorkerStore(SQLBaseStore):
return bool(result)
@cachedList(cached_method_name="is_user_erased", list_name="user_ids")
async def are_users_erased(self, user_ids):
async def are_users_erased(self, user_ids: Iterable[str]) -> Dict[str, bool]:
"""
Checks which users in a list have requested erasure
Args:
user_ids (iterable[str]): full user id to check
user_ids: full user ids to check
Returns:
dict[str, bool]:
for each user, whether the user has requested erasure.
for each user, whether the user has requested erasure.
"""
# this serves the dual purpose of (a) making sure we can do len and
# iterate it multiple times, and (b) avoiding duplicates.
user_ids = tuple(set(user_ids))
rows = await self.db_pool.simple_select_many_batch(
table="erased_users",
column="user_id",

View file

@ -0,0 +1,34 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Add a table that keeps track of a list of users who should, upon their next
-- sync request, receive presence for all currently online users that they are
-- "interested" in.
-- The motivation for a DB table over an in-memory list is so that this list
-- can be added to and retrieved from by any worker. Specifically, we don't
-- want to duplicate work across multiple sync workers.
CREATE TABLE IF NOT EXISTS users_to_send_full_presence_to(
-- The user ID to send full presence to.
user_id TEXT PRIMARY KEY,
-- A presence stream ID token - the current presence stream token when the row was last upserted.
-- If a user calls /sync and this token is part of the update they're to receive, we also include
-- full user presence in the response.
-- This allows multiple devices for a user to receive full presence whenever they next call /sync.
presence_stream_id BIGINT,
FOREIGN KEY (user_id)
REFERENCES users (name)
);

View file

@ -540,7 +540,7 @@ class StateGroupStorage:
state_filter: The state filter used to fetch state from the database.
Returns:
A dict from (type, state_key) -> state_event
A dict from (type, state_key) -> state_event_id
"""
state_map = await self.get_state_ids_for_events(
[event_id], state_filter or StateFilter.all()