mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-09-21 07:24:34 -04:00
Merge remote-tracking branch 'upstream/release-v1.42'
This commit is contained in:
commit
c75eed92c9
157 changed files with 7799 additions and 3658 deletions
|
@ -47,7 +47,7 @@ try:
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
__version__ = "1.41.1"
|
||||
__version__ = "1.42.0rc1"
|
||||
|
||||
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
||||
# We import here so that we don't have to install a bunch of deps when
|
||||
|
|
|
@ -79,6 +79,7 @@ class LoginType:
|
|||
TERMS = "m.login.terms"
|
||||
SSO = "m.login.sso"
|
||||
DUMMY = "m.login.dummy"
|
||||
REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token"
|
||||
|
||||
|
||||
# This is used in the `type` parameter for /register when called by
|
||||
|
|
|
@ -147,6 +147,14 @@ class SynapseError(CodeMessageException):
|
|||
return cs_error(self.msg, self.errcode)
|
||||
|
||||
|
||||
class InvalidAPICallError(SynapseError):
|
||||
"""You called an existing API endpoint, but fed that endpoint
|
||||
invalid or incomplete data."""
|
||||
|
||||
def __init__(self, msg: str):
|
||||
super().__init__(HTTPStatus.BAD_REQUEST, msg, Codes.BAD_JSON)
|
||||
|
||||
|
||||
class ProxiedRequestError(SynapseError):
|
||||
"""An error from a general matrix endpoint, eg. from a proxied Matrix API call.
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@ from synapse.app import check_bind_error
|
|||
from synapse.app.phone_stats_home import start_phone_stats_home
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.crypto import context_factory
|
||||
from synapse.events.presence_router import load_legacy_presence_router
|
||||
from synapse.events.spamcheck import load_legacy_spam_checkers
|
||||
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
|
||||
from synapse.logging.context import PreserveLoggingContext
|
||||
|
@ -370,6 +371,7 @@ async def start(hs: "HomeServer"):
|
|||
|
||||
load_legacy_spam_checkers(hs)
|
||||
load_legacy_third_party_event_rules(hs)
|
||||
load_legacy_presence_router(hs)
|
||||
|
||||
# If we've configured an expiry time for caches, start the background job now.
|
||||
setup_expire_lru_cache_entries(hs)
|
||||
|
|
|
@ -95,7 +95,10 @@ from synapse.rest.client.profile import (
|
|||
ProfileRestServlet,
|
||||
)
|
||||
from synapse.rest.client.push_rule import PushRuleRestServlet
|
||||
from synapse.rest.client.register import RegisterRestServlet
|
||||
from synapse.rest.client.register import (
|
||||
RegisterRestServlet,
|
||||
RegistrationTokenValidityRestServlet,
|
||||
)
|
||||
from synapse.rest.client.sendtodevice import SendToDeviceRestServlet
|
||||
from synapse.rest.client.versions import VersionsRestServlet
|
||||
from synapse.rest.client.voip import VoipRestServlet
|
||||
|
@ -115,6 +118,7 @@ from synapse.storage.databases.main.monthly_active_users import (
|
|||
from synapse.storage.databases.main.presence import PresenceStore
|
||||
from synapse.storage.databases.main.room import RoomWorkerStore
|
||||
from synapse.storage.databases.main.search import SearchStore
|
||||
from synapse.storage.databases.main.session import SessionStore
|
||||
from synapse.storage.databases.main.stats import StatsStore
|
||||
from synapse.storage.databases.main.transactions import TransactionWorkerStore
|
||||
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
|
||||
|
@ -250,6 +254,7 @@ class GenericWorkerSlavedStore(
|
|||
SearchStore,
|
||||
TransactionWorkerStore,
|
||||
LockStore,
|
||||
SessionStore,
|
||||
BaseSlavedStore,
|
||||
):
|
||||
pass
|
||||
|
@ -279,6 +284,7 @@ class GenericWorkerServer(HomeServer):
|
|||
resource = JsonResource(self, canonical_json=False)
|
||||
|
||||
RegisterRestServlet(self).register(resource)
|
||||
RegistrationTokenValidityRestServlet(self).register(resource)
|
||||
login.register_servlets(self, resource)
|
||||
ThreepidRestServlet(self).register(resource)
|
||||
DevicesRestServlet(self).register(resource)
|
||||
|
|
|
@ -39,5 +39,8 @@ class ExperimentalConfig(Config):
|
|||
# MSC3244 (room version capabilities)
|
||||
self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True)
|
||||
|
||||
# MSC3283 (set displayname, avatar_url and change 3pid capabilities)
|
||||
self.msc3283_enabled: bool = experimental.get("msc3283_enabled", False)
|
||||
|
||||
# MSC3266 (room summary api)
|
||||
self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False)
|
||||
|
|
|
@ -79,6 +79,11 @@ class RatelimitConfig(Config):
|
|||
|
||||
self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
|
||||
|
||||
self.rc_registration_token_validity = RateLimitConfig(
|
||||
config.get("rc_registration_token_validity", {}),
|
||||
defaults={"per_second": 0.1, "burst_count": 5},
|
||||
)
|
||||
|
||||
rc_login_config = config.get("rc_login", {})
|
||||
self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
|
||||
self.rc_login_account = RateLimitConfig(rc_login_config.get("account", {}))
|
||||
|
@ -143,6 +148,8 @@ class RatelimitConfig(Config):
|
|||
# is using
|
||||
# - one for registration that ratelimits registration requests based on the
|
||||
# client's IP address.
|
||||
# - one for checking the validity of registration tokens that ratelimits
|
||||
# requests based on the client's IP address.
|
||||
# - one for login that ratelimits login requests based on the client's IP
|
||||
# address.
|
||||
# - one for login that ratelimits login requests based on the account the
|
||||
|
@ -171,6 +178,10 @@ class RatelimitConfig(Config):
|
|||
# per_second: 0.17
|
||||
# burst_count: 3
|
||||
#
|
||||
#rc_registration_token_validity:
|
||||
# per_second: 0.1
|
||||
# burst_count: 5
|
||||
#
|
||||
#rc_login:
|
||||
# address:
|
||||
# per_second: 0.17
|
||||
|
|
|
@ -33,6 +33,9 @@ class RegistrationConfig(Config):
|
|||
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
|
||||
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
|
||||
self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
|
||||
self.registration_requires_token = config.get(
|
||||
"registration_requires_token", False
|
||||
)
|
||||
self.registration_shared_secret = config.get("registration_shared_secret")
|
||||
|
||||
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
||||
|
@ -140,6 +143,9 @@ class RegistrationConfig(Config):
|
|||
"mechanism by removing the `access_token_lifetime` option."
|
||||
)
|
||||
|
||||
# The fallback template used for authenticating using a registration token
|
||||
self.registration_token_template = self.read_template("registration_token.html")
|
||||
|
||||
# The success template used during fallback auth.
|
||||
self.fallback_success_template = self.read_template("auth_success.html")
|
||||
|
||||
|
@ -199,6 +205,15 @@ class RegistrationConfig(Config):
|
|||
#
|
||||
#enable_3pid_lookup: true
|
||||
|
||||
# Require users to submit a token during registration.
|
||||
# Tokens can be managed using the admin API:
|
||||
# https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/registration_tokens.html
|
||||
# Note that `enable_registration` must be set to `true`.
|
||||
# Disabling this option will not delete any tokens previously generated.
|
||||
# Defaults to false. Uncomment the following to require tokens:
|
||||
#
|
||||
#registration_requires_token: true
|
||||
|
||||
# If set, allows registration of standard or admin accounts by anyone who
|
||||
# has the shared secret, even if registration is otherwise disabled.
|
||||
#
|
||||
|
|
|
@ -248,6 +248,7 @@ class ServerConfig(Config):
|
|||
self.use_presence = config.get("use_presence", True)
|
||||
|
||||
# Custom presence router module
|
||||
# This is the legacy way of configuring it (the config should now be put in the modules section)
|
||||
self.presence_router_module_class = None
|
||||
self.presence_router_config = None
|
||||
presence_router_config = presence_config.get("presence_router")
|
||||
|
@ -870,20 +871,6 @@ class ServerConfig(Config):
|
|||
#
|
||||
#enabled: false
|
||||
|
||||
# Presence routers are third-party modules that can specify additional logic
|
||||
# to where presence updates from users are routed.
|
||||
#
|
||||
presence_router:
|
||||
# The custom module's class. Uncomment to use a custom presence router module.
|
||||
#
|
||||
#module: "my_custom_router.PresenceRouter"
|
||||
|
||||
# Configuration options of the custom module. Refer to your module's
|
||||
# documentation for available options.
|
||||
#
|
||||
#config:
|
||||
# example_option: 'something'
|
||||
|
||||
# Whether to require authentication to retrieve profile data (avatars,
|
||||
# display names) of other users through the client API. Defaults to
|
||||
# 'false'. Note that profile data is also available via the federation
|
||||
|
|
|
@ -11,45 +11,115 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, Set, Union
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Union,
|
||||
)
|
||||
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
GET_USERS_FOR_STATES_CALLBACK = Callable[
|
||||
[Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]]
|
||||
]
|
||||
GET_INTERESTED_USERS_CALLBACK = Callable[
|
||||
[str], Awaitable[Union[Set[str], "PresenceRouter.ALL_USERS"]]
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_legacy_presence_router(hs: "HomeServer"):
|
||||
"""Wrapper that loads a presence router module configured using the old
|
||||
configuration, and registers the hooks they implement.
|
||||
"""
|
||||
|
||||
if hs.config.presence_router_module_class is None:
|
||||
return
|
||||
|
||||
module = hs.config.presence_router_module_class
|
||||
config = hs.config.presence_router_config
|
||||
api = hs.get_module_api()
|
||||
|
||||
presence_router = module(config=config, module_api=api)
|
||||
|
||||
# The known hooks. If a module implements a method which name appears in this set,
|
||||
# we'll want to register it.
|
||||
presence_router_methods = {
|
||||
"get_users_for_states",
|
||||
"get_interested_users",
|
||||
}
|
||||
|
||||
# All methods that the module provides should be async, but this wasn't enforced
|
||||
# in the old module system, so we wrap them if needed
|
||||
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
|
||||
# f might be None if the callback isn't implemented by the module. In this
|
||||
# case we don't want to register a callback at all so we return None.
|
||||
if f is None:
|
||||
return None
|
||||
|
||||
def run(*args, **kwargs):
|
||||
# mypy doesn't do well across function boundaries so we need to tell it
|
||||
# f is definitely not None.
|
||||
assert f is not None
|
||||
|
||||
return maybe_awaitable(f(*args, **kwargs))
|
||||
|
||||
return run
|
||||
|
||||
# Register the hooks through the module API.
|
||||
hooks = {
|
||||
hook: async_wrapper(getattr(presence_router, hook, None))
|
||||
for hook in presence_router_methods
|
||||
}
|
||||
|
||||
api.register_presence_router_callbacks(**hooks)
|
||||
|
||||
|
||||
class PresenceRouter:
|
||||
"""
|
||||
A module that the homeserver will call upon to help route user presence updates to
|
||||
additional destinations. If a custom presence router is configured, calls will be
|
||||
passed to that instead.
|
||||
additional destinations.
|
||||
"""
|
||||
|
||||
ALL_USERS = "ALL"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.custom_presence_router = None
|
||||
# Initially there are no callbacks
|
||||
self._get_users_for_states_callbacks: List[GET_USERS_FOR_STATES_CALLBACK] = []
|
||||
self._get_interested_users_callbacks: List[GET_INTERESTED_USERS_CALLBACK] = []
|
||||
|
||||
# Check whether a custom presence router module has been configured
|
||||
if hs.config.presence_router_module_class:
|
||||
# Initialise the module
|
||||
self.custom_presence_router = hs.config.presence_router_module_class(
|
||||
config=hs.config.presence_router_config, module_api=hs.get_module_api()
|
||||
def register_presence_router_callbacks(
|
||||
self,
|
||||
get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None,
|
||||
get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None,
|
||||
):
|
||||
# PresenceRouter modules are required to implement both of these methods
|
||||
# or neither of them as they are assumed to act in a complementary manner
|
||||
paired_methods = [get_users_for_states, get_interested_users]
|
||||
if paired_methods.count(None) == 1:
|
||||
raise RuntimeError(
|
||||
"PresenceRouter modules must register neither or both of the paired callbacks: "
|
||||
"[get_users_for_states, get_interested_users]"
|
||||
)
|
||||
|
||||
# Ensure the module has implemented the required methods
|
||||
required_methods = ["get_users_for_states", "get_interested_users"]
|
||||
for method_name in required_methods:
|
||||
if not hasattr(self.custom_presence_router, method_name):
|
||||
raise Exception(
|
||||
"PresenceRouter module '%s' must implement all required methods: %s"
|
||||
% (
|
||||
hs.config.presence_router_module_class.__name__,
|
||||
", ".join(required_methods),
|
||||
)
|
||||
)
|
||||
# Append the methods provided to the lists of callbacks
|
||||
if get_users_for_states is not None:
|
||||
self._get_users_for_states_callbacks.append(get_users_for_states)
|
||||
|
||||
if get_interested_users is not None:
|
||||
self._get_interested_users_callbacks.append(get_interested_users)
|
||||
|
||||
async def get_users_for_states(
|
||||
self,
|
||||
|
@ -66,14 +136,40 @@ class PresenceRouter:
|
|||
A dictionary of user_id -> set of UserPresenceState, indicating which
|
||||
presence updates each user should receive.
|
||||
"""
|
||||
if self.custom_presence_router is not None:
|
||||
# Ask the custom module
|
||||
return await self.custom_presence_router.get_users_for_states(
|
||||
state_updates=state_updates
|
||||
)
|
||||
|
||||
# Don't include any extra destinations for presence updates
|
||||
return {}
|
||||
# Bail out early if we don't have any callbacks to run.
|
||||
if len(self._get_users_for_states_callbacks) == 0:
|
||||
# Don't include any extra destinations for presence updates
|
||||
return {}
|
||||
|
||||
users_for_states = {}
|
||||
# run all the callbacks for get_users_for_states and combine the results
|
||||
for callback in self._get_users_for_states_callbacks:
|
||||
try:
|
||||
result = await callback(state_updates)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to run module API callback %s: %s", callback, e)
|
||||
continue
|
||||
|
||||
if not isinstance(result, Dict):
|
||||
logger.warning(
|
||||
"Wrong type returned by module API callback %s: %s, expected Dict",
|
||||
callback,
|
||||
result,
|
||||
)
|
||||
continue
|
||||
|
||||
for key, new_entries in result.items():
|
||||
if not isinstance(new_entries, Set):
|
||||
logger.warning(
|
||||
"Wrong type returned by module API callback %s: %s, expected Set",
|
||||
callback,
|
||||
new_entries,
|
||||
)
|
||||
break
|
||||
users_for_states.setdefault(key, set()).update(new_entries)
|
||||
|
||||
return users_for_states
|
||||
|
||||
async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]:
|
||||
"""
|
||||
|
@ -92,12 +188,36 @@ class PresenceRouter:
|
|||
A set of user IDs to return presence updates for, or ALL_USERS to return all
|
||||
known updates.
|
||||
"""
|
||||
if self.custom_presence_router is not None:
|
||||
# Ask the custom module for interested users
|
||||
return await self.custom_presence_router.get_interested_users(
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# A custom presence router is not defined.
|
||||
# Don't report any additional interested users
|
||||
return set()
|
||||
# Bail out early if we don't have any callbacks to run.
|
||||
if len(self._get_interested_users_callbacks) == 0:
|
||||
# Don't report any additional interested users
|
||||
return set()
|
||||
|
||||
interested_users = set()
|
||||
# run all the callbacks for get_interested_users and combine the results
|
||||
for callback in self._get_interested_users_callbacks:
|
||||
try:
|
||||
result = await callback(user_id)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to run module API callback %s: %s", callback, e)
|
||||
continue
|
||||
|
||||
# If one of the callbacks returns ALL_USERS then we can stop calling all
|
||||
# of the other callbacks, since the set of interested_users is already as
|
||||
# large as it can possibly be
|
||||
if result == PresenceRouter.ALL_USERS:
|
||||
return PresenceRouter.ALL_USERS
|
||||
|
||||
if not isinstance(result, Set):
|
||||
logger.warning(
|
||||
"Wrong type returned by module API callback %s: %s, expected set",
|
||||
callback,
|
||||
result,
|
||||
)
|
||||
continue
|
||||
|
||||
# Add the new interested users to the set
|
||||
interested_users.update(result)
|
||||
|
||||
return interested_users
|
||||
|
|
|
@ -32,6 +32,9 @@ from . import EventBase
|
|||
# the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar"
|
||||
SPLIT_FIELD_REGEX = re.compile(r"(?<!\\)\.")
|
||||
|
||||
CANONICALJSON_MAX_INT = (2 ** 53) - 1
|
||||
CANONICALJSON_MIN_INT = -CANONICALJSON_MAX_INT
|
||||
|
||||
|
||||
def prune_event(event: EventBase) -> EventBase:
|
||||
"""Returns a pruned version of the given event, which removes all keys we
|
||||
|
@ -505,7 +508,7 @@ def validate_canonicaljson(value: Any):
|
|||
* NaN, Infinity, -Infinity
|
||||
"""
|
||||
if isinstance(value, int):
|
||||
if value <= -(2 ** 53) or 2 ** 53 <= value:
|
||||
if value < CANONICALJSON_MIN_INT or CANONICALJSON_MAX_INT < value:
|
||||
raise SynapseError(400, "JSON integer out of range", Codes.BAD_JSON)
|
||||
|
||||
elif isinstance(value, float):
|
||||
|
|
|
@ -11,16 +11,22 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections.abc
|
||||
from typing import Union
|
||||
|
||||
import jsonschema
|
||||
|
||||
from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.api.room_versions import EventFormatVersions
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.builder import EventBuilder
|
||||
from synapse.events.utils import validate_canonicaljson
|
||||
from synapse.events.utils import (
|
||||
CANONICALJSON_MAX_INT,
|
||||
CANONICALJSON_MIN_INT,
|
||||
validate_canonicaljson,
|
||||
)
|
||||
from synapse.federation.federation_server import server_matches_acl_event
|
||||
from synapse.types import EventID, RoomID, UserID
|
||||
|
||||
|
@ -93,6 +99,29 @@ class EventValidator:
|
|||
400, "Can't create an ACL event that denies the local server"
|
||||
)
|
||||
|
||||
if event.type == EventTypes.PowerLevels:
|
||||
try:
|
||||
jsonschema.validate(
|
||||
instance=event.content,
|
||||
schema=POWER_LEVELS_SCHEMA,
|
||||
cls=plValidator,
|
||||
)
|
||||
except jsonschema.ValidationError as e:
|
||||
if e.path:
|
||||
# example: "users_default": '0' is not of type 'integer'
|
||||
message = '"' + e.path[-1] + '": ' + e.message # noqa: B306
|
||||
# jsonschema.ValidationError.message is a valid attribute
|
||||
else:
|
||||
# example: '0' is not of type 'integer'
|
||||
message = e.message # noqa: B306
|
||||
# jsonschema.ValidationError.message is a valid attribute
|
||||
|
||||
raise SynapseError(
|
||||
code=400,
|
||||
msg=message,
|
||||
errcode=Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
def _validate_retention(self, event: EventBase):
|
||||
"""Checks that an event that defines the retention policy for a room respects the
|
||||
format enforced by the spec.
|
||||
|
@ -195,3 +224,47 @@ class EventValidator:
|
|||
def _ensure_state_event(self, event):
|
||||
if not event.is_state():
|
||||
raise SynapseError(400, "'%s' must be state events" % (event.type,))
|
||||
|
||||
|
||||
POWER_LEVELS_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"ban": {"$ref": "#/definitions/int"},
|
||||
"events": {"$ref": "#/definitions/objectOfInts"},
|
||||
"events_default": {"$ref": "#/definitions/int"},
|
||||
"invite": {"$ref": "#/definitions/int"},
|
||||
"kick": {"$ref": "#/definitions/int"},
|
||||
"notifications": {"$ref": "#/definitions/objectOfInts"},
|
||||
"redact": {"$ref": "#/definitions/int"},
|
||||
"state_default": {"$ref": "#/definitions/int"},
|
||||
"users": {"$ref": "#/definitions/objectOfInts"},
|
||||
"users_default": {"$ref": "#/definitions/int"},
|
||||
},
|
||||
"definitions": {
|
||||
"int": {
|
||||
"type": "integer",
|
||||
"minimum": CANONICALJSON_MIN_INT,
|
||||
"maximum": CANONICALJSON_MAX_INT,
|
||||
},
|
||||
"objectOfInts": {
|
||||
"type": "object",
|
||||
"additionalProperties": {"$ref": "#/definitions/int"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _create_power_level_validator():
|
||||
validator = jsonschema.validators.validator_for(POWER_LEVELS_SCHEMA)
|
||||
|
||||
# by default jsonschema does not consider a frozendict to be an object so
|
||||
# we need to use a custom type checker
|
||||
# https://python-jsonschema.readthedocs.io/en/stable/validate/?highlight=object#validating-with-additional-types
|
||||
type_checker = validator.TYPE_CHECKER.redefine(
|
||||
"object", lambda checker, thing: isinstance(thing, collections.abc.Mapping)
|
||||
)
|
||||
|
||||
return jsonschema.validators.extend(validator, type_checker=type_checker)
|
||||
|
||||
|
||||
plValidator = _create_power_level_validator()
|
||||
|
|
|
@ -43,6 +43,7 @@ from synapse.api.errors import (
|
|||
Codes,
|
||||
FederationDeniedError,
|
||||
HttpResponseException,
|
||||
RequestSendFailed,
|
||||
SynapseError,
|
||||
UnsupportedRoomVersionError,
|
||||
)
|
||||
|
@ -110,6 +111,23 @@ class FederationClient(FederationBase):
|
|||
reset_expiry_on_get=False,
|
||||
)
|
||||
|
||||
# A cache for fetching the room hierarchy over federation.
|
||||
#
|
||||
# Some stale data over federation is OK, but must be refreshed
|
||||
# periodically since the local server is in the room.
|
||||
#
|
||||
# It is a map of (room ID, suggested-only) -> the response of
|
||||
# get_room_hierarchy.
|
||||
self._get_room_hierarchy_cache: ExpiringCache[
|
||||
Tuple[str, bool], Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]
|
||||
] = ExpiringCache(
|
||||
cache_name="get_room_hierarchy_cache",
|
||||
clock=self._clock,
|
||||
max_len=1000,
|
||||
expiry_ms=5 * 60 * 1000,
|
||||
reset_expiry_on_get=False,
|
||||
)
|
||||
|
||||
def _clear_tried_cache(self):
|
||||
"""Clear pdu_destination_tried cache"""
|
||||
now = self._clock.time_msec()
|
||||
|
@ -558,7 +576,11 @@ class FederationClient(FederationBase):
|
|||
|
||||
try:
|
||||
return await callback(destination)
|
||||
except InvalidResponseError as e:
|
||||
except (
|
||||
RequestSendFailed,
|
||||
InvalidResponseError,
|
||||
NotRetryingDestination,
|
||||
) as e:
|
||||
logger.warning("Failed to %s via %s: %s", description, destination, e)
|
||||
except UnsupportedRoomVersionError:
|
||||
raise
|
||||
|
@ -1319,6 +1341,10 @@ class FederationClient(FederationBase):
|
|||
remote servers
|
||||
"""
|
||||
|
||||
cached_result = self._get_room_hierarchy_cache.get((room_id, suggested_only))
|
||||
if cached_result:
|
||||
return cached_result
|
||||
|
||||
async def send_request(
|
||||
destination: str,
|
||||
) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]:
|
||||
|
@ -1365,58 +1391,63 @@ class FederationClient(FederationBase):
|
|||
return room, children, inaccessible_children
|
||||
|
||||
try:
|
||||
return await self._try_destination_list(
|
||||
result = await self._try_destination_list(
|
||||
"fetch room hierarchy",
|
||||
destinations,
|
||||
send_request,
|
||||
failover_on_unknown_endpoint=True,
|
||||
)
|
||||
except SynapseError as e:
|
||||
# If an unexpected error occurred, re-raise it.
|
||||
if e.code != 502:
|
||||
raise
|
||||
|
||||
# Fallback to the old federation API and translate the results if
|
||||
# no servers implement the new API.
|
||||
#
|
||||
# The algorithm below is a bit inefficient as it only attempts to
|
||||
# get information for the requested room, but the legacy API may
|
||||
# parse information for the requested room, but the legacy API may
|
||||
# return additional layers.
|
||||
if e.code == 502:
|
||||
legacy_result = await self.get_space_summary(
|
||||
destinations,
|
||||
room_id,
|
||||
suggested_only,
|
||||
max_rooms_per_space=None,
|
||||
exclude_rooms=[],
|
||||
)
|
||||
legacy_result = await self.get_space_summary(
|
||||
destinations,
|
||||
room_id,
|
||||
suggested_only,
|
||||
max_rooms_per_space=None,
|
||||
exclude_rooms=[],
|
||||
)
|
||||
|
||||
# Find the requested room in the response (and remove it).
|
||||
for _i, room in enumerate(legacy_result.rooms):
|
||||
if room.get("room_id") == room_id:
|
||||
break
|
||||
else:
|
||||
# The requested room was not returned, nothing we can do.
|
||||
raise
|
||||
requested_room = legacy_result.rooms.pop(_i)
|
||||
# Find the requested room in the response (and remove it).
|
||||
for _i, room in enumerate(legacy_result.rooms):
|
||||
if room.get("room_id") == room_id:
|
||||
break
|
||||
else:
|
||||
# The requested room was not returned, nothing we can do.
|
||||
raise
|
||||
requested_room = legacy_result.rooms.pop(_i)
|
||||
|
||||
# Find any children events of the requested room.
|
||||
children_events = []
|
||||
children_room_ids = set()
|
||||
for event in legacy_result.events:
|
||||
if event.room_id == room_id:
|
||||
children_events.append(event.data)
|
||||
children_room_ids.add(event.state_key)
|
||||
# And add them under the requested room.
|
||||
requested_room["children_state"] = children_events
|
||||
# Find any children events of the requested room.
|
||||
children_events = []
|
||||
children_room_ids = set()
|
||||
for event in legacy_result.events:
|
||||
if event.room_id == room_id:
|
||||
children_events.append(event.data)
|
||||
children_room_ids.add(event.state_key)
|
||||
# And add them under the requested room.
|
||||
requested_room["children_state"] = children_events
|
||||
|
||||
# Find the children rooms.
|
||||
children = []
|
||||
for room in legacy_result.rooms:
|
||||
if room.get("room_id") in children_room_ids:
|
||||
children.append(room)
|
||||
# Find the children rooms.
|
||||
children = []
|
||||
for room in legacy_result.rooms:
|
||||
if room.get("room_id") in children_room_ids:
|
||||
children.append(room)
|
||||
|
||||
# It isn't clear from the response whether some of the rooms are
|
||||
# not accessible.
|
||||
return requested_room, children, ()
|
||||
# It isn't clear from the response whether some of the rooms are
|
||||
# not accessible.
|
||||
result = (requested_room, children, ())
|
||||
|
||||
raise
|
||||
# Cache the result to avoid fetching data over federation every time.
|
||||
self._get_room_hierarchy_cache[(room_id, suggested_only)] = result
|
||||
return result
|
||||
|
||||
|
||||
@attr.s(frozen=True, slots=True, auto_attribs=True)
|
||||
|
|
|
@ -110,6 +110,7 @@ class FederationServer(FederationBase):
|
|||
super().__init__(hs)
|
||||
|
||||
self.handler = hs.get_federation_handler()
|
||||
self._federation_event_handler = hs.get_federation_event_handler()
|
||||
self.state = hs.get_state_handler()
|
||||
self._event_auth_handler = hs.get_event_auth_handler()
|
||||
|
||||
|
@ -787,7 +788,9 @@ class FederationServer(FederationBase):
|
|||
|
||||
event = await self._check_sigs_and_hash(room_version, event)
|
||||
|
||||
return await self.handler.on_send_membership_event(origin, event)
|
||||
return await self._federation_event_handler.on_send_membership_event(
|
||||
origin, event
|
||||
)
|
||||
|
||||
async def on_event_auth(
|
||||
self, origin: str, room_id: str, event_id: str
|
||||
|
@ -1005,9 +1008,7 @@ class FederationServer(FederationBase):
|
|||
async with lock:
|
||||
logger.info("handling received PDU: %s", event)
|
||||
try:
|
||||
await self.handler.on_receive_pdu(
|
||||
origin, event, sent_to_us_directly=True
|
||||
)
|
||||
await self._federation_event_handler.on_receive_pdu(origin, event)
|
||||
except FederationError as e:
|
||||
# XXX: Ideally we'd inform the remote we failed to process
|
||||
# the event, but we can't return an error in the transaction
|
||||
|
|
|
@ -627,23 +627,28 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
async def add_oob_auth(
|
||||
self, stagetype: str, authdict: Dict[str, Any], clientip: str
|
||||
) -> bool:
|
||||
) -> None:
|
||||
"""
|
||||
Adds the result of out-of-band authentication into an existing auth
|
||||
session. Currently used for adding the result of fallback auth.
|
||||
|
||||
Raises:
|
||||
LoginError if the stagetype is unknown or the session is missing.
|
||||
LoginError is raised by check_auth if authentication fails.
|
||||
"""
|
||||
if stagetype not in self.checkers:
|
||||
raise LoginError(400, "", Codes.MISSING_PARAM)
|
||||
if "session" not in authdict:
|
||||
raise LoginError(400, "", Codes.MISSING_PARAM)
|
||||
|
||||
result = await self.checkers[stagetype].check_auth(authdict, clientip)
|
||||
if result:
|
||||
await self.store.mark_ui_auth_stage_complete(
|
||||
authdict["session"], stagetype, result
|
||||
raise LoginError(
|
||||
400, f"Unknown UIA stage type: {stagetype}", Codes.INVALID_PARAM
|
||||
)
|
||||
return True
|
||||
return False
|
||||
if "session" not in authdict:
|
||||
raise LoginError(400, "Missing session ID", Codes.MISSING_PARAM)
|
||||
|
||||
# If authentication fails a LoginError is raised. Otherwise, store
|
||||
# the successful result.
|
||||
result = await self.checkers[stagetype].check_auth(authdict, clientip)
|
||||
await self.store.mark_ui_auth_stage_complete(
|
||||
authdict["session"], stagetype, result
|
||||
)
|
||||
|
||||
def get_session_id(self, clientdict: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
|
@ -1459,6 +1464,10 @@ class AuthHandler(BaseHandler):
|
|||
)
|
||||
|
||||
await self.store.user_delete_threepid(user_id, medium, address)
|
||||
if medium == "email":
|
||||
await self.store.delete_pusher_by_app_id_pushkey_user_id(
|
||||
app_id="m.email", pushkey=address, user_id=user_id
|
||||
)
|
||||
return result
|
||||
|
||||
async def hash(self, password: str) -> str:
|
||||
|
@ -1727,7 +1736,6 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
@attr.s(slots=True)
|
||||
class MacaroonGenerator:
|
||||
|
||||
hs = attr.ib()
|
||||
|
||||
def generate_guest_access_token(self, user_id: str) -> str:
|
||||
|
|
File diff suppressed because it is too large
Load diff
1825
synapse/handlers/federation_event.py
Normal file
1825
synapse/handlers/federation_event.py
Normal file
File diff suppressed because it is too large
Load diff
|
@ -151,7 +151,7 @@ class InitialSyncHandler(BaseHandler):
|
|||
limit = 10
|
||||
|
||||
async def handle_room(event: RoomsForUser):
|
||||
d = {
|
||||
d: JsonDict = {
|
||||
"room_id": event.room_id,
|
||||
"membership": event.membership,
|
||||
"visibility": (
|
||||
|
|
|
@ -353,6 +353,11 @@ class BasePresenceHandler(abc.ABC):
|
|||
# otherwise would not do).
|
||||
await self.set_state(UserID.from_string(user_id), state, force_notify=True)
|
||||
|
||||
async def is_visible(self, observed_user: UserID, observer_user: UserID) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Attempting to check presence on a non-presence worker."
|
||||
)
|
||||
|
||||
|
||||
class _NullContextManager(ContextManager[None]):
|
||||
"""A context manager which does nothing."""
|
||||
|
|
|
@ -56,6 +56,22 @@ login_counter = Counter(
|
|||
)
|
||||
|
||||
|
||||
def init_counters_for_auth_provider(auth_provider_id: str) -> None:
|
||||
"""Ensure the prometheus counters for the given auth provider are initialised
|
||||
|
||||
This fixes a problem where the counters are not reported for a given auth provider
|
||||
until the user first logs in/registers.
|
||||
"""
|
||||
for is_guest in (True, False):
|
||||
login_counter.labels(guest=is_guest, auth_provider=auth_provider_id)
|
||||
for shadow_banned in (True, False):
|
||||
registration_counter.labels(
|
||||
guest=is_guest,
|
||||
shadow_banned=shadow_banned,
|
||||
auth_provider=auth_provider_id,
|
||||
)
|
||||
|
||||
|
||||
class LoginDict(TypedDict):
|
||||
device_id: str
|
||||
access_token: str
|
||||
|
@ -96,6 +112,8 @@ class RegistrationHandler(BaseHandler):
|
|||
self.session_lifetime = hs.config.session_lifetime
|
||||
self.access_token_lifetime = hs.config.access_token_lifetime
|
||||
|
||||
init_counters_for_auth_provider("")
|
||||
|
||||
async def check_username(
|
||||
self,
|
||||
localpart: str,
|
||||
|
|
|
@ -36,6 +36,7 @@ from synapse.api.ratelimiting import Ratelimiter
|
|||
from synapse.event_auth import get_named_level, get_power_level_event
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
Requester,
|
||||
|
@ -79,7 +80,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
self.account_data_handler = hs.get_account_data_handler()
|
||||
self.event_auth_handler = hs.get_event_auth_handler()
|
||||
|
||||
self.member_linearizer = Linearizer(name="member")
|
||||
self.member_linearizer: Linearizer = Linearizer(name="member")
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
|
@ -556,6 +557,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
content.pop("displayname", None)
|
||||
content.pop("avatar_url", None)
|
||||
|
||||
if len(content.get("displayname") or "") > MAX_DISPLAYNAME_LEN:
|
||||
raise SynapseError(
|
||||
400,
|
||||
f"Displayname is too long (max {MAX_DISPLAYNAME_LEN})",
|
||||
errcode=Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
if len(content.get("avatar_url") or "") > MAX_AVATAR_URL_LEN:
|
||||
raise SynapseError(
|
||||
400,
|
||||
f"Avatar URL is too long (max {MAX_AVATAR_URL_LEN})",
|
||||
errcode=Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
effective_membership_state = action
|
||||
if action in ["kick", "unban"]:
|
||||
effective_membership_state = "leave"
|
||||
|
|
|
@ -28,12 +28,11 @@ from synapse.api.constants import (
|
|||
Membership,
|
||||
RoomTypes,
|
||||
)
|
||||
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
|
||||
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.utils import format_event_for_client_v2
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
@ -76,6 +75,9 @@ class _PaginationSession:
|
|||
|
||||
|
||||
class RoomSummaryHandler:
|
||||
# A unique key used for pagination sessions for the room hierarchy endpoint.
|
||||
_PAGINATION_SESSION_TYPE = "room_hierarchy_pagination"
|
||||
|
||||
# The time a pagination session remains valid for.
|
||||
_PAGINATION_SESSION_VALIDITY_PERIOD_MS = 5 * 60 * 1000
|
||||
|
||||
|
@ -87,12 +89,6 @@ class RoomSummaryHandler:
|
|||
self._server_name = hs.hostname
|
||||
self._federation_client = hs.get_federation_client()
|
||||
|
||||
# A map of query information to the current pagination state.
|
||||
#
|
||||
# TODO Allow for multiple workers to share this data.
|
||||
# TODO Expire pagination tokens.
|
||||
self._pagination_sessions: Dict[_PaginationKey, _PaginationSession] = {}
|
||||
|
||||
# If a user tries to fetch the same page multiple times in quick succession,
|
||||
# only process the first attempt and return its result to subsequent requests.
|
||||
self._pagination_response_cache: ResponseCache[
|
||||
|
@ -102,21 +98,6 @@ class RoomSummaryHandler:
|
|||
"get_room_hierarchy",
|
||||
)
|
||||
|
||||
def _expire_pagination_sessions(self):
|
||||
"""Expire pagination session which are old."""
|
||||
expire_before = (
|
||||
self._clock.time_msec() - self._PAGINATION_SESSION_VALIDITY_PERIOD_MS
|
||||
)
|
||||
to_expire = []
|
||||
|
||||
for key, value in self._pagination_sessions.items():
|
||||
if value.creation_time_ms < expire_before:
|
||||
to_expire.append(key)
|
||||
|
||||
for key in to_expire:
|
||||
logger.debug("Expiring pagination session id %s", key)
|
||||
del self._pagination_sessions[key]
|
||||
|
||||
async def get_space_summary(
|
||||
self,
|
||||
requester: str,
|
||||
|
@ -327,18 +308,29 @@ class RoomSummaryHandler:
|
|||
|
||||
# If this is continuing a previous session, pull the persisted data.
|
||||
if from_token:
|
||||
self._expire_pagination_sessions()
|
||||
try:
|
||||
pagination_session = await self._store.get_session(
|
||||
session_type=self._PAGINATION_SESSION_TYPE,
|
||||
session_id=from_token,
|
||||
)
|
||||
except StoreError:
|
||||
raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM)
|
||||
|
||||
pagination_key = _PaginationKey(
|
||||
requested_room_id, suggested_only, max_depth, from_token
|
||||
)
|
||||
if pagination_key not in self._pagination_sessions:
|
||||
# If the requester, room ID, suggested-only, or max depth were modified
|
||||
# the session is invalid.
|
||||
if (
|
||||
requester != pagination_session["requester"]
|
||||
or requested_room_id != pagination_session["room_id"]
|
||||
or suggested_only != pagination_session["suggested_only"]
|
||||
or max_depth != pagination_session["max_depth"]
|
||||
):
|
||||
raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM)
|
||||
|
||||
# Load the previous state.
|
||||
pagination_session = self._pagination_sessions[pagination_key]
|
||||
room_queue = pagination_session.room_queue
|
||||
processed_rooms = pagination_session.processed_rooms
|
||||
room_queue = [
|
||||
_RoomQueueEntry(*fields) for fields in pagination_session["room_queue"]
|
||||
]
|
||||
processed_rooms = set(pagination_session["processed_rooms"])
|
||||
else:
|
||||
# The queue of rooms to process, the next room is last on the stack.
|
||||
room_queue = [_RoomQueueEntry(requested_room_id, ())]
|
||||
|
@ -456,13 +448,21 @@ class RoomSummaryHandler:
|
|||
|
||||
# If there's additional data, generate a pagination token (and persist state).
|
||||
if room_queue:
|
||||
next_batch = random_string(24)
|
||||
result["next_batch"] = next_batch
|
||||
pagination_key = _PaginationKey(
|
||||
requested_room_id, suggested_only, max_depth, next_batch
|
||||
)
|
||||
self._pagination_sessions[pagination_key] = _PaginationSession(
|
||||
self._clock.time_msec(), room_queue, processed_rooms
|
||||
result["next_batch"] = await self._store.create_session(
|
||||
session_type=self._PAGINATION_SESSION_TYPE,
|
||||
value={
|
||||
# Information which must be identical across pagination.
|
||||
"requester": requester,
|
||||
"room_id": requested_room_id,
|
||||
"suggested_only": suggested_only,
|
||||
"max_depth": max_depth,
|
||||
# The stored state.
|
||||
"room_queue": [
|
||||
attr.astuple(room_entry) for room_entry in room_queue
|
||||
],
|
||||
"processed_rooms": list(processed_rooms),
|
||||
},
|
||||
expiry_ms=self._PAGINATION_SESSION_VALIDITY_PERIOD_MS,
|
||||
)
|
||||
|
||||
return result
|
||||
|
|
|
@ -37,6 +37,7 @@ from twisted.web.server import Request
|
|||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
|
||||
from synapse.config.sso import SsoAttributeRequirement
|
||||
from synapse.handlers.register import init_counters_for_auth_provider
|
||||
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
|
||||
from synapse.http import get_request_user_agent
|
||||
from synapse.http.server import respond_with_html, respond_with_redirect
|
||||
|
@ -213,6 +214,7 @@ class SsoHandler:
|
|||
p_id = p.idp_id
|
||||
assert p_id not in self._identity_providers
|
||||
self._identity_providers[p_id] = p
|
||||
init_counters_for_auth_provider(p_id)
|
||||
|
||||
def get_identity_providers(self) -> Mapping[str, SsoIdentityProvider]:
|
||||
"""Get the configured identity providers"""
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
# Copyright 2018, 2019 New Vector Ltd
|
||||
# Copyright 2015-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -31,6 +30,8 @@ from prometheus_client import Counter
|
|||
|
||||
from synapse.api.constants import AccountDataTypes, EventTypes, Membership
|
||||
from synapse.api.filtering import FilterCollection
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.events import EventBase
|
||||
from synapse.logging.context import current_context
|
||||
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
|
||||
|
@ -86,20 +87,20 @@ LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100
|
|||
SyncRequestKey = Tuple[Any, ...]
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class SyncConfig:
|
||||
user = attr.ib(type=UserID)
|
||||
filter_collection = attr.ib(type=FilterCollection)
|
||||
is_guest = attr.ib(type=bool)
|
||||
request_key = attr.ib(type=SyncRequestKey)
|
||||
device_id = attr.ib(type=Optional[str])
|
||||
user: UserID
|
||||
filter_collection: FilterCollection
|
||||
is_guest: bool
|
||||
request_key: SyncRequestKey
|
||||
device_id: Optional[str]
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class TimelineBatch:
|
||||
prev_batch = attr.ib(type=StreamToken)
|
||||
events = attr.ib(type=List[EventBase])
|
||||
limited = attr.ib(type=bool)
|
||||
prev_batch: StreamToken
|
||||
events: List[EventBase]
|
||||
limited: bool
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Make the result appear empty if there are no updates. This is used
|
||||
|
@ -113,16 +114,16 @@ class TimelineBatch:
|
|||
# if there are updates for it, which we check after the instance has been created.
|
||||
# This should not be a big deal because we update the notification counts afterwards as
|
||||
# well anyway.
|
||||
@attr.s(slots=True)
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class JoinedSyncResult:
|
||||
room_id = attr.ib(type=str)
|
||||
timeline = attr.ib(type=TimelineBatch)
|
||||
state = attr.ib(type=StateMap[EventBase])
|
||||
ephemeral = attr.ib(type=List[JsonDict])
|
||||
account_data = attr.ib(type=List[JsonDict])
|
||||
unread_notifications = attr.ib(type=JsonDict)
|
||||
summary = attr.ib(type=Optional[JsonDict])
|
||||
unread_count = attr.ib(type=int)
|
||||
room_id: str
|
||||
timeline: TimelineBatch
|
||||
state: StateMap[EventBase]
|
||||
ephemeral: List[JsonDict]
|
||||
account_data: List[JsonDict]
|
||||
unread_notifications: JsonDict
|
||||
summary: Optional[JsonDict]
|
||||
unread_count: int
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Make the result appear empty if there are no updates. This is used
|
||||
|
@ -138,12 +139,12 @@ class JoinedSyncResult:
|
|||
)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class ArchivedSyncResult:
|
||||
room_id = attr.ib(type=str)
|
||||
timeline = attr.ib(type=TimelineBatch)
|
||||
state = attr.ib(type=StateMap[EventBase])
|
||||
account_data = attr.ib(type=List[JsonDict])
|
||||
room_id: str
|
||||
timeline: TimelineBatch
|
||||
state: StateMap[EventBase]
|
||||
account_data: List[JsonDict]
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Make the result appear empty if there are no updates. This is used
|
||||
|
@ -152,37 +153,37 @@ class ArchivedSyncResult:
|
|||
return bool(self.timeline or self.state or self.account_data)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class InvitedSyncResult:
|
||||
room_id = attr.ib(type=str)
|
||||
invite = attr.ib(type=EventBase)
|
||||
room_id: str
|
||||
invite: EventBase
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Invited rooms should always be reported to the client"""
|
||||
return True
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class KnockedSyncResult:
|
||||
room_id = attr.ib(type=str)
|
||||
knock = attr.ib(type=EventBase)
|
||||
room_id: str
|
||||
knock: EventBase
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Knocked rooms should always be reported to the client"""
|
||||
return True
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class GroupsSyncResult:
|
||||
join = attr.ib(type=JsonDict)
|
||||
invite = attr.ib(type=JsonDict)
|
||||
leave = attr.ib(type=JsonDict)
|
||||
join: JsonDict
|
||||
invite: JsonDict
|
||||
leave: JsonDict
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.join or self.invite or self.leave)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class DeviceLists:
|
||||
"""
|
||||
Attributes:
|
||||
|
@ -190,27 +191,27 @@ class DeviceLists:
|
|||
left: List of user_ids whose devices we no longer track
|
||||
"""
|
||||
|
||||
changed = attr.ib(type=Collection[str])
|
||||
left = attr.ib(type=Collection[str])
|
||||
changed: Collection[str]
|
||||
left: Collection[str]
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.changed or self.left)
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class _RoomChanges:
|
||||
"""The set of room entries to include in the sync, plus the set of joined
|
||||
and left room IDs since last sync.
|
||||
"""
|
||||
|
||||
room_entries = attr.ib(type=List["RoomSyncResultBuilder"])
|
||||
invited = attr.ib(type=List[InvitedSyncResult])
|
||||
knocked = attr.ib(type=List[KnockedSyncResult])
|
||||
newly_joined_rooms = attr.ib(type=List[str])
|
||||
newly_left_rooms = attr.ib(type=List[str])
|
||||
room_entries: List["RoomSyncResultBuilder"]
|
||||
invited: List[InvitedSyncResult]
|
||||
knocked: List[KnockedSyncResult]
|
||||
newly_joined_rooms: List[str]
|
||||
newly_left_rooms: List[str]
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class SyncResult:
|
||||
"""
|
||||
Attributes:
|
||||
|
@ -230,18 +231,18 @@ class SyncResult:
|
|||
groups: Group updates, if any
|
||||
"""
|
||||
|
||||
next_batch = attr.ib(type=StreamToken)
|
||||
presence = attr.ib(type=List[JsonDict])
|
||||
account_data = attr.ib(type=List[JsonDict])
|
||||
joined = attr.ib(type=List[JoinedSyncResult])
|
||||
invited = attr.ib(type=List[InvitedSyncResult])
|
||||
knocked = attr.ib(type=List[KnockedSyncResult])
|
||||
archived = attr.ib(type=List[ArchivedSyncResult])
|
||||
to_device = attr.ib(type=List[JsonDict])
|
||||
device_lists = attr.ib(type=DeviceLists)
|
||||
device_one_time_keys_count = attr.ib(type=JsonDict)
|
||||
device_unused_fallback_key_types = attr.ib(type=List[str])
|
||||
groups = attr.ib(type=Optional[GroupsSyncResult])
|
||||
next_batch: StreamToken
|
||||
presence: List[UserPresenceState]
|
||||
account_data: List[JsonDict]
|
||||
joined: List[JoinedSyncResult]
|
||||
invited: List[InvitedSyncResult]
|
||||
knocked: List[KnockedSyncResult]
|
||||
archived: List[ArchivedSyncResult]
|
||||
to_device: List[JsonDict]
|
||||
device_lists: DeviceLists
|
||||
device_one_time_keys_count: JsonDict
|
||||
device_unused_fallback_key_types: List[str]
|
||||
groups: Optional[GroupsSyncResult]
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Make the result appear empty if there are no updates. This is used
|
||||
|
@ -701,7 +702,7 @@ class SyncHandler:
|
|||
name_id = state_ids.get((EventTypes.Name, ""))
|
||||
canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, ""))
|
||||
|
||||
summary = {}
|
||||
summary: JsonDict = {}
|
||||
empty_ms = MemberSummary([], 0)
|
||||
|
||||
# TODO: only send these when they change.
|
||||
|
@ -1842,6 +1843,9 @@ class SyncHandler:
|
|||
knocked = []
|
||||
|
||||
for event in room_list:
|
||||
if event.room_version_id not in KNOWN_ROOM_VERSIONS:
|
||||
continue
|
||||
|
||||
if event.membership == Membership.JOIN:
|
||||
room_entries.append(
|
||||
RoomSyncResultBuilder(
|
||||
|
@ -2075,21 +2079,23 @@ class SyncHandler:
|
|||
# If the membership's stream ordering is after the given stream
|
||||
# ordering, we need to go and work out if the user was in the room
|
||||
# before.
|
||||
for room_id, event_pos in joined_rooms:
|
||||
if not event_pos.persisted_after(room_key):
|
||||
joined_room_ids.add(room_id)
|
||||
for joined_room in joined_rooms:
|
||||
if not joined_room.event_pos.persisted_after(room_key):
|
||||
joined_room_ids.add(joined_room.room_id)
|
||||
continue
|
||||
|
||||
logger.info("User joined room after current token: %s", room_id)
|
||||
logger.info("User joined room after current token: %s", joined_room.room_id)
|
||||
|
||||
extrems = (
|
||||
await self.store.get_forward_extremities_for_room_at_stream_ordering(
|
||||
room_id, event_pos.stream
|
||||
joined_room.room_id, joined_room.event_pos.stream
|
||||
)
|
||||
)
|
||||
users_in_room = await self.state.get_current_users_in_room(room_id, extrems)
|
||||
users_in_room = await self.state.get_current_users_in_room(
|
||||
joined_room.room_id, extrems
|
||||
)
|
||||
if user_id in users_in_room:
|
||||
joined_room_ids.add(room_id)
|
||||
joined_room_ids.add(joined_room.room_id)
|
||||
|
||||
return frozenset(joined_room_ids)
|
||||
|
||||
|
@ -2159,7 +2165,7 @@ def _calculate_state(
|
|||
return {event_id_to_key[e]: e for e in state_ids}
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class SyncResultBuilder:
|
||||
"""Used to help build up a new SyncResult for a user
|
||||
|
||||
|
@ -2171,33 +2177,33 @@ class SyncResultBuilder:
|
|||
joined_room_ids: List of rooms the user is joined to
|
||||
|
||||
# The following mirror the fields in a sync response
|
||||
presence (list)
|
||||
account_data (list)
|
||||
joined (list[JoinedSyncResult])
|
||||
invited (list[InvitedSyncResult])
|
||||
knocked (list[KnockedSyncResult])
|
||||
archived (list[ArchivedSyncResult])
|
||||
groups (GroupsSyncResult|None)
|
||||
to_device (list)
|
||||
presence
|
||||
account_data
|
||||
joined
|
||||
invited
|
||||
knocked
|
||||
archived
|
||||
groups
|
||||
to_device
|
||||
"""
|
||||
|
||||
sync_config = attr.ib(type=SyncConfig)
|
||||
full_state = attr.ib(type=bool)
|
||||
since_token = attr.ib(type=Optional[StreamToken])
|
||||
now_token = attr.ib(type=StreamToken)
|
||||
joined_room_ids = attr.ib(type=FrozenSet[str])
|
||||
sync_config: SyncConfig
|
||||
full_state: bool
|
||||
since_token: Optional[StreamToken]
|
||||
now_token: StreamToken
|
||||
joined_room_ids: FrozenSet[str]
|
||||
|
||||
presence = attr.ib(type=List[JsonDict], default=attr.Factory(list))
|
||||
account_data = attr.ib(type=List[JsonDict], default=attr.Factory(list))
|
||||
joined = attr.ib(type=List[JoinedSyncResult], default=attr.Factory(list))
|
||||
invited = attr.ib(type=List[InvitedSyncResult], default=attr.Factory(list))
|
||||
knocked = attr.ib(type=List[KnockedSyncResult], default=attr.Factory(list))
|
||||
archived = attr.ib(type=List[ArchivedSyncResult], default=attr.Factory(list))
|
||||
groups = attr.ib(type=Optional[GroupsSyncResult], default=None)
|
||||
to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list))
|
||||
presence: List[UserPresenceState] = attr.Factory(list)
|
||||
account_data: List[JsonDict] = attr.Factory(list)
|
||||
joined: List[JoinedSyncResult] = attr.Factory(list)
|
||||
invited: List[InvitedSyncResult] = attr.Factory(list)
|
||||
knocked: List[KnockedSyncResult] = attr.Factory(list)
|
||||
archived: List[ArchivedSyncResult] = attr.Factory(list)
|
||||
groups: Optional[GroupsSyncResult] = None
|
||||
to_device: List[JsonDict] = attr.Factory(list)
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class RoomSyncResultBuilder:
|
||||
"""Stores information needed to create either a `JoinedSyncResult` or
|
||||
`ArchivedSyncResult`.
|
||||
|
@ -2213,10 +2219,10 @@ class RoomSyncResultBuilder:
|
|||
upto_token: Latest point to return events from.
|
||||
"""
|
||||
|
||||
room_id = attr.ib(type=str)
|
||||
rtype = attr.ib(type=str)
|
||||
events = attr.ib(type=Optional[List[EventBase]])
|
||||
newly_joined = attr.ib(type=bool)
|
||||
full_state = attr.ib(type=bool)
|
||||
since_token = attr.ib(type=Optional[StreamToken])
|
||||
upto_token = attr.ib(type=StreamToken)
|
||||
room_id: str
|
||||
rtype: str
|
||||
events: Optional[List[EventBase]]
|
||||
newly_joined: bool
|
||||
full_state: bool
|
||||
since_token: Optional[StreamToken]
|
||||
upto_token: StreamToken
|
||||
|
|
|
@ -34,3 +34,8 @@ class UIAuthSessionDataConstants:
|
|||
# used by validate_user_via_ui_auth to store the mxid of the user we are validating
|
||||
# for.
|
||||
REQUEST_USER_ID = "request_user_id"
|
||||
|
||||
# used during registration to store the registration token used (if required) so that:
|
||||
# - we can prevent a token being used twice by one session
|
||||
# - we can 'use up' the token after registration has successfully completed
|
||||
REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token"
|
||||
|
|
|
@ -49,7 +49,7 @@ class UserInteractiveAuthChecker:
|
|||
clientip: The IP address of the client.
|
||||
|
||||
Raises:
|
||||
SynapseError if authentication failed
|
||||
LoginError if authentication failed.
|
||||
|
||||
Returns:
|
||||
The result of authentication (to pass back to the client?)
|
||||
|
@ -131,7 +131,9 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
|
|||
)
|
||||
if resp_body["success"]:
|
||||
return True
|
||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
||||
raise LoginError(
|
||||
401, "Captcha authentication failed", errcode=Codes.UNAUTHORIZED
|
||||
)
|
||||
|
||||
|
||||
class _BaseThreepidAuthChecker:
|
||||
|
@ -191,7 +193,9 @@ class _BaseThreepidAuthChecker:
|
|||
raise AssertionError("Unrecognized threepid medium: %s" % (medium,))
|
||||
|
||||
if not threepid:
|
||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
||||
raise LoginError(
|
||||
401, "Unable to get validated threepid", errcode=Codes.UNAUTHORIZED
|
||||
)
|
||||
|
||||
if threepid["medium"] != medium:
|
||||
raise LoginError(
|
||||
|
@ -237,11 +241,76 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
|
|||
return await self._check_threepid("msisdn", authdict)
|
||||
|
||||
|
||||
class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
|
||||
AUTH_TYPE = LoginType.REGISTRATION_TOKEN
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self.hs = hs
|
||||
self._enabled = bool(hs.config.registration_requires_token)
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
return self._enabled
|
||||
|
||||
async def check_auth(self, authdict: dict, clientip: str) -> Any:
|
||||
if "token" not in authdict:
|
||||
raise LoginError(400, "Missing registration token", Codes.MISSING_PARAM)
|
||||
if not isinstance(authdict["token"], str):
|
||||
raise LoginError(
|
||||
400, "Registration token must be a string", Codes.INVALID_PARAM
|
||||
)
|
||||
if "session" not in authdict:
|
||||
raise LoginError(400, "Missing UIA session", Codes.MISSING_PARAM)
|
||||
|
||||
# Get these here to avoid cyclic dependencies
|
||||
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
|
||||
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
|
||||
session = authdict["session"]
|
||||
token = authdict["token"]
|
||||
|
||||
# If the LoginType.REGISTRATION_TOKEN stage has already been completed,
|
||||
# return early to avoid incrementing `pending` again.
|
||||
stored_token = await auth_handler.get_session_data(
|
||||
session, UIAuthSessionDataConstants.REGISTRATION_TOKEN
|
||||
)
|
||||
if stored_token:
|
||||
if token != stored_token:
|
||||
raise LoginError(
|
||||
400, "Registration token has changed", Codes.INVALID_PARAM
|
||||
)
|
||||
else:
|
||||
return token
|
||||
|
||||
if await self.store.registration_token_is_valid(token):
|
||||
# Increment pending counter, so that if token has limited uses it
|
||||
# can't be used up by someone else in the meantime.
|
||||
await self.store.set_registration_token_pending(token)
|
||||
# Store the token in the UIA session, so that once registration
|
||||
# is complete `completed` can be incremented.
|
||||
await auth_handler.set_session_data(
|
||||
session,
|
||||
UIAuthSessionDataConstants.REGISTRATION_TOKEN,
|
||||
token,
|
||||
)
|
||||
# The token will be stored as the result of the authentication stage
|
||||
# in ui_auth_sessions_credentials. This allows the pending counter
|
||||
# for tokens to be decremented when expired sessions are deleted.
|
||||
return token
|
||||
else:
|
||||
raise LoginError(
|
||||
401, "Invalid registration token", errcode=Codes.UNAUTHORIZED
|
||||
)
|
||||
|
||||
|
||||
INTERACTIVE_AUTH_CHECKERS = [
|
||||
DummyAuthChecker,
|
||||
TermsAuthChecker,
|
||||
RecaptchaAuthChecker,
|
||||
EmailIdentityAuthChecker,
|
||||
MsisdnAuthChecker,
|
||||
RegistrationTokenAuthChecker,
|
||||
]
|
||||
"""A list of UserInteractiveAuthChecker classes"""
|
||||
|
|
|
@ -12,8 +12,15 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.http.server import DirectServeJsonResource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
class AdditionalResource(DirectServeJsonResource):
|
||||
"""Resource wrapper for additional_resources
|
||||
|
@ -25,7 +32,7 @@ class AdditionalResource(DirectServeJsonResource):
|
|||
and exception handling.
|
||||
"""
|
||||
|
||||
def __init__(self, hs, handler):
|
||||
def __init__(self, hs: "HomeServer", handler):
|
||||
"""Initialise AdditionalResource
|
||||
|
||||
The ``handler`` should return a deferred which completes when it has
|
||||
|
@ -33,14 +40,14 @@ class AdditionalResource(DirectServeJsonResource):
|
|||
``request.write()``, and call ``request.finish()``.
|
||||
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): homeserver
|
||||
hs: homeserver
|
||||
handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
|
||||
function to be called to handle the request.
|
||||
"""
|
||||
super().__init__()
|
||||
self._handler = handler
|
||||
|
||||
def _async_render(self, request):
|
||||
def _async_render(self, request: Request):
|
||||
# Cheekily pass the result straight through, so we don't need to worry
|
||||
# if its an awaitable or not.
|
||||
return self._handler(request)
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import logging
|
||||
import random
|
||||
import time
|
||||
from typing import List
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
import attr
|
||||
|
||||
|
@ -28,35 +28,35 @@ from synapse.logging.context import make_deferred_yieldable
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SERVER_CACHE = {}
|
||||
SERVER_CACHE: Dict[bytes, List["Server"]] = {}
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
@attr.s(auto_attribs=True, slots=True, frozen=True)
|
||||
class Server:
|
||||
"""
|
||||
Our record of an individual server which can be tried to reach a destination.
|
||||
|
||||
Attributes:
|
||||
host (bytes): target hostname
|
||||
port (int):
|
||||
priority (int):
|
||||
weight (int):
|
||||
expires (int): when the cache should expire this record - in *seconds* since
|
||||
host: target hostname
|
||||
port:
|
||||
priority:
|
||||
weight:
|
||||
expires: when the cache should expire this record - in *seconds* since
|
||||
the epoch
|
||||
"""
|
||||
|
||||
host = attr.ib()
|
||||
port = attr.ib()
|
||||
priority = attr.ib(default=0)
|
||||
weight = attr.ib(default=0)
|
||||
expires = attr.ib(default=0)
|
||||
host: bytes
|
||||
port: int
|
||||
priority: int = 0
|
||||
weight: int = 0
|
||||
expires: int = 0
|
||||
|
||||
|
||||
def _sort_server_list(server_list):
|
||||
def _sort_server_list(server_list: List[Server]) -> List[Server]:
|
||||
"""Given a list of SRV records sort them into priority order and shuffle
|
||||
each priority with the given weight.
|
||||
"""
|
||||
priority_map = {}
|
||||
priority_map: Dict[int, List[Server]] = {}
|
||||
|
||||
for server in server_list:
|
||||
priority_map.setdefault(server.priority, []).append(server)
|
||||
|
@ -103,11 +103,16 @@ class SrvResolver:
|
|||
|
||||
Args:
|
||||
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
|
||||
cache (dict): cache object
|
||||
get_time (callable): clock implementation. Should return seconds since the epoch
|
||||
cache: cache object
|
||||
get_time: clock implementation. Should return seconds since the epoch
|
||||
"""
|
||||
|
||||
def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
|
||||
def __init__(
|
||||
self,
|
||||
dns_client=client,
|
||||
cache: Dict[bytes, List[Server]] = SERVER_CACHE,
|
||||
get_time: Callable[[], float] = time.time,
|
||||
):
|
||||
self._dns_client = dns_client
|
||||
self._cache = cache
|
||||
self._get_time = get_time
|
||||
|
@ -116,7 +121,7 @@ class SrvResolver:
|
|||
"""Look up a SRV record
|
||||
|
||||
Args:
|
||||
service_name (bytes): record to look up
|
||||
service_name: record to look up
|
||||
|
||||
Returns:
|
||||
a list of the SRV records, or an empty list if none found
|
||||
|
@ -158,7 +163,7 @@ class SrvResolver:
|
|||
and answers[0].payload
|
||||
and answers[0].payload.target == dns.Name(b".")
|
||||
):
|
||||
raise ConnectError("Service %s unavailable" % service_name)
|
||||
raise ConnectError(f"Service {service_name!r} unavailable")
|
||||
|
||||
servers = []
|
||||
|
||||
|
|
|
@ -173,7 +173,7 @@ class ProxyAgent(_AgentBase):
|
|||
raise ValueError(f"Invalid URI {uri!r}")
|
||||
|
||||
parsed_uri = URI.fromBytes(uri)
|
||||
pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
|
||||
pool_key = f"{parsed_uri.scheme!r}{parsed_uri.host!r}{parsed_uri.port}"
|
||||
request_path = parsed_uri.originForm
|
||||
|
||||
should_skip_proxy = False
|
||||
|
@ -199,7 +199,7 @@ class ProxyAgent(_AgentBase):
|
|||
)
|
||||
# Cache *all* connections under the same key, since we are only
|
||||
# connecting to a single destination, the proxy:
|
||||
pool_key = ("http-proxy", self.http_proxy_endpoint)
|
||||
pool_key = "http-proxy"
|
||||
endpoint = self.http_proxy_endpoint
|
||||
request_path = uri
|
||||
elif (
|
||||
|
|
|
@ -32,6 +32,7 @@ from twisted.internet import defer
|
|||
from twisted.web.resource import IResource
|
||||
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.presence_router import PresenceRouter
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.http.server import (
|
||||
DirectServeHtmlResource,
|
||||
|
@ -57,6 +58,8 @@ This package defines the 'stable' API which can be used by extension modules whi
|
|||
are loaded into Synapse.
|
||||
"""
|
||||
|
||||
PRESENCE_ALL_USERS = PresenceRouter.ALL_USERS
|
||||
|
||||
__all__ = [
|
||||
"errors",
|
||||
"make_deferred_yieldable",
|
||||
|
@ -70,6 +73,7 @@ __all__ = [
|
|||
"DirectServeHtmlResource",
|
||||
"DirectServeJsonResource",
|
||||
"ModuleApi",
|
||||
"PRESENCE_ALL_USERS",
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -112,6 +116,7 @@ class ModuleApi:
|
|||
self._spam_checker = hs.get_spam_checker()
|
||||
self._account_validity_handler = hs.get_account_validity_handler()
|
||||
self._third_party_event_rules = hs.get_third_party_event_rules()
|
||||
self._presence_router = hs.get_presence_router()
|
||||
|
||||
#################################################################################
|
||||
# The following methods should only be called during the module's initialisation.
|
||||
|
@ -131,6 +136,11 @@ class ModuleApi:
|
|||
"""Registers callbacks for third party event rules capabilities."""
|
||||
return self._third_party_event_rules.register_third_party_rules_callbacks
|
||||
|
||||
@property
|
||||
def register_presence_router_callbacks(self):
|
||||
"""Registers callbacks for presence router capabilities."""
|
||||
return self._presence_router.register_presence_router_callbacks
|
||||
|
||||
def register_web_resource(self, path: str, resource: IResource):
|
||||
"""Registers a web resource to be served at the given path.
|
||||
|
||||
|
|
|
@ -48,7 +48,8 @@ logger = logging.getLogger(__name__)
|
|||
# [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers.
|
||||
|
||||
REQUIREMENTS = [
|
||||
"jsonschema>=2.5.1",
|
||||
# we use the TYPE_CHECKER.redefine method added in jsonschema 3.0.0
|
||||
"jsonschema>=3.0.0",
|
||||
"frozendict>=1",
|
||||
"unpaddedbase64>=1.1.0",
|
||||
"canonicaljson>=1.4.0",
|
||||
|
|
|
@ -62,7 +62,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
|||
self.store = hs.get_datastore()
|
||||
self.storage = hs.get_storage()
|
||||
self.clock = hs.get_clock()
|
||||
self.federation_handler = hs.get_federation_handler()
|
||||
self.federation_event_handler = hs.get_federation_event_handler()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
|
||||
|
@ -127,7 +127,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
|||
|
||||
logger.info("Got %d events from federation", len(event_and_contexts))
|
||||
|
||||
max_stream_id = await self.federation_handler.persist_events_and_notify(
|
||||
max_stream_id = await self.federation_event_handler.persist_events_and_notify(
|
||||
room_id, event_and_contexts, backfilled
|
||||
)
|
||||
|
||||
|
|
|
@ -16,6 +16,9 @@ function captchaDone() {
|
|||
<body>
|
||||
<form id="registrationForm" method="post" action="{{ myurl }}">
|
||||
<div>
|
||||
{% if error is defined %}
|
||||
<p class="error"><strong>Error: {{ error }}</strong></p>
|
||||
{% endif %}
|
||||
<p>
|
||||
Hello! We need to prevent computer programs and other automated
|
||||
things from creating accounts on this server.
|
||||
|
|
23
synapse/res/templates/registration_token.html
Normal file
23
synapse/res/templates/registration_token.html
Normal file
|
@ -0,0 +1,23 @@
|
|||
<html>
|
||||
<head>
|
||||
<title>Authentication</title>
|
||||
<meta name='viewport' content='width=device-width, initial-scale=1,
|
||||
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
|
||||
<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
|
||||
</head>
|
||||
<body>
|
||||
<form id="registrationForm" method="post" action="{{ myurl }}">
|
||||
<div>
|
||||
{% if error is defined %}
|
||||
<p class="error"><strong>Error: {{ error }}</strong></p>
|
||||
{% endif %}
|
||||
<p>
|
||||
Please enter a registration token.
|
||||
</p>
|
||||
<input type="hidden" name="session" value="{{ session }}" />
|
||||
<input type="text" name="token" />
|
||||
<input type="submit" value="Authenticate" />
|
||||
</div>
|
||||
</form>
|
||||
</body>
|
||||
</html>
|
|
@ -8,6 +8,9 @@
|
|||
<body>
|
||||
<form id="registrationForm" method="post" action="{{ myurl }}">
|
||||
<div>
|
||||
{% if error is defined %}
|
||||
<p class="error"><strong>Error: {{ error }}</strong></p>
|
||||
{% endif %}
|
||||
<p>
|
||||
Please click the button below if you agree to the
|
||||
<a href="{{ terms_url }}">privacy policy of this homeserver.</a>
|
||||
|
|
|
@ -36,7 +36,11 @@ from synapse.rest.admin.event_reports import (
|
|||
)
|
||||
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
|
||||
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
|
||||
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
|
||||
from synapse.rest.admin.registration_tokens import (
|
||||
ListRegistrationTokensRestServlet,
|
||||
NewRegistrationTokenRestServlet,
|
||||
RegistrationTokenRestServlet,
|
||||
)
|
||||
from synapse.rest.admin.rooms import (
|
||||
DeleteRoomRestServlet,
|
||||
ForwardExtremitiesRestServlet,
|
||||
|
@ -47,7 +51,6 @@ from synapse.rest.admin.rooms import (
|
|||
RoomMembersRestServlet,
|
||||
RoomRestServlet,
|
||||
RoomStateRestServlet,
|
||||
ShutdownRoomRestServlet,
|
||||
)
|
||||
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
|
||||
from synapse.rest.admin.statistics import UserMediaStatisticsRestServlet
|
||||
|
@ -220,8 +223,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
|||
RoomMembersRestServlet(hs).register(http_server)
|
||||
DeleteRoomRestServlet(hs).register(http_server)
|
||||
JoinRoomAliasServlet(hs).register(http_server)
|
||||
PurgeRoomServlet(hs).register(http_server)
|
||||
SendServerNoticeServlet(hs).register(http_server)
|
||||
VersionServlet(hs).register(http_server)
|
||||
UserAdminServlet(hs).register(http_server)
|
||||
UserMembershipRestServlet(hs).register(http_server)
|
||||
|
@ -241,6 +242,13 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
|||
RoomEventContextServlet(hs).register(http_server)
|
||||
RateLimitRestServlet(hs).register(http_server)
|
||||
UsernameAvailableRestServlet(hs).register(http_server)
|
||||
ListRegistrationTokensRestServlet(hs).register(http_server)
|
||||
NewRegistrationTokenRestServlet(hs).register(http_server)
|
||||
RegistrationTokenRestServlet(hs).register(http_server)
|
||||
|
||||
# Some servlets only get registered for the main process.
|
||||
if hs.config.worker_app is None:
|
||||
SendServerNoticeServlet(hs).register(http_server)
|
||||
|
||||
|
||||
def register_servlets_for_client_rest_resource(
|
||||
|
@ -253,7 +261,6 @@ def register_servlets_for_client_rest_resource(
|
|||
PurgeHistoryRestServlet(hs).register(http_server)
|
||||
ResetPasswordRestServlet(hs).register(http_server)
|
||||
SearchUsersRestServlet(hs).register(http_server)
|
||||
ShutdownRoomRestServlet(hs).register(http_server)
|
||||
UserRegisterServlet(hs).register(http_server)
|
||||
DeleteGroupAdminRestServlet(hs).register(http_server)
|
||||
AccountValidityRenewServlet(hs).register(http_server)
|
||||
|
|
|
@ -1,58 +0,0 @@
|
|||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
parse_json_object_from_request,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.admin import assert_requester_is_admin
|
||||
from synapse.rest.admin._base import admin_patterns
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
class PurgeRoomServlet(RestServlet):
|
||||
"""Servlet which will remove all trace of a room from the database
|
||||
|
||||
POST /_synapse/admin/v1/purge_room
|
||||
{
|
||||
"room_id": "!room:id"
|
||||
}
|
||||
|
||||
returns:
|
||||
|
||||
{}
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/purge_room$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.pagination_handler = hs.get_pagination_handler()
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(body, ("room_id",))
|
||||
|
||||
await self.pagination_handler.purge_room(body["room_id"])
|
||||
|
||||
return 200, {}
|
321
synapse/rest/admin/registration_tokens.py
Normal file
321
synapse/rest/admin/registration_tokens.py
Normal file
|
@ -0,0 +1,321 @@
|
|||
# Copyright 2021 Callum Brown
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import string
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
parse_boolean,
|
||||
parse_json_object_from_request,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ListRegistrationTokensRestServlet(RestServlet):
|
||||
"""List registration tokens.
|
||||
|
||||
To list all tokens:
|
||||
|
||||
GET /_synapse/admin/v1/registration_tokens
|
||||
|
||||
200 OK
|
||||
|
||||
{
|
||||
"registration_tokens": [
|
||||
{
|
||||
"token": "abcd",
|
||||
"uses_allowed": 3,
|
||||
"pending": 0,
|
||||
"completed": 1,
|
||||
"expiry_time": null
|
||||
},
|
||||
{
|
||||
"token": "wxyz",
|
||||
"uses_allowed": null,
|
||||
"pending": 0,
|
||||
"completed": 9,
|
||||
"expiry_time": 1625394937000
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
The optional query parameter `valid` can be used to filter the response.
|
||||
If it is `true`, only valid tokens are returned. If it is `false`, only
|
||||
tokens that have expired or have had all uses exhausted are returned.
|
||||
If it is omitted, all tokens are returned regardless of validity.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/registration_tokens$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
valid = parse_boolean(request, "valid")
|
||||
token_list = await self.store.get_registration_tokens(valid)
|
||||
return 200, {"registration_tokens": token_list}
|
||||
|
||||
|
||||
class NewRegistrationTokenRestServlet(RestServlet):
|
||||
"""Create a new registration token.
|
||||
|
||||
For example, to create a token specifying some fields:
|
||||
|
||||
POST /_synapse/admin/v1/registration_tokens/new
|
||||
|
||||
{
|
||||
"token": "defg",
|
||||
"uses_allowed": 1
|
||||
}
|
||||
|
||||
200 OK
|
||||
|
||||
{
|
||||
"token": "defg",
|
||||
"uses_allowed": 1,
|
||||
"pending": 0,
|
||||
"completed": 0,
|
||||
"expiry_time": null
|
||||
}
|
||||
|
||||
Defaults are used for any fields not specified.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/registration_tokens/new$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
# A string of all the characters allowed to be in a registration_token
|
||||
self.allowed_chars = string.ascii_letters + string.digits + "-_"
|
||||
self.allowed_chars_set = set(self.allowed_chars)
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
if "token" in body:
|
||||
token = body["token"]
|
||||
if not isinstance(token, str):
|
||||
raise SynapseError(400, "token must be a string", Codes.INVALID_PARAM)
|
||||
if not (0 < len(token) <= 64):
|
||||
raise SynapseError(
|
||||
400,
|
||||
"token must not be empty and must not be longer than 64 characters",
|
||||
Codes.INVALID_PARAM,
|
||||
)
|
||||
if not set(token).issubset(self.allowed_chars_set):
|
||||
raise SynapseError(
|
||||
400,
|
||||
"token must consist only of characters matched by the regex [A-Za-z0-9-_]",
|
||||
Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
else:
|
||||
# Get length of token to generate (default is 16)
|
||||
length = body.get("length", 16)
|
||||
if not isinstance(length, int):
|
||||
raise SynapseError(
|
||||
400, "length must be an integer", Codes.INVALID_PARAM
|
||||
)
|
||||
if not (0 < length <= 64):
|
||||
raise SynapseError(
|
||||
400,
|
||||
"length must be greater than zero and not greater than 64",
|
||||
Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
# Generate token
|
||||
token = await self.store.generate_registration_token(
|
||||
length, self.allowed_chars
|
||||
)
|
||||
|
||||
uses_allowed = body.get("uses_allowed", None)
|
||||
if not (
|
||||
uses_allowed is None
|
||||
or (isinstance(uses_allowed, int) and uses_allowed >= 0)
|
||||
):
|
||||
raise SynapseError(
|
||||
400,
|
||||
"uses_allowed must be a non-negative integer or null",
|
||||
Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
expiry_time = body.get("expiry_time", None)
|
||||
if not isinstance(expiry_time, (int, type(None))):
|
||||
raise SynapseError(
|
||||
400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
|
||||
)
|
||||
if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
|
||||
raise SynapseError(
|
||||
400, "expiry_time must not be in the past", Codes.INVALID_PARAM
|
||||
)
|
||||
|
||||
created = await self.store.create_registration_token(
|
||||
token, uses_allowed, expiry_time
|
||||
)
|
||||
if not created:
|
||||
raise SynapseError(
|
||||
400, f"Token already exists: {token}", Codes.INVALID_PARAM
|
||||
)
|
||||
|
||||
resp = {
|
||||
"token": token,
|
||||
"uses_allowed": uses_allowed,
|
||||
"pending": 0,
|
||||
"completed": 0,
|
||||
"expiry_time": expiry_time,
|
||||
}
|
||||
return 200, resp
|
||||
|
||||
|
||||
class RegistrationTokenRestServlet(RestServlet):
|
||||
"""Retrieve, update, or delete the given token.
|
||||
|
||||
For example,
|
||||
|
||||
to retrieve a token:
|
||||
|
||||
GET /_synapse/admin/v1/registration_tokens/abcd
|
||||
|
||||
200 OK
|
||||
|
||||
{
|
||||
"token": "abcd",
|
||||
"uses_allowed": 3,
|
||||
"pending": 0,
|
||||
"completed": 1,
|
||||
"expiry_time": null
|
||||
}
|
||||
|
||||
|
||||
to update a token:
|
||||
|
||||
PUT /_synapse/admin/v1/registration_tokens/defg
|
||||
|
||||
{
|
||||
"uses_allowed": 5,
|
||||
"expiry_time": 4781243146000
|
||||
}
|
||||
|
||||
200 OK
|
||||
|
||||
{
|
||||
"token": "defg",
|
||||
"uses_allowed": 5,
|
||||
"pending": 0,
|
||||
"completed": 0,
|
||||
"expiry_time": 4781243146000
|
||||
}
|
||||
|
||||
|
||||
to delete a token:
|
||||
|
||||
DELETE /_synapse/admin/v1/registration_tokens/wxyz
|
||||
|
||||
200 OK
|
||||
|
||||
{}
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/registration_tokens/(?P<token>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_GET(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
|
||||
"""Retrieve a registration token."""
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
token_info = await self.store.get_one_registration_token(token)
|
||||
|
||||
# If no result return a 404
|
||||
if token_info is None:
|
||||
raise NotFoundError(f"No such registration token: {token}")
|
||||
|
||||
return 200, token_info
|
||||
|
||||
async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
|
||||
"""Update a registration token."""
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
body = parse_json_object_from_request(request)
|
||||
new_attributes = {}
|
||||
|
||||
# Only add uses_allowed to new_attributes if it is present and valid
|
||||
if "uses_allowed" in body:
|
||||
uses_allowed = body["uses_allowed"]
|
||||
if not (
|
||||
uses_allowed is None
|
||||
or (isinstance(uses_allowed, int) and uses_allowed >= 0)
|
||||
):
|
||||
raise SynapseError(
|
||||
400,
|
||||
"uses_allowed must be a non-negative integer or null",
|
||||
Codes.INVALID_PARAM,
|
||||
)
|
||||
new_attributes["uses_allowed"] = uses_allowed
|
||||
|
||||
if "expiry_time" in body:
|
||||
expiry_time = body["expiry_time"]
|
||||
if not isinstance(expiry_time, (int, type(None))):
|
||||
raise SynapseError(
|
||||
400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
|
||||
)
|
||||
if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
|
||||
raise SynapseError(
|
||||
400, "expiry_time must not be in the past", Codes.INVALID_PARAM
|
||||
)
|
||||
new_attributes["expiry_time"] = expiry_time
|
||||
|
||||
if len(new_attributes) == 0:
|
||||
# Nothing to update, get token info to return
|
||||
token_info = await self.store.get_one_registration_token(token)
|
||||
else:
|
||||
token_info = await self.store.update_registration_token(
|
||||
token, new_attributes
|
||||
)
|
||||
|
||||
# If no result return a 404
|
||||
if token_info is None:
|
||||
raise NotFoundError(f"No such registration token: {token}")
|
||||
|
||||
return 200, token_info
|
||||
|
||||
async def on_DELETE(
|
||||
self, request: SynapseRequest, token: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
"""Delete a registration token."""
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
if await self.store.delete_registration_token(token):
|
||||
return 200, {}
|
||||
|
||||
raise NotFoundError(f"No such registration token: {token}")
|
|
@ -46,41 +46,6 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ShutdownRoomRestServlet(RestServlet):
|
||||
"""Shuts down a room by removing all local users from the room and blocking
|
||||
all future invites and joins to the room. Any local aliases will be repointed
|
||||
to a new room created by `new_room_user_id` and kicked users will be auto
|
||||
joined to the new room.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/shutdown_room/(?P<room_id>[^/]+)")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.room_shutdown_handler = hs.get_room_shutdown_handler()
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(content, ["new_room_user_id"])
|
||||
|
||||
ret = await self.room_shutdown_handler.shutdown_room(
|
||||
room_id=room_id,
|
||||
new_room_user_id=content["new_room_user_id"],
|
||||
new_room_name=content.get("room_name"),
|
||||
message=content.get("message"),
|
||||
requester_user_id=requester.user.to_string(),
|
||||
block=True,
|
||||
)
|
||||
|
||||
return (200, ret)
|
||||
|
||||
|
||||
class DeleteRoomRestServlet(RestServlet):
|
||||
"""Delete a room from server.
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.errors import NotFoundError, SynapseError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
|
@ -53,6 +53,8 @@ class SendServerNoticeServlet(RestServlet):
|
|||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.server_notices_manager = hs.get_server_notices_manager()
|
||||
self.admin_handler = hs.get_admin_handler()
|
||||
self.txns = HttpTransactionCache(hs)
|
||||
|
||||
def register(self, json_resource: HttpServer):
|
||||
|
@ -79,19 +81,22 @@ class SendServerNoticeServlet(RestServlet):
|
|||
# We grab the server notices manager here as its initialisation has a check for worker processes,
|
||||
# but worker processes still need to initialise SendServerNoticeServlet (as it is part of the
|
||||
# admin api).
|
||||
if not self.hs.get_server_notices_manager().is_enabled():
|
||||
if not self.server_notices_manager.is_enabled():
|
||||
raise SynapseError(400, "Server notices are not enabled on this server")
|
||||
|
||||
user_id = body["user_id"]
|
||||
UserID.from_string(user_id)
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
target_user = UserID.from_string(body["user_id"])
|
||||
if not self.hs.is_mine(target_user):
|
||||
raise SynapseError(400, "Server notices can only be sent to local users")
|
||||
|
||||
event = await self.hs.get_server_notices_manager().send_notice(
|
||||
user_id=body["user_id"],
|
||||
if not await self.admin_handler.get_user(target_user):
|
||||
raise NotFoundError("User not found")
|
||||
|
||||
event = await self.server_notices_manager.send_notice(
|
||||
user_id=target_user.to_string(),
|
||||
type=event_type,
|
||||
state_key=state_key,
|
||||
event_content=body["content"],
|
||||
txn_id=txn_id,
|
||||
)
|
||||
|
||||
return 200, {"event_id": event.event_id}
|
||||
|
|
|
@ -228,13 +228,18 @@ class UserRestServletV2(RestServlet):
|
|||
if not isinstance(deactivate, bool):
|
||||
raise SynapseError(400, "'deactivated' parameter is not of type boolean")
|
||||
|
||||
# convert into List[Tuple[str, str]]
|
||||
# convert List[Dict[str, str]] into Set[Tuple[str, str]]
|
||||
if external_ids is not None:
|
||||
new_external_ids = []
|
||||
for external_id in external_ids:
|
||||
new_external_ids.append(
|
||||
(external_id["auth_provider"], external_id["external_id"])
|
||||
)
|
||||
new_external_ids = {
|
||||
(external_id["auth_provider"], external_id["external_id"])
|
||||
for external_id in external_ids
|
||||
}
|
||||
|
||||
# convert List[Dict[str, str]] into Set[Tuple[str, str]]
|
||||
if threepids is not None:
|
||||
new_threepids = {
|
||||
(threepid["medium"], threepid["address"]) for threepid in threepids
|
||||
}
|
||||
|
||||
if user: # modify user
|
||||
if "displayname" in body:
|
||||
|
@ -243,29 +248,39 @@ class UserRestServletV2(RestServlet):
|
|||
)
|
||||
|
||||
if threepids is not None:
|
||||
# remove old threepids from user
|
||||
old_threepids = await self.store.user_get_threepids(user_id)
|
||||
for threepid in old_threepids:
|
||||
# get changed threepids (added and removed)
|
||||
# convert List[Dict[str, Any]] into Set[Tuple[str, str]]
|
||||
cur_threepids = {
|
||||
(threepid["medium"], threepid["address"])
|
||||
for threepid in await self.store.user_get_threepids(user_id)
|
||||
}
|
||||
add_threepids = new_threepids - cur_threepids
|
||||
del_threepids = cur_threepids - new_threepids
|
||||
|
||||
# remove old threepids
|
||||
for medium, address in del_threepids:
|
||||
try:
|
||||
await self.auth_handler.delete_threepid(
|
||||
user_id, threepid["medium"], threepid["address"], None
|
||||
user_id, medium, address, None
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to remove threepids")
|
||||
raise SynapseError(500, "Failed to remove threepids")
|
||||
|
||||
# add new threepids to user
|
||||
# add new threepids
|
||||
current_time = self.hs.get_clock().time_msec()
|
||||
for threepid in threepids:
|
||||
for medium, address in add_threepids:
|
||||
await self.auth_handler.add_threepid(
|
||||
user_id, threepid["medium"], threepid["address"], current_time
|
||||
user_id, medium, address, current_time
|
||||
)
|
||||
|
||||
if external_ids is not None:
|
||||
# get changed external_ids (added and removed)
|
||||
cur_external_ids = await self.store.get_external_ids_by_user(user_id)
|
||||
add_external_ids = set(new_external_ids) - set(cur_external_ids)
|
||||
del_external_ids = set(cur_external_ids) - set(new_external_ids)
|
||||
cur_external_ids = set(
|
||||
await self.store.get_external_ids_by_user(user_id)
|
||||
)
|
||||
add_external_ids = new_external_ids - cur_external_ids
|
||||
del_external_ids = cur_external_ids - new_external_ids
|
||||
|
||||
# remove old external_ids
|
||||
for auth_provider, external_id in del_external_ids:
|
||||
|
@ -348,9 +363,9 @@ class UserRestServletV2(RestServlet):
|
|||
|
||||
if threepids is not None:
|
||||
current_time = self.hs.get_clock().time_msec()
|
||||
for threepid in threepids:
|
||||
for medium, address in new_threepids:
|
||||
await self.auth_handler.add_threepid(
|
||||
user_id, threepid["medium"], threepid["address"], current_time
|
||||
user_id, medium, address, current_time
|
||||
)
|
||||
if (
|
||||
self.hs.config.email_enable_notifs
|
||||
|
@ -362,8 +377,8 @@ class UserRestServletV2(RestServlet):
|
|||
kind="email",
|
||||
app_id="m.email",
|
||||
app_display_name="Email Notifications",
|
||||
device_display_name=threepid["address"],
|
||||
pushkey=threepid["address"],
|
||||
device_display_name=address,
|
||||
pushkey=address,
|
||||
lang=None, # We don't know a user's language here
|
||||
data={},
|
||||
)
|
||||
|
|
|
@ -13,24 +13,27 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.server import respond_with_html
|
||||
from synapse.http.servlet import RestServlet
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.http.server import HttpServer, respond_with_html
|
||||
from synapse.http.servlet import RestServlet, parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import client_patterns
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AccountValidityRenewServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/account_validity/renew$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
|
||||
self.hs = hs
|
||||
|
@ -46,18 +49,14 @@ class AccountValidityRenewServlet(RestServlet):
|
|||
hs.config.account_validity.account_validity_invalid_token_template
|
||||
)
|
||||
|
||||
async def on_GET(self, request):
|
||||
if b"token" not in request.args:
|
||||
raise SynapseError(400, "Missing renewal token")
|
||||
renewal_token = request.args[b"token"][0]
|
||||
async def on_GET(self, request: Request) -> None:
|
||||
renewal_token = parse_string(request, "token", required=True)
|
||||
|
||||
(
|
||||
token_valid,
|
||||
token_stale,
|
||||
expiration_ts,
|
||||
) = await self.account_activity_handler.renew_account(
|
||||
renewal_token.decode("utf8")
|
||||
)
|
||||
) = await self.account_activity_handler.renew_account(renewal_token)
|
||||
|
||||
if token_valid:
|
||||
status_code = 200
|
||||
|
@ -77,11 +76,7 @@ class AccountValidityRenewServlet(RestServlet):
|
|||
class AccountValiditySendMailServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/account_validity/send_mail$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
|
||||
self.hs = hs
|
||||
|
@ -91,7 +86,7 @@ class AccountValiditySendMailServlet(RestServlet):
|
|||
hs.config.account_validity.account_validity_renew_by_email_enabled
|
||||
)
|
||||
|
||||
async def on_POST(self, request):
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_expired=True)
|
||||
user_id = requester.user.to_string()
|
||||
await self.account_activity_handler.send_renewal_email_to_user(user_id)
|
||||
|
@ -99,6 +94,6 @@ class AccountValiditySendMailServlet(RestServlet):
|
|||
return 200, {}
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
AccountValidityRenewServlet(hs).register(http_server)
|
||||
AccountValiditySendMailServlet(hs).register(http_server)
|
||||
|
|
|
@ -15,11 +15,14 @@
|
|||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.errors import LoginError, SynapseError
|
||||
from synapse.api.urls import CLIENT_API_PREFIX
|
||||
from synapse.http.server import respond_with_html
|
||||
from synapse.http.server import HttpServer, respond_with_html
|
||||
from synapse.http.servlet import RestServlet, parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
|
||||
from ._base import client_patterns
|
||||
|
||||
|
@ -46,9 +49,10 @@ class AuthRestServlet(RestServlet):
|
|||
self.registration_handler = hs.get_registration_handler()
|
||||
self.recaptcha_template = hs.config.recaptcha_template
|
||||
self.terms_template = hs.config.terms_template
|
||||
self.registration_token_template = hs.config.registration_token_template
|
||||
self.success_template = hs.config.fallback_success_template
|
||||
|
||||
async def on_GET(self, request, stagetype):
|
||||
async def on_GET(self, request: SynapseRequest, stagetype: str) -> None:
|
||||
session = parse_string(request, "session")
|
||||
if not session:
|
||||
raise SynapseError(400, "No session supplied")
|
||||
|
@ -74,6 +78,12 @@ class AuthRestServlet(RestServlet):
|
|||
# re-authenticate with their SSO provider.
|
||||
html = await self.auth_handler.start_sso_ui_auth(request, session)
|
||||
|
||||
elif stagetype == LoginType.REGISTRATION_TOKEN:
|
||||
html = self.registration_token_template.render(
|
||||
session=session,
|
||||
myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
|
||||
)
|
||||
|
||||
else:
|
||||
raise SynapseError(404, "Unknown auth stage type")
|
||||
|
||||
|
@ -81,7 +91,7 @@ class AuthRestServlet(RestServlet):
|
|||
respond_with_html(request, 200, html)
|
||||
return None
|
||||
|
||||
async def on_POST(self, request, stagetype):
|
||||
async def on_POST(self, request: Request, stagetype: str) -> None:
|
||||
|
||||
session = parse_string(request, "session")
|
||||
if not session:
|
||||
|
@ -95,29 +105,32 @@ class AuthRestServlet(RestServlet):
|
|||
|
||||
authdict = {"response": response, "session": session}
|
||||
|
||||
success = await self.auth_handler.add_oob_auth(
|
||||
LoginType.RECAPTCHA, authdict, request.getClientIP()
|
||||
)
|
||||
|
||||
if success:
|
||||
html = self.success_template.render()
|
||||
else:
|
||||
try:
|
||||
await self.auth_handler.add_oob_auth(
|
||||
LoginType.RECAPTCHA, authdict, request.getClientIP()
|
||||
)
|
||||
except LoginError as e:
|
||||
# Authentication failed, let user try again
|
||||
html = self.recaptcha_template.render(
|
||||
session=session,
|
||||
myurl="%s/r0/auth/%s/fallback/web"
|
||||
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
|
||||
sitekey=self.hs.config.recaptcha_public_key,
|
||||
error=e.msg,
|
||||
)
|
||||
else:
|
||||
# No LoginError was raised, so authentication was successful
|
||||
html = self.success_template.render()
|
||||
|
||||
elif stagetype == LoginType.TERMS:
|
||||
authdict = {"session": session}
|
||||
|
||||
success = await self.auth_handler.add_oob_auth(
|
||||
LoginType.TERMS, authdict, request.getClientIP()
|
||||
)
|
||||
|
||||
if success:
|
||||
html = self.success_template.render()
|
||||
else:
|
||||
try:
|
||||
await self.auth_handler.add_oob_auth(
|
||||
LoginType.TERMS, authdict, request.getClientIP()
|
||||
)
|
||||
except LoginError as e:
|
||||
# Authentication failed, let user try again
|
||||
html = self.terms_template.render(
|
||||
session=session,
|
||||
terms_url="%s_matrix/consent?v=%s"
|
||||
|
@ -127,10 +140,33 @@ class AuthRestServlet(RestServlet):
|
|||
),
|
||||
myurl="%s/r0/auth/%s/fallback/web"
|
||||
% (CLIENT_API_PREFIX, LoginType.TERMS),
|
||||
error=e.msg,
|
||||
)
|
||||
else:
|
||||
# No LoginError was raised, so authentication was successful
|
||||
html = self.success_template.render()
|
||||
|
||||
elif stagetype == LoginType.SSO:
|
||||
# The SSO fallback workflow should not post here,
|
||||
raise SynapseError(404, "Fallback SSO auth does not support POST requests.")
|
||||
|
||||
elif stagetype == LoginType.REGISTRATION_TOKEN:
|
||||
token = parse_string(request, "token", required=True)
|
||||
authdict = {"session": session, "token": token}
|
||||
|
||||
try:
|
||||
await self.auth_handler.add_oob_auth(
|
||||
LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP()
|
||||
)
|
||||
except LoginError as e:
|
||||
html = self.registration_token_template.render(
|
||||
session=session,
|
||||
myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
|
||||
error=e.msg,
|
||||
)
|
||||
else:
|
||||
html = self.success_template.render()
|
||||
|
||||
else:
|
||||
raise SynapseError(404, "Unknown auth stage type")
|
||||
|
||||
|
@ -139,5 +175,5 @@ class AuthRestServlet(RestServlet):
|
|||
return None
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
AuthRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -15,6 +15,7 @@ import logging
|
|||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, MSC3244_CAPABILITIES
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import JsonDict
|
||||
|
@ -61,8 +62,19 @@ class CapabilitiesRestServlet(RestServlet):
|
|||
"org.matrix.msc3244.room_capabilities"
|
||||
] = MSC3244_CAPABILITIES
|
||||
|
||||
if self.config.experimental.msc3283_enabled:
|
||||
response["capabilities"]["org.matrix.msc3283.set_displayname"] = {
|
||||
"enabled": self.config.enable_set_displayname
|
||||
}
|
||||
response["capabilities"]["org.matrix.msc3283.set_avatar_url"] = {
|
||||
"enabled": self.config.enable_set_avatar_url
|
||||
}
|
||||
response["capabilities"]["org.matrix.msc3283.3pid_changes"] = {
|
||||
"enabled": self.config.enable_3pid_changes
|
||||
}
|
||||
|
||||
return 200, response
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
CapabilitiesRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -14,34 +14,36 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api import errors
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
parse_json_object_from_request,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import client_patterns, interactive_auth_handler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DevicesRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/devices$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
async def on_GET(self, request):
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
devices = await self.device_handler.get_devices_by_user(
|
||||
requester.user.to_string()
|
||||
|
@ -57,7 +59,7 @@ class DeleteDevicesRestServlet(RestServlet):
|
|||
|
||||
PATTERNS = client_patterns("/delete_devices")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
|
@ -65,7 +67,7 @@ class DeleteDevicesRestServlet(RestServlet):
|
|||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
@interactive_auth_handler
|
||||
async def on_POST(self, request):
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
try:
|
||||
|
@ -100,18 +102,16 @@ class DeleteDevicesRestServlet(RestServlet):
|
|||
class DeviceRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/devices/(?P<device_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
async def on_GET(self, request, device_id):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, device_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
device = await self.device_handler.get_device(
|
||||
requester.user.to_string(), device_id
|
||||
|
@ -119,7 +119,9 @@ class DeviceRestServlet(RestServlet):
|
|||
return 200, device
|
||||
|
||||
@interactive_auth_handler
|
||||
async def on_DELETE(self, request, device_id):
|
||||
async def on_DELETE(
|
||||
self, request: SynapseRequest, device_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
try:
|
||||
|
@ -146,7 +148,9 @@ class DeviceRestServlet(RestServlet):
|
|||
await self.device_handler.delete_device(requester.user.to_string(), device_id)
|
||||
return 200, {}
|
||||
|
||||
async def on_PUT(self, request, device_id):
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, device_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
|
@ -193,13 +197,13 @@ class DehydratedDeviceServlet(RestServlet):
|
|||
|
||||
PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=())
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
async def on_GET(self, request: SynapseRequest):
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
dehydrated_device = await self.device_handler.get_dehydrated_device(
|
||||
requester.user.to_string()
|
||||
|
@ -211,7 +215,7 @@ class DehydratedDeviceServlet(RestServlet):
|
|||
else:
|
||||
raise errors.NotFoundError("No dehydrated device available")
|
||||
|
||||
async def on_PUT(self, request: SynapseRequest):
|
||||
async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
submission = parse_json_object_from_request(request)
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
|
@ -259,13 +263,13 @@ class ClaimDehydratedDeviceServlet(RestServlet):
|
|||
"/org.matrix.msc2697.v2/dehydrated_device/claim", releases=()
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
async def on_POST(self, request: SynapseRequest):
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
submission = parse_json_object_from_request(request)
|
||||
|
@ -292,7 +296,7 @@ class ClaimDehydratedDeviceServlet(RestServlet):
|
|||
return (200, result)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
DeleteDevicesRestServlet(hs).register(http_server)
|
||||
DevicesRestServlet(hs).register(http_server)
|
||||
DeviceRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -12,8 +12,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
|
@ -22,14 +24,19 @@ from synapse.api.errors import (
|
|||
NotFoundError,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client._base import client_patterns
|
||||
from synapse.types import RoomAlias
|
||||
from synapse.types import JsonDict, RoomAlias
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
ClientDirectoryServer(hs).register(http_server)
|
||||
ClientDirectoryListServer(hs).register(http_server)
|
||||
ClientAppserviceDirectoryListServer(hs).register(http_server)
|
||||
|
@ -38,21 +45,23 @@ def register_servlets(hs, http_server):
|
|||
class ClientDirectoryServer(RestServlet):
|
||||
PATTERNS = client_patterns("/directory/room/(?P<room_alias>[^/]*)$", v1=True)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.store = hs.get_datastore()
|
||||
self.directory_handler = hs.get_directory_handler()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_GET(self, request, room_alias):
|
||||
room_alias = RoomAlias.from_string(room_alias)
|
||||
async def on_GET(self, request: Request, room_alias: str) -> Tuple[int, JsonDict]:
|
||||
room_alias_obj = RoomAlias.from_string(room_alias)
|
||||
|
||||
res = await self.directory_handler.get_association(room_alias)
|
||||
res = await self.directory_handler.get_association(room_alias_obj)
|
||||
|
||||
return 200, res
|
||||
|
||||
async def on_PUT(self, request, room_alias):
|
||||
room_alias = RoomAlias.from_string(room_alias)
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, room_alias: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
room_alias_obj = RoomAlias.from_string(room_alias)
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
if "room_id" not in content:
|
||||
|
@ -61,7 +70,7 @@ class ClientDirectoryServer(RestServlet):
|
|||
)
|
||||
|
||||
logger.debug("Got content: %s", content)
|
||||
logger.debug("Got room name: %s", room_alias.to_string())
|
||||
logger.debug("Got room name: %s", room_alias_obj.to_string())
|
||||
|
||||
room_id = content["room_id"]
|
||||
servers = content["servers"] if "servers" in content else None
|
||||
|
@ -78,22 +87,25 @@ class ClientDirectoryServer(RestServlet):
|
|||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
await self.directory_handler.create_association(
|
||||
requester, room_alias, room_id, servers
|
||||
requester, room_alias_obj, room_id, servers
|
||||
)
|
||||
|
||||
return 200, {}
|
||||
|
||||
async def on_DELETE(self, request, room_alias):
|
||||
async def on_DELETE(
|
||||
self, request: SynapseRequest, room_alias: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
room_alias_obj = RoomAlias.from_string(room_alias)
|
||||
|
||||
try:
|
||||
service = self.auth.get_appservice_by_req(request)
|
||||
room_alias = RoomAlias.from_string(room_alias)
|
||||
await self.directory_handler.delete_appservice_association(
|
||||
service, room_alias
|
||||
service, room_alias_obj
|
||||
)
|
||||
logger.info(
|
||||
"Application service at %s deleted alias %s",
|
||||
service.url,
|
||||
room_alias.to_string(),
|
||||
room_alias_obj.to_string(),
|
||||
)
|
||||
return 200, {}
|
||||
except InvalidClientCredentialsError:
|
||||
|
@ -103,12 +115,10 @@ class ClientDirectoryServer(RestServlet):
|
|||
requester = await self.auth.get_user_by_req(request)
|
||||
user = requester.user
|
||||
|
||||
room_alias = RoomAlias.from_string(room_alias)
|
||||
|
||||
await self.directory_handler.delete_association(requester, room_alias)
|
||||
await self.directory_handler.delete_association(requester, room_alias_obj)
|
||||
|
||||
logger.info(
|
||||
"User %s deleted alias %s", user.to_string(), room_alias.to_string()
|
||||
"User %s deleted alias %s", user.to_string(), room_alias_obj.to_string()
|
||||
)
|
||||
|
||||
return 200, {}
|
||||
|
@ -117,20 +127,22 @@ class ClientDirectoryServer(RestServlet):
|
|||
class ClientDirectoryListServer(RestServlet):
|
||||
PATTERNS = client_patterns("/directory/list/room/(?P<room_id>[^/]*)$", v1=True)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.store = hs.get_datastore()
|
||||
self.directory_handler = hs.get_directory_handler()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_GET(self, request, room_id):
|
||||
async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
|
||||
room = await self.store.get_room(room_id)
|
||||
if room is None:
|
||||
raise NotFoundError("Unknown room")
|
||||
|
||||
return 200, {"visibility": "public" if room["is_public"] else "private"}
|
||||
|
||||
async def on_PUT(self, request, room_id):
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
|
@ -142,7 +154,9 @@ class ClientDirectoryListServer(RestServlet):
|
|||
|
||||
return 200, {}
|
||||
|
||||
async def on_DELETE(self, request, room_id):
|
||||
async def on_DELETE(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
await self.directory_handler.edit_published_room_list(
|
||||
|
@ -157,21 +171,27 @@ class ClientAppserviceDirectoryListServer(RestServlet):
|
|||
"/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$", v1=True
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.store = hs.get_datastore()
|
||||
self.directory_handler = hs.get_directory_handler()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
def on_PUT(self, request, network_id, room_id):
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, network_id: str, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
content = parse_json_object_from_request(request)
|
||||
visibility = content.get("visibility", "public")
|
||||
return self._edit(request, network_id, room_id, visibility)
|
||||
return await self._edit(request, network_id, room_id, visibility)
|
||||
|
||||
def on_DELETE(self, request, network_id, room_id):
|
||||
return self._edit(request, network_id, room_id, "private")
|
||||
async def on_DELETE(
|
||||
self, request: SynapseRequest, network_id: str, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
return await self._edit(request, network_id, room_id, "private")
|
||||
|
||||
async def _edit(self, request, network_id, room_id, visibility):
|
||||
async def _edit(
|
||||
self, request: SynapseRequest, network_id: str, room_id: str, visibility: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
if not requester.app_service:
|
||||
raise AuthError(
|
||||
|
|
|
@ -14,11 +14,18 @@
|
|||
|
||||
"""This module contains REST servlets to do with event streaming, /events."""
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.servlet import RestServlet
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client._base import client_patterns
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -28,31 +35,30 @@ class EventStreamRestServlet(RestServlet):
|
|||
|
||||
DEFAULT_LONGPOLL_TIME_MS = 30000
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.event_stream_handler = hs.get_event_stream_handler()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_GET(self, request):
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
is_guest = requester.is_guest
|
||||
room_id = None
|
||||
args: Dict[bytes, List[bytes]] = request.args # type: ignore
|
||||
if is_guest:
|
||||
if b"room_id" not in request.args:
|
||||
if b"room_id" not in args:
|
||||
raise SynapseError(400, "Guest users must specify room_id param")
|
||||
if b"room_id" in request.args:
|
||||
room_id = request.args[b"room_id"][0].decode("ascii")
|
||||
room_id = parse_string(request, "room_id")
|
||||
|
||||
pagin_config = await PaginationConfig.from_request(self.store, request)
|
||||
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
|
||||
if b"timeout" in request.args:
|
||||
if b"timeout" in args:
|
||||
try:
|
||||
timeout = int(request.args[b"timeout"][0])
|
||||
timeout = int(args[b"timeout"][0])
|
||||
except ValueError:
|
||||
raise SynapseError(400, "timeout must be in milliseconds.")
|
||||
|
||||
as_client_event = b"raw" not in request.args
|
||||
as_client_event = b"raw" not in args
|
||||
|
||||
chunk = await self.event_stream_handler.get_stream(
|
||||
requester.user.to_string(),
|
||||
|
@ -70,25 +76,27 @@ class EventStreamRestServlet(RestServlet):
|
|||
class EventRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.clock = hs.get_clock()
|
||||
self.event_handler = hs.get_event_handler()
|
||||
self.auth = hs.get_auth()
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
|
||||
async def on_GET(self, request, event_id):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, event_id: str
|
||||
) -> Tuple[int, Union[str, JsonDict]]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
event = await self.event_handler.get_event(requester.user, None, event_id)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
if event:
|
||||
event = await self._event_serializer.serialize_event(event, time_now)
|
||||
return 200, event
|
||||
result = await self._event_serializer.serialize_event(event, time_now)
|
||||
return 200, result
|
||||
else:
|
||||
return 404, "Event not found."
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
EventStreamRestServlet(hs).register(http_server)
|
||||
EventRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -13,26 +13,34 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.types import UserID
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import JsonDict, UserID
|
||||
|
||||
from ._base import client_patterns, set_timeline_upper_limit
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GetFilterRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.filtering = hs.get_filtering()
|
||||
|
||||
async def on_GET(self, request, user_id, filter_id):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str, filter_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
target_user = UserID.from_string(user_id)
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
|
@ -43,13 +51,13 @@ class GetFilterRestServlet(RestServlet):
|
|||
raise AuthError(403, "Can only get filters for local users")
|
||||
|
||||
try:
|
||||
filter_id = int(filter_id)
|
||||
filter_id_int = int(filter_id)
|
||||
except Exception:
|
||||
raise SynapseError(400, "Invalid filter_id")
|
||||
|
||||
try:
|
||||
filter_collection = await self.filtering.get_user_filter(
|
||||
user_localpart=target_user.localpart, filter_id=filter_id
|
||||
user_localpart=target_user.localpart, filter_id=filter_id_int
|
||||
)
|
||||
except StoreError as e:
|
||||
if e.code != 404:
|
||||
|
@ -62,13 +70,15 @@ class GetFilterRestServlet(RestServlet):
|
|||
class CreateFilterRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.filtering = hs.get_filtering()
|
||||
|
||||
async def on_POST(self, request, user_id):
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
|
||||
target_user = UserID.from_string(user_id)
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
@ -89,6 +99,6 @@ class CreateFilterRestServlet(RestServlet):
|
|||
return 200, {"filter_id": str(filter_id)}
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
GetFilterRestServlet(hs).register(http_server)
|
||||
CreateFilterRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -26,6 +26,7 @@ from synapse.api.constants import (
|
|||
)
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.handlers.groups_local import GroupsLocalHandler
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
|
@ -930,7 +931,7 @@ class GroupsForUserServlet(RestServlet):
|
|||
return 200, result
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
GroupServlet(hs).register(http_server)
|
||||
GroupSummaryServlet(hs).register(http_server)
|
||||
GroupInvitedUsersServlet(hs).register(http_server)
|
||||
|
|
|
@ -12,25 +12,33 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple
|
||||
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_boolean
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client._base import client_patterns
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
# TODO: Needs unit testing
|
||||
class InitialSyncRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/initialSync$", v1=True)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.initial_sync_handler = hs.get_initial_sync_handler()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_GET(self, request):
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
as_client_event = b"raw" not in request.args
|
||||
args: Dict[bytes, List[bytes]] = request.args # type: ignore
|
||||
as_client_event = b"raw" not in args
|
||||
pagination_config = await PaginationConfig.from_request(self.store, request)
|
||||
include_archived = parse_boolean(request, "archived", default=False)
|
||||
content = await self.initial_sync_handler.snapshot_all_rooms(
|
||||
|
@ -43,5 +51,5 @@ class InitialSyncRestServlet(RestServlet):
|
|||
return 200, content
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
InitialSyncRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -15,19 +15,25 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Optional, Tuple
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.errors import InvalidAPICallError, SynapseError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
parse_integer,
|
||||
parse_json_object_from_request,
|
||||
parse_string,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||
from synapse.types import StreamToken
|
||||
from synapse.types import JsonDict, StreamToken
|
||||
|
||||
from ._base import client_patterns, interactive_auth_handler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -59,18 +65,16 @@ class KeyUploadServlet(RestServlet):
|
|||
|
||||
PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
@trace(opname="upload_keys")
|
||||
async def on_POST(self, request, device_id):
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, device_id: Optional[str]
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
user_id = requester.user.to_string()
|
||||
body = parse_json_object_from_request(request)
|
||||
|
@ -148,21 +152,30 @@ class KeyQueryServlet(RestServlet):
|
|||
|
||||
PATTERNS = client_patterns("/keys/query$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer):
|
||||
"""
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||
|
||||
async def on_POST(self, request):
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
user_id = requester.user.to_string()
|
||||
device_id = requester.device_id
|
||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
device_keys = body.get("device_keys")
|
||||
if not isinstance(device_keys, dict):
|
||||
raise InvalidAPICallError("'device_keys' must be a JSON object")
|
||||
|
||||
def is_list_of_strings(values: Any) -> bool:
|
||||
return isinstance(values, list) and all(isinstance(v, str) for v in values)
|
||||
|
||||
if any(not is_list_of_strings(keys) for keys in device_keys.values()):
|
||||
raise InvalidAPICallError(
|
||||
"'device_keys' values must be a list of strings",
|
||||
)
|
||||
|
||||
result = await self.e2e_keys_handler.query_devices(
|
||||
body, timeout, user_id, device_id
|
||||
)
|
||||
|
@ -181,17 +194,13 @@ class KeyChangesServlet(RestServlet):
|
|||
|
||||
PATTERNS = client_patterns("/keys/changes$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer):
|
||||
"""
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_GET(self, request):
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
from_token_string = parse_string(request, "from", required=True)
|
||||
|
@ -231,12 +240,12 @@ class OneTimeKeyServlet(RestServlet):
|
|||
|
||||
PATTERNS = client_patterns("/keys/claim$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||
|
||||
async def on_POST(self, request):
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||
body = parse_json_object_from_request(request)
|
||||
|
@ -255,11 +264,7 @@ class SigningKeyUploadServlet(RestServlet):
|
|||
|
||||
PATTERNS = client_patterns("/keys/device_signing/upload$", releases=())
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
|
@ -267,7 +272,7 @@ class SigningKeyUploadServlet(RestServlet):
|
|||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
@interactive_auth_handler
|
||||
async def on_POST(self, request):
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
body = parse_json_object_from_request(request)
|
||||
|
@ -315,16 +320,12 @@ class SignaturesUploadServlet(RestServlet):
|
|||
|
||||
PATTERNS = client_patterns("/keys/signatures/upload$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||
|
||||
async def on_POST(self, request):
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
user_id = requester.user.to_string()
|
||||
body = parse_json_object_from_request(request)
|
||||
|
@ -335,7 +336,7 @@ class SignaturesUploadServlet(RestServlet):
|
|||
return 200, result
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
KeyUploadServlet(hs).register(http_server)
|
||||
KeyQueryServlet(hs).register(http_server)
|
||||
KeyChangesServlet(hs).register(http_server)
|
||||
|
|
|
@ -19,6 +19,7 @@ from twisted.web.server import Request
|
|||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
parse_json_object_from_request,
|
||||
|
@ -103,5 +104,5 @@ class KnockRoomAliasServlet(RestServlet):
|
|||
)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
KnockRoomAliasServlet(hs).register(http_server)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -14,7 +14,7 @@
|
|||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
@ -104,7 +104,13 @@ class LoginRestServlet(RestServlet):
|
|||
burst_count=self.hs.config.rc_login_account.burst_count,
|
||||
)
|
||||
|
||||
def on_GET(self, request: SynapseRequest):
|
||||
# ensure the CAS/SAML/OIDC handlers are loaded on this worker instance.
|
||||
# The reason for this is to ensure that the auth_provider_ids are registered
|
||||
# with SsoHandler, which in turn ensures that the login/registration prometheus
|
||||
# counters are initialised for the auth_provider_ids.
|
||||
_load_sso_handlers(hs)
|
||||
|
||||
def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
flows = []
|
||||
if self.jwt_enabled:
|
||||
flows.append({"type": LoginRestServlet.JWT_TYPE})
|
||||
|
@ -151,7 +157,7 @@ class LoginRestServlet(RestServlet):
|
|||
|
||||
return 200, {"flows": flows}
|
||||
|
||||
async def on_POST(self, request: SynapseRequest):
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]:
|
||||
login_submission = parse_json_object_from_request(request)
|
||||
|
||||
if self._msc2918_enabled:
|
||||
|
@ -211,7 +217,7 @@ class LoginRestServlet(RestServlet):
|
|||
login_submission: JsonDict,
|
||||
appservice: ApplicationService,
|
||||
should_issue_refresh_token: bool = False,
|
||||
):
|
||||
) -> LoginResponse:
|
||||
identifier = login_submission.get("identifier")
|
||||
logger.info("Got appservice login request with identifier: %r", identifier)
|
||||
|
||||
|
@ -461,10 +467,7 @@ class RefreshTokenServlet(RestServlet):
|
|||
self._clock = hs.get_clock()
|
||||
self.access_token_lifetime = hs.config.access_token_lifetime
|
||||
|
||||
async def on_POST(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
):
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
refresh_submission = parse_json_object_from_request(request)
|
||||
|
||||
assert_params_in_dict(refresh_submission, ["refresh_token"])
|
||||
|
@ -499,12 +502,7 @@ class SsoRedirectServlet(RestServlet):
|
|||
def __init__(self, hs: "HomeServer"):
|
||||
# make sure that the relevant handlers are instantiated, so that they
|
||||
# register themselves with the main SSOHandler.
|
||||
if hs.config.cas_enabled:
|
||||
hs.get_cas_handler()
|
||||
if hs.config.saml2_enabled:
|
||||
hs.get_saml_handler()
|
||||
if hs.config.oidc_enabled:
|
||||
hs.get_oidc_handler()
|
||||
_load_sso_handlers(hs)
|
||||
self._sso_handler = hs.get_sso_handler()
|
||||
self._msc2858_enabled = hs.config.experimental.msc2858_enabled
|
||||
self._public_baseurl = hs.config.public_baseurl
|
||||
|
@ -569,7 +567,7 @@ class SsoRedirectServlet(RestServlet):
|
|||
class CasTicketServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/login/cas/ticket", v1=True)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self._cas_handler = hs.get_cas_handler()
|
||||
|
||||
|
@ -591,10 +589,26 @@ class CasTicketServlet(RestServlet):
|
|||
)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
LoginRestServlet(hs).register(http_server)
|
||||
if hs.config.access_token_lifetime is not None:
|
||||
RefreshTokenServlet(hs).register(http_server)
|
||||
SsoRedirectServlet(hs).register(http_server)
|
||||
if hs.config.cas_enabled:
|
||||
CasTicketServlet(hs).register(http_server)
|
||||
|
||||
|
||||
def _load_sso_handlers(hs: "HomeServer") -> None:
|
||||
"""Ensure that the SSO handlers are loaded, if they are enabled by configuration.
|
||||
|
||||
This is mostly useful to ensure that the CAS/SAML/OIDC handlers register themselves
|
||||
with the main SsoHandler.
|
||||
|
||||
It's safe to call this multiple times.
|
||||
"""
|
||||
if hs.config.cas.cas_enabled:
|
||||
hs.get_cas_handler()
|
||||
if hs.config.saml2.saml2_enabled:
|
||||
hs.get_saml_handler()
|
||||
if hs.config.oidc.oidc_enabled:
|
||||
hs.get_oidc_handler()
|
||||
|
|
|
@ -13,9 +13,16 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client._base import client_patterns
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -23,13 +30,13 @@ logger = logging.getLogger(__name__)
|
|||
class LogoutRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/logout$", v1=True)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._device_handler = hs.get_device_handler()
|
||||
|
||||
async def on_POST(self, request):
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_expired=True)
|
||||
|
||||
if requester.device_id is None:
|
||||
|
@ -48,13 +55,13 @@ class LogoutRestServlet(RestServlet):
|
|||
class LogoutAllRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/logout/all$", v1=True)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._device_handler = hs.get_device_handler()
|
||||
|
||||
async def on_POST(self, request):
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_expired=True)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
|
@ -67,6 +74,6 @@ class LogoutAllRestServlet(RestServlet):
|
|||
return 200, {}
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
LogoutRestServlet(hs).register(http_server)
|
||||
LogoutAllRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -13,26 +13,33 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.events.utils import format_event_for_client_v2_without_room_id
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_integer, parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import client_patterns
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NotificationsServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/notifications$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
|
||||
async def on_GET(self, request):
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
|
@ -87,5 +94,5 @@ class NotificationsServlet(RestServlet):
|
|||
return 200, {"notifications": returned_push_actions, "next_token": next_token}
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
NotificationsServlet(hs).register(http_server)
|
||||
|
|
|
@ -12,15 +12,21 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
from ._base import client_patterns
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -58,14 +64,16 @@ class IdTokenServlet(RestServlet):
|
|||
|
||||
EXPIRES_MS = 3600 * 1000
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
self.server_name = hs.config.server_name
|
||||
|
||||
async def on_POST(self, request, user_id):
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
if user_id != requester.user.to_string():
|
||||
raise AuthError(403, "Cannot request tokens for other users.")
|
||||
|
@ -90,5 +98,5 @@ class IdTokenServlet(RestServlet):
|
|||
)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
IdTokenServlet(hs).register(http_server)
|
||||
|
|
|
@ -13,28 +13,32 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import client_patterns
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PasswordPolicyServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/password_policy$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
|
||||
self.policy = hs.config.password_policy
|
||||
self.enabled = hs.config.password_policy_enabled
|
||||
|
||||
def on_GET(self, request):
|
||||
def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
|
||||
if not self.enabled or not self.policy:
|
||||
return (200, {})
|
||||
|
||||
|
@ -53,5 +57,5 @@ class PasswordPolicyServlet(RestServlet):
|
|||
return (200, policy)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
PasswordPolicyServlet(hs).register(http_server)
|
||||
|
|
|
@ -15,12 +15,18 @@
|
|||
""" This module contains REST servlets to do with presence: /presence/<paths>
|
||||
"""
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.errors import AuthError, SynapseError
|
||||
from synapse.handlers.presence import format_user_presence_state
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client._base import client_patterns
|
||||
from synapse.types import UserID
|
||||
from synapse.types import JsonDict, UserID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -28,7 +34,7 @@ logger = logging.getLogger(__name__)
|
|||
class PresenceStatusRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status", v1=True)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
|
@ -37,7 +43,9 @@ class PresenceStatusRestServlet(RestServlet):
|
|||
|
||||
self._use_presence = hs.config.server.use_presence
|
||||
|
||||
async def on_GET(self, request, user_id):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
|
@ -53,13 +61,15 @@ class PresenceStatusRestServlet(RestServlet):
|
|||
raise AuthError(403, "You are not allowed to see their presence.")
|
||||
|
||||
state = await self.presence_handler.get_state(target_user=user)
|
||||
state = format_user_presence_state(
|
||||
result = format_user_presence_state(
|
||||
state, self.clock.time_msec(), include_user_id=False
|
||||
)
|
||||
|
||||
return 200, state
|
||||
return 200, result
|
||||
|
||||
async def on_PUT(self, request, user_id):
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
|
@ -91,5 +101,5 @@ class PresenceStatusRestServlet(RestServlet):
|
|||
return 200, {}
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
PresenceStatusRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -14,22 +14,31 @@
|
|||
|
||||
""" This module contains REST servlets to do with profile: /profile/<paths> """
|
||||
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client._base import client_patterns
|
||||
from synapse.types import UserID
|
||||
from synapse.types import JsonDict, UserID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
class ProfileDisplaynameRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_GET(self, request, user_id):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester_user = None
|
||||
|
||||
if self.hs.config.require_auth_for_profile_requests:
|
||||
|
@ -48,7 +57,9 @@ class ProfileDisplaynameRestServlet(RestServlet):
|
|||
|
||||
return 200, ret
|
||||
|
||||
async def on_PUT(self, request, user_id):
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
user = UserID.from_string(user_id)
|
||||
is_admin = await self.auth.is_server_admin(requester.user)
|
||||
|
@ -72,13 +83,15 @@ class ProfileDisplaynameRestServlet(RestServlet):
|
|||
class ProfileAvatarURLRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_GET(self, request, user_id):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester_user = None
|
||||
|
||||
if self.hs.config.require_auth_for_profile_requests:
|
||||
|
@ -97,7 +110,9 @@ class ProfileAvatarURLRestServlet(RestServlet):
|
|||
|
||||
return 200, ret
|
||||
|
||||
async def on_PUT(self, request, user_id):
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
user = UserID.from_string(user_id)
|
||||
is_admin = await self.auth.is_server_admin(requester.user)
|
||||
|
@ -120,13 +135,15 @@ class ProfileAvatarURLRestServlet(RestServlet):
|
|||
class ProfileRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_GET(self, request, user_id):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester_user = None
|
||||
|
||||
if self.hs.config.require_auth_for_profile_requests:
|
||||
|
@ -149,7 +166,7 @@ class ProfileRestServlet(RestServlet):
|
|||
return 200, ret
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
ProfileDisplaynameRestServlet(hs).register(http_server)
|
||||
ProfileAvatarURLRestServlet(hs).register(http_server)
|
||||
ProfileRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -13,17 +13,23 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError
|
||||
from synapse.http.server import respond_with_html_bytes
|
||||
from synapse.http.server import HttpServer, respond_with_html_bytes
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
parse_json_object_from_request,
|
||||
parse_string,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.push import PusherConfigException
|
||||
from synapse.rest.client._base import client_patterns
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -31,12 +37,12 @@ logger = logging.getLogger(__name__)
|
|||
class PushersRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/pushers$", v1=True)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_GET(self, request):
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
user = requester.user
|
||||
|
||||
|
@ -50,14 +56,14 @@ class PushersRestServlet(RestServlet):
|
|||
class PushersSetRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/pushers/set$", v1=True)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.notifier = hs.get_notifier()
|
||||
self.pusher_pool = self.hs.get_pusherpool()
|
||||
|
||||
async def on_POST(self, request):
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
user = requester.user
|
||||
|
||||
|
@ -132,14 +138,14 @@ class PushersRemoveRestServlet(RestServlet):
|
|||
PATTERNS = client_patterns("/pushers/remove$", v1=True)
|
||||
SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>"
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.notifier = hs.get_notifier()
|
||||
self.auth = hs.get_auth()
|
||||
self.pusher_pool = self.hs.get_pusherpool()
|
||||
|
||||
async def on_GET(self, request):
|
||||
async def on_GET(self, request: SynapseRequest) -> None:
|
||||
requester = await self.auth.get_user_by_req(request, rights="delete_pusher")
|
||||
user = requester.user
|
||||
|
||||
|
@ -165,7 +171,7 @@ class PushersRemoveRestServlet(RestServlet):
|
|||
return None
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
PushersRestServlet(hs).register(http_server)
|
||||
PushersSetRestServlet(hs).register(http_server)
|
||||
PushersRemoveRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -13,27 +13,36 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.constants import ReadReceiptEventFields
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import client_patterns
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReadMarkerRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.receipts_handler = hs.get_receipts_handler()
|
||||
self.read_marker_handler = hs.get_read_marker_handler()
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
|
||||
async def on_POST(self, request, room_id):
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
await self.presence_handler.bump_presence_active_time(requester.user)
|
||||
|
@ -70,5 +79,5 @@ class ReadMarkerRestServlet(RestServlet):
|
|||
return 200, {}
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
ReadMarkerRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import hmac
|
||||
import logging
|
||||
import random
|
||||
from typing import List, Union
|
||||
|
@ -28,6 +27,7 @@ from synapse.api.errors import (
|
|||
ThreepidValidationError,
|
||||
UnrecognizedRequestError,
|
||||
)
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.config import ConfigError
|
||||
from synapse.config.captcha import CaptchaConfig
|
||||
from synapse.config.consent import ConsentConfig
|
||||
|
@ -59,18 +59,6 @@ from synapse.util.threepids import (
|
|||
|
||||
from ._base import client_patterns, interactive_auth_handler
|
||||
|
||||
# We ought to be using hmac.compare_digest() but on older pythons it doesn't
|
||||
# exist. It's a _really minor_ security flaw to use plain string comparison
|
||||
# because the timing attack is so obscured by all the other code here it's
|
||||
# unlikely to make much difference
|
||||
if hasattr(hmac, "compare_digest"):
|
||||
compare_digest = hmac.compare_digest
|
||||
else:
|
||||
|
||||
def compare_digest(a, b):
|
||||
return a == b
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -379,6 +367,55 @@ class UsernameAvailabilityRestServlet(RestServlet):
|
|||
return 200, {"available": True}
|
||||
|
||||
|
||||
class RegistrationTokenValidityRestServlet(RestServlet):
|
||||
"""Check the validity of a registration token.
|
||||
|
||||
Example:
|
||||
|
||||
GET /_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity?token=abcd
|
||||
|
||||
200 OK
|
||||
|
||||
{
|
||||
"valid": true
|
||||
}
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns(
|
||||
f"/org.matrix.msc3231/register/{LoginType.REGISTRATION_TOKEN}/validity",
|
||||
releases=(),
|
||||
unstable=True,
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.ratelimiter = Ratelimiter(
|
||||
store=self.store,
|
||||
clock=hs.get_clock(),
|
||||
rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second,
|
||||
burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count,
|
||||
)
|
||||
|
||||
async def on_GET(self, request):
|
||||
await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
|
||||
|
||||
if not self.hs.config.enable_registration:
|
||||
raise SynapseError(
|
||||
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
token = parse_string(request, "token", required=True)
|
||||
valid = await self.store.registration_token_is_valid(token)
|
||||
|
||||
return 200, {"valid": valid}
|
||||
|
||||
|
||||
class RegisterRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/register$")
|
||||
|
||||
|
@ -686,6 +723,22 @@ class RegisterRestServlet(RestServlet):
|
|||
)
|
||||
|
||||
if registered:
|
||||
# Check if a token was used to authenticate registration
|
||||
registration_token = await self.auth_handler.get_session_data(
|
||||
session_id,
|
||||
UIAuthSessionDataConstants.REGISTRATION_TOKEN,
|
||||
)
|
||||
if registration_token:
|
||||
# Increment the `completed` counter for the token
|
||||
await self.store.use_registration_token(registration_token)
|
||||
# Indicate that the token has been successfully used so that
|
||||
# pending is not decremented again when expiring old UIA sessions.
|
||||
await self.store.mark_ui_auth_stage_complete(
|
||||
session_id,
|
||||
LoginType.REGISTRATION_TOKEN,
|
||||
True,
|
||||
)
|
||||
|
||||
await self.registration_handler.post_registration_actions(
|
||||
user_id=registered_user_id,
|
||||
auth_result=auth_result,
|
||||
|
@ -868,6 +921,11 @@ def _calculate_registration_flows(
|
|||
for flow in flows:
|
||||
flow.insert(0, LoginType.RECAPTCHA)
|
||||
|
||||
# Prepend registration token to all flows if we're requiring a token
|
||||
if config.registration_requires_token:
|
||||
for flow in flows:
|
||||
flow.insert(0, LoginType.REGISTRATION_TOKEN)
|
||||
|
||||
return flows
|
||||
|
||||
|
||||
|
@ -876,4 +934,5 @@ def register_servlets(hs, http_server):
|
|||
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
|
||||
UsernameAvailabilityRestServlet(hs).register(http_server)
|
||||
RegistrationSubmitTokenServlet(hs).register(http_server)
|
||||
RegistrationTokenValidityRestServlet(hs).register(http_server)
|
||||
RegisterRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -13,18 +13,25 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.errors import Codes, ShadowBanError, SynapseError
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
parse_json_object_from_request,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import stringutils
|
||||
|
||||
from ._base import client_patterns
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -41,9 +48,6 @@ class RoomUpgradeRestServlet(RestServlet):
|
|||
}
|
||||
|
||||
Creates a new room and shuts down the old one. Returns the ID of the new room.
|
||||
|
||||
Args:
|
||||
hs (synapse.server.HomeServer):
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns(
|
||||
|
@ -51,13 +55,15 @@ class RoomUpgradeRestServlet(RestServlet):
|
|||
"/rooms/(?P<room_id>[^/]*)/upgrade$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self._hs = hs
|
||||
self._room_creation_handler = hs.get_room_creation_handler()
|
||||
self._auth = hs.get_auth()
|
||||
|
||||
async def on_POST(self, request, room_id):
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self._auth.get_user_by_req(request)
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
|
@ -84,5 +90,5 @@ class RoomUpgradeRestServlet(RestServlet):
|
|||
return 200, ret
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
RoomUpgradeRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -12,13 +12,19 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet
|
||||
from synapse.types import UserID
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import JsonDict, UserID
|
||||
|
||||
from ._base import client_patterns
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -32,13 +38,15 @@ class UserSharedRoomsServlet(RestServlet):
|
|||
releases=(), # This is an unstable feature
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.user_directory_active = hs.config.update_user_directory
|
||||
|
||||
async def on_GET(self, request, user_id):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
|
||||
if not self.user_directory_active:
|
||||
raise SynapseError(
|
||||
|
@ -63,5 +71,5 @@ class UserSharedRoomsServlet(RestServlet):
|
|||
return 200, {"joined": list(rooms)}
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
UserSharedRoomsServlet(hs).register(http_server)
|
||||
|
|
|
@ -14,17 +14,26 @@
|
|||
import itertools
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from synapse.api.constants import Membership, PresenceState
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError
|
||||
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.events.utils import (
|
||||
format_event_for_client_v2_without_room_id,
|
||||
format_event_raw,
|
||||
)
|
||||
from synapse.handlers.presence import format_user_presence_state
|
||||
from synapse.handlers.sync import KnockedSyncResult, SyncConfig
|
||||
from synapse.handlers.sync import (
|
||||
ArchivedSyncResult,
|
||||
InvitedSyncResult,
|
||||
JoinedSyncResult,
|
||||
KnockedSyncResult,
|
||||
SyncConfig,
|
||||
SyncResult,
|
||||
)
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import JsonDict, StreamToken
|
||||
|
@ -192,6 +201,8 @@ class SyncRestServlet(RestServlet):
|
|||
return 200, {}
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
# We know that the the requester has an access token since appservices
|
||||
# cannot use sync.
|
||||
response_content = await self.encode_response(
|
||||
time_now, sync_result, requester.access_token_id, filter_collection
|
||||
)
|
||||
|
@ -199,7 +210,13 @@ class SyncRestServlet(RestServlet):
|
|||
logger.debug("Event formatting complete")
|
||||
return 200, response_content
|
||||
|
||||
async def encode_response(self, time_now, sync_result, access_token_id, filter):
|
||||
async def encode_response(
|
||||
self,
|
||||
time_now: int,
|
||||
sync_result: SyncResult,
|
||||
access_token_id: Optional[int],
|
||||
filter: FilterCollection,
|
||||
) -> JsonDict:
|
||||
logger.debug("Formatting events in sync response")
|
||||
if filter.event_format == "client":
|
||||
event_formatter = format_event_for_client_v2_without_room_id
|
||||
|
@ -234,7 +251,7 @@ class SyncRestServlet(RestServlet):
|
|||
|
||||
logger.debug("building sync response dict")
|
||||
|
||||
response: dict = defaultdict(dict)
|
||||
response: JsonDict = defaultdict(dict)
|
||||
response["next_batch"] = await sync_result.next_batch.to_string(self.store)
|
||||
|
||||
if sync_result.account_data:
|
||||
|
@ -274,6 +291,8 @@ class SyncRestServlet(RestServlet):
|
|||
if archived:
|
||||
response["rooms"][Membership.LEAVE] = archived
|
||||
|
||||
# By the time we get here groups is no longer optional.
|
||||
assert sync_result.groups is not None
|
||||
if sync_result.groups.join:
|
||||
response["groups"][Membership.JOIN] = sync_result.groups.join
|
||||
if sync_result.groups.invite:
|
||||
|
@ -284,7 +303,7 @@ class SyncRestServlet(RestServlet):
|
|||
return response
|
||||
|
||||
@staticmethod
|
||||
def encode_presence(events, time_now):
|
||||
def encode_presence(events: List[UserPresenceState], time_now: int) -> JsonDict:
|
||||
return {
|
||||
"events": [
|
||||
{
|
||||
|
@ -299,25 +318,27 @@ class SyncRestServlet(RestServlet):
|
|||
}
|
||||
|
||||
async def encode_joined(
|
||||
self, rooms, time_now, token_id, event_fields, event_formatter
|
||||
):
|
||||
self,
|
||||
rooms: List[JoinedSyncResult],
|
||||
time_now: int,
|
||||
token_id: Optional[int],
|
||||
event_fields: List[str],
|
||||
event_formatter: Callable[[JsonDict], JsonDict],
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Encode the joined rooms in a sync result
|
||||
|
||||
Args:
|
||||
rooms(list[synapse.handlers.sync.JoinedSyncResult]): list of sync
|
||||
results for rooms this user is joined to
|
||||
time_now(int): current time - used as a baseline for age
|
||||
calculations
|
||||
token_id(int): ID of the user's auth token - used for namespacing
|
||||
rooms: list of sync results for rooms this user is joined to
|
||||
time_now: current time - used as a baseline for age calculations
|
||||
token_id: ID of the user's auth token - used for namespacing
|
||||
of transaction IDs
|
||||
event_fields(list<str>): List of event fields to include. If empty,
|
||||
event_fields: List of event fields to include. If empty,
|
||||
all fields will be returned.
|
||||
event_formatter (func[dict]): function to convert from federation format
|
||||
event_formatter: function to convert from federation format
|
||||
to client format
|
||||
Returns:
|
||||
dict[str, dict[str, object]]: the joined rooms list, in our
|
||||
response format
|
||||
The joined rooms list, in our response format
|
||||
"""
|
||||
joined = {}
|
||||
for room in rooms:
|
||||
|
@ -332,23 +353,26 @@ class SyncRestServlet(RestServlet):
|
|||
|
||||
return joined
|
||||
|
||||
async def encode_invited(self, rooms, time_now, token_id, event_formatter):
|
||||
async def encode_invited(
|
||||
self,
|
||||
rooms: List[InvitedSyncResult],
|
||||
time_now: int,
|
||||
token_id: Optional[int],
|
||||
event_formatter: Callable[[JsonDict], JsonDict],
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Encode the invited rooms in a sync result
|
||||
|
||||
Args:
|
||||
rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of
|
||||
sync results for rooms this user is invited to
|
||||
time_now(int): current time - used as a baseline for age
|
||||
calculations
|
||||
token_id(int): ID of the user's auth token - used for namespacing
|
||||
rooms: list of sync results for rooms this user is invited to
|
||||
time_now: current time - used as a baseline for age calculations
|
||||
token_id: ID of the user's auth token - used for namespacing
|
||||
of transaction IDs
|
||||
event_formatter (func[dict]): function to convert from federation format
|
||||
event_formatter: function to convert from federation format
|
||||
to client format
|
||||
|
||||
Returns:
|
||||
dict[str, dict[str, object]]: the invited rooms list, in our
|
||||
response format
|
||||
The invited rooms list, in our response format
|
||||
"""
|
||||
invited = {}
|
||||
for room in rooms:
|
||||
|
@ -371,7 +395,7 @@ class SyncRestServlet(RestServlet):
|
|||
self,
|
||||
rooms: List[KnockedSyncResult],
|
||||
time_now: int,
|
||||
token_id: int,
|
||||
token_id: Optional[int],
|
||||
event_formatter: Callable[[Dict], Dict],
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
|
@ -422,25 +446,26 @@ class SyncRestServlet(RestServlet):
|
|||
return knocked
|
||||
|
||||
async def encode_archived(
|
||||
self, rooms, time_now, token_id, event_fields, event_formatter
|
||||
):
|
||||
self,
|
||||
rooms: List[ArchivedSyncResult],
|
||||
time_now: int,
|
||||
token_id: Optional[int],
|
||||
event_fields: List[str],
|
||||
event_formatter: Callable[[JsonDict], JsonDict],
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Encode the archived rooms in a sync result
|
||||
|
||||
Args:
|
||||
rooms (list[synapse.handlers.sync.ArchivedSyncResult]): list of
|
||||
sync results for rooms this user is joined to
|
||||
time_now(int): current time - used as a baseline for age
|
||||
calculations
|
||||
token_id(int): ID of the user's auth token - used for namespacing
|
||||
rooms: list of sync results for rooms this user is joined to
|
||||
time_now: current time - used as a baseline for age calculations
|
||||
token_id: ID of the user's auth token - used for namespacing
|
||||
of transaction IDs
|
||||
event_fields(list<str>): List of event fields to include. If empty,
|
||||
event_fields: List of event fields to include. If empty,
|
||||
all fields will be returned.
|
||||
event_formatter (func[dict]): function to convert from federation format
|
||||
to client format
|
||||
event_formatter: function to convert from federation format to client format
|
||||
Returns:
|
||||
dict[str, dict[str, object]]: The invited rooms list, in our
|
||||
response format
|
||||
The archived rooms list, in our response format
|
||||
"""
|
||||
joined = {}
|
||||
for room in rooms:
|
||||
|
@ -456,23 +481,27 @@ class SyncRestServlet(RestServlet):
|
|||
return joined
|
||||
|
||||
async def encode_room(
|
||||
self, room, time_now, token_id, joined, only_fields, event_formatter
|
||||
):
|
||||
self,
|
||||
room: Union[JoinedSyncResult, ArchivedSyncResult],
|
||||
time_now: int,
|
||||
token_id: Optional[int],
|
||||
joined: bool,
|
||||
only_fields: Optional[List[str]],
|
||||
event_formatter: Callable[[JsonDict], JsonDict],
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Args:
|
||||
room (JoinedSyncResult|ArchivedSyncResult): sync result for a
|
||||
single room
|
||||
time_now (int): current time - used as a baseline for age
|
||||
calculations
|
||||
token_id (int): ID of the user's auth token - used for namespacing
|
||||
room: sync result for a single room
|
||||
time_now: current time - used as a baseline for age calculations
|
||||
token_id: ID of the user's auth token - used for namespacing
|
||||
of transaction IDs
|
||||
joined (bool): True if the user is joined to this room - will mean
|
||||
joined: True if the user is joined to this room - will mean
|
||||
we handle ephemeral events
|
||||
only_fields(list<str>): Optional. The list of event fields to include.
|
||||
event_formatter (func[dict]): function to convert from federation format
|
||||
only_fields: Optional. The list of event fields to include.
|
||||
event_formatter: function to convert from federation format
|
||||
to client format
|
||||
Returns:
|
||||
dict[str, object]: the room, encoded in our response format
|
||||
The room, encoded in our response format
|
||||
"""
|
||||
|
||||
def serialize(events):
|
||||
|
@ -508,7 +537,7 @@ class SyncRestServlet(RestServlet):
|
|||
|
||||
account_data = room.account_data
|
||||
|
||||
result = {
|
||||
result: JsonDict = {
|
||||
"timeline": {
|
||||
"events": serialized_timeline,
|
||||
"prev_batch": await room.timeline.prev_batch.to_string(self.store),
|
||||
|
@ -519,6 +548,7 @@ class SyncRestServlet(RestServlet):
|
|||
}
|
||||
|
||||
if joined:
|
||||
assert isinstance(room, JoinedSyncResult)
|
||||
ephemeral_events = room.ephemeral
|
||||
result["ephemeral"] = {"events": ephemeral_events}
|
||||
result["unread_notifications"] = room.unread_notifications
|
||||
|
@ -528,5 +558,5 @@ class SyncRestServlet(RestServlet):
|
|||
return result
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
SyncRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -13,12 +13,19 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import client_patterns
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -29,12 +36,14 @@ class TagListServlet(RestServlet):
|
|||
|
||||
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_GET(self, request, user_id, room_id):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
if user_id != requester.user.to_string():
|
||||
raise AuthError(403, "Cannot get tags for other users.")
|
||||
|
@ -54,12 +63,14 @@ class TagServlet(RestServlet):
|
|||
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.handler = hs.get_account_data_handler()
|
||||
|
||||
async def on_PUT(self, request, user_id, room_id, tag):
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, user_id: str, room_id: str, tag: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
if user_id != requester.user.to_string():
|
||||
raise AuthError(403, "Cannot add tags for other users.")
|
||||
|
@ -70,7 +81,9 @@ class TagServlet(RestServlet):
|
|||
|
||||
return 200, {}
|
||||
|
||||
async def on_DELETE(self, request, user_id, room_id, tag):
|
||||
async def on_DELETE(
|
||||
self, request: SynapseRequest, user_id: str, room_id: str, tag: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
if user_id != requester.user.to_string():
|
||||
raise AuthError(403, "Cannot add tags for other users.")
|
||||
|
@ -80,6 +93,6 @@ class TagServlet(RestServlet):
|
|||
return 200, {}
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
TagListServlet(hs).register(http_server)
|
||||
TagServlet(hs).register(http_server)
|
||||
|
|
|
@ -12,27 +12,33 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple
|
||||
|
||||
from synapse.api.constants import ThirdPartyEntityKind
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import client_patterns
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ThirdPartyProtocolsServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/thirdparty/protocols")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
self.appservice_handler = hs.get_application_service_handler()
|
||||
|
||||
async def on_GET(self, request):
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
protocols = await self.appservice_handler.get_3pe_protocols()
|
||||
|
@ -42,13 +48,15 @@ class ThirdPartyProtocolsServlet(RestServlet):
|
|||
class ThirdPartyProtocolServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
self.appservice_handler = hs.get_application_service_handler()
|
||||
|
||||
async def on_GET(self, request, protocol):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, protocol: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
protocols = await self.appservice_handler.get_3pe_protocols(
|
||||
|
@ -63,16 +71,18 @@ class ThirdPartyProtocolServlet(RestServlet):
|
|||
class ThirdPartyUserServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
self.appservice_handler = hs.get_application_service_handler()
|
||||
|
||||
async def on_GET(self, request, protocol):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, protocol: str
|
||||
) -> Tuple[int, List[JsonDict]]:
|
||||
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
fields = request.args
|
||||
fields: Dict[bytes, List[bytes]] = request.args # type: ignore[assignment]
|
||||
fields.pop(b"access_token", None)
|
||||
|
||||
results = await self.appservice_handler.query_3pe(
|
||||
|
@ -85,16 +95,18 @@ class ThirdPartyUserServlet(RestServlet):
|
|||
class ThirdPartyLocationServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
self.appservice_handler = hs.get_application_service_handler()
|
||||
|
||||
async def on_GET(self, request, protocol):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, protocol: str
|
||||
) -> Tuple[int, List[JsonDict]]:
|
||||
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
fields = request.args
|
||||
fields: Dict[bytes, List[bytes]] = request.args # type: ignore[assignment]
|
||||
fields.pop(b"access_token", None)
|
||||
|
||||
results = await self.appservice_handler.query_3pe(
|
||||
|
@ -104,7 +116,7 @@ class ThirdPartyLocationServlet(RestServlet):
|
|||
return 200, results
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
ThirdPartyProtocolsServlet(hs).register(http_server)
|
||||
ThirdPartyProtocolServlet(hs).register(http_server)
|
||||
ThirdPartyUserServlet(hs).register(http_server)
|
||||
|
|
|
@ -12,11 +12,19 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet
|
||||
|
||||
from ._base import client_patterns
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
class TokenRefreshRestServlet(RestServlet):
|
||||
"""
|
||||
|
@ -26,12 +34,12 @@ class TokenRefreshRestServlet(RestServlet):
|
|||
|
||||
PATTERNS = client_patterns("/tokenrefresh")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
|
||||
async def on_POST(self, request):
|
||||
async def on_POST(self, request: Request) -> None:
|
||||
raise AuthError(403, "tokenrefresh is no longer supported.")
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
TokenRefreshRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -13,29 +13,32 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import client_patterns
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserDirectorySearchRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/user_directory/search$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.user_directory_handler = hs.get_user_directory_handler()
|
||||
|
||||
async def on_POST(self, request):
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
"""Searches for users in directory
|
||||
|
||||
Returns:
|
||||
|
@ -75,5 +78,5 @@ class UserDirectorySearchRestServlet(RestServlet):
|
|||
return 200, results
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
UserDirectorySearchRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -17,9 +17,17 @@
|
|||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.constants import RoomCreationPreset
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -27,7 +35,7 @@ logger = logging.getLogger(__name__)
|
|||
class VersionsRestServlet(RestServlet):
|
||||
PATTERNS = [re.compile("^/_matrix/client/versions$")]
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.config = hs.config
|
||||
|
||||
|
@ -45,7 +53,7 @@ class VersionsRestServlet(RestServlet):
|
|||
in self.config.encryption_enabled_by_default_for_room_presets
|
||||
)
|
||||
|
||||
def on_GET(self, request):
|
||||
def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
|
||||
return (
|
||||
200,
|
||||
{
|
||||
|
@ -89,5 +97,5 @@ class VersionsRestServlet(RestServlet):
|
|||
)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
VersionsRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -15,20 +15,27 @@
|
|||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client._base import client_patterns
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
class VoipRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/voip/turnServer$", v1=True)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_GET(self, request):
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(
|
||||
request, self.hs.config.turn_allow_guests
|
||||
)
|
||||
|
@ -69,5 +76,5 @@ class VoipRestServlet(RestServlet):
|
|||
)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
VoipRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
|
@ -414,9 +414,9 @@ class ThumbnailResource(DirectServeJsonResource):
|
|||
|
||||
if desired_method == "crop":
|
||||
# Thumbnails that match equal or larger sizes of desired width/height.
|
||||
crop_info_list = []
|
||||
crop_info_list: List[Tuple[int, int, int, bool, int, Dict[str, Any]]] = []
|
||||
# Other thumbnails.
|
||||
crop_info_list2 = []
|
||||
crop_info_list2: List[Tuple[int, int, int, bool, int, Dict[str, Any]]] = []
|
||||
for info in thumbnail_infos:
|
||||
# Skip thumbnails generated with different methods.
|
||||
if info["thumbnail_method"] != "crop":
|
||||
|
@ -451,15 +451,19 @@ class ThumbnailResource(DirectServeJsonResource):
|
|||
info,
|
||||
)
|
||||
)
|
||||
# Pick the most appropriate thumbnail. Some values of `desired_width` and
|
||||
# `desired_height` may result in a tie, in which case we avoid comparing on
|
||||
# the thumbnail info dictionary and pick the thumbnail that appears earlier
|
||||
# in the list of candidates.
|
||||
if crop_info_list:
|
||||
thumbnail_info = min(crop_info_list)[-1]
|
||||
thumbnail_info = min(crop_info_list, key=lambda t: t[:-1])[-1]
|
||||
elif crop_info_list2:
|
||||
thumbnail_info = min(crop_info_list2)[-1]
|
||||
thumbnail_info = min(crop_info_list2, key=lambda t: t[:-1])[-1]
|
||||
elif desired_method == "scale":
|
||||
# Thumbnails that match equal or larger sizes of desired width/height.
|
||||
info_list = []
|
||||
info_list: List[Tuple[int, bool, int, Dict[str, Any]]] = []
|
||||
# Other thumbnails.
|
||||
info_list2 = []
|
||||
info_list2: List[Tuple[int, bool, int, Dict[str, Any]]] = []
|
||||
|
||||
for info in thumbnail_infos:
|
||||
# Skip thumbnails generated with different methods.
|
||||
|
@ -477,10 +481,14 @@ class ThumbnailResource(DirectServeJsonResource):
|
|||
info_list2.append(
|
||||
(size_quality, type_quality, length_quality, info)
|
||||
)
|
||||
# Pick the most appropriate thumbnail. Some values of `desired_width` and
|
||||
# `desired_height` may result in a tie, in which case we avoid comparing on
|
||||
# the thumbnail info dictionary and pick the thumbnail that appears earlier
|
||||
# in the list of candidates.
|
||||
if info_list:
|
||||
thumbnail_info = min(info_list)[-1]
|
||||
thumbnail_info = min(info_list, key=lambda t: t[:-1])[-1]
|
||||
elif info_list2:
|
||||
thumbnail_info = min(info_list2)[-1]
|
||||
thumbnail_info = min(info_list2, key=lambda t: t[:-1])[-1]
|
||||
|
||||
if thumbnail_info:
|
||||
return FileInfo(
|
||||
|
|
|
@ -76,6 +76,7 @@ from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler
|
|||
from synapse.handlers.event_auth import EventAuthHandler
|
||||
from synapse.handlers.events import EventHandler, EventStreamHandler
|
||||
from synapse.handlers.federation import FederationHandler
|
||||
from synapse.handlers.federation_event import FederationEventHandler
|
||||
from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerHandler
|
||||
from synapse.handlers.identity import IdentityHandler
|
||||
from synapse.handlers.initial_sync import InitialSyncHandler
|
||||
|
@ -546,6 +547,10 @@ class HomeServer(metaclass=abc.ABCMeta):
|
|||
def get_federation_handler(self) -> FederationHandler:
|
||||
return FederationHandler(self)
|
||||
|
||||
@cache_in_self
|
||||
def get_federation_event_handler(self) -> FederationEventHandler:
|
||||
return FederationEventHandler(self)
|
||||
|
||||
@cache_in_self
|
||||
def get_identity_handler(self) -> IdentityHandler:
|
||||
return IdentityHandler(self)
|
||||
|
|
|
@ -12,26 +12,23 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership, RoomCreationPreset
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import UserID, create_requester
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SERVER_NOTICE_ROOM_TAG = "m.server_notice"
|
||||
|
||||
|
||||
class ServerNoticesManager:
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
|
||||
Args:
|
||||
hs (synapse.server.HomeServer):
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._store = hs.get_datastore()
|
||||
self._config = hs.config
|
||||
self._account_data_handler = hs.get_account_data_handler()
|
||||
|
@ -58,6 +55,7 @@ class ServerNoticesManager:
|
|||
event_content: dict,
|
||||
type: str = EventTypes.Message,
|
||||
state_key: Optional[str] = None,
|
||||
txn_id: Optional[str] = None,
|
||||
) -> EventBase:
|
||||
"""Send a notice to the given user
|
||||
|
||||
|
@ -68,6 +66,7 @@ class ServerNoticesManager:
|
|||
event_content: content of event to send
|
||||
type: type of event
|
||||
is_state_event: Is the event a state event
|
||||
txn_id: The transaction ID.
|
||||
"""
|
||||
room_id = await self.get_or_create_notice_room_for_user(user_id)
|
||||
await self.maybe_invite_user_to_room(user_id, room_id)
|
||||
|
@ -90,7 +89,7 @@ class ServerNoticesManager:
|
|||
event_dict["state_key"] = state_key
|
||||
|
||||
event, _ = await self._event_creation_handler.create_and_send_nonmember_event(
|
||||
requester, event_dict, ratelimit=False
|
||||
requester, event_dict, ratelimit=False, txn_id=txn_id
|
||||
)
|
||||
return event
|
||||
|
||||
|
|
|
@ -57,4 +57,8 @@ textarea, input {
|
|||
|
||||
background-color: #f8f8f8;
|
||||
border: 1px #ccc solid;
|
||||
}
|
||||
}
|
||||
|
||||
.error {
|
||||
color: red;
|
||||
}
|
||||
|
|
|
@ -63,6 +63,7 @@ from .relations import RelationsStore
|
|||
from .room import RoomStore
|
||||
from .roommember import RoomMemberStore
|
||||
from .search import SearchStore
|
||||
from .session import SessionStore
|
||||
from .signatures import SignatureStore
|
||||
from .state import StateStore
|
||||
from .stats import StatsStore
|
||||
|
@ -121,6 +122,7 @@ class DataStore(
|
|||
ServerMetricsStore,
|
||||
EventForwardExtremitiesStore,
|
||||
LockStore,
|
||||
SessionStore,
|
||||
):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
self.hs = hs
|
||||
|
|
|
@ -520,16 +520,26 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
# We now look up if we're already fetching some of the events in the DB,
|
||||
# if so we wait for those lookups to finish instead of pulling the same
|
||||
# events out of the DB multiple times.
|
||||
already_fetching: Dict[str, defer.Deferred] = {}
|
||||
#
|
||||
# Note: we might get the same `ObservableDeferred` back for multiple
|
||||
# events we're already fetching, so we deduplicate the deferreds to
|
||||
# avoid extraneous work (if we don't do this we can end up in a n^2 mode
|
||||
# when we wait on the same Deferred N times, then try and merge the
|
||||
# same dict into itself N times).
|
||||
already_fetching_ids: Set[str] = set()
|
||||
already_fetching_deferreds: Set[
|
||||
ObservableDeferred[Dict[str, _EventCacheEntry]]
|
||||
] = set()
|
||||
|
||||
for event_id in missing_events_ids:
|
||||
deferred = self._current_event_fetches.get(event_id)
|
||||
if deferred is not None:
|
||||
# We're already pulling the event out of the DB. Add the deferred
|
||||
# to the collection of deferreds to wait on.
|
||||
already_fetching[event_id] = deferred.observe()
|
||||
already_fetching_ids.add(event_id)
|
||||
already_fetching_deferreds.add(deferred)
|
||||
|
||||
missing_events_ids.difference_update(already_fetching)
|
||||
missing_events_ids.difference_update(already_fetching_ids)
|
||||
|
||||
if missing_events_ids:
|
||||
log_ctx = current_context()
|
||||
|
@ -569,18 +579,25 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
with PreserveLoggingContext():
|
||||
fetching_deferred.callback(missing_events)
|
||||
|
||||
if already_fetching:
|
||||
if already_fetching_deferreds:
|
||||
# Wait for the other event requests to finish and add their results
|
||||
# to ours.
|
||||
results = await make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
already_fetching.values(),
|
||||
(d.observe() for d in already_fetching_deferreds),
|
||||
consumeErrors=True,
|
||||
)
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
for result in results:
|
||||
event_entry_map.update(result)
|
||||
# We filter out events that we haven't asked for as we might get
|
||||
# a *lot* of superfluous events back, and there is no point
|
||||
# going through and inserting them all (which can take time).
|
||||
event_entry_map.update(
|
||||
(event_id, entry)
|
||||
for event_id, entry in result.items()
|
||||
if event_id in already_fetching_ids
|
||||
)
|
||||
|
||||
if not allow_rejected:
|
||||
event_entry_map = {
|
||||
|
|
|
@ -295,6 +295,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
|
|||
self._invalidate_cache_and_stream(
|
||||
txn, self.have_seen_event, (room_id, event_id)
|
||||
)
|
||||
self._invalidate_get_event_cache(event_id)
|
||||
|
||||
logger.info("[purge] done")
|
||||
|
||||
|
|
|
@ -48,6 +48,11 @@ class PusherWorkerStore(SQLBaseStore):
|
|||
self._remove_stale_pushers,
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
"remove_deleted_email_pushers",
|
||||
self._remove_deleted_email_pushers,
|
||||
)
|
||||
|
||||
def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
|
||||
"""JSON-decode the data in the rows returned from the `pushers` table
|
||||
|
||||
|
@ -388,6 +393,74 @@ class PusherWorkerStore(SQLBaseStore):
|
|||
|
||||
return number_deleted
|
||||
|
||||
async def _remove_deleted_email_pushers(
|
||||
self, progress: dict, batch_size: int
|
||||
) -> int:
|
||||
"""A background update that deletes all pushers for deleted email addresses.
|
||||
|
||||
In previous versions of synapse, when users deleted their email address, it didn't
|
||||
also delete all the pushers for that email address. This background update removes
|
||||
those to prevent unwanted emails. This should only need to be run once (when users
|
||||
upgrade to v1.42.0
|
||||
|
||||
Args:
|
||||
progress: dict used to store progress of this background update
|
||||
batch_size: the maximum number of rows to retrieve in a single select query
|
||||
|
||||
Returns:
|
||||
The number of deleted rows
|
||||
"""
|
||||
|
||||
last_pusher = progress.get("last_pusher", 0)
|
||||
|
||||
def _delete_pushers(txn) -> int:
|
||||
|
||||
sql = """
|
||||
SELECT p.id, p.user_name, p.app_id, p.pushkey
|
||||
FROM pushers AS p
|
||||
LEFT JOIN user_threepids AS t
|
||||
ON t.user_id = p.user_name
|
||||
AND t.medium = 'email'
|
||||
AND t.address = p.pushkey
|
||||
WHERE t.user_id is NULL
|
||||
AND p.app_id = 'm.email'
|
||||
AND p.id > ?
|
||||
ORDER BY p.id ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
txn.execute(sql, (last_pusher, batch_size))
|
||||
rows = txn.fetchall()
|
||||
|
||||
last = None
|
||||
num_deleted = 0
|
||||
for row in rows:
|
||||
last = row[0]
|
||||
num_deleted += 1
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
"pushers",
|
||||
{"user_name": row[1], "app_id": row[2], "pushkey": row[3]},
|
||||
)
|
||||
|
||||
if last is not None:
|
||||
self.db_pool.updates._background_update_progress_txn(
|
||||
txn, "remove_deleted_email_pushers", {"last_pusher": last}
|
||||
)
|
||||
|
||||
return num_deleted
|
||||
|
||||
number_deleted = await self.db_pool.runInteraction(
|
||||
"_remove_deleted_email_pushers", _delete_pushers
|
||||
)
|
||||
|
||||
if number_deleted < batch_size:
|
||||
await self.db_pool.updates._end_background_update(
|
||||
"remove_deleted_email_pushers"
|
||||
)
|
||||
|
||||
return number_deleted
|
||||
|
||||
|
||||
class PusherStore(PusherWorkerStore):
|
||||
def get_pushers_stream_token(self) -> int:
|
||||
|
|
|
@ -754,16 +754,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
)
|
||||
return user_id
|
||||
|
||||
def get_user_id_by_threepid_txn(self, txn, medium, address):
|
||||
def get_user_id_by_threepid_txn(
|
||||
self, txn, medium: str, address: str
|
||||
) -> Optional[str]:
|
||||
"""Returns user id from threepid
|
||||
|
||||
Args:
|
||||
txn (cursor):
|
||||
medium (str): threepid medium e.g. email
|
||||
address (str): threepid address e.g. me@example.com
|
||||
medium: threepid medium e.g. email
|
||||
address: threepid address e.g. me@example.com
|
||||
|
||||
Returns:
|
||||
str|None: user id or None if no user id/threepid mapping exists
|
||||
user id, or None if no user id/threepid mapping exists
|
||||
"""
|
||||
ret = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
|
@ -776,14 +778,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
return ret["user_id"]
|
||||
return None
|
||||
|
||||
async def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
|
||||
async def user_add_threepid(
|
||||
self,
|
||||
user_id: str,
|
||||
medium: str,
|
||||
address: str,
|
||||
validated_at: int,
|
||||
added_at: int,
|
||||
) -> None:
|
||||
await self.db_pool.simple_upsert(
|
||||
"user_threepids",
|
||||
{"medium": medium, "address": address},
|
||||
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
|
||||
)
|
||||
|
||||
async def user_get_threepids(self, user_id):
|
||||
async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]:
|
||||
return await self.db_pool.simple_select_list(
|
||||
"user_threepids",
|
||||
{"user_id": user_id},
|
||||
|
@ -791,7 +800,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
"user_get_threepids",
|
||||
)
|
||||
|
||||
async def user_delete_threepid(self, user_id, medium, address) -> None:
|
||||
async def user_delete_threepid(
|
||||
self, user_id: str, medium: str, address: str
|
||||
) -> None:
|
||||
await self.db_pool.simple_delete(
|
||||
"user_threepids",
|
||||
keyvalues={"user_id": user_id, "medium": medium, "address": address},
|
||||
|
@ -1157,6 +1168,322 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
desc="update_access_token_last_validated",
|
||||
)
|
||||
|
||||
async def registration_token_is_valid(self, token: str) -> bool:
|
||||
"""Checks if a token can be used to authenticate a registration.
|
||||
|
||||
Args:
|
||||
token: The registration token to be checked
|
||||
Returns:
|
||||
True if the token is valid, False otherwise.
|
||||
"""
|
||||
res = await self.db_pool.simple_select_one(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=["uses_allowed", "pending", "completed", "expiry_time"],
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
# Check if the token exists
|
||||
if res is None:
|
||||
return False
|
||||
|
||||
# Check if the token has expired
|
||||
now = self._clock.time_msec()
|
||||
if res["expiry_time"] and res["expiry_time"] < now:
|
||||
return False
|
||||
|
||||
# Check if the token has been used up
|
||||
if (
|
||||
res["uses_allowed"]
|
||||
and res["pending"] + res["completed"] >= res["uses_allowed"]
|
||||
):
|
||||
return False
|
||||
|
||||
# Otherwise, the token is valid
|
||||
return True
|
||||
|
||||
async def set_registration_token_pending(self, token: str) -> None:
|
||||
"""Increment the pending registrations counter for a token.
|
||||
|
||||
Args:
|
||||
token: The registration token pending use
|
||||
"""
|
||||
|
||||
def _set_registration_token_pending_txn(txn):
|
||||
pending = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcol="pending",
|
||||
)
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
updatevalues={"pending": pending + 1},
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"set_registration_token_pending", _set_registration_token_pending_txn
|
||||
)
|
||||
|
||||
async def use_registration_token(self, token: str) -> None:
|
||||
"""Complete a use of the given registration token.
|
||||
|
||||
The `pending` counter will be decremented, and the `completed`
|
||||
counter will be incremented.
|
||||
|
||||
Args:
|
||||
token: The registration token to be 'used'
|
||||
"""
|
||||
|
||||
def _use_registration_token_txn(txn):
|
||||
# Normally, res is Optional[Dict[str, Any]].
|
||||
# Override type because the return type is only optional if
|
||||
# allow_none is True, and we don't want mypy throwing errors
|
||||
# about None not being indexable.
|
||||
res: Dict[str, Any] = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=["pending", "completed"],
|
||||
) # type: ignore
|
||||
|
||||
# Decrement pending and increment completed
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
updatevalues={
|
||||
"completed": res["completed"] + 1,
|
||||
"pending": res["pending"] - 1,
|
||||
},
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"use_registration_token", _use_registration_token_txn
|
||||
)
|
||||
|
||||
async def get_registration_tokens(
|
||||
self, valid: Optional[bool] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List all registration tokens. Used by the admin API.
|
||||
|
||||
Args:
|
||||
valid: If True, only valid tokens are returned.
|
||||
If False, only invalid tokens are returned.
|
||||
Default is None: return all tokens regardless of validity.
|
||||
|
||||
Returns:
|
||||
A list of dicts, each containing details of a token.
|
||||
"""
|
||||
|
||||
def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]):
|
||||
if valid is None:
|
||||
# Return all tokens regardless of validity
|
||||
txn.execute("SELECT * FROM registration_tokens")
|
||||
|
||||
elif valid:
|
||||
# Select valid tokens only
|
||||
sql = (
|
||||
"SELECT * FROM registration_tokens WHERE "
|
||||
"(uses_allowed > pending + completed OR uses_allowed IS NULL) "
|
||||
"AND (expiry_time > ? OR expiry_time IS NULL)"
|
||||
)
|
||||
txn.execute(sql, [now])
|
||||
|
||||
else:
|
||||
# Select invalid tokens only
|
||||
sql = (
|
||||
"SELECT * FROM registration_tokens WHERE "
|
||||
"uses_allowed <= pending + completed OR expiry_time <= ?"
|
||||
)
|
||||
txn.execute(sql, [now])
|
||||
|
||||
return self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"select_registration_tokens",
|
||||
select_registration_tokens_txn,
|
||||
self._clock.time_msec(),
|
||||
valid,
|
||||
)
|
||||
|
||||
async def get_one_registration_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get info about the given registration token. Used by the admin API.
|
||||
|
||||
Args:
|
||||
token: The token to retrieve information about.
|
||||
|
||||
Returns:
|
||||
A dict, or None if token doesn't exist.
|
||||
"""
|
||||
return await self.db_pool.simple_select_one(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
|
||||
allow_none=True,
|
||||
desc="get_one_registration_token",
|
||||
)
|
||||
|
||||
async def generate_registration_token(
|
||||
self, length: int, chars: str
|
||||
) -> Optional[str]:
|
||||
"""Generate a random registration token. Used by the admin API.
|
||||
|
||||
Args:
|
||||
length: The length of the token to generate.
|
||||
chars: A string of the characters allowed in the generated token.
|
||||
|
||||
Returns:
|
||||
The generated token.
|
||||
|
||||
Raises:
|
||||
SynapseError if a unique registration token could still not be
|
||||
generated after a few tries.
|
||||
"""
|
||||
# Make a few attempts at generating a unique token of the required
|
||||
# length before failing.
|
||||
for _i in range(3):
|
||||
# Generate token
|
||||
token = "".join(random.choices(chars, k=length))
|
||||
|
||||
# Check if the token already exists
|
||||
existing_token = await self.db_pool.simple_select_one_onecol(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcol="token",
|
||||
allow_none=True,
|
||||
desc="check_if_registration_token_exists",
|
||||
)
|
||||
|
||||
if existing_token is None:
|
||||
# The generated token doesn't exist yet, return it
|
||||
return token
|
||||
|
||||
raise SynapseError(
|
||||
500,
|
||||
"Unable to generate a unique registration token. Try again with a greater length",
|
||||
Codes.UNKNOWN,
|
||||
)
|
||||
|
||||
async def create_registration_token(
|
||||
self, token: str, uses_allowed: Optional[int], expiry_time: Optional[int]
|
||||
) -> bool:
|
||||
"""Create a new registration token. Used by the admin API.
|
||||
|
||||
Args:
|
||||
token: The token to create.
|
||||
uses_allowed: The number of times the token can be used to complete
|
||||
a registration before it becomes invalid. A value of None indicates
|
||||
unlimited uses.
|
||||
expiry_time: The latest time the token is valid. Given as the
|
||||
number of milliseconds since 1970-01-01 00:00:00 UTC. A value of
|
||||
None indicates that the token does not expire.
|
||||
|
||||
Returns:
|
||||
Whether the row was inserted or not.
|
||||
"""
|
||||
|
||||
def _create_registration_token_txn(txn):
|
||||
row = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=["token"],
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
if row is not None:
|
||||
# Token already exists
|
||||
return False
|
||||
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
values={
|
||||
"token": token,
|
||||
"uses_allowed": uses_allowed,
|
||||
"pending": 0,
|
||||
"completed": 0,
|
||||
"expiry_time": expiry_time,
|
||||
},
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"create_registration_token", _create_registration_token_txn
|
||||
)
|
||||
|
||||
async def update_registration_token(
|
||||
self, token: str, updatevalues: Dict[str, Optional[int]]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Update a registration token. Used by the admin API.
|
||||
|
||||
Args:
|
||||
token: The token to update.
|
||||
updatevalues: A dict with the fields to update. E.g.:
|
||||
`{"uses_allowed": 3}` to update just uses_allowed, or
|
||||
`{"uses_allowed": 3, "expiry_time": None}` to update both.
|
||||
This is passed straight to simple_update_one.
|
||||
|
||||
Returns:
|
||||
A dict with all info about the token, or None if token doesn't exist.
|
||||
"""
|
||||
|
||||
def _update_registration_token_txn(txn):
|
||||
try:
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
updatevalues=updatevalues,
|
||||
)
|
||||
except StoreError:
|
||||
# Update failed because token does not exist
|
||||
return None
|
||||
|
||||
# Get all info about the token so it can be sent in the response
|
||||
return self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=[
|
||||
"token",
|
||||
"uses_allowed",
|
||||
"pending",
|
||||
"completed",
|
||||
"expiry_time",
|
||||
],
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"update_registration_token", _update_registration_token_txn
|
||||
)
|
||||
|
||||
async def delete_registration_token(self, token: str) -> bool:
|
||||
"""Delete a registration token. Used by the admin API.
|
||||
|
||||
Args:
|
||||
token: The token to delete.
|
||||
|
||||
Returns:
|
||||
Whether the token was successfully deleted or not.
|
||||
"""
|
||||
try:
|
||||
await self.db_pool.simple_delete_one(
|
||||
"registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
desc="delete_registration_token",
|
||||
)
|
||||
except StoreError:
|
||||
# Deletion failed because token does not exist
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@cached()
|
||||
async def mark_access_token_as_used(self, token_id: int) -> None:
|
||||
"""
|
||||
|
|
|
@ -307,7 +307,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
)
|
||||
|
||||
@cached()
|
||||
async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser:
|
||||
async def get_invited_rooms_for_local_user(
|
||||
self, user_id: str
|
||||
) -> List[RoomsForUser]:
|
||||
"""Get all the rooms the *local* user is invited to.
|
||||
|
||||
Args:
|
||||
|
@ -384,9 +386,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
)
|
||||
|
||||
sql = """
|
||||
SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
|
||||
SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering, r.room_version
|
||||
FROM local_current_membership AS c
|
||||
INNER JOIN events AS e USING (room_id, event_id)
|
||||
INNER JOIN rooms AS r USING (room_id)
|
||||
WHERE
|
||||
user_id = ?
|
||||
AND %s
|
||||
|
@ -395,7 +398,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
)
|
||||
|
||||
txn.execute(sql, (user_id, *args))
|
||||
results = [RoomsForUser(**r) for r in self.db_pool.cursor_to_dict(txn)]
|
||||
results = [RoomsForUser(*r) for r in txn]
|
||||
|
||||
return results
|
||||
|
||||
|
@ -445,7 +448,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
|
||||
Returns:
|
||||
Returns the rooms the user is in currently, along with the stream
|
||||
ordering of the most recent join for that user and room.
|
||||
ordering of the most recent join for that user and room, along with
|
||||
the room version of the room.
|
||||
"""
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_rooms_for_user_with_stream_ordering",
|
||||
|
@ -522,7 +526,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
_get_users_server_still_shares_room_with_txn,
|
||||
)
|
||||
|
||||
async def get_rooms_for_user(self, user_id: str, on_invalidate=None):
|
||||
async def get_rooms_for_user(
|
||||
self, user_id: str, on_invalidate=None
|
||||
) -> FrozenSet[str]:
|
||||
"""Returns a set of room_ids the user is currently joined to.
|
||||
|
||||
If a remote user only returns rooms this server is currently
|
||||
|
|
145
synapse/storage/databases/main/session.py
Normal file
145
synapse/storage/databases/main/session.py
Normal file
|
@ -0,0 +1,145 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import synapse.util.stringutils as stringutils
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
class SessionStore(SQLBaseStore):
|
||||
"""
|
||||
A store for generic session data.
|
||||
|
||||
Each type of session should provide a unique type (to separate sessions).
|
||||
|
||||
Sessions are automatically removed when they expire.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
# Create a background job for culling expired sessions.
|
||||
if hs.config.run_background_tasks:
|
||||
self._clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000)
|
||||
|
||||
async def create_session(
|
||||
self, session_type: str, value: JsonDict, expiry_ms: int
|
||||
) -> str:
|
||||
"""
|
||||
Creates a new pagination session for the room hierarchy endpoint.
|
||||
|
||||
Args:
|
||||
session_type: The type for this session.
|
||||
value: The value to store.
|
||||
expiry_ms: How long before an item is evicted from the cache
|
||||
in milliseconds. Default is 0, indicating items never get
|
||||
evicted based on time.
|
||||
|
||||
Returns:
|
||||
The newly created session ID.
|
||||
|
||||
Raises:
|
||||
StoreError if a unique session ID cannot be generated.
|
||||
"""
|
||||
# autogen a session ID and try to create it. We may clash, so just
|
||||
# try a few times till one goes through, giving up eventually.
|
||||
attempts = 0
|
||||
while attempts < 5:
|
||||
session_id = stringutils.random_string(24)
|
||||
|
||||
try:
|
||||
await self.db_pool.simple_insert(
|
||||
table="sessions",
|
||||
values={
|
||||
"session_id": session_id,
|
||||
"session_type": session_type,
|
||||
"value": json_encoder.encode(value),
|
||||
"expiry_time_ms": self.hs.get_clock().time_msec() + expiry_ms,
|
||||
},
|
||||
desc="create_session",
|
||||
)
|
||||
|
||||
return session_id
|
||||
except self.db_pool.engine.module.IntegrityError:
|
||||
attempts += 1
|
||||
raise StoreError(500, "Couldn't generate a session ID.")
|
||||
|
||||
async def get_session(self, session_type: str, session_id: str) -> JsonDict:
|
||||
"""
|
||||
Retrieve data stored with create_session
|
||||
|
||||
Args:
|
||||
session_type: The type for this session.
|
||||
session_id: The session ID returned from create_session.
|
||||
|
||||
Raises:
|
||||
StoreError if the session cannot be found.
|
||||
"""
|
||||
|
||||
def _get_session(
|
||||
txn: LoggingTransaction, session_type: str, session_id: str, ts: int
|
||||
) -> JsonDict:
|
||||
# This includes the expiry time since items are only periodically
|
||||
# deleted, not upon expiry.
|
||||
select_sql = """
|
||||
SELECT value FROM sessions WHERE
|
||||
session_type = ? AND session_id = ? AND expiry_time_ms > ?
|
||||
"""
|
||||
txn.execute(select_sql, [session_type, session_id, ts])
|
||||
row = txn.fetchone()
|
||||
|
||||
if not row:
|
||||
raise StoreError(404, "No session")
|
||||
|
||||
return db_to_json(row[0])
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_session",
|
||||
_get_session,
|
||||
session_type,
|
||||
session_id,
|
||||
self._clock.time_msec(),
|
||||
)
|
||||
|
||||
@wrap_as_background_process("delete_expired_sessions")
|
||||
async def _delete_expired_sessions(self) -> None:
|
||||
"""Remove sessions with expiry dates that have passed."""
|
||||
|
||||
def _delete_expired_sessions_txn(txn: LoggingTransaction, ts: int) -> None:
|
||||
sql = "DELETE FROM sessions WHERE expiry_time_ms <= ?"
|
||||
txn.execute(sql, (ts,))
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_expired_sessions",
|
||||
_delete_expired_sessions_txn,
|
||||
self._clock.time_msec(),
|
||||
)
|
|
@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
|
@ -329,6 +330,48 @@ class UIAuthWorkerStore(SQLBaseStore):
|
|||
keyvalues={},
|
||||
)
|
||||
|
||||
# If a registration token was used, decrement the pending counter
|
||||
# before deleting the session.
|
||||
rows = self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="ui_auth_sessions_credentials",
|
||||
column="session_id",
|
||||
iterable=session_ids,
|
||||
keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
|
||||
retcols=["result"],
|
||||
)
|
||||
|
||||
# Get the tokens used and how much pending needs to be decremented by.
|
||||
token_counts: Dict[str, int] = {}
|
||||
for r in rows:
|
||||
# If registration was successfully completed, the result of the
|
||||
# registration token stage for that session will be True.
|
||||
# If a token was used to authenticate, but registration was
|
||||
# never completed, the result will be the token used.
|
||||
token = db_to_json(r["result"])
|
||||
if isinstance(token, str):
|
||||
token_counts[token] = token_counts.get(token, 0) + 1
|
||||
|
||||
# Update the `pending` counters.
|
||||
if len(token_counts) > 0:
|
||||
token_rows = self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="registration_tokens",
|
||||
column="token",
|
||||
iterable=list(token_counts.keys()),
|
||||
keyvalues={},
|
||||
retcols=["token", "pending"],
|
||||
)
|
||||
for token_row in token_rows:
|
||||
token = token_row["token"]
|
||||
new_pending = token_row["pending"] - token_counts[token]
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn,
|
||||
table="registration_tokens",
|
||||
keyvalues={"token": token},
|
||||
updatevalues={"pending": new_pending},
|
||||
)
|
||||
|
||||
# Delete the corresponding completed credentials.
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
|
|
|
@ -365,7 +365,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
return False
|
||||
|
||||
async def update_profile_in_user_dir(
|
||||
self, user_id: str, display_name: str, avatar_url: str
|
||||
self, user_id: str, display_name: Optional[str], avatar_url: Optional[str]
|
||||
) -> None:
|
||||
"""
|
||||
Update or add a user's profile in the user directory.
|
||||
|
|
|
@ -14,25 +14,41 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.types import PersistedEventPosition
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
RoomsForUser = namedtuple(
|
||||
"RoomsForUser", ("room_id", "sender", "membership", "event_id", "stream_ordering")
|
||||
)
|
||||
|
||||
GetRoomsForUserWithStreamOrdering = namedtuple(
|
||||
"GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
|
||||
)
|
||||
@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
|
||||
class RoomsForUser:
|
||||
room_id: str
|
||||
sender: str
|
||||
membership: str
|
||||
event_id: str
|
||||
stream_ordering: int
|
||||
room_version_id: str
|
||||
|
||||
|
||||
# We store this using a namedtuple so that we save about 3x space over using a
|
||||
# dict.
|
||||
ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name"))
|
||||
@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
|
||||
class GetRoomsForUserWithStreamOrdering:
|
||||
room_id: str
|
||||
event_pos: PersistedEventPosition
|
||||
|
||||
# "members" points to a truncated list of (user_id, event_id) tuples for users of
|
||||
# a given membership type, suitable for use in calculating heroes for a room.
|
||||
# "count" points to the total numberr of users of a given membership type.
|
||||
MemberSummary = namedtuple("MemberSummary", ("members", "count"))
|
||||
|
||||
@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
|
||||
class ProfileInfo:
|
||||
avatar_url: Optional[str]
|
||||
display_name: Optional[str]
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
|
||||
class MemberSummary:
|
||||
# A truncated list of (user_id, event_id) tuples for users of a given
|
||||
# membership type, suitable for use in calculating heroes for a room.
|
||||
members: List[Tuple[str, str]]
|
||||
# The total number of users of a given membership type.
|
||||
count: int
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# When updating these values, please leave a short summary of the changes below.
|
||||
|
||||
SCHEMA_VERSION = 63
|
||||
"""Represents the expectations made by the codebase about the database schema
|
||||
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
/* Copyright 2021 Callum Brown
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE IF NOT EXISTS registration_tokens(
|
||||
token TEXT NOT NULL, -- The token that can be used for authentication.
|
||||
uses_allowed INT, -- The total number of times this token can be used. NULL if no limit.
|
||||
pending INT NOT NULL, -- The number of in progress registrations using this token.
|
||||
completed INT NOT NULL, -- The number of times this token has been used to complete a registration.
|
||||
expiry_time BIGINT, -- The latest time this token will be valid (epoch time in milliseconds). NULL if token doesn't expire.
|
||||
UNIQUE (token)
|
||||
);
|
|
@ -0,0 +1,20 @@
|
|||
/* Copyright 2021 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
-- We may not have deleted all pushers for emails that are no longer linked
|
||||
-- to an account, so we set up a background job to delete them.
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(6302, 'remove_deleted_email_pushers', '{}');
|
23
synapse/storage/schema/main/delta/63/03session_store.sql
Normal file
23
synapse/storage/schema/main/delta/63/03session_store.sql
Normal file
|
@ -0,0 +1,23 @@
|
|||
/*
|
||||
* Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE IF NOT EXISTS sessions(
|
||||
session_type TEXT NOT NULL, -- The unique key for this type of session.
|
||||
session_id TEXT NOT NULL, -- The session ID passed to the client.
|
||||
value TEXT NOT NULL, -- A JSON dictionary to persist.
|
||||
expiry_time_ms BIGINT NOT NULL, -- The time this session will expire (epoch time in milliseconds).
|
||||
UNIQUE (session_type, session_id)
|
||||
);
|
Loading…
Add table
Add a link
Reference in a new issue