Merge remote-tracking branch 'upstream/release-v1.56'

This commit is contained in:
Tulir Asokan 2022-03-29 13:30:30 +03:00
commit ef6a9c8a70
98 changed files with 6349 additions and 5436 deletions

View file

@ -68,7 +68,7 @@ try:
except ImportError:
pass
__version__ = "1.55.0"
__version__ = "1.56.0rc1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when

View file

@ -261,7 +261,10 @@ class SynapseHomeServer(HomeServer):
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
if name == "metrics" and self.config.metrics.enable_metrics:
resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
metrics_resource: Resource = MetricsResource(RegistryProxy)
if compress:
metrics_resource = gz_wrap(metrics_resource)
resources[METRICS_PREFIX] = metrics_resource
if name == "replication":
resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
@ -348,6 +351,23 @@ def setup(config_options: List[str]) -> SynapseHomeServer:
if config.server.gc_seconds:
synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds
if (
config.registration.enable_registration
and not config.registration.enable_registration_without_verification
):
if (
not config.captcha.enable_registration_captcha
and not config.registration.registrations_require_3pid
and not config.registration.registration_requires_token
):
raise ConfigError(
"You have enabled open registration without any verification. This is a known vector for "
"spam and abuse. If you would like to allow public registration, please consider adding email, "
"captcha, or token-based verification. Otherwise this check can be removed by setting the "
"`enable_registration_without_verification` config option to `true`."
)
hs = SynapseHomeServer(
config.server.server_name,
config=config,

View file

@ -37,6 +37,12 @@ DEFAULT_CONFIG = """\
# 'txn_limit' gives the maximum number of transactions to run per connection
# before reconnecting. Defaults to 0, which means no limit.
#
# 'allow_unsafe_locale' is an option specific to Postgres. Under the default behavior, Synapse will refuse to
# start if the postgres db is set to a non-C locale. You can override this behavior (which is *not* recommended)
# by setting 'allow_unsafe_locale' to true. Note that doing so may corrupt your database. You can find more information
# here: https://matrix-org.github.io/synapse/latest/postgres.html#fixing-incorrect-collate-or-ctype and here:
# https://wiki.postgresql.org/wiki/Locale_data_changes
#
# 'args' gives options which are passed through to the database engine,
# except for options starting 'cp_', which are used to configure the Twisted
# connection pool. For a reference to valid arguments, see:

View file

@ -33,6 +33,10 @@ class RegistrationConfig(Config):
str(config["disable_registration"])
)
self.enable_registration_without_verification = strtobool(
str(config.get("enable_registration_without_verification", False))
)
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
@ -207,10 +211,18 @@ class RegistrationConfig(Config):
# Registration can be rate-limited using the parameters in the "Ratelimiting"
# section of this file.
# Enable registration for new users.
# Enable registration for new users. Defaults to 'false'. It is highly recommended that if you enable registration,
# you use either captcha, email, or token-based verification to verify that new users are not bots. In order to enable registration
# without any verification, you must also set `enable_registration_without_verification`, found below.
#
#enable_registration: false
# Enable registration without email or captcha verification. Note: this option is *not* recommended,
# as registration without verification is a known vector for spam and abuse. Defaults to false. Has no effect
# unless `enable_registration` is also enabled.
#
#enable_registration_without_verification: true
# Time that a user's session remains valid for, after they log in.
#
# Note that this is not currently compatible with guest logins.

View file

@ -676,6 +676,10 @@ class ServerConfig(Config):
):
raise ConfigError("'custom_template_directory' must be a string")
self.use_account_validity_in_account_status: bool = (
config.get("use_account_validity_in_account_status") or False
)
def has_tls_listener(self) -> bool:
return any(listener.tls for listener in self.listeners)

View file

@ -25,8 +25,8 @@ logger = logging.getLogger(__name__)
LEGACY_SPAM_CHECKER_WARNING = """
This server is using a spam checker module that is implementing the deprecated spam
checker interface. Please check with the module's maintainer to see if a new version
supporting Synapse's generic modules system is available.
For more information, please see https://matrix-org.github.io/synapse/latest/modules.html
supporting Synapse's generic modules system is available. For more information, please
see https://matrix-org.github.io/synapse/latest/modules/index.html
---------------------------------------------------------------------------------------"""

View file

@ -21,7 +21,6 @@ from typing import (
Awaitable,
Callable,
Collection,
Dict,
List,
Optional,
Tuple,
@ -31,7 +30,7 @@ from typing import (
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.media_storage import ReadableFileWrapper
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias
from synapse.types import RoomAlias, UserProfile
from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
@ -50,7 +49,7 @@ USER_MAY_SEND_3PID_INVITE_CALLBACK = Callable[[str, str, str, str], Awaitable[bo
USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]]
USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[[str, RoomAlias], Awaitable[bool]]
USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]]
CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[Dict[str, str]], Awaitable[bool]]
CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[UserProfile], Awaitable[bool]]
LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[
[
Optional[dict],
@ -383,7 +382,7 @@ class SpamChecker:
return True
async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
async def check_username_for_spam(self, user_profile: UserProfile) -> bool:
"""Checks if a user ID or display name are considered "spammy" by this server.
If the server considers a username spammy, then it will not be included in

View file

@ -38,8 +38,8 @@ from synapse.util.frozenutils import unfreeze
from . import EventBase
if TYPE_CHECKING:
from synapse.handlers.relations import BundledAggregations
from synapse.server import HomeServer
from synapse.storage.databases.main.relations import BundledAggregations
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'

View file

@ -22,7 +22,6 @@ from typing import (
Callable,
Collection,
Dict,
Iterable,
List,
Optional,
Tuple,
@ -577,10 +576,10 @@ class FederationServer(FederationBase):
async def _on_context_state_request_compute(
self, room_id: str, event_id: Optional[str]
) -> Dict[str, list]:
pdus: Collection[EventBase]
if event_id:
pdus: Iterable[EventBase] = await self.handler.get_state_for_pdu(
room_id, event_id
)
event_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
pdus = await self.store.get_events_as_list(event_ids)
else:
pdus = (await self.state.get_current_state(room_id)).values()
@ -1093,7 +1092,7 @@ class FederationServer(FederationBase):
# has started processing).
while True:
async with lock:
logger.info("handling received PDU: %s", event)
logger.info("handling received PDU in room %s: %s", room_id, event)
try:
with nested_logging_context(event.event_id):
await self._federation_event_handler.on_receive_pdu(

View file

@ -26,6 +26,10 @@ class AccountHandler:
self._main_store = hs.get_datastores().main
self._is_mine = hs.is_mine
self._federation_client = hs.get_federation_client()
self._use_account_validity_in_account_status = (
hs.config.server.use_account_validity_in_account_status
)
self._account_validity_handler = hs.get_account_validity_handler()
async def get_account_statuses(
self,
@ -106,6 +110,13 @@ class AccountHandler:
"deactivated": userinfo.is_deactivated,
}
if self._use_account_validity_in_account_status:
status[
"org.matrix.expired"
] = await self._account_validity_handler.is_user_expired(
user_id.to_string()
)
return status
async def _get_remote_account_statuses(

View file

@ -950,54 +950,35 @@ class FederationHandler:
return event
async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
"""Returns the state at the event. i.e. not including said event."""
event = await self.store.get_event(event_id, check_room_id=room_id)
state_groups = await self.state_store.get_state_groups(room_id, [event_id])
if state_groups:
_, state = list(state_groups.items()).pop()
results = {(e.type, e.state_key): e for e in state}
if event.is_state():
# Get previous state
if "replaces_state" in event.unsigned:
prev_id = event.unsigned["replaces_state"]
if prev_id != event.event_id:
prev_event = await self.store.get_event(prev_id)
results[(event.type, event.state_key)] = prev_event
else:
del results[(event.type, event.state_key)]
res = list(results.values())
return res
else:
return []
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event."""
event = await self.store.get_event(event_id, check_room_id=room_id)
if event.internal_metadata.outlier:
raise NotFoundError("State not known at event %s" % (event_id,))
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
if state_groups:
_, state = list(state_groups.items()).pop()
results = state
# get_state_groups_ids should return exactly one result
assert len(state_groups) == 1
if event.is_state():
# Get previous state
if "replaces_state" in event.unsigned:
prev_id = event.unsigned["replaces_state"]
if prev_id != event.event_id:
results[(event.type, event.state_key)] = prev_id
else:
results.pop((event.type, event.state_key), None)
state_map = next(iter(state_groups.values()))
return list(results.values())
else:
return []
state_key = event.get_state_key()
if state_key is not None:
# the event was not rejected (get_event raises a NotFoundError for rejected
# events) so the state at the event should include the event itself.
assert (
state_map.get((event.type, state_key)) == event.event_id
), "State at event did not include event itself"
# ... but we need the state *before* that event
if "replaces_state" in event.unsigned:
prev_id = event.unsigned["replaces_state"]
state_map[(event.type, state_key)] = prev_id
else:
del state_map[(event.type, state_key)]
return list(state_map.values())
async def on_backfill_request(
self, origin: str, room_id: str, pdu_list: List[str], limit: int

View file

@ -495,6 +495,7 @@ class EventCreationHandler:
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
state_event_ids: Optional[List[str]] = None,
require_consent: bool = True,
outlier: bool = False,
historical: bool = False,
@ -529,6 +530,15 @@ class EventCreationHandler:
If non-None, prev_event_ids must also be provided.
state_event_ids:
The full state at a given event. This is used particularly by the MSC2716
/batch_send endpoint. One use case is with insertion events which float at
the beginning of a historical batch and don't have any `prev_events` to
derive from; we add all of these state events as the explicit state so the
rest of the historical batch can inherit the same state and state_group.
This should normally be left as None, which will cause the auth_event_ids
to be calculated based on the room state at the prev_events.
require_consent: Whether to check if the requester has
consented to the privacy policy.
@ -614,6 +624,7 @@ class EventCreationHandler:
allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
state_event_ids=state_event_ids,
depth=depth,
)
@ -773,7 +784,7 @@ class EventCreationHandler:
event_dict: dict,
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
state_event_ids: Optional[List[str]] = None,
ratelimit: bool = True,
txn_id: Optional[str] = None,
ignore_shadow_ban: bool = False,
@ -797,12 +808,14 @@ class EventCreationHandler:
The event IDs to use as the prev events.
Should normally be left as None to automatically request them
from the database.
auth_event_ids:
The event ids to use as the auth_events for the new event.
Should normally be left as None, which will cause them to be calculated
based on the room state at the prev_events.
If non-None, prev_event_ids must also be provided.
state_event_ids:
The full state at a given event. This is used particularly by the MSC2716
/batch_send endpoint. One use case is with insertion events which float at
the beginning of a historical batch and don't have any `prev_events` to
derive from; we add all of these state events as the explicit state so the
rest of the historical batch can inherit the same state and state_group.
This should normally be left as None, which will cause the auth_event_ids
to be calculated based on the room state at the prev_events.
ratelimit: Whether to rate limit this send.
txn_id: The transaction ID.
ignore_shadow_ban: True if shadow-banned users should be allowed to
@ -858,8 +871,9 @@ class EventCreationHandler:
requester,
event_dict,
txn_id=txn_id,
allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
state_event_ids=state_event_ids,
outlier=outlier,
historical=historical,
depth=depth,
@ -895,6 +909,7 @@ class EventCreationHandler:
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
state_event_ids: Optional[List[str]] = None,
depth: Optional[int] = None,
) -> Tuple[EventBase, EventContext]:
"""Create a new event for a local client
@ -917,6 +932,15 @@ class EventCreationHandler:
Should normally be left as None, which will cause them to be calculated
based on the room state at the prev_events.
state_event_ids:
The full state at a given event. This is used particularly by the MSC2716
/batch_send endpoint. One use case is with insertion events which float at
the beginning of a historical batch and don't have any `prev_events` to
derive from; we add all of these state events as the explicit state so the
rest of the historical batch can inherit the same state and state_group.
This should normally be left as None, which will cause the auth_event_ids
to be calculated based on the room state at the prev_events.
depth: Override the depth used to order the event in the DAG.
Should normally be set to None, which will cause the depth to be calculated
based on the prev_events.
@ -924,31 +948,26 @@ class EventCreationHandler:
Returns:
Tuple of created event, context
"""
# Strip down the auth_event_ids to only what we need to auth the event.
# Strip down the state_event_ids to only what we need to auth the event.
# For example, we don't need extra m.room.member that don't match event.sender
full_state_ids_at_event = None
if auth_event_ids is not None:
# If auth events are provided, prev events must be also.
if state_event_ids is not None:
# Do a quick check to make sure that prev_event_ids is present to
# make the type-checking around `builder.build` happy.
# prev_event_ids could be an empty array though.
assert prev_event_ids is not None
# Copy the full auth state before it stripped down
full_state_ids_at_event = auth_event_ids.copy()
temp_event = await builder.build(
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
auth_event_ids=state_event_ids,
depth=depth,
)
auth_events = await self.store.get_events_as_list(auth_event_ids)
state_events = await self.store.get_events_as_list(state_event_ids)
# Create a StateMap[str]
auth_event_state_map = {
(e.type, e.state_key): e.event_id for e in auth_events
}
# Actually strip down and use the necessary auth events
state_map = {(e.type, e.state_key): e.event_id for e in state_events}
# Actually strip down and only use the necessary auth events
auth_event_ids = self._event_auth_handler.compute_auth_events(
event=temp_event,
current_state_ids=auth_event_state_map,
current_state_ids=state_map,
for_verification=False,
)
@ -991,12 +1010,16 @@ class EventCreationHandler:
context = EventContext.for_outlier()
elif (
event.type == EventTypes.MSC2716_INSERTION
and full_state_ids_at_event
and state_event_ids
and builder.internal_metadata.is_historical()
):
# Add explicit state to the insertion event so it has state to derive
# from even though it's floating with no `prev_events`. The rest of
# the batch can derive from this state and state_group.
#
# TODO(faster_joins): figure out how this works, and make sure that the
# old state is complete.
old_state = await self.store.get_events_as_list(full_state_ids_at_event)
old_state = await self.store.get_events_as_list(state_event_ids)
context = await self.state.compute_event_context(event, old_state=old_state)
else:
context = await self.state.compute_event_context(event)

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set
import attr
@ -134,6 +134,7 @@ class PaginationHandler:
self.clock = hs.get_clock()
self._server_name = hs.hostname
self._room_shutdown_handler = hs.get_room_shutdown_handler()
self._relations_handler = hs.get_relations_handler()
self.pagination_lock = ReadWriteLock()
# IDs of rooms in which there currently an active purge *or delete* operation.
@ -422,7 +423,7 @@ class PaginationHandler:
pagin_config: PaginationConfig,
as_client_event: bool = True,
event_filter: Optional[Filter] = None,
) -> Dict[str, Any]:
) -> JsonDict:
"""Get messages in a room.
Args:
@ -431,6 +432,7 @@ class PaginationHandler:
pagin_config: The pagination config rules to apply, if any.
as_client_event: True to get events in client-server format.
event_filter: Filter to apply to results or None
Returns:
Pagination API results
"""
@ -538,7 +540,9 @@ class PaginationHandler:
state_dict = await self.store.get_events(list(state_ids.values()))
state = state_dict.values()
aggregations = await self.store.get_bundled_aggregations(events, user_id)
aggregations = await self._relations_handler.get_bundled_aggregations(
events, user_id
)
time_now = self.clock.time_msec()

View file

@ -336,12 +336,18 @@ class ProfileHandler:
"""Check that the size and content type of the avatar at the given MXC URI are
within the configured limits.
If the given `mxc` is empty, no checks are performed. (Users are always able to
unset their avatar.)
Args:
mxc: The MXC URI at which the avatar can be found.
Returns:
A boolean indicating whether the file can be allowed to be set as an avatar.
"""
if mxc == "":
return True
if not self.max_avatar_size and not self.allowed_avatar_mimetypes:
return True

View file

@ -0,0 +1,271 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast
import attr
from frozendict import frozendict
from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.types import JsonDict, Requester, StreamToken
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation:
# The latest event in the thread.
latest_event: EventBase
# The latest edit to the latest event in the thread.
latest_edit: Optional[EventBase]
# The total number of events in the thread.
count: int
# True if the current user has sent an event to the thread.
current_user_participated: bool
@attr.s(slots=True, auto_attribs=True)
class BundledAggregations:
"""
The bundled aggregations for an event.
Some values require additional processing during serialization.
"""
annotations: Optional[JsonDict] = None
references: Optional[JsonDict] = None
replace: Optional[EventBase] = None
thread: Optional[_ThreadAggregation] = None
def __bool__(self) -> bool:
return bool(self.annotations or self.references or self.replace or self.thread)
class RelationsHandler:
def __init__(self, hs: "HomeServer"):
self._main_store = hs.get_datastores().main
self._storage = hs.get_storage()
self._auth = hs.get_auth()
self._clock = hs.get_clock()
self._event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
async def get_relations(
self,
requester: Requester,
event_id: str,
room_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
aggregation_key: Optional[str] = None,
limit: int = 5,
direction: str = "b",
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
) -> JsonDict:
"""Get related events of a event, ordered by topological ordering.
TODO Accept a PaginationConfig instead of individual pagination parameters.
Args:
requester: The user requesting the relations.
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
aggregation_key: Only fetch events with this aggregation key, if given.
limit: Only fetch the most recent `limit` events.
direction: Whether to fetch the most recent first (`"b"`) or the
oldest first (`"f"`).
from_token: Fetch rows from the given token, or from the start if None.
to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
The pagination chunk.
"""
user_id = requester.user.to_string()
# TODO Properly handle a user leaving a room.
(_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True
)
# This gets the original event and checks that a) the event exists and
# b) the user is allowed to view it.
event = await self._event_handler.get_event(requester.user, room_id, event_id)
if event is None:
raise SynapseError(404, "Unknown parent event.")
pagination_chunk = await self._main_store.get_relations_for_event(
event_id=event_id,
event=event,
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
aggregation_key=aggregation_key,
limit=limit,
direction=direction,
from_token=from_token,
to_token=to_token,
)
events = await self._main_store.get_events_as_list(
[c["event_id"] for c in pagination_chunk.chunk]
)
events = await filter_events_for_client(
self._storage, user_id, events, is_peeking=(member_event_id is None)
)
now = self._clock.time_msec()
# Do not bundle aggregations when retrieving the original event because
# we want the content before relations are applied to it.
original_event = self._event_serializer.serialize_event(
event, now, bundle_aggregations=None
)
# The relations returned for the requested event do include their
# bundled aggregations.
aggregations = await self.get_bundled_aggregations(
events, requester.user.to_string()
)
serialized_events = self._event_serializer.serialize_events(
events, now, bundle_aggregations=aggregations
)
return_value = await pagination_chunk.to_dict(self._main_store)
return_value["chunk"] = serialized_events
return_value["original_event"] = original_event
return return_value
async def _get_bundled_aggregation_for_event(
self, event: EventBase, user_id: str
) -> Optional[BundledAggregations]:
"""Generate bundled aggregations for an event.
Note that this does not use a cache, but depends on cached methods.
Args:
event: The event to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
Returns:
The bundled aggregations for an event, if bundled aggregations are
enabled and the event can have bundled aggregations.
"""
# Do not bundle aggregations for an event which represents an edit or an
# annotation. It does not make sense for them to have related events.
relates_to = event.content.get("m.relates_to")
if isinstance(relates_to, (dict, frozendict)):
relation_type = relates_to.get("rel_type")
if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
return None
event_id = event.event_id
room_id = event.room_id
# The bundled aggregations to include, a mapping of relation type to a
# type-specific value. Some types include the direct return type here
# while others need more processing during serialization.
aggregations = BundledAggregations()
annotations = await self._main_store.get_aggregation_groups_for_event(
event_id, room_id
)
if annotations.chunk:
aggregations.annotations = await annotations.to_dict(
cast("DataStore", self)
)
references = await self._main_store.get_relations_for_event(
event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
aggregations.references = await references.to_dict(cast("DataStore", self))
# Store the bundled aggregations in the event metadata for later use.
return aggregations
async def get_bundled_aggregations(
self, events: Iterable[EventBase], user_id: str
) -> Dict[str, BundledAggregations]:
"""Generate bundled aggregations for events.
Args:
events: The iterable of events to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
Returns:
A map of event ID to the bundled aggregation for the event. Not all
events may have bundled aggregations in the results.
"""
# De-duplicate events by ID to handle the same event requested multiple times.
#
# State events do not get bundled aggregations.
events_by_id = {
event.event_id: event for event in events if not event.is_state()
}
# event ID -> bundled aggregation in non-serialized form.
results: Dict[str, BundledAggregations] = {}
# Fetch other relations per event.
for event in events_by_id.values():
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
if event_result:
results[event.event_id] = event_result
# Fetch any edits (but not for redacted events).
edits = await self._main_store.get_applicable_edits(
[
event_id
for event_id, event in events_by_id.items()
if not event.internal_metadata.is_redacted()
]
)
for event_id, edit in edits.items():
results.setdefault(event_id, BundledAggregations()).replace = edit
# Fetch thread summaries.
summaries = await self._main_store.get_thread_summaries(events_by_id.keys())
# Only fetch participated for a limited selection based on what had
# summaries.
participated = await self._main_store.get_threads_participated(
[event_id for event_id, summary in summaries.items() if summary], user_id
)
for event_id, summary in summaries.items():
if summary:
thread_count, latest_thread_event, edit = summary
results.setdefault(
event_id, BundledAggregations()
).thread = _ThreadAggregation(
latest_event=latest_thread_event,
latest_edit=edit,
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
current_user_participated=participated[event_id],
)
return results

View file

@ -60,8 +60,8 @@ from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
from synapse.federation.federation_client import InvalidResponseError
from synapse.handlers.federation import get_domains_from_state
from synapse.handlers.relations import BundledAggregations
from synapse.rest.admin._base import assert_user_is_admin
from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.state import StateFilter
from synapse.streams import EventSource
from synapse.types import (
@ -1128,6 +1128,7 @@ class RoomContextHandler:
self.store = hs.get_datastores().main
self.storage = hs.get_storage()
self.state_store = self.storage.state
self._relations_handler = hs.get_relations_handler()
async def get_event_context(
self,
@ -1200,7 +1201,7 @@ class RoomContextHandler:
event = filtered[0]
# Fetch the aggregations.
aggregations = await self.store.get_bundled_aggregations(
aggregations = await self._relations_handler.get_bundled_aggregations(
itertools.chain(events_before, (event,), events_after),
user.to_string(),
)

View file

@ -123,12 +123,11 @@ class RoomBatchHandler:
return create_requester(user_id, app_service=app_service)
async def get_most_recent_auth_event_ids_from_event_id_list(
async def get_most_recent_full_state_ids_from_event_id_list(
self, event_ids: List[str]
) -> List[str]:
"""Find the most recent auth event ids (derived from state events) that
allowed that message to be sent. We will use this as a base
to auth our historical messages against.
"""Find the most recent event_id and grab the full state at that event.
We will use this as a base to auth our historical messages against.
Args:
event_ids: List of event ID's to look at
@ -138,38 +137,37 @@ class RoomBatchHandler:
"""
(
most_recent_prev_event_id,
most_recent_event_id,
_,
) = await self.store.get_max_depth_of(event_ids)
# mapping from (type, state_key) -> state_event_id
prev_state_map = await self.state_store.get_state_ids_for_event(
most_recent_prev_event_id
most_recent_event_id
)
# List of state event ID's
prev_state_ids = list(prev_state_map.values())
auth_event_ids = prev_state_ids
full_state_ids = list(prev_state_map.values())
return auth_event_ids
return full_state_ids
async def persist_state_events_at_start(
self,
state_events_at_start: List[JsonDict],
room_id: str,
initial_auth_event_ids: List[str],
initial_state_event_ids: List[str],
app_service_requester: Requester,
) -> List[str]:
"""Takes all `state_events_at_start` event dictionaries and creates/persists
them as floating state events which don't resolve into the current room state.
They are floating because they reference a fake prev_event which doesn't connect
to the normal DAG at all.
them in a floating state event chain which don't resolve into the current room
state. They are floating because they reference no prev_events and are marked
as outliers which disconnects them from the normal DAG.
Args:
state_events_at_start:
room_id: Room where you want the events persisted in.
initial_auth_event_ids: These will be the auth_events for the first
state event created. Each event created afterwards will be
added to the list of auth events for the next state event
created.
initial_state_event_ids:
The base set of state for the historical batch which the floating
state chain will derive from. This should probably be the state
from the `prev_event` defined by `/batch_send?prev_event_id=$abc`.
app_service_requester: The requester of an application service.
Returns:
@ -178,7 +176,7 @@ class RoomBatchHandler:
assert app_service_requester.app_service
state_event_ids_at_start = []
auth_event_ids = initial_auth_event_ids.copy()
state_event_ids = initial_state_event_ids.copy()
# Make the state events float off on their own by specifying no
# prev_events for the first one in the chain so we don't have a bunch of
@ -191,9 +189,7 @@ class RoomBatchHandler:
)
logger.debug(
"RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s",
state_event,
auth_event_ids,
"RoomBatchSendEventRestServlet inserting state_event=%s", state_event
)
event_dict = {
@ -219,16 +215,26 @@ class RoomBatchHandler:
room_id=room_id,
action=membership,
content=event_dict["content"],
# Mark as an outlier to disconnect it from the normal DAG
# and not show up between batches of history.
outlier=True,
historical=True,
# Only the first event in the chain should be floating.
# Only the first event in the state chain should be floating.
# The rest should hang off each other in a chain.
allow_no_prev_events=index == 0,
prev_event_ids=prev_event_ids_for_state_chain,
# Since each state event is marked as an outlier, the
# `EventContext.for_outlier()` won't have any `state_ids`
# set and therefore can't derive any state even though the
# prev_events are set. Also since the first event in the
# state chain is floating with no `prev_events`, it can't
# derive state from anywhere automatically. So we need to
# set some state explicitly.
#
# Make sure to use a copy of this list because we modify it
# later in the loop here. Otherwise it will be the same
# reference and also update in the event when we append later.
auth_event_ids=auth_event_ids.copy(),
state_event_ids=state_event_ids.copy(),
)
else:
# TODO: Add some complement tests that adds state that is not member joins
@ -242,21 +248,31 @@ class RoomBatchHandler:
state_event["sender"], app_service_requester.app_service
),
event_dict,
# Mark as an outlier to disconnect it from the normal DAG
# and not show up between batches of history.
outlier=True,
historical=True,
# Only the first event in the chain should be floating.
# Only the first event in the state chain should be floating.
# The rest should hang off each other in a chain.
allow_no_prev_events=index == 0,
prev_event_ids=prev_event_ids_for_state_chain,
# Since each state event is marked as an outlier, the
# `EventContext.for_outlier()` won't have any `state_ids`
# set and therefore can't derive any state even though the
# prev_events are set. Also since the first event in the
# state chain is floating with no `prev_events`, it can't
# derive state from anywhere automatically. So we need to
# set some state explicitly.
#
# Make sure to use a copy of this list because we modify it
# later in the loop here. Otherwise it will be the same
# reference and also update in the event when we append later.
auth_event_ids=auth_event_ids.copy(),
state_event_ids=state_event_ids.copy(),
)
event_id = event.event_id
state_event_ids_at_start.append(event_id)
auth_event_ids.append(event_id)
state_event_ids.append(event_id)
# Connect all the state in a floating chain
prev_event_ids_for_state_chain = [event_id]
@ -267,7 +283,7 @@ class RoomBatchHandler:
events_to_create: List[JsonDict],
room_id: str,
inherited_depth: int,
auth_event_ids: List[str],
initial_state_event_ids: List[str],
app_service_requester: Requester,
) -> List[str]:
"""Create and persists all events provided sequentially. Handles the
@ -283,8 +299,10 @@ class RoomBatchHandler:
room_id: Room where you want the events persisted in.
inherited_depth: The depth to create the events at (you will
probably by calling inherit_depth_from_prev_ids(...)).
auth_event_ids: Define which events allow you to create the given
event in the room.
initial_state_event_ids:
This is used to set explicit state for the insertion event at
the start of the historical batch since it's floating with no
prev_events to derive state from automatically.
app_service_requester: The requester of an application service.
Returns:
@ -292,6 +310,11 @@ class RoomBatchHandler:
"""
assert app_service_requester.app_service
# We expect the first event in a historical batch to be an insertion event
assert events_to_create[0]["type"] == EventTypes.MSC2716_INSERTION
# We expect the last event in a historical batch to be an batch event
assert events_to_create[-1]["type"] == EventTypes.MSC2716_BATCH
# Make the historical event chain float off on its own by specifying no
# prev_events for the first event in the chain which causes the HS to
# ask for the state at the start of the batch later.
@ -323,11 +346,16 @@ class RoomBatchHandler:
ev["sender"], app_service_requester.app_service
),
event_dict,
# Only the first event in the chain should be floating.
# The rest should hang off each other in a chain.
# Only the first event (which is the insertion event) in the
# chain should be floating. The rest should hang off each other
# in a chain.
allow_no_prev_events=index == 0,
prev_event_ids=event_dict.get("prev_events"),
auth_event_ids=auth_event_ids,
# Since the first event (which is the insertion event) in the
# chain is floating with no `prev_events`, it can't derive state
# from anywhere automatically. So we need to set some state
# explicitly.
state_event_ids=initial_state_event_ids if index == 0 else None,
historical=True,
depth=inherited_depth,
)
@ -345,10 +373,9 @@ class RoomBatchHandler:
)
logger.debug(
"RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s",
"RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s",
event,
prev_event_ids,
auth_event_ids,
)
events_to_persist.append((event, context))
@ -378,12 +405,12 @@ class RoomBatchHandler:
room_id: str,
batch_id_to_connect_to: str,
inherited_depth: int,
auth_event_ids: List[str],
initial_state_event_ids: List[str],
app_service_requester: Requester,
) -> Tuple[List[str], str]:
"""
Handles creating and persisting all of the historical events as well
as insertion and batch meta events to make the batch navigable in the DAG.
Handles creating and persisting all of the historical events as well as
insertion and batch meta events to make the batch navigable in the DAG.
Args:
events_to_create: List of historical events to create in JSON
@ -393,8 +420,13 @@ class RoomBatchHandler:
want this batch to connect to.
inherited_depth: The depth to create the events at (you will
probably by calling inherit_depth_from_prev_ids(...)).
auth_event_ids: Define which events allow you to create the given
event in the room.
initial_state_event_ids:
This is used to set explicit state for the insertion event at
the start of the historical batch since it's floating with no
prev_events to derive state from automatically. This should
probably be the state from the `prev_event` defined by
`/batch_send?prev_event_id=$abc` plus the outcome of
`persist_state_events_at_start`
app_service_requester: The requester of an application service.
Returns:
@ -440,7 +472,7 @@ class RoomBatchHandler:
events_to_create=events_to_create,
room_id=room_id,
inherited_depth=inherited_depth,
auth_event_ids=auth_event_ids,
initial_state_event_ids=initial_state_event_ids,
app_service_requester=app_service_requester,
)

View file

@ -271,7 +271,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
membership: str,
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
state_event_ids: Optional[List[str]] = None,
txn_id: Optional[str] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
@ -294,10 +294,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
events should have a prev_event and we should only use this in special
cases like MSC2716.
prev_event_ids: The event IDs to use as the prev events
auth_event_ids:
The event ids to use as the auth_events for the new event.
Should normally be left as None, which will cause them to be calculated
based on the room state at the prev_events.
state_event_ids:
The full state at a given event. This is used particularly by the MSC2716
/batch_send endpoint. One use case is the historical `state_events_at_start`;
since each is marked as an `outlier`, the `EventContext.for_outlier()` won't
have any `state_ids` set and therefore can't derive any state even though the
prev_events are set so we need to set them ourself via this argument.
This should normally be left as None, which will cause the auth_event_ids
to be calculated based on the room state at the prev_events.
txn_id:
ratelimit:
@ -352,7 +356,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
txn_id=txn_id,
allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
state_event_ids=state_event_ids,
require_consent=require_consent,
outlier=outlier,
historical=historical,
@ -455,7 +459,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
historical: bool = False,
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
state_event_ids: Optional[List[str]] = None,
) -> Tuple[str, int]:
"""Update a user's membership in a room.
@ -483,10 +487,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
events should have a prev_event and we should only use this in special
cases like MSC2716.
prev_event_ids: The event IDs to use as the prev events
auth_event_ids:
The event ids to use as the auth_events for the new event.
Should normally be left as None, which will cause them to be calculated
based on the room state at the prev_events.
state_event_ids:
The full state at a given event. This is used particularly by the MSC2716
/batch_send endpoint. One use case is the historical `state_events_at_start`;
since each is marked as an `outlier`, the `EventContext.for_outlier()` won't
have any `state_ids` set and therefore can't derive any state even though the
prev_events are set so we need to set them ourself via this argument.
This should normally be left as None, which will cause the auth_event_ids
to be calculated based on the room state at the prev_events.
Returns:
A tuple of the new event ID and stream ID.
@ -525,7 +533,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
historical=historical,
allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
state_event_ids=state_event_ids,
)
return result
@ -547,7 +555,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
historical: bool = False,
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
state_event_ids: Optional[List[str]] = None,
) -> Tuple[str, int]:
"""Helper for update_membership.
@ -577,10 +585,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
events should have a prev_event and we should only use this in special
cases like MSC2716.
prev_event_ids: The event IDs to use as the prev events
auth_event_ids:
The event ids to use as the auth_events for the new event.
Should normally be left as None, which will cause them to be calculated
based on the room state at the prev_events.
state_event_ids:
The full state at a given event. This is used particularly by the MSC2716
/batch_send endpoint. One use case is the historical `state_events_at_start`;
since each is marked as an `outlier`, the `EventContext.for_outlier()` won't
have any `state_ids` set and therefore can't derive any state even though the
prev_events are set so we need to set them ourself via this argument.
This should normally be left as None, which will cause the auth_event_ids
to be calculated based on the room state at the prev_events.
Returns:
A tuple of the new event ID and stream ID.
@ -707,7 +719,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
ratelimit=ratelimit,
allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
state_event_ids=state_event_ids,
content=content,
require_consent=require_consent,
outlier=outlier,
@ -931,7 +943,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
txn_id=txn_id,
ratelimit=ratelimit,
prev_event_ids=latest_event_ids,
auth_event_ids=auth_event_ids,
state_event_ids=state_event_ids,
content=content,
require_consent=require_consent,
outlier=outlier,

View file

@ -54,6 +54,7 @@ class SearchHandler:
self.clock = hs.get_clock()
self.hs = hs
self._event_serializer = hs.get_event_client_serializer()
self._relations_handler = hs.get_relations_handler()
self.storage = hs.get_storage()
self.state_store = self.storage.state
self.auth = hs.get_auth()
@ -354,7 +355,7 @@ class SearchHandler:
aggregations = None
if self._msc3666_enabled:
aggregations = await self.store.get_bundled_aggregations(
aggregations = await self._relations_handler.get_bundled_aggregations(
# Generate an iterable of EventBase for all the events that will be
# returned, including contextual events.
itertools.chain(

View file

@ -28,16 +28,16 @@ from typing import (
import attr
from prometheus_client import Counter
from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes
from synapse.api.constants import EventTypes, Membership, ReceiptTypes
from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.handlers.relations import BundledAggregations
from synapse.logging.context import current_context
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.databases.main.event_push_actions import NotifCounts
from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
from synapse.types import (
@ -269,6 +269,7 @@ class SyncHandler:
self.store = hs.get_datastores().main
self.notifier = hs.get_notifier()
self.presence_handler = hs.get_presence_handler()
self._relations_handler = hs.get_relations_handler()
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
@ -638,8 +639,10 @@ class SyncHandler:
# as clients will have all the necessary information.
bundled_aggregations = None
if limited or newly_joined_room:
bundled_aggregations = await self.store.get_bundled_aggregations(
recents, sync_config.user.to_string()
bundled_aggregations = (
await self._relations_handler.get_bundled_aggregations(
recents, sync_config.user.to_string()
)
)
return TimelineBatch(
@ -1600,7 +1603,7 @@ class SyncHandler:
return set(), set(), set(), set()
# 3. Work out which rooms need reporting in the sync response.
ignored_users = await self._get_ignored_users(user_id)
ignored_users = await self.store.ignored_users(user_id)
if since_token:
room_changes = await self._get_rooms_changed(
sync_result_builder, ignored_users
@ -1626,7 +1629,6 @@ class SyncHandler:
logger.debug("Generating room entry for %s", room_entry.room_id)
await self._generate_room_entry(
sync_result_builder,
ignored_users,
room_entry,
ephemeral=ephemeral_by_room.get(room_entry.room_id, []),
tags=tags_by_room.get(room_entry.room_id),
@ -1656,29 +1658,6 @@ class SyncHandler:
newly_left_users,
)
async def _get_ignored_users(self, user_id: str) -> FrozenSet[str]:
"""Retrieve the users ignored by the given user from their global account_data.
Returns an empty set if
- there is no global account_data entry for ignored_users
- there is such an entry, but it's not a JSON object.
"""
# TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead?
ignored_account_data = (
await self.store.get_global_account_data_by_type_for_user(
user_id=user_id, data_type=AccountDataTypes.IGNORED_USER_LIST
)
)
# If there is ignored users account data and it matches the proper type,
# then use it.
ignored_users: FrozenSet[str] = frozenset()
if ignored_account_data:
ignored_users_data = ignored_account_data.get("ignored_users", {})
if isinstance(ignored_users_data, dict):
ignored_users = frozenset(ignored_users_data.keys())
return ignored_users
async def _have_rooms_changed(
self, sync_result_builder: "SyncResultBuilder"
) -> bool:
@ -2021,7 +2000,6 @@ class SyncHandler:
async def _generate_room_entry(
self,
sync_result_builder: "SyncResultBuilder",
ignored_users: FrozenSet[str],
room_builder: "RoomSyncResultBuilder",
ephemeral: List[JsonDict],
tags: Optional[Dict[str, Dict[str, Any]]],
@ -2050,7 +2028,6 @@ class SyncHandler:
Args:
sync_result_builder
ignored_users: Set of users ignored by user.
room_builder
ephemeral: List of new ephemeral events for room
tags: List of *all* tags for room, or None if there has been

View file

@ -19,8 +19,8 @@ import synapse.metrics
from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership
from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main.user_directory import SearchResult
from synapse.storage.roommember import ProfileInfo
from synapse.types import JsonDict
from synapse.util.metrics import Measure
if TYPE_CHECKING:
@ -78,7 +78,7 @@ class UserDirectoryHandler(StateDeltasHandler):
async def search_users(
self, user_id: str, search_term: str, limit: int
) -> JsonDict:
) -> SearchResult:
"""Searches for users in directory
Returns:

View file

@ -111,6 +111,7 @@ from synapse.types import (
StateMap,
UserID,
UserInfo,
UserProfile,
create_requester,
)
from synapse.util import Clock
@ -150,6 +151,7 @@ __all__ = [
"EventBase",
"StateMap",
"ProfileInfo",
"UserProfile",
]
logger = logging.getLogger(__name__)
@ -609,15 +611,18 @@ class ModuleApi:
localpart: str,
displayname: Optional[str] = None,
emails: Optional[List[str]] = None,
admin: bool = False,
) -> "defer.Deferred[str]":
"""Registers a new user with given localpart and optional displayname, emails.
Added in Synapse v1.2.0.
Changed in Synapse v1.56.0: add 'admin' argument to register the user as admin.
Args:
localpart: The localpart of the new user.
displayname: The displayname of the new user.
emails: Emails to bind to the new user.
admin: True if the user should be registered as a server admin.
Raises:
SynapseError if there is an error performing the registration. Check the
@ -631,6 +636,7 @@ class ModuleApi:
localpart=localpart,
default_display_name=displayname,
bind_emails=emails or [],
admin=admin,
)
)
@ -665,7 +671,8 @@ class ModuleApi:
def record_user_external_id(
self, auth_provider_id: str, remote_user_id: str, registered_user_id: str
) -> defer.Deferred:
"""Record a mapping from an external user id to a mxid
"""Record a mapping between an external user id from a single sign-on provider
and a mxid.
Added in Synapse v1.9.0.
@ -1280,6 +1287,30 @@ class ModuleApi:
"""
await self._registration_handler.check_username(username)
async def store_remote_3pid_association(
self, user_id: str, medium: str, address: str, id_server: str
) -> None:
"""Stores an existing association between a user ID and a third-party identifier.
The association must already exist on the remote identity server.
Added in Synapse v1.56.0.
Args:
user_id: The user ID that's been associated with the 3PID.
medium: The medium of the 3PID (current supported values are "msisdn" and
"email").
address: The address of the 3PID.
id_server: The identity server the 3PID association has been registered on.
This should only be the domain (or IP address, optionally with the port
number) for the identity server. This will be used to reach out to the
identity server using HTTPS (unless specified otherwise by Synapse's
configuration) when attempting to unbind the third-party identifier.
"""
await self._store.add_user_bound_threepid(user_id, medium, address, id_server)
class PublicRoomListManager:
"""Contains methods for adding to, removing from and querying whether a room

View file

@ -24,6 +24,7 @@ from synapse.event_auth import get_user_power_level
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.storage.databases.main.roommember import EventIdMembership
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.descriptors import lru_cache
@ -213,7 +214,7 @@ class BulkPushRuleEvaluator:
if not event.is_state():
ignorers = await self.store.ignored_by(event.sender)
else:
ignorers = set()
ignorers = frozenset()
for uid, rules in rules_by_user.items():
if event.sender == uid:
@ -292,7 +293,7 @@ def _condition_checker(
return True
MemberMap = Dict[str, Tuple[str, str]]
MemberMap = Dict[str, Optional[EventIdMembership]]
Rule = Dict[str, dict]
RulesByUser = Dict[str, List[Rule]]
StateGroup = Union[object, int]
@ -306,7 +307,7 @@ class RulesForRoomData:
*only* include data, and not references to e.g. the data stores.
"""
# event_id -> (user_id, state)
# event_id -> EventIdMembership
member_map: MemberMap = attr.Factory(dict)
# user_id -> rules
rules_by_user: RulesByUser = attr.Factory(dict)
@ -447,11 +448,10 @@ class RulesForRoom:
res = self.data.member_map.get(event_id, None)
if res:
user_id, state = res
if state == Membership.JOIN:
rules = self.data.rules_by_user.get(user_id, None)
if res.membership == Membership.JOIN:
rules = self.data.rules_by_user.get(res.user_id, None)
if rules:
ret_rules_by_user[user_id] = rules
ret_rules_by_user[res.user_id] = rules
continue
# If a user has left a room we remove their push rule. If they
@ -502,24 +502,26 @@ class RulesForRoom:
"""
sequence = self.data.sequence
rows = await self.store.get_membership_from_event_ids(member_event_ids.values())
members = await self.store.get_membership_from_event_ids(
member_event_ids.values()
)
members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows}
# If the event is a join event then it will be in current state evnts
# If the event is a join event then it will be in current state events
# map but not in the DB, so we have to explicitly insert it.
if event.type == EventTypes.Member:
for event_id in member_event_ids.values():
if event_id == event.event_id:
members[event_id] = (event.state_key, event.membership)
members[event_id] = EventIdMembership(
user_id=event.state_key, membership=event.membership
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values())
joined_user_ids = {
user_id
for user_id, membership in members.values()
if membership == Membership.JOIN
entry.user_id
for entry in members.values()
if entry and entry.membership == Membership.JOIN
}
logger.debug("Joined: %r", joined_user_ids)

View file

@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, TypeVar
import bleach
import jinja2
from markupsafe import Markup
from synapse.api.constants import EventTypes, Membership, RoomTypes
from synapse.api.errors import StoreError
@ -867,7 +868,7 @@ class Mailer:
)
def safe_markup(raw_html: str) -> jinja2.Markup:
def safe_markup(raw_html: str) -> Markup:
"""
Sanitise a raw HTML string to a set of allowed tags and attributes, and linkify any bare URLs.
@ -877,7 +878,7 @@ def safe_markup(raw_html: str) -> jinja2.Markup:
Returns:
A Markup object ready to safely use in a Jinja template.
"""
return jinja2.Markup(
return Markup(
bleach.linkify(
bleach.clean(
raw_html,
@ -891,7 +892,7 @@ def safe_markup(raw_html: str) -> jinja2.Markup:
)
def safe_text(raw_text: str) -> jinja2.Markup:
def safe_text(raw_text: str) -> Markup:
"""
Sanitise text (escape any HTML tags), and then linkify any bare URLs.
@ -901,7 +902,7 @@ def safe_text(raw_text: str) -> jinja2.Markup:
Returns:
A Markup object ready to safely use in a Jinja template.
"""
return jinja2.Markup(
return Markup(
bleach.linkify(bleach.clean(raw_text, tags=[], attributes=[], strip=False))
)

View file

@ -74,7 +74,10 @@ REQUIREMENTS = [
# Note: 21.1.0 broke `/sync`, see #9936
"attrs>=19.2.0,!=21.1.0",
"netaddr>=0.7.18",
"Jinja2>=2.9",
# Jinja 2.x is incompatible with MarkupSafe>=2.1. To ensure that admins do not
# end up with a broken installation, with recent MarkupSafe but old Jinja, we
# add a lower bound to the Jinja2 dependency.
"Jinja2>=3.0",
"bleach>=1.4.3",
# We use `ParamSpec`, which was added in `typing-extensions` 3.10.0.0.
"typing-extensions>=3.10.0",

View file

@ -32,6 +32,7 @@ from synapse.rest.client import (
knock,
login as v1_login,
logout,
mutual_rooms,
notifications,
openid,
password_policy,
@ -49,7 +50,6 @@ from synapse.rest.client import (
room_keys,
room_upgrade_rest_servlet,
sendtodevice,
shared_rooms,
sync,
tags,
thirdparty,
@ -132,4 +132,4 @@ class ClientRestResource(JsonResource):
admin.register_servlets_for_client_rest_resource(hs, client_resource)
# unstable
shared_rooms.register_servlets(hs, client_resource)
mutual_rooms.register_servlets(hs, client_resource)

View file

@ -28,13 +28,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class UserSharedRoomsServlet(RestServlet):
class UserMutualRoomsServlet(RestServlet):
"""
GET /uk.half-shot.msc2666/user/shared_rooms/{user_id} HTTP/1.1
GET /uk.half-shot.msc2666/user/mutual_rooms/{user_id} HTTP/1.1
"""
PATTERNS = client_patterns(
"/uk.half-shot.msc2666/user/shared_rooms/(?P<user_id>[^/]*)",
"/uk.half-shot.msc2666/user/mutual_rooms/(?P<user_id>[^/]*)",
releases=(), # This is an unstable feature
)
@ -42,17 +42,19 @@ class UserSharedRoomsServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self.user_directory_active = hs.config.server.update_user_directory
self.user_directory_search_enabled = (
hs.config.userdirectory.user_directory_search_enabled
)
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
if not self.user_directory_active:
if not self.user_directory_search_enabled:
raise SynapseError(
code=400,
msg="The user directory is disabled on this server. Cannot determine shared rooms.",
errcode=Codes.FORBIDDEN,
msg="User directory searching is disabled. Cannot determine shared rooms.",
errcode=Codes.UNKNOWN,
)
UserID.from_string(user_id)
@ -64,7 +66,8 @@ class UserSharedRoomsServlet(RestServlet):
msg="You cannot request a list of shared rooms with yourself",
errcode=Codes.FORBIDDEN,
)
rooms = await self.store.get_shared_rooms_for_users(
rooms = await self.store.get_mutual_rooms_for_users(
requester.user.to_string(), user_id
)
@ -72,4 +75,4 @@ class UserSharedRoomsServlet(RestServlet):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserSharedRoomsServlet(hs).register(http_server)
UserMutualRoomsServlet(hs).register(http_server)

View file

@ -51,9 +51,7 @@ class RelationPaginationServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self.clock = hs.get_clock()
self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler()
self._relations_handler = hs.get_relations_handler()
async def on_GET(
self,
@ -65,16 +63,6 @@ class RelationPaginationServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
room_id, requester.user.to_string(), allow_departed_users=True
)
# This gets the original event and checks that a) the event exists and
# b) the user is allowed to view it.
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
if event is None:
raise SynapseError(404, "Unknown parent event.")
limit = parse_integer(request, "limit", default=5)
direction = parse_string(
request, "org.matrix.msc3715.dir", default="b", allowed_values=["f", "b"]
@ -90,9 +78,9 @@ class RelationPaginationServlet(RestServlet):
if to_token_str:
to_token = await StreamToken.from_string(self.store, to_token_str)
pagination_chunk = await self.store.get_relations_for_event(
result = await self._relations_handler.get_relations(
requester=requester,
event_id=parent_id,
event=event,
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
@ -102,30 +90,7 @@ class RelationPaginationServlet(RestServlet):
to_token=to_token,
)
events = await self.store.get_events_as_list(
[c["event_id"] for c in pagination_chunk.chunk]
)
now = self.clock.time_msec()
# Do not bundle aggregations when retrieving the original event because
# we want the content before relations are applied to it.
original_event = self._event_serializer.serialize_event(
event, now, bundle_aggregations=None
)
# The relations returned for the requested event do include their
# bundled aggregations.
aggregations = await self.store.get_bundled_aggregations(
events, requester.user.to_string()
)
serialized_events = self._event_serializer.serialize_events(
events, now, bundle_aggregations=aggregations
)
return_value = await pagination_chunk.to_dict(self.store)
return_value["chunk"] = serialized_events
return_value["original_event"] = original_event
return 200, return_value
return 200, result
class RelationAggregationPaginationServlet(RestServlet):
@ -245,9 +210,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self.clock = hs.get_clock()
self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler()
self._relations_handler = hs.get_relations_handler()
async def on_GET(
self,
@ -260,18 +223,6 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
room_id,
requester.user.to_string(),
allow_departed_users=True,
)
# This checks that a) the event exists and b) the user is allowed to
# view it.
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
if event is None:
raise SynapseError(404, "Unknown parent event.")
if relation_type != RelationTypes.ANNOTATION:
raise SynapseError(400, "Relation type must be 'annotation'")
@ -286,9 +237,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
if to_token_str:
to_token = await StreamToken.from_string(self.store, to_token_str)
result = await self.store.get_relations_for_event(
result = await self._relations_handler.get_relations(
requester=requester,
event_id=parent_id,
event=event,
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
@ -298,17 +249,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
to_token=to_token,
)
events = await self.store.get_events_as_list(
[c["event_id"] for c in result.chunk]
)
now = self.clock.time_msec()
serialized_events = self._event_serializer.serialize_events(events, now)
return_value = await result.to_dict(self.store)
return_value["chunk"] = serialized_events
return 200, return_value
return 200, result
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:

View file

@ -649,6 +649,7 @@ class RoomEventServlet(RestServlet):
self._store = hs.get_datastores().main
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
self._relations_handler = hs.get_relations_handler()
self.auth = hs.get_auth()
async def on_GET(
@ -667,7 +668,7 @@ class RoomEventServlet(RestServlet):
if event:
# Ensure there are bundled aggregations available.
aggregations = await self._store.get_bundled_aggregations(
aggregations = await self._relations_handler.get_bundled_aggregations(
[event], requester.user.to_string()
)

View file

@ -124,14 +124,14 @@ class RoomBatchSendEventRestServlet(RestServlet):
)
# For the event we are inserting next to (`prev_event_ids_from_query`),
# find the most recent auth events (derived from state events) that
# allowed that message to be sent. We will use that as a base
# to auth our historical messages against.
auth_event_ids = await self.room_batch_handler.get_most_recent_auth_event_ids_from_event_id_list(
# find the most recent state events that allowed that message to be
# sent. We will use that as a base to auth our historical messages
# against.
state_event_ids = await self.room_batch_handler.get_most_recent_full_state_ids_from_event_id_list(
prev_event_ids_from_query
)
if not auth_event_ids:
if not state_event_ids:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"No auth events found for given prev_event query parameter. The prev_event=%s probably does not exist."
@ -148,13 +148,13 @@ class RoomBatchSendEventRestServlet(RestServlet):
await self.room_batch_handler.persist_state_events_at_start(
state_events_at_start=body["state_events_at_start"],
room_id=room_id,
initial_auth_event_ids=auth_event_ids,
initial_state_event_ids=state_event_ids,
app_service_requester=requester,
)
)
# Update our ongoing auth event ID list with all of the new state we
# just created
auth_event_ids.extend(state_event_ids_at_start)
state_event_ids.extend(state_event_ids_at_start)
inherited_depth = await self.room_batch_handler.inherit_depth_from_prev_ids(
prev_event_ids_from_query
@ -196,7 +196,12 @@ class RoomBatchSendEventRestServlet(RestServlet):
),
base_insertion_event_dict,
prev_event_ids=base_insertion_event_dict.get("prev_events"),
auth_event_ids=auth_event_ids,
# Also set the explicit state here because we want to resolve
# any `state_events_at_start` here too. It's not strictly
# necessary to accomplish anything but if someone asks for the
# state at this point, we probably want to show them the
# historical state that was part of this batch.
state_event_ids=state_event_ids,
historical=True,
depth=inherited_depth,
)
@ -212,7 +217,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
room_id=room_id,
batch_id_to_connect_to=batch_id_to_connect_to,
inherited_depth=inherited_depth,
auth_event_ids=auth_event_ids,
initial_state_event_ids=state_event_ids,
app_service_requester=requester,
)

View file

@ -19,7 +19,7 @@ from synapse.api.errors import SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from synapse.types import JsonMapping
from ._base import client_patterns
@ -38,7 +38,7 @@ class UserDirectorySearchRestServlet(RestServlet):
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonMapping]:
"""Searches for users in directory
Returns:

View file

@ -16,7 +16,6 @@ import itertools
import logging
import re
from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union
from urllib import parse as urlparse
if TYPE_CHECKING:
from lxml import etree
@ -144,9 +143,7 @@ def decode_body(
return etree.fromstring(body, parser)
def parse_html_to_open_graph(
tree: "etree.Element", media_uri: str
) -> Dict[str, Optional[str]]:
def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
"""
Parse the HTML document into an Open Graph response.
@ -155,7 +152,6 @@ def parse_html_to_open_graph(
Args:
tree: The parsed HTML document.
media_url: The URI used to download the body.
Returns:
The Open Graph response as a dictionary.
@ -209,7 +205,7 @@ def parse_html_to_open_graph(
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
)
if meta_image:
og["og:image"] = rebase_url(meta_image[0], media_uri)
og["og:image"] = meta_image[0]
else:
# TODO: consider inlined CSS styles as well as width & height attribs
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
@ -320,37 +316,6 @@ def _iterate_over_text(
)
def rebase_url(url: str, base: str) -> str:
"""
Resolves a potentially relative `url` against an absolute `base` URL.
For example:
>>> rebase_url("subpage", "https://example.com/foo/")
'https://example.com/foo/subpage'
>>> rebase_url("sibling", "https://example.com/foo")
'https://example.com/sibling'
>>> rebase_url("/bar", "https://example.com/foo/")
'https://example.com/bar'
>>> rebase_url("https://alice.com/a/", "https://example.com/foo/")
'https://alice.com/a'
"""
base_parts = urlparse.urlparse(base)
# Convert the parsed URL to a list for (potential) modification.
url_parts = list(urlparse.urlparse(url))
# Add a scheme, if one does not exist.
if not url_parts[0]:
url_parts[0] = base_parts.scheme or "http"
# Fix up the hostname, if this is not a data URL.
if url_parts[0] != "data" and not url_parts[1]:
url_parts[1] = base_parts.netloc
# If the path does not start with a /, nest it under the base path's last
# directory.
if not url_parts[2].startswith("/"):
url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts.path) + url_parts[2]
return urlparse.urlunparse(url_parts)
def summarize_paragraphs(
text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
) -> Optional[str]:

View file

@ -22,7 +22,7 @@ import shutil
import sys
import traceback
from typing import TYPE_CHECKING, BinaryIO, Iterable, Optional, Tuple
from urllib import parse as urlparse
from urllib.parse import urljoin, urlparse, urlsplit
from urllib.request import urlopen
import attr
@ -44,11 +44,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.oembed import OEmbedProvider
from synapse.rest.media.v1.preview_html import (
decode_body,
parse_html_to_open_graph,
rebase_url,
)
from synapse.rest.media.v1.preview_html import decode_body, parse_html_to_open_graph
from synapse.types import JsonDict, UserID
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
@ -187,7 +183,7 @@ class PreviewUrlResource(DirectServeJsonResource):
ts = self.clock.time_msec()
# XXX: we could move this into _do_preview if we wanted.
url_tuple = urlparse.urlsplit(url)
url_tuple = urlsplit(url)
for entry in self.url_preview_url_blacklist:
match = True
for attrib in entry:
@ -322,7 +318,7 @@ class PreviewUrlResource(DirectServeJsonResource):
# Parse Open Graph information from the HTML in case the oEmbed
# response failed or is incomplete.
og_from_html = parse_html_to_open_graph(tree, media_info.uri)
og_from_html = parse_html_to_open_graph(tree)
# Compile the Open Graph response by using the scraped
# information from the HTML and overlaying any information
@ -588,12 +584,17 @@ class PreviewUrlResource(DirectServeJsonResource):
if "og:image" not in og or not og["og:image"]:
return
# The image URL from the HTML might be relative to the previewed page,
# convert it to an URL which can be requested directly.
image_url = og["og:image"]
url_parts = urlparse(image_url)
if url_parts.scheme != "data":
image_url = urljoin(media_info.uri, image_url)
# FIXME: it might be cleaner to use the same flow as the main /preview_url
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
image_info = await self._handle_url(
rebase_url(og["og:image"], media_info.uri), user, allow_data_urls=True
)
image_info = await self._handle_url(image_url, user, allow_data_urls=True)
if _is_media(image_info.media_type):
# TODO: make sure we don't choke on white-on-transparent images

View file

@ -94,6 +94,7 @@ from synapse.handlers.profile import ProfileHandler
from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.receipts import ReceiptsHandler
from synapse.handlers.register import RegistrationHandler
from synapse.handlers.relations import RelationsHandler
from synapse.handlers.room import (
RoomContextHandler,
RoomCreationHandler,
@ -719,6 +720,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_pagination_handler(self) -> PaginationHandler:
return PaginationHandler(self)
@cache_in_self
def get_relations_handler(self) -> RelationsHandler:
return RelationsHandler(self)
@cache_in_self
def get_room_context_handler(self) -> RoomContextHandler:
return RoomContextHandler(self)

View file

@ -41,6 +41,7 @@ from prometheus_client import Histogram
from typing_extensions import Literal
from twisted.enterprise import adbapi
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
@ -55,6 +56,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
from synapse.util.async_helpers import delay_cancellation
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@ -286,13 +288,17 @@ class LoggingTransaction:
"""
if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch # type: ignore
from psycopg2.extras import execute_batch
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
self._do_execute(
lambda the_sql: execute_batch(self.txn, the_sql, args), sql
)
else:
self.executemany(sql, args)
def execute_values(self, sql: str, *args: Any, fetch: bool = True) -> List[Tuple]:
def execute_values(
self, sql: str, values: Iterable[Iterable[Any]], fetch: bool = True
) -> List[Tuple]:
"""Corresponds to psycopg2.extras.execute_values. Only available when
using postgres.
@ -300,10 +306,11 @@ class LoggingTransaction:
rows (e.g. INSERTs).
"""
assert isinstance(self.database_engine, PostgresEngine)
from psycopg2.extras import execute_values # type: ignore
from psycopg2.extras import execute_values
return self._do_execute(
lambda *x: execute_values(self.txn, *x, fetch=fetch), sql, *args
lambda the_sql: execute_values(self.txn, the_sql, values, fetch=fetch),
sql,
)
def execute(self, sql: str, *args: Any) -> None:
@ -732,34 +739,45 @@ class DatabasePool:
Returns:
The result of func
"""
after_callbacks: List[_CallbackListEntry] = []
exception_callbacks: List[_CallbackListEntry] = []
if not current_context():
logger.warning("Starting db txn '%s' from sentinel context", desc)
async def _runInteraction() -> R:
after_callbacks: List[_CallbackListEntry] = []
exception_callbacks: List[_CallbackListEntry] = []
try:
with opentracing.start_active_span(f"db.{desc}"):
result = await self.runWithConnection(
self.new_transaction,
desc,
after_callbacks,
exception_callbacks,
func,
*args,
db_autocommit=db_autocommit,
isolation_level=isolation_level,
**kwargs,
)
if not current_context():
logger.warning("Starting db txn '%s' from sentinel context", desc)
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
except Exception:
for after_callback, after_args, after_kwargs in exception_callbacks:
after_callback(*after_args, **after_kwargs)
raise
try:
with opentracing.start_active_span(f"db.{desc}"):
result = await self.runWithConnection(
self.new_transaction,
desc,
after_callbacks,
exception_callbacks,
func,
*args,
db_autocommit=db_autocommit,
isolation_level=isolation_level,
**kwargs,
)
return cast(R, result)
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
return cast(R, result)
except Exception:
for after_callback, after_args, after_kwargs in exception_callbacks:
after_callback(*after_args, **after_kwargs)
raise
# To handle cancellation, we ensure that `after_callback`s and
# `exception_callback`s are always run, since the transaction will complete
# on another thread regardless of cancellation.
#
# We also wait until everything above is done before releasing the
# `CancelledError`, so that logging contexts won't get used after they have been
# finished.
return await delay_cancellation(defer.ensureDeferred(_runInteraction()))
async def runWithConnection(
self,

View file

@ -14,7 +14,17 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast
from typing import (
TYPE_CHECKING,
Any,
Dict,
FrozenSet,
Iterable,
List,
Optional,
Tuple,
cast,
)
from synapse.api.constants import AccountDataTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@ -365,7 +375,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
@cached(max_entries=5000, iterable=True)
async def ignored_by(self, user_id: str) -> Set[str]:
async def ignored_by(self, user_id: str) -> FrozenSet[str]:
"""
Get users which ignore the given user.
@ -375,7 +385,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
Return:
The user IDs which ignore the given user.
"""
return set(
return frozenset(
await self.db_pool.simple_select_onecol(
table="ignored_users",
keyvalues={"ignored_user_id": user_id},
@ -384,6 +394,26 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
)
@cached(max_entries=5000, iterable=True)
async def ignored_users(self, user_id: str) -> FrozenSet[str]:
"""
Get users which the given user ignores.
Params:
user_id: The user ID which is making the request.
Return:
The user IDs which are ignored by the given user.
"""
return frozenset(
await self.db_pool.simple_select_onecol(
table="ignored_users",
keyvalues={"ignorer_user_id": user_id},
retcol="ignored_user_id",
desc="ignored_users",
)
)
def process_replication_rows(
self,
stream_name: str,
@ -529,6 +559,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
else:
currently_ignored_users = set()
# If the data has not changed, nothing to do.
if previously_ignored_users == currently_ignored_users:
return
# Delete entries which are no longer ignored.
self.db_pool.simple_delete_many_txn(
txn,
@ -551,6 +585,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
# Invalidate the cache for any ignored users which were added or removed.
for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
async def purge_account_data_for_user(self, user_id: str) -> None:
"""

View file

@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamCurrentStateRow,
EventsStreamEventRow,
EventsStreamRow,
)
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
@ -31,6 +32,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.util.caches.descriptors import _CachedFunction
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@ -82,7 +84,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
def get_all_updated_caches_txn(txn):
def get_all_updated_caches_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
@ -107,7 +111,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"get_all_updated_caches", get_all_updated_caches_txn
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == EventsStream.NAME:
for row in rows:
self._process_event_stream_row(token, row)
@ -142,10 +148,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
super().process_replication_rows(stream_name, instance_name, token, rows)
def _process_event_stream_row(self, token, row):
def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
data = row.data
if row.type == EventsStreamEventRow.TypeId:
assert isinstance(data, EventsStreamEventRow)
self._invalidate_caches_for_event(
token,
data.event_id,
@ -157,9 +164,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
backfilled=False,
)
elif row.type == EventsStreamCurrentStateRow.TypeId:
self._curr_state_delta_stream_cache.entity_has_changed(
row.data.room_id, token
)
assert isinstance(data, EventsStreamCurrentStateRow)
self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token)
if data.type == EventTypes.Member:
self.get_rooms_for_user_with_stream_ordering.invalidate(
@ -170,15 +176,15 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
def _invalidate_caches_for_event(
self,
stream_ordering,
event_id,
room_id,
etype,
state_key,
redacts,
relates_to,
backfilled,
):
stream_ordering: int,
event_id: str,
room_id: str,
etype: str,
state_key: Optional[str],
redacts: Optional[str],
relates_to: Optional[str],
backfilled: bool,
) -> None:
self._invalidate_get_event_cache(event_id)
self.have_seen_event.invalidate((room_id, event_id))
@ -186,6 +192,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,))
# The `_get_membership_from_event_id` is immutable, except for the
# case where we look up an event *before* persisting it.
self._get_membership_from_event_id.invalidate((event_id,))
if not backfilled:
self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
@ -207,7 +217,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_thread_summary.invalidate((relates_to,))
self.get_thread_participated.invalidate((relates_to,))
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
async def invalidate_cache_and_stream(
self, cache_name: str, keys: Tuple[Any, ...]
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@ -227,7 +239,12 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
keys,
)
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
def _invalidate_cache_and_stream(
self,
txn: LoggingTransaction,
cache_func: _CachedFunction,
keys: Tuple[Any, ...],
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@ -238,7 +255,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
def _invalidate_all_cache_and_stream(self, txn, cache_func):
def _invalidate_all_cache_and_stream(
self, txn: LoggingTransaction, cache_func: _CachedFunction
) -> None:
"""Invalidates the entire cache and adds it to the cache stream so slaves
will know to invalidate their caches.
"""
@ -279,8 +298,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
)
def _send_invalidation_to_replication(
self, txn, cache_name: str, keys: Optional[Iterable[Any]]
):
self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
) -> None:
"""Notifies replication that given cache has been invalidated.
Note that this does *not* invalidate the cache locally.
@ -315,7 +334,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"instance_name": self._instance_name,
"cache_func": cache_name,
"keys": keys,
"invalidation_ts": self.clock.time_msec(),
"invalidation_ts": self._clock.time_msec(),
},
)

View file

@ -1073,9 +1073,15 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
/* Get the depth and stream_ordering of the prev_event_id from the events table */
INNER JOIN events
ON prev_event_id = events.event_id
/* exclude outliers from the results (we don't have the state, so cannot
* verify if the requesting server can see them).
*/
WHERE NOT events.outlier
/* Look for an edge which matches the given event_id */
WHERE event_edges.event_id = ?
AND event_edges.is_state = ?
AND event_edges.event_id = ? AND NOT event_edges.is_state
/* Because we can have many events at the same depth,
* we want to also tie-break and sort on stream_ordering */
ORDER BY depth DESC, stream_ordering DESC
@ -1084,7 +1090,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
txn.execute(
connected_prev_event_query,
(event_id, False, limit),
(event_id, limit),
)
return [
BackfillQueueNavigationItem(

View file

@ -1745,6 +1745,13 @@ class PersistEventsStore:
(event.state_key,),
)
# The `_get_membership_from_event_id` is immutable, except for the
# case where we look up an event *before* persisting it.
txn.call_after(
self.store._get_membership_from_event_id.invalidate,
(event.event_id,),
)
# We update the local_current_membership table only if the event is
# "current", i.e., its something that has just happened.
#

View file

@ -13,13 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
from typing_extensions import TypedDict
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.types import JsonDict
from synapse.util import json_encoder
@ -75,7 +79,7 @@ class GroupServerWorkerStore(SQLBaseStore):
) -> List[Dict[str, Any]]:
# TODO: Pagination
keyvalues = {"group_id": group_id}
keyvalues: JsonDict = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
@ -117,7 +121,7 @@ class GroupServerWorkerStore(SQLBaseStore):
# TODO: Pagination
def _get_rooms_in_group_txn(txn):
def _get_rooms_in_group_txn(txn: LoggingTransaction) -> List[_RoomInGroup]:
sql = """
SELECT room_id, is_public FROM group_rooms
WHERE group_id = ?
@ -176,8 +180,10 @@ class GroupServerWorkerStore(SQLBaseStore):
* "order": int, the sort order of rooms in this category
"""
def _get_rooms_for_summary_txn(txn):
keyvalues = {"group_id": group_id}
def _get_rooms_for_summary_txn(
txn: LoggingTransaction,
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
keyvalues: JsonDict = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
@ -241,7 +247,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_rooms_for_summary", _get_rooms_for_summary_txn
)
async def get_group_categories(self, group_id):
async def get_group_categories(self, group_id: str) -> JsonDict:
rows = await self.db_pool.simple_select_list(
table="group_room_categories",
keyvalues={"group_id": group_id},
@ -257,7 +263,7 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in rows
}
async def get_group_category(self, group_id, category_id):
async def get_group_category(self, group_id: str, category_id: str) -> JsonDict:
category = await self.db_pool.simple_select_one(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
@ -269,7 +275,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return category
async def get_group_roles(self, group_id):
async def get_group_roles(self, group_id: str) -> JsonDict:
rows = await self.db_pool.simple_select_list(
table="group_roles",
keyvalues={"group_id": group_id},
@ -285,7 +291,7 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in rows
}
async def get_group_role(self, group_id, role_id):
async def get_group_role(self, group_id: str, role_id: str) -> JsonDict:
role = await self.db_pool.simple_select_one(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
@ -311,15 +317,19 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_local_groups_for_room",
)
async def get_users_for_summary_by_role(self, group_id, include_private=False):
async def get_users_for_summary_by_role(
self, group_id: str, include_private: bool = False
) -> Tuple[List[JsonDict], JsonDict]:
"""Get the users and roles that should be included in a summary request
Returns:
([users], [roles])
"""
def _get_users_for_summary_txn(txn):
keyvalues = {"group_id": group_id}
def _get_users_for_summary_txn(
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], JsonDict]:
keyvalues: JsonDict = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
@ -406,7 +416,9 @@ class GroupServerWorkerStore(SQLBaseStore):
allow_none=True,
)
async def get_users_membership_info_in_group(self, group_id, user_id):
async def get_users_membership_info_in_group(
self, group_id: str, user_id: str
) -> JsonDict:
"""Get a dict describing the membership of a user in a group.
Example if joined:
@ -421,7 +433,7 @@ class GroupServerWorkerStore(SQLBaseStore):
An empty dict if the user is not join/invite/etc
"""
def _get_users_membership_in_group_txn(txn):
def _get_users_membership_in_group_txn(txn: LoggingTransaction) -> JsonDict:
row = self.db_pool.simple_select_one_txn(
txn,
table="group_users",
@ -463,10 +475,14 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_publicised_groups_for_user",
)
async def get_attestations_need_renewals(self, valid_until_ms):
async def get_attestations_need_renewals(
self, valid_until_ms: int
) -> List[Dict[str, Any]]:
"""Get all attestations that need to be renewed until givent time"""
def _get_attestations_need_renewals_txn(txn):
def _get_attestations_need_renewals_txn(
txn: LoggingTransaction,
) -> List[Dict[str, Any]]:
sql = """
SELECT group_id, user_id FROM group_attestations_renewals
WHERE valid_until_ms <= ?
@ -478,7 +494,9 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_attestations_need_renewals", _get_attestations_need_renewals_txn
)
async def get_remote_attestation(self, group_id, user_id):
async def get_remote_attestation(
self, group_id: str, user_id: str
) -> Optional[JsonDict]:
"""Get the attestation that proves the remote agrees that the user is
in the group.
"""
@ -504,8 +522,8 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_joined_groups",
)
async def get_all_groups_for_user(self, user_id, now_token):
def _get_all_groups_for_user_txn(txn):
async def get_all_groups_for_user(self, user_id, now_token) -> List[JsonDict]:
def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
sql = """
SELECT group_id, type, membership, u.content
FROM local_group_updates AS u
@ -528,15 +546,16 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_all_groups_for_user", _get_all_groups_for_user_txn
)
async def get_groups_changes_for_user(self, user_id, from_token, to_token):
from_token = int(from_token)
has_changed = self._group_updates_stream_cache.has_entity_changed(
async def get_groups_changes_for_user(
self, user_id: str, from_token: int, to_token: int
) -> List[JsonDict]:
has_changed = self._group_updates_stream_cache.has_entity_changed( # type: ignore[attr-defined]
user_id, from_token
)
if not has_changed:
return []
def _get_groups_changes_for_user_txn(txn):
def _get_groups_changes_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
sql = """
SELECT group_id, membership, type, u.content
FROM local_group_updates AS u
@ -583,12 +602,14 @@ class GroupServerWorkerStore(SQLBaseStore):
"""
last_id = int(last_id)
has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id)
has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) # type: ignore[attr-defined]
if not has_changed:
return [], current_id, False
def _get_all_groups_changes_txn(txn):
def _get_all_groups_changes_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
sql = """
SELECT stream_id, group_id, user_id, type, content
FROM local_group_updates
@ -596,10 +617,13 @@ class GroupServerWorkerStore(SQLBaseStore):
LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
updates = [
(stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
for stream_id, group_id, user_id, gtype, content_json in txn
]
updates = cast(
List[Tuple[int, tuple]],
[
(stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
for stream_id, group_id, user_id, gtype, content_json in txn
],
)
limited = False
upto_token = current_id
@ -633,8 +657,8 @@ class GroupServerStore(GroupServerWorkerStore):
self,
group_id: str,
room_id: str,
category_id: str,
order: int,
category_id: Optional[str],
order: Optional[int],
is_public: Optional[bool],
) -> None:
"""Add (or update) room's entry in summary.
@ -661,11 +685,11 @@ class GroupServerStore(GroupServerWorkerStore):
def _add_room_to_summary_txn(
self,
txn,
txn: LoggingTransaction,
group_id: str,
room_id: str,
category_id: str,
order: int,
category_id: Optional[str],
order: Optional[int],
is_public: Optional[bool],
) -> None:
"""Add (or update) room's entry in summary.
@ -750,7 +774,7 @@ class GroupServerStore(GroupServerWorkerStore):
WHERE group_id = ? AND category_id = ?
"""
txn.execute(sql, (group_id, category_id))
(order,) = txn.fetchone()
(order,) = cast(Tuple[int], txn.fetchone())
if existing:
to_update = {}
@ -766,7 +790,7 @@ class GroupServerStore(GroupServerWorkerStore):
"category_id": category_id,
"room_id": room_id,
},
values=to_update,
updatevalues=to_update,
)
else:
if is_public is None:
@ -785,7 +809,7 @@ class GroupServerStore(GroupServerWorkerStore):
)
async def remove_room_from_summary(
self, group_id: str, room_id: str, category_id: str
self, group_id: str, room_id: str, category_id: Optional[str]
) -> int:
if category_id is None:
category_id = _DEFAULT_CATEGORY_ID
@ -808,8 +832,8 @@ class GroupServerStore(GroupServerWorkerStore):
is_public: Optional[bool],
) -> None:
"""Add/update room category for group"""
insertion_values = {}
update_values = {"category_id": category_id} # This cannot be empty
insertion_values: JsonDict = {}
update_values: JsonDict = {"category_id": category_id} # This cannot be empty
if profile is None:
insertion_values["profile"] = "{}"
@ -844,8 +868,8 @@ class GroupServerStore(GroupServerWorkerStore):
is_public: Optional[bool],
) -> None:
"""Add/remove user role"""
insertion_values = {}
update_values = {"role_id": role_id} # This cannot be empty
insertion_values: JsonDict = {}
update_values: JsonDict = {"role_id": role_id} # This cannot be empty
if profile is None:
insertion_values["profile"] = "{}"
@ -876,8 +900,8 @@ class GroupServerStore(GroupServerWorkerStore):
self,
group_id: str,
user_id: str,
role_id: str,
order: int,
role_id: Optional[str],
order: Optional[int],
is_public: Optional[bool],
) -> None:
"""Add (or update) user's entry in summary.
@ -904,13 +928,13 @@ class GroupServerStore(GroupServerWorkerStore):
def _add_user_to_summary_txn(
self,
txn,
txn: LoggingTransaction,
group_id: str,
user_id: str,
role_id: str,
order: int,
role_id: Optional[str],
order: Optional[int],
is_public: Optional[bool],
):
) -> None:
"""Add (or update) user's entry in summary.
Args:
@ -989,7 +1013,7 @@ class GroupServerStore(GroupServerWorkerStore):
WHERE group_id = ? AND role_id = ?
"""
txn.execute(sql, (group_id, role_id))
(order,) = txn.fetchone()
(order,) = cast(Tuple[int], txn.fetchone())
if existing:
to_update = {}
@ -1005,7 +1029,7 @@ class GroupServerStore(GroupServerWorkerStore):
"role_id": role_id,
"user_id": user_id,
},
values=to_update,
updatevalues=to_update,
)
else:
if is_public is None:
@ -1024,7 +1048,7 @@ class GroupServerStore(GroupServerWorkerStore):
)
async def remove_user_from_summary(
self, group_id: str, user_id: str, role_id: str
self, group_id: str, user_id: str, role_id: Optional[str]
) -> int:
if role_id is None:
role_id = _DEFAULT_ROLE_ID
@ -1065,7 +1089,7 @@ class GroupServerStore(GroupServerWorkerStore):
Optional if the user and group are on the same server
"""
def _add_user_to_group_txn(txn):
def _add_user_to_group_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_insert_txn(
txn,
table="group_users",
@ -1108,7 +1132,7 @@ class GroupServerStore(GroupServerWorkerStore):
await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
async def remove_user_from_group(self, group_id: str, user_id: str) -> None:
def _remove_user_from_group_txn(txn):
def _remove_user_from_group_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn,
table="group_users",
@ -1159,7 +1183,7 @@ class GroupServerStore(GroupServerWorkerStore):
)
async def remove_room_from_group(self, group_id: str, room_id: str) -> None:
def _remove_room_from_group_txn(txn):
def _remove_room_from_group_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn,
table="group_rooms",
@ -1216,7 +1240,9 @@ class GroupServerStore(GroupServerWorkerStore):
content = content or {}
def _register_user_group_membership_txn(txn, next_id):
def _register_user_group_membership_txn(
txn: LoggingTransaction, next_id: int
) -> int:
# TODO: Upsert?
self.db_pool.simple_delete_txn(
txn,
@ -1249,7 +1275,7 @@ class GroupServerStore(GroupServerWorkerStore):
),
},
)
self._group_updates_stream_cache.entity_has_changed(user_id, next_id)
self._group_updates_stream_cache.entity_has_changed(user_id, next_id) # type: ignore[attr-defined]
# TODO: Insert profile to ensure it comes down stream if its a join.
@ -1289,7 +1315,7 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id
async with self._group_updates_id_gen.get_next() as next_id:
async with self._group_updates_id_gen.get_next() as next_id: # type: ignore[attr-defined]
res = await self.db_pool.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
@ -1298,7 +1324,13 @@ class GroupServerStore(GroupServerWorkerStore):
return res
async def create_group(
self, group_id, user_id, name, avatar_url, short_description, long_description
self,
group_id: str,
user_id: str,
name: str,
avatar_url: str,
short_description: str,
long_description: str,
) -> None:
await self.db_pool.simple_insert(
table="groups",
@ -1313,7 +1345,7 @@ class GroupServerStore(GroupServerWorkerStore):
desc="create_group",
)
async def update_group_profile(self, group_id, profile):
async def update_group_profile(self, group_id: str, profile: JsonDict) -> None:
await self.db_pool.simple_update_one(
table="groups",
keyvalues={"group_id": group_id},
@ -1361,8 +1393,8 @@ class GroupServerStore(GroupServerWorkerStore):
desc="remove_attestation_renewal",
)
def get_group_stream_token(self):
return self._group_updates_id_gen.get_current_token()
def get_group_stream_token(self) -> int:
return self._group_updates_id_gen.get_current_token() # type: ignore[attr-defined]
async def delete_group(self, group_id: str) -> None:
"""Deletes a group fully from the database.
@ -1371,7 +1403,7 @@ class GroupServerStore(GroupServerWorkerStore):
group_id: The group ID to delete.
"""
def _delete_group_txn(txn):
def _delete_group_txn(txn: LoggingTransaction) -> None:
tables = [
"groups",
"group_users",

View file

@ -156,7 +156,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
self.server_name: str = hs.hostname
async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
"""Get the metadata for a local piece of media

View file

@ -12,15 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.databases.main.registration import RegistrationWorkerStore
from synapse.util.caches.descriptors import cached
from synapse.util.threepids import canonicalise_email
@ -56,7 +58,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
Number of current monthly active users
"""
def _count_users(txn):
def _count_users(txn: LoggingTransaction) -> int:
# Exclude app service users
sql = """
SELECT COUNT(*)
@ -66,7 +68,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
WHERE (users.appservice_id IS NULL OR users.appservice_id = '');
"""
txn.execute(sql)
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction("count_users", _count_users)
@ -84,7 +86,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"""
def _count_users_by_service(txn):
def _count_users_by_service(txn: LoggingTransaction) -> Dict[str, int]:
sql = """
SELECT COALESCE(appservice_id, 'native'), COUNT(*)
FROM monthly_active_users
@ -93,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"""
txn.execute(sql)
result = txn.fetchall()
result = cast(List[Tuple[str, int]], txn.fetchall())
return dict(result)
return await self.db_pool.runInteraction(
@ -141,12 +143,12 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
)
@wrap_as_background_process("reap_monthly_active_users")
async def reap_monthly_active_users(self):
async def reap_monthly_active_users(self) -> None:
"""Cleans out monthly active user table to ensure that no stale
entries exist.
"""
def _reap_users(txn, reserved_users):
def _reap_users(txn: LoggingTransaction, reserved_users: List[str]) -> None:
"""
Args:
reserved_users (tuple): reserved users to preserve
@ -210,10 +212,10 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
# is racy.
# Have resolved to invalidate the whole cache for now and do
# something about it if and when the perf becomes significant
self._invalidate_all_cache_and_stream(
self._invalidate_all_cache_and_stream( # type: ignore[attr-defined]
txn, self.user_last_seen_monthly_active
)
self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) # type: ignore[attr-defined]
reserved_users = await self.get_registered_reserved_users()
await self.db_pool.runInteraction(
@ -221,7 +223,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
)
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore):
def __init__(
self,
database: DatabasePool,
@ -242,13 +244,15 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
)
def _initialise_reserved_users(self, txn, threepids):
def _initialise_reserved_users(
self, txn: LoggingTransaction, threepids: List[dict]
) -> None:
"""Ensures that reserved threepids are accounted for in the MAU table, should
be called on start up.
Args:
txn (cursor):
threepids (list[dict]): List of threepid dicts to reserve
txn:
threepids: List of threepid dicts to reserve
"""
# XXX what is this function trying to achieve? It upserts into
@ -299,7 +303,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
)
def upsert_monthly_active_user_txn(self, txn, user_id):
def upsert_monthly_active_user_txn(
self, txn: LoggingTransaction, user_id: str
) -> None:
"""Updates or inserts monthly active user member
We consciously do not call is_support_txn from this method because it
@ -336,7 +342,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
txn, self.user_last_seen_monthly_active, (user_id,)
)
async def populate_monthly_active_users(self, user_id):
async def populate_monthly_active_users(self, user_id: str) -> None:
"""Checks on the state of monthly active user limits and optionally
add the user to the monthly active tables
@ -345,7 +351,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"""
if self._limit_usage_by_mau or self._mau_stats_only:
# Trial users and guests should not be included as part of MAU group
is_guest = await self.is_guest(user_id)
is_guest = await self.is_guest(user_id) # type: ignore[attr-defined]
if is_guest:
return
is_trial = await self.is_trial_user(user_id)

View file

@ -24,10 +24,9 @@ from typing import (
Optional,
Set,
Tuple,
cast,
)
from twisted.internet import defer
from synapse.api.constants import ReceiptTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream
@ -38,7 +37,11 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.id_generators import (
AbstractStreamIdTracker,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@ -58,6 +61,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
hs: "HomeServer",
):
self._instance_name = hs.get_instance_name()
self._receipts_id_gen: AbstractStreamIdTracker
if isinstance(database.engine, PostgresEngine):
self._can_write_to_receipts = (
@ -161,7 +165,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
" AND user_id = ?"
)
txn.execute(sql, (user_id,))
return txn.fetchall()
return cast(List[Tuple[str, str, int, int]], txn.fetchall())
rows = await self.db_pool.runInteraction(
"get_receipts_for_user_with_orderings", f
@ -257,7 +261,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
if not rows:
return []
content = {}
content: JsonDict = {}
for row in rows:
content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
row["user_id"]
@ -305,7 +309,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"_get_linearized_receipts_for_rooms", f
)
results = {}
results: JsonDict = {}
for row in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
@ -370,7 +374,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"get_linearized_receipts_for_all_rooms", f
)
results = {}
results: JsonDict = {}
for row in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
@ -399,7 +403,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""
if last_id == current_id:
return defer.succeed([])
return []
def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
sql = """
@ -453,7 +457,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""
txn.execute(sql, (last_id, current_id, limit))
updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn]
updates = cast(
List[Tuple[int, list]],
[(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn],
)
limited = False
upper_bound = current_id
@ -496,7 +503,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
self.get_receipts_for_room.invalidate((room_id, receipt_type))
def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self,
stream_name: str,
instance_name: str,
token: int,
rows: Iterable[Any],
) -> None:
if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(instance_name, token)
for row in rows:
@ -584,7 +597,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
self._remove_old_push_actions_before_txn(
self._remove_old_push_actions_before_txn( # type: ignore[attr-defined]
txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
)
@ -637,7 +650,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"insert_receipt_conv", graph_to_linear
)
async with self._receipts_id_gen.get_next() as stream_id:
async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,

View file

@ -22,6 +22,7 @@ import attr
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import (
DatabasePool,
@ -123,7 +124,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
):
super().__init__(database, db_conn, hs)
self.config = hs.config
self.config: HomeServerConfig = hs.config
# Note: we don't check this sequence for consistency as we'd have to
# call `find_max_generated_user_id_localpart` each time, which is

View file

@ -27,7 +27,6 @@ from typing import (
)
import attr
from frozendict import frozendict
from synapse.api.constants import RelationTypes
from synapse.events import EventBase
@ -41,45 +40,15 @@ from synapse.storage.database import (
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine
from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
from synapse.types import JsonDict, RoomStreamToken, StreamToken
from synapse.types import RoomStreamToken, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation:
# The latest event in the thread.
latest_event: EventBase
# The latest edit to the latest event in the thread.
latest_edit: Optional[EventBase]
# The total number of events in the thread.
count: int
# True if the current user has sent an event to the thread.
current_user_participated: bool
@attr.s(slots=True, auto_attribs=True)
class BundledAggregations:
"""
The bundled aggregations for an event.
Some values require additional processing during serialization.
"""
annotations: Optional[JsonDict] = None
references: Optional[JsonDict] = None
replace: Optional[EventBase] = None
thread: Optional[_ThreadAggregation] = None
def __bool__(self) -> bool:
return bool(self.annotations or self.references or self.replace or self.thread)
class RelationsWorkerStore(SQLBaseStore):
def __init__(
self,
@ -384,7 +353,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
async def _get_applicable_edits(
async def get_applicable_edits(
self, event_ids: Collection[str]
) -> Dict[str, Optional[EventBase]]:
"""Get the most recent edit (if any) that has happened for the given
@ -473,7 +442,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
async def _get_thread_summaries(
async def get_thread_summaries(
self, event_ids: Collection[str]
) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
"""Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
@ -587,7 +556,7 @@ class RelationsWorkerStore(SQLBaseStore):
latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
# Check to see if any of those events are edited.
latest_edits = await self._get_applicable_edits(latest_event_ids.values())
latest_edits = await self.get_applicable_edits(latest_event_ids.values())
# Map to the event IDs to the thread summary.
#
@ -610,7 +579,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
async def _get_threads_participated(
async def get_threads_participated(
self, event_ids: Collection[str], user_id: str
) -> Dict[str, bool]:
"""Get whether the requesting user participated in the given threads.
@ -766,114 +735,6 @@ class RelationsWorkerStore(SQLBaseStore):
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
async def _get_bundled_aggregation_for_event(
self, event: EventBase, user_id: str
) -> Optional[BundledAggregations]:
"""Generate bundled aggregations for an event.
Note that this does not use a cache, but depends on cached methods.
Args:
event: The event to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
Returns:
The bundled aggregations for an event, if bundled aggregations are
enabled and the event can have bundled aggregations.
"""
# Do not bundle aggregations for an event which represents an edit or an
# annotation. It does not make sense for them to have related events.
relates_to = event.content.get("m.relates_to")
if isinstance(relates_to, (dict, frozendict)):
relation_type = relates_to.get("rel_type")
if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
return None
event_id = event.event_id
room_id = event.room_id
# The bundled aggregations to include, a mapping of relation type to a
# type-specific value. Some types include the direct return type here
# while others need more processing during serialization.
aggregations = BundledAggregations()
annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
if annotations.chunk:
aggregations.annotations = await annotations.to_dict(
cast("DataStore", self)
)
references = await self.get_relations_for_event(
event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
aggregations.references = await references.to_dict(cast("DataStore", self))
# Store the bundled aggregations in the event metadata for later use.
return aggregations
async def get_bundled_aggregations(
self, events: Iterable[EventBase], user_id: str
) -> Dict[str, BundledAggregations]:
"""Generate bundled aggregations for events.
Args:
events: The iterable of events to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
Returns:
A map of event ID to the bundled aggregation for the event. Not all
events may have bundled aggregations in the results.
"""
# De-duplicate events by ID to handle the same event requested multiple times.
#
# State events do not get bundled aggregations.
events_by_id = {
event.event_id: event for event in events if not event.is_state()
}
# event ID -> bundled aggregation in non-serialized form.
results: Dict[str, BundledAggregations] = {}
# Fetch other relations per event.
for event in events_by_id.values():
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
if event_result:
results[event.event_id] = event_result
# Fetch any edits (but not for redacted events).
edits = await self._get_applicable_edits(
[
event_id
for event_id, event in events_by_id.items()
if not event.internal_metadata.is_redacted()
]
)
for event_id, edit in edits.items():
results.setdefault(event_id, BundledAggregations()).replace = edit
# Fetch thread summaries.
summaries = await self._get_thread_summaries(events_by_id.keys())
# Only fetch participated for a limited selection based on what had
# summaries.
participated = await self._get_threads_participated(summaries.keys(), user_id)
for event_id, summary in summaries.items():
if summary:
thread_count, latest_thread_event, edit = summary
results.setdefault(
event_id, BundledAggregations()
).thread = _ThreadAggregation(
latest_event=latest_thread_event,
latest_edit=edit,
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
current_user_participated=participated[event_id],
)
return results
class RelationsStore(RelationsWorkerStore):
pass

View file

@ -34,6 +34,7 @@ import attr
from synapse.api.constants import EventContentFields, EventTypes, JoinRules
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
@ -98,7 +99,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
):
super().__init__(database, db_conn, hs)
self.config = hs.config
self.config: HomeServerConfig = hs.config
async def store_room(
self,

View file

@ -63,6 +63,14 @@ _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
@attr.s(frozen=True, slots=True, auto_attribs=True)
class EventIdMembership:
"""Returned by `get_membership_from_event_ids`"""
user_id: str
membership: str
class RoomMemberWorkerStore(EventsWorkerStore):
def __init__(
self,
@ -772,7 +780,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
retcols=("user_id", "display_name", "avatar_url", "event_id"),
keyvalues={"membership": Membership.JOIN},
batch_size=500,
desc="_get_membership_from_event_ids",
desc="_get_joined_profiles_from_event_ids",
)
return {
@ -1000,12 +1008,26 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return set(room_ids)
@cached(max_entries=5000)
async def _get_membership_from_event_id(
self, member_event_id: str
) -> Optional[EventIdMembership]:
raise NotImplementedError()
@cachedList(
cached_method_name="_get_membership_from_event_id", list_name="member_event_ids"
)
async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
) -> List[dict]:
"""Get user_id and membership of a set of event IDs."""
) -> Dict[str, Optional[EventIdMembership]]:
"""Get user_id and membership of a set of event IDs.
return await self.db_pool.simple_select_many_batch(
Returns:
Mapping from event ID to `EventIdMembership` if the event is a
membership event, otherwise the value is None.
"""
rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
@ -1015,6 +1037,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
desc="get_membership_from_event_ids",
)
return {
row["event_id"]: EventIdMembership(
membership=row["membership"], user_id=row["user_id"]
)
for row in rows
}
async def is_local_host_in_room_ignoring_users(
self, room_id: str, ignore_users: Collection[str]
) -> bool:

View file

@ -14,7 +14,7 @@
import logging
import re
from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set
import attr
@ -74,7 +74,7 @@ class SearchWorkerStore(SQLBaseStore):
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
)
args = (
args1 = (
(
entry.event_id,
entry.room_id,
@ -86,14 +86,14 @@ class SearchWorkerStore(SQLBaseStore):
for entry in entries
)
txn.execute_batch(sql, args)
txn.execute_batch(sql, args1)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
args = (
args2 = (
(
entry.event_id,
entry.room_id,
@ -102,7 +102,7 @@ class SearchWorkerStore(SQLBaseStore):
)
for entry in entries
)
txn.execute_batch(sql, args)
txn.execute_batch(sql, args2)
else:
# This should be unreachable.
@ -427,7 +427,7 @@ class SearchStore(SearchBackgroundUpdateStore):
search_query = _parse_query(self.database_engine, search_term)
args = []
args: List[Any] = []
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
@ -496,7 +496,7 @@ class SearchStore(SearchBackgroundUpdateStore):
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak)
events = await self.get_events_as_list(
events = await self.get_events_as_list( # type: ignore[attr-defined]
[r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK,
)
@ -530,7 +530,7 @@ class SearchStore(SearchBackgroundUpdateStore):
room_ids: Collection[str],
search_term: str,
keys: Iterable[str],
limit,
limit: int,
pagination_token: Optional[str] = None,
) -> JsonDict:
"""Performs a full text search over events with given keys.
@ -549,7 +549,7 @@ class SearchStore(SearchBackgroundUpdateStore):
search_query = _parse_query(self.database_engine, search_term)
args = []
args: List[Any] = []
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
@ -573,9 +573,9 @@ class SearchStore(SearchBackgroundUpdateStore):
if pagination_token:
try:
origin_server_ts, stream = pagination_token.split(",")
origin_server_ts = int(origin_server_ts)
stream = int(stream)
origin_server_ts_str, stream_str = pagination_token.split(",")
origin_server_ts = int(origin_server_ts_str)
stream = int(stream_str)
except Exception:
raise SynapseError(400, "Invalid pagination token")
@ -654,7 +654,7 @@ class SearchStore(SearchBackgroundUpdateStore):
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak)
events = await self.get_events_as_list(
events = await self.get_events_as_list( # type: ignore[attr-defined]
[r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK,
)

View file

@ -14,7 +14,7 @@
# limitations under the License.
import collections.abc
import logging
from typing import TYPE_CHECKING, Iterable, Optional, Set
from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
@ -29,7 +29,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter
from synapse.types import StateMap
from synapse.types import JsonDict, StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
@ -241,7 +241,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# We delegate to the cached version
return await self.get_current_state_ids(room_id)
def _get_filtered_current_state_ids_txn(txn):
def _get_filtered_current_state_ids_txn(
txn: LoggingTransaction,
) -> StateMap[str]:
results = {}
sql = """
SELECT type, state_key, event_id FROM current_state_events
@ -281,11 +283,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
event_id = state.get((EventTypes.CanonicalAlias, ""))
if not event_id:
return
return None
event = await self.get_event(event_id, allow_none=True)
if not event:
return
return None
return event.content.get("canonical_alias")
@ -304,7 +306,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
list_name="event_ids",
num_args=1,
)
async def _get_state_group_for_events(self, event_ids):
async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict:
"""Returns mapping event_id -> state_group"""
rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
@ -355,7 +357,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
self.server_name: str = hs.hostname
self.db_pool.updates.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
@ -375,7 +377,9 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
self._background_remove_left_rooms,
)
async def _background_remove_left_rooms(self, progress, batch_size):
async def _background_remove_left_rooms(
self, progress: JsonDict, batch_size: int
) -> int:
"""Background update to delete rows from `current_state_events` and
`event_forward_extremities` tables of rooms that the server is no
longer joined to.
@ -383,7 +387,9 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
last_room_id = progress.get("last_room_id", "")
def _background_remove_left_rooms_txn(txn):
def _background_remove_left_rooms_txn(
txn: LoggingTransaction,
) -> Tuple[bool, Set[str]]:
# get a batch of room ids to consider
sql = """
SELECT DISTINCT room_id FROM current_state_events

View file

@ -108,7 +108,7 @@ class StatsStore(StateDeltasStore):
):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
self.server_name: str = hs.hostname
self.clock = self.hs.get_clock()
self.stats_enabled = hs.config.stats.stats_enabled

View file

@ -26,6 +26,8 @@ from typing import (
cast,
)
from typing_extensions import TypedDict
from synapse.api.errors import StoreError
if TYPE_CHECKING:
@ -40,7 +42,12 @@ from synapse.storage.database import (
from synapse.storage.databases.main.state import StateFilter
from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
from synapse.types import (
JsonDict,
UserProfile,
get_domain_from_id,
get_localpart_from_id,
)
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@ -61,7 +68,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
) -> None:
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
self.server_name: str = hs.hostname
self.db_pool.updates.register_background_update_handler(
"populate_user_directory_createtables",
@ -591,6 +598,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
class SearchResult(TypedDict):
limited: bool
results: List[UserProfile]
class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# How many records do we calculate before sending it to
# add_users_who_share_private_rooms?
@ -718,7 +730,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows)
return list(users)
async def get_shared_rooms_for_users(
async def get_mutual_rooms_for_users(
self, user_id: str, other_user_id: str
) -> Set[str]:
"""
@ -732,7 +744,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
A set of room ID's that the users share.
"""
def _get_shared_rooms_for_users_txn(
def _get_mutual_rooms_for_users_txn(
txn: LoggingTransaction,
) -> List[Dict[str, str]]:
txn.execute(
@ -756,7 +768,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
return rows
rows = await self.db_pool.runInteraction(
"get_shared_rooms_for_users", _get_shared_rooms_for_users_txn
"get_mutual_rooms_for_users", _get_mutual_rooms_for_users_txn
)
return {row["room_id"] for row in rows}
@ -777,7 +789,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
async def search_user_dir(
self, user_id: str, search_term: str, limit: int
) -> JsonDict:
) -> SearchResult:
"""Searches for users in directory
Returns:
@ -910,8 +922,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# This should be unreachable.
raise Exception("Unrecognized database engine")
results = await self.db_pool.execute(
"search_user_dir", self.db_pool.cursor_to_dict, sql, *args
results = cast(
List[UserProfile],
await self.db_pool.execute(
"search_user_dir", self.db_pool.cursor_to_dict, sql, *args
),
)
limited = len(results) > limit

View file

@ -27,7 +27,7 @@ def create_engine(database_config) -> BaseDatabaseEngine:
if name == "psycopg2":
# Note that psycopg2cffi-compat provides the psycopg2 module on pypy.
import psycopg2 # type: ignore
import psycopg2
return PostgresEngine(psycopg2, database_config)

View file

@ -47,17 +47,26 @@ class PostgresEngine(BaseDatabaseEngine):
self.default_isolation_level = (
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
self.config = database_config
@property
def single_threaded(self) -> bool:
return False
def get_db_locale(self, txn):
txn.execute(
"SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
)
collation, ctype = txn.fetchone()
return collation, ctype
def check_database(self, db_conn, allow_outdated_version: bool = False):
# Get the version of PostgreSQL that we're using. As per the psycopg2
# docs: The number is formed by converting the major, minor, and
# revision numbers into two-decimal-digit numbers and appending them
# together. For example, version 8.1.5 will be returned as 80105
self._version = db_conn.server_version
allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
# Are we on a supported PostgreSQL version?
if not allow_outdated_version and self._version < 100000:
@ -72,33 +81,39 @@ class PostgresEngine(BaseDatabaseEngine):
"See docs/postgres.md for more information." % (rows[0][0],)
)
txn.execute(
"SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
)
collation, ctype = txn.fetchone()
collation, ctype = self.get_db_locale(txn)
if collation != "C":
logger.warning(
"Database has incorrect collation of %r. Should be 'C'\n"
"See docs/postgres.md for more information.",
"Database has incorrect collation of %r. Should be 'C'",
collation,
)
if not allow_unsafe_locale:
raise IncorrectDatabaseSetup(
"Database has incorrect collation of %r. Should be 'C'\n"
"See docs/postgres.md for more information. You can override this check by"
"setting 'allow_unsafe_locale' to true in the database config.",
collation,
)
if ctype != "C":
logger.warning(
"Database has incorrect ctype of %r. Should be 'C'\n"
"See docs/postgres.md for more information.",
ctype,
)
if not allow_unsafe_locale:
logger.warning(
"Database has incorrect ctype of %r. Should be 'C'",
ctype,
)
raise IncorrectDatabaseSetup(
"Database has incorrect ctype of %r. Should be 'C'\n"
"See docs/postgres.md for more information. You can override this check by"
"setting 'allow_unsafe_locale' to true in the database config.",
ctype,
)
def check_new_database(self, txn):
"""Gets called when setting up a brand new database. This allows us to
apply stricter checks on new databases versus existing database.
"""
txn.execute(
"SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
)
collation, ctype = txn.fetchone()
collation, ctype = self.get_db_locale(txn)
errors = []

View file

@ -1023,8 +1023,13 @@ class EventsPersistenceStorage:
# Check if any of the changes that we don't have events for are joins.
if events_to_check:
rows = await self.main_store.get_membership_from_event_ids(events_to_check)
is_still_joined = any(row["membership"] == Membership.JOIN for row in rows)
members = await self.main_store.get_membership_from_event_ids(
events_to_check
)
is_still_joined = any(
member and member.membership == Membership.JOIN
for member in members.values()
)
if is_still_joined:
return True
@ -1060,9 +1065,11 @@ class EventsPersistenceStorage:
), event_id in current_state.items()
if typ == EventTypes.Member and not self.is_mine_id(state_key)
]
rows = await self.main_store.get_membership_from_event_ids(remote_event_ids)
members = await self.main_store.get_membership_from_event_ids(remote_event_ids)
potentially_left_users.update(
row["user_id"] for row in rows if row["membership"] == Membership.JOIN
member.user_id
for member in members.values()
if member and member.membership == Membership.JOIN
)
return False

View file

@ -34,6 +34,7 @@ from typing import (
import attr
from frozendict import frozendict
from signedjson.key import decode_verify_key_bytes
from typing_extensions import TypedDict
from unpaddedbase64 import decode_base64
from zope.interface import Interface
@ -63,6 +64,10 @@ MutableStateMap = MutableMapping[StateKey, T]
# JSON types. These could be made stronger, but will do for now.
# A JSON-serialisable dict.
JsonDict = Dict[str, Any]
# A JSON-serialisable mapping; roughly speaking an immutable JSONDict.
# Useful when you have a TypedDict which isn't going to be mutated and you don't want
# to cast to JsonDict everywhere.
JsonMapping = Mapping[str, Any]
# A JSON-serialisable object.
JsonSerializable = object
@ -791,3 +796,9 @@ class UserInfo:
is_deactivated: bool
is_guest: bool
is_shadow_banned: bool
class UserProfile(TypedDict):
user_id: str
display_name: Optional[str]
avatar_url: Optional[str]

View file

@ -128,6 +128,19 @@ def _incorrect_version(
)
def _no_reported_version(requirement: Requirement, extra: Optional[str] = None) -> str:
if extra:
return (
f"Synapse {VERSION} needs {requirement} for {extra}, "
f"but can't determine {requirement.name}'s version"
)
else:
return (
f"Synapse {VERSION} needs {requirement}, "
f"but can't determine {requirement.name}'s version"
)
def check_requirements(extra: Optional[str] = None) -> None:
"""Check Synapse's dependencies are present and correctly versioned.
@ -163,8 +176,17 @@ def check_requirements(extra: Optional[str] = None) -> None:
deps_unfulfilled.append(requirement.name)
errors.append(_not_installed(requirement, extra))
else:
if dist.version is None:
# This shouldn't happen---it suggests a borked virtualenv. (See #12223)
# Try to give a vaguely helpful error message anyway.
# Type-ignore: the annotations don't reflect reality: see
# https://github.com/python/typeshed/issues/7513
# https://bugs.python.org/issue47060
deps_unfulfilled.append(requirement.name) # type: ignore[unreachable]
errors.append(_no_reported_version(requirement, extra))
# We specify prereleases=True to allow prereleases such as RCs.
if not requirement.specifier.contains(dist.version, prereleases=True):
elif not requirement.specifier.contains(dist.version, prereleases=True):
deps_unfulfilled.append(requirement.name)
errors.append(_incorrect_version(requirement, dist.version, extra))

View file

@ -14,12 +14,7 @@
import logging
from typing import Dict, FrozenSet, List, Optional
from synapse.api.constants import (
AccountDataTypes,
EventTypes,
HistoryVisibility,
Membership,
)
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.events import EventBase
from synapse.events.utils import prune_event
from synapse.storage import Storage
@ -87,15 +82,8 @@ async def filter_events_for_client(
state_filter=StateFilter.from_types(types),
)
ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user(
user_id, AccountDataTypes.IGNORED_USER_LIST
)
ignore_list: FrozenSet[str] = frozenset()
if ignore_dict_content:
ignored_users_dict = ignore_dict_content.get("ignored_users", {})
if isinstance(ignored_users_dict, dict):
ignore_list = frozenset(ignored_users_dict.keys())
# Get the users who are ignored by the requesting user.
ignore_list = await storage.main.ignored_users(user_id)
erased_senders = await storage.main.are_users_erased(e.sender for e in events)