mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-08-17 11:50:15 -04:00
Merge remote-tracking branch 'upstream/release-v1.46'
This commit is contained in:
commit
cf45cfd314
172 changed files with 5549 additions and 2350 deletions
|
@ -19,6 +19,7 @@ from collections import defaultdict
|
|||
from sys import intern
|
||||
from time import monotonic as monotonic_time
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Collection,
|
||||
|
@ -52,6 +53,9 @@ from synapse.storage.background_updates import BackgroundUpdater
|
|||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
||||
from synapse.storage.types import Connection, Cursor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
# python 3 does not have a maximum int value
|
||||
MAX_TXN_ID = 2 ** 63 - 1
|
||||
|
||||
|
@ -392,7 +396,7 @@ class DatabasePool:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
hs,
|
||||
hs: "HomeServer",
|
||||
database_config: DatabaseConnectionConfig,
|
||||
engine: BaseDatabaseEngine,
|
||||
):
|
||||
|
|
|
@ -13,33 +13,49 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Generic, List, Optional, Type, TypeVar
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import DatabasePool, make_conn
|
||||
from synapse.storage.databases.main.events import PersistEventsStore
|
||||
from synapse.storage.databases.state import StateGroupDataStore
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.storage.prepare_database import prepare_database
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Databases:
|
||||
DataStoreT = TypeVar("DataStoreT", bound=SQLBaseStore, covariant=True)
|
||||
|
||||
|
||||
class Databases(Generic[DataStoreT]):
|
||||
"""The various databases.
|
||||
|
||||
These are low level interfaces to physical databases.
|
||||
|
||||
Attributes:
|
||||
main (DataStore)
|
||||
databases
|
||||
main
|
||||
state
|
||||
persist_events
|
||||
"""
|
||||
|
||||
def __init__(self, main_store_class, hs):
|
||||
databases: List[DatabasePool]
|
||||
main: DataStoreT
|
||||
state: StateGroupDataStore
|
||||
persist_events: Optional[PersistEventsStore]
|
||||
|
||||
def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"):
|
||||
# Note we pass in the main store class here as workers use a different main
|
||||
# store.
|
||||
|
||||
self.databases = []
|
||||
main = None
|
||||
state = None
|
||||
persist_events = None
|
||||
main: Optional[DataStoreT] = None
|
||||
state: Optional[StateGroupDataStore] = None
|
||||
persist_events: Optional[PersistEventsStore] = None
|
||||
|
||||
for database_config in hs.config.database.databases:
|
||||
db_name = database_config.name
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.storage.database import DatabasePool
|
||||
|
@ -75,6 +75,9 @@ from .ui_auth import UIAuthStore
|
|||
from .user_directory import UserDirectoryStore
|
||||
from .user_erasure_store import UserErasureStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -126,7 +129,7 @@ class DataStore(
|
|||
LockStore,
|
||||
SessionStore,
|
||||
):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self._clock = hs.get_clock()
|
||||
self.database_engine = database.engine
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from synapse.api.constants import AccountDataTypes
|
||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||
|
@ -28,6 +28,9 @@ from synapse.util import json_encoder
|
|||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -36,7 +39,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
|||
`get_max_account_data_stream_id` which can be called in the initializer.
|
||||
"""
|
||||
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
if isinstance(database.engine, PostgresEngine):
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Any, Iterable, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.replication.tcp.streams import BackfillStream, CachesStream
|
||||
|
@ -29,6 +29,9 @@ from synapse.storage.database import DatabasePool
|
|||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -38,7 +41,7 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
|
|||
|
||||
|
||||
class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
|
|
@ -13,14 +13,26 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
|
||||
from synapse.types import UserID
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
make_tuple_comparison_clause,
|
||||
)
|
||||
from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore
|
||||
from synapse.storage.types import Connection
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
|
||||
|
@ -29,8 +41,31 @@ logger = logging.getLogger(__name__)
|
|||
LAST_SEEN_GRANULARITY = 120 * 1000
|
||||
|
||||
|
||||
class DeviceLastConnectionInfo(TypedDict):
|
||||
"""Metadata for the last connection seen for a user and device combination"""
|
||||
|
||||
# These types must match the columns in the `devices` table
|
||||
user_id: str
|
||||
device_id: str
|
||||
|
||||
ip: Optional[str]
|
||||
user_agent: Optional[str]
|
||||
last_seen: Optional[int]
|
||||
|
||||
|
||||
class LastConnectionInfo(TypedDict):
|
||||
"""Metadata for the last connection seen for an access token and IP combination"""
|
||||
|
||||
# These types must match the columns in the `user_ips` table
|
||||
access_token: str
|
||||
ip: str
|
||||
|
||||
user_agent: str
|
||||
last_seen: int
|
||||
|
||||
|
||||
class ClientIpBackgroundUpdateStore(SQLBaseStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
|
@ -81,8 +116,10 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
|
|||
"devices_last_seen", self._devices_last_seen_update
|
||||
)
|
||||
|
||||
async def _remove_user_ip_nonunique(self, progress, batch_size):
|
||||
def f(conn):
|
||||
async def _remove_user_ip_nonunique(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
def f(conn: LoggingDatabaseConnection) -> None:
|
||||
txn = conn.cursor()
|
||||
txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
|
||||
txn.close()
|
||||
|
@ -93,14 +130,14 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
|
|||
)
|
||||
return 1
|
||||
|
||||
async def _analyze_user_ip(self, progress, batch_size):
|
||||
async def _analyze_user_ip(self, progress: JsonDict, batch_size: int) -> int:
|
||||
# Background update to analyze user_ips table before we run the
|
||||
# deduplication background update. The table may not have been analyzed
|
||||
# for ages due to the table locks.
|
||||
#
|
||||
# This will lock out the naive upserts to user_ips while it happens, but
|
||||
# the analyze should be quick (28GB table takes ~10s)
|
||||
def user_ips_analyze(txn):
|
||||
def user_ips_analyze(txn: LoggingTransaction) -> None:
|
||||
txn.execute("ANALYZE user_ips")
|
||||
|
||||
await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze)
|
||||
|
@ -109,16 +146,16 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
|
|||
|
||||
return 1
|
||||
|
||||
async def _remove_user_ip_dupes(self, progress, batch_size):
|
||||
async def _remove_user_ip_dupes(self, progress: JsonDict, batch_size: int) -> int:
|
||||
# This works function works by scanning the user_ips table in batches
|
||||
# based on `last_seen`. For each row in a batch it searches the rest of
|
||||
# the table to see if there are any duplicates, if there are then they
|
||||
# are removed and replaced with a suitable row.
|
||||
|
||||
# Fetch the start of the batch
|
||||
begin_last_seen = progress.get("last_seen", 0)
|
||||
begin_last_seen: int = progress.get("last_seen", 0)
|
||||
|
||||
def get_last_seen(txn):
|
||||
def get_last_seen(txn: LoggingTransaction) -> Optional[int]:
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT last_seen FROM user_ips
|
||||
|
@ -129,7 +166,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
|
|||
""",
|
||||
(begin_last_seen, batch_size),
|
||||
)
|
||||
row = txn.fetchone()
|
||||
row = cast(Optional[Tuple[int]], txn.fetchone())
|
||||
if row:
|
||||
return row[0]
|
||||
else:
|
||||
|
@ -149,7 +186,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
|
|||
end_last_seen,
|
||||
)
|
||||
|
||||
def remove(txn):
|
||||
def remove(txn: LoggingTransaction) -> None:
|
||||
# This works by looking at all entries in the given time span, and
|
||||
# then for each (user_id, access_token, ip) tuple in that range
|
||||
# checking for any duplicates in the rest of the table (via a join).
|
||||
|
@ -161,10 +198,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
|
|||
|
||||
# Define the search space, which requires handling the last batch in
|
||||
# a different way
|
||||
args: Tuple[int, ...]
|
||||
if last:
|
||||
clause = "? <= last_seen"
|
||||
args = (begin_last_seen,)
|
||||
else:
|
||||
assert end_last_seen is not None
|
||||
clause = "? <= last_seen AND last_seen < ?"
|
||||
args = (begin_last_seen, end_last_seen)
|
||||
|
||||
|
@ -189,7 +228,9 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
|
|||
),
|
||||
args,
|
||||
)
|
||||
res = txn.fetchall()
|
||||
res = cast(
|
||||
List[Tuple[str, str, str, Optional[str], str, int, int]], txn.fetchall()
|
||||
)
|
||||
|
||||
# We've got some duplicates
|
||||
for i in res:
|
||||
|
@ -278,13 +319,15 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
|
|||
|
||||
return batch_size
|
||||
|
||||
async def _devices_last_seen_update(self, progress, batch_size):
|
||||
async def _devices_last_seen_update(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
"""Background update to insert last seen info into devices table"""
|
||||
|
||||
last_user_id = progress.get("last_user_id", "")
|
||||
last_device_id = progress.get("last_device_id", "")
|
||||
last_user_id: str = progress.get("last_user_id", "")
|
||||
last_device_id: str = progress.get("last_device_id", "")
|
||||
|
||||
def _devices_last_seen_update_txn(txn):
|
||||
def _devices_last_seen_update_txn(txn: LoggingTransaction) -> int:
|
||||
# This consists of two queries:
|
||||
#
|
||||
# 1. The sub-query searches for the next N devices and joins
|
||||
|
@ -296,6 +339,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
|
|||
# we'll just end up updating the same device row multiple
|
||||
# times, which is fine.
|
||||
|
||||
where_args: List[Union[str, int]]
|
||||
where_clause, where_args = make_tuple_comparison_clause(
|
||||
[("user_id", last_user_id), ("device_id", last_device_id)],
|
||||
)
|
||||
|
@ -319,7 +363,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
|
|||
}
|
||||
txn.execute(sql, where_args + [batch_size])
|
||||
|
||||
rows = txn.fetchall()
|
||||
rows = cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
|
||||
if not rows:
|
||||
return 0
|
||||
|
||||
|
@ -350,7 +394,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
|
|||
|
||||
|
||||
class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self.user_ips_max_age = hs.config.server.user_ips_max_age
|
||||
|
@ -359,7 +403,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
|
|||
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
|
||||
|
||||
@wrap_as_background_process("prune_old_user_ips")
|
||||
async def _prune_old_user_ips(self):
|
||||
async def _prune_old_user_ips(self) -> None:
|
||||
"""Removes entries in user IPs older than the configured period."""
|
||||
|
||||
if self.user_ips_max_age is None:
|
||||
|
@ -394,9 +438,9 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
|
|||
)
|
||||
"""
|
||||
|
||||
timestamp = self.clock.time_msec() - self.user_ips_max_age
|
||||
timestamp = self._clock.time_msec() - self.user_ips_max_age
|
||||
|
||||
def _prune_old_user_ips_txn(txn):
|
||||
def _prune_old_user_ips_txn(txn: LoggingTransaction) -> None:
|
||||
txn.execute(sql, (timestamp,))
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
|
@ -405,7 +449,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
|
|||
|
||||
async def get_last_client_ip_by_device(
|
||||
self, user_id: str, device_id: Optional[str]
|
||||
) -> Dict[Tuple[str, str], dict]:
|
||||
) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]:
|
||||
"""For each device_id listed, give the user_ip it was last seen on.
|
||||
|
||||
The result might be slightly out of date as client IPs are inserted in batches.
|
||||
|
@ -423,26 +467,84 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
|
|||
if device_id is not None:
|
||||
keyvalues["device_id"] = device_id
|
||||
|
||||
res = await self.db_pool.simple_select_list(
|
||||
table="devices",
|
||||
keyvalues=keyvalues,
|
||||
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
|
||||
res = cast(
|
||||
List[DeviceLastConnectionInfo],
|
||||
await self.db_pool.simple_select_list(
|
||||
table="devices",
|
||||
keyvalues=keyvalues,
|
||||
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
|
||||
),
|
||||
)
|
||||
|
||||
return {(d["user_id"], d["device_id"]): d for d in res}
|
||||
|
||||
async def get_user_ip_and_agents(
|
||||
self, user: UserID, since_ts: int = 0
|
||||
) -> List[LastConnectionInfo]:
|
||||
"""Fetch the IPs and user agents for a user since the given timestamp.
|
||||
|
||||
class ClientIpStore(ClientIpWorkerStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
The result might be slightly out of date as client IPs are inserted in batches.
|
||||
|
||||
self.client_ip_last_seen = LruCache(
|
||||
Args:
|
||||
user: The user for which to fetch IP addresses and user agents.
|
||||
since_ts: The timestamp after which to fetch IP addresses and user agents,
|
||||
in milliseconds.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries, each containing:
|
||||
* `access_token`: The access token used.
|
||||
* `ip`: The IP address used.
|
||||
* `user_agent`: The last user agent seen for this access token and IP
|
||||
address combination.
|
||||
* `last_seen`: The timestamp at which this access token and IP address
|
||||
combination was last seen, in milliseconds.
|
||||
|
||||
Only the latest user agent for each access token and IP address combination
|
||||
is available.
|
||||
"""
|
||||
user_id = user.to_string()
|
||||
|
||||
def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]:
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT access_token, ip, user_agent, last_seen FROM user_ips
|
||||
WHERE last_seen >= ? AND user_id = ?
|
||||
ORDER BY last_seen
|
||||
DESC
|
||||
""",
|
||||
(since_ts, user_id),
|
||||
)
|
||||
return cast(List[Tuple[str, str, str, int]], txn.fetchall())
|
||||
|
||||
rows = await self.db_pool.runInteraction(
|
||||
desc="get_user_ip_and_agents", func=get_recent
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"access_token": access_token,
|
||||
"ip": ip,
|
||||
"user_agent": user_agent,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
for access_token, ip, user_agent, last_seen in rows
|
||||
]
|
||||
|
||||
|
||||
class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
|
||||
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
||||
|
||||
# (user_id, access_token, ip,) -> last_seen
|
||||
self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
|
||||
cache_name="client_ip_last_seen", max_size=50000
|
||||
)
|
||||
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
|
||||
self._batch_row_update = {}
|
||||
self._batch_row_update: Dict[
|
||||
Tuple[str, str, str], Tuple[str, Optional[str], int]
|
||||
] = {}
|
||||
|
||||
self._client_ip_looper = self._clock.looping_call(
|
||||
self._update_client_ips_batch, 5 * 1000
|
||||
|
@ -452,8 +554,14 @@ class ClientIpStore(ClientIpWorkerStore):
|
|||
)
|
||||
|
||||
async def insert_client_ip(
|
||||
self, user_id, access_token, ip, user_agent, device_id, now=None
|
||||
):
|
||||
self,
|
||||
user_id: str,
|
||||
access_token: str,
|
||||
ip: str,
|
||||
user_agent: str,
|
||||
device_id: Optional[str],
|
||||
now: Optional[int] = None,
|
||||
) -> None:
|
||||
if not now:
|
||||
now = int(self._clock.time_msec())
|
||||
key = (user_id, access_token, ip)
|
||||
|
@ -485,7 +593,11 @@ class ClientIpStore(ClientIpWorkerStore):
|
|||
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
|
||||
)
|
||||
|
||||
def _update_client_ips_batch_txn(self, txn, to_update):
|
||||
def _update_client_ips_batch_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]],
|
||||
) -> None:
|
||||
if "user_ips" in self.db_pool._unsafe_to_upsert_tables or (
|
||||
not self.database_engine.can_native_upsert
|
||||
):
|
||||
|
@ -525,7 +637,7 @@ class ClientIpStore(ClientIpWorkerStore):
|
|||
|
||||
async def get_last_client_ip_by_device(
|
||||
self, user_id: str, device_id: Optional[str]
|
||||
) -> Dict[Tuple[str, str], dict]:
|
||||
) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]:
|
||||
"""For each device_id listed, give the user_ip it was last seen on
|
||||
|
||||
Args:
|
||||
|
@ -561,50 +673,44 @@ class ClientIpStore(ClientIpWorkerStore):
|
|||
|
||||
async def get_user_ip_and_agents(
|
||||
self, user: UserID, since_ts: int = 0
|
||||
) -> List[Dict[str, Union[str, int]]]:
|
||||
"""
|
||||
Fetch IP/User Agent connection since a given timestamp.
|
||||
"""
|
||||
user_id = user.to_string()
|
||||
results = {}
|
||||
) -> List[LastConnectionInfo]:
|
||||
"""Fetch the IPs and user agents for a user since the given timestamp.
|
||||
|
||||
Args:
|
||||
user: The user for which to fetch IP addresses and user agents.
|
||||
since_ts: The timestamp after which to fetch IP addresses and user agents,
|
||||
in milliseconds.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries, each containing:
|
||||
* `access_token`: The access token used.
|
||||
* `ip`: The IP address used.
|
||||
* `user_agent`: The last user agent seen for this access token and IP
|
||||
address combination.
|
||||
* `last_seen`: The timestamp at which this access token and IP address
|
||||
combination was last seen, in milliseconds.
|
||||
|
||||
Only the latest user agent for each access token and IP address combination
|
||||
is available.
|
||||
"""
|
||||
results: Dict[Tuple[str, str], LastConnectionInfo] = {
|
||||
(connection["access_token"], connection["ip"]): connection
|
||||
for connection in await super().get_user_ip_and_agents(user, since_ts)
|
||||
}
|
||||
|
||||
# Overlay data that is pending insertion on top of the results from the
|
||||
# database.
|
||||
user_id = user.to_string()
|
||||
for key in self._batch_row_update:
|
||||
(
|
||||
uid,
|
||||
access_token,
|
||||
ip,
|
||||
) = key
|
||||
uid, access_token, ip = key
|
||||
if uid == user_id:
|
||||
user_agent, _, last_seen = self._batch_row_update[key]
|
||||
if last_seen >= since_ts:
|
||||
results[(access_token, ip)] = (user_agent, last_seen)
|
||||
results[(access_token, ip)] = {
|
||||
"access_token": access_token,
|
||||
"ip": ip,
|
||||
"user_agent": user_agent,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
|
||||
def get_recent(txn):
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT access_token, ip, user_agent, last_seen FROM user_ips
|
||||
WHERE last_seen >= ? AND user_id = ?
|
||||
ORDER BY last_seen
|
||||
DESC
|
||||
""",
|
||||
(since_ts, user_id),
|
||||
)
|
||||
return txn.fetchall()
|
||||
|
||||
rows = await self.db_pool.runInteraction(
|
||||
desc="get_user_ip_and_agents", func=get_recent
|
||||
)
|
||||
|
||||
results.update(
|
||||
((access_token, ip), (user_agent, last_seen))
|
||||
for access_token, ip, user_agent, last_seen in rows
|
||||
)
|
||||
return [
|
||||
{
|
||||
"access_token": access_token,
|
||||
"ip": ip,
|
||||
"user_agent": user_agent,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
for (access_token, ip), (user_agent, last_seen) in results.items()
|
||||
]
|
||||
return list(results.values())
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
from synapse.logging import issue9533_logger
|
||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||
|
@ -26,11 +26,14 @@ from synapse.util import json_encoder
|
|||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeviceInboxWorkerStore(SQLBaseStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
@ -553,7 +556,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
|
||||
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
|
||||
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
|
|
|
@ -15,7 +15,17 @@
|
|||
# limitations under the License.
|
||||
import abc
|
||||
import logging
|
||||
from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from synapse.api.errors import Codes, StoreError
|
||||
from synapse.logging.opentracing import (
|
||||
|
@ -38,6 +48,9 @@ from synapse.util.caches.lrucache import LruCache
|
|||
from synapse.util.iterutils import batch_iter
|
||||
from synapse.util.stringutils import shortstr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
|
||||
|
@ -48,7 +61,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
|
|||
|
||||
|
||||
class DeviceWorkerStore(SQLBaseStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
if hs.config.worker.run_background_tasks:
|
||||
|
@ -915,7 +928,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
|
||||
|
||||
class DeviceBackgroundUpdateStore(SQLBaseStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
|
@ -1047,7 +1060,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
|||
|
||||
|
||||
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
# Map of (user_id, device_id) -> bool. If there is an entry that implies
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
import itertools
|
||||
import logging
|
||||
from queue import Empty, PriorityQueue
|
||||
from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
from prometheus_client import Counter, Gauge
|
||||
|
||||
|
@ -34,6 +34,9 @@ from synapse.util.caches.descriptors import cached
|
|||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
oldest_pdu_in_federation_staging = Gauge(
|
||||
"synapse_federation_server_oldest_inbound_pdu_in_staging",
|
||||
"The age in seconds since we received the oldest pdu in the federation staging area",
|
||||
|
@ -59,7 +62,7 @@ class _NoChainCoverIndex(Exception):
|
|||
|
||||
|
||||
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
if hs.config.worker.run_background_tasks:
|
||||
|
@ -906,7 +909,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||
desc="get_latest_event_ids_in_room",
|
||||
)
|
||||
|
||||
async def get_min_depth(self, room_id: str) -> int:
|
||||
async def get_min_depth(self, room_id: str) -> Optional[int]:
|
||||
"""For the given room, get the minimum depth we have seen for it."""
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_min_depth", self._get_min_depth_interaction, room_id
|
||||
|
@ -1511,7 +1514,7 @@ class EventFederationStore(EventFederationWorkerStore):
|
|||
|
||||
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
|
||||
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import attr
|
||||
|
||||
|
@ -23,6 +23,9 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
|
|||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -64,7 +67,7 @@ def _deserialize_action(actions, is_highlight):
|
|||
|
||||
|
||||
class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
# These get correctly set by _find_stream_orderings_for_times_txn
|
||||
|
@ -892,7 +895,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
class EventPushActionsStore(EventPushActionsWorkerStore):
|
||||
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
|
||||
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
|
|
|
@ -1710,6 +1710,7 @@ class PersistEventsStore:
|
|||
RelationTypes.ANNOTATION,
|
||||
RelationTypes.REFERENCE,
|
||||
RelationTypes.REPLACE,
|
||||
RelationTypes.THREAD,
|
||||
):
|
||||
# Unknown relation type
|
||||
return
|
||||
|
@ -1740,6 +1741,9 @@ class PersistEventsStore:
|
|||
if rel_type == RelationTypes.REPLACE:
|
||||
txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
|
||||
|
||||
if rel_type == RelationTypes.THREAD:
|
||||
txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
|
||||
|
||||
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
|
||||
"""Handles keeping track of insertion events and edges/connections.
|
||||
Part of MSC2716.
|
||||
|
@ -2069,12 +2073,14 @@ class PersistEventsStore:
|
|||
|
||||
state_groups[event.event_id] = context.state_group
|
||||
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
self.db_pool.simple_upsert_many_txn(
|
||||
txn,
|
||||
table="event_to_state_groups",
|
||||
values=[
|
||||
{"state_group": state_group_id, "event_id": event_id}
|
||||
for event_id, state_group_id in state_groups.items()
|
||||
key_names=["event_id"],
|
||||
key_values=[[event_id] for event_id, _ in state_groups.items()],
|
||||
value_names=["state_group"],
|
||||
value_values=[
|
||||
[state_group_id] for _, state_group_id in state_groups.items()
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -13,19 +13,26 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import EventContentFields
|
||||
from synapse.api.constants import EventContentFields, RelationTypes
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.events import make_event_from_dict
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingTransaction,
|
||||
make_tuple_comparison_clause,
|
||||
)
|
||||
from synapse.storage.databases.main.events import PersistEventsStore
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -76,7 +83,7 @@ class _CalculateChainCover:
|
|||
|
||||
|
||||
class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
|
@ -164,6 +171,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
self._purged_chain_cover_index,
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
"event_thread_relation", self._event_thread_relation
|
||||
)
|
||||
|
||||
################################################################################
|
||||
|
||||
# bg updates for replacing stream_ordering with a BIGINT
|
||||
|
@ -1088,6 +1099,79 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
|
||||
return result
|
||||
|
||||
async def _event_thread_relation(self, progress: JsonDict, batch_size: int) -> int:
|
||||
"""Background update handler which will store thread relations for existing events."""
|
||||
last_event_id = progress.get("last_event_id", "")
|
||||
|
||||
def _event_thread_relation_txn(txn: LoggingTransaction) -> int:
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT event_id, json FROM event_json
|
||||
LEFT JOIN event_relations USING (event_id)
|
||||
WHERE event_id > ? AND event_relations.event_id IS NULL
|
||||
ORDER BY event_id LIMIT ?
|
||||
""",
|
||||
(last_event_id, batch_size),
|
||||
)
|
||||
|
||||
results = list(txn)
|
||||
missing_thread_relations = []
|
||||
for (event_id, event_json_raw) in results:
|
||||
try:
|
||||
event_json = db_to_json(event_json_raw)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Unable to load event %s (no relations will be updated): %s",
|
||||
event_id,
|
||||
e,
|
||||
)
|
||||
continue
|
||||
|
||||
# If there's no relation (or it is not a thread), skip!
|
||||
relates_to = event_json["content"].get("m.relates_to")
|
||||
if not relates_to or not isinstance(relates_to, dict):
|
||||
continue
|
||||
if relates_to.get("rel_type") != RelationTypes.THREAD:
|
||||
continue
|
||||
|
||||
# Get the parent ID.
|
||||
parent_id = relates_to.get("event_id")
|
||||
if not isinstance(parent_id, str):
|
||||
continue
|
||||
|
||||
missing_thread_relations.append((event_id, parent_id))
|
||||
|
||||
# Insert the missing data.
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
txn=txn,
|
||||
table="event_relations",
|
||||
values=[
|
||||
{
|
||||
"event_id": event_id,
|
||||
"relates_to_Id": parent_id,
|
||||
"relation_type": RelationTypes.THREAD,
|
||||
}
|
||||
for event_id, parent_id in missing_thread_relations
|
||||
],
|
||||
)
|
||||
|
||||
if results:
|
||||
latest_event_id = results[-1][0]
|
||||
self.db_pool.updates._background_update_progress_txn(
|
||||
txn, "event_thread_relation", {"last_event_id": latest_event_id}
|
||||
)
|
||||
|
||||
return len(results)
|
||||
|
||||
num_rows = await self.db_pool.runInteraction(
|
||||
desc="event_thread_relation", func=_event_thread_relation_txn
|
||||
)
|
||||
|
||||
if not num_rows:
|
||||
await self.db_pool.updates._end_background_update("event_thread_relation")
|
||||
|
||||
return num_rows
|
||||
|
||||
async def _background_populate_stream_ordering2(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
|
|
|
@ -55,8 +55,9 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
|||
from synapse.replication.tcp.streams import BackfillStream
|
||||
from synapse.replication.tcp.streams.events import EventsStream
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.types import Connection
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
||||
from synapse.storage.util.sequence import build_sequence_generator
|
||||
from synapse.types import JsonDict, get_domain_from_id
|
||||
|
@ -86,6 +87,47 @@ class _EventCacheEntry:
|
|||
redacted_event: Optional[EventBase]
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class _EventRow:
|
||||
"""
|
||||
An event, as pulled from the database.
|
||||
|
||||
Properties:
|
||||
event_id: The event ID of the event.
|
||||
|
||||
stream_ordering: stream ordering for this event
|
||||
|
||||
json: json-encoded event structure
|
||||
|
||||
internal_metadata: json-encoded internal metadata dict
|
||||
|
||||
format_version: The format of the event. Hopefully one of EventFormatVersions.
|
||||
'None' means the event predates EventFormatVersions (so the event is format V1).
|
||||
|
||||
room_version_id: The version of the room which contains the event. Hopefully
|
||||
one of RoomVersions.
|
||||
|
||||
Due to historical reasons, there may be a few events in the database which
|
||||
do not have an associated room; in this case None will be returned here.
|
||||
|
||||
rejected_reason: if the event was rejected, the reason why.
|
||||
|
||||
redactions: a list of event-ids which (claim to) redact this event.
|
||||
|
||||
outlier: True if this event is an outlier.
|
||||
"""
|
||||
|
||||
event_id: str
|
||||
stream_ordering: int
|
||||
json: str
|
||||
internal_metadata: str
|
||||
format_version: Optional[int]
|
||||
room_version_id: Optional[int]
|
||||
rejected_reason: Optional[str]
|
||||
redactions: List[str]
|
||||
outlier: bool
|
||||
|
||||
|
||||
class EventRedactBehaviour(Names):
|
||||
"""
|
||||
What to do when retrieving a redacted event from the database.
|
||||
|
@ -686,7 +728,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
for e in state_to_include.values()
|
||||
]
|
||||
|
||||
def _do_fetch(self, conn):
|
||||
def _do_fetch(self, conn: Connection) -> None:
|
||||
"""Takes a database connection and waits for requests for events from
|
||||
the _event_fetch_list queue.
|
||||
"""
|
||||
|
@ -713,13 +755,15 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
self._fetch_event_list(conn, event_list)
|
||||
|
||||
def _fetch_event_list(self, conn, event_list):
|
||||
def _fetch_event_list(
|
||||
self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
|
||||
) -> None:
|
||||
"""Handle a load of requests from the _event_fetch_list queue
|
||||
|
||||
Args:
|
||||
conn (twisted.enterprise.adbapi.Connection): database connection
|
||||
conn: database connection
|
||||
|
||||
event_list (list[Tuple[list[str], Deferred]]):
|
||||
event_list:
|
||||
The fetch requests. Each entry consists of a list of event
|
||||
ids to be fetched, and a deferred to be completed once the
|
||||
events have been fetched.
|
||||
|
@ -788,7 +832,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
row = row_map.get(event_id)
|
||||
fetched_events[event_id] = row
|
||||
if row:
|
||||
redaction_ids.update(row["redactions"])
|
||||
redaction_ids.update(row.redactions)
|
||||
|
||||
events_to_fetch = redaction_ids.difference(fetched_events.keys())
|
||||
if events_to_fetch:
|
||||
|
@ -799,32 +843,32 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
for event_id, row in fetched_events.items():
|
||||
if not row:
|
||||
continue
|
||||
assert row["event_id"] == event_id
|
||||
assert row.event_id == event_id
|
||||
|
||||
rejected_reason = row["rejected_reason"]
|
||||
rejected_reason = row.rejected_reason
|
||||
|
||||
# If the event or metadata cannot be parsed, log the error and act
|
||||
# as if the event is unknown.
|
||||
try:
|
||||
d = db_to_json(row["json"])
|
||||
d = db_to_json(row.json)
|
||||
except ValueError:
|
||||
logger.error("Unable to parse json from event: %s", event_id)
|
||||
continue
|
||||
try:
|
||||
internal_metadata = db_to_json(row["internal_metadata"])
|
||||
internal_metadata = db_to_json(row.internal_metadata)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Unable to parse internal_metadata from event: %s", event_id
|
||||
)
|
||||
continue
|
||||
|
||||
format_version = row["format_version"]
|
||||
format_version = row.format_version
|
||||
if format_version is None:
|
||||
# This means that we stored the event before we had the concept
|
||||
# of a event format version, so it must be a V1 event.
|
||||
format_version = EventFormatVersions.V1
|
||||
|
||||
room_version_id = row["room_version_id"]
|
||||
room_version_id = row.room_version_id
|
||||
|
||||
if not room_version_id:
|
||||
# this should only happen for out-of-band membership events which
|
||||
|
@ -889,8 +933,8 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
internal_metadata_dict=internal_metadata,
|
||||
rejected_reason=rejected_reason,
|
||||
)
|
||||
original_ev.internal_metadata.stream_ordering = row["stream_ordering"]
|
||||
original_ev.internal_metadata.outlier = row["outlier"]
|
||||
original_ev.internal_metadata.stream_ordering = row.stream_ordering
|
||||
original_ev.internal_metadata.outlier = row.outlier
|
||||
|
||||
event_map[event_id] = original_ev
|
||||
|
||||
|
@ -898,7 +942,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
# the cache entries.
|
||||
result_map = {}
|
||||
for event_id, original_ev in event_map.items():
|
||||
redactions = fetched_events[event_id]["redactions"]
|
||||
redactions = fetched_events[event_id].redactions
|
||||
redacted_event = self._maybe_redact_event_row(
|
||||
original_ev, redactions, event_map
|
||||
)
|
||||
|
@ -912,17 +956,17 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
return result_map
|
||||
|
||||
async def _enqueue_events(self, events):
|
||||
async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
|
||||
"""Fetches events from the database using the _event_fetch_list. This
|
||||
allows batch and bulk fetching of events - it allows us to fetch events
|
||||
without having to create a new transaction for each request for events.
|
||||
|
||||
Args:
|
||||
events (Iterable[str]): events to be fetched.
|
||||
events: events to be fetched.
|
||||
|
||||
Returns:
|
||||
Dict[str, Dict]: map from event id to row data from the database.
|
||||
May contain events that weren't requested.
|
||||
A map from event id to row data from the database. May contain events
|
||||
that weren't requested.
|
||||
"""
|
||||
|
||||
events_d = defer.Deferred()
|
||||
|
@ -949,43 +993,19 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
return row_map
|
||||
|
||||
def _fetch_event_rows(self, txn, event_ids):
|
||||
def _fetch_event_rows(
|
||||
self, txn: LoggingTransaction, event_ids: Iterable[str]
|
||||
) -> Dict[str, _EventRow]:
|
||||
"""Fetch event rows from the database
|
||||
|
||||
Events which are not found are omitted from the result.
|
||||
|
||||
The returned per-event dicts contain the following keys:
|
||||
|
||||
* event_id (str)
|
||||
|
||||
* stream_ordering (int): stream ordering for this event
|
||||
|
||||
* json (str): json-encoded event structure
|
||||
|
||||
* internal_metadata (str): json-encoded internal metadata dict
|
||||
|
||||
* format_version (int|None): The format of the event. Hopefully one
|
||||
of EventFormatVersions. 'None' means the event predates
|
||||
EventFormatVersions (so the event is format V1).
|
||||
|
||||
* room_version_id (str|None): The version of the room which contains the event.
|
||||
Hopefully one of RoomVersions.
|
||||
|
||||
Due to historical reasons, there may be a few events in the database which
|
||||
do not have an associated room; in this case None will be returned here.
|
||||
|
||||
* rejected_reason (str|None): if the event was rejected, the reason
|
||||
why.
|
||||
|
||||
* redactions (List[str]): a list of event-ids which (claim to) redact
|
||||
this event.
|
||||
|
||||
Args:
|
||||
txn (twisted.enterprise.adbapi.Connection):
|
||||
event_ids (Iterable[str]): event IDs to fetch
|
||||
txn: The database transaction.
|
||||
event_ids: event IDs to fetch
|
||||
|
||||
Returns:
|
||||
Dict[str, Dict]: a map from event id to event info.
|
||||
A map from event id to event info.
|
||||
"""
|
||||
event_dict = {}
|
||||
for evs in batch_iter(event_ids, 200):
|
||||
|
@ -1013,17 +1033,17 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
for row in txn:
|
||||
event_id = row[0]
|
||||
event_dict[event_id] = {
|
||||
"event_id": event_id,
|
||||
"stream_ordering": row[1],
|
||||
"internal_metadata": row[2],
|
||||
"json": row[3],
|
||||
"format_version": row[4],
|
||||
"room_version_id": row[5],
|
||||
"rejected_reason": row[6],
|
||||
"redactions": [],
|
||||
"outlier": row[7],
|
||||
}
|
||||
event_dict[event_id] = _EventRow(
|
||||
event_id=event_id,
|
||||
stream_ordering=row[1],
|
||||
internal_metadata=row[2],
|
||||
json=row[3],
|
||||
format_version=row[4],
|
||||
room_version_id=row[5],
|
||||
rejected_reason=row[6],
|
||||
redactions=[],
|
||||
outlier=row[7],
|
||||
)
|
||||
|
||||
# check for redactions
|
||||
redactions_sql = "SELECT event_id, redacts FROM redactions WHERE "
|
||||
|
@ -1035,7 +1055,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
for (redacter, redacted) in txn:
|
||||
d = event_dict.get(redacted)
|
||||
if d:
|
||||
d["redactions"].append(redacter)
|
||||
d.redactions.append(redacter)
|
||||
|
||||
return event_dict
|
||||
|
||||
|
|
|
@ -13,11 +13,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import DatabasePool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
|
||||
"media_repository_drop_index_wo_method"
|
||||
)
|
||||
|
@ -43,7 +46,7 @@ class MediaSortOrder(Enum):
|
|||
|
||||
|
||||
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
|
@ -123,7 +126,7 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
|
|||
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||
"""Persistence for attachments and avatars"""
|
||||
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
self.server_name = hs.hostname
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
import calendar
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from synapse.metrics import GaugeBucketCollector
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
|
@ -24,6 +24,9 @@ from synapse.storage.databases.main.event_push_actions import (
|
|||
EventPushActionsWorkerStore,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Collect metrics on the number of forward extremities that exist.
|
||||
|
@ -52,7 +55,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
|||
stats and prometheus metrics.
|
||||
"""
|
||||
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
# Read the extrems every 60 minutes
|
||||
|
|
|
@ -12,13 +12,16 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Number of msec of granularity to store the monthly_active_user timestamp
|
||||
|
@ -27,7 +30,7 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
|
|||
|
||||
|
||||
class MonthlyActiveUsersWorkerStore(SQLBaseStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
self._clock = hs.get_clock()
|
||||
self.hs = hs
|
||||
|
@ -209,7 +212,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
|
|||
|
||||
|
||||
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._mau_stats_only = hs.config.server.mau_stats_only
|
||||
|
@ -354,27 +357,3 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
|
|||
await self.upsert_monthly_active_user(user_id)
|
||||
elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
|
||||
await self.upsert_monthly_active_user(user_id)
|
||||
|
||||
async def remove_deactivated_user_from_mau_table(self, user_id: str) -> None:
|
||||
"""
|
||||
Removes a deactivated user from the monthly active user
|
||||
table and resets affected caches.
|
||||
|
||||
Args:
|
||||
user_id(str): the user_id to remove
|
||||
"""
|
||||
|
||||
rows_deleted = await self.db_pool.simple_delete(
|
||||
table="monthly_active_users",
|
||||
keyvalues={"user_id": user_id},
|
||||
desc="simple_delete",
|
||||
)
|
||||
|
||||
if rows_deleted != 0:
|
||||
await self.invalidate_cache_and_stream(
|
||||
"user_last_seen_monthly_active", (user_id,)
|
||||
)
|
||||
await self.invalidate_cache_and_stream("get_monthly_active_count", ())
|
||||
await self.invalidate_cache_and_stream(
|
||||
"get_monthly_active_count_by_service", ()
|
||||
)
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
import abc
|
||||
import logging
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
||||
|
||||
from synapse.api.errors import NotFoundError, StoreError
|
||||
from synapse.push.baserules import list_with_base_rules
|
||||
|
@ -33,6 +33,9 @@ from synapse.util import json_encoder
|
|||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -75,7 +78,7 @@ class PushRulesWorkerStore(
|
|||
`get_max_push_rules_stream_id` which can be called in the initializer.
|
||||
"""
|
||||
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
if hs.config.worker.worker_app is None:
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
@ -29,11 +29,14 @@ from synapse.util import json_encoder
|
|||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReceiptsWorkerStore(SQLBaseStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
if isinstance(database.engine, PostgresEngine):
|
||||
|
|
|
@ -23,7 +23,11 @@ import attr
|
|||
from synapse.api.constants import UserTypes
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.databases.main.stats import StatsStore
|
||||
from synapse.storage.types import Cursor
|
||||
|
@ -40,6 +44,13 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExternalIDReuseException(Exception):
|
||||
"""Exception if writing an external id for a user fails,
|
||||
because this external id is given to an other user."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@attr.s(frozen=True, slots=True)
|
||||
class TokenLookupResult:
|
||||
"""Result of looking up an access token.
|
||||
|
@ -488,6 +499,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
|
||||
await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
|
||||
|
||||
async def set_user_type(self, user: UserID, user_type: Optional[UserTypes]) -> None:
|
||||
"""Sets the user type.
|
||||
|
||||
Args:
|
||||
user: user ID of the user.
|
||||
user_type: type of the user or None for a user without a type.
|
||||
"""
|
||||
|
||||
def set_user_type_txn(txn):
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn, "users", {"name": user.to_string()}, {"user_type": user_type}
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_user_by_id, (user.to_string(),)
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction("set_user_type", set_user_type_txn)
|
||||
|
||||
def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
|
||||
sql = """
|
||||
SELECT users.name as user_id,
|
||||
|
@ -588,24 +617,44 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
auth_provider: identifier for the remote auth provider
|
||||
external_id: id on that system
|
||||
user_id: complete mxid that it is mapped to
|
||||
Raises:
|
||||
ExternalIDReuseException if the new external_id could not be mapped.
|
||||
"""
|
||||
await self.db_pool.simple_insert(
|
||||
|
||||
try:
|
||||
await self.db_pool.runInteraction(
|
||||
"record_user_external_id",
|
||||
self._record_user_external_id_txn,
|
||||
auth_provider,
|
||||
external_id,
|
||||
user_id,
|
||||
)
|
||||
except self.database_engine.module.IntegrityError:
|
||||
raise ExternalIDReuseException()
|
||||
|
||||
def _record_user_external_id_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
auth_provider: str,
|
||||
external_id: str,
|
||||
user_id: str,
|
||||
) -> None:
|
||||
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn,
|
||||
table="user_external_ids",
|
||||
values={
|
||||
"auth_provider": auth_provider,
|
||||
"external_id": external_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
desc="record_user_external_id",
|
||||
)
|
||||
|
||||
async def remove_user_external_id(
|
||||
self, auth_provider: str, external_id: str, user_id: str
|
||||
) -> None:
|
||||
"""Remove a mapping from an external user id to a mxid
|
||||
|
||||
If the mapping is not found, this method does nothing.
|
||||
|
||||
Args:
|
||||
auth_provider: identifier for the remote auth provider
|
||||
external_id: id on that system
|
||||
|
@ -621,6 +670,60 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
desc="remove_user_external_id",
|
||||
)
|
||||
|
||||
async def replace_user_external_id(
|
||||
self,
|
||||
record_external_ids: List[Tuple[str, str]],
|
||||
user_id: str,
|
||||
) -> None:
|
||||
"""Replace mappings from external user ids to a mxid in a single transaction.
|
||||
All mappings are deleted and the new ones are created.
|
||||
|
||||
Args:
|
||||
record_external_ids:
|
||||
List with tuple of auth_provider and external_id to record
|
||||
user_id: complete mxid that it is mapped to
|
||||
Raises:
|
||||
ExternalIDReuseException if the new external_id could not be mapped.
|
||||
"""
|
||||
|
||||
def _remove_user_external_ids_txn(
|
||||
txn: LoggingTransaction,
|
||||
user_id: str,
|
||||
) -> None:
|
||||
"""Remove all mappings from external user ids to a mxid
|
||||
If these mappings are not found, this method does nothing.
|
||||
|
||||
Args:
|
||||
user_id: complete mxid that it is mapped to
|
||||
"""
|
||||
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
table="user_external_ids",
|
||||
keyvalues={"user_id": user_id},
|
||||
)
|
||||
|
||||
def _replace_user_external_id_txn(
|
||||
txn: LoggingTransaction,
|
||||
):
|
||||
_remove_user_external_ids_txn(txn, user_id)
|
||||
|
||||
for auth_provider, external_id in record_external_ids:
|
||||
self._record_user_external_id_txn(
|
||||
txn,
|
||||
auth_provider,
|
||||
external_id,
|
||||
user_id,
|
||||
)
|
||||
|
||||
try:
|
||||
await self.db_pool.runInteraction(
|
||||
"replace_user_external_id",
|
||||
_replace_user_external_id_txn,
|
||||
)
|
||||
except self.database_engine.module.IntegrityError:
|
||||
raise ExternalIDReuseException()
|
||||
|
||||
async def get_user_by_external_id(
|
||||
self, auth_provider: str, external_id: str
|
||||
) -> Optional[str]:
|
||||
|
@ -2237,7 +2340,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
# accident.
|
||||
row = {"client_secret": None, "validated_at": None}
|
||||
else:
|
||||
raise ThreepidValidationError(400, "Unknown session_id")
|
||||
raise ThreepidValidationError("Unknown session_id")
|
||||
|
||||
retrieved_client_secret = row["client_secret"]
|
||||
validated_at = row["validated_at"]
|
||||
|
@ -2252,14 +2355,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
|
||||
if not row:
|
||||
raise ThreepidValidationError(
|
||||
400, "Validation token not found or has expired"
|
||||
"Validation token not found or has expired"
|
||||
)
|
||||
expires = row["expires"]
|
||||
next_link = row["next_link"]
|
||||
|
||||
if retrieved_client_secret != client_secret:
|
||||
raise ThreepidValidationError(
|
||||
400, "This client_secret does not match the provided session_id"
|
||||
"This client_secret does not match the provided session_id"
|
||||
)
|
||||
|
||||
# If the session is already validated, no need to revalidate
|
||||
|
@ -2268,7 +2371,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
|
||||
if expires <= current_ts:
|
||||
raise ThreepidValidationError(
|
||||
400, "This token has expired. Please request a new one"
|
||||
"This token has expired. Please request a new one"
|
||||
)
|
||||
|
||||
# Looks good. Validate the session
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
|
@ -269,6 +269,63 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
|
||||
return await self.get_event(edit_id, allow_none=True)
|
||||
|
||||
@cached()
|
||||
async def get_thread_summary(
|
||||
self, event_id: str
|
||||
) -> Tuple[int, Optional[EventBase]]:
|
||||
"""Get the number of threaded replies, the senders of those replies, and
|
||||
the latest reply (if any) for the given event.
|
||||
|
||||
Args:
|
||||
event_id: The original event ID
|
||||
|
||||
Returns:
|
||||
The number of items in the thread and the most recent response, if any.
|
||||
"""
|
||||
|
||||
def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]:
|
||||
# Fetch the count of threaded events and the latest event ID.
|
||||
# TODO Should this only allow m.room.message events.
|
||||
sql = """
|
||||
SELECT event_id
|
||||
FROM event_relations
|
||||
INNER JOIN events USING (event_id)
|
||||
WHERE
|
||||
relates_to_id = ?
|
||||
AND relation_type = ?
|
||||
ORDER BY topological_ordering DESC, stream_ordering DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
txn.execute(sql, (event_id, RelationTypes.THREAD))
|
||||
row = txn.fetchone()
|
||||
if row is None:
|
||||
return 0, None
|
||||
|
||||
latest_event_id = row[0]
|
||||
|
||||
sql = """
|
||||
SELECT COALESCE(COUNT(event_id), 0)
|
||||
FROM event_relations
|
||||
WHERE
|
||||
relates_to_id = ?
|
||||
AND relation_type = ?
|
||||
"""
|
||||
txn.execute(sql, (event_id, RelationTypes.THREAD))
|
||||
count = txn.fetchone()[0]
|
||||
|
||||
return count, latest_event_id
|
||||
|
||||
count, latest_event_id = await self.db_pool.runInteraction(
|
||||
"get_thread_summary", _get_thread_summary_txn
|
||||
)
|
||||
|
||||
latest_event = None
|
||||
if latest_event_id:
|
||||
latest_event = await self.get_event(latest_event_id, allow_none=True)
|
||||
|
||||
return count, latest_event
|
||||
|
||||
async def has_user_annotated_event(
|
||||
self, parent_id: str, event_type: str, aggregation_key: str, sender: str
|
||||
) -> bool:
|
||||
|
|
|
@ -17,7 +17,7 @@ import collections
|
|||
import logging
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from synapse.api.constants import EventContentFields, EventTypes, JoinRules
|
||||
from synapse.api.errors import StoreError
|
||||
|
@ -32,6 +32,9 @@ from synapse.util import json_encoder
|
|||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.stringutils import MXC_REGEX
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -69,7 +72,7 @@ class RoomSortOrder(Enum):
|
|||
|
||||
|
||||
class RoomWorkerStore(SQLBaseStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self.config = hs.config
|
||||
|
@ -679,8 +682,8 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
# policy.
|
||||
if not ret:
|
||||
return {
|
||||
"min_lifetime": self.config.server.retention_default_min_lifetime,
|
||||
"max_lifetime": self.config.server.retention_default_max_lifetime,
|
||||
"min_lifetime": self.config.retention.retention_default_min_lifetime,
|
||||
"max_lifetime": self.config.retention.retention_default_max_lifetime,
|
||||
}
|
||||
|
||||
row = ret[0]
|
||||
|
@ -690,10 +693,10 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
# The default values will be None if no default policy has been defined, or if one
|
||||
# of the attributes is missing from the default policy.
|
||||
if row["min_lifetime"] is None:
|
||||
row["min_lifetime"] = self.config.server.retention_default_min_lifetime
|
||||
row["min_lifetime"] = self.config.retention.retention_default_min_lifetime
|
||||
|
||||
if row["max_lifetime"] is None:
|
||||
row["max_lifetime"] = self.config.server.retention_default_max_lifetime
|
||||
row["max_lifetime"] = self.config.retention.retention_default_max_lifetime
|
||||
|
||||
return row
|
||||
|
||||
|
@ -1026,7 +1029,7 @@ _REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
|
|||
|
||||
|
||||
class RoomBackgroundUpdateStore(SQLBaseStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self.config = hs.config
|
||||
|
@ -1411,7 +1414,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
|
|||
|
||||
|
||||
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self.config = hs.config
|
||||
|
|
|
@ -36,3 +36,16 @@ class RoomBatchStore(SQLBaseStore):
|
|||
retcol="event_id",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
async def store_state_group_id_for_event_id(
|
||||
self, event_id: str, state_group_id: int
|
||||
) -> Optional[str]:
|
||||
{
|
||||
await self.db_pool.simple_upsert(
|
||||
table="event_to_state_groups",
|
||||
keyvalues={"event_id": event_id},
|
||||
values={"state_group": state_group_id, "event_id": event_id},
|
||||
# Unique constraint on event_id so we don't have to lock
|
||||
lock=False,
|
||||
)
|
||||
}
|
||||
|
|
|
@ -53,6 +53,7 @@ from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
|
|||
from synapse.util.metrics import Measure
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
from synapse.state import _StateCacheEntry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -63,7 +64,7 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
|
|||
|
||||
|
||||
class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
# Used by `_get_joined_hosts` to ensure only one thing mutates the cache
|
||||
|
@ -982,7 +983,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
|
||||
|
||||
class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
|
||||
|
@ -1132,7 +1133,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
|||
|
||||
|
||||
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
async def forget(self, user_id: str, room_id: str) -> None:
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
import logging
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from typing import Collection, Iterable, List, Optional, Set
|
||||
from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.events import EventBase
|
||||
|
@ -24,6 +24,9 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
|
|||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SearchEntry = namedtuple(
|
||||
|
@ -102,7 +105,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
|
|||
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
|
||||
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
|
||||
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
if not hs.config.server.enable_search:
|
||||
|
@ -355,7 +358,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
|
|||
|
||||
|
||||
class SearchStore(SearchBackgroundUpdateStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
async def search_msgs(self, room_ids, search_term, keys):
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
import collections.abc
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Iterable, Optional, Set
|
||||
from typing import TYPE_CHECKING, Iterable, Optional, Set
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
|
||||
|
@ -30,6 +30,9 @@ from synapse.types import StateMap
|
|||
from synapse.util.caches import intern_string
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -53,7 +56,7 @@ class _GetStateGroupDelta(
|
|||
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
"""The parts of StateGroupStore that can be called from workers."""
|
||||
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
async def get_room_version(self, room_id: str) -> RoomVersion:
|
||||
|
@ -346,7 +349,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
|
|||
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
|
||||
DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
|
||||
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self.server_name = hs.hostname
|
||||
|
@ -533,5 +536,5 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
|
|||
* `state_groups_state`: Maps state group to state events.
|
||||
"""
|
||||
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import logging
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from typing_extensions import Counter
|
||||
|
||||
|
@ -29,6 +29,9 @@ from synapse.storage.databases.main.state_deltas import StateDeltasStore
|
|||
from synapse.types import JsonDict
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# these fields track absolutes (e.g. total number of rooms on the server)
|
||||
|
@ -93,7 +96,7 @@ class UserSortOrder(Enum):
|
|||
|
||||
|
||||
class StatsStore(StateDeltasStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self.server_name = hs.hostname
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
@ -26,6 +26,9 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
|||
from synapse.types import JsonDict
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
db_binary_type = memoryview
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -57,7 +60,7 @@ class DestinationRetryTimings:
|
|||
|
||||
|
||||
class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
if hs.config.worker.run_background_tasks:
|
||||
|
|
|
@ -18,6 +18,7 @@ import itertools
|
|||
import logging
|
||||
from collections import deque
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
|
@ -56,6 +57,9 @@ from synapse.types import (
|
|||
from synapse.util.async_helpers import ObservableDeferred, yieldable_gather_results
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# The number of times we are recalculating the current state
|
||||
|
@ -272,7 +276,7 @@ class EventsPersistenceStorage:
|
|||
current state and forward extremity changes.
|
||||
"""
|
||||
|
||||
def __init__(self, hs, stores: Databases):
|
||||
def __init__(self, hs: "HomeServer", stores: Databases):
|
||||
# We ultimately want to split out the state store from the main store,
|
||||
# so we use separate variables here even though they point to the same
|
||||
# store for now.
|
||||
|
|
|
@ -549,6 +549,8 @@ def _apply_module_schemas(
|
|||
database_engine:
|
||||
config: application config
|
||||
"""
|
||||
# This is the old way for password_auth_provider modules to make changes
|
||||
# to the database. This should instead be done using the module API
|
||||
for (mod, _config) in config.authproviders.password_providers:
|
||||
if not hasattr(mod, "get_db_schema_files"):
|
||||
continue
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
SCHEMA_VERSION = 64 # remember to update the list below when updating
|
||||
SCHEMA_VERSION = 65 # remember to update the list below when updating
|
||||
"""Represents the expectations made by the codebase about the database schema
|
||||
|
||||
This should be incremented whenever the codebase changes its requirements on the
|
||||
|
@ -41,6 +41,10 @@ Changes in SCHEMA_VERSION = 63:
|
|||
|
||||
Changes in SCHEMA_VERSION = 64:
|
||||
- MSC2716: Rename related tables and columns from "chunks" to "batches".
|
||||
|
||||
Changes in SCHEMA_VERSION = 65:
|
||||
- MSC2716: Remove unique event_id constraint from insertion_event_edges
|
||||
because an insertion event can have multiple edges.
|
||||
"""
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
/* Copyright 2021 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- Recreate the insertion_event_edges event_id index without the unique constraint
|
||||
-- because an insertion event can have multiple edges.
|
||||
DROP INDEX insertion_event_edges_event_id;
|
||||
CREATE INDEX IF NOT EXISTS insertion_event_edges_event_id ON insertion_event_edges(event_id);
|
18
synapse/storage/schema/main/delta/65/02_thread_relations.sql
Normal file
18
synapse/storage/schema/main/delta/65/02_thread_relations.sql
Normal file
|
@ -0,0 +1,18 @@
|
|||
/* Copyright 2021 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- Check old events for thread relations.
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(6502, 'event_thread_relation', '{}');
|
Loading…
Add table
Add a link
Reference in a new issue