Merge remote-tracking branch 'upstream/release-v1.27.0'

This commit is contained in:
Tulir Asokan 2021-02-02 16:06:33 +02:00
commit 7f7fb9b566
146 changed files with 4893 additions and 1484 deletions

View file

@ -262,13 +262,18 @@ class LoggingTransaction:
return self.txn.description
def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
"""Similar to `executemany`, except `txn.rowcount` will not be correct
afterwards.
More efficient than `executemany` on PostgreSQL
"""
if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch # type: ignore
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
else:
for val in args:
self.execute(sql, val)
self.executemany(sql, args)
def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
"""Corresponds to psycopg2.extras.execute_values. Only available when
@ -888,7 +893,7 @@ class DatabasePool:
", ".join("?" for _ in keys[0]),
)
txn.executemany(sql, vals)
txn.execute_batch(sql, vals)
async def simple_upsert(
self,

View file

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -43,6 +43,7 @@ from .end_to_end_keys import EndToEndKeyStore
from .event_federation import EventFederationStore
from .event_push_actions import EventPushActionsStore
from .events_bg_updates import EventsBackgroundUpdatesStore
from .events_forward_extremities import EventForwardExtremitiesStore
from .filtering import FilteringStore
from .group_server import GroupServerStore
from .keys import KeyStore
@ -118,6 +119,7 @@ class DataStore(
UIAuthStore,
CacheInvalidationWorkerStore,
ServerMetricsStore,
EventForwardExtremitiesStore,
):
def __init__(self, database: DatabasePool, db_conn, hs):
self.hs = hs

View file

@ -897,7 +897,7 @@ class DeviceWorkerStore(SQLBaseStore):
DELETE FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ?
"""
txn.executemany(sql, ((row[0], row[1]) for row in rows))
txn.execute_batch(sql, ((row[0], row[1]) for row in rows))
logger.info("Pruned %d device list outbound pokes", count)
@ -1343,7 +1343,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Delete older entries in the table, as we really only care about
# when the latest change happened.
txn.executemany(
txn.execute_batch(
"""
DELETE FROM device_lists_stream
WHERE user_id = ? AND device_id = ? AND stream_id < ?

View file

@ -634,7 +634,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
async def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: Optional[str] = None
) -> Dict[str, Dict[str, dict]]:
) -> Dict[str, Optional[Dict[str, dict]]]:
"""Returns the cross-signing keys for a set of users.
Args:
@ -724,7 +724,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
async def claim_e2e_one_time_keys(
self, query_list: Iterable[Tuple[str, str, str]]
) -> Dict[str, Dict[str, Dict[str, bytes]]]:
) -> Dict[str, Dict[str, Dict[str, str]]]:
"""Take a list of one time keys out of the database.
Args:

View file

@ -487,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
VALUES (?, ?, ?, ?, ?, ?)
"""
txn.executemany(
txn.execute_batch(
sql,
(
_gen_entry(user_id, actions)
@ -803,7 +803,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
],
)
txn.executemany(
txn.execute_batch(
"""
UPDATE event_push_summary
SET notif_count = ?, unread_count = ?, stream_ordering = ?

View file

@ -473,8 +473,9 @@ class PersistEventsStore:
txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
)
@staticmethod
@classmethod
def _add_chain_cover_index(
cls,
txn,
db_pool: DatabasePool,
event_to_room_id: Dict[str, str],
@ -614,60 +615,17 @@ class PersistEventsStore:
if not events_to_calc_chain_id_for:
return
# We now calculate the chain IDs/sequence numbers for the events. We
# do this by looking at the chain ID and sequence number of any auth
# event with the same type/state_key and incrementing the sequence
# number by one. If there was no match or the chain ID/sequence
# number is already taken we generate a new chain.
#
# We need to do this in a topologically sorted order as we want to
# generate chain IDs/sequence numbers of an event's auth events
# before the event itself.
chains_tuples_allocated = set() # type: Set[Tuple[int, int]]
new_chain_tuples = {} # type: Dict[str, Tuple[int, int]]
for event_id in sorted_topologically(
events_to_calc_chain_id_for, event_to_auth_chain
):
existing_chain_id = None
for auth_id in event_to_auth_chain.get(event_id, []):
if event_to_types.get(event_id) == event_to_types.get(auth_id):
existing_chain_id = chain_map[auth_id]
break
new_chain_tuple = None
if existing_chain_id:
# We found a chain ID/sequence number candidate, check its
# not already taken.
proposed_new_id = existing_chain_id[0]
proposed_new_seq = existing_chain_id[1] + 1
if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
already_allocated = db_pool.simple_select_one_onecol_txn(
txn,
table="event_auth_chains",
keyvalues={
"chain_id": proposed_new_id,
"sequence_number": proposed_new_seq,
},
retcol="event_id",
allow_none=True,
)
if already_allocated:
# Mark it as already allocated so we don't need to hit
# the DB again.
chains_tuples_allocated.add((proposed_new_id, proposed_new_seq))
else:
new_chain_tuple = (
proposed_new_id,
proposed_new_seq,
)
if not new_chain_tuple:
new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1)
chains_tuples_allocated.add(new_chain_tuple)
chain_map[event_id] = new_chain_tuple
new_chain_tuples[event_id] = new_chain_tuple
# Allocate chain ID/sequence numbers to each new event.
new_chain_tuples = cls._allocate_chain_ids(
txn,
db_pool,
event_to_room_id,
event_to_types,
event_to_auth_chain,
events_to_calc_chain_id_for,
chain_map,
)
chain_map.update(new_chain_tuples)
db_pool.simple_insert_many_txn(
txn,
@ -794,6 +752,137 @@ class PersistEventsStore:
],
)
@staticmethod
def _allocate_chain_ids(
txn,
db_pool: DatabasePool,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]],
events_to_calc_chain_id_for: Set[str],
chain_map: Dict[str, Tuple[int, int]],
) -> Dict[str, Tuple[int, int]]:
"""Allocates, but does not persist, chain ID/sequence numbers for the
events in `events_to_calc_chain_id_for`. (c.f. _add_chain_cover_index
for info on args)
"""
# We now calculate the chain IDs/sequence numbers for the events. We do
# this by looking at the chain ID and sequence number of any auth event
# with the same type/state_key and incrementing the sequence number by
# one. If there was no match or the chain ID/sequence number is already
# taken we generate a new chain.
#
# We try to reduce the number of times that we hit the database by
# batching up calls, to make this more efficient when persisting large
# numbers of state events (e.g. during joins).
#
# We do this by:
# 1. Calculating for each event which auth event will be used to
# inherit the chain ID, i.e. converting the auth chain graph to a
# tree that we can allocate chains on. We also keep track of which
# existing chain IDs have been referenced.
# 2. Fetching the max allocated sequence number for each referenced
# existing chain ID, generating a map from chain ID to the max
# allocated sequence number.
# 3. Iterating over the tree and allocating a chain ID/seq no. to the
# new event, by incrementing the sequence number from the
# referenced event's chain ID/seq no. and checking that the
# incremented sequence number hasn't already been allocated (by
# looking in the map generated in the previous step). We generate a
# new chain if the sequence number has already been allocated.
#
existing_chains = set() # type: Set[int]
tree = [] # type: List[Tuple[str, Optional[str]]]
# We need to do this in a topologically sorted order as we want to
# generate chain IDs/sequence numbers of an event's auth events before
# the event itself.
for event_id in sorted_topologically(
events_to_calc_chain_id_for, event_to_auth_chain
):
for auth_id in event_to_auth_chain.get(event_id, []):
if event_to_types.get(event_id) == event_to_types.get(auth_id):
existing_chain_id = chain_map.get(auth_id)
if existing_chain_id:
existing_chains.add(existing_chain_id[0])
tree.append((event_id, auth_id))
break
else:
tree.append((event_id, None))
# Fetch the current max sequence number for each existing referenced chain.
sql = """
SELECT chain_id, MAX(sequence_number) FROM event_auth_chains
WHERE %s
GROUP BY chain_id
"""
clause, args = make_in_list_sql_clause(
db_pool.engine, "chain_id", existing_chains
)
txn.execute(sql % (clause,), args)
chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int]
# Allocate the new events chain ID/sequence numbers.
#
# To reduce the number of calls to the database we don't allocate a
# chain ID number in the loop, instead we use a temporary `object()` for
# each new chain ID. Once we've done the loop we generate the necessary
# number of new chain IDs in one call, replacing all temporary
# objects with real allocated chain IDs.
unallocated_chain_ids = set() # type: Set[object]
new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]]
for event_id, auth_event_id in tree:
# If we reference an auth_event_id we fetch the allocated chain ID,
# either from the existing `chain_map` or the newly generated
# `new_chain_tuples` map.
existing_chain_id = None
if auth_event_id:
existing_chain_id = new_chain_tuples.get(auth_event_id)
if not existing_chain_id:
existing_chain_id = chain_map[auth_event_id]
new_chain_tuple = None # type: Optional[Tuple[Any, int]]
if existing_chain_id:
# We found a chain ID/sequence number candidate, check its
# not already taken.
proposed_new_id = existing_chain_id[0]
proposed_new_seq = existing_chain_id[1] + 1
if chain_to_max_seq_no[proposed_new_id] < proposed_new_seq:
new_chain_tuple = (
proposed_new_id,
proposed_new_seq,
)
# If we need to start a new chain we allocate a temporary chain ID.
if not new_chain_tuple:
new_chain_tuple = (object(), 1)
unallocated_chain_ids.add(new_chain_tuple[0])
new_chain_tuples[event_id] = new_chain_tuple
chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
# Generate new chain IDs for all unallocated chain IDs.
newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
txn, len(unallocated_chain_ids)
)
# Map from potentially temporary chain ID to real chain ID
chain_id_to_allocated_map = dict(
zip(unallocated_chain_ids, newly_allocated_chain_ids)
) # type: Dict[Any, int]
chain_id_to_allocated_map.update((c, c) for c in existing_chains)
return {
event_id: (chain_id_to_allocated_map[chain_id], seq)
for event_id, (chain_id, seq) in new_chain_tuples.items()
}
def _persist_transaction_ids_txn(
self,
txn: LoggingTransaction,
@ -876,7 +965,7 @@ class PersistEventsStore:
WHERE room_id = ? AND type = ? AND state_key = ?
)
"""
txn.executemany(
txn.execute_batch(
sql,
(
(
@ -895,7 +984,7 @@ class PersistEventsStore:
)
# Now we actually update the current_state_events table
txn.executemany(
txn.execute_batch(
"DELETE FROM current_state_events"
" WHERE room_id = ? AND type = ? AND state_key = ?",
(
@ -907,7 +996,7 @@ class PersistEventsStore:
# We include the membership in the current state table, hence we do
# a lookup when we insert. This assumes that all events have already
# been inserted into room_memberships.
txn.executemany(
txn.execute_batch(
"""INSERT INTO current_state_events
(room_id, type, state_key, event_id, membership)
VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@ -927,7 +1016,7 @@ class PersistEventsStore:
# we have no record of the fact the user *was* a member of the
# room but got, say, state reset out of it.
if to_delete or to_insert:
txn.executemany(
txn.execute_batch(
"DELETE FROM local_current_membership"
" WHERE room_id = ? AND user_id = ?",
(
@ -938,7 +1027,7 @@ class PersistEventsStore:
)
if to_insert:
txn.executemany(
txn.execute_batch(
"""INSERT INTO local_current_membership
(room_id, user_id, event_id, membership)
VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@ -1738,7 +1827,7 @@ class PersistEventsStore:
"""
if events_and_contexts:
txn.executemany(
txn.execute_batch(
sql,
(
(
@ -1767,7 +1856,7 @@ class PersistEventsStore:
# Now we delete the staging area for *all* events that were being
# persisted.
txn.executemany(
txn.execute_batch(
"DELETE FROM event_push_actions_staging WHERE event_id = ?",
((event.event_id,) for event, _ in all_events_and_contexts),
)
@ -1886,7 +1975,7 @@ class PersistEventsStore:
" )"
)
txn.executemany(
txn.execute_batch(
query,
[
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
@ -1900,7 +1989,7 @@ class PersistEventsStore:
"DELETE FROM event_backward_extremities"
" WHERE event_id = ? AND room_id = ?"
)
txn.executemany(
txn.execute_batch(
query,
[
(ev.event_id, ev.room_id)

View file

@ -139,8 +139,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
INSERT_CLUMP_SIZE = 1000
def reindex_txn(txn):
sql = (
"SELECT stream_ordering, event_id, json FROM events"
@ -178,9 +176,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
clump = update_rows[index : index + INSERT_CLUMP_SIZE]
txn.executemany(sql, clump)
txn.execute_batch(sql, update_rows)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
@ -210,8 +206,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
INSERT_CLUMP_SIZE = 1000
def reindex_search_txn(txn):
sql = (
"SELECT stream_ordering, event_id FROM events"
@ -256,9 +250,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
clump = rows_to_update[index : index + INSERT_CLUMP_SIZE]
txn.executemany(sql, clump)
txn.execute_batch(sql, rows_to_update)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,

View file

@ -0,0 +1,101 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict, List
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
logger = logging.getLogger(__name__)
class EventForwardExtremitiesStore(SQLBaseStore):
async def delete_forward_extremities_for_room(self, room_id: str) -> int:
"""Delete any extra forward extremities for a room.
Invalidates the "get_latest_event_ids_in_room" cache if any forward
extremities were deleted.
Returns count deleted.
"""
def delete_forward_extremities_for_room_txn(txn):
# First we need to get the event_id to not delete
sql = """
SELECT event_id FROM event_forward_extremities
INNER JOIN events USING (room_id, event_id)
WHERE room_id = ?
ORDER BY stream_ordering DESC
LIMIT 1
"""
txn.execute(sql, (room_id,))
rows = txn.fetchall()
try:
event_id = rows[0][0]
logger.debug(
"Found event_id %s as the forward extremity to keep for room %s",
event_id,
room_id,
)
except KeyError:
msg = "No forward extremity event found for room %s" % room_id
logger.warning(msg)
raise SynapseError(400, msg)
# Now delete the extra forward extremities
sql = """
DELETE FROM event_forward_extremities
WHERE event_id != ? AND room_id = ?
"""
txn.execute(sql, (event_id, room_id))
logger.info(
"Deleted %s extra forward extremities for room %s",
txn.rowcount,
room_id,
)
if txn.rowcount > 0:
# Invalidate the cache
self._invalidate_cache_and_stream(
txn, self.get_latest_event_ids_in_room, (room_id,),
)
return txn.rowcount
return await self.db_pool.runInteraction(
"delete_forward_extremities_for_room",
delete_forward_extremities_for_room_txn,
)
async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]:
"""Get list of forward extremities for a room."""
def get_forward_extremities_for_room_txn(txn):
sql = """
SELECT event_id, state_group, depth, received_ts
FROM event_forward_extremities
INNER JOIN event_to_state_groups USING (event_id)
INNER JOIN events USING (room_id, event_id)
WHERE room_id = ?
"""
txn.execute(sql, (room_id,))
return self.db_pool.cursor_to_dict(txn)
return await self.db_pool.runInteraction(
"get_forward_extremities_for_room", get_forward_extremities_for_room_txn,
)

View file

@ -417,7 +417,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" WHERE media_origin = ? AND media_id = ?"
)
txn.executemany(
txn.execute_batch(
sql,
(
(time_ms, media_origin, media_id)
@ -430,7 +430,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" WHERE media_id = ?"
)
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media))
return await self.db_pool.runInteraction(
"update_cached_last_access_time", update_cache_txn
@ -557,7 +557,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
def _delete_url_cache_txn(txn):
txn.executemany(sql, [(media_id,) for media_id in media_ids])
txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
return await self.db_pool.runInteraction(
"delete_url_cache", _delete_url_cache_txn
@ -586,11 +586,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def _delete_url_cache_media_txn(txn):
sql = "DELETE FROM local_media_repository WHERE media_id = ?"
txn.executemany(sql, [(media_id,) for media_id in media_ids])
txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
txn.executemany(sql, [(media_id,) for media_id in media_ids])
txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
return await self.db_pool.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn

View file

@ -88,6 +88,62 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(x[0] - 1) * x[1] for x in res if x[1]
)
async def count_daily_e2ee_messages(self):
"""
Returns an estimate of the number of messages sent in the last day.
If it has been significantly less or more than one day since the last
call to this function, it will return None.
"""
def _count_messages(txn):
sql = """
SELECT COALESCE(COUNT(*), 0) FROM events
WHERE type = 'm.room.encrypted'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
(count,) = txn.fetchone()
return count
return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
async def count_daily_sent_e2ee_messages(self):
def _count_messages(txn):
# This is good enough as if you have silly characters in your own
# hostname then thats your own fault.
like_clause = "%:" + self.hs.hostname
sql = """
SELECT COALESCE(COUNT(*), 0) FROM events
WHERE type = 'm.room.encrypted'
AND sender LIKE ?
AND stream_ordering > ?
"""
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
(count,) = txn.fetchone()
return count
return await self.db_pool.runInteraction(
"count_daily_sent_e2ee_messages", _count_messages
)
async def count_daily_active_e2ee_rooms(self):
def _count(txn):
sql = """
SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
WHERE type = 'm.room.encrypted'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
(count,) = txn.fetchone()
return count
return await self.db_pool.runInteraction(
"count_daily_active_e2ee_rooms", _count
)
async def count_daily_messages(self):
"""
Returns an estimate of the number of messages sent in the last day.

View file

@ -172,7 +172,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
)
# Update backward extremeties
txn.executemany(
txn.execute_batch(
"INSERT INTO event_backward_extremities (room_id, event_id)"
" VALUES (?, ?)",
[(room_id, event_id) for event_id, in new_backwards_extrems],

View file

@ -344,7 +344,9 @@ class PusherStore(PusherWorkerStore):
txn, self.get_if_user_has_pusher, (user_id,)
)
self.db_pool.simple_delete_one_txn(
# It is expected that there is exactly one pusher to delete, but
# if it isn't there (or there are multiple) delete them all.
self.db_pool.simple_delete_txn(
txn,
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},

View file

@ -360,6 +360,35 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
async def set_shadow_banned(self, user: UserID, shadow_banned: bool) -> None:
"""Sets whether a user shadow-banned.
Args:
user: user ID of the user to test
shadow_banned: true iff the user is to be shadow-banned, false otherwise.
"""
def set_shadow_banned_txn(txn):
self.db_pool.simple_update_one_txn(
txn,
table="users",
keyvalues={"name": user.to_string()},
updatevalues={"shadow_banned": shadow_banned},
)
# In order for this to apply immediately, clear the cache for this user.
tokens = self.db_pool.simple_select_onecol_txn(
txn,
table="access_tokens",
keyvalues={"user_id": user.to_string()},
retcol="token",
)
for token in tokens:
self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (token,)
)
await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
sql = """
SELECT users.name as user_id,
@ -443,6 +472,26 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
async def record_user_external_id(
self, auth_provider: str, external_id: str, user_id: str
) -> None:
"""Record a mapping from an external user id to a mxid
Args:
auth_provider: identifier for the remote auth provider
external_id: id on that system
user_id: complete mxid that it is mapped to
"""
await self.db_pool.simple_insert(
table="user_external_ids",
values={
"auth_provider": auth_provider,
"external_id": external_id,
"user_id": user_id,
},
desc="record_user_external_id",
)
async def get_user_by_external_id(
self, auth_provider: str, external_id: str
) -> Optional[str]:
@ -1104,7 +1153,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
FROM user_threepids
"""
txn.executemany(sql, [(id_server,) for id_server in id_servers])
txn.execute_batch(sql, [(id_server,) for id_server in id_servers])
if id_servers:
await self.db_pool.runInteraction(
@ -1371,26 +1420,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
async def record_user_external_id(
self, auth_provider: str, external_id: str, user_id: str
) -> None:
"""Record a mapping from an external user id to a mxid
Args:
auth_provider: identifier for the remote auth provider
external_id: id on that system
user_id: complete mxid that it is mapped to
"""
await self.db_pool.simple_insert(
table="user_external_ids",
values={
"auth_provider": auth_provider,
"external_id": external_id,
"user_id": user_id,
},
desc="record_user_external_id",
)
async def user_set_password_hash(
self, user_id: str, password_hash: Optional[str]
) -> None:

View file

@ -873,8 +873,6 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
"max_stream_id_exclusive", self._stream_order_on_start + 1
)
INSERT_CLUMP_SIZE = 1000
def add_membership_profile_txn(txn):
sql = """
SELECT stream_ordering, event_id, events.room_id, event_json.json
@ -915,9 +913,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
UPDATE room_memberships SET display_name = ?, avatar_url = ?
WHERE event_id = ? AND room_id = ?
"""
for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
clump = to_update[index : index + INSERT_CLUMP_SIZE]
txn.executemany(to_update_sql, clump)
txn.execute_batch(to_update_sql, to_update)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,

View file

@ -55,7 +55,7 @@ def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs
# { "ignored_users": "@someone:example.org": {} }
ignored_users = content.get("ignored_users", {})
if isinstance(ignored_users, dict) and ignored_users:
cur.executemany(insert_sql, [(user_id, u) for u in ignored_users])
cur.execute_batch(insert_sql, [(user_id, u) for u in ignored_users])
# Add indexes after inserting data for efficiency.
logger.info("Adding constraints to ignored_users table")

View file

@ -24,6 +24,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import Collection
logger = logging.getLogger(__name__)
@ -63,7 +64,7 @@ class SearchWorkerStore(SQLBaseStore):
for entry in entries
)
txn.executemany(sql, args)
txn.execute_batch(sql, args)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
@ -75,7 +76,7 @@ class SearchWorkerStore(SQLBaseStore):
for entry in entries
)
txn.executemany(sql, args)
txn.execute_batch(sql, args)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
@ -460,7 +461,7 @@ class SearchStore(SearchBackgroundUpdateStore):
async def search_rooms(
self,
room_ids: List[str],
room_ids: Collection[str],
search_term: str,
keys: List[str],
limit,

View file

@ -15,11 +15,12 @@
# limitations under the License.
import logging
from collections import Counter
from enum import Enum
from itertools import chain
from typing import Any, Dict, List, Optional, Tuple
from typing_extensions import Counter
from twisted.internet.defer import DeferredLock
from synapse.api.constants import EventTypes, Membership
@ -319,7 +320,9 @@ class StatsStore(StateDeltasStore):
return slice_list
@cached()
async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
async def get_earliest_token_for_stats(
self, stats_type: str, id: str
) -> Optional[int]:
"""
Fetch the "earliest token". This is used by the room stats delta
processor to ignore deltas that have been processed between the
@ -339,7 +342,7 @@ class StatsStore(StateDeltasStore):
)
async def bulk_update_stats_delta(
self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int
self, ts: int, updates: Dict[str, Dict[str, Counter[str]]], stream_id: int
) -> None:
"""Bulk update stats tables for a given stream_id and updates the stats
incremental position.
@ -665,7 +668,7 @@ class StatsStore(StateDeltasStore):
async def get_changes_room_total_events_and_bytes(
self, min_pos: int, max_pos: int
) -> Dict[str, Dict[str, int]]:
) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
"""Fetches the counts of events in the given range of stream IDs.
Args:
@ -683,18 +686,19 @@ class StatsStore(StateDeltasStore):
max_pos,
)
def get_changes_room_total_events_and_bytes_txn(self, txn, low_pos, high_pos):
def get_changes_room_total_events_and_bytes_txn(
self, txn, low_pos: int, high_pos: int
) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
"""Gets the total_events and total_event_bytes counts for rooms and
senders, in a range of stream_orderings (including backfilled events).
Args:
txn
low_pos (int): Low stream ordering
high_pos (int): High stream ordering
low_pos: Low stream ordering
high_pos: High stream ordering
Returns:
tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: The
room and user deltas for total_events/total_event_bytes in the
The room and user deltas for total_events/total_event_bytes in the
format of `stats_id` -> fields
"""

View file

@ -540,7 +540,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
desc="get_user_in_directory",
)
async def update_user_directory_stream_pos(self, stream_id: str) -> None:
async def update_user_directory_stream_pos(self, stream_id: int) -> None:
await self.db_pool.simple_update_one(
table="user_directory_stream_pos",
keyvalues={},

View file

@ -565,11 +565,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
logger.info("[purge] removing redundant state groups")
txn.executemany(
txn.execute_batch(
"DELETE FROM state_groups_state WHERE state_group = ?",
((sg,) for sg in state_groups_to_delete),
)
txn.executemany(
txn.execute_batch(
"DELETE FROM state_groups WHERE id = ?",
((sg,) for sg in state_groups_to_delete),
)

View file

@ -15,12 +15,11 @@
import heapq
import logging
import threading
from collections import deque
from collections import OrderedDict
from contextlib import contextmanager
from typing import Dict, List, Optional, Set, Tuple, Union
import attr
from typing_extensions import Deque
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import DatabasePool, LoggingTransaction
@ -101,7 +100,13 @@ class StreamIdGenerator:
self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step)
)
self._unfinished_ids = deque() # type: Deque[int]
# We use this as an ordered set, as we want to efficiently append items,
# remove items and get the first item. Since we insert IDs in order, the
# insertion ordering will ensure its in the correct ordering.
#
# The key and values are the same, but we never look at the values.
self._unfinished_ids = OrderedDict() # type: OrderedDict[int, int]
def get_next(self):
"""
@ -113,7 +118,7 @@ class StreamIdGenerator:
self._current += self._step
next_id = self._current
self._unfinished_ids.append(next_id)
self._unfinished_ids[next_id] = next_id
@contextmanager
def manager():
@ -121,7 +126,7 @@ class StreamIdGenerator:
yield next_id
finally:
with self._lock:
self._unfinished_ids.remove(next_id)
self._unfinished_ids.pop(next_id)
return _AsyncCtxManagerWrapper(manager())
@ -140,7 +145,7 @@ class StreamIdGenerator:
self._current += n * self._step
for next_id in next_ids:
self._unfinished_ids.append(next_id)
self._unfinished_ids[next_id] = next_id
@contextmanager
def manager():
@ -149,7 +154,7 @@ class StreamIdGenerator:
finally:
with self._lock:
for next_id in next_ids:
self._unfinished_ids.remove(next_id)
self._unfinished_ids.pop(next_id)
return _AsyncCtxManagerWrapper(manager())
@ -162,7 +167,7 @@ class StreamIdGenerator:
"""
with self._lock:
if self._unfinished_ids:
return self._unfinished_ids[0] - self._step
return next(iter(self._unfinished_ids)) - self._step
return self._current

View file

@ -69,6 +69,11 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
"""Gets the next ID in the sequence"""
...
@abc.abstractmethod
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
"""Get the next `n` IDs in the sequence"""
...
@abc.abstractmethod
def check_consistency(
self,
@ -219,6 +224,17 @@ class LocalSequenceGenerator(SequenceGenerator):
self._current_max_id += 1
return self._current_max_id
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
with self._lock:
if self._current_max_id is None:
assert self._callback is not None
self._current_max_id = self._callback(txn)
self._callback = None
first_id = self._current_max_id + 1
self._current_max_id += n
return [first_id + i for i in range(n)]
def check_consistency(
self,
db_conn: Connection,