mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-08-18 00:00:10 -04:00
Merge remote-tracking branch 'upstream/release-v1.32.0'
This commit is contained in:
commit
dbe12b3eb1
207 changed files with 4883 additions and 1550 deletions
|
@ -488,7 +488,7 @@ class DatabasePool:
|
|||
exception_callbacks: List[_CallbackListEntry],
|
||||
func: "Callable[..., R]",
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
**kwargs: Any,
|
||||
) -> R:
|
||||
"""Start a new database transaction with the given connection.
|
||||
|
||||
|
@ -622,7 +622,7 @@ class DatabasePool:
|
|||
func: "Callable[..., R]",
|
||||
*args: Any,
|
||||
db_autocommit: bool = False,
|
||||
**kwargs: Any
|
||||
**kwargs: Any,
|
||||
) -> R:
|
||||
"""Starts a transaction on the database and runs a given function
|
||||
|
||||
|
@ -682,7 +682,7 @@ class DatabasePool:
|
|||
func: "Callable[..., R]",
|
||||
*args: Any,
|
||||
db_autocommit: bool = False,
|
||||
**kwargs: Any
|
||||
**kwargs: Any,
|
||||
) -> R:
|
||||
"""Wraps the .runWithConnection() method on the underlying db_pool.
|
||||
|
||||
|
@ -775,7 +775,7 @@ class DatabasePool:
|
|||
desc: str,
|
||||
decoder: Optional[Callable[[Cursor], R]],
|
||||
query: str,
|
||||
*args: Any
|
||||
*args: Any,
|
||||
) -> R:
|
||||
"""Runs a single query for a result set.
|
||||
|
||||
|
@ -900,7 +900,7 @@ class DatabasePool:
|
|||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
values: Dict[str, Any],
|
||||
insertion_values: Dict[str, Any] = {},
|
||||
insertion_values: Optional[Dict[str, Any]] = None,
|
||||
desc: str = "simple_upsert",
|
||||
lock: bool = True,
|
||||
) -> Optional[bool]:
|
||||
|
@ -927,6 +927,8 @@ class DatabasePool:
|
|||
Native upserts always return None. Emulated upserts return True if a
|
||||
new entry was created, False if an existing one was updated.
|
||||
"""
|
||||
insertion_values = insertion_values or {}
|
||||
|
||||
attempts = 0
|
||||
while True:
|
||||
try:
|
||||
|
@ -964,7 +966,7 @@ class DatabasePool:
|
|||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
values: Dict[str, Any],
|
||||
insertion_values: Dict[str, Any] = {},
|
||||
insertion_values: Optional[Dict[str, Any]] = None,
|
||||
lock: bool = True,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
|
@ -982,6 +984,8 @@ class DatabasePool:
|
|||
Native upserts always return None. Emulated upserts return True if a
|
||||
new entry was created, False if an existing one was updated.
|
||||
"""
|
||||
insertion_values = insertion_values or {}
|
||||
|
||||
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
|
||||
self.simple_upsert_txn_native_upsert(
|
||||
txn, table, keyvalues, values, insertion_values=insertion_values
|
||||
|
@ -1003,7 +1007,7 @@ class DatabasePool:
|
|||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
values: Dict[str, Any],
|
||||
insertion_values: Dict[str, Any] = {},
|
||||
insertion_values: Optional[Dict[str, Any]] = None,
|
||||
lock: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
|
@ -1017,6 +1021,8 @@ class DatabasePool:
|
|||
Returns True if a new entry was created, False if an existing
|
||||
one was updated.
|
||||
"""
|
||||
insertion_values = insertion_values or {}
|
||||
|
||||
# We need to lock the table :(, unless we're *really* careful
|
||||
if lock:
|
||||
self.engine.lock_table(txn, table)
|
||||
|
@ -1077,7 +1083,7 @@ class DatabasePool:
|
|||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
values: Dict[str, Any],
|
||||
insertion_values: Dict[str, Any] = {},
|
||||
insertion_values: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Use the native UPSERT functionality in recent PostgreSQL versions.
|
||||
|
@ -1090,7 +1096,7 @@ class DatabasePool:
|
|||
"""
|
||||
allvalues = {} # type: Dict[str, Any]
|
||||
allvalues.update(keyvalues)
|
||||
allvalues.update(insertion_values)
|
||||
allvalues.update(insertion_values or {})
|
||||
|
||||
if not values:
|
||||
latter = "NOTHING"
|
||||
|
@ -1513,7 +1519,7 @@ class DatabasePool:
|
|||
column: str,
|
||||
iterable: Iterable[Any],
|
||||
retcols: Iterable[str],
|
||||
keyvalues: Dict[str, Any] = {},
|
||||
keyvalues: Optional[Dict[str, Any]] = None,
|
||||
desc: str = "simple_select_many_batch",
|
||||
batch_size: int = 100,
|
||||
) -> List[Any]:
|
||||
|
@ -1531,6 +1537,8 @@ class DatabasePool:
|
|||
desc: description of the transaction, for logging and metrics
|
||||
batch_size: the number of rows for each select query
|
||||
"""
|
||||
keyvalues = keyvalues or {}
|
||||
|
||||
results = [] # type: List[Dict[str, Any]]
|
||||
|
||||
if not iterable:
|
||||
|
@ -2059,69 +2067,18 @@ def make_in_list_sql_clause(
|
|||
KV = TypeVar("KV")
|
||||
|
||||
|
||||
def make_tuple_comparison_clause(
|
||||
database_engine: BaseDatabaseEngine, keys: List[Tuple[str, KV]]
|
||||
) -> Tuple[str, List[KV]]:
|
||||
def make_tuple_comparison_clause(keys: List[Tuple[str, KV]]) -> Tuple[str, List[KV]]:
|
||||
"""Returns a tuple comparison SQL clause
|
||||
|
||||
Depending what the SQL engine supports, builds a SQL clause that looks like either
|
||||
"(a, b) > (?, ?)", or "(a > ?) OR (a == ? AND b > ?)".
|
||||
Builds a SQL clause that looks like "(a, b) > (?, ?)"
|
||||
|
||||
Args:
|
||||
database_engine
|
||||
keys: A set of (column, value) pairs to be compared.
|
||||
|
||||
Returns:
|
||||
A tuple of SQL query and the args
|
||||
"""
|
||||
if database_engine.supports_tuple_comparison:
|
||||
return (
|
||||
"(%s) > (%s)" % (",".join(k[0] for k in keys), ",".join("?" for _ in keys)),
|
||||
[k[1] for k in keys],
|
||||
)
|
||||
|
||||
# we want to build a clause
|
||||
# (a > ?) OR
|
||||
# (a == ? AND b > ?) OR
|
||||
# (a == ? AND b == ? AND c > ?)
|
||||
# ...
|
||||
# (a == ? AND b == ? AND ... AND z > ?)
|
||||
#
|
||||
# or, equivalently:
|
||||
#
|
||||
# (a > ? OR (a == ? AND
|
||||
# (b > ? OR (b == ? AND
|
||||
# ...
|
||||
# (y > ? OR (y == ? AND
|
||||
# z > ?
|
||||
# ))
|
||||
# ...
|
||||
# ))
|
||||
# ))
|
||||
#
|
||||
# which itself is equivalent to (and apparently easier for the query optimiser):
|
||||
#
|
||||
# (a >= ? AND (a > ? OR
|
||||
# (b >= ? AND (b > ? OR
|
||||
# ...
|
||||
# (y >= ? AND (y > ? OR
|
||||
# z > ?
|
||||
# ))
|
||||
# ...
|
||||
# ))
|
||||
# ))
|
||||
#
|
||||
#
|
||||
|
||||
clause = ""
|
||||
args = [] # type: List[KV]
|
||||
for k, v in keys[:-1]:
|
||||
clause = clause + "(%s >= ? AND (%s > ? OR " % (k, k)
|
||||
args.extend([v, v])
|
||||
|
||||
(k, v) = keys[-1]
|
||||
clause += "%s > ?" % (k,)
|
||||
args.append(v)
|
||||
|
||||
clause += "))" * (len(keys) - 1)
|
||||
return clause, args
|
||||
return (
|
||||
"(%s) > (%s)" % (",".join(k[0] for k in keys), ",".join("?" for _ in keys)),
|
||||
[k[1] for k in keys],
|
||||
)
|
||||
|
|
|
@ -21,6 +21,7 @@ from typing import List, Optional, Tuple
|
|||
from synapse.api.constants import PresenceState
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.databases.main.stats import UserSortOrder
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.util.id_generators import (
|
||||
IdGenerator,
|
||||
|
@ -292,6 +293,8 @@ class DataStore(
|
|||
name: Optional[str] = None,
|
||||
guests: bool = True,
|
||||
deactivated: bool = False,
|
||||
order_by: UserSortOrder = UserSortOrder.USER_ID.value,
|
||||
direction: str = "f",
|
||||
) -> Tuple[List[JsonDict], int]:
|
||||
"""Function to retrieve a paginated list of users from
|
||||
users list. This will return a json list of users and the
|
||||
|
@ -304,6 +307,8 @@ class DataStore(
|
|||
name: search for local part of user_id or display name
|
||||
guests: whether to in include guest users
|
||||
deactivated: whether to include deactivated users
|
||||
order_by: the sort order of the returned list
|
||||
direction: sort ascending or descending
|
||||
Returns:
|
||||
A tuple of a list of mappings from user to information and a count of total users.
|
||||
"""
|
||||
|
@ -312,6 +317,14 @@ class DataStore(
|
|||
filters = []
|
||||
args = [self.hs.config.server_name]
|
||||
|
||||
# Set ordering
|
||||
order_by_column = UserSortOrder(order_by).value
|
||||
|
||||
if direction == "b":
|
||||
order = "DESC"
|
||||
else:
|
||||
order = "ASC"
|
||||
|
||||
# `name` is in database already in lower case
|
||||
if name:
|
||||
filters.append("(name LIKE ? OR LOWER(displayname) LIKE ?)")
|
||||
|
@ -339,10 +352,15 @@ class DataStore(
|
|||
txn.execute(sql, args)
|
||||
count = txn.fetchone()[0]
|
||||
|
||||
sql = (
|
||||
"SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url "
|
||||
+ sql_base
|
||||
+ " ORDER BY u.name LIMIT ? OFFSET ?"
|
||||
sql = """
|
||||
SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url
|
||||
{sql_base}
|
||||
ORDER BY {order_by_column} {order}, u.name ASC
|
||||
LIMIT ? OFFSET ?
|
||||
""".format(
|
||||
sql_base=sql_base,
|
||||
order_by_column=order_by_column,
|
||||
order=order,
|
||||
)
|
||||
args += [limit, start]
|
||||
txn.execute(sql, args)
|
||||
|
|
|
@ -298,7 +298,6 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
|
|||
# times, which is fine.
|
||||
|
||||
where_clause, where_args = make_tuple_comparison_clause(
|
||||
self.database_engine,
|
||||
[("user_id", last_user_id), ("device_id", last_device_id)],
|
||||
)
|
||||
|
||||
|
|
|
@ -985,7 +985,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
|||
|
||||
def _txn(txn):
|
||||
clause, args = make_tuple_comparison_clause(
|
||||
self.db_pool.engine, [(x, last_row[x]) for x in KEY_COLS]
|
||||
[(x, last_row[x]) for x in KEY_COLS]
|
||||
)
|
||||
sql = """
|
||||
SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts
|
||||
|
|
|
@ -320,8 +320,8 @@ class PersistEventsStore:
|
|||
txn: LoggingTransaction,
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
backfilled: bool,
|
||||
state_delta_for_room: Dict[str, DeltaState] = {},
|
||||
new_forward_extremeties: Dict[str, List[str]] = {},
|
||||
state_delta_for_room: Optional[Dict[str, DeltaState]] = None,
|
||||
new_forward_extremeties: Optional[Dict[str, List[str]]] = None,
|
||||
):
|
||||
"""Insert some number of room events into the necessary database tables.
|
||||
|
||||
|
@ -342,6 +342,9 @@ class PersistEventsStore:
|
|||
extremities.
|
||||
|
||||
"""
|
||||
state_delta_for_room = state_delta_for_room or {}
|
||||
new_forward_extremeties = new_forward_extremeties or {}
|
||||
|
||||
all_events_and_contexts = events_and_contexts
|
||||
|
||||
min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
|
||||
|
|
|
@ -838,7 +838,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
# We want to do a `(topological_ordering, stream_ordering) > (?,?)`
|
||||
# comparison, but that is not supported on older SQLite versions
|
||||
tuple_clause, tuple_args = make_tuple_comparison_clause(
|
||||
self.database_engine,
|
||||
[
|
||||
("events.room_id", last_room_id),
|
||||
("topological_ordering", last_depth),
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import logging
|
||||
import threading
|
||||
from collections import namedtuple
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, overload
|
||||
from typing import Container, Dict, Iterable, List, Optional, Tuple, overload
|
||||
|
||||
from constantly import NamedConstant, Names
|
||||
from typing_extensions import Literal
|
||||
|
@ -544,7 +544,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
async def get_stripped_room_state_from_event_context(
|
||||
self,
|
||||
context: EventContext,
|
||||
state_types_to_include: List[EventTypes],
|
||||
state_types_to_include: Container[str],
|
||||
membership_user_id: Optional[str] = None,
|
||||
) -> List[JsonDict]:
|
||||
"""
|
||||
|
|
|
@ -1027,8 +1027,8 @@ class GroupServerStore(GroupServerWorkerStore):
|
|||
user_id: str,
|
||||
is_admin: bool = False,
|
||||
is_public: bool = True,
|
||||
local_attestation: dict = None,
|
||||
remote_attestation: dict = None,
|
||||
local_attestation: Optional[dict] = None,
|
||||
remote_attestation: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""Add a user to the group server.
|
||||
|
||||
|
@ -1171,7 +1171,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
|||
user_id: str,
|
||||
membership: str,
|
||||
is_admin: bool = False,
|
||||
content: JsonDict = {},
|
||||
content: Optional[JsonDict] = None,
|
||||
local_attestation: Optional[dict] = None,
|
||||
remote_attestation: Optional[dict] = None,
|
||||
is_publicised: bool = False,
|
||||
|
@ -1192,6 +1192,8 @@ class GroupServerStore(GroupServerWorkerStore):
|
|||
is_publicised: Whether this should be publicised.
|
||||
"""
|
||||
|
||||
content = content or {}
|
||||
|
||||
def _register_user_group_membership_txn(txn, next_id):
|
||||
# TODO: Upsert?
|
||||
self.db_pool.simple_delete_txn(
|
||||
|
|
|
@ -22,6 +22,9 @@ from synapse.storage.database import DatabasePool
|
|||
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
|
||||
"media_repository_drop_index_wo_method"
|
||||
)
|
||||
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = (
|
||||
"media_repository_drop_index_wo_method_2"
|
||||
)
|
||||
|
||||
|
||||
class MediaSortOrder(Enum):
|
||||
|
@ -85,23 +88,35 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
|
|||
unique=True,
|
||||
)
|
||||
|
||||
# the original impl of _drop_media_index_without_method was broken (see
|
||||
# https://github.com/matrix-org/synapse/issues/8649), so we replace the original
|
||||
# impl with a no-op and run the fixed migration as
|
||||
# media_repository_drop_index_wo_method_2.
|
||||
self.db_pool.updates.register_noop_background_update(
|
||||
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD
|
||||
)
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD,
|
||||
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2,
|
||||
self._drop_media_index_without_method,
|
||||
)
|
||||
|
||||
async def _drop_media_index_without_method(self, progress, batch_size):
|
||||
"""background update handler which removes the old constraints.
|
||||
|
||||
Note that this is only run on postgres.
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
txn.execute(
|
||||
"ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key"
|
||||
)
|
||||
txn.execute(
|
||||
"ALTER TABLE remote_media_cache_thumbnails DROP CONSTRAINT IF EXISTS remote_media_repository_thumbn_media_id_thumbnail_width_thum_key"
|
||||
"ALTER TABLE remote_media_cache_thumbnails DROP CONSTRAINT IF EXISTS remote_media_cache_thumbnails_media_origin_media_id_thumbna_key"
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction("drop_media_indices_without_method", f)
|
||||
await self.db_pool.updates._end_background_update(
|
||||
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD
|
||||
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2
|
||||
)
|
||||
return 1
|
||||
|
||||
|
|
|
@ -521,13 +521,11 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
@cached(max_entries=10000)
|
||||
async def get_ratelimit_for_user(self, user_id):
|
||||
"""Check if there are any overrides for ratelimiting for the given
|
||||
user
|
||||
async def get_ratelimit_for_user(self, user_id: str) -> Optional[RatelimitOverride]:
|
||||
"""Check if there are any overrides for ratelimiting for the given user
|
||||
|
||||
Args:
|
||||
user_id (str)
|
||||
|
||||
user_id: user ID of the user
|
||||
Returns:
|
||||
RatelimitOverride if there is an override, else None. If the contents
|
||||
of RatelimitOverride are None or 0 then ratelimitng has been
|
||||
|
@ -549,6 +547,62 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
else:
|
||||
return None
|
||||
|
||||
async def set_ratelimit_for_user(
|
||||
self, user_id: str, messages_per_second: int, burst_count: int
|
||||
) -> None:
|
||||
"""Sets whether a user is set an overridden ratelimit.
|
||||
Args:
|
||||
user_id: user ID of the user
|
||||
messages_per_second: The number of actions that can be performed in a second.
|
||||
burst_count: How many actions that can be performed before being limited.
|
||||
"""
|
||||
|
||||
def set_ratelimit_txn(txn):
|
||||
self.db_pool.simple_upsert_txn(
|
||||
txn,
|
||||
table="ratelimit_override",
|
||||
keyvalues={"user_id": user_id},
|
||||
values={
|
||||
"messages_per_second": messages_per_second,
|
||||
"burst_count": burst_count,
|
||||
},
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_ratelimit_for_user, (user_id,)
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction("set_ratelimit", set_ratelimit_txn)
|
||||
|
||||
async def delete_ratelimit_for_user(self, user_id: str) -> None:
|
||||
"""Delete an overridden ratelimit for a user.
|
||||
Args:
|
||||
user_id: user ID of the user
|
||||
"""
|
||||
|
||||
def delete_ratelimit_txn(txn):
|
||||
row = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
table="ratelimit_override",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=["user_id"],
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
if not row:
|
||||
return
|
||||
|
||||
# They are there, delete them.
|
||||
self.db_pool.simple_delete_one_txn(
|
||||
txn, "ratelimit_override", keyvalues={"user_id": user_id}
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_ratelimit_for_user, (user_id,)
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn)
|
||||
|
||||
@cached()
|
||||
async def get_retention_policy_for_room(self, room_id):
|
||||
"""Get the retention policy for a given room.
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
/* Copyright 2021 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- drop old constraints on remote_media_cache_thumbnails
|
||||
--
|
||||
-- This was originally part of 57.07, but it was done wrong, per
|
||||
-- https://github.com/matrix-org/synapse/issues/8649, so we do it again.
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
|
||||
(5911, 'media_repository_drop_index_wo_method_2', '{}', 'remote_media_repository_thumbnails_method_idx');
|
||||
|
|
@ -190,7 +190,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
# FIXME: how should this be cached?
|
||||
async def get_filtered_current_state_ids(
|
||||
self, room_id: str, state_filter: StateFilter = StateFilter.all()
|
||||
self, room_id: str, state_filter: Optional[StateFilter] = None
|
||||
) -> StateMap[str]:
|
||||
"""Get the current state event of a given type for a room based on the
|
||||
current_state_events table. This may not be as up-to-date as the result
|
||||
|
@ -205,7 +205,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
Map from type/state_key to event ID.
|
||||
"""
|
||||
|
||||
where_clause, where_args = state_filter.make_sql_filter_clause()
|
||||
where_clause, where_args = (
|
||||
state_filter or StateFilter.all()
|
||||
).make_sql_filter_clause()
|
||||
|
||||
if not where_clause:
|
||||
# We delegate to the cached version
|
||||
|
|
|
@ -66,18 +66,37 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")}
|
|||
class UserSortOrder(Enum):
|
||||
"""
|
||||
Enum to define the sorting method used when returning users
|
||||
with get_users_media_usage_paginate
|
||||
with get_users_paginate in __init__.py
|
||||
and get_users_media_usage_paginate in stats.py
|
||||
|
||||
MEDIA_LENGTH = ordered by size of uploaded media. Smallest to largest.
|
||||
MEDIA_COUNT = ordered by number of uploaded media. Smallest to largest.
|
||||
When moves this to __init__.py gets `builtins.ImportError` with
|
||||
`most likely due to a circular import`
|
||||
|
||||
MEDIA_LENGTH = ordered by size of uploaded media.
|
||||
MEDIA_COUNT = ordered by number of uploaded media.
|
||||
USER_ID = ordered alphabetically by `user_id`.
|
||||
NAME = ordered alphabetically by `user_id`. This is for compatibility reasons,
|
||||
as the user_id is returned in the name field in the response in list users admin API.
|
||||
DISPLAYNAME = ordered alphabetically by `displayname`
|
||||
GUEST = ordered by `is_guest`
|
||||
ADMIN = ordered by `admin`
|
||||
DEACTIVATED = ordered by `deactivated`
|
||||
USER_TYPE = ordered alphabetically by `user_type`
|
||||
AVATAR_URL = ordered alphabetically by `avatar_url`
|
||||
SHADOW_BANNED = ordered by `shadow_banned`
|
||||
"""
|
||||
|
||||
MEDIA_LENGTH = "media_length"
|
||||
MEDIA_COUNT = "media_count"
|
||||
USER_ID = "user_id"
|
||||
NAME = "name"
|
||||
DISPLAYNAME = "displayname"
|
||||
GUEST = "is_guest"
|
||||
ADMIN = "admin"
|
||||
DEACTIVATED = "deactivated"
|
||||
USER_TYPE = "user_type"
|
||||
AVATAR_URL = "avatar_url"
|
||||
SHADOW_BANNED = "shadow_banned"
|
||||
|
||||
|
||||
class StatsStore(StateDeltasStore):
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import DatabasePool
|
||||
|
@ -73,8 +74,10 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
|
|||
return count
|
||||
|
||||
def _get_state_groups_from_groups_txn(
|
||||
self, txn, groups, state_filter=StateFilter.all()
|
||||
self, txn, groups, state_filter: Optional[StateFilter] = None
|
||||
):
|
||||
state_filter = state_filter or StateFilter.all()
|
||||
|
||||
results = {group: {} for group in groups}
|
||||
|
||||
where_clause, where_args = state_filter.make_sql_filter_clause()
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Dict, Iterable, List, Set, Tuple
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
|
@ -210,7 +210,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
return state_filter.filter_state(state_dict_ids), not missing_types
|
||||
|
||||
async def _get_state_for_groups(
|
||||
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
|
||||
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
|
||||
) -> Dict[int, MutableStateMap[str]]:
|
||||
"""Gets the state at each of a list of state groups, optionally
|
||||
filtering by type/state_key
|
||||
|
@ -223,6 +223,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
Returns:
|
||||
Dict of state group to state map.
|
||||
"""
|
||||
state_filter = state_filter or StateFilter.all()
|
||||
|
||||
member_filter, non_member_filter = state_filter.get_member_split()
|
||||
|
||||
|
|
|
@ -42,14 +42,6 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
|
|||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def supports_tuple_comparison(self) -> bool:
|
||||
"""
|
||||
Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def supports_using_any_list(self) -> bool:
|
||||
|
|
|
@ -47,8 +47,8 @@ class PostgresEngine(BaseDatabaseEngine):
|
|||
self._version = db_conn.server_version
|
||||
|
||||
# Are we on a supported PostgreSQL version?
|
||||
if not allow_outdated_version and self._version < 90500:
|
||||
raise RuntimeError("Synapse requires PostgreSQL 9.5+ or above.")
|
||||
if not allow_outdated_version and self._version < 90600:
|
||||
raise RuntimeError("Synapse requires PostgreSQL 9.6 or above.")
|
||||
|
||||
with db_conn.cursor() as txn:
|
||||
txn.execute("SHOW SERVER_ENCODING")
|
||||
|
@ -129,13 +129,6 @@ class PostgresEngine(BaseDatabaseEngine):
|
|||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_tuple_comparison(self):
|
||||
"""
|
||||
Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_using_any_list(self):
|
||||
"""Do we support using `a = ANY(?)` and passing a list"""
|
||||
|
|
|
@ -56,14 +56,6 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
|
|||
"""
|
||||
return self.module.sqlite_version_info >= (3, 24, 0)
|
||||
|
||||
@property
|
||||
def supports_tuple_comparison(self):
|
||||
"""
|
||||
Do we support comparing tuples, i.e. `(a, b) > (c, d)`? This requires
|
||||
SQLite 3.15+.
|
||||
"""
|
||||
return self.module.sqlite_version_info >= (3, 15, 0)
|
||||
|
||||
@property
|
||||
def supports_using_any_list(self):
|
||||
"""Do we support using `a = ANY(?)` and passing a list"""
|
||||
|
@ -72,8 +64,11 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
|
|||
def check_database(self, db_conn, allow_outdated_version: bool = False):
|
||||
if not allow_outdated_version:
|
||||
version = self.module.sqlite_version_info
|
||||
if version < (3, 11, 0):
|
||||
raise RuntimeError("Synapse requires sqlite 3.11 or above.")
|
||||
# Synapse is untested against older SQLite versions, and we don't want
|
||||
# to let users upgrade to a version of Synapse with broken support for their
|
||||
# sqlite version, because it risks leaving them with a half-upgraded db.
|
||||
if version < (3, 22, 0):
|
||||
raise RuntimeError("Synapse requires sqlite 3.22 or above.")
|
||||
|
||||
def check_new_database(self, txn):
|
||||
"""Gets called when setting up a brand new database. This allows us to
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import imp
|
||||
import importlib.util
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
@ -454,8 +454,13 @@ def _upgrade_existing_database(
|
|||
)
|
||||
|
||||
module_name = "synapse.storage.v%d_%s" % (v, root_name)
|
||||
with open(absolute_path) as python_file:
|
||||
module = imp.load_source(module_name, absolute_path, python_file) # type: ignore
|
||||
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
module_name, absolute_path
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
|
||||
logger.info("Running script %s", relative_path)
|
||||
module.run_create(cur, database_engine) # type: ignore
|
||||
if not is_empty:
|
||||
|
|
|
@ -449,7 +449,7 @@ class StateGroupStorage:
|
|||
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
|
||||
|
||||
async def get_state_for_events(
|
||||
self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
|
||||
self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
|
||||
) -> Dict[str, StateMap[EventBase]]:
|
||||
"""Given a list of event_ids and type tuples, return a list of state
|
||||
dicts for each event.
|
||||
|
@ -465,7 +465,7 @@ class StateGroupStorage:
|
|||
|
||||
groups = set(event_to_groups.values())
|
||||
group_to_state = await self.stores.state._get_state_for_groups(
|
||||
groups, state_filter
|
||||
groups, state_filter or StateFilter.all()
|
||||
)
|
||||
|
||||
state_event_map = await self.stores.main.get_events(
|
||||
|
@ -485,7 +485,7 @@ class StateGroupStorage:
|
|||
return {event: event_to_state[event] for event in event_ids}
|
||||
|
||||
async def get_state_ids_for_events(
|
||||
self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
|
||||
self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
|
||||
) -> Dict[str, StateMap[str]]:
|
||||
"""
|
||||
Get the state dicts corresponding to a list of events, containing the event_ids
|
||||
|
@ -502,7 +502,7 @@ class StateGroupStorage:
|
|||
|
||||
groups = set(event_to_groups.values())
|
||||
group_to_state = await self.stores.state._get_state_for_groups(
|
||||
groups, state_filter
|
||||
groups, state_filter or StateFilter.all()
|
||||
)
|
||||
|
||||
event_to_state = {
|
||||
|
@ -513,7 +513,7 @@ class StateGroupStorage:
|
|||
return {event: event_to_state[event] for event in event_ids}
|
||||
|
||||
async def get_state_for_event(
|
||||
self, event_id: str, state_filter: StateFilter = StateFilter.all()
|
||||
self, event_id: str, state_filter: Optional[StateFilter] = None
|
||||
) -> StateMap[EventBase]:
|
||||
"""
|
||||
Get the state dict corresponding to a particular event
|
||||
|
@ -525,11 +525,13 @@ class StateGroupStorage:
|
|||
Returns:
|
||||
A dict from (type, state_key) -> state_event
|
||||
"""
|
||||
state_map = await self.get_state_for_events([event_id], state_filter)
|
||||
state_map = await self.get_state_for_events(
|
||||
[event_id], state_filter or StateFilter.all()
|
||||
)
|
||||
return state_map[event_id]
|
||||
|
||||
async def get_state_ids_for_event(
|
||||
self, event_id: str, state_filter: StateFilter = StateFilter.all()
|
||||
self, event_id: str, state_filter: Optional[StateFilter] = None
|
||||
) -> StateMap[str]:
|
||||
"""
|
||||
Get the state dict corresponding to a particular event
|
||||
|
@ -541,11 +543,13 @@ class StateGroupStorage:
|
|||
Returns:
|
||||
A dict from (type, state_key) -> state_event
|
||||
"""
|
||||
state_map = await self.get_state_ids_for_events([event_id], state_filter)
|
||||
state_map = await self.get_state_ids_for_events(
|
||||
[event_id], state_filter or StateFilter.all()
|
||||
)
|
||||
return state_map[event_id]
|
||||
|
||||
def _get_state_for_groups(
|
||||
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
|
||||
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
|
||||
) -> Awaitable[Dict[int, MutableStateMap[str]]]:
|
||||
"""Gets the state at each of a list of state groups, optionally
|
||||
filtering by type/state_key
|
||||
|
@ -558,7 +562,9 @@ class StateGroupStorage:
|
|||
Returns:
|
||||
Dict of state group to state map.
|
||||
"""
|
||||
return self.stores.state._get_state_for_groups(groups, state_filter)
|
||||
return self.stores.state._get_state_for_groups(
|
||||
groups, state_filter or StateFilter.all()
|
||||
)
|
||||
|
||||
async def store_state_group(
|
||||
self,
|
||||
|
|
|
@ -17,7 +17,7 @@ import logging
|
|||
import threading
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import attr
|
||||
|
||||
|
@ -91,7 +91,14 @@ class StreamIdGenerator:
|
|||
# ... persist event ...
|
||||
"""
|
||||
|
||||
def __init__(self, db_conn, table, column, extra_tables=[], step=1):
|
||||
def __init__(
|
||||
self,
|
||||
db_conn,
|
||||
table,
|
||||
column,
|
||||
extra_tables: Iterable[Tuple[str, str]] = (),
|
||||
step=1,
|
||||
):
|
||||
assert step != 0
|
||||
self._lock = threading.Lock()
|
||||
self._step = step
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue