Merge remote-tracking branch 'upstream/release-v1.56'

This commit is contained in:
Tulir Asokan 2022-03-29 13:30:30 +03:00
commit ef6a9c8a70
98 changed files with 6349 additions and 5436 deletions

View file

@ -41,6 +41,7 @@ from prometheus_client import Histogram
from typing_extensions import Literal
from twisted.enterprise import adbapi
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
@ -55,6 +56,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
from synapse.util.async_helpers import delay_cancellation
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@ -286,13 +288,17 @@ class LoggingTransaction:
"""
if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch # type: ignore
from psycopg2.extras import execute_batch
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
self._do_execute(
lambda the_sql: execute_batch(self.txn, the_sql, args), sql
)
else:
self.executemany(sql, args)
def execute_values(self, sql: str, *args: Any, fetch: bool = True) -> List[Tuple]:
def execute_values(
self, sql: str, values: Iterable[Iterable[Any]], fetch: bool = True
) -> List[Tuple]:
"""Corresponds to psycopg2.extras.execute_values. Only available when
using postgres.
@ -300,10 +306,11 @@ class LoggingTransaction:
rows (e.g. INSERTs).
"""
assert isinstance(self.database_engine, PostgresEngine)
from psycopg2.extras import execute_values # type: ignore
from psycopg2.extras import execute_values
return self._do_execute(
lambda *x: execute_values(self.txn, *x, fetch=fetch), sql, *args
lambda the_sql: execute_values(self.txn, the_sql, values, fetch=fetch),
sql,
)
def execute(self, sql: str, *args: Any) -> None:
@ -732,34 +739,45 @@ class DatabasePool:
Returns:
The result of func
"""
after_callbacks: List[_CallbackListEntry] = []
exception_callbacks: List[_CallbackListEntry] = []
if not current_context():
logger.warning("Starting db txn '%s' from sentinel context", desc)
async def _runInteraction() -> R:
after_callbacks: List[_CallbackListEntry] = []
exception_callbacks: List[_CallbackListEntry] = []
try:
with opentracing.start_active_span(f"db.{desc}"):
result = await self.runWithConnection(
self.new_transaction,
desc,
after_callbacks,
exception_callbacks,
func,
*args,
db_autocommit=db_autocommit,
isolation_level=isolation_level,
**kwargs,
)
if not current_context():
logger.warning("Starting db txn '%s' from sentinel context", desc)
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
except Exception:
for after_callback, after_args, after_kwargs in exception_callbacks:
after_callback(*after_args, **after_kwargs)
raise
try:
with opentracing.start_active_span(f"db.{desc}"):
result = await self.runWithConnection(
self.new_transaction,
desc,
after_callbacks,
exception_callbacks,
func,
*args,
db_autocommit=db_autocommit,
isolation_level=isolation_level,
**kwargs,
)
return cast(R, result)
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
return cast(R, result)
except Exception:
for after_callback, after_args, after_kwargs in exception_callbacks:
after_callback(*after_args, **after_kwargs)
raise
# To handle cancellation, we ensure that `after_callback`s and
# `exception_callback`s are always run, since the transaction will complete
# on another thread regardless of cancellation.
#
# We also wait until everything above is done before releasing the
# `CancelledError`, so that logging contexts won't get used after they have been
# finished.
return await delay_cancellation(defer.ensureDeferred(_runInteraction()))
async def runWithConnection(
self,

View file

@ -14,7 +14,17 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast
from typing import (
TYPE_CHECKING,
Any,
Dict,
FrozenSet,
Iterable,
List,
Optional,
Tuple,
cast,
)
from synapse.api.constants import AccountDataTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@ -365,7 +375,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
@cached(max_entries=5000, iterable=True)
async def ignored_by(self, user_id: str) -> Set[str]:
async def ignored_by(self, user_id: str) -> FrozenSet[str]:
"""
Get users which ignore the given user.
@ -375,7 +385,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
Return:
The user IDs which ignore the given user.
"""
return set(
return frozenset(
await self.db_pool.simple_select_onecol(
table="ignored_users",
keyvalues={"ignored_user_id": user_id},
@ -384,6 +394,26 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
)
@cached(max_entries=5000, iterable=True)
async def ignored_users(self, user_id: str) -> FrozenSet[str]:
"""
Get users which the given user ignores.
Params:
user_id: The user ID which is making the request.
Return:
The user IDs which are ignored by the given user.
"""
return frozenset(
await self.db_pool.simple_select_onecol(
table="ignored_users",
keyvalues={"ignorer_user_id": user_id},
retcol="ignored_user_id",
desc="ignored_users",
)
)
def process_replication_rows(
self,
stream_name: str,
@ -529,6 +559,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
else:
currently_ignored_users = set()
# If the data has not changed, nothing to do.
if previously_ignored_users == currently_ignored_users:
return
# Delete entries which are no longer ignored.
self.db_pool.simple_delete_many_txn(
txn,
@ -551,6 +585,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
# Invalidate the cache for any ignored users which were added or removed.
for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
async def purge_account_data_for_user(self, user_id: str) -> None:
"""

View file

@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamCurrentStateRow,
EventsStreamEventRow,
EventsStreamRow,
)
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
@ -31,6 +32,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.util.caches.descriptors import _CachedFunction
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@ -82,7 +84,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
def get_all_updated_caches_txn(txn):
def get_all_updated_caches_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
@ -107,7 +111,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"get_all_updated_caches", get_all_updated_caches_txn
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == EventsStream.NAME:
for row in rows:
self._process_event_stream_row(token, row)
@ -142,10 +148,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
super().process_replication_rows(stream_name, instance_name, token, rows)
def _process_event_stream_row(self, token, row):
def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
data = row.data
if row.type == EventsStreamEventRow.TypeId:
assert isinstance(data, EventsStreamEventRow)
self._invalidate_caches_for_event(
token,
data.event_id,
@ -157,9 +164,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
backfilled=False,
)
elif row.type == EventsStreamCurrentStateRow.TypeId:
self._curr_state_delta_stream_cache.entity_has_changed(
row.data.room_id, token
)
assert isinstance(data, EventsStreamCurrentStateRow)
self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token)
if data.type == EventTypes.Member:
self.get_rooms_for_user_with_stream_ordering.invalidate(
@ -170,15 +176,15 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
def _invalidate_caches_for_event(
self,
stream_ordering,
event_id,
room_id,
etype,
state_key,
redacts,
relates_to,
backfilled,
):
stream_ordering: int,
event_id: str,
room_id: str,
etype: str,
state_key: Optional[str],
redacts: Optional[str],
relates_to: Optional[str],
backfilled: bool,
) -> None:
self._invalidate_get_event_cache(event_id)
self.have_seen_event.invalidate((room_id, event_id))
@ -186,6 +192,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,))
# The `_get_membership_from_event_id` is immutable, except for the
# case where we look up an event *before* persisting it.
self._get_membership_from_event_id.invalidate((event_id,))
if not backfilled:
self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
@ -207,7 +217,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_thread_summary.invalidate((relates_to,))
self.get_thread_participated.invalidate((relates_to,))
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
async def invalidate_cache_and_stream(
self, cache_name: str, keys: Tuple[Any, ...]
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@ -227,7 +239,12 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
keys,
)
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
def _invalidate_cache_and_stream(
self,
txn: LoggingTransaction,
cache_func: _CachedFunction,
keys: Tuple[Any, ...],
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@ -238,7 +255,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
def _invalidate_all_cache_and_stream(self, txn, cache_func):
def _invalidate_all_cache_and_stream(
self, txn: LoggingTransaction, cache_func: _CachedFunction
) -> None:
"""Invalidates the entire cache and adds it to the cache stream so slaves
will know to invalidate their caches.
"""
@ -279,8 +298,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
)
def _send_invalidation_to_replication(
self, txn, cache_name: str, keys: Optional[Iterable[Any]]
):
self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
) -> None:
"""Notifies replication that given cache has been invalidated.
Note that this does *not* invalidate the cache locally.
@ -315,7 +334,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"instance_name": self._instance_name,
"cache_func": cache_name,
"keys": keys,
"invalidation_ts": self.clock.time_msec(),
"invalidation_ts": self._clock.time_msec(),
},
)

View file

@ -1073,9 +1073,15 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
/* Get the depth and stream_ordering of the prev_event_id from the events table */
INNER JOIN events
ON prev_event_id = events.event_id
/* exclude outliers from the results (we don't have the state, so cannot
* verify if the requesting server can see them).
*/
WHERE NOT events.outlier
/* Look for an edge which matches the given event_id */
WHERE event_edges.event_id = ?
AND event_edges.is_state = ?
AND event_edges.event_id = ? AND NOT event_edges.is_state
/* Because we can have many events at the same depth,
* we want to also tie-break and sort on stream_ordering */
ORDER BY depth DESC, stream_ordering DESC
@ -1084,7 +1090,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
txn.execute(
connected_prev_event_query,
(event_id, False, limit),
(event_id, limit),
)
return [
BackfillQueueNavigationItem(

View file

@ -1745,6 +1745,13 @@ class PersistEventsStore:
(event.state_key,),
)
# The `_get_membership_from_event_id` is immutable, except for the
# case where we look up an event *before* persisting it.
txn.call_after(
self.store._get_membership_from_event_id.invalidate,
(event.event_id,),
)
# We update the local_current_membership table only if the event is
# "current", i.e., its something that has just happened.
#

View file

@ -13,13 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
from typing_extensions import TypedDict
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.types import JsonDict
from synapse.util import json_encoder
@ -75,7 +79,7 @@ class GroupServerWorkerStore(SQLBaseStore):
) -> List[Dict[str, Any]]:
# TODO: Pagination
keyvalues = {"group_id": group_id}
keyvalues: JsonDict = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
@ -117,7 +121,7 @@ class GroupServerWorkerStore(SQLBaseStore):
# TODO: Pagination
def _get_rooms_in_group_txn(txn):
def _get_rooms_in_group_txn(txn: LoggingTransaction) -> List[_RoomInGroup]:
sql = """
SELECT room_id, is_public FROM group_rooms
WHERE group_id = ?
@ -176,8 +180,10 @@ class GroupServerWorkerStore(SQLBaseStore):
* "order": int, the sort order of rooms in this category
"""
def _get_rooms_for_summary_txn(txn):
keyvalues = {"group_id": group_id}
def _get_rooms_for_summary_txn(
txn: LoggingTransaction,
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
keyvalues: JsonDict = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
@ -241,7 +247,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_rooms_for_summary", _get_rooms_for_summary_txn
)
async def get_group_categories(self, group_id):
async def get_group_categories(self, group_id: str) -> JsonDict:
rows = await self.db_pool.simple_select_list(
table="group_room_categories",
keyvalues={"group_id": group_id},
@ -257,7 +263,7 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in rows
}
async def get_group_category(self, group_id, category_id):
async def get_group_category(self, group_id: str, category_id: str) -> JsonDict:
category = await self.db_pool.simple_select_one(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
@ -269,7 +275,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return category
async def get_group_roles(self, group_id):
async def get_group_roles(self, group_id: str) -> JsonDict:
rows = await self.db_pool.simple_select_list(
table="group_roles",
keyvalues={"group_id": group_id},
@ -285,7 +291,7 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in rows
}
async def get_group_role(self, group_id, role_id):
async def get_group_role(self, group_id: str, role_id: str) -> JsonDict:
role = await self.db_pool.simple_select_one(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
@ -311,15 +317,19 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_local_groups_for_room",
)
async def get_users_for_summary_by_role(self, group_id, include_private=False):
async def get_users_for_summary_by_role(
self, group_id: str, include_private: bool = False
) -> Tuple[List[JsonDict], JsonDict]:
"""Get the users and roles that should be included in a summary request
Returns:
([users], [roles])
"""
def _get_users_for_summary_txn(txn):
keyvalues = {"group_id": group_id}
def _get_users_for_summary_txn(
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], JsonDict]:
keyvalues: JsonDict = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
@ -406,7 +416,9 @@ class GroupServerWorkerStore(SQLBaseStore):
allow_none=True,
)
async def get_users_membership_info_in_group(self, group_id, user_id):
async def get_users_membership_info_in_group(
self, group_id: str, user_id: str
) -> JsonDict:
"""Get a dict describing the membership of a user in a group.
Example if joined:
@ -421,7 +433,7 @@ class GroupServerWorkerStore(SQLBaseStore):
An empty dict if the user is not join/invite/etc
"""
def _get_users_membership_in_group_txn(txn):
def _get_users_membership_in_group_txn(txn: LoggingTransaction) -> JsonDict:
row = self.db_pool.simple_select_one_txn(
txn,
table="group_users",
@ -463,10 +475,14 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_publicised_groups_for_user",
)
async def get_attestations_need_renewals(self, valid_until_ms):
async def get_attestations_need_renewals(
self, valid_until_ms: int
) -> List[Dict[str, Any]]:
"""Get all attestations that need to be renewed until givent time"""
def _get_attestations_need_renewals_txn(txn):
def _get_attestations_need_renewals_txn(
txn: LoggingTransaction,
) -> List[Dict[str, Any]]:
sql = """
SELECT group_id, user_id FROM group_attestations_renewals
WHERE valid_until_ms <= ?
@ -478,7 +494,9 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_attestations_need_renewals", _get_attestations_need_renewals_txn
)
async def get_remote_attestation(self, group_id, user_id):
async def get_remote_attestation(
self, group_id: str, user_id: str
) -> Optional[JsonDict]:
"""Get the attestation that proves the remote agrees that the user is
in the group.
"""
@ -504,8 +522,8 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_joined_groups",
)
async def get_all_groups_for_user(self, user_id, now_token):
def _get_all_groups_for_user_txn(txn):
async def get_all_groups_for_user(self, user_id, now_token) -> List[JsonDict]:
def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
sql = """
SELECT group_id, type, membership, u.content
FROM local_group_updates AS u
@ -528,15 +546,16 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_all_groups_for_user", _get_all_groups_for_user_txn
)
async def get_groups_changes_for_user(self, user_id, from_token, to_token):
from_token = int(from_token)
has_changed = self._group_updates_stream_cache.has_entity_changed(
async def get_groups_changes_for_user(
self, user_id: str, from_token: int, to_token: int
) -> List[JsonDict]:
has_changed = self._group_updates_stream_cache.has_entity_changed( # type: ignore[attr-defined]
user_id, from_token
)
if not has_changed:
return []
def _get_groups_changes_for_user_txn(txn):
def _get_groups_changes_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
sql = """
SELECT group_id, membership, type, u.content
FROM local_group_updates AS u
@ -583,12 +602,14 @@ class GroupServerWorkerStore(SQLBaseStore):
"""
last_id = int(last_id)
has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id)
has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) # type: ignore[attr-defined]
if not has_changed:
return [], current_id, False
def _get_all_groups_changes_txn(txn):
def _get_all_groups_changes_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
sql = """
SELECT stream_id, group_id, user_id, type, content
FROM local_group_updates
@ -596,10 +617,13 @@ class GroupServerWorkerStore(SQLBaseStore):
LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
updates = [
(stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
for stream_id, group_id, user_id, gtype, content_json in txn
]
updates = cast(
List[Tuple[int, tuple]],
[
(stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
for stream_id, group_id, user_id, gtype, content_json in txn
],
)
limited = False
upto_token = current_id
@ -633,8 +657,8 @@ class GroupServerStore(GroupServerWorkerStore):
self,
group_id: str,
room_id: str,
category_id: str,
order: int,
category_id: Optional[str],
order: Optional[int],
is_public: Optional[bool],
) -> None:
"""Add (or update) room's entry in summary.
@ -661,11 +685,11 @@ class GroupServerStore(GroupServerWorkerStore):
def _add_room_to_summary_txn(
self,
txn,
txn: LoggingTransaction,
group_id: str,
room_id: str,
category_id: str,
order: int,
category_id: Optional[str],
order: Optional[int],
is_public: Optional[bool],
) -> None:
"""Add (or update) room's entry in summary.
@ -750,7 +774,7 @@ class GroupServerStore(GroupServerWorkerStore):
WHERE group_id = ? AND category_id = ?
"""
txn.execute(sql, (group_id, category_id))
(order,) = txn.fetchone()
(order,) = cast(Tuple[int], txn.fetchone())
if existing:
to_update = {}
@ -766,7 +790,7 @@ class GroupServerStore(GroupServerWorkerStore):
"category_id": category_id,
"room_id": room_id,
},
values=to_update,
updatevalues=to_update,
)
else:
if is_public is None:
@ -785,7 +809,7 @@ class GroupServerStore(GroupServerWorkerStore):
)
async def remove_room_from_summary(
self, group_id: str, room_id: str, category_id: str
self, group_id: str, room_id: str, category_id: Optional[str]
) -> int:
if category_id is None:
category_id = _DEFAULT_CATEGORY_ID
@ -808,8 +832,8 @@ class GroupServerStore(GroupServerWorkerStore):
is_public: Optional[bool],
) -> None:
"""Add/update room category for group"""
insertion_values = {}
update_values = {"category_id": category_id} # This cannot be empty
insertion_values: JsonDict = {}
update_values: JsonDict = {"category_id": category_id} # This cannot be empty
if profile is None:
insertion_values["profile"] = "{}"
@ -844,8 +868,8 @@ class GroupServerStore(GroupServerWorkerStore):
is_public: Optional[bool],
) -> None:
"""Add/remove user role"""
insertion_values = {}
update_values = {"role_id": role_id} # This cannot be empty
insertion_values: JsonDict = {}
update_values: JsonDict = {"role_id": role_id} # This cannot be empty
if profile is None:
insertion_values["profile"] = "{}"
@ -876,8 +900,8 @@ class GroupServerStore(GroupServerWorkerStore):
self,
group_id: str,
user_id: str,
role_id: str,
order: int,
role_id: Optional[str],
order: Optional[int],
is_public: Optional[bool],
) -> None:
"""Add (or update) user's entry in summary.
@ -904,13 +928,13 @@ class GroupServerStore(GroupServerWorkerStore):
def _add_user_to_summary_txn(
self,
txn,
txn: LoggingTransaction,
group_id: str,
user_id: str,
role_id: str,
order: int,
role_id: Optional[str],
order: Optional[int],
is_public: Optional[bool],
):
) -> None:
"""Add (or update) user's entry in summary.
Args:
@ -989,7 +1013,7 @@ class GroupServerStore(GroupServerWorkerStore):
WHERE group_id = ? AND role_id = ?
"""
txn.execute(sql, (group_id, role_id))
(order,) = txn.fetchone()
(order,) = cast(Tuple[int], txn.fetchone())
if existing:
to_update = {}
@ -1005,7 +1029,7 @@ class GroupServerStore(GroupServerWorkerStore):
"role_id": role_id,
"user_id": user_id,
},
values=to_update,
updatevalues=to_update,
)
else:
if is_public is None:
@ -1024,7 +1048,7 @@ class GroupServerStore(GroupServerWorkerStore):
)
async def remove_user_from_summary(
self, group_id: str, user_id: str, role_id: str
self, group_id: str, user_id: str, role_id: Optional[str]
) -> int:
if role_id is None:
role_id = _DEFAULT_ROLE_ID
@ -1065,7 +1089,7 @@ class GroupServerStore(GroupServerWorkerStore):
Optional if the user and group are on the same server
"""
def _add_user_to_group_txn(txn):
def _add_user_to_group_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_insert_txn(
txn,
table="group_users",
@ -1108,7 +1132,7 @@ class GroupServerStore(GroupServerWorkerStore):
await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
async def remove_user_from_group(self, group_id: str, user_id: str) -> None:
def _remove_user_from_group_txn(txn):
def _remove_user_from_group_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn,
table="group_users",
@ -1159,7 +1183,7 @@ class GroupServerStore(GroupServerWorkerStore):
)
async def remove_room_from_group(self, group_id: str, room_id: str) -> None:
def _remove_room_from_group_txn(txn):
def _remove_room_from_group_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn,
table="group_rooms",
@ -1216,7 +1240,9 @@ class GroupServerStore(GroupServerWorkerStore):
content = content or {}
def _register_user_group_membership_txn(txn, next_id):
def _register_user_group_membership_txn(
txn: LoggingTransaction, next_id: int
) -> int:
# TODO: Upsert?
self.db_pool.simple_delete_txn(
txn,
@ -1249,7 +1275,7 @@ class GroupServerStore(GroupServerWorkerStore):
),
},
)
self._group_updates_stream_cache.entity_has_changed(user_id, next_id)
self._group_updates_stream_cache.entity_has_changed(user_id, next_id) # type: ignore[attr-defined]
# TODO: Insert profile to ensure it comes down stream if its a join.
@ -1289,7 +1315,7 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id
async with self._group_updates_id_gen.get_next() as next_id:
async with self._group_updates_id_gen.get_next() as next_id: # type: ignore[attr-defined]
res = await self.db_pool.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
@ -1298,7 +1324,13 @@ class GroupServerStore(GroupServerWorkerStore):
return res
async def create_group(
self, group_id, user_id, name, avatar_url, short_description, long_description
self,
group_id: str,
user_id: str,
name: str,
avatar_url: str,
short_description: str,
long_description: str,
) -> None:
await self.db_pool.simple_insert(
table="groups",
@ -1313,7 +1345,7 @@ class GroupServerStore(GroupServerWorkerStore):
desc="create_group",
)
async def update_group_profile(self, group_id, profile):
async def update_group_profile(self, group_id: str, profile: JsonDict) -> None:
await self.db_pool.simple_update_one(
table="groups",
keyvalues={"group_id": group_id},
@ -1361,8 +1393,8 @@ class GroupServerStore(GroupServerWorkerStore):
desc="remove_attestation_renewal",
)
def get_group_stream_token(self):
return self._group_updates_id_gen.get_current_token()
def get_group_stream_token(self) -> int:
return self._group_updates_id_gen.get_current_token() # type: ignore[attr-defined]
async def delete_group(self, group_id: str) -> None:
"""Deletes a group fully from the database.
@ -1371,7 +1403,7 @@ class GroupServerStore(GroupServerWorkerStore):
group_id: The group ID to delete.
"""
def _delete_group_txn(txn):
def _delete_group_txn(txn: LoggingTransaction) -> None:
tables = [
"groups",
"group_users",

View file

@ -156,7 +156,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
self.server_name: str = hs.hostname
async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
"""Get the metadata for a local piece of media

View file

@ -12,15 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.databases.main.registration import RegistrationWorkerStore
from synapse.util.caches.descriptors import cached
from synapse.util.threepids import canonicalise_email
@ -56,7 +58,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
Number of current monthly active users
"""
def _count_users(txn):
def _count_users(txn: LoggingTransaction) -> int:
# Exclude app service users
sql = """
SELECT COUNT(*)
@ -66,7 +68,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
WHERE (users.appservice_id IS NULL OR users.appservice_id = '');
"""
txn.execute(sql)
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction("count_users", _count_users)
@ -84,7 +86,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"""
def _count_users_by_service(txn):
def _count_users_by_service(txn: LoggingTransaction) -> Dict[str, int]:
sql = """
SELECT COALESCE(appservice_id, 'native'), COUNT(*)
FROM monthly_active_users
@ -93,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"""
txn.execute(sql)
result = txn.fetchall()
result = cast(List[Tuple[str, int]], txn.fetchall())
return dict(result)
return await self.db_pool.runInteraction(
@ -141,12 +143,12 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
)
@wrap_as_background_process("reap_monthly_active_users")
async def reap_monthly_active_users(self):
async def reap_monthly_active_users(self) -> None:
"""Cleans out monthly active user table to ensure that no stale
entries exist.
"""
def _reap_users(txn, reserved_users):
def _reap_users(txn: LoggingTransaction, reserved_users: List[str]) -> None:
"""
Args:
reserved_users (tuple): reserved users to preserve
@ -210,10 +212,10 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
# is racy.
# Have resolved to invalidate the whole cache for now and do
# something about it if and when the perf becomes significant
self._invalidate_all_cache_and_stream(
self._invalidate_all_cache_and_stream( # type: ignore[attr-defined]
txn, self.user_last_seen_monthly_active
)
self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) # type: ignore[attr-defined]
reserved_users = await self.get_registered_reserved_users()
await self.db_pool.runInteraction(
@ -221,7 +223,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
)
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore):
def __init__(
self,
database: DatabasePool,
@ -242,13 +244,15 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
)
def _initialise_reserved_users(self, txn, threepids):
def _initialise_reserved_users(
self, txn: LoggingTransaction, threepids: List[dict]
) -> None:
"""Ensures that reserved threepids are accounted for in the MAU table, should
be called on start up.
Args:
txn (cursor):
threepids (list[dict]): List of threepid dicts to reserve
txn:
threepids: List of threepid dicts to reserve
"""
# XXX what is this function trying to achieve? It upserts into
@ -299,7 +303,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
)
def upsert_monthly_active_user_txn(self, txn, user_id):
def upsert_monthly_active_user_txn(
self, txn: LoggingTransaction, user_id: str
) -> None:
"""Updates or inserts monthly active user member
We consciously do not call is_support_txn from this method because it
@ -336,7 +342,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
txn, self.user_last_seen_monthly_active, (user_id,)
)
async def populate_monthly_active_users(self, user_id):
async def populate_monthly_active_users(self, user_id: str) -> None:
"""Checks on the state of monthly active user limits and optionally
add the user to the monthly active tables
@ -345,7 +351,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"""
if self._limit_usage_by_mau or self._mau_stats_only:
# Trial users and guests should not be included as part of MAU group
is_guest = await self.is_guest(user_id)
is_guest = await self.is_guest(user_id) # type: ignore[attr-defined]
if is_guest:
return
is_trial = await self.is_trial_user(user_id)

View file

@ -24,10 +24,9 @@ from typing import (
Optional,
Set,
Tuple,
cast,
)
from twisted.internet import defer
from synapse.api.constants import ReceiptTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream
@ -38,7 +37,11 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.id_generators import (
AbstractStreamIdTracker,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@ -58,6 +61,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
hs: "HomeServer",
):
self._instance_name = hs.get_instance_name()
self._receipts_id_gen: AbstractStreamIdTracker
if isinstance(database.engine, PostgresEngine):
self._can_write_to_receipts = (
@ -161,7 +165,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
" AND user_id = ?"
)
txn.execute(sql, (user_id,))
return txn.fetchall()
return cast(List[Tuple[str, str, int, int]], txn.fetchall())
rows = await self.db_pool.runInteraction(
"get_receipts_for_user_with_orderings", f
@ -257,7 +261,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
if not rows:
return []
content = {}
content: JsonDict = {}
for row in rows:
content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
row["user_id"]
@ -305,7 +309,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"_get_linearized_receipts_for_rooms", f
)
results = {}
results: JsonDict = {}
for row in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
@ -370,7 +374,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"get_linearized_receipts_for_all_rooms", f
)
results = {}
results: JsonDict = {}
for row in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
@ -399,7 +403,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""
if last_id == current_id:
return defer.succeed([])
return []
def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
sql = """
@ -453,7 +457,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""
txn.execute(sql, (last_id, current_id, limit))
updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn]
updates = cast(
List[Tuple[int, list]],
[(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn],
)
limited = False
upper_bound = current_id
@ -496,7 +503,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
self.get_receipts_for_room.invalidate((room_id, receipt_type))
def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self,
stream_name: str,
instance_name: str,
token: int,
rows: Iterable[Any],
) -> None:
if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(instance_name, token)
for row in rows:
@ -584,7 +597,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
self._remove_old_push_actions_before_txn(
self._remove_old_push_actions_before_txn( # type: ignore[attr-defined]
txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
)
@ -637,7 +650,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"insert_receipt_conv", graph_to_linear
)
async with self._receipts_id_gen.get_next() as stream_id:
async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,

View file

@ -22,6 +22,7 @@ import attr
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import (
DatabasePool,
@ -123,7 +124,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
):
super().__init__(database, db_conn, hs)
self.config = hs.config
self.config: HomeServerConfig = hs.config
# Note: we don't check this sequence for consistency as we'd have to
# call `find_max_generated_user_id_localpart` each time, which is

View file

@ -27,7 +27,6 @@ from typing import (
)
import attr
from frozendict import frozendict
from synapse.api.constants import RelationTypes
from synapse.events import EventBase
@ -41,45 +40,15 @@ from synapse.storage.database import (
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine
from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
from synapse.types import JsonDict, RoomStreamToken, StreamToken
from synapse.types import RoomStreamToken, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation:
# The latest event in the thread.
latest_event: EventBase
# The latest edit to the latest event in the thread.
latest_edit: Optional[EventBase]
# The total number of events in the thread.
count: int
# True if the current user has sent an event to the thread.
current_user_participated: bool
@attr.s(slots=True, auto_attribs=True)
class BundledAggregations:
"""
The bundled aggregations for an event.
Some values require additional processing during serialization.
"""
annotations: Optional[JsonDict] = None
references: Optional[JsonDict] = None
replace: Optional[EventBase] = None
thread: Optional[_ThreadAggregation] = None
def __bool__(self) -> bool:
return bool(self.annotations or self.references or self.replace or self.thread)
class RelationsWorkerStore(SQLBaseStore):
def __init__(
self,
@ -384,7 +353,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
async def _get_applicable_edits(
async def get_applicable_edits(
self, event_ids: Collection[str]
) -> Dict[str, Optional[EventBase]]:
"""Get the most recent edit (if any) that has happened for the given
@ -473,7 +442,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
async def _get_thread_summaries(
async def get_thread_summaries(
self, event_ids: Collection[str]
) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
"""Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
@ -587,7 +556,7 @@ class RelationsWorkerStore(SQLBaseStore):
latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
# Check to see if any of those events are edited.
latest_edits = await self._get_applicable_edits(latest_event_ids.values())
latest_edits = await self.get_applicable_edits(latest_event_ids.values())
# Map to the event IDs to the thread summary.
#
@ -610,7 +579,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
async def _get_threads_participated(
async def get_threads_participated(
self, event_ids: Collection[str], user_id: str
) -> Dict[str, bool]:
"""Get whether the requesting user participated in the given threads.
@ -766,114 +735,6 @@ class RelationsWorkerStore(SQLBaseStore):
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
async def _get_bundled_aggregation_for_event(
self, event: EventBase, user_id: str
) -> Optional[BundledAggregations]:
"""Generate bundled aggregations for an event.
Note that this does not use a cache, but depends on cached methods.
Args:
event: The event to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
Returns:
The bundled aggregations for an event, if bundled aggregations are
enabled and the event can have bundled aggregations.
"""
# Do not bundle aggregations for an event which represents an edit or an
# annotation. It does not make sense for them to have related events.
relates_to = event.content.get("m.relates_to")
if isinstance(relates_to, (dict, frozendict)):
relation_type = relates_to.get("rel_type")
if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
return None
event_id = event.event_id
room_id = event.room_id
# The bundled aggregations to include, a mapping of relation type to a
# type-specific value. Some types include the direct return type here
# while others need more processing during serialization.
aggregations = BundledAggregations()
annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
if annotations.chunk:
aggregations.annotations = await annotations.to_dict(
cast("DataStore", self)
)
references = await self.get_relations_for_event(
event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
aggregations.references = await references.to_dict(cast("DataStore", self))
# Store the bundled aggregations in the event metadata for later use.
return aggregations
async def get_bundled_aggregations(
self, events: Iterable[EventBase], user_id: str
) -> Dict[str, BundledAggregations]:
"""Generate bundled aggregations for events.
Args:
events: The iterable of events to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
Returns:
A map of event ID to the bundled aggregation for the event. Not all
events may have bundled aggregations in the results.
"""
# De-duplicate events by ID to handle the same event requested multiple times.
#
# State events do not get bundled aggregations.
events_by_id = {
event.event_id: event for event in events if not event.is_state()
}
# event ID -> bundled aggregation in non-serialized form.
results: Dict[str, BundledAggregations] = {}
# Fetch other relations per event.
for event in events_by_id.values():
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
if event_result:
results[event.event_id] = event_result
# Fetch any edits (but not for redacted events).
edits = await self._get_applicable_edits(
[
event_id
for event_id, event in events_by_id.items()
if not event.internal_metadata.is_redacted()
]
)
for event_id, edit in edits.items():
results.setdefault(event_id, BundledAggregations()).replace = edit
# Fetch thread summaries.
summaries = await self._get_thread_summaries(events_by_id.keys())
# Only fetch participated for a limited selection based on what had
# summaries.
participated = await self._get_threads_participated(summaries.keys(), user_id)
for event_id, summary in summaries.items():
if summary:
thread_count, latest_thread_event, edit = summary
results.setdefault(
event_id, BundledAggregations()
).thread = _ThreadAggregation(
latest_event=latest_thread_event,
latest_edit=edit,
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
current_user_participated=participated[event_id],
)
return results
class RelationsStore(RelationsWorkerStore):
pass

View file

@ -34,6 +34,7 @@ import attr
from synapse.api.constants import EventContentFields, EventTypes, JoinRules
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
@ -98,7 +99,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
):
super().__init__(database, db_conn, hs)
self.config = hs.config
self.config: HomeServerConfig = hs.config
async def store_room(
self,

View file

@ -63,6 +63,14 @@ _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
@attr.s(frozen=True, slots=True, auto_attribs=True)
class EventIdMembership:
"""Returned by `get_membership_from_event_ids`"""
user_id: str
membership: str
class RoomMemberWorkerStore(EventsWorkerStore):
def __init__(
self,
@ -772,7 +780,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
retcols=("user_id", "display_name", "avatar_url", "event_id"),
keyvalues={"membership": Membership.JOIN},
batch_size=500,
desc="_get_membership_from_event_ids",
desc="_get_joined_profiles_from_event_ids",
)
return {
@ -1000,12 +1008,26 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return set(room_ids)
@cached(max_entries=5000)
async def _get_membership_from_event_id(
self, member_event_id: str
) -> Optional[EventIdMembership]:
raise NotImplementedError()
@cachedList(
cached_method_name="_get_membership_from_event_id", list_name="member_event_ids"
)
async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
) -> List[dict]:
"""Get user_id and membership of a set of event IDs."""
) -> Dict[str, Optional[EventIdMembership]]:
"""Get user_id and membership of a set of event IDs.
return await self.db_pool.simple_select_many_batch(
Returns:
Mapping from event ID to `EventIdMembership` if the event is a
membership event, otherwise the value is None.
"""
rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
@ -1015,6 +1037,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
desc="get_membership_from_event_ids",
)
return {
row["event_id"]: EventIdMembership(
membership=row["membership"], user_id=row["user_id"]
)
for row in rows
}
async def is_local_host_in_room_ignoring_users(
self, room_id: str, ignore_users: Collection[str]
) -> bool:

View file

@ -14,7 +14,7 @@
import logging
import re
from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set
import attr
@ -74,7 +74,7 @@ class SearchWorkerStore(SQLBaseStore):
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
)
args = (
args1 = (
(
entry.event_id,
entry.room_id,
@ -86,14 +86,14 @@ class SearchWorkerStore(SQLBaseStore):
for entry in entries
)
txn.execute_batch(sql, args)
txn.execute_batch(sql, args1)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
args = (
args2 = (
(
entry.event_id,
entry.room_id,
@ -102,7 +102,7 @@ class SearchWorkerStore(SQLBaseStore):
)
for entry in entries
)
txn.execute_batch(sql, args)
txn.execute_batch(sql, args2)
else:
# This should be unreachable.
@ -427,7 +427,7 @@ class SearchStore(SearchBackgroundUpdateStore):
search_query = _parse_query(self.database_engine, search_term)
args = []
args: List[Any] = []
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
@ -496,7 +496,7 @@ class SearchStore(SearchBackgroundUpdateStore):
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak)
events = await self.get_events_as_list(
events = await self.get_events_as_list( # type: ignore[attr-defined]
[r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK,
)
@ -530,7 +530,7 @@ class SearchStore(SearchBackgroundUpdateStore):
room_ids: Collection[str],
search_term: str,
keys: Iterable[str],
limit,
limit: int,
pagination_token: Optional[str] = None,
) -> JsonDict:
"""Performs a full text search over events with given keys.
@ -549,7 +549,7 @@ class SearchStore(SearchBackgroundUpdateStore):
search_query = _parse_query(self.database_engine, search_term)
args = []
args: List[Any] = []
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
@ -573,9 +573,9 @@ class SearchStore(SearchBackgroundUpdateStore):
if pagination_token:
try:
origin_server_ts, stream = pagination_token.split(",")
origin_server_ts = int(origin_server_ts)
stream = int(stream)
origin_server_ts_str, stream_str = pagination_token.split(",")
origin_server_ts = int(origin_server_ts_str)
stream = int(stream_str)
except Exception:
raise SynapseError(400, "Invalid pagination token")
@ -654,7 +654,7 @@ class SearchStore(SearchBackgroundUpdateStore):
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak)
events = await self.get_events_as_list(
events = await self.get_events_as_list( # type: ignore[attr-defined]
[r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK,
)

View file

@ -14,7 +14,7 @@
# limitations under the License.
import collections.abc
import logging
from typing import TYPE_CHECKING, Iterable, Optional, Set
from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
@ -29,7 +29,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter
from synapse.types import StateMap
from synapse.types import JsonDict, StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
@ -241,7 +241,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# We delegate to the cached version
return await self.get_current_state_ids(room_id)
def _get_filtered_current_state_ids_txn(txn):
def _get_filtered_current_state_ids_txn(
txn: LoggingTransaction,
) -> StateMap[str]:
results = {}
sql = """
SELECT type, state_key, event_id FROM current_state_events
@ -281,11 +283,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
event_id = state.get((EventTypes.CanonicalAlias, ""))
if not event_id:
return
return None
event = await self.get_event(event_id, allow_none=True)
if not event:
return
return None
return event.content.get("canonical_alias")
@ -304,7 +306,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
list_name="event_ids",
num_args=1,
)
async def _get_state_group_for_events(self, event_ids):
async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict:
"""Returns mapping event_id -> state_group"""
rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
@ -355,7 +357,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
self.server_name: str = hs.hostname
self.db_pool.updates.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
@ -375,7 +377,9 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
self._background_remove_left_rooms,
)
async def _background_remove_left_rooms(self, progress, batch_size):
async def _background_remove_left_rooms(
self, progress: JsonDict, batch_size: int
) -> int:
"""Background update to delete rows from `current_state_events` and
`event_forward_extremities` tables of rooms that the server is no
longer joined to.
@ -383,7 +387,9 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
last_room_id = progress.get("last_room_id", "")
def _background_remove_left_rooms_txn(txn):
def _background_remove_left_rooms_txn(
txn: LoggingTransaction,
) -> Tuple[bool, Set[str]]:
# get a batch of room ids to consider
sql = """
SELECT DISTINCT room_id FROM current_state_events

View file

@ -108,7 +108,7 @@ class StatsStore(StateDeltasStore):
):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
self.server_name: str = hs.hostname
self.clock = self.hs.get_clock()
self.stats_enabled = hs.config.stats.stats_enabled

View file

@ -26,6 +26,8 @@ from typing import (
cast,
)
from typing_extensions import TypedDict
from synapse.api.errors import StoreError
if TYPE_CHECKING:
@ -40,7 +42,12 @@ from synapse.storage.database import (
from synapse.storage.databases.main.state import StateFilter
from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
from synapse.types import (
JsonDict,
UserProfile,
get_domain_from_id,
get_localpart_from_id,
)
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@ -61,7 +68,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
) -> None:
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
self.server_name: str = hs.hostname
self.db_pool.updates.register_background_update_handler(
"populate_user_directory_createtables",
@ -591,6 +598,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
class SearchResult(TypedDict):
limited: bool
results: List[UserProfile]
class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# How many records do we calculate before sending it to
# add_users_who_share_private_rooms?
@ -718,7 +730,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows)
return list(users)
async def get_shared_rooms_for_users(
async def get_mutual_rooms_for_users(
self, user_id: str, other_user_id: str
) -> Set[str]:
"""
@ -732,7 +744,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
A set of room ID's that the users share.
"""
def _get_shared_rooms_for_users_txn(
def _get_mutual_rooms_for_users_txn(
txn: LoggingTransaction,
) -> List[Dict[str, str]]:
txn.execute(
@ -756,7 +768,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
return rows
rows = await self.db_pool.runInteraction(
"get_shared_rooms_for_users", _get_shared_rooms_for_users_txn
"get_mutual_rooms_for_users", _get_mutual_rooms_for_users_txn
)
return {row["room_id"] for row in rows}
@ -777,7 +789,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
async def search_user_dir(
self, user_id: str, search_term: str, limit: int
) -> JsonDict:
) -> SearchResult:
"""Searches for users in directory
Returns:
@ -910,8 +922,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# This should be unreachable.
raise Exception("Unrecognized database engine")
results = await self.db_pool.execute(
"search_user_dir", self.db_pool.cursor_to_dict, sql, *args
results = cast(
List[UserProfile],
await self.db_pool.execute(
"search_user_dir", self.db_pool.cursor_to_dict, sql, *args
),
)
limited = len(results) > limit

View file

@ -27,7 +27,7 @@ def create_engine(database_config) -> BaseDatabaseEngine:
if name == "psycopg2":
# Note that psycopg2cffi-compat provides the psycopg2 module on pypy.
import psycopg2 # type: ignore
import psycopg2
return PostgresEngine(psycopg2, database_config)

View file

@ -47,17 +47,26 @@ class PostgresEngine(BaseDatabaseEngine):
self.default_isolation_level = (
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
self.config = database_config
@property
def single_threaded(self) -> bool:
return False
def get_db_locale(self, txn):
txn.execute(
"SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
)
collation, ctype = txn.fetchone()
return collation, ctype
def check_database(self, db_conn, allow_outdated_version: bool = False):
# Get the version of PostgreSQL that we're using. As per the psycopg2
# docs: The number is formed by converting the major, minor, and
# revision numbers into two-decimal-digit numbers and appending them
# together. For example, version 8.1.5 will be returned as 80105
self._version = db_conn.server_version
allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
# Are we on a supported PostgreSQL version?
if not allow_outdated_version and self._version < 100000:
@ -72,33 +81,39 @@ class PostgresEngine(BaseDatabaseEngine):
"See docs/postgres.md for more information." % (rows[0][0],)
)
txn.execute(
"SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
)
collation, ctype = txn.fetchone()
collation, ctype = self.get_db_locale(txn)
if collation != "C":
logger.warning(
"Database has incorrect collation of %r. Should be 'C'\n"
"See docs/postgres.md for more information.",
"Database has incorrect collation of %r. Should be 'C'",
collation,
)
if not allow_unsafe_locale:
raise IncorrectDatabaseSetup(
"Database has incorrect collation of %r. Should be 'C'\n"
"See docs/postgres.md for more information. You can override this check by"
"setting 'allow_unsafe_locale' to true in the database config.",
collation,
)
if ctype != "C":
logger.warning(
"Database has incorrect ctype of %r. Should be 'C'\n"
"See docs/postgres.md for more information.",
ctype,
)
if not allow_unsafe_locale:
logger.warning(
"Database has incorrect ctype of %r. Should be 'C'",
ctype,
)
raise IncorrectDatabaseSetup(
"Database has incorrect ctype of %r. Should be 'C'\n"
"See docs/postgres.md for more information. You can override this check by"
"setting 'allow_unsafe_locale' to true in the database config.",
ctype,
)
def check_new_database(self, txn):
"""Gets called when setting up a brand new database. This allows us to
apply stricter checks on new databases versus existing database.
"""
txn.execute(
"SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
)
collation, ctype = txn.fetchone()
collation, ctype = self.get_db_locale(txn)
errors = []

View file

@ -1023,8 +1023,13 @@ class EventsPersistenceStorage:
# Check if any of the changes that we don't have events for are joins.
if events_to_check:
rows = await self.main_store.get_membership_from_event_ids(events_to_check)
is_still_joined = any(row["membership"] == Membership.JOIN for row in rows)
members = await self.main_store.get_membership_from_event_ids(
events_to_check
)
is_still_joined = any(
member and member.membership == Membership.JOIN
for member in members.values()
)
if is_still_joined:
return True
@ -1060,9 +1065,11 @@ class EventsPersistenceStorage:
), event_id in current_state.items()
if typ == EventTypes.Member and not self.is_mine_id(state_key)
]
rows = await self.main_store.get_membership_from_event_ids(remote_event_ids)
members = await self.main_store.get_membership_from_event_ids(remote_event_ids)
potentially_left_users.update(
row["user_id"] for row in rows if row["membership"] == Membership.JOIN
member.user_id
for member in members.values()
if member and member.membership == Membership.JOIN
)
return False