Merge remote-tracking branch 'upstream/release-v1.42'

This commit is contained in:
Tulir Asokan 2021-09-01 15:54:42 +03:00
commit c75eed92c9
157 changed files with 7799 additions and 3658 deletions

View file

@ -47,7 +47,7 @@ try:
except ImportError:
pass
__version__ = "1.41.1"
__version__ = "1.42.0rc1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when

View file

@ -79,6 +79,7 @@ class LoginType:
TERMS = "m.login.terms"
SSO = "m.login.sso"
DUMMY = "m.login.dummy"
REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token"
# This is used in the `type` parameter for /register when called by

View file

@ -147,6 +147,14 @@ class SynapseError(CodeMessageException):
return cs_error(self.msg, self.errcode)
class InvalidAPICallError(SynapseError):
"""You called an existing API endpoint, but fed that endpoint
invalid or incomplete data."""
def __init__(self, msg: str):
super().__init__(HTTPStatus.BAD_REQUEST, msg, Codes.BAD_JSON)
class ProxiedRequestError(SynapseError):
"""An error from a general matrix endpoint, eg. from a proxied Matrix API call.

View file

@ -37,6 +37,7 @@ from synapse.app import check_bind_error
from synapse.app.phone_stats_home import start_phone_stats_home
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.events.presence_router import load_legacy_presence_router
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.logging.context import PreserveLoggingContext
@ -370,6 +371,7 @@ async def start(hs: "HomeServer"):
load_legacy_spam_checkers(hs)
load_legacy_third_party_event_rules(hs)
load_legacy_presence_router(hs)
# If we've configured an expiry time for caches, start the background job now.
setup_expire_lru_cache_entries(hs)

View file

@ -95,7 +95,10 @@ from synapse.rest.client.profile import (
ProfileRestServlet,
)
from synapse.rest.client.push_rule import PushRuleRestServlet
from synapse.rest.client.register import RegisterRestServlet
from synapse.rest.client.register import (
RegisterRestServlet,
RegistrationTokenValidityRestServlet,
)
from synapse.rest.client.sendtodevice import SendToDeviceRestServlet
from synapse.rest.client.versions import VersionsRestServlet
from synapse.rest.client.voip import VoipRestServlet
@ -115,6 +118,7 @@ from synapse.storage.databases.main.monthly_active_users import (
from synapse.storage.databases.main.presence import PresenceStore
from synapse.storage.databases.main.room import RoomWorkerStore
from synapse.storage.databases.main.search import SearchStore
from synapse.storage.databases.main.session import SessionStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.databases.main.transactions import TransactionWorkerStore
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
@ -250,6 +254,7 @@ class GenericWorkerSlavedStore(
SearchStore,
TransactionWorkerStore,
LockStore,
SessionStore,
BaseSlavedStore,
):
pass
@ -279,6 +284,7 @@ class GenericWorkerServer(HomeServer):
resource = JsonResource(self, canonical_json=False)
RegisterRestServlet(self).register(resource)
RegistrationTokenValidityRestServlet(self).register(resource)
login.register_servlets(self, resource)
ThreepidRestServlet(self).register(resource)
DevicesRestServlet(self).register(resource)

View file

@ -39,5 +39,8 @@ class ExperimentalConfig(Config):
# MSC3244 (room version capabilities)
self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True)
# MSC3283 (set displayname, avatar_url and change 3pid capabilities)
self.msc3283_enabled: bool = experimental.get("msc3283_enabled", False)
# MSC3266 (room summary api)
self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False)

View file

@ -79,6 +79,11 @@ class RatelimitConfig(Config):
self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
self.rc_registration_token_validity = RateLimitConfig(
config.get("rc_registration_token_validity", {}),
defaults={"per_second": 0.1, "burst_count": 5},
)
rc_login_config = config.get("rc_login", {})
self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
self.rc_login_account = RateLimitConfig(rc_login_config.get("account", {}))
@ -143,6 +148,8 @@ class RatelimitConfig(Config):
# is using
# - one for registration that ratelimits registration requests based on the
# client's IP address.
# - one for checking the validity of registration tokens that ratelimits
# requests based on the client's IP address.
# - one for login that ratelimits login requests based on the client's IP
# address.
# - one for login that ratelimits login requests based on the account the
@ -171,6 +178,10 @@ class RatelimitConfig(Config):
# per_second: 0.17
# burst_count: 3
#
#rc_registration_token_validity:
# per_second: 0.1
# burst_count: 5
#
#rc_login:
# address:
# per_second: 0.17

View file

@ -33,6 +33,9 @@ class RegistrationConfig(Config):
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
self.registration_requires_token = config.get(
"registration_requires_token", False
)
self.registration_shared_secret = config.get("registration_shared_secret")
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
@ -140,6 +143,9 @@ class RegistrationConfig(Config):
"mechanism by removing the `access_token_lifetime` option."
)
# The fallback template used for authenticating using a registration token
self.registration_token_template = self.read_template("registration_token.html")
# The success template used during fallback auth.
self.fallback_success_template = self.read_template("auth_success.html")
@ -199,6 +205,15 @@ class RegistrationConfig(Config):
#
#enable_3pid_lookup: true
# Require users to submit a token during registration.
# Tokens can be managed using the admin API:
# https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/registration_tokens.html
# Note that `enable_registration` must be set to `true`.
# Disabling this option will not delete any tokens previously generated.
# Defaults to false. Uncomment the following to require tokens:
#
#registration_requires_token: true
# If set, allows registration of standard or admin accounts by anyone who
# has the shared secret, even if registration is otherwise disabled.
#

View file

@ -248,6 +248,7 @@ class ServerConfig(Config):
self.use_presence = config.get("use_presence", True)
# Custom presence router module
# This is the legacy way of configuring it (the config should now be put in the modules section)
self.presence_router_module_class = None
self.presence_router_config = None
presence_router_config = presence_config.get("presence_router")
@ -870,20 +871,6 @@ class ServerConfig(Config):
#
#enabled: false
# Presence routers are third-party modules that can specify additional logic
# to where presence updates from users are routed.
#
presence_router:
# The custom module's class. Uncomment to use a custom presence router module.
#
#module: "my_custom_router.PresenceRouter"
# Configuration options of the custom module. Refer to your module's
# documentation for available options.
#
#config:
# example_option: 'something'
# Whether to require authentication to retrieve profile data (avatars,
# display names) of other users through the client API. Defaults to
# 'false'. Note that profile data is also available via the federation

View file

@ -11,45 +11,115 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict, Iterable, Set, Union
import logging
from typing import (
TYPE_CHECKING,
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Set,
Union,
)
from synapse.api.presence import UserPresenceState
from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
from synapse.server import HomeServer
GET_USERS_FOR_STATES_CALLBACK = Callable[
[Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]]
]
GET_INTERESTED_USERS_CALLBACK = Callable[
[str], Awaitable[Union[Set[str], "PresenceRouter.ALL_USERS"]]
]
logger = logging.getLogger(__name__)
def load_legacy_presence_router(hs: "HomeServer"):
"""Wrapper that loads a presence router module configured using the old
configuration, and registers the hooks they implement.
"""
if hs.config.presence_router_module_class is None:
return
module = hs.config.presence_router_module_class
config = hs.config.presence_router_config
api = hs.get_module_api()
presence_router = module(config=config, module_api=api)
# The known hooks. If a module implements a method which name appears in this set,
# we'll want to register it.
presence_router_methods = {
"get_users_for_states",
"get_interested_users",
}
# All methods that the module provides should be async, but this wasn't enforced
# in the old module system, so we wrap them if needed
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
# f might be None if the callback isn't implemented by the module. In this
# case we don't want to register a callback at all so we return None.
if f is None:
return None
def run(*args, **kwargs):
# mypy doesn't do well across function boundaries so we need to tell it
# f is definitely not None.
assert f is not None
return maybe_awaitable(f(*args, **kwargs))
return run
# Register the hooks through the module API.
hooks = {
hook: async_wrapper(getattr(presence_router, hook, None))
for hook in presence_router_methods
}
api.register_presence_router_callbacks(**hooks)
class PresenceRouter:
"""
A module that the homeserver will call upon to help route user presence updates to
additional destinations. If a custom presence router is configured, calls will be
passed to that instead.
additional destinations.
"""
ALL_USERS = "ALL"
def __init__(self, hs: "HomeServer"):
self.custom_presence_router = None
# Initially there are no callbacks
self._get_users_for_states_callbacks: List[GET_USERS_FOR_STATES_CALLBACK] = []
self._get_interested_users_callbacks: List[GET_INTERESTED_USERS_CALLBACK] = []
# Check whether a custom presence router module has been configured
if hs.config.presence_router_module_class:
# Initialise the module
self.custom_presence_router = hs.config.presence_router_module_class(
config=hs.config.presence_router_config, module_api=hs.get_module_api()
def register_presence_router_callbacks(
self,
get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None,
get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None,
):
# PresenceRouter modules are required to implement both of these methods
# or neither of them as they are assumed to act in a complementary manner
paired_methods = [get_users_for_states, get_interested_users]
if paired_methods.count(None) == 1:
raise RuntimeError(
"PresenceRouter modules must register neither or both of the paired callbacks: "
"[get_users_for_states, get_interested_users]"
)
# Ensure the module has implemented the required methods
required_methods = ["get_users_for_states", "get_interested_users"]
for method_name in required_methods:
if not hasattr(self.custom_presence_router, method_name):
raise Exception(
"PresenceRouter module '%s' must implement all required methods: %s"
% (
hs.config.presence_router_module_class.__name__,
", ".join(required_methods),
)
)
# Append the methods provided to the lists of callbacks
if get_users_for_states is not None:
self._get_users_for_states_callbacks.append(get_users_for_states)
if get_interested_users is not None:
self._get_interested_users_callbacks.append(get_interested_users)
async def get_users_for_states(
self,
@ -66,14 +136,40 @@ class PresenceRouter:
A dictionary of user_id -> set of UserPresenceState, indicating which
presence updates each user should receive.
"""
if self.custom_presence_router is not None:
# Ask the custom module
return await self.custom_presence_router.get_users_for_states(
state_updates=state_updates
)
# Don't include any extra destinations for presence updates
return {}
# Bail out early if we don't have any callbacks to run.
if len(self._get_users_for_states_callbacks) == 0:
# Don't include any extra destinations for presence updates
return {}
users_for_states = {}
# run all the callbacks for get_users_for_states and combine the results
for callback in self._get_users_for_states_callbacks:
try:
result = await callback(state_updates)
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue
if not isinstance(result, Dict):
logger.warning(
"Wrong type returned by module API callback %s: %s, expected Dict",
callback,
result,
)
continue
for key, new_entries in result.items():
if not isinstance(new_entries, Set):
logger.warning(
"Wrong type returned by module API callback %s: %s, expected Set",
callback,
new_entries,
)
break
users_for_states.setdefault(key, set()).update(new_entries)
return users_for_states
async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]:
"""
@ -92,12 +188,36 @@ class PresenceRouter:
A set of user IDs to return presence updates for, or ALL_USERS to return all
known updates.
"""
if self.custom_presence_router is not None:
# Ask the custom module for interested users
return await self.custom_presence_router.get_interested_users(
user_id=user_id
)
# A custom presence router is not defined.
# Don't report any additional interested users
return set()
# Bail out early if we don't have any callbacks to run.
if len(self._get_interested_users_callbacks) == 0:
# Don't report any additional interested users
return set()
interested_users = set()
# run all the callbacks for get_interested_users and combine the results
for callback in self._get_interested_users_callbacks:
try:
result = await callback(user_id)
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue
# If one of the callbacks returns ALL_USERS then we can stop calling all
# of the other callbacks, since the set of interested_users is already as
# large as it can possibly be
if result == PresenceRouter.ALL_USERS:
return PresenceRouter.ALL_USERS
if not isinstance(result, Set):
logger.warning(
"Wrong type returned by module API callback %s: %s, expected set",
callback,
result,
)
continue
# Add the new interested users to the set
interested_users.update(result)
return interested_users

View file

@ -32,6 +32,9 @@ from . import EventBase
# the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar"
SPLIT_FIELD_REGEX = re.compile(r"(?<!\\)\.")
CANONICALJSON_MAX_INT = (2 ** 53) - 1
CANONICALJSON_MIN_INT = -CANONICALJSON_MAX_INT
def prune_event(event: EventBase) -> EventBase:
"""Returns a pruned version of the given event, which removes all keys we
@ -505,7 +508,7 @@ def validate_canonicaljson(value: Any):
* NaN, Infinity, -Infinity
"""
if isinstance(value, int):
if value <= -(2 ** 53) or 2 ** 53 <= value:
if value < CANONICALJSON_MIN_INT or CANONICALJSON_MAX_INT < value:
raise SynapseError(400, "JSON integer out of range", Codes.BAD_JSON)
elif isinstance(value, float):

View file

@ -11,16 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc
from typing import Union
import jsonschema
from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import EventFormatVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
from synapse.events.utils import validate_canonicaljson
from synapse.events.utils import (
CANONICALJSON_MAX_INT,
CANONICALJSON_MIN_INT,
validate_canonicaljson,
)
from synapse.federation.federation_server import server_matches_acl_event
from synapse.types import EventID, RoomID, UserID
@ -93,6 +99,29 @@ class EventValidator:
400, "Can't create an ACL event that denies the local server"
)
if event.type == EventTypes.PowerLevels:
try:
jsonschema.validate(
instance=event.content,
schema=POWER_LEVELS_SCHEMA,
cls=plValidator,
)
except jsonschema.ValidationError as e:
if e.path:
# example: "users_default": '0' is not of type 'integer'
message = '"' + e.path[-1] + '": ' + e.message # noqa: B306
# jsonschema.ValidationError.message is a valid attribute
else:
# example: '0' is not of type 'integer'
message = e.message # noqa: B306
# jsonschema.ValidationError.message is a valid attribute
raise SynapseError(
code=400,
msg=message,
errcode=Codes.BAD_JSON,
)
def _validate_retention(self, event: EventBase):
"""Checks that an event that defines the retention policy for a room respects the
format enforced by the spec.
@ -195,3 +224,47 @@ class EventValidator:
def _ensure_state_event(self, event):
if not event.is_state():
raise SynapseError(400, "'%s' must be state events" % (event.type,))
POWER_LEVELS_SCHEMA = {
"type": "object",
"properties": {
"ban": {"$ref": "#/definitions/int"},
"events": {"$ref": "#/definitions/objectOfInts"},
"events_default": {"$ref": "#/definitions/int"},
"invite": {"$ref": "#/definitions/int"},
"kick": {"$ref": "#/definitions/int"},
"notifications": {"$ref": "#/definitions/objectOfInts"},
"redact": {"$ref": "#/definitions/int"},
"state_default": {"$ref": "#/definitions/int"},
"users": {"$ref": "#/definitions/objectOfInts"},
"users_default": {"$ref": "#/definitions/int"},
},
"definitions": {
"int": {
"type": "integer",
"minimum": CANONICALJSON_MIN_INT,
"maximum": CANONICALJSON_MAX_INT,
},
"objectOfInts": {
"type": "object",
"additionalProperties": {"$ref": "#/definitions/int"},
},
},
}
def _create_power_level_validator():
validator = jsonschema.validators.validator_for(POWER_LEVELS_SCHEMA)
# by default jsonschema does not consider a frozendict to be an object so
# we need to use a custom type checker
# https://python-jsonschema.readthedocs.io/en/stable/validate/?highlight=object#validating-with-additional-types
type_checker = validator.TYPE_CHECKER.redefine(
"object", lambda checker, thing: isinstance(thing, collections.abc.Mapping)
)
return jsonschema.validators.extend(validator, type_checker=type_checker)
plValidator = _create_power_level_validator()

View file

@ -43,6 +43,7 @@ from synapse.api.errors import (
Codes,
FederationDeniedError,
HttpResponseException,
RequestSendFailed,
SynapseError,
UnsupportedRoomVersionError,
)
@ -110,6 +111,23 @@ class FederationClient(FederationBase):
reset_expiry_on_get=False,
)
# A cache for fetching the room hierarchy over federation.
#
# Some stale data over federation is OK, but must be refreshed
# periodically since the local server is in the room.
#
# It is a map of (room ID, suggested-only) -> the response of
# get_room_hierarchy.
self._get_room_hierarchy_cache: ExpiringCache[
Tuple[str, bool], Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]
] = ExpiringCache(
cache_name="get_room_hierarchy_cache",
clock=self._clock,
max_len=1000,
expiry_ms=5 * 60 * 1000,
reset_expiry_on_get=False,
)
def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache"""
now = self._clock.time_msec()
@ -558,7 +576,11 @@ class FederationClient(FederationBase):
try:
return await callback(destination)
except InvalidResponseError as e:
except (
RequestSendFailed,
InvalidResponseError,
NotRetryingDestination,
) as e:
logger.warning("Failed to %s via %s: %s", description, destination, e)
except UnsupportedRoomVersionError:
raise
@ -1319,6 +1341,10 @@ class FederationClient(FederationBase):
remote servers
"""
cached_result = self._get_room_hierarchy_cache.get((room_id, suggested_only))
if cached_result:
return cached_result
async def send_request(
destination: str,
) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]:
@ -1365,58 +1391,63 @@ class FederationClient(FederationBase):
return room, children, inaccessible_children
try:
return await self._try_destination_list(
result = await self._try_destination_list(
"fetch room hierarchy",
destinations,
send_request,
failover_on_unknown_endpoint=True,
)
except SynapseError as e:
# If an unexpected error occurred, re-raise it.
if e.code != 502:
raise
# Fallback to the old federation API and translate the results if
# no servers implement the new API.
#
# The algorithm below is a bit inefficient as it only attempts to
# get information for the requested room, but the legacy API may
# parse information for the requested room, but the legacy API may
# return additional layers.
if e.code == 502:
legacy_result = await self.get_space_summary(
destinations,
room_id,
suggested_only,
max_rooms_per_space=None,
exclude_rooms=[],
)
legacy_result = await self.get_space_summary(
destinations,
room_id,
suggested_only,
max_rooms_per_space=None,
exclude_rooms=[],
)
# Find the requested room in the response (and remove it).
for _i, room in enumerate(legacy_result.rooms):
if room.get("room_id") == room_id:
break
else:
# The requested room was not returned, nothing we can do.
raise
requested_room = legacy_result.rooms.pop(_i)
# Find the requested room in the response (and remove it).
for _i, room in enumerate(legacy_result.rooms):
if room.get("room_id") == room_id:
break
else:
# The requested room was not returned, nothing we can do.
raise
requested_room = legacy_result.rooms.pop(_i)
# Find any children events of the requested room.
children_events = []
children_room_ids = set()
for event in legacy_result.events:
if event.room_id == room_id:
children_events.append(event.data)
children_room_ids.add(event.state_key)
# And add them under the requested room.
requested_room["children_state"] = children_events
# Find any children events of the requested room.
children_events = []
children_room_ids = set()
for event in legacy_result.events:
if event.room_id == room_id:
children_events.append(event.data)
children_room_ids.add(event.state_key)
# And add them under the requested room.
requested_room["children_state"] = children_events
# Find the children rooms.
children = []
for room in legacy_result.rooms:
if room.get("room_id") in children_room_ids:
children.append(room)
# Find the children rooms.
children = []
for room in legacy_result.rooms:
if room.get("room_id") in children_room_ids:
children.append(room)
# It isn't clear from the response whether some of the rooms are
# not accessible.
return requested_room, children, ()
# It isn't clear from the response whether some of the rooms are
# not accessible.
result = (requested_room, children, ())
raise
# Cache the result to avoid fetching data over federation every time.
self._get_room_hierarchy_cache[(room_id, suggested_only)] = result
return result
@attr.s(frozen=True, slots=True, auto_attribs=True)

View file

@ -110,6 +110,7 @@ class FederationServer(FederationBase):
super().__init__(hs)
self.handler = hs.get_federation_handler()
self._federation_event_handler = hs.get_federation_event_handler()
self.state = hs.get_state_handler()
self._event_auth_handler = hs.get_event_auth_handler()
@ -787,7 +788,9 @@ class FederationServer(FederationBase):
event = await self._check_sigs_and_hash(room_version, event)
return await self.handler.on_send_membership_event(origin, event)
return await self._federation_event_handler.on_send_membership_event(
origin, event
)
async def on_event_auth(
self, origin: str, room_id: str, event_id: str
@ -1005,9 +1008,7 @@ class FederationServer(FederationBase):
async with lock:
logger.info("handling received PDU: %s", event)
try:
await self.handler.on_receive_pdu(
origin, event, sent_to_us_directly=True
)
await self._federation_event_handler.on_receive_pdu(origin, event)
except FederationError as e:
# XXX: Ideally we'd inform the remote we failed to process
# the event, but we can't return an error in the transaction

View file

@ -627,23 +627,28 @@ class AuthHandler(BaseHandler):
async def add_oob_auth(
self, stagetype: str, authdict: Dict[str, Any], clientip: str
) -> bool:
) -> None:
"""
Adds the result of out-of-band authentication into an existing auth
session. Currently used for adding the result of fallback auth.
Raises:
LoginError if the stagetype is unknown or the session is missing.
LoginError is raised by check_auth if authentication fails.
"""
if stagetype not in self.checkers:
raise LoginError(400, "", Codes.MISSING_PARAM)
if "session" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
result = await self.checkers[stagetype].check_auth(authdict, clientip)
if result:
await self.store.mark_ui_auth_stage_complete(
authdict["session"], stagetype, result
raise LoginError(
400, f"Unknown UIA stage type: {stagetype}", Codes.INVALID_PARAM
)
return True
return False
if "session" not in authdict:
raise LoginError(400, "Missing session ID", Codes.MISSING_PARAM)
# If authentication fails a LoginError is raised. Otherwise, store
# the successful result.
result = await self.checkers[stagetype].check_auth(authdict, clientip)
await self.store.mark_ui_auth_stage_complete(
authdict["session"], stagetype, result
)
def get_session_id(self, clientdict: Dict[str, Any]) -> Optional[str]:
"""
@ -1459,6 +1464,10 @@ class AuthHandler(BaseHandler):
)
await self.store.user_delete_threepid(user_id, medium, address)
if medium == "email":
await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id="m.email", pushkey=address, user_id=user_id
)
return result
async def hash(self, password: str) -> str:
@ -1727,7 +1736,6 @@ class AuthHandler(BaseHandler):
@attr.s(slots=True)
class MacaroonGenerator:
hs = attr.ib()
def generate_guest_access_token(self, user_id: str) -> str:

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -151,7 +151,7 @@ class InitialSyncHandler(BaseHandler):
limit = 10
async def handle_room(event: RoomsForUser):
d = {
d: JsonDict = {
"room_id": event.room_id,
"membership": event.membership,
"visibility": (

View file

@ -353,6 +353,11 @@ class BasePresenceHandler(abc.ABC):
# otherwise would not do).
await self.set_state(UserID.from_string(user_id), state, force_notify=True)
async def is_visible(self, observed_user: UserID, observer_user: UserID) -> bool:
raise NotImplementedError(
"Attempting to check presence on a non-presence worker."
)
class _NullContextManager(ContextManager[None]):
"""A context manager which does nothing."""

View file

@ -56,6 +56,22 @@ login_counter = Counter(
)
def init_counters_for_auth_provider(auth_provider_id: str) -> None:
"""Ensure the prometheus counters for the given auth provider are initialised
This fixes a problem where the counters are not reported for a given auth provider
until the user first logs in/registers.
"""
for is_guest in (True, False):
login_counter.labels(guest=is_guest, auth_provider=auth_provider_id)
for shadow_banned in (True, False):
registration_counter.labels(
guest=is_guest,
shadow_banned=shadow_banned,
auth_provider=auth_provider_id,
)
class LoginDict(TypedDict):
device_id: str
access_token: str
@ -96,6 +112,8 @@ class RegistrationHandler(BaseHandler):
self.session_lifetime = hs.config.session_lifetime
self.access_token_lifetime = hs.config.access_token_lifetime
init_counters_for_auth_provider("")
async def check_username(
self,
localpart: str,

View file

@ -36,6 +36,7 @@ from synapse.api.ratelimiting import Ratelimiter
from synapse.event_auth import get_named_level, get_power_level_event
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
from synapse.types import (
JsonDict,
Requester,
@ -79,7 +80,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.account_data_handler = hs.get_account_data_handler()
self.event_auth_handler = hs.get_event_auth_handler()
self.member_linearizer = Linearizer(name="member")
self.member_linearizer: Linearizer = Linearizer(name="member")
self.clock = hs.get_clock()
self.spam_checker = hs.get_spam_checker()
@ -556,6 +557,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
content.pop("displayname", None)
content.pop("avatar_url", None)
if len(content.get("displayname") or "") > MAX_DISPLAYNAME_LEN:
raise SynapseError(
400,
f"Displayname is too long (max {MAX_DISPLAYNAME_LEN})",
errcode=Codes.BAD_JSON,
)
if len(content.get("avatar_url") or "") > MAX_AVATAR_URL_LEN:
raise SynapseError(
400,
f"Avatar URL is too long (max {MAX_AVATAR_URL_LEN})",
errcode=Codes.BAD_JSON,
)
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"

View file

@ -28,12 +28,11 @@ from synapse.api.constants import (
Membership,
RoomTypes,
)
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
from synapse.events import EventBase
from synapse.events.utils import format_event_for_client_v2
from synapse.types import JsonDict
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -76,6 +75,9 @@ class _PaginationSession:
class RoomSummaryHandler:
# A unique key used for pagination sessions for the room hierarchy endpoint.
_PAGINATION_SESSION_TYPE = "room_hierarchy_pagination"
# The time a pagination session remains valid for.
_PAGINATION_SESSION_VALIDITY_PERIOD_MS = 5 * 60 * 1000
@ -87,12 +89,6 @@ class RoomSummaryHandler:
self._server_name = hs.hostname
self._federation_client = hs.get_federation_client()
# A map of query information to the current pagination state.
#
# TODO Allow for multiple workers to share this data.
# TODO Expire pagination tokens.
self._pagination_sessions: Dict[_PaginationKey, _PaginationSession] = {}
# If a user tries to fetch the same page multiple times in quick succession,
# only process the first attempt and return its result to subsequent requests.
self._pagination_response_cache: ResponseCache[
@ -102,21 +98,6 @@ class RoomSummaryHandler:
"get_room_hierarchy",
)
def _expire_pagination_sessions(self):
"""Expire pagination session which are old."""
expire_before = (
self._clock.time_msec() - self._PAGINATION_SESSION_VALIDITY_PERIOD_MS
)
to_expire = []
for key, value in self._pagination_sessions.items():
if value.creation_time_ms < expire_before:
to_expire.append(key)
for key in to_expire:
logger.debug("Expiring pagination session id %s", key)
del self._pagination_sessions[key]
async def get_space_summary(
self,
requester: str,
@ -327,18 +308,29 @@ class RoomSummaryHandler:
# If this is continuing a previous session, pull the persisted data.
if from_token:
self._expire_pagination_sessions()
try:
pagination_session = await self._store.get_session(
session_type=self._PAGINATION_SESSION_TYPE,
session_id=from_token,
)
except StoreError:
raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM)
pagination_key = _PaginationKey(
requested_room_id, suggested_only, max_depth, from_token
)
if pagination_key not in self._pagination_sessions:
# If the requester, room ID, suggested-only, or max depth were modified
# the session is invalid.
if (
requester != pagination_session["requester"]
or requested_room_id != pagination_session["room_id"]
or suggested_only != pagination_session["suggested_only"]
or max_depth != pagination_session["max_depth"]
):
raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM)
# Load the previous state.
pagination_session = self._pagination_sessions[pagination_key]
room_queue = pagination_session.room_queue
processed_rooms = pagination_session.processed_rooms
room_queue = [
_RoomQueueEntry(*fields) for fields in pagination_session["room_queue"]
]
processed_rooms = set(pagination_session["processed_rooms"])
else:
# The queue of rooms to process, the next room is last on the stack.
room_queue = [_RoomQueueEntry(requested_room_id, ())]
@ -456,13 +448,21 @@ class RoomSummaryHandler:
# If there's additional data, generate a pagination token (and persist state).
if room_queue:
next_batch = random_string(24)
result["next_batch"] = next_batch
pagination_key = _PaginationKey(
requested_room_id, suggested_only, max_depth, next_batch
)
self._pagination_sessions[pagination_key] = _PaginationSession(
self._clock.time_msec(), room_queue, processed_rooms
result["next_batch"] = await self._store.create_session(
session_type=self._PAGINATION_SESSION_TYPE,
value={
# Information which must be identical across pagination.
"requester": requester,
"room_id": requested_room_id,
"suggested_only": suggested_only,
"max_depth": max_depth,
# The stored state.
"room_queue": [
attr.astuple(room_entry) for room_entry in room_queue
],
"processed_rooms": list(processed_rooms),
},
expiry_ms=self._PAGINATION_SESSION_VALIDITY_PERIOD_MS,
)
return result

View file

@ -37,6 +37,7 @@ from twisted.web.server import Request
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
from synapse.config.sso import SsoAttributeRequirement
from synapse.handlers.register import init_counters_for_auth_provider
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
from synapse.http.server import respond_with_html, respond_with_redirect
@ -213,6 +214,7 @@ class SsoHandler:
p_id = p.idp_id
assert p_id not in self._identity_providers
self._identity_providers[p_id] = p
init_counters_for_auth_provider(p_id)
def get_identity_providers(self) -> Mapping[str, SsoIdentityProvider]:
"""Get the configured identity providers"""

View file

@ -1,5 +1,4 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018, 2019 New Vector Ltd
# Copyright 2015-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -31,6 +30,8 @@ from prometheus_client import Counter
from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.logging.context import current_context
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
@ -86,20 +87,20 @@ LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100
SyncRequestKey = Tuple[Any, ...]
@attr.s(slots=True, frozen=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class SyncConfig:
user = attr.ib(type=UserID)
filter_collection = attr.ib(type=FilterCollection)
is_guest = attr.ib(type=bool)
request_key = attr.ib(type=SyncRequestKey)
device_id = attr.ib(type=Optional[str])
user: UserID
filter_collection: FilterCollection
is_guest: bool
request_key: SyncRequestKey
device_id: Optional[str]
@attr.s(slots=True, frozen=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class TimelineBatch:
prev_batch = attr.ib(type=StreamToken)
events = attr.ib(type=List[EventBase])
limited = attr.ib(type=bool)
prev_batch: StreamToken
events: List[EventBase]
limited: bool
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@ -113,16 +114,16 @@ class TimelineBatch:
# if there are updates for it, which we check after the instance has been created.
# This should not be a big deal because we update the notification counts afterwards as
# well anyway.
@attr.s(slots=True)
@attr.s(slots=True, auto_attribs=True)
class JoinedSyncResult:
room_id = attr.ib(type=str)
timeline = attr.ib(type=TimelineBatch)
state = attr.ib(type=StateMap[EventBase])
ephemeral = attr.ib(type=List[JsonDict])
account_data = attr.ib(type=List[JsonDict])
unread_notifications = attr.ib(type=JsonDict)
summary = attr.ib(type=Optional[JsonDict])
unread_count = attr.ib(type=int)
room_id: str
timeline: TimelineBatch
state: StateMap[EventBase]
ephemeral: List[JsonDict]
account_data: List[JsonDict]
unread_notifications: JsonDict
summary: Optional[JsonDict]
unread_count: int
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@ -138,12 +139,12 @@ class JoinedSyncResult:
)
@attr.s(slots=True, frozen=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ArchivedSyncResult:
room_id = attr.ib(type=str)
timeline = attr.ib(type=TimelineBatch)
state = attr.ib(type=StateMap[EventBase])
account_data = attr.ib(type=List[JsonDict])
room_id: str
timeline: TimelineBatch
state: StateMap[EventBase]
account_data: List[JsonDict]
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@ -152,37 +153,37 @@ class ArchivedSyncResult:
return bool(self.timeline or self.state or self.account_data)
@attr.s(slots=True, frozen=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class InvitedSyncResult:
room_id = attr.ib(type=str)
invite = attr.ib(type=EventBase)
room_id: str
invite: EventBase
def __bool__(self) -> bool:
"""Invited rooms should always be reported to the client"""
return True
@attr.s(slots=True, frozen=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class KnockedSyncResult:
room_id = attr.ib(type=str)
knock = attr.ib(type=EventBase)
room_id: str
knock: EventBase
def __bool__(self) -> bool:
"""Knocked rooms should always be reported to the client"""
return True
@attr.s(slots=True, frozen=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class GroupsSyncResult:
join = attr.ib(type=JsonDict)
invite = attr.ib(type=JsonDict)
leave = attr.ib(type=JsonDict)
join: JsonDict
invite: JsonDict
leave: JsonDict
def __bool__(self) -> bool:
return bool(self.join or self.invite or self.leave)
@attr.s(slots=True, frozen=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceLists:
"""
Attributes:
@ -190,27 +191,27 @@ class DeviceLists:
left: List of user_ids whose devices we no longer track
"""
changed = attr.ib(type=Collection[str])
left = attr.ib(type=Collection[str])
changed: Collection[str]
left: Collection[str]
def __bool__(self) -> bool:
return bool(self.changed or self.left)
@attr.s(slots=True)
@attr.s(slots=True, auto_attribs=True)
class _RoomChanges:
"""The set of room entries to include in the sync, plus the set of joined
and left room IDs since last sync.
"""
room_entries = attr.ib(type=List["RoomSyncResultBuilder"])
invited = attr.ib(type=List[InvitedSyncResult])
knocked = attr.ib(type=List[KnockedSyncResult])
newly_joined_rooms = attr.ib(type=List[str])
newly_left_rooms = attr.ib(type=List[str])
room_entries: List["RoomSyncResultBuilder"]
invited: List[InvitedSyncResult]
knocked: List[KnockedSyncResult]
newly_joined_rooms: List[str]
newly_left_rooms: List[str]
@attr.s(slots=True, frozen=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class SyncResult:
"""
Attributes:
@ -230,18 +231,18 @@ class SyncResult:
groups: Group updates, if any
"""
next_batch = attr.ib(type=StreamToken)
presence = attr.ib(type=List[JsonDict])
account_data = attr.ib(type=List[JsonDict])
joined = attr.ib(type=List[JoinedSyncResult])
invited = attr.ib(type=List[InvitedSyncResult])
knocked = attr.ib(type=List[KnockedSyncResult])
archived = attr.ib(type=List[ArchivedSyncResult])
to_device = attr.ib(type=List[JsonDict])
device_lists = attr.ib(type=DeviceLists)
device_one_time_keys_count = attr.ib(type=JsonDict)
device_unused_fallback_key_types = attr.ib(type=List[str])
groups = attr.ib(type=Optional[GroupsSyncResult])
next_batch: StreamToken
presence: List[UserPresenceState]
account_data: List[JsonDict]
joined: List[JoinedSyncResult]
invited: List[InvitedSyncResult]
knocked: List[KnockedSyncResult]
archived: List[ArchivedSyncResult]
to_device: List[JsonDict]
device_lists: DeviceLists
device_one_time_keys_count: JsonDict
device_unused_fallback_key_types: List[str]
groups: Optional[GroupsSyncResult]
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@ -701,7 +702,7 @@ class SyncHandler:
name_id = state_ids.get((EventTypes.Name, ""))
canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, ""))
summary = {}
summary: JsonDict = {}
empty_ms = MemberSummary([], 0)
# TODO: only send these when they change.
@ -1842,6 +1843,9 @@ class SyncHandler:
knocked = []
for event in room_list:
if event.room_version_id not in KNOWN_ROOM_VERSIONS:
continue
if event.membership == Membership.JOIN:
room_entries.append(
RoomSyncResultBuilder(
@ -2075,21 +2079,23 @@ class SyncHandler:
# If the membership's stream ordering is after the given stream
# ordering, we need to go and work out if the user was in the room
# before.
for room_id, event_pos in joined_rooms:
if not event_pos.persisted_after(room_key):
joined_room_ids.add(room_id)
for joined_room in joined_rooms:
if not joined_room.event_pos.persisted_after(room_key):
joined_room_ids.add(joined_room.room_id)
continue
logger.info("User joined room after current token: %s", room_id)
logger.info("User joined room after current token: %s", joined_room.room_id)
extrems = (
await self.store.get_forward_extremities_for_room_at_stream_ordering(
room_id, event_pos.stream
joined_room.room_id, joined_room.event_pos.stream
)
)
users_in_room = await self.state.get_current_users_in_room(room_id, extrems)
users_in_room = await self.state.get_current_users_in_room(
joined_room.room_id, extrems
)
if user_id in users_in_room:
joined_room_ids.add(room_id)
joined_room_ids.add(joined_room.room_id)
return frozenset(joined_room_ids)
@ -2159,7 +2165,7 @@ def _calculate_state(
return {event_id_to_key[e]: e for e in state_ids}
@attr.s(slots=True)
@attr.s(slots=True, auto_attribs=True)
class SyncResultBuilder:
"""Used to help build up a new SyncResult for a user
@ -2171,33 +2177,33 @@ class SyncResultBuilder:
joined_room_ids: List of rooms the user is joined to
# The following mirror the fields in a sync response
presence (list)
account_data (list)
joined (list[JoinedSyncResult])
invited (list[InvitedSyncResult])
knocked (list[KnockedSyncResult])
archived (list[ArchivedSyncResult])
groups (GroupsSyncResult|None)
to_device (list)
presence
account_data
joined
invited
knocked
archived
groups
to_device
"""
sync_config = attr.ib(type=SyncConfig)
full_state = attr.ib(type=bool)
since_token = attr.ib(type=Optional[StreamToken])
now_token = attr.ib(type=StreamToken)
joined_room_ids = attr.ib(type=FrozenSet[str])
sync_config: SyncConfig
full_state: bool
since_token: Optional[StreamToken]
now_token: StreamToken
joined_room_ids: FrozenSet[str]
presence = attr.ib(type=List[JsonDict], default=attr.Factory(list))
account_data = attr.ib(type=List[JsonDict], default=attr.Factory(list))
joined = attr.ib(type=List[JoinedSyncResult], default=attr.Factory(list))
invited = attr.ib(type=List[InvitedSyncResult], default=attr.Factory(list))
knocked = attr.ib(type=List[KnockedSyncResult], default=attr.Factory(list))
archived = attr.ib(type=List[ArchivedSyncResult], default=attr.Factory(list))
groups = attr.ib(type=Optional[GroupsSyncResult], default=None)
to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list))
presence: List[UserPresenceState] = attr.Factory(list)
account_data: List[JsonDict] = attr.Factory(list)
joined: List[JoinedSyncResult] = attr.Factory(list)
invited: List[InvitedSyncResult] = attr.Factory(list)
knocked: List[KnockedSyncResult] = attr.Factory(list)
archived: List[ArchivedSyncResult] = attr.Factory(list)
groups: Optional[GroupsSyncResult] = None
to_device: List[JsonDict] = attr.Factory(list)
@attr.s(slots=True)
@attr.s(slots=True, auto_attribs=True)
class RoomSyncResultBuilder:
"""Stores information needed to create either a `JoinedSyncResult` or
`ArchivedSyncResult`.
@ -2213,10 +2219,10 @@ class RoomSyncResultBuilder:
upto_token: Latest point to return events from.
"""
room_id = attr.ib(type=str)
rtype = attr.ib(type=str)
events = attr.ib(type=Optional[List[EventBase]])
newly_joined = attr.ib(type=bool)
full_state = attr.ib(type=bool)
since_token = attr.ib(type=Optional[StreamToken])
upto_token = attr.ib(type=StreamToken)
room_id: str
rtype: str
events: Optional[List[EventBase]]
newly_joined: bool
full_state: bool
since_token: Optional[StreamToken]
upto_token: StreamToken

View file

@ -34,3 +34,8 @@ class UIAuthSessionDataConstants:
# used by validate_user_via_ui_auth to store the mxid of the user we are validating
# for.
REQUEST_USER_ID = "request_user_id"
# used during registration to store the registration token used (if required) so that:
# - we can prevent a token being used twice by one session
# - we can 'use up' the token after registration has successfully completed
REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token"

View file

@ -49,7 +49,7 @@ class UserInteractiveAuthChecker:
clientip: The IP address of the client.
Raises:
SynapseError if authentication failed
LoginError if authentication failed.
Returns:
The result of authentication (to pass back to the client?)
@ -131,7 +131,9 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
)
if resp_body["success"]:
return True
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
raise LoginError(
401, "Captcha authentication failed", errcode=Codes.UNAUTHORIZED
)
class _BaseThreepidAuthChecker:
@ -191,7 +193,9 @@ class _BaseThreepidAuthChecker:
raise AssertionError("Unrecognized threepid medium: %s" % (medium,))
if not threepid:
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
raise LoginError(
401, "Unable to get validated threepid", errcode=Codes.UNAUTHORIZED
)
if threepid["medium"] != medium:
raise LoginError(
@ -237,11 +241,76 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
return await self._check_threepid("msisdn", authdict)
class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
AUTH_TYPE = LoginType.REGISTRATION_TOKEN
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self._enabled = bool(hs.config.registration_requires_token)
self.store = hs.get_datastore()
def is_enabled(self) -> bool:
return self._enabled
async def check_auth(self, authdict: dict, clientip: str) -> Any:
if "token" not in authdict:
raise LoginError(400, "Missing registration token", Codes.MISSING_PARAM)
if not isinstance(authdict["token"], str):
raise LoginError(
400, "Registration token must be a string", Codes.INVALID_PARAM
)
if "session" not in authdict:
raise LoginError(400, "Missing UIA session", Codes.MISSING_PARAM)
# Get these here to avoid cyclic dependencies
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
auth_handler = self.hs.get_auth_handler()
session = authdict["session"]
token = authdict["token"]
# If the LoginType.REGISTRATION_TOKEN stage has already been completed,
# return early to avoid incrementing `pending` again.
stored_token = await auth_handler.get_session_data(
session, UIAuthSessionDataConstants.REGISTRATION_TOKEN
)
if stored_token:
if token != stored_token:
raise LoginError(
400, "Registration token has changed", Codes.INVALID_PARAM
)
else:
return token
if await self.store.registration_token_is_valid(token):
# Increment pending counter, so that if token has limited uses it
# can't be used up by someone else in the meantime.
await self.store.set_registration_token_pending(token)
# Store the token in the UIA session, so that once registration
# is complete `completed` can be incremented.
await auth_handler.set_session_data(
session,
UIAuthSessionDataConstants.REGISTRATION_TOKEN,
token,
)
# The token will be stored as the result of the authentication stage
# in ui_auth_sessions_credentials. This allows the pending counter
# for tokens to be decremented when expired sessions are deleted.
return token
else:
raise LoginError(
401, "Invalid registration token", errcode=Codes.UNAUTHORIZED
)
INTERACTIVE_AUTH_CHECKERS = [
DummyAuthChecker,
TermsAuthChecker,
RecaptchaAuthChecker,
EmailIdentityAuthChecker,
MsisdnAuthChecker,
RegistrationTokenAuthChecker,
]
"""A list of UserInteractiveAuthChecker classes"""

View file

@ -12,8 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from twisted.web.server import Request
from synapse.http.server import DirectServeJsonResource
if TYPE_CHECKING:
from synapse.server import HomeServer
class AdditionalResource(DirectServeJsonResource):
"""Resource wrapper for additional_resources
@ -25,7 +32,7 @@ class AdditionalResource(DirectServeJsonResource):
and exception handling.
"""
def __init__(self, hs, handler):
def __init__(self, hs: "HomeServer", handler):
"""Initialise AdditionalResource
The ``handler`` should return a deferred which completes when it has
@ -33,14 +40,14 @@ class AdditionalResource(DirectServeJsonResource):
``request.write()``, and call ``request.finish()``.
Args:
hs (synapse.server.HomeServer): homeserver
hs: homeserver
handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
function to be called to handle the request.
"""
super().__init__()
self._handler = handler
def _async_render(self, request):
def _async_render(self, request: Request):
# Cheekily pass the result straight through, so we don't need to worry
# if its an awaitable or not.
return self._handler(request)

View file

@ -16,7 +16,7 @@
import logging
import random
import time
from typing import List
from typing import Callable, Dict, List
import attr
@ -28,35 +28,35 @@ from synapse.logging.context import make_deferred_yieldable
logger = logging.getLogger(__name__)
SERVER_CACHE = {}
SERVER_CACHE: Dict[bytes, List["Server"]] = {}
@attr.s(slots=True, frozen=True)
@attr.s(auto_attribs=True, slots=True, frozen=True)
class Server:
"""
Our record of an individual server which can be tried to reach a destination.
Attributes:
host (bytes): target hostname
port (int):
priority (int):
weight (int):
expires (int): when the cache should expire this record - in *seconds* since
host: target hostname
port:
priority:
weight:
expires: when the cache should expire this record - in *seconds* since
the epoch
"""
host = attr.ib()
port = attr.ib()
priority = attr.ib(default=0)
weight = attr.ib(default=0)
expires = attr.ib(default=0)
host: bytes
port: int
priority: int = 0
weight: int = 0
expires: int = 0
def _sort_server_list(server_list):
def _sort_server_list(server_list: List[Server]) -> List[Server]:
"""Given a list of SRV records sort them into priority order and shuffle
each priority with the given weight.
"""
priority_map = {}
priority_map: Dict[int, List[Server]] = {}
for server in server_list:
priority_map.setdefault(server.priority, []).append(server)
@ -103,11 +103,16 @@ class SrvResolver:
Args:
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
cache (dict): cache object
get_time (callable): clock implementation. Should return seconds since the epoch
cache: cache object
get_time: clock implementation. Should return seconds since the epoch
"""
def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
def __init__(
self,
dns_client=client,
cache: Dict[bytes, List[Server]] = SERVER_CACHE,
get_time: Callable[[], float] = time.time,
):
self._dns_client = dns_client
self._cache = cache
self._get_time = get_time
@ -116,7 +121,7 @@ class SrvResolver:
"""Look up a SRV record
Args:
service_name (bytes): record to look up
service_name: record to look up
Returns:
a list of the SRV records, or an empty list if none found
@ -158,7 +163,7 @@ class SrvResolver:
and answers[0].payload
and answers[0].payload.target == dns.Name(b".")
):
raise ConnectError("Service %s unavailable" % service_name)
raise ConnectError(f"Service {service_name!r} unavailable")
servers = []

View file

@ -173,7 +173,7 @@ class ProxyAgent(_AgentBase):
raise ValueError(f"Invalid URI {uri!r}")
parsed_uri = URI.fromBytes(uri)
pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
pool_key = f"{parsed_uri.scheme!r}{parsed_uri.host!r}{parsed_uri.port}"
request_path = parsed_uri.originForm
should_skip_proxy = False
@ -199,7 +199,7 @@ class ProxyAgent(_AgentBase):
)
# Cache *all* connections under the same key, since we are only
# connecting to a single destination, the proxy:
pool_key = ("http-proxy", self.http_proxy_endpoint)
pool_key = "http-proxy"
endpoint = self.http_proxy_endpoint
request_path = uri
elif (

View file

@ -32,6 +32,7 @@ from twisted.internet import defer
from twisted.web.resource import IResource
from synapse.events import EventBase
from synapse.events.presence_router import PresenceRouter
from synapse.http.client import SimpleHttpClient
from synapse.http.server import (
DirectServeHtmlResource,
@ -57,6 +58,8 @@ This package defines the 'stable' API which can be used by extension modules whi
are loaded into Synapse.
"""
PRESENCE_ALL_USERS = PresenceRouter.ALL_USERS
__all__ = [
"errors",
"make_deferred_yieldable",
@ -70,6 +73,7 @@ __all__ = [
"DirectServeHtmlResource",
"DirectServeJsonResource",
"ModuleApi",
"PRESENCE_ALL_USERS",
]
logger = logging.getLogger(__name__)
@ -112,6 +116,7 @@ class ModuleApi:
self._spam_checker = hs.get_spam_checker()
self._account_validity_handler = hs.get_account_validity_handler()
self._third_party_event_rules = hs.get_third_party_event_rules()
self._presence_router = hs.get_presence_router()
#################################################################################
# The following methods should only be called during the module's initialisation.
@ -131,6 +136,11 @@ class ModuleApi:
"""Registers callbacks for third party event rules capabilities."""
return self._third_party_event_rules.register_third_party_rules_callbacks
@property
def register_presence_router_callbacks(self):
"""Registers callbacks for presence router capabilities."""
return self._presence_router.register_presence_router_callbacks
def register_web_resource(self, path: str, resource: IResource):
"""Registers a web resource to be served at the given path.

View file

@ -48,7 +48,8 @@ logger = logging.getLogger(__name__)
# [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers.
REQUIREMENTS = [
"jsonschema>=2.5.1",
# we use the TYPE_CHECKER.redefine method added in jsonschema 3.0.0
"jsonschema>=3.0.0",
"frozendict>=1",
"unpaddedbase64>=1.1.0",
"canonicaljson>=1.4.0",

View file

@ -62,7 +62,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.clock = hs.get_clock()
self.federation_handler = hs.get_federation_handler()
self.federation_event_handler = hs.get_federation_event_handler()
@staticmethod
async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
@ -127,7 +127,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
logger.info("Got %d events from federation", len(event_and_contexts))
max_stream_id = await self.federation_handler.persist_events_and_notify(
max_stream_id = await self.federation_event_handler.persist_events_and_notify(
room_id, event_and_contexts, backfilled
)

View file

@ -16,6 +16,9 @@ function captchaDone() {
<body>
<form id="registrationForm" method="post" action="{{ myurl }}">
<div>
{% if error is defined %}
<p class="error"><strong>Error: {{ error }}</strong></p>
{% endif %}
<p>
Hello! We need to prevent computer programs and other automated
things from creating accounts on this server.

View file

@ -0,0 +1,23 @@
<html>
<head>
<title>Authentication</title>
<meta name='viewport' content='width=device-width, initial-scale=1,
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
</head>
<body>
<form id="registrationForm" method="post" action="{{ myurl }}">
<div>
{% if error is defined %}
<p class="error"><strong>Error: {{ error }}</strong></p>
{% endif %}
<p>
Please enter a registration token.
</p>
<input type="hidden" name="session" value="{{ session }}" />
<input type="text" name="token" />
<input type="submit" value="Authenticate" />
</div>
</form>
</body>
</html>

View file

@ -8,6 +8,9 @@
<body>
<form id="registrationForm" method="post" action="{{ myurl }}">
<div>
{% if error is defined %}
<p class="error"><strong>Error: {{ error }}</strong></p>
{% endif %}
<p>
Please click the button below if you agree to the
<a href="{{ terms_url }}">privacy policy of this homeserver.</a>

View file

@ -36,7 +36,11 @@ from synapse.rest.admin.event_reports import (
)
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
from synapse.rest.admin.registration_tokens import (
ListRegistrationTokensRestServlet,
NewRegistrationTokenRestServlet,
RegistrationTokenRestServlet,
)
from synapse.rest.admin.rooms import (
DeleteRoomRestServlet,
ForwardExtremitiesRestServlet,
@ -47,7 +51,6 @@ from synapse.rest.admin.rooms import (
RoomMembersRestServlet,
RoomRestServlet,
RoomStateRestServlet,
ShutdownRoomRestServlet,
)
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
from synapse.rest.admin.statistics import UserMediaStatisticsRestServlet
@ -220,8 +223,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomMembersRestServlet(hs).register(http_server)
DeleteRoomRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
PurgeRoomServlet(hs).register(http_server)
SendServerNoticeServlet(hs).register(http_server)
VersionServlet(hs).register(http_server)
UserAdminServlet(hs).register(http_server)
UserMembershipRestServlet(hs).register(http_server)
@ -241,6 +242,13 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomEventContextServlet(hs).register(http_server)
RateLimitRestServlet(hs).register(http_server)
UsernameAvailableRestServlet(hs).register(http_server)
ListRegistrationTokensRestServlet(hs).register(http_server)
NewRegistrationTokenRestServlet(hs).register(http_server)
RegistrationTokenRestServlet(hs).register(http_server)
# Some servlets only get registered for the main process.
if hs.config.worker_app is None:
SendServerNoticeServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(
@ -253,7 +261,6 @@ def register_servlets_for_client_rest_resource(
PurgeHistoryRestServlet(hs).register(http_server)
ResetPasswordRestServlet(hs).register(http_server)
SearchUsersRestServlet(hs).register(http_server)
ShutdownRoomRestServlet(hs).register(http_server)
UserRegisterServlet(hs).register(http_server)
DeleteGroupAdminRestServlet(hs).register(http_server)
AccountValidityRenewServlet(hs).register(http_server)

View file

@ -1,58 +0,0 @@
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Tuple
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
from synapse.rest.admin import assert_requester_is_admin
from synapse.rest.admin._base import admin_patterns
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
class PurgeRoomServlet(RestServlet):
"""Servlet which will remove all trace of a room from the database
POST /_synapse/admin/v1/purge_room
{
"room_id": "!room:id"
}
returns:
{}
"""
PATTERNS = admin_patterns("/purge_room$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.pagination_handler = hs.get_pagination_handler()
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ("room_id",))
await self.pagination_handler.purge_room(body["room_id"])
return 200, {}

View file

@ -0,0 +1,321 @@
# Copyright 2021 Callum Brown
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import string
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
RestServlet,
parse_boolean,
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class ListRegistrationTokensRestServlet(RestServlet):
"""List registration tokens.
To list all tokens:
GET /_synapse/admin/v1/registration_tokens
200 OK
{
"registration_tokens": [
{
"token": "abcd",
"uses_allowed": 3,
"pending": 0,
"completed": 1,
"expiry_time": null
},
{
"token": "wxyz",
"uses_allowed": null,
"pending": 0,
"completed": 9,
"expiry_time": 1625394937000
}
]
}
The optional query parameter `valid` can be used to filter the response.
If it is `true`, only valid tokens are returned. If it is `false`, only
tokens that have expired or have had all uses exhausted are returned.
If it is omitted, all tokens are returned regardless of validity.
"""
PATTERNS = admin_patterns("/registration_tokens$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
valid = parse_boolean(request, "valid")
token_list = await self.store.get_registration_tokens(valid)
return 200, {"registration_tokens": token_list}
class NewRegistrationTokenRestServlet(RestServlet):
"""Create a new registration token.
For example, to create a token specifying some fields:
POST /_synapse/admin/v1/registration_tokens/new
{
"token": "defg",
"uses_allowed": 1
}
200 OK
{
"token": "defg",
"uses_allowed": 1,
"pending": 0,
"completed": 0,
"expiry_time": null
}
Defaults are used for any fields not specified.
"""
PATTERNS = admin_patterns("/registration_tokens/new$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
# A string of all the characters allowed to be in a registration_token
self.allowed_chars = string.ascii_letters + string.digits + "-_"
self.allowed_chars_set = set(self.allowed_chars)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
if "token" in body:
token = body["token"]
if not isinstance(token, str):
raise SynapseError(400, "token must be a string", Codes.INVALID_PARAM)
if not (0 < len(token) <= 64):
raise SynapseError(
400,
"token must not be empty and must not be longer than 64 characters",
Codes.INVALID_PARAM,
)
if not set(token).issubset(self.allowed_chars_set):
raise SynapseError(
400,
"token must consist only of characters matched by the regex [A-Za-z0-9-_]",
Codes.INVALID_PARAM,
)
else:
# Get length of token to generate (default is 16)
length = body.get("length", 16)
if not isinstance(length, int):
raise SynapseError(
400, "length must be an integer", Codes.INVALID_PARAM
)
if not (0 < length <= 64):
raise SynapseError(
400,
"length must be greater than zero and not greater than 64",
Codes.INVALID_PARAM,
)
# Generate token
token = await self.store.generate_registration_token(
length, self.allowed_chars
)
uses_allowed = body.get("uses_allowed", None)
if not (
uses_allowed is None
or (isinstance(uses_allowed, int) and uses_allowed >= 0)
):
raise SynapseError(
400,
"uses_allowed must be a non-negative integer or null",
Codes.INVALID_PARAM,
)
expiry_time = body.get("expiry_time", None)
if not isinstance(expiry_time, (int, type(None))):
raise SynapseError(
400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
)
if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
raise SynapseError(
400, "expiry_time must not be in the past", Codes.INVALID_PARAM
)
created = await self.store.create_registration_token(
token, uses_allowed, expiry_time
)
if not created:
raise SynapseError(
400, f"Token already exists: {token}", Codes.INVALID_PARAM
)
resp = {
"token": token,
"uses_allowed": uses_allowed,
"pending": 0,
"completed": 0,
"expiry_time": expiry_time,
}
return 200, resp
class RegistrationTokenRestServlet(RestServlet):
"""Retrieve, update, or delete the given token.
For example,
to retrieve a token:
GET /_synapse/admin/v1/registration_tokens/abcd
200 OK
{
"token": "abcd",
"uses_allowed": 3,
"pending": 0,
"completed": 1,
"expiry_time": null
}
to update a token:
PUT /_synapse/admin/v1/registration_tokens/defg
{
"uses_allowed": 5,
"expiry_time": 4781243146000
}
200 OK
{
"token": "defg",
"uses_allowed": 5,
"pending": 0,
"completed": 0,
"expiry_time": 4781243146000
}
to delete a token:
DELETE /_synapse/admin/v1/registration_tokens/wxyz
200 OK
{}
"""
PATTERNS = admin_patterns("/registration_tokens/(?P<token>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
"""Retrieve a registration token."""
await assert_requester_is_admin(self.auth, request)
token_info = await self.store.get_one_registration_token(token)
# If no result return a 404
if token_info is None:
raise NotFoundError(f"No such registration token: {token}")
return 200, token_info
async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
"""Update a registration token."""
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
new_attributes = {}
# Only add uses_allowed to new_attributes if it is present and valid
if "uses_allowed" in body:
uses_allowed = body["uses_allowed"]
if not (
uses_allowed is None
or (isinstance(uses_allowed, int) and uses_allowed >= 0)
):
raise SynapseError(
400,
"uses_allowed must be a non-negative integer or null",
Codes.INVALID_PARAM,
)
new_attributes["uses_allowed"] = uses_allowed
if "expiry_time" in body:
expiry_time = body["expiry_time"]
if not isinstance(expiry_time, (int, type(None))):
raise SynapseError(
400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
)
if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
raise SynapseError(
400, "expiry_time must not be in the past", Codes.INVALID_PARAM
)
new_attributes["expiry_time"] = expiry_time
if len(new_attributes) == 0:
# Nothing to update, get token info to return
token_info = await self.store.get_one_registration_token(token)
else:
token_info = await self.store.update_registration_token(
token, new_attributes
)
# If no result return a 404
if token_info is None:
raise NotFoundError(f"No such registration token: {token}")
return 200, token_info
async def on_DELETE(
self, request: SynapseRequest, token: str
) -> Tuple[int, JsonDict]:
"""Delete a registration token."""
await assert_requester_is_admin(self.auth, request)
if await self.store.delete_registration_token(token):
return 200, {}
raise NotFoundError(f"No such registration token: {token}")

View file

@ -46,41 +46,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class ShutdownRoomRestServlet(RestServlet):
"""Shuts down a room by removing all local users from the room and blocking
all future invites and joins to the room. Any local aliases will be repointed
to a new room created by `new_room_user_id` and kicked users will be auto
joined to the new room.
"""
PATTERNS = admin_patterns("/shutdown_room/(?P<room_id>[^/]+)")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.room_shutdown_handler = hs.get_room_shutdown_handler()
async def on_POST(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
content = parse_json_object_from_request(request)
assert_params_in_dict(content, ["new_room_user_id"])
ret = await self.room_shutdown_handler.shutdown_room(
room_id=room_id,
new_room_user_id=content["new_room_user_id"],
new_room_name=content.get("room_name"),
message=content.get("message"),
requester_user_id=requester.user.to_string(),
block=True,
)
return (200, ret)
class DeleteRoomRestServlet(RestServlet):
"""Delete a room from server.

View file

@ -14,7 +14,7 @@
from typing import TYPE_CHECKING, Optional, Tuple
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from synapse.api.errors import NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
@ -53,6 +53,8 @@ class SendServerNoticeServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.server_notices_manager = hs.get_server_notices_manager()
self.admin_handler = hs.get_admin_handler()
self.txns = HttpTransactionCache(hs)
def register(self, json_resource: HttpServer):
@ -79,19 +81,22 @@ class SendServerNoticeServlet(RestServlet):
# We grab the server notices manager here as its initialisation has a check for worker processes,
# but worker processes still need to initialise SendServerNoticeServlet (as it is part of the
# admin api).
if not self.hs.get_server_notices_manager().is_enabled():
if not self.server_notices_manager.is_enabled():
raise SynapseError(400, "Server notices are not enabled on this server")
user_id = body["user_id"]
UserID.from_string(user_id)
if not self.hs.is_mine_id(user_id):
target_user = UserID.from_string(body["user_id"])
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Server notices can only be sent to local users")
event = await self.hs.get_server_notices_manager().send_notice(
user_id=body["user_id"],
if not await self.admin_handler.get_user(target_user):
raise NotFoundError("User not found")
event = await self.server_notices_manager.send_notice(
user_id=target_user.to_string(),
type=event_type,
state_key=state_key,
event_content=body["content"],
txn_id=txn_id,
)
return 200, {"event_id": event.event_id}

View file

@ -228,13 +228,18 @@ class UserRestServletV2(RestServlet):
if not isinstance(deactivate, bool):
raise SynapseError(400, "'deactivated' parameter is not of type boolean")
# convert into List[Tuple[str, str]]
# convert List[Dict[str, str]] into Set[Tuple[str, str]]
if external_ids is not None:
new_external_ids = []
for external_id in external_ids:
new_external_ids.append(
(external_id["auth_provider"], external_id["external_id"])
)
new_external_ids = {
(external_id["auth_provider"], external_id["external_id"])
for external_id in external_ids
}
# convert List[Dict[str, str]] into Set[Tuple[str, str]]
if threepids is not None:
new_threepids = {
(threepid["medium"], threepid["address"]) for threepid in threepids
}
if user: # modify user
if "displayname" in body:
@ -243,29 +248,39 @@ class UserRestServletV2(RestServlet):
)
if threepids is not None:
# remove old threepids from user
old_threepids = await self.store.user_get_threepids(user_id)
for threepid in old_threepids:
# get changed threepids (added and removed)
# convert List[Dict[str, Any]] into Set[Tuple[str, str]]
cur_threepids = {
(threepid["medium"], threepid["address"])
for threepid in await self.store.user_get_threepids(user_id)
}
add_threepids = new_threepids - cur_threepids
del_threepids = cur_threepids - new_threepids
# remove old threepids
for medium, address in del_threepids:
try:
await self.auth_handler.delete_threepid(
user_id, threepid["medium"], threepid["address"], None
user_id, medium, address, None
)
except Exception:
logger.exception("Failed to remove threepids")
raise SynapseError(500, "Failed to remove threepids")
# add new threepids to user
# add new threepids
current_time = self.hs.get_clock().time_msec()
for threepid in threepids:
for medium, address in add_threepids:
await self.auth_handler.add_threepid(
user_id, threepid["medium"], threepid["address"], current_time
user_id, medium, address, current_time
)
if external_ids is not None:
# get changed external_ids (added and removed)
cur_external_ids = await self.store.get_external_ids_by_user(user_id)
add_external_ids = set(new_external_ids) - set(cur_external_ids)
del_external_ids = set(cur_external_ids) - set(new_external_ids)
cur_external_ids = set(
await self.store.get_external_ids_by_user(user_id)
)
add_external_ids = new_external_ids - cur_external_ids
del_external_ids = cur_external_ids - new_external_ids
# remove old external_ids
for auth_provider, external_id in del_external_ids:
@ -348,9 +363,9 @@ class UserRestServletV2(RestServlet):
if threepids is not None:
current_time = self.hs.get_clock().time_msec()
for threepid in threepids:
for medium, address in new_threepids:
await self.auth_handler.add_threepid(
user_id, threepid["medium"], threepid["address"], current_time
user_id, medium, address, current_time
)
if (
self.hs.config.email_enable_notifs
@ -362,8 +377,8 @@ class UserRestServletV2(RestServlet):
kind="email",
app_id="m.email",
app_display_name="Email Notifications",
device_display_name=threepid["address"],
pushkey=threepid["address"],
device_display_name=address,
pushkey=address,
lang=None, # We don't know a user's language here
data={},
)

View file

@ -13,24 +13,27 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import SynapseError
from synapse.http.server import respond_with_html
from synapse.http.servlet import RestServlet
from twisted.web.server import Request
from synapse.http.server import HttpServer, respond_with_html
from synapse.http.servlet import RestServlet, parse_string
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class AccountValidityRenewServlet(RestServlet):
PATTERNS = client_patterns("/account_validity/renew$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@ -46,18 +49,14 @@ class AccountValidityRenewServlet(RestServlet):
hs.config.account_validity.account_validity_invalid_token_template
)
async def on_GET(self, request):
if b"token" not in request.args:
raise SynapseError(400, "Missing renewal token")
renewal_token = request.args[b"token"][0]
async def on_GET(self, request: Request) -> None:
renewal_token = parse_string(request, "token", required=True)
(
token_valid,
token_stale,
expiration_ts,
) = await self.account_activity_handler.renew_account(
renewal_token.decode("utf8")
)
) = await self.account_activity_handler.renew_account(renewal_token)
if token_valid:
status_code = 200
@ -77,11 +76,7 @@ class AccountValidityRenewServlet(RestServlet):
class AccountValiditySendMailServlet(RestServlet):
PATTERNS = client_patterns("/account_validity/send_mail$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@ -91,7 +86,7 @@ class AccountValiditySendMailServlet(RestServlet):
hs.config.account_validity.account_validity_renew_by_email_enabled
)
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string()
await self.account_activity_handler.send_renewal_email_to_user(user_id)
@ -99,6 +94,6 @@ class AccountValiditySendMailServlet(RestServlet):
return 200, {}
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
AccountValidityRenewServlet(hs).register(http_server)
AccountValiditySendMailServlet(hs).register(http_server)

View file

@ -15,11 +15,14 @@
import logging
from typing import TYPE_CHECKING
from twisted.web.server import Request
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError
from synapse.api.errors import LoginError, SynapseError
from synapse.api.urls import CLIENT_API_PREFIX
from synapse.http.server import respond_with_html
from synapse.http.server import HttpServer, respond_with_html
from synapse.http.servlet import RestServlet, parse_string
from synapse.http.site import SynapseRequest
from ._base import client_patterns
@ -46,9 +49,10 @@ class AuthRestServlet(RestServlet):
self.registration_handler = hs.get_registration_handler()
self.recaptcha_template = hs.config.recaptcha_template
self.terms_template = hs.config.terms_template
self.registration_token_template = hs.config.registration_token_template
self.success_template = hs.config.fallback_success_template
async def on_GET(self, request, stagetype):
async def on_GET(self, request: SynapseRequest, stagetype: str) -> None:
session = parse_string(request, "session")
if not session:
raise SynapseError(400, "No session supplied")
@ -74,6 +78,12 @@ class AuthRestServlet(RestServlet):
# re-authenticate with their SSO provider.
html = await self.auth_handler.start_sso_ui_auth(request, session)
elif stagetype == LoginType.REGISTRATION_TOKEN:
html = self.registration_token_template.render(
session=session,
myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
)
else:
raise SynapseError(404, "Unknown auth stage type")
@ -81,7 +91,7 @@ class AuthRestServlet(RestServlet):
respond_with_html(request, 200, html)
return None
async def on_POST(self, request, stagetype):
async def on_POST(self, request: Request, stagetype: str) -> None:
session = parse_string(request, "session")
if not session:
@ -95,29 +105,32 @@ class AuthRestServlet(RestServlet):
authdict = {"response": response, "session": session}
success = await self.auth_handler.add_oob_auth(
LoginType.RECAPTCHA, authdict, request.getClientIP()
)
if success:
html = self.success_template.render()
else:
try:
await self.auth_handler.add_oob_auth(
LoginType.RECAPTCHA, authdict, request.getClientIP()
)
except LoginError as e:
# Authentication failed, let user try again
html = self.recaptcha_template.render(
session=session,
myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
sitekey=self.hs.config.recaptcha_public_key,
error=e.msg,
)
else:
# No LoginError was raised, so authentication was successful
html = self.success_template.render()
elif stagetype == LoginType.TERMS:
authdict = {"session": session}
success = await self.auth_handler.add_oob_auth(
LoginType.TERMS, authdict, request.getClientIP()
)
if success:
html = self.success_template.render()
else:
try:
await self.auth_handler.add_oob_auth(
LoginType.TERMS, authdict, request.getClientIP()
)
except LoginError as e:
# Authentication failed, let user try again
html = self.terms_template.render(
session=session,
terms_url="%s_matrix/consent?v=%s"
@ -127,10 +140,33 @@ class AuthRestServlet(RestServlet):
),
myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
error=e.msg,
)
else:
# No LoginError was raised, so authentication was successful
html = self.success_template.render()
elif stagetype == LoginType.SSO:
# The SSO fallback workflow should not post here,
raise SynapseError(404, "Fallback SSO auth does not support POST requests.")
elif stagetype == LoginType.REGISTRATION_TOKEN:
token = parse_string(request, "token", required=True)
authdict = {"session": session, "token": token}
try:
await self.auth_handler.add_oob_auth(
LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP()
)
except LoginError as e:
html = self.registration_token_template.render(
session=session,
myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
error=e.msg,
)
else:
html = self.success_template.render()
else:
raise SynapseError(404, "Unknown auth stage type")
@ -139,5 +175,5 @@ class AuthRestServlet(RestServlet):
return None
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
AuthRestServlet(hs).register(http_server)

View file

@ -15,6 +15,7 @@ import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, MSC3244_CAPABILITIES
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
@ -61,8 +62,19 @@ class CapabilitiesRestServlet(RestServlet):
"org.matrix.msc3244.room_capabilities"
] = MSC3244_CAPABILITIES
if self.config.experimental.msc3283_enabled:
response["capabilities"]["org.matrix.msc3283.set_displayname"] = {
"enabled": self.config.enable_set_displayname
}
response["capabilities"]["org.matrix.msc3283.set_avatar_url"] = {
"enabled": self.config.enable_set_avatar_url
}
response["capabilities"]["org.matrix.msc3283.3pid_changes"] = {
"enabled": self.config.enable_3pid_changes
}
return 200, response
def register_servlets(hs: "HomeServer", http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
CapabilitiesRestServlet(hs).register(http_server)

View file

@ -14,34 +14,36 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api import errors
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from ._base import client_patterns, interactive_auth_handler
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class DevicesRestServlet(RestServlet):
PATTERNS = client_patterns("/devices$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
devices = await self.device_handler.get_devices_by_user(
requester.user.to_string()
@ -57,7 +59,7 @@ class DeleteDevicesRestServlet(RestServlet):
PATTERNS = client_patterns("/delete_devices")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@ -65,7 +67,7 @@ class DeleteDevicesRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
try:
@ -100,18 +102,16 @@ class DeleteDevicesRestServlet(RestServlet):
class DeviceRestServlet(RestServlet):
PATTERNS = client_patterns("/devices/(?P<device_id>[^/]*)$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler()
async def on_GET(self, request, device_id):
async def on_GET(
self, request: SynapseRequest, device_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
device = await self.device_handler.get_device(
requester.user.to_string(), device_id
@ -119,7 +119,9 @@ class DeviceRestServlet(RestServlet):
return 200, device
@interactive_auth_handler
async def on_DELETE(self, request, device_id):
async def on_DELETE(
self, request: SynapseRequest, device_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
try:
@ -146,7 +148,9 @@ class DeviceRestServlet(RestServlet):
await self.device_handler.delete_device(requester.user.to_string(), device_id)
return 200, {}
async def on_PUT(self, request, device_id):
async def on_PUT(
self, request: SynapseRequest, device_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
body = parse_json_object_from_request(request)
@ -193,13 +197,13 @@ class DehydratedDeviceServlet(RestServlet):
PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=())
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
async def on_GET(self, request: SynapseRequest):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
dehydrated_device = await self.device_handler.get_dehydrated_device(
requester.user.to_string()
@ -211,7 +215,7 @@ class DehydratedDeviceServlet(RestServlet):
else:
raise errors.NotFoundError("No dehydrated device available")
async def on_PUT(self, request: SynapseRequest):
async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
submission = parse_json_object_from_request(request)
requester = await self.auth.get_user_by_req(request)
@ -259,13 +263,13 @@ class ClaimDehydratedDeviceServlet(RestServlet):
"/org.matrix.msc2697.v2/dehydrated_device/claim", releases=()
)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
async def on_POST(self, request: SynapseRequest):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
submission = parse_json_object_from_request(request)
@ -292,7 +296,7 @@ class ClaimDehydratedDeviceServlet(RestServlet):
return (200, result)
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
DeleteDevicesRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)

View file

@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.api.errors import (
AuthError,
@ -22,14 +24,19 @@ from synapse.api.errors import (
NotFoundError,
SynapseError,
)
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.types import RoomAlias
from synapse.types import JsonDict, RoomAlias
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ClientDirectoryServer(hs).register(http_server)
ClientDirectoryListServer(hs).register(http_server)
ClientAppserviceDirectoryListServer(hs).register(http_server)
@ -38,21 +45,23 @@ def register_servlets(hs, http_server):
class ClientDirectoryServer(RestServlet):
PATTERNS = client_patterns("/directory/room/(?P<room_alias>[^/]*)$", v1=True)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()
async def on_GET(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias)
async def on_GET(self, request: Request, room_alias: str) -> Tuple[int, JsonDict]:
room_alias_obj = RoomAlias.from_string(room_alias)
res = await self.directory_handler.get_association(room_alias)
res = await self.directory_handler.get_association(room_alias_obj)
return 200, res
async def on_PUT(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias)
async def on_PUT(
self, request: SynapseRequest, room_alias: str
) -> Tuple[int, JsonDict]:
room_alias_obj = RoomAlias.from_string(room_alias)
content = parse_json_object_from_request(request)
if "room_id" not in content:
@ -61,7 +70,7 @@ class ClientDirectoryServer(RestServlet):
)
logger.debug("Got content: %s", content)
logger.debug("Got room name: %s", room_alias.to_string())
logger.debug("Got room name: %s", room_alias_obj.to_string())
room_id = content["room_id"]
servers = content["servers"] if "servers" in content else None
@ -78,22 +87,25 @@ class ClientDirectoryServer(RestServlet):
requester = await self.auth.get_user_by_req(request)
await self.directory_handler.create_association(
requester, room_alias, room_id, servers
requester, room_alias_obj, room_id, servers
)
return 200, {}
async def on_DELETE(self, request, room_alias):
async def on_DELETE(
self, request: SynapseRequest, room_alias: str
) -> Tuple[int, JsonDict]:
room_alias_obj = RoomAlias.from_string(room_alias)
try:
service = self.auth.get_appservice_by_req(request)
room_alias = RoomAlias.from_string(room_alias)
await self.directory_handler.delete_appservice_association(
service, room_alias
service, room_alias_obj
)
logger.info(
"Application service at %s deleted alias %s",
service.url,
room_alias.to_string(),
room_alias_obj.to_string(),
)
return 200, {}
except InvalidClientCredentialsError:
@ -103,12 +115,10 @@ class ClientDirectoryServer(RestServlet):
requester = await self.auth.get_user_by_req(request)
user = requester.user
room_alias = RoomAlias.from_string(room_alias)
await self.directory_handler.delete_association(requester, room_alias)
await self.directory_handler.delete_association(requester, room_alias_obj)
logger.info(
"User %s deleted alias %s", user.to_string(), room_alias.to_string()
"User %s deleted alias %s", user.to_string(), room_alias_obj.to_string()
)
return 200, {}
@ -117,20 +127,22 @@ class ClientDirectoryServer(RestServlet):
class ClientDirectoryListServer(RestServlet):
PATTERNS = client_patterns("/directory/list/room/(?P<room_id>[^/]*)$", v1=True)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()
async def on_GET(self, request, room_id):
async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
room = await self.store.get_room(room_id)
if room is None:
raise NotFoundError("Unknown room")
return 200, {"visibility": "public" if room["is_public"] else "private"}
async def on_PUT(self, request, room_id):
async def on_PUT(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
@ -142,7 +154,9 @@ class ClientDirectoryListServer(RestServlet):
return 200, {}
async def on_DELETE(self, request, room_id):
async def on_DELETE(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await self.directory_handler.edit_published_room_list(
@ -157,21 +171,27 @@ class ClientAppserviceDirectoryListServer(RestServlet):
"/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$", v1=True
)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()
def on_PUT(self, request, network_id, room_id):
async def on_PUT(
self, request: SynapseRequest, network_id: str, room_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
visibility = content.get("visibility", "public")
return self._edit(request, network_id, room_id, visibility)
return await self._edit(request, network_id, room_id, visibility)
def on_DELETE(self, request, network_id, room_id):
return self._edit(request, network_id, room_id, "private")
async def on_DELETE(
self, request: SynapseRequest, network_id: str, room_id: str
) -> Tuple[int, JsonDict]:
return await self._edit(request, network_id, room_id, "private")
async def _edit(self, request, network_id, room_id, visibility):
async def _edit(
self, request: SynapseRequest, network_id: str, room_id: str, visibility: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if not requester.app_service:
raise AuthError(

View file

@ -14,11 +14,18 @@
"""This module contains REST servlets to do with event streaming, /events."""
import logging
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -28,31 +35,30 @@ class EventStreamRestServlet(RestServlet):
DEFAULT_LONGPOLL_TIME_MS = 30000
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.event_stream_handler = hs.get_event_stream_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
is_guest = requester.is_guest
room_id = None
args: Dict[bytes, List[bytes]] = request.args # type: ignore
if is_guest:
if b"room_id" not in request.args:
if b"room_id" not in args:
raise SynapseError(400, "Guest users must specify room_id param")
if b"room_id" in request.args:
room_id = request.args[b"room_id"][0].decode("ascii")
room_id = parse_string(request, "room_id")
pagin_config = await PaginationConfig.from_request(self.store, request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if b"timeout" in request.args:
if b"timeout" in args:
try:
timeout = int(request.args[b"timeout"][0])
timeout = int(args[b"timeout"][0])
except ValueError:
raise SynapseError(400, "timeout must be in milliseconds.")
as_client_event = b"raw" not in request.args
as_client_event = b"raw" not in args
chunk = await self.event_stream_handler.get_stream(
requester.user.to_string(),
@ -70,25 +76,27 @@ class EventStreamRestServlet(RestServlet):
class EventRestServlet(RestServlet):
PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
self.auth = hs.get_auth()
self._event_serializer = hs.get_event_client_serializer()
async def on_GET(self, request, event_id):
async def on_GET(
self, request: SynapseRequest, event_id: str
) -> Tuple[int, Union[str, JsonDict]]:
requester = await self.auth.get_user_by_req(request)
event = await self.event_handler.get_event(requester.user, None, event_id)
time_now = self.clock.time_msec()
if event:
event = await self._event_serializer.serialize_event(event, time_now)
return 200, event
result = await self._event_serializer.serialize_event(event, time_now)
return 200, result
else:
return 404, "Event not found."
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
EventStreamRestServlet(hs).register(http_server)
EventRestServlet(hs).register(http_server)

View file

@ -13,26 +13,34 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, UserID
from ._base import client_patterns, set_timeline_upper_limit
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class GetFilterRestServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
async def on_GET(self, request, user_id, filter_id):
async def on_GET(
self, request: SynapseRequest, user_id: str, filter_id: str
) -> Tuple[int, JsonDict]:
target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request)
@ -43,13 +51,13 @@ class GetFilterRestServlet(RestServlet):
raise AuthError(403, "Can only get filters for local users")
try:
filter_id = int(filter_id)
filter_id_int = int(filter_id)
except Exception:
raise SynapseError(400, "Invalid filter_id")
try:
filter_collection = await self.filtering.get_user_filter(
user_localpart=target_user.localpart, filter_id=filter_id
user_localpart=target_user.localpart, filter_id=filter_id_int
)
except StoreError as e:
if e.code != 404:
@ -62,13 +70,15 @@ class GetFilterRestServlet(RestServlet):
class CreateFilterRestServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
async def on_POST(self, request, user_id):
async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request)
@ -89,6 +99,6 @@ class CreateFilterRestServlet(RestServlet):
return 200, {"filter_id": str(filter_id)}
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
GetFilterRestServlet(hs).register(http_server)
CreateFilterRestServlet(hs).register(http_server)

View file

@ -26,6 +26,7 @@ from synapse.api.constants import (
)
from synapse.api.errors import Codes, SynapseError
from synapse.handlers.groups_local import GroupsLocalHandler
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@ -930,7 +931,7 @@ class GroupsForUserServlet(RestServlet):
return 200, result
def register_servlets(hs: "HomeServer", http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
GroupServlet(hs).register(http_server)
GroupSummaryServlet(hs).register(http_server)
GroupInvitedUsersServlet(hs).register(http_server)

View file

@ -12,25 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict, List, Tuple
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
# TODO: Needs unit testing
class InitialSyncRestServlet(RestServlet):
PATTERNS = client_patterns("/initialSync$", v1=True)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
as_client_event = b"raw" not in request.args
args: Dict[bytes, List[bytes]] = request.args # type: ignore
as_client_event = b"raw" not in args
pagination_config = await PaginationConfig.from_request(self.store, request)
include_archived = parse_boolean(request, "archived", default=False)
content = await self.initial_sync_handler.snapshot_all_rooms(
@ -43,5 +51,5 @@ class InitialSyncRestServlet(RestServlet):
return 200, content
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
InitialSyncRestServlet(hs).register(http_server)

View file

@ -15,19 +15,25 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Optional, Tuple
from synapse.api.errors import SynapseError
from synapse.api.errors import InvalidAPICallError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_integer,
parse_json_object_from_request,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.types import StreamToken
from synapse.types import JsonDict, StreamToken
from ._base import client_patterns, interactive_auth_handler
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -59,18 +65,16 @@ class KeyUploadServlet(RestServlet):
PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = hs.get_device_handler()
@trace(opname="upload_keys")
async def on_POST(self, request, device_id):
async def on_POST(
self, request: SynapseRequest, device_id: Optional[str]
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@ -148,21 +152,30 @@ class KeyQueryServlet(RestServlet):
PATTERNS = client_patterns("/keys/query$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
device_id = requester.device_id
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
device_keys = body.get("device_keys")
if not isinstance(device_keys, dict):
raise InvalidAPICallError("'device_keys' must be a JSON object")
def is_list_of_strings(values: Any) -> bool:
return isinstance(values, list) and all(isinstance(v, str) for v in values)
if any(not is_list_of_strings(keys) for keys in device_keys.values()):
raise InvalidAPICallError(
"'device_keys' values must be a list of strings",
)
result = await self.e2e_keys_handler.query_devices(
body, timeout, user_id, device_id
)
@ -181,17 +194,13 @@ class KeyChangesServlet(RestServlet):
PATTERNS = client_patterns("/keys/changes$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
from_token_string = parse_string(request, "from", required=True)
@ -231,12 +240,12 @@ class OneTimeKeyServlet(RestServlet):
PATTERNS = client_patterns("/keys/claim$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
@ -255,11 +264,7 @@ class SigningKeyUploadServlet(RestServlet):
PATTERNS = client_patterns("/keys/device_signing/upload$", releases=())
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@ -267,7 +272,7 @@ class SigningKeyUploadServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@ -315,16 +320,12 @@ class SignaturesUploadServlet(RestServlet):
PATTERNS = client_patterns("/keys/signatures/upload$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@ -335,7 +336,7 @@ class SignaturesUploadServlet(RestServlet):
return 200, result
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
KeyUploadServlet(hs).register(http_server)
KeyQueryServlet(hs).register(http_server)
KeyChangesServlet(hs).register(http_server)

View file

@ -19,6 +19,7 @@ from twisted.web.server import Request
from synapse.api.constants import Membership
from synapse.api.errors import SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_json_object_from_request,
@ -103,5 +104,5 @@ class KnockRoomAliasServlet(RestServlet):
)
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
KnockRoomAliasServlet(hs).register(http_server)

View file

@ -1,4 +1,4 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,7 +14,7 @@
import logging
import re
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple
from typing_extensions import TypedDict
@ -104,7 +104,13 @@ class LoginRestServlet(RestServlet):
burst_count=self.hs.config.rc_login_account.burst_count,
)
def on_GET(self, request: SynapseRequest):
# ensure the CAS/SAML/OIDC handlers are loaded on this worker instance.
# The reason for this is to ensure that the auth_provider_ids are registered
# with SsoHandler, which in turn ensures that the login/registration prometheus
# counters are initialised for the auth_provider_ids.
_load_sso_handlers(hs)
def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
@ -151,7 +157,7 @@ class LoginRestServlet(RestServlet):
return 200, {"flows": flows}
async def on_POST(self, request: SynapseRequest):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]:
login_submission = parse_json_object_from_request(request)
if self._msc2918_enabled:
@ -211,7 +217,7 @@ class LoginRestServlet(RestServlet):
login_submission: JsonDict,
appservice: ApplicationService,
should_issue_refresh_token: bool = False,
):
) -> LoginResponse:
identifier = login_submission.get("identifier")
logger.info("Got appservice login request with identifier: %r", identifier)
@ -461,10 +467,7 @@ class RefreshTokenServlet(RestServlet):
self._clock = hs.get_clock()
self.access_token_lifetime = hs.config.access_token_lifetime
async def on_POST(
self,
request: SynapseRequest,
):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
refresh_submission = parse_json_object_from_request(request)
assert_params_in_dict(refresh_submission, ["refresh_token"])
@ -499,12 +502,7 @@ class SsoRedirectServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
# make sure that the relevant handlers are instantiated, so that they
# register themselves with the main SSOHandler.
if hs.config.cas_enabled:
hs.get_cas_handler()
if hs.config.saml2_enabled:
hs.get_saml_handler()
if hs.config.oidc_enabled:
hs.get_oidc_handler()
_load_sso_handlers(hs)
self._sso_handler = hs.get_sso_handler()
self._msc2858_enabled = hs.config.experimental.msc2858_enabled
self._public_baseurl = hs.config.public_baseurl
@ -569,7 +567,7 @@ class SsoRedirectServlet(RestServlet):
class CasTicketServlet(RestServlet):
PATTERNS = client_patterns("/login/cas/ticket", v1=True)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self._cas_handler = hs.get_cas_handler()
@ -591,10 +589,26 @@ class CasTicketServlet(RestServlet):
)
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
LoginRestServlet(hs).register(http_server)
if hs.config.access_token_lifetime is not None:
RefreshTokenServlet(hs).register(http_server)
SsoRedirectServlet(hs).register(http_server)
if hs.config.cas_enabled:
CasTicketServlet(hs).register(http_server)
def _load_sso_handlers(hs: "HomeServer") -> None:
"""Ensure that the SSO handlers are loaded, if they are enabled by configuration.
This is mostly useful to ensure that the CAS/SAML/OIDC handlers register themselves
with the main SsoHandler.
It's safe to call this multiple times.
"""
if hs.config.cas.cas_enabled:
hs.get_cas_handler()
if hs.config.saml2.saml2_enabled:
hs.get_saml_handler()
if hs.config.oidc.oidc_enabled:
hs.get_oidc_handler()

View file

@ -13,9 +13,16 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -23,13 +30,13 @@ logger = logging.getLogger(__name__)
class LogoutRestServlet(RestServlet):
PATTERNS = client_patterns("/logout$", v1=True)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True)
if requester.device_id is None:
@ -48,13 +55,13 @@ class LogoutRestServlet(RestServlet):
class LogoutAllRestServlet(RestServlet):
PATTERNS = client_patterns("/logout/all$", v1=True)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string()
@ -67,6 +74,6 @@ class LogoutAllRestServlet(RestServlet):
return 200, {}
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
LogoutRestServlet(hs).register(http_server)
LogoutAllRestServlet(hs).register(http_server)

View file

@ -13,26 +13,33 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple
from synapse.events.utils import format_event_for_client_v2_without_room_id
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class NotificationsServlet(RestServlet):
PATTERNS = client_patterns("/notifications$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self._event_serializer = hs.get_event_client_serializer()
async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
@ -87,5 +94,5 @@ class NotificationsServlet(RestServlet):
return 200, {"notifications": returned_push_actions, "next_token": next_token}
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
NotificationsServlet(hs).register(http_server)

View file

@ -12,15 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from synapse.util.stringutils import random_string
from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -58,14 +64,16 @@ class IdTokenServlet(RestServlet):
EXPIRES_MS = 3600 * 1000
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.server_name = hs.config.server_name
async def on_POST(self, request, user_id):
async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot request tokens for other users.")
@ -90,5 +98,5 @@ class IdTokenServlet(RestServlet):
)
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
IdTokenServlet(hs).register(http_server)

View file

@ -13,28 +13,32 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from synapse.types import JsonDict
from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class PasswordPolicyServlet(RestServlet):
PATTERNS = client_patterns("/password_policy$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.policy = hs.config.password_policy
self.enabled = hs.config.password_policy_enabled
def on_GET(self, request):
def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
if not self.enabled or not self.policy:
return (200, {})
@ -53,5 +57,5 @@ class PasswordPolicyServlet(RestServlet):
return (200, policy)
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
PasswordPolicyServlet(hs).register(http_server)

View file

@ -15,12 +15,18 @@
""" This module contains REST servlets to do with presence: /presence/<paths>
"""
import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError, SynapseError
from synapse.handlers.presence import format_user_presence_state
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.types import UserID
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -28,7 +34,7 @@ logger = logging.getLogger(__name__)
class PresenceStatusRestServlet(RestServlet):
PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status", v1=True)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.presence_handler = hs.get_presence_handler()
@ -37,7 +43,9 @@ class PresenceStatusRestServlet(RestServlet):
self._use_presence = hs.config.server.use_presence
async def on_GET(self, request, user_id):
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
@ -53,13 +61,15 @@ class PresenceStatusRestServlet(RestServlet):
raise AuthError(403, "You are not allowed to see their presence.")
state = await self.presence_handler.get_state(target_user=user)
state = format_user_presence_state(
result = format_user_presence_state(
state, self.clock.time_msec(), include_user_id=False
)
return 200, state
return 200, result
async def on_PUT(self, request, user_id):
async def on_PUT(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
@ -91,5 +101,5 @@ class PresenceStatusRestServlet(RestServlet):
return 200, {}
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
PresenceStatusRestServlet(hs).register(http_server)

View file

@ -14,22 +14,31 @@
""" This module contains REST servlets to do with profile: /profile/<paths> """
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.types import UserID
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
class ProfileDisplaynameRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
async def on_GET(self, request, user_id):
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester_user = None
if self.hs.config.require_auth_for_profile_requests:
@ -48,7 +57,9 @@ class ProfileDisplaynameRestServlet(RestServlet):
return 200, ret
async def on_PUT(self, request, user_id):
async def on_PUT(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user = UserID.from_string(user_id)
is_admin = await self.auth.is_server_admin(requester.user)
@ -72,13 +83,15 @@ class ProfileDisplaynameRestServlet(RestServlet):
class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
async def on_GET(self, request, user_id):
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester_user = None
if self.hs.config.require_auth_for_profile_requests:
@ -97,7 +110,9 @@ class ProfileAvatarURLRestServlet(RestServlet):
return 200, ret
async def on_PUT(self, request, user_id):
async def on_PUT(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
is_admin = await self.auth.is_server_admin(requester.user)
@ -120,13 +135,15 @@ class ProfileAvatarURLRestServlet(RestServlet):
class ProfileRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
async def on_GET(self, request, user_id):
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester_user = None
if self.hs.config.require_auth_for_profile_requests:
@ -149,7 +166,7 @@ class ProfileRestServlet(RestServlet):
return 200, ret
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ProfileDisplaynameRestServlet(hs).register(http_server)
ProfileAvatarURLRestServlet(hs).register(http_server)
ProfileRestServlet(hs).register(http_server)

View file

@ -13,17 +13,23 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.http.server import respond_with_html_bytes
from synapse.http.server import HttpServer, respond_with_html_bytes
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.push import PusherConfigException
from synapse.rest.client._base import client_patterns
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -31,12 +37,12 @@ logger = logging.getLogger(__name__)
class PushersRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers$", v1=True)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = requester.user
@ -50,14 +56,14 @@ class PushersRestServlet(RestServlet):
class PushersSetRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers/set$", v1=True)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
self.pusher_pool = self.hs.get_pusherpool()
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = requester.user
@ -132,14 +138,14 @@ class PushersRemoveRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers/remove$", v1=True)
SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.notifier = hs.get_notifier()
self.auth = hs.get_auth()
self.pusher_pool = self.hs.get_pusherpool()
async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request, rights="delete_pusher")
user = requester.user
@ -165,7 +171,7 @@ class PushersRemoveRestServlet(RestServlet):
return None
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
PushersRestServlet(hs).register(http_server)
PushersSetRestServlet(hs).register(http_server)
PushersRemoveRestServlet(hs).register(http_server)

View file

@ -13,27 +13,36 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import ReadReceiptEventFields
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class ReadMarkerRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.receipts_handler = hs.get_receipts_handler()
self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler()
async def on_POST(self, request, room_id):
async def on_POST(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await self.presence_handler.bump_presence_active_time(requester.user)
@ -70,5 +79,5 @@ class ReadMarkerRestServlet(RestServlet):
return 200, {}
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReadMarkerRestServlet(hs).register(http_server)

View file

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hmac
import logging
import random
from typing import List, Union
@ -28,6 +27,7 @@ from synapse.api.errors import (
ThreepidValidationError,
UnrecognizedRequestError,
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.config import ConfigError
from synapse.config.captcha import CaptchaConfig
from synapse.config.consent import ConsentConfig
@ -59,18 +59,6 @@ from synapse.util.threepids import (
from ._base import client_patterns, interactive_auth_handler
# We ought to be using hmac.compare_digest() but on older pythons it doesn't
# exist. It's a _really minor_ security flaw to use plain string comparison
# because the timing attack is so obscured by all the other code here it's
# unlikely to make much difference
if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest
else:
def compare_digest(a, b):
return a == b
logger = logging.getLogger(__name__)
@ -379,6 +367,55 @@ class UsernameAvailabilityRestServlet(RestServlet):
return 200, {"available": True}
class RegistrationTokenValidityRestServlet(RestServlet):
"""Check the validity of a registration token.
Example:
GET /_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity?token=abcd
200 OK
{
"valid": true
}
"""
PATTERNS = client_patterns(
f"/org.matrix.msc3231/register/{LoginType.REGISTRATION_TOKEN}/validity",
releases=(),
unstable=True,
)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__()
self.hs = hs
self.store = hs.get_datastore()
self.ratelimiter = Ratelimiter(
store=self.store,
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second,
burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count,
)
async def on_GET(self, request):
await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
if not self.hs.config.enable_registration:
raise SynapseError(
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
)
token = parse_string(request, "token", required=True)
valid = await self.store.registration_token_is_valid(token)
return 200, {"valid": valid}
class RegisterRestServlet(RestServlet):
PATTERNS = client_patterns("/register$")
@ -686,6 +723,22 @@ class RegisterRestServlet(RestServlet):
)
if registered:
# Check if a token was used to authenticate registration
registration_token = await self.auth_handler.get_session_data(
session_id,
UIAuthSessionDataConstants.REGISTRATION_TOKEN,
)
if registration_token:
# Increment the `completed` counter for the token
await self.store.use_registration_token(registration_token)
# Indicate that the token has been successfully used so that
# pending is not decremented again when expiring old UIA sessions.
await self.store.mark_ui_auth_stage_complete(
session_id,
LoginType.REGISTRATION_TOKEN,
True,
)
await self.registration_handler.post_registration_actions(
user_id=registered_user_id,
auth_result=auth_result,
@ -868,6 +921,11 @@ def _calculate_registration_flows(
for flow in flows:
flow.insert(0, LoginType.RECAPTCHA)
# Prepend registration token to all flows if we're requiring a token
if config.registration_requires_token:
for flow in flows:
flow.insert(0, LoginType.REGISTRATION_TOKEN)
return flows
@ -876,4 +934,5 @@ def register_servlets(hs, http_server):
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
UsernameAvailabilityRestServlet(hs).register(http_server)
RegistrationSubmitTokenServlet(hs).register(http_server)
RegistrationTokenValidityRestServlet(hs).register(http_server)
RegisterRestServlet(hs).register(http_server)

View file

@ -13,18 +13,25 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, ShadowBanError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from synapse.util import stringutils
from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -41,9 +48,6 @@ class RoomUpgradeRestServlet(RestServlet):
}
Creates a new room and shuts down the old one. Returns the ID of the new room.
Args:
hs (synapse.server.HomeServer):
"""
PATTERNS = client_patterns(
@ -51,13 +55,15 @@ class RoomUpgradeRestServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)/upgrade$"
)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self._hs = hs
self._room_creation_handler = hs.get_room_creation_handler()
self._auth = hs.get_auth()
async def on_POST(self, request, room_id):
async def on_POST(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
@ -84,5 +90,5 @@ class RoomUpgradeRestServlet(RestServlet):
return 200, ret
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomUpgradeRestServlet(hs).register(http_server)

View file

@ -12,13 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from synapse.types import UserID
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, UserID
from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -32,13 +38,15 @@ class UserSharedRoomsServlet(RestServlet):
releases=(), # This is an unstable feature
)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.user_directory_active = hs.config.update_user_directory
async def on_GET(self, request, user_id):
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
if not self.user_directory_active:
raise SynapseError(
@ -63,5 +71,5 @@ class UserSharedRoomsServlet(RestServlet):
return 200, {"joined": list(rooms)}
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserSharedRoomsServlet(hs).register(http_server)

View file

@ -14,17 +14,26 @@
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from synapse.api.constants import Membership, PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.events.utils import (
format_event_for_client_v2_without_room_id,
format_event_raw,
)
from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.sync import KnockedSyncResult, SyncConfig
from synapse.handlers.sync import (
ArchivedSyncResult,
InvitedSyncResult,
JoinedSyncResult,
KnockedSyncResult,
SyncConfig,
SyncResult,
)
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, StreamToken
@ -192,6 +201,8 @@ class SyncRestServlet(RestServlet):
return 200, {}
time_now = self.clock.time_msec()
# We know that the the requester has an access token since appservices
# cannot use sync.
response_content = await self.encode_response(
time_now, sync_result, requester.access_token_id, filter_collection
)
@ -199,7 +210,13 @@ class SyncRestServlet(RestServlet):
logger.debug("Event formatting complete")
return 200, response_content
async def encode_response(self, time_now, sync_result, access_token_id, filter):
async def encode_response(
self,
time_now: int,
sync_result: SyncResult,
access_token_id: Optional[int],
filter: FilterCollection,
) -> JsonDict:
logger.debug("Formatting events in sync response")
if filter.event_format == "client":
event_formatter = format_event_for_client_v2_without_room_id
@ -234,7 +251,7 @@ class SyncRestServlet(RestServlet):
logger.debug("building sync response dict")
response: dict = defaultdict(dict)
response: JsonDict = defaultdict(dict)
response["next_batch"] = await sync_result.next_batch.to_string(self.store)
if sync_result.account_data:
@ -274,6 +291,8 @@ class SyncRestServlet(RestServlet):
if archived:
response["rooms"][Membership.LEAVE] = archived
# By the time we get here groups is no longer optional.
assert sync_result.groups is not None
if sync_result.groups.join:
response["groups"][Membership.JOIN] = sync_result.groups.join
if sync_result.groups.invite:
@ -284,7 +303,7 @@ class SyncRestServlet(RestServlet):
return response
@staticmethod
def encode_presence(events, time_now):
def encode_presence(events: List[UserPresenceState], time_now: int) -> JsonDict:
return {
"events": [
{
@ -299,25 +318,27 @@ class SyncRestServlet(RestServlet):
}
async def encode_joined(
self, rooms, time_now, token_id, event_fields, event_formatter
):
self,
rooms: List[JoinedSyncResult],
time_now: int,
token_id: Optional[int],
event_fields: List[str],
event_formatter: Callable[[JsonDict], JsonDict],
) -> JsonDict:
"""
Encode the joined rooms in a sync result
Args:
rooms(list[synapse.handlers.sync.JoinedSyncResult]): list of sync
results for rooms this user is joined to
time_now(int): current time - used as a baseline for age
calculations
token_id(int): ID of the user's auth token - used for namespacing
rooms: list of sync results for rooms this user is joined to
time_now: current time - used as a baseline for age calculations
token_id: ID of the user's auth token - used for namespacing
of transaction IDs
event_fields(list<str>): List of event fields to include. If empty,
event_fields: List of event fields to include. If empty,
all fields will be returned.
event_formatter (func[dict]): function to convert from federation format
event_formatter: function to convert from federation format
to client format
Returns:
dict[str, dict[str, object]]: the joined rooms list, in our
response format
The joined rooms list, in our response format
"""
joined = {}
for room in rooms:
@ -332,23 +353,26 @@ class SyncRestServlet(RestServlet):
return joined
async def encode_invited(self, rooms, time_now, token_id, event_formatter):
async def encode_invited(
self,
rooms: List[InvitedSyncResult],
time_now: int,
token_id: Optional[int],
event_formatter: Callable[[JsonDict], JsonDict],
) -> JsonDict:
"""
Encode the invited rooms in a sync result
Args:
rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of
sync results for rooms this user is invited to
time_now(int): current time - used as a baseline for age
calculations
token_id(int): ID of the user's auth token - used for namespacing
rooms: list of sync results for rooms this user is invited to
time_now: current time - used as a baseline for age calculations
token_id: ID of the user's auth token - used for namespacing
of transaction IDs
event_formatter (func[dict]): function to convert from federation format
event_formatter: function to convert from federation format
to client format
Returns:
dict[str, dict[str, object]]: the invited rooms list, in our
response format
The invited rooms list, in our response format
"""
invited = {}
for room in rooms:
@ -371,7 +395,7 @@ class SyncRestServlet(RestServlet):
self,
rooms: List[KnockedSyncResult],
time_now: int,
token_id: int,
token_id: Optional[int],
event_formatter: Callable[[Dict], Dict],
) -> Dict[str, Dict[str, Any]]:
"""
@ -422,25 +446,26 @@ class SyncRestServlet(RestServlet):
return knocked
async def encode_archived(
self, rooms, time_now, token_id, event_fields, event_formatter
):
self,
rooms: List[ArchivedSyncResult],
time_now: int,
token_id: Optional[int],
event_fields: List[str],
event_formatter: Callable[[JsonDict], JsonDict],
) -> JsonDict:
"""
Encode the archived rooms in a sync result
Args:
rooms (list[synapse.handlers.sync.ArchivedSyncResult]): list of
sync results for rooms this user is joined to
time_now(int): current time - used as a baseline for age
calculations
token_id(int): ID of the user's auth token - used for namespacing
rooms: list of sync results for rooms this user is joined to
time_now: current time - used as a baseline for age calculations
token_id: ID of the user's auth token - used for namespacing
of transaction IDs
event_fields(list<str>): List of event fields to include. If empty,
event_fields: List of event fields to include. If empty,
all fields will be returned.
event_formatter (func[dict]): function to convert from federation format
to client format
event_formatter: function to convert from federation format to client format
Returns:
dict[str, dict[str, object]]: The invited rooms list, in our
response format
The archived rooms list, in our response format
"""
joined = {}
for room in rooms:
@ -456,23 +481,27 @@ class SyncRestServlet(RestServlet):
return joined
async def encode_room(
self, room, time_now, token_id, joined, only_fields, event_formatter
):
self,
room: Union[JoinedSyncResult, ArchivedSyncResult],
time_now: int,
token_id: Optional[int],
joined: bool,
only_fields: Optional[List[str]],
event_formatter: Callable[[JsonDict], JsonDict],
) -> JsonDict:
"""
Args:
room (JoinedSyncResult|ArchivedSyncResult): sync result for a
single room
time_now (int): current time - used as a baseline for age
calculations
token_id (int): ID of the user's auth token - used for namespacing
room: sync result for a single room
time_now: current time - used as a baseline for age calculations
token_id: ID of the user's auth token - used for namespacing
of transaction IDs
joined (bool): True if the user is joined to this room - will mean
joined: True if the user is joined to this room - will mean
we handle ephemeral events
only_fields(list<str>): Optional. The list of event fields to include.
event_formatter (func[dict]): function to convert from federation format
only_fields: Optional. The list of event fields to include.
event_formatter: function to convert from federation format
to client format
Returns:
dict[str, object]: the room, encoded in our response format
The room, encoded in our response format
"""
def serialize(events):
@ -508,7 +537,7 @@ class SyncRestServlet(RestServlet):
account_data = room.account_data
result = {
result: JsonDict = {
"timeline": {
"events": serialized_timeline,
"prev_batch": await room.timeline.prev_batch.to_string(self.store),
@ -519,6 +548,7 @@ class SyncRestServlet(RestServlet):
}
if joined:
assert isinstance(room, JoinedSyncResult)
ephemeral_events = room.ephemeral
result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications
@ -528,5 +558,5 @@ class SyncRestServlet(RestServlet):
return result
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server)

View file

@ -13,12 +13,19 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -29,12 +36,14 @@ class TagListServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request, user_id, room_id):
async def on_GET(
self, request: SynapseRequest, user_id: str, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get tags for other users.")
@ -54,12 +63,14 @@ class TagServlet(RestServlet):
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)"
)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, room_id, tag):
async def on_PUT(
self, request: SynapseRequest, user_id: str, room_id: str, tag: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
@ -70,7 +81,9 @@ class TagServlet(RestServlet):
return 200, {}
async def on_DELETE(self, request, user_id, room_id, tag):
async def on_DELETE(
self, request: SynapseRequest, user_id: str, room_id: str, tag: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
@ -80,6 +93,6 @@ class TagServlet(RestServlet):
return 200, {}
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
TagListServlet(hs).register(http_server)
TagServlet(hs).register(http_server)

View file

@ -12,27 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, List, Tuple
from synapse.api.constants import ThirdPartyEntityKind
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class ThirdPartyProtocolsServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/protocols")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
protocols = await self.appservice_handler.get_3pe_protocols()
@ -42,13 +48,15 @@ class ThirdPartyProtocolsServlet(RestServlet):
class ThirdPartyProtocolServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
async def on_GET(self, request, protocol):
async def on_GET(
self, request: SynapseRequest, protocol: str
) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
protocols = await self.appservice_handler.get_3pe_protocols(
@ -63,16 +71,18 @@ class ThirdPartyProtocolServlet(RestServlet):
class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
async def on_GET(self, request, protocol):
async def on_GET(
self, request: SynapseRequest, protocol: str
) -> Tuple[int, List[JsonDict]]:
await self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args
fields: Dict[bytes, List[bytes]] = request.args # type: ignore[assignment]
fields.pop(b"access_token", None)
results = await self.appservice_handler.query_3pe(
@ -85,16 +95,18 @@ class ThirdPartyUserServlet(RestServlet):
class ThirdPartyLocationServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
async def on_GET(self, request, protocol):
async def on_GET(
self, request: SynapseRequest, protocol: str
) -> Tuple[int, List[JsonDict]]:
await self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args
fields: Dict[bytes, List[bytes]] = request.args # type: ignore[assignment]
fields.pop(b"access_token", None)
results = await self.appservice_handler.query_3pe(
@ -104,7 +116,7 @@ class ThirdPartyLocationServlet(RestServlet):
return 200, results
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ThirdPartyProtocolsServlet(hs).register(http_server)
ThirdPartyProtocolServlet(hs).register(http_server)
ThirdPartyUserServlet(hs).register(http_server)

View file

@ -12,11 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from twisted.web.server import Request
from synapse.api.errors import AuthError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
class TokenRefreshRestServlet(RestServlet):
"""
@ -26,12 +34,12 @@ class TokenRefreshRestServlet(RestServlet):
PATTERNS = client_patterns("/tokenrefresh")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
async def on_POST(self, request):
async def on_POST(self, request: Request) -> None:
raise AuthError(403, "tokenrefresh is no longer supported.")
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
TokenRefreshRestServlet(hs).register(http_server)

View file

@ -13,29 +13,32 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class UserDirectorySearchRestServlet(RestServlet):
PATTERNS = client_patterns("/user_directory/search$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""Searches for users in directory
Returns:
@ -75,5 +78,5 @@ class UserDirectorySearchRestServlet(RestServlet):
return 200, results
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserDirectorySearchRestServlet(hs).register(http_server)

View file

@ -17,9 +17,17 @@
import logging
import re
from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.api.constants import RoomCreationPreset
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -27,7 +35,7 @@ logger = logging.getLogger(__name__)
class VersionsRestServlet(RestServlet):
PATTERNS = [re.compile("^/_matrix/client/versions$")]
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.config = hs.config
@ -45,7 +53,7 @@ class VersionsRestServlet(RestServlet):
in self.config.encryption_enabled_by_default_for_room_presets
)
def on_GET(self, request):
def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
return (
200,
{
@ -89,5 +97,5 @@ class VersionsRestServlet(RestServlet):
)
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
VersionsRestServlet(hs).register(http_server)

View file

@ -15,20 +15,27 @@
import base64
import hashlib
import hmac
from typing import TYPE_CHECKING, Tuple
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
class VoipRestServlet(RestServlet):
PATTERNS = client_patterns("/voip/turnServer$", v1=True)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(
request, self.hs.config.turn_allow_guests
)
@ -69,5 +76,5 @@ class VoipRestServlet(RestServlet):
)
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
VoipRestServlet(hs).register(http_server)

View file

@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from twisted.web.server import Request
@ -414,9 +414,9 @@ class ThumbnailResource(DirectServeJsonResource):
if desired_method == "crop":
# Thumbnails that match equal or larger sizes of desired width/height.
crop_info_list = []
crop_info_list: List[Tuple[int, int, int, bool, int, Dict[str, Any]]] = []
# Other thumbnails.
crop_info_list2 = []
crop_info_list2: List[Tuple[int, int, int, bool, int, Dict[str, Any]]] = []
for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
if info["thumbnail_method"] != "crop":
@ -451,15 +451,19 @@ class ThumbnailResource(DirectServeJsonResource):
info,
)
)
# Pick the most appropriate thumbnail. Some values of `desired_width` and
# `desired_height` may result in a tie, in which case we avoid comparing on
# the thumbnail info dictionary and pick the thumbnail that appears earlier
# in the list of candidates.
if crop_info_list:
thumbnail_info = min(crop_info_list)[-1]
thumbnail_info = min(crop_info_list, key=lambda t: t[:-1])[-1]
elif crop_info_list2:
thumbnail_info = min(crop_info_list2)[-1]
thumbnail_info = min(crop_info_list2, key=lambda t: t[:-1])[-1]
elif desired_method == "scale":
# Thumbnails that match equal or larger sizes of desired width/height.
info_list = []
info_list: List[Tuple[int, bool, int, Dict[str, Any]]] = []
# Other thumbnails.
info_list2 = []
info_list2: List[Tuple[int, bool, int, Dict[str, Any]]] = []
for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
@ -477,10 +481,14 @@ class ThumbnailResource(DirectServeJsonResource):
info_list2.append(
(size_quality, type_quality, length_quality, info)
)
# Pick the most appropriate thumbnail. Some values of `desired_width` and
# `desired_height` may result in a tie, in which case we avoid comparing on
# the thumbnail info dictionary and pick the thumbnail that appears earlier
# in the list of candidates.
if info_list:
thumbnail_info = min(info_list)[-1]
thumbnail_info = min(info_list, key=lambda t: t[:-1])[-1]
elif info_list2:
thumbnail_info = min(info_list2)[-1]
thumbnail_info = min(info_list2, key=lambda t: t[:-1])[-1]
if thumbnail_info:
return FileInfo(

View file

@ -76,6 +76,7 @@ from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler
from synapse.handlers.event_auth import EventAuthHandler
from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.handlers.federation import FederationHandler
from synapse.handlers.federation_event import FederationEventHandler
from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerHandler
from synapse.handlers.identity import IdentityHandler
from synapse.handlers.initial_sync import InitialSyncHandler
@ -546,6 +547,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_federation_handler(self) -> FederationHandler:
return FederationHandler(self)
@cache_in_self
def get_federation_event_handler(self) -> FederationEventHandler:
return FederationEventHandler(self)
@cache_in_self
def get_identity_handler(self) -> IdentityHandler:
return IdentityHandler(self)

View file

@ -12,26 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
from typing import TYPE_CHECKING, Optional
from synapse.api.constants import EventTypes, Membership, RoomCreationPreset
from synapse.events import EventBase
from synapse.types import UserID, create_requester
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
SERVER_NOTICE_ROOM_TAG = "m.server_notice"
class ServerNoticesManager:
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastore()
self._config = hs.config
self._account_data_handler = hs.get_account_data_handler()
@ -58,6 +55,7 @@ class ServerNoticesManager:
event_content: dict,
type: str = EventTypes.Message,
state_key: Optional[str] = None,
txn_id: Optional[str] = None,
) -> EventBase:
"""Send a notice to the given user
@ -68,6 +66,7 @@ class ServerNoticesManager:
event_content: content of event to send
type: type of event
is_state_event: Is the event a state event
txn_id: The transaction ID.
"""
room_id = await self.get_or_create_notice_room_for_user(user_id)
await self.maybe_invite_user_to_room(user_id, room_id)
@ -90,7 +89,7 @@ class ServerNoticesManager:
event_dict["state_key"] = state_key
event, _ = await self._event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, ratelimit=False
requester, event_dict, ratelimit=False, txn_id=txn_id
)
return event

View file

@ -57,4 +57,8 @@ textarea, input {
background-color: #f8f8f8;
border: 1px #ccc solid;
}
}
.error {
color: red;
}

View file

@ -63,6 +63,7 @@ from .relations import RelationsStore
from .room import RoomStore
from .roommember import RoomMemberStore
from .search import SearchStore
from .session import SessionStore
from .signatures import SignatureStore
from .state import StateStore
from .stats import StatsStore
@ -121,6 +122,7 @@ class DataStore(
ServerMetricsStore,
EventForwardExtremitiesStore,
LockStore,
SessionStore,
):
def __init__(self, database: DatabasePool, db_conn, hs):
self.hs = hs

View file

@ -520,16 +520,26 @@ class EventsWorkerStore(SQLBaseStore):
# We now look up if we're already fetching some of the events in the DB,
# if so we wait for those lookups to finish instead of pulling the same
# events out of the DB multiple times.
already_fetching: Dict[str, defer.Deferred] = {}
#
# Note: we might get the same `ObservableDeferred` back for multiple
# events we're already fetching, so we deduplicate the deferreds to
# avoid extraneous work (if we don't do this we can end up in a n^2 mode
# when we wait on the same Deferred N times, then try and merge the
# same dict into itself N times).
already_fetching_ids: Set[str] = set()
already_fetching_deferreds: Set[
ObservableDeferred[Dict[str, _EventCacheEntry]]
] = set()
for event_id in missing_events_ids:
deferred = self._current_event_fetches.get(event_id)
if deferred is not None:
# We're already pulling the event out of the DB. Add the deferred
# to the collection of deferreds to wait on.
already_fetching[event_id] = deferred.observe()
already_fetching_ids.add(event_id)
already_fetching_deferreds.add(deferred)
missing_events_ids.difference_update(already_fetching)
missing_events_ids.difference_update(already_fetching_ids)
if missing_events_ids:
log_ctx = current_context()
@ -569,18 +579,25 @@ class EventsWorkerStore(SQLBaseStore):
with PreserveLoggingContext():
fetching_deferred.callback(missing_events)
if already_fetching:
if already_fetching_deferreds:
# Wait for the other event requests to finish and add their results
# to ours.
results = await make_deferred_yieldable(
defer.gatherResults(
already_fetching.values(),
(d.observe() for d in already_fetching_deferreds),
consumeErrors=True,
)
).addErrback(unwrapFirstError)
for result in results:
event_entry_map.update(result)
# We filter out events that we haven't asked for as we might get
# a *lot* of superfluous events back, and there is no point
# going through and inserting them all (which can take time).
event_entry_map.update(
(event_id, entry)
for event_id, entry in result.items()
if event_id in already_fetching_ids
)
if not allow_rejected:
event_entry_map = {

View file

@ -295,6 +295,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
self._invalidate_cache_and_stream(
txn, self.have_seen_event, (room_id, event_id)
)
self._invalidate_get_event_cache(event_id)
logger.info("[purge] done")

View file

@ -48,6 +48,11 @@ class PusherWorkerStore(SQLBaseStore):
self._remove_stale_pushers,
)
self.db_pool.updates.register_background_update_handler(
"remove_deleted_email_pushers",
self._remove_deleted_email_pushers,
)
def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
"""JSON-decode the data in the rows returned from the `pushers` table
@ -388,6 +393,74 @@ class PusherWorkerStore(SQLBaseStore):
return number_deleted
async def _remove_deleted_email_pushers(
self, progress: dict, batch_size: int
) -> int:
"""A background update that deletes all pushers for deleted email addresses.
In previous versions of synapse, when users deleted their email address, it didn't
also delete all the pushers for that email address. This background update removes
those to prevent unwanted emails. This should only need to be run once (when users
upgrade to v1.42.0
Args:
progress: dict used to store progress of this background update
batch_size: the maximum number of rows to retrieve in a single select query
Returns:
The number of deleted rows
"""
last_pusher = progress.get("last_pusher", 0)
def _delete_pushers(txn) -> int:
sql = """
SELECT p.id, p.user_name, p.app_id, p.pushkey
FROM pushers AS p
LEFT JOIN user_threepids AS t
ON t.user_id = p.user_name
AND t.medium = 'email'
AND t.address = p.pushkey
WHERE t.user_id is NULL
AND p.app_id = 'm.email'
AND p.id > ?
ORDER BY p.id ASC
LIMIT ?
"""
txn.execute(sql, (last_pusher, batch_size))
rows = txn.fetchall()
last = None
num_deleted = 0
for row in rows:
last = row[0]
num_deleted += 1
self.db_pool.simple_delete_txn(
txn,
"pushers",
{"user_name": row[1], "app_id": row[2], "pushkey": row[3]},
)
if last is not None:
self.db_pool.updates._background_update_progress_txn(
txn, "remove_deleted_email_pushers", {"last_pusher": last}
)
return num_deleted
number_deleted = await self.db_pool.runInteraction(
"_remove_deleted_email_pushers", _delete_pushers
)
if number_deleted < batch_size:
await self.db_pool.updates._end_background_update(
"remove_deleted_email_pushers"
)
return number_deleted
class PusherStore(PusherWorkerStore):
def get_pushers_stream_token(self) -> int:

View file

@ -754,16 +754,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
return user_id
def get_user_id_by_threepid_txn(self, txn, medium, address):
def get_user_id_by_threepid_txn(
self, txn, medium: str, address: str
) -> Optional[str]:
"""Returns user id from threepid
Args:
txn (cursor):
medium (str): threepid medium e.g. email
address (str): threepid address e.g. me@example.com
medium: threepid medium e.g. email
address: threepid address e.g. me@example.com
Returns:
str|None: user id or None if no user id/threepid mapping exists
user id, or None if no user id/threepid mapping exists
"""
ret = self.db_pool.simple_select_one_txn(
txn,
@ -776,14 +778,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return ret["user_id"]
return None
async def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
async def user_add_threepid(
self,
user_id: str,
medium: str,
address: str,
validated_at: int,
added_at: int,
) -> None:
await self.db_pool.simple_upsert(
"user_threepids",
{"medium": medium, "address": address},
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
)
async def user_get_threepids(self, user_id):
async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]:
return await self.db_pool.simple_select_list(
"user_threepids",
{"user_id": user_id},
@ -791,7 +800,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"user_get_threepids",
)
async def user_delete_threepid(self, user_id, medium, address) -> None:
async def user_delete_threepid(
self, user_id: str, medium: str, address: str
) -> None:
await self.db_pool.simple_delete(
"user_threepids",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
@ -1157,6 +1168,322 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="update_access_token_last_validated",
)
async def registration_token_is_valid(self, token: str) -> bool:
"""Checks if a token can be used to authenticate a registration.
Args:
token: The registration token to be checked
Returns:
True if the token is valid, False otherwise.
"""
res = await self.db_pool.simple_select_one(
"registration_tokens",
keyvalues={"token": token},
retcols=["uses_allowed", "pending", "completed", "expiry_time"],
allow_none=True,
)
# Check if the token exists
if res is None:
return False
# Check if the token has expired
now = self._clock.time_msec()
if res["expiry_time"] and res["expiry_time"] < now:
return False
# Check if the token has been used up
if (
res["uses_allowed"]
and res["pending"] + res["completed"] >= res["uses_allowed"]
):
return False
# Otherwise, the token is valid
return True
async def set_registration_token_pending(self, token: str) -> None:
"""Increment the pending registrations counter for a token.
Args:
token: The registration token pending use
"""
def _set_registration_token_pending_txn(txn):
pending = self.db_pool.simple_select_one_onecol_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
retcol="pending",
)
self.db_pool.simple_update_one_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
updatevalues={"pending": pending + 1},
)
return await self.db_pool.runInteraction(
"set_registration_token_pending", _set_registration_token_pending_txn
)
async def use_registration_token(self, token: str) -> None:
"""Complete a use of the given registration token.
The `pending` counter will be decremented, and the `completed`
counter will be incremented.
Args:
token: The registration token to be 'used'
"""
def _use_registration_token_txn(txn):
# Normally, res is Optional[Dict[str, Any]].
# Override type because the return type is only optional if
# allow_none is True, and we don't want mypy throwing errors
# about None not being indexable.
res: Dict[str, Any] = self.db_pool.simple_select_one_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
retcols=["pending", "completed"],
) # type: ignore
# Decrement pending and increment completed
self.db_pool.simple_update_one_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
updatevalues={
"completed": res["completed"] + 1,
"pending": res["pending"] - 1,
},
)
return await self.db_pool.runInteraction(
"use_registration_token", _use_registration_token_txn
)
async def get_registration_tokens(
self, valid: Optional[bool] = None
) -> List[Dict[str, Any]]:
"""List all registration tokens. Used by the admin API.
Args:
valid: If True, only valid tokens are returned.
If False, only invalid tokens are returned.
Default is None: return all tokens regardless of validity.
Returns:
A list of dicts, each containing details of a token.
"""
def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]):
if valid is None:
# Return all tokens regardless of validity
txn.execute("SELECT * FROM registration_tokens")
elif valid:
# Select valid tokens only
sql = (
"SELECT * FROM registration_tokens WHERE "
"(uses_allowed > pending + completed OR uses_allowed IS NULL) "
"AND (expiry_time > ? OR expiry_time IS NULL)"
)
txn.execute(sql, [now])
else:
# Select invalid tokens only
sql = (
"SELECT * FROM registration_tokens WHERE "
"uses_allowed <= pending + completed OR expiry_time <= ?"
)
txn.execute(sql, [now])
return self.db_pool.cursor_to_dict(txn)
return await self.db_pool.runInteraction(
"select_registration_tokens",
select_registration_tokens_txn,
self._clock.time_msec(),
valid,
)
async def get_one_registration_token(self, token: str) -> Optional[Dict[str, Any]]:
"""Get info about the given registration token. Used by the admin API.
Args:
token: The token to retrieve information about.
Returns:
A dict, or None if token doesn't exist.
"""
return await self.db_pool.simple_select_one(
"registration_tokens",
keyvalues={"token": token},
retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
allow_none=True,
desc="get_one_registration_token",
)
async def generate_registration_token(
self, length: int, chars: str
) -> Optional[str]:
"""Generate a random registration token. Used by the admin API.
Args:
length: The length of the token to generate.
chars: A string of the characters allowed in the generated token.
Returns:
The generated token.
Raises:
SynapseError if a unique registration token could still not be
generated after a few tries.
"""
# Make a few attempts at generating a unique token of the required
# length before failing.
for _i in range(3):
# Generate token
token = "".join(random.choices(chars, k=length))
# Check if the token already exists
existing_token = await self.db_pool.simple_select_one_onecol(
"registration_tokens",
keyvalues={"token": token},
retcol="token",
allow_none=True,
desc="check_if_registration_token_exists",
)
if existing_token is None:
# The generated token doesn't exist yet, return it
return token
raise SynapseError(
500,
"Unable to generate a unique registration token. Try again with a greater length",
Codes.UNKNOWN,
)
async def create_registration_token(
self, token: str, uses_allowed: Optional[int], expiry_time: Optional[int]
) -> bool:
"""Create a new registration token. Used by the admin API.
Args:
token: The token to create.
uses_allowed: The number of times the token can be used to complete
a registration before it becomes invalid. A value of None indicates
unlimited uses.
expiry_time: The latest time the token is valid. Given as the
number of milliseconds since 1970-01-01 00:00:00 UTC. A value of
None indicates that the token does not expire.
Returns:
Whether the row was inserted or not.
"""
def _create_registration_token_txn(txn):
row = self.db_pool.simple_select_one_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
retcols=["token"],
allow_none=True,
)
if row is not None:
# Token already exists
return False
self.db_pool.simple_insert_txn(
txn,
"registration_tokens",
values={
"token": token,
"uses_allowed": uses_allowed,
"pending": 0,
"completed": 0,
"expiry_time": expiry_time,
},
)
return True
return await self.db_pool.runInteraction(
"create_registration_token", _create_registration_token_txn
)
async def update_registration_token(
self, token: str, updatevalues: Dict[str, Optional[int]]
) -> Optional[Dict[str, Any]]:
"""Update a registration token. Used by the admin API.
Args:
token: The token to update.
updatevalues: A dict with the fields to update. E.g.:
`{"uses_allowed": 3}` to update just uses_allowed, or
`{"uses_allowed": 3, "expiry_time": None}` to update both.
This is passed straight to simple_update_one.
Returns:
A dict with all info about the token, or None if token doesn't exist.
"""
def _update_registration_token_txn(txn):
try:
self.db_pool.simple_update_one_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
updatevalues=updatevalues,
)
except StoreError:
# Update failed because token does not exist
return None
# Get all info about the token so it can be sent in the response
return self.db_pool.simple_select_one_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
retcols=[
"token",
"uses_allowed",
"pending",
"completed",
"expiry_time",
],
allow_none=True,
)
return await self.db_pool.runInteraction(
"update_registration_token", _update_registration_token_txn
)
async def delete_registration_token(self, token: str) -> bool:
"""Delete a registration token. Used by the admin API.
Args:
token: The token to delete.
Returns:
Whether the token was successfully deleted or not.
"""
try:
await self.db_pool.simple_delete_one(
"registration_tokens",
keyvalues={"token": token},
desc="delete_registration_token",
)
except StoreError:
# Deletion failed because token does not exist
return False
return True
@cached()
async def mark_access_token_as_used(self, token_id: int) -> None:
"""

View file

@ -307,7 +307,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
@cached()
async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser:
async def get_invited_rooms_for_local_user(
self, user_id: str
) -> List[RoomsForUser]:
"""Get all the rooms the *local* user is invited to.
Args:
@ -384,9 +386,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
sql = """
SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering, r.room_version
FROM local_current_membership AS c
INNER JOIN events AS e USING (room_id, event_id)
INNER JOIN rooms AS r USING (room_id)
WHERE
user_id = ?
AND %s
@ -395,7 +398,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
txn.execute(sql, (user_id, *args))
results = [RoomsForUser(**r) for r in self.db_pool.cursor_to_dict(txn)]
results = [RoomsForUser(*r) for r in txn]
return results
@ -445,7 +448,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
Returns:
Returns the rooms the user is in currently, along with the stream
ordering of the most recent join for that user and room.
ordering of the most recent join for that user and room, along with
the room version of the room.
"""
return await self.db_pool.runInteraction(
"get_rooms_for_user_with_stream_ordering",
@ -522,7 +526,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
_get_users_server_still_shares_room_with_txn,
)
async def get_rooms_for_user(self, user_id: str, on_invalidate=None):
async def get_rooms_for_user(
self, user_id: str, on_invalidate=None
) -> FrozenSet[str]:
"""Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently

View file

@ -0,0 +1,145 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
import synapse.util.stringutils as stringutils
from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.types import JsonDict
from synapse.util import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
class SessionStore(SQLBaseStore):
"""
A store for generic session data.
Each type of session should provide a unique type (to separate sessions).
Sessions are automatically removed when they expire.
"""
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
# Create a background job for culling expired sessions.
if hs.config.run_background_tasks:
self._clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000)
async def create_session(
self, session_type: str, value: JsonDict, expiry_ms: int
) -> str:
"""
Creates a new pagination session for the room hierarchy endpoint.
Args:
session_type: The type for this session.
value: The value to store.
expiry_ms: How long before an item is evicted from the cache
in milliseconds. Default is 0, indicating items never get
evicted based on time.
Returns:
The newly created session ID.
Raises:
StoreError if a unique session ID cannot be generated.
"""
# autogen a session ID and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
attempts = 0
while attempts < 5:
session_id = stringutils.random_string(24)
try:
await self.db_pool.simple_insert(
table="sessions",
values={
"session_id": session_id,
"session_type": session_type,
"value": json_encoder.encode(value),
"expiry_time_ms": self.hs.get_clock().time_msec() + expiry_ms,
},
desc="create_session",
)
return session_id
except self.db_pool.engine.module.IntegrityError:
attempts += 1
raise StoreError(500, "Couldn't generate a session ID.")
async def get_session(self, session_type: str, session_id: str) -> JsonDict:
"""
Retrieve data stored with create_session
Args:
session_type: The type for this session.
session_id: The session ID returned from create_session.
Raises:
StoreError if the session cannot be found.
"""
def _get_session(
txn: LoggingTransaction, session_type: str, session_id: str, ts: int
) -> JsonDict:
# This includes the expiry time since items are only periodically
# deleted, not upon expiry.
select_sql = """
SELECT value FROM sessions WHERE
session_type = ? AND session_id = ? AND expiry_time_ms > ?
"""
txn.execute(select_sql, [session_type, session_id, ts])
row = txn.fetchone()
if not row:
raise StoreError(404, "No session")
return db_to_json(row[0])
return await self.db_pool.runInteraction(
"get_session",
_get_session,
session_type,
session_id,
self._clock.time_msec(),
)
@wrap_as_background_process("delete_expired_sessions")
async def _delete_expired_sessions(self) -> None:
"""Remove sessions with expiry dates that have passed."""
def _delete_expired_sessions_txn(txn: LoggingTransaction, ts: int) -> None:
sql = "DELETE FROM sessions WHERE expiry_time_ms <= ?"
txn.execute(sql, (ts,))
await self.db_pool.runInteraction(
"delete_expired_sessions",
_delete_expired_sessions_txn,
self._clock.time_msec(),
)

View file

@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import attr
from synapse.api.constants import LoginType
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
@ -329,6 +330,48 @@ class UIAuthWorkerStore(SQLBaseStore):
keyvalues={},
)
# If a registration token was used, decrement the pending counter
# before deleting the session.
rows = self.db_pool.simple_select_many_txn(
txn,
table="ui_auth_sessions_credentials",
column="session_id",
iterable=session_ids,
keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
retcols=["result"],
)
# Get the tokens used and how much pending needs to be decremented by.
token_counts: Dict[str, int] = {}
for r in rows:
# If registration was successfully completed, the result of the
# registration token stage for that session will be True.
# If a token was used to authenticate, but registration was
# never completed, the result will be the token used.
token = db_to_json(r["result"])
if isinstance(token, str):
token_counts[token] = token_counts.get(token, 0) + 1
# Update the `pending` counters.
if len(token_counts) > 0:
token_rows = self.db_pool.simple_select_many_txn(
txn,
table="registration_tokens",
column="token",
iterable=list(token_counts.keys()),
keyvalues={},
retcols=["token", "pending"],
)
for token_row in token_rows:
token = token_row["token"]
new_pending = token_row["pending"] - token_counts[token]
self.db_pool.simple_update_one_txn(
txn,
table="registration_tokens",
keyvalues={"token": token},
updatevalues={"pending": new_pending},
)
# Delete the corresponding completed credentials.
self.db_pool.simple_delete_many_txn(
txn,

View file

@ -365,7 +365,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return False
async def update_profile_in_user_dir(
self, user_id: str, display_name: str, avatar_url: str
self, user_id: str, display_name: Optional[str], avatar_url: Optional[str]
) -> None:
"""
Update or add a user's profile in the user directory.

View file

@ -14,25 +14,41 @@
# limitations under the License.
import logging
from collections import namedtuple
from typing import List, Optional, Tuple
import attr
from synapse.types import PersistedEventPosition
logger = logging.getLogger(__name__)
RoomsForUser = namedtuple(
"RoomsForUser", ("room_id", "sender", "membership", "event_id", "stream_ordering")
)
GetRoomsForUserWithStreamOrdering = namedtuple(
"GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
)
@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
class RoomsForUser:
room_id: str
sender: str
membership: str
event_id: str
stream_ordering: int
room_version_id: str
# We store this using a namedtuple so that we save about 3x space over using a
# dict.
ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name"))
@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
class GetRoomsForUserWithStreamOrdering:
room_id: str
event_pos: PersistedEventPosition
# "members" points to a truncated list of (user_id, event_id) tuples for users of
# a given membership type, suitable for use in calculating heroes for a room.
# "count" points to the total numberr of users of a given membership type.
MemberSummary = namedtuple("MemberSummary", ("members", "count"))
@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
class ProfileInfo:
avatar_url: Optional[str]
display_name: Optional[str]
@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
class MemberSummary:
# A truncated list of (user_id, event_id) tuples for users of a given
# membership type, suitable for use in calculating heroes for a room.
members: List[Tuple[str, str]]
# The total number of users of a given membership type.
count: int

View file

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# When updating these values, please leave a short summary of the changes below.
SCHEMA_VERSION = 63
"""Represents the expectations made by the codebase about the database schema

View file

@ -0,0 +1,23 @@
/* Copyright 2021 Callum Brown
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS registration_tokens(
token TEXT NOT NULL, -- The token that can be used for authentication.
uses_allowed INT, -- The total number of times this token can be used. NULL if no limit.
pending INT NOT NULL, -- The number of in progress registrations using this token.
completed INT NOT NULL, -- The number of times this token has been used to complete a registration.
expiry_time BIGINT, -- The latest time this token will be valid (epoch time in milliseconds). NULL if token doesn't expire.
UNIQUE (token)
);

View file

@ -0,0 +1,20 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- We may not have deleted all pushers for emails that are no longer linked
-- to an account, so we set up a background job to delete them.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(6302, 'remove_deleted_email_pushers', '{}');

View file

@ -0,0 +1,23 @@
/*
* Copyright 2021 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS sessions(
session_type TEXT NOT NULL, -- The unique key for this type of session.
session_id TEXT NOT NULL, -- The session ID passed to the client.
value TEXT NOT NULL, -- A JSON dictionary to persist.
expiry_time_ms BIGINT NOT NULL, -- The time this session will expire (epoch time in milliseconds).
UNIQUE (session_type, session_id)
);