Merge remote-tracking branch 'upstream/release-v1.46'

This commit is contained in:
Tulir Asokan 2021-10-27 15:42:34 +03:00
commit cf45cfd314
172 changed files with 5549 additions and 2350 deletions

View file

@ -19,6 +19,7 @@ from collections import defaultdict
from sys import intern
from time import monotonic as monotonic_time
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
@ -52,6 +53,9 @@ from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
if TYPE_CHECKING:
from synapse.server import HomeServer
# python 3 does not have a maximum int value
MAX_TXN_ID = 2 ** 63 - 1
@ -392,7 +396,7 @@ class DatabasePool:
def __init__(
self,
hs,
hs: "HomeServer",
database_config: DatabaseConnectionConfig,
engine: BaseDatabaseEngine,
):

View file

@ -13,33 +13,49 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Generic, List, Optional, Type, TypeVar
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_conn
from synapse.storage.databases.main.events import PersistEventsStore
from synapse.storage.databases.state import StateGroupDataStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class Databases:
DataStoreT = TypeVar("DataStoreT", bound=SQLBaseStore, covariant=True)
class Databases(Generic[DataStoreT]):
"""The various databases.
These are low level interfaces to physical databases.
Attributes:
main (DataStore)
databases
main
state
persist_events
"""
def __init__(self, main_store_class, hs):
databases: List[DatabasePool]
main: DataStoreT
state: StateGroupDataStore
persist_events: Optional[PersistEventsStore]
def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"):
# Note we pass in the main store class here as workers use a different main
# store.
self.databases = []
main = None
state = None
persist_events = None
main: Optional[DataStoreT] = None
state: Optional[StateGroupDataStore] = None
persist_events: Optional[PersistEventsStore] = None
for database_config in hs.config.database.databases:
db_name = database_config.name

View file

@ -15,7 +15,7 @@
# limitations under the License.
import logging
from typing import List, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import DatabasePool
@ -75,6 +75,9 @@ from .ui_auth import UIAuthStore
from .user_directory import UserDirectoryStore
from .user_erasure_store import UserErasureStore
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -126,7 +129,7 @@ class DataStore(
LockStore,
SessionStore,
):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine

View file

@ -14,7 +14,7 @@
# limitations under the License.
import logging
from typing import Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from synapse.api.constants import AccountDataTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@ -28,6 +28,9 @@ from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -36,7 +39,7 @@ class AccountDataWorkerStore(SQLBaseStore):
`get_max_account_data_stream_id` which can be called in the initializer.
"""
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
self._instance_name = hs.get_instance_name()
if isinstance(database.engine, PostgresEngine):

View file

@ -15,7 +15,7 @@
import itertools
import logging
from typing import Any, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple
from synapse.api.constants import EventTypes
from synapse.replication.tcp.streams import BackfillStream, CachesStream
@ -29,6 +29,9 @@ from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -38,7 +41,7 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
class CacheInvalidationWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()

View file

@ -13,14 +13,26 @@
# limitations under the License.
import logging
from typing import Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast
from typing_extensions import TypedDict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
from synapse.types import UserID
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_tuple_comparison_clause,
)
from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore
from synapse.storage.types import Connection
from synapse.types import JsonDict, UserID
from synapse.util.caches.lrucache import LruCache
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
@ -29,8 +41,31 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 120 * 1000
class DeviceLastConnectionInfo(TypedDict):
"""Metadata for the last connection seen for a user and device combination"""
# These types must match the columns in the `devices` table
user_id: str
device_id: str
ip: Optional[str]
user_agent: Optional[str]
last_seen: Optional[int]
class LastConnectionInfo(TypedDict):
"""Metadata for the last connection seen for an access token and IP combination"""
# These types must match the columns in the `user_ips` table
access_token: str
ip: str
user_agent: str
last_seen: int
class ClientIpBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@ -81,8 +116,10 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
"devices_last_seen", self._devices_last_seen_update
)
async def _remove_user_ip_nonunique(self, progress, batch_size):
def f(conn):
async def _remove_user_ip_nonunique(
self, progress: JsonDict, batch_size: int
) -> int:
def f(conn: LoggingDatabaseConnection) -> None:
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
txn.close()
@ -93,14 +130,14 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
)
return 1
async def _analyze_user_ip(self, progress, batch_size):
async def _analyze_user_ip(self, progress: JsonDict, batch_size: int) -> int:
# Background update to analyze user_ips table before we run the
# deduplication background update. The table may not have been analyzed
# for ages due to the table locks.
#
# This will lock out the naive upserts to user_ips while it happens, but
# the analyze should be quick (28GB table takes ~10s)
def user_ips_analyze(txn):
def user_ips_analyze(txn: LoggingTransaction) -> None:
txn.execute("ANALYZE user_ips")
await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze)
@ -109,16 +146,16 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return 1
async def _remove_user_ip_dupes(self, progress, batch_size):
async def _remove_user_ip_dupes(self, progress: JsonDict, batch_size: int) -> int:
# This works function works by scanning the user_ips table in batches
# based on `last_seen`. For each row in a batch it searches the rest of
# the table to see if there are any duplicates, if there are then they
# are removed and replaced with a suitable row.
# Fetch the start of the batch
begin_last_seen = progress.get("last_seen", 0)
begin_last_seen: int = progress.get("last_seen", 0)
def get_last_seen(txn):
def get_last_seen(txn: LoggingTransaction) -> Optional[int]:
txn.execute(
"""
SELECT last_seen FROM user_ips
@ -129,7 +166,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
""",
(begin_last_seen, batch_size),
)
row = txn.fetchone()
row = cast(Optional[Tuple[int]], txn.fetchone())
if row:
return row[0]
else:
@ -149,7 +186,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
end_last_seen,
)
def remove(txn):
def remove(txn: LoggingTransaction) -> None:
# This works by looking at all entries in the given time span, and
# then for each (user_id, access_token, ip) tuple in that range
# checking for any duplicates in the rest of the table (via a join).
@ -161,10 +198,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
# Define the search space, which requires handling the last batch in
# a different way
args: Tuple[int, ...]
if last:
clause = "? <= last_seen"
args = (begin_last_seen,)
else:
assert end_last_seen is not None
clause = "? <= last_seen AND last_seen < ?"
args = (begin_last_seen, end_last_seen)
@ -189,7 +228,9 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
),
args,
)
res = txn.fetchall()
res = cast(
List[Tuple[str, str, str, Optional[str], str, int, int]], txn.fetchall()
)
# We've got some duplicates
for i in res:
@ -278,13 +319,15 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return batch_size
async def _devices_last_seen_update(self, progress, batch_size):
async def _devices_last_seen_update(
self, progress: JsonDict, batch_size: int
) -> int:
"""Background update to insert last seen info into devices table"""
last_user_id = progress.get("last_user_id", "")
last_device_id = progress.get("last_device_id", "")
last_user_id: str = progress.get("last_user_id", "")
last_device_id: str = progress.get("last_device_id", "")
def _devices_last_seen_update_txn(txn):
def _devices_last_seen_update_txn(txn: LoggingTransaction) -> int:
# This consists of two queries:
#
# 1. The sub-query searches for the next N devices and joins
@ -296,6 +339,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
# we'll just end up updating the same device row multiple
# times, which is fine.
where_args: List[Union[str, int]]
where_clause, where_args = make_tuple_comparison_clause(
[("user_id", last_user_id), ("device_id", last_device_id)],
)
@ -319,7 +363,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
}
txn.execute(sql, where_args + [batch_size])
rows = txn.fetchall()
rows = cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
if not rows:
return 0
@ -350,7 +394,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.user_ips_max_age = hs.config.server.user_ips_max_age
@ -359,7 +403,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
@wrap_as_background_process("prune_old_user_ips")
async def _prune_old_user_ips(self):
async def _prune_old_user_ips(self) -> None:
"""Removes entries in user IPs older than the configured period."""
if self.user_ips_max_age is None:
@ -394,9 +438,9 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
)
"""
timestamp = self.clock.time_msec() - self.user_ips_max_age
timestamp = self._clock.time_msec() - self.user_ips_max_age
def _prune_old_user_ips_txn(txn):
def _prune_old_user_ips_txn(txn: LoggingTransaction) -> None:
txn.execute(sql, (timestamp,))
await self.db_pool.runInteraction(
@ -405,7 +449,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
async def get_last_client_ip_by_device(
self, user_id: str, device_id: Optional[str]
) -> Dict[Tuple[str, str], dict]:
) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]:
"""For each device_id listed, give the user_ip it was last seen on.
The result might be slightly out of date as client IPs are inserted in batches.
@ -423,26 +467,84 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
if device_id is not None:
keyvalues["device_id"] = device_id
res = await self.db_pool.simple_select_list(
table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
res = cast(
List[DeviceLastConnectionInfo],
await self.db_pool.simple_select_list(
table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
),
)
return {(d["user_id"], d["device_id"]): d for d in res}
async def get_user_ip_and_agents(
self, user: UserID, since_ts: int = 0
) -> List[LastConnectionInfo]:
"""Fetch the IPs and user agents for a user since the given timestamp.
class ClientIpStore(ClientIpWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
The result might be slightly out of date as client IPs are inserted in batches.
self.client_ip_last_seen = LruCache(
Args:
user: The user for which to fetch IP addresses and user agents.
since_ts: The timestamp after which to fetch IP addresses and user agents,
in milliseconds.
Returns:
A list of dictionaries, each containing:
* `access_token`: The access token used.
* `ip`: The IP address used.
* `user_agent`: The last user agent seen for this access token and IP
address combination.
* `last_seen`: The timestamp at which this access token and IP address
combination was last seen, in milliseconds.
Only the latest user agent for each access token and IP address combination
is available.
"""
user_id = user.to_string()
def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]:
txn.execute(
"""
SELECT access_token, ip, user_agent, last_seen FROM user_ips
WHERE last_seen >= ? AND user_id = ?
ORDER BY last_seen
DESC
""",
(since_ts, user_id),
)
return cast(List[Tuple[str, str, str, int]], txn.fetchall())
rows = await self.db_pool.runInteraction(
desc="get_user_ip_and_agents", func=get_recent
)
return [
{
"access_token": access_token,
"ip": ip,
"user_agent": user_agent,
"last_seen": last_seen,
}
for access_token, ip, user_agent, last_seen in rows
]
class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
# (user_id, access_token, ip,) -> last_seen
self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
cache_name="client_ip_last_seen", max_size=50000
)
super().__init__(database, db_conn, hs)
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
self._batch_row_update = {}
self._batch_row_update: Dict[
Tuple[str, str, str], Tuple[str, Optional[str], int]
] = {}
self._client_ip_looper = self._clock.looping_call(
self._update_client_ips_batch, 5 * 1000
@ -452,8 +554,14 @@ class ClientIpStore(ClientIpWorkerStore):
)
async def insert_client_ip(
self, user_id, access_token, ip, user_agent, device_id, now=None
):
self,
user_id: str,
access_token: str,
ip: str,
user_agent: str,
device_id: Optional[str],
now: Optional[int] = None,
) -> None:
if not now:
now = int(self._clock.time_msec())
key = (user_id, access_token, ip)
@ -485,7 +593,11 @@ class ClientIpStore(ClientIpWorkerStore):
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
)
def _update_client_ips_batch_txn(self, txn, to_update):
def _update_client_ips_batch_txn(
self,
txn: LoggingTransaction,
to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]],
) -> None:
if "user_ips" in self.db_pool._unsafe_to_upsert_tables or (
not self.database_engine.can_native_upsert
):
@ -525,7 +637,7 @@ class ClientIpStore(ClientIpWorkerStore):
async def get_last_client_ip_by_device(
self, user_id: str, device_id: Optional[str]
) -> Dict[Tuple[str, str], dict]:
) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]:
"""For each device_id listed, give the user_ip it was last seen on
Args:
@ -561,50 +673,44 @@ class ClientIpStore(ClientIpWorkerStore):
async def get_user_ip_and_agents(
self, user: UserID, since_ts: int = 0
) -> List[Dict[str, Union[str, int]]]:
"""
Fetch IP/User Agent connection since a given timestamp.
"""
user_id = user.to_string()
results = {}
) -> List[LastConnectionInfo]:
"""Fetch the IPs and user agents for a user since the given timestamp.
Args:
user: The user for which to fetch IP addresses and user agents.
since_ts: The timestamp after which to fetch IP addresses and user agents,
in milliseconds.
Returns:
A list of dictionaries, each containing:
* `access_token`: The access token used.
* `ip`: The IP address used.
* `user_agent`: The last user agent seen for this access token and IP
address combination.
* `last_seen`: The timestamp at which this access token and IP address
combination was last seen, in milliseconds.
Only the latest user agent for each access token and IP address combination
is available.
"""
results: Dict[Tuple[str, str], LastConnectionInfo] = {
(connection["access_token"], connection["ip"]): connection
for connection in await super().get_user_ip_and_agents(user, since_ts)
}
# Overlay data that is pending insertion on top of the results from the
# database.
user_id = user.to_string()
for key in self._batch_row_update:
(
uid,
access_token,
ip,
) = key
uid, access_token, ip = key
if uid == user_id:
user_agent, _, last_seen = self._batch_row_update[key]
if last_seen >= since_ts:
results[(access_token, ip)] = (user_agent, last_seen)
results[(access_token, ip)] = {
"access_token": access_token,
"ip": ip,
"user_agent": user_agent,
"last_seen": last_seen,
}
def get_recent(txn):
txn.execute(
"""
SELECT access_token, ip, user_agent, last_seen FROM user_ips
WHERE last_seen >= ? AND user_id = ?
ORDER BY last_seen
DESC
""",
(since_ts, user_id),
)
return txn.fetchall()
rows = await self.db_pool.runInteraction(
desc="get_user_ip_and_agents", func=get_recent
)
results.update(
((access_token, ip), (user_agent, last_seen))
for access_token, ip, user_agent, last_seen in rows
)
return [
{
"access_token": access_token,
"ip": ip,
"user_agent": user_agent,
"last_seen": last_seen,
}
for (access_token, ip), (user_agent, last_seen) in results.items()
]
return list(results.values())

View file

@ -13,7 +13,7 @@
# limitations under the License.
import logging
from typing import List, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace
@ -26,11 +26,14 @@ from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class DeviceInboxWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
@ -553,7 +556,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(

View file

@ -15,7 +15,17 @@
# limitations under the License.
import abc
import logging
from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
)
from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import (
@ -38,6 +48,9 @@ from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
@ -48,7 +61,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
class DeviceWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks:
@ -915,7 +928,7 @@ class DeviceWorkerStore(SQLBaseStore):
class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@ -1047,7 +1060,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies

View file

@ -14,7 +14,7 @@
import itertools
import logging
from queue import Empty, PriorityQueue
from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
from prometheus_client import Counter, Gauge
@ -34,6 +34,9 @@ from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
from synapse.server import HomeServer
oldest_pdu_in_federation_staging = Gauge(
"synapse_federation_server_oldest_inbound_pdu_in_staging",
"The age in seconds since we received the oldest pdu in the federation staging area",
@ -59,7 +62,7 @@ class _NoChainCoverIndex(Exception):
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks:
@ -906,7 +909,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
desc="get_latest_event_ids_in_room",
)
async def get_min_depth(self, room_id: str) -> int:
async def get_min_depth(self, room_id: str) -> Optional[int]:
"""For the given room, get the minimum depth we have seen for it."""
return await self.db_pool.runInteraction(
"get_min_depth", self._get_min_depth_interaction, room_id
@ -1511,7 +1514,7 @@ class EventFederationStore(EventFederationWorkerStore):
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import attr
@ -23,6 +23,9 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -64,7 +67,7 @@ def _deserialize_action(actions, is_highlight):
class EventPushActionsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
# These get correctly set by _find_stream_orderings_for_times_txn
@ -892,7 +895,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(

View file

@ -1710,6 +1710,7 @@ class PersistEventsStore:
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
RelationTypes.REPLACE,
RelationTypes.THREAD,
):
# Unknown relation type
return
@ -1740,6 +1741,9 @@ class PersistEventsStore:
if rel_type == RelationTypes.REPLACE:
txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
if rel_type == RelationTypes.THREAD:
txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
"""Handles keeping track of insertion events and edges/connections.
Part of MSC2716.
@ -2069,12 +2073,14 @@ class PersistEventsStore:
state_groups[event.event_id] = context.state_group
self.db_pool.simple_insert_many_txn(
self.db_pool.simple_upsert_many_txn(
txn,
table="event_to_state_groups",
values=[
{"state_group": state_group_id, "event_id": event_id}
for event_id, state_group_id in state_groups.items()
key_names=["event_id"],
key_values=[[event_id] for event_id, _ in state_groups.items()],
value_names=["state_group"],
value_values=[
[state_group_id] for _, state_group_id in state_groups.items()
],
)

View file

@ -13,19 +13,26 @@
# limitations under the License.
import logging
from typing import Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import attr
from synapse.api.constants import EventContentFields
from synapse.api.constants import EventContentFields, RelationTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
from synapse.storage.database import (
DatabasePool,
LoggingTransaction,
make_tuple_comparison_clause,
)
from synapse.storage.databases.main.events import PersistEventsStore
from synapse.storage.types import Cursor
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -76,7 +83,7 @@ class _CalculateChainCover:
class EventsBackgroundUpdatesStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
@ -164,6 +171,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self._purged_chain_cover_index,
)
self.db_pool.updates.register_background_update_handler(
"event_thread_relation", self._event_thread_relation
)
################################################################################
# bg updates for replacing stream_ordering with a BIGINT
@ -1088,6 +1099,79 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return result
async def _event_thread_relation(self, progress: JsonDict, batch_size: int) -> int:
"""Background update handler which will store thread relations for existing events."""
last_event_id = progress.get("last_event_id", "")
def _event_thread_relation_txn(txn: LoggingTransaction) -> int:
txn.execute(
"""
SELECT event_id, json FROM event_json
LEFT JOIN event_relations USING (event_id)
WHERE event_id > ? AND event_relations.event_id IS NULL
ORDER BY event_id LIMIT ?
""",
(last_event_id, batch_size),
)
results = list(txn)
missing_thread_relations = []
for (event_id, event_json_raw) in results:
try:
event_json = db_to_json(event_json_raw)
except Exception as e:
logger.warning(
"Unable to load event %s (no relations will be updated): %s",
event_id,
e,
)
continue
# If there's no relation (or it is not a thread), skip!
relates_to = event_json["content"].get("m.relates_to")
if not relates_to or not isinstance(relates_to, dict):
continue
if relates_to.get("rel_type") != RelationTypes.THREAD:
continue
# Get the parent ID.
parent_id = relates_to.get("event_id")
if not isinstance(parent_id, str):
continue
missing_thread_relations.append((event_id, parent_id))
# Insert the missing data.
self.db_pool.simple_insert_many_txn(
txn=txn,
table="event_relations",
values=[
{
"event_id": event_id,
"relates_to_Id": parent_id,
"relation_type": RelationTypes.THREAD,
}
for event_id, parent_id in missing_thread_relations
],
)
if results:
latest_event_id = results[-1][0]
self.db_pool.updates._background_update_progress_txn(
txn, "event_thread_relation", {"last_event_id": latest_event_id}
)
return len(results)
num_rows = await self.db_pool.runInteraction(
desc="event_thread_relation", func=_event_thread_relation_txn
)
if not num_rows:
await self.db_pool.updates._end_background_update("event_thread_relation")
return num_rows
async def _background_populate_stream_ordering2(
self, progress: JsonDict, batch_size: int
) -> int:

View file

@ -55,8 +55,9 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
@ -86,6 +87,47 @@ class _EventCacheEntry:
redacted_event: Optional[EventBase]
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _EventRow:
"""
An event, as pulled from the database.
Properties:
event_id: The event ID of the event.
stream_ordering: stream ordering for this event
json: json-encoded event structure
internal_metadata: json-encoded internal metadata dict
format_version: The format of the event. Hopefully one of EventFormatVersions.
'None' means the event predates EventFormatVersions (so the event is format V1).
room_version_id: The version of the room which contains the event. Hopefully
one of RoomVersions.
Due to historical reasons, there may be a few events in the database which
do not have an associated room; in this case None will be returned here.
rejected_reason: if the event was rejected, the reason why.
redactions: a list of event-ids which (claim to) redact this event.
outlier: True if this event is an outlier.
"""
event_id: str
stream_ordering: int
json: str
internal_metadata: str
format_version: Optional[int]
room_version_id: Optional[int]
rejected_reason: Optional[str]
redactions: List[str]
outlier: bool
class EventRedactBehaviour(Names):
"""
What to do when retrieving a redacted event from the database.
@ -686,7 +728,7 @@ class EventsWorkerStore(SQLBaseStore):
for e in state_to_include.values()
]
def _do_fetch(self, conn):
def _do_fetch(self, conn: Connection) -> None:
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
@ -713,13 +755,15 @@ class EventsWorkerStore(SQLBaseStore):
self._fetch_event_list(conn, event_list)
def _fetch_event_list(self, conn, event_list):
def _fetch_event_list(
self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
) -> None:
"""Handle a load of requests from the _event_fetch_list queue
Args:
conn (twisted.enterprise.adbapi.Connection): database connection
conn: database connection
event_list (list[Tuple[list[str], Deferred]]):
event_list:
The fetch requests. Each entry consists of a list of event
ids to be fetched, and a deferred to be completed once the
events have been fetched.
@ -788,7 +832,7 @@ class EventsWorkerStore(SQLBaseStore):
row = row_map.get(event_id)
fetched_events[event_id] = row
if row:
redaction_ids.update(row["redactions"])
redaction_ids.update(row.redactions)
events_to_fetch = redaction_ids.difference(fetched_events.keys())
if events_to_fetch:
@ -799,32 +843,32 @@ class EventsWorkerStore(SQLBaseStore):
for event_id, row in fetched_events.items():
if not row:
continue
assert row["event_id"] == event_id
assert row.event_id == event_id
rejected_reason = row["rejected_reason"]
rejected_reason = row.rejected_reason
# If the event or metadata cannot be parsed, log the error and act
# as if the event is unknown.
try:
d = db_to_json(row["json"])
d = db_to_json(row.json)
except ValueError:
logger.error("Unable to parse json from event: %s", event_id)
continue
try:
internal_metadata = db_to_json(row["internal_metadata"])
internal_metadata = db_to_json(row.internal_metadata)
except ValueError:
logger.error(
"Unable to parse internal_metadata from event: %s", event_id
)
continue
format_version = row["format_version"]
format_version = row.format_version
if format_version is None:
# This means that we stored the event before we had the concept
# of a event format version, so it must be a V1 event.
format_version = EventFormatVersions.V1
room_version_id = row["room_version_id"]
room_version_id = row.room_version_id
if not room_version_id:
# this should only happen for out-of-band membership events which
@ -889,8 +933,8 @@ class EventsWorkerStore(SQLBaseStore):
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
original_ev.internal_metadata.stream_ordering = row["stream_ordering"]
original_ev.internal_metadata.outlier = row["outlier"]
original_ev.internal_metadata.stream_ordering = row.stream_ordering
original_ev.internal_metadata.outlier = row.outlier
event_map[event_id] = original_ev
@ -898,7 +942,7 @@ class EventsWorkerStore(SQLBaseStore):
# the cache entries.
result_map = {}
for event_id, original_ev in event_map.items():
redactions = fetched_events[event_id]["redactions"]
redactions = fetched_events[event_id].redactions
redacted_event = self._maybe_redact_event_row(
original_ev, redactions, event_map
)
@ -912,17 +956,17 @@ class EventsWorkerStore(SQLBaseStore):
return result_map
async def _enqueue_events(self, events):
async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
Args:
events (Iterable[str]): events to be fetched.
events: events to be fetched.
Returns:
Dict[str, Dict]: map from event id to row data from the database.
May contain events that weren't requested.
A map from event id to row data from the database. May contain events
that weren't requested.
"""
events_d = defer.Deferred()
@ -949,43 +993,19 @@ class EventsWorkerStore(SQLBaseStore):
return row_map
def _fetch_event_rows(self, txn, event_ids):
def _fetch_event_rows(
self, txn: LoggingTransaction, event_ids: Iterable[str]
) -> Dict[str, _EventRow]:
"""Fetch event rows from the database
Events which are not found are omitted from the result.
The returned per-event dicts contain the following keys:
* event_id (str)
* stream_ordering (int): stream ordering for this event
* json (str): json-encoded event structure
* internal_metadata (str): json-encoded internal metadata dict
* format_version (int|None): The format of the event. Hopefully one
of EventFormatVersions. 'None' means the event predates
EventFormatVersions (so the event is format V1).
* room_version_id (str|None): The version of the room which contains the event.
Hopefully one of RoomVersions.
Due to historical reasons, there may be a few events in the database which
do not have an associated room; in this case None will be returned here.
* rejected_reason (str|None): if the event was rejected, the reason
why.
* redactions (List[str]): a list of event-ids which (claim to) redact
this event.
Args:
txn (twisted.enterprise.adbapi.Connection):
event_ids (Iterable[str]): event IDs to fetch
txn: The database transaction.
event_ids: event IDs to fetch
Returns:
Dict[str, Dict]: a map from event id to event info.
A map from event id to event info.
"""
event_dict = {}
for evs in batch_iter(event_ids, 200):
@ -1013,17 +1033,17 @@ class EventsWorkerStore(SQLBaseStore):
for row in txn:
event_id = row[0]
event_dict[event_id] = {
"event_id": event_id,
"stream_ordering": row[1],
"internal_metadata": row[2],
"json": row[3],
"format_version": row[4],
"room_version_id": row[5],
"rejected_reason": row[6],
"redactions": [],
"outlier": row[7],
}
event_dict[event_id] = _EventRow(
event_id=event_id,
stream_ordering=row[1],
internal_metadata=row[2],
json=row[3],
format_version=row[4],
room_version_id=row[5],
rejected_reason=row[6],
redactions=[],
outlier=row[7],
)
# check for redactions
redactions_sql = "SELECT event_id, redacts FROM redactions WHERE "
@ -1035,7 +1055,7 @@ class EventsWorkerStore(SQLBaseStore):
for (redacter, redacted) in txn:
d = event_dict.get(redacted)
if d:
d["redactions"].append(redacter)
d.redactions.append(redacter)
return event_dict

View file

@ -13,11 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
if TYPE_CHECKING:
from synapse.server import HomeServer
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
"media_repository_drop_index_wo_method"
)
@ -43,7 +46,7 @@ class MediaSortOrder(Enum):
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@ -123,7 +126,7 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"""Persistence for attachments and avatars"""
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname

View file

@ -14,7 +14,7 @@
import calendar
import logging
import time
from typing import Dict
from typing import TYPE_CHECKING, Dict
from synapse.metrics import GaugeBucketCollector
from synapse.metrics.background_process_metrics import wrap_as_background_process
@ -24,6 +24,9 @@ from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
# Collect metrics on the number of forward extremities that exist.
@ -52,7 +55,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
stats and prometheus metrics.
"""
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
# Read the extrems every 60 minutes

View file

@ -12,13 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
# Number of msec of granularity to store the monthly_active_user timestamp
@ -27,7 +30,7 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
class MonthlyActiveUsersWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self._clock = hs.get_clock()
self.hs = hs
@ -209,7 +212,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self._mau_stats_only = hs.config.server.mau_stats_only
@ -354,27 +357,3 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
await self.upsert_monthly_active_user(user_id)
elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
await self.upsert_monthly_active_user(user_id)
async def remove_deactivated_user_from_mau_table(self, user_id: str) -> None:
"""
Removes a deactivated user from the monthly active user
table and resets affected caches.
Args:
user_id(str): the user_id to remove
"""
rows_deleted = await self.db_pool.simple_delete(
table="monthly_active_users",
keyvalues={"user_id": user_id},
desc="simple_delete",
)
if rows_deleted != 0:
await self.invalidate_cache_and_stream(
"user_last_seen_monthly_active", (user_id,)
)
await self.invalidate_cache_and_stream("get_monthly_active_count", ())
await self.invalidate_cache_and_stream(
"get_monthly_active_count_by_service", ()
)

View file

@ -14,7 +14,7 @@
# limitations under the License.
import abc
import logging
from typing import Dict, List, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from synapse.api.errors import NotFoundError, StoreError
from synapse.push.baserules import list_with_base_rules
@ -33,6 +33,9 @@ from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -75,7 +78,7 @@ class PushRulesWorkerStore(
`get_max_push_rules_stream_id` which can be called in the initializer.
"""
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:

View file

@ -14,7 +14,7 @@
# limitations under the License.
import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
from twisted.internet import defer
@ -29,11 +29,14 @@ from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class ReceiptsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
self._instance_name = hs.get_instance_name()
if isinstance(database.engine, PostgresEngine):

View file

@ -23,7 +23,11 @@ import attr
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Cursor
@ -40,6 +44,13 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
logger = logging.getLogger(__name__)
class ExternalIDReuseException(Exception):
"""Exception if writing an external id for a user fails,
because this external id is given to an other user."""
pass
@attr.s(frozen=True, slots=True)
class TokenLookupResult:
"""Result of looking up an access token.
@ -488,6 +499,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
async def set_user_type(self, user: UserID, user_type: Optional[UserTypes]) -> None:
"""Sets the user type.
Args:
user: user ID of the user.
user_type: type of the user or None for a user without a type.
"""
def set_user_type_txn(txn):
self.db_pool.simple_update_one_txn(
txn, "users", {"name": user.to_string()}, {"user_type": user_type}
)
self._invalidate_cache_and_stream(
txn, self.get_user_by_id, (user.to_string(),)
)
await self.db_pool.runInteraction("set_user_type", set_user_type_txn)
def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
sql = """
SELECT users.name as user_id,
@ -588,24 +617,44 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
auth_provider: identifier for the remote auth provider
external_id: id on that system
user_id: complete mxid that it is mapped to
Raises:
ExternalIDReuseException if the new external_id could not be mapped.
"""
await self.db_pool.simple_insert(
try:
await self.db_pool.runInteraction(
"record_user_external_id",
self._record_user_external_id_txn,
auth_provider,
external_id,
user_id,
)
except self.database_engine.module.IntegrityError:
raise ExternalIDReuseException()
def _record_user_external_id_txn(
self,
txn: LoggingTransaction,
auth_provider: str,
external_id: str,
user_id: str,
) -> None:
self.db_pool.simple_insert_txn(
txn,
table="user_external_ids",
values={
"auth_provider": auth_provider,
"external_id": external_id,
"user_id": user_id,
},
desc="record_user_external_id",
)
async def remove_user_external_id(
self, auth_provider: str, external_id: str, user_id: str
) -> None:
"""Remove a mapping from an external user id to a mxid
If the mapping is not found, this method does nothing.
Args:
auth_provider: identifier for the remote auth provider
external_id: id on that system
@ -621,6 +670,60 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="remove_user_external_id",
)
async def replace_user_external_id(
self,
record_external_ids: List[Tuple[str, str]],
user_id: str,
) -> None:
"""Replace mappings from external user ids to a mxid in a single transaction.
All mappings are deleted and the new ones are created.
Args:
record_external_ids:
List with tuple of auth_provider and external_id to record
user_id: complete mxid that it is mapped to
Raises:
ExternalIDReuseException if the new external_id could not be mapped.
"""
def _remove_user_external_ids_txn(
txn: LoggingTransaction,
user_id: str,
) -> None:
"""Remove all mappings from external user ids to a mxid
If these mappings are not found, this method does nothing.
Args:
user_id: complete mxid that it is mapped to
"""
self.db_pool.simple_delete_txn(
txn,
table="user_external_ids",
keyvalues={"user_id": user_id},
)
def _replace_user_external_id_txn(
txn: LoggingTransaction,
):
_remove_user_external_ids_txn(txn, user_id)
for auth_provider, external_id in record_external_ids:
self._record_user_external_id_txn(
txn,
auth_provider,
external_id,
user_id,
)
try:
await self.db_pool.runInteraction(
"replace_user_external_id",
_replace_user_external_id_txn,
)
except self.database_engine.module.IntegrityError:
raise ExternalIDReuseException()
async def get_user_by_external_id(
self, auth_provider: str, external_id: str
) -> Optional[str]:
@ -2237,7 +2340,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
# accident.
row = {"client_secret": None, "validated_at": None}
else:
raise ThreepidValidationError(400, "Unknown session_id")
raise ThreepidValidationError("Unknown session_id")
retrieved_client_secret = row["client_secret"]
validated_at = row["validated_at"]
@ -2252,14 +2355,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
if not row:
raise ThreepidValidationError(
400, "Validation token not found or has expired"
"Validation token not found or has expired"
)
expires = row["expires"]
next_link = row["next_link"]
if retrieved_client_secret != client_secret:
raise ThreepidValidationError(
400, "This client_secret does not match the provided session_id"
"This client_secret does not match the provided session_id"
)
# If the session is already validated, no need to revalidate
@ -2268,7 +2371,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
if expires <= current_ts:
raise ThreepidValidationError(
400, "This token has expired. Please request a new one"
"This token has expired. Please request a new one"
)
# Looks good. Validate the session

View file

@ -13,7 +13,7 @@
# limitations under the License.
import logging
from typing import Optional
from typing import Optional, Tuple
import attr
@ -269,6 +269,63 @@ class RelationsWorkerStore(SQLBaseStore):
return await self.get_event(edit_id, allow_none=True)
@cached()
async def get_thread_summary(
self, event_id: str
) -> Tuple[int, Optional[EventBase]]:
"""Get the number of threaded replies, the senders of those replies, and
the latest reply (if any) for the given event.
Args:
event_id: The original event ID
Returns:
The number of items in the thread and the most recent response, if any.
"""
def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]:
# Fetch the count of threaded events and the latest event ID.
# TODO Should this only allow m.room.message events.
sql = """
SELECT event_id
FROM event_relations
INNER JOIN events USING (event_id)
WHERE
relates_to_id = ?
AND relation_type = ?
ORDER BY topological_ordering DESC, stream_ordering DESC
LIMIT 1
"""
txn.execute(sql, (event_id, RelationTypes.THREAD))
row = txn.fetchone()
if row is None:
return 0, None
latest_event_id = row[0]
sql = """
SELECT COALESCE(COUNT(event_id), 0)
FROM event_relations
WHERE
relates_to_id = ?
AND relation_type = ?
"""
txn.execute(sql, (event_id, RelationTypes.THREAD))
count = txn.fetchone()[0]
return count, latest_event_id
count, latest_event_id = await self.db_pool.runInteraction(
"get_thread_summary", _get_thread_summary_txn
)
latest_event = None
if latest_event_id:
latest_event = await self.get_event(latest_event_id, allow_none=True)
return count, latest_event
async def has_user_annotated_event(
self, parent_id: str, event_type: str, aggregation_key: str, sender: str
) -> bool:

View file

@ -17,7 +17,7 @@ import collections
import logging
from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from synapse.api.constants import EventContentFields, EventTypes, JoinRules
from synapse.api.errors import StoreError
@ -32,6 +32,9 @@ from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.stringutils import MXC_REGEX
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -69,7 +72,7 @@ class RoomSortOrder(Enum):
class RoomWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.config = hs.config
@ -679,8 +682,8 @@ class RoomWorkerStore(SQLBaseStore):
# policy.
if not ret:
return {
"min_lifetime": self.config.server.retention_default_min_lifetime,
"max_lifetime": self.config.server.retention_default_max_lifetime,
"min_lifetime": self.config.retention.retention_default_min_lifetime,
"max_lifetime": self.config.retention.retention_default_max_lifetime,
}
row = ret[0]
@ -690,10 +693,10 @@ class RoomWorkerStore(SQLBaseStore):
# The default values will be None if no default policy has been defined, or if one
# of the attributes is missing from the default policy.
if row["min_lifetime"] is None:
row["min_lifetime"] = self.config.server.retention_default_min_lifetime
row["min_lifetime"] = self.config.retention.retention_default_min_lifetime
if row["max_lifetime"] is None:
row["max_lifetime"] = self.config.server.retention_default_max_lifetime
row["max_lifetime"] = self.config.retention.retention_default_max_lifetime
return row
@ -1026,7 +1029,7 @@ _REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
class RoomBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.config = hs.config
@ -1411,7 +1414,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.config = hs.config

View file

@ -36,3 +36,16 @@ class RoomBatchStore(SQLBaseStore):
retcol="event_id",
allow_none=True,
)
async def store_state_group_id_for_event_id(
self, event_id: str, state_group_id: int
) -> Optional[str]:
{
await self.db_pool.simple_upsert(
table="event_to_state_groups",
keyvalues={"event_id": event_id},
values={"state_group": state_group_id, "event_id": event_id},
# Unique constraint on event_id so we don't have to lock
lock=False,
)
}

View file

@ -53,6 +53,7 @@ from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.state import _StateCacheEntry
logger = logging.getLogger(__name__)
@ -63,7 +64,7 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
class RoomMemberWorkerStore(EventsWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
# Used by `_get_joined_hosts` to ensure only one thing mutates the cache
@ -982,7 +983,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
class RoomMemberBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
@ -1132,7 +1133,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
async def forget(self, user_id: str, room_id: str) -> None:

View file

@ -15,7 +15,7 @@
import logging
import re
from collections import namedtuple
from typing import Collection, Iterable, List, Optional, Set
from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
from synapse.api.errors import SynapseError
from synapse.events import EventBase
@ -24,6 +24,9 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
SearchEntry = namedtuple(
@ -102,7 +105,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
if not hs.config.server.enable_search:
@ -355,7 +358,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
class SearchStore(SearchBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
async def search_msgs(self, room_ids, search_term, keys):

View file

@ -15,7 +15,7 @@
import collections.abc
import logging
from collections import namedtuple
from typing import Iterable, Optional, Set
from typing import TYPE_CHECKING, Iterable, Optional, Set
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
@ -30,6 +30,9 @@ from synapse.types import StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -53,7 +56,7 @@ class _GetStateGroupDelta(
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""The parts of StateGroupStore that can be called from workers."""
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
async def get_room_version(self, room_id: str) -> RoomVersion:
@ -346,7 +349,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
@ -533,5 +536,5 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
* `state_groups_state`: Maps state group to state events.
"""
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

View file

@ -16,7 +16,7 @@
import logging
from enum import Enum
from itertools import chain
from typing import Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing_extensions import Counter
@ -29,6 +29,9 @@ from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
# these fields track absolutes (e.g. total number of rooms on the server)
@ -93,7 +96,7 @@ class UserSortOrder(Enum):
class StatsStore(StateDeltasStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname

View file

@ -14,7 +14,7 @@
import logging
from collections import namedtuple
from typing import Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json
@ -26,6 +26,9 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.server import HomeServer
db_binary_type = memoryview
logger = logging.getLogger(__name__)
@ -57,7 +60,7 @@ class DestinationRetryTimings:
class TransactionWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks:

View file

@ -18,6 +18,7 @@ import itertools
import logging
from collections import deque
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
@ -56,6 +57,9 @@ from synapse.types import (
from synapse.util.async_helpers import ObservableDeferred, yieldable_gather_results
from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
# The number of times we are recalculating the current state
@ -272,7 +276,7 @@ class EventsPersistenceStorage:
current state and forward extremity changes.
"""
def __init__(self, hs, stores: Databases):
def __init__(self, hs: "HomeServer", stores: Databases):
# We ultimately want to split out the state store from the main store,
# so we use separate variables here even though they point to the same
# store for now.

View file

@ -549,6 +549,8 @@ def _apply_module_schemas(
database_engine:
config: application config
"""
# This is the old way for password_auth_provider modules to make changes
# to the database. This should instead be done using the module API
for (mod, _config) in config.authproviders.password_providers:
if not hasattr(mod, "get_db_schema_files"):
continue

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
SCHEMA_VERSION = 64 # remember to update the list below when updating
SCHEMA_VERSION = 65 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the
@ -41,6 +41,10 @@ Changes in SCHEMA_VERSION = 63:
Changes in SCHEMA_VERSION = 64:
- MSC2716: Rename related tables and columns from "chunks" to "batches".
Changes in SCHEMA_VERSION = 65:
- MSC2716: Remove unique event_id constraint from insertion_event_edges
because an insertion event can have multiple edges.
"""

View file

@ -0,0 +1,19 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Recreate the insertion_event_edges event_id index without the unique constraint
-- because an insertion event can have multiple edges.
DROP INDEX insertion_event_edges_event_id;
CREATE INDEX IF NOT EXISTS insertion_event_edges_event_id ON insertion_event_edges(event_id);

View file

@ -0,0 +1,18 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Check old events for thread relations.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(6502, 'event_thread_relation', '{}');