Merge remote-tracking branch 'upstream/release-v1.48'

This commit is contained in:
Tulir Asokan 2021-11-25 18:33:37 +02:00
commit 9f4fa40b64
175 changed files with 6413 additions and 1993 deletions

View file

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable
from synapse.http.server import HttpServer, JsonResource
from synapse.rest import admin
@ -62,6 +62,8 @@ from synapse.rest.client import (
if TYPE_CHECKING:
from synapse.server import HomeServer
RegisterServletsFunc = Callable[["HomeServer", HttpServer], None]
class ClientRestResource(JsonResource):
"""Matrix Client API REST resource.

View file

@ -28,6 +28,7 @@ from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.rest.admin.background_updates import (
BackgroundUpdateEnabledRestServlet,
BackgroundUpdateRestServlet,
BackgroundUpdateStartJobRestServlet,
)
from synapse.rest.admin.devices import (
DeleteDevicesRestServlet,
@ -46,6 +47,9 @@ from synapse.rest.admin.registration_tokens import (
RegistrationTokenRestServlet,
)
from synapse.rest.admin.rooms import (
BlockRoomRestServlet,
DeleteRoomStatusByDeleteIdRestServlet,
DeleteRoomStatusByRoomIdRestServlet,
ForwardExtremitiesRestServlet,
JoinRoomAliasServlet,
ListRoomRestServlet,
@ -53,6 +57,7 @@ from synapse.rest.admin.rooms import (
RoomEventContextServlet,
RoomMembersRestServlet,
RoomRestServlet,
RoomRestV2Servlet,
RoomStateRestServlet,
)
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
@ -220,10 +225,14 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
Register all the admin servlets.
"""
register_servlets_for_client_rest_resource(hs, http_server)
BlockRoomRestServlet(hs).register(http_server)
ListRoomRestServlet(hs).register(http_server)
RoomStateRestServlet(hs).register(http_server)
RoomRestServlet(hs).register(http_server)
RoomRestV2Servlet(hs).register(http_server)
RoomMembersRestServlet(hs).register(http_server)
DeleteRoomStatusByDeleteIdRestServlet(hs).register(http_server)
DeleteRoomStatusByRoomIdRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
VersionServlet(hs).register(http_server)
UserAdminServlet(hs).register(http_server)
@ -253,6 +262,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SendServerNoticeServlet(hs).register(http_server)
BackgroundUpdateEnabledRestServlet(hs).register(http_server)
BackgroundUpdateRestServlet(hs).register(http_server)
BackgroundUpdateStartJobRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(

View file

@ -12,10 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
from synapse.types import JsonDict
@ -29,37 +34,36 @@ logger = logging.getLogger(__name__)
class BackgroundUpdateEnabledRestServlet(RestServlet):
"""Allows temporarily disabling background updates"""
PATTERNS = admin_patterns("/background_updates/enabled")
PATTERNS = admin_patterns("/background_updates/enabled$")
def __init__(self, hs: "HomeServer"):
self.group_server = hs.get_groups_server_handler()
self.is_mine_id = hs.is_mine_id
self.auth = hs.get_auth()
self.data_stores = hs.get_datastores()
self._auth = hs.get_auth()
self._data_stores = hs.get_datastores()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
requester = await self._auth.get_user_by_req(request)
await assert_user_is_admin(self._auth, requester.user)
# We need to check that all configured databases have updates enabled.
# (They *should* all be in sync.)
enabled = all(db.updates.enabled for db in self.data_stores.databases)
enabled = all(db.updates.enabled for db in self._data_stores.databases)
return 200, {"enabled": enabled}
return HTTPStatus.OK, {"enabled": enabled}
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
requester = await self._auth.get_user_by_req(request)
await assert_user_is_admin(self._auth, requester.user)
body = parse_json_object_from_request(request)
enabled = body.get("enabled", True)
if not isinstance(enabled, bool):
raise SynapseError(400, "'enabled' parameter must be a boolean")
raise SynapseError(
HTTPStatus.BAD_REQUEST, "'enabled' parameter must be a boolean"
)
for db in self.data_stores.databases:
for db in self._data_stores.databases:
db.updates.enabled = enabled
# If we're re-enabling them ensure that we start the background
@ -67,32 +71,29 @@ class BackgroundUpdateEnabledRestServlet(RestServlet):
if enabled:
db.updates.start_doing_background_updates()
return 200, {"enabled": enabled}
return HTTPStatus.OK, {"enabled": enabled}
class BackgroundUpdateRestServlet(RestServlet):
"""Fetch information about background updates"""
PATTERNS = admin_patterns("/background_updates/status")
PATTERNS = admin_patterns("/background_updates/status$")
def __init__(self, hs: "HomeServer"):
self.group_server = hs.get_groups_server_handler()
self.is_mine_id = hs.is_mine_id
self.auth = hs.get_auth()
self.data_stores = hs.get_datastores()
self._auth = hs.get_auth()
self._data_stores = hs.get_datastores()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
requester = await self._auth.get_user_by_req(request)
await assert_user_is_admin(self._auth, requester.user)
# We need to check that all configured databases have updates enabled.
# (They *should* all be in sync.)
enabled = all(db.updates.enabled for db in self.data_stores.databases)
enabled = all(db.updates.enabled for db in self._data_stores.databases)
current_updates = {}
for db in self.data_stores.databases:
for db in self._data_stores.databases:
update = db.updates.get_current_update()
if not update:
continue
@ -104,4 +105,72 @@ class BackgroundUpdateRestServlet(RestServlet):
"average_items_per_ms": update.average_items_per_ms(),
}
return 200, {"enabled": enabled, "current_updates": current_updates}
return HTTPStatus.OK, {"enabled": enabled, "current_updates": current_updates}
class BackgroundUpdateStartJobRestServlet(RestServlet):
"""Allows to start specific background updates"""
PATTERNS = admin_patterns("/background_updates/start_job")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
self._store = hs.get_datastore()
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request)
await assert_user_is_admin(self._auth, requester.user)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["job_name"])
job_name = body["job_name"]
if job_name == "populate_stats_process_rooms":
jobs = [
{
"update_name": "populate_stats_process_rooms",
"progress_json": "{}",
},
]
elif job_name == "regenerate_directory":
jobs = [
{
"update_name": "populate_user_directory_createtables",
"progress_json": "{}",
"depends_on": "",
},
{
"update_name": "populate_user_directory_process_rooms",
"progress_json": "{}",
"depends_on": "populate_user_directory_createtables",
},
{
"update_name": "populate_user_directory_process_users",
"progress_json": "{}",
"depends_on": "populate_user_directory_process_rooms",
},
{
"update_name": "populate_user_directory_cleanup",
"progress_json": "{}",
"depends_on": "populate_user_directory_process_users",
},
]
else:
raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid job_name")
try:
await self._store.db_pool.simple_insert_many(
table="background_updates",
values=jobs,
desc=f"admin_api_run_{job_name}",
)
except self._store.db_pool.engine.module.IntegrityError:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Job %s is already in queue of background updates." % (job_name,),
)
self._store.db_pool.updates.start_doing_background_updates()
return HTTPStatus.OK, {}

View file

@ -13,7 +13,7 @@
# limitations under the License.
import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from urllib import parse as urlparse
from synapse.api.constants import EventTypes, JoinRules, Membership
@ -34,7 +34,7 @@ from synapse.rest.admin._base import (
assert_user_is_admin,
)
from synapse.storage.databases.main.room import RoomSortOrder
from synapse.types import JsonDict, UserID, create_requester
from synapse.types import JsonDict, RoomID, UserID, create_requester
from synapse.util import json_decoder
if TYPE_CHECKING:
@ -46,6 +46,138 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class RoomRestV2Servlet(RestServlet):
"""Delete a room from server asynchronously with a background task.
It is a combination and improvement of shutdown and purge room.
Shuts down a room by removing all local users from the room.
Blocking all future invites and joins to the room is optional.
If desired any local aliases will be repointed to a new room
created by `new_room_user_id` and kicked users will be auto-
joined to the new room.
If 'purge' is true, it will remove all traces of a room from the database.
"""
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$", "v2")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
self._store = hs.get_datastore()
self._pagination_handler = hs.get_pagination_handler()
async def on_DELETE(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request)
await assert_user_is_admin(self._auth, requester.user)
content = parse_json_object_from_request(request)
block = content.get("block", False)
if not isinstance(block, bool):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Param 'block' must be a boolean, if given",
Codes.BAD_JSON,
)
purge = content.get("purge", True)
if not isinstance(purge, bool):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Param 'purge' must be a boolean, if given",
Codes.BAD_JSON,
)
force_purge = content.get("force_purge", False)
if not isinstance(force_purge, bool):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Param 'force_purge' must be a boolean, if given",
Codes.BAD_JSON,
)
if not RoomID.is_valid(room_id):
raise SynapseError(400, "%s is not a legal room ID" % (room_id,))
if not await self._store.get_room(room_id):
raise NotFoundError("Unknown room id %s" % (room_id,))
delete_id = self._pagination_handler.start_shutdown_and_purge_room(
room_id=room_id,
new_room_user_id=content.get("new_room_user_id"),
new_room_name=content.get("room_name"),
message=content.get("message"),
requester_user_id=requester.user.to_string(),
block=block,
purge=purge,
force_purge=force_purge,
)
return 200, {"delete_id": delete_id}
class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
"""Get the status of the delete room background task."""
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete_status$", "v2")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
self._pagination_handler = hs.get_pagination_handler()
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self._auth, request)
if not RoomID.is_valid(room_id):
raise SynapseError(400, "%s is not a legal room ID" % (room_id,))
delete_ids = self._pagination_handler.get_delete_ids_by_room(room_id)
if delete_ids is None:
raise NotFoundError("No delete task for room_id '%s' found" % room_id)
response = []
for delete_id in delete_ids:
delete = self._pagination_handler.get_delete_status(delete_id)
if delete:
response += [
{
"delete_id": delete_id,
**delete.asdict(),
}
]
return 200, {"results": cast(JsonDict, response)}
class DeleteRoomStatusByDeleteIdRestServlet(RestServlet):
"""Get the status of the delete room background task."""
PATTERNS = admin_patterns("/rooms/delete_status/(?P<delete_id>[^/]+)$", "v2")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
self._pagination_handler = hs.get_pagination_handler()
async def on_GET(
self, request: SynapseRequest, delete_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self._auth, request)
delete_status = self._pagination_handler.get_delete_status(delete_id)
if delete_status is None:
raise NotFoundError("delete id '%s' not found" % delete_id)
return 200, cast(JsonDict, delete_status.asdict())
class ListRoomRestServlet(RestServlet):
"""
List all rooms that are known to the homeserver. Results are returned
@ -239,9 +371,22 @@ class RoomRestServlet(RestServlet):
# Purge room
if purge:
await pagination_handler.purge_room(room_id, force=force_purge)
try:
await pagination_handler.purge_room(room_id, force=force_purge)
except NotFoundError:
if block:
# We can block unknown rooms with this endpoint, in which case
# a failed purge is expected.
pass
else:
# But otherwise, we expect this purge to have succeeded.
raise
return 200, ret
# Cast safety: cast away the knowledge that this is a TypedDict.
# See https://github.com/python/mypy/issues/4976#issuecomment-579883622
# for some discussion on why this is necessary. Either way,
# `ret` is an opaque dictionary blob as far as the rest of the app cares.
return 200, cast(JsonDict, ret)
class RoomMembersRestServlet(RestServlet):
@ -303,7 +448,7 @@ class RoomStateRestServlet(RestServlet):
now,
# We don't bother bundling aggregations in when asked for state
# events, as clients won't use them.
bundle_aggregations=False,
bundle_relations=False,
)
ret = {"state": room_state}
@ -583,6 +728,7 @@ class RoomEventContextServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
self._hs = hs
self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer()
@ -600,7 +746,9 @@ class RoomEventContextServlet(RestServlet):
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
event_filter: Optional[Filter] = Filter(
self._hs, json_decoder.decode(filter_json)
)
else:
event_filter = None
@ -630,7 +778,70 @@ class RoomEventContextServlet(RestServlet):
results["state"],
time_now,
# No need to bundle aggregations for state events
bundle_aggregations=False,
bundle_relations=False,
)
return 200, results
class BlockRoomRestServlet(RestServlet):
"""
Manage blocking of rooms.
On PUT: Add or remove a room from blocking list.
On GET: Get blocking status of room and user who has blocked this room.
"""
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/block$")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
self._store = hs.get_datastore()
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self._auth, request)
if not RoomID.is_valid(room_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,)
)
blocked_by = await self._store.room_is_blocked_by(room_id)
# Test `not None` if `user_id` is an empty string
# if someone add manually an entry in database
if blocked_by is not None:
response = {"block": True, "user_id": blocked_by}
else:
response = {"block": False}
return HTTPStatus.OK, response
async def on_PUT(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request)
await assert_user_is_admin(self._auth, requester.user)
content = parse_json_object_from_request(request)
if not RoomID.is_valid(room_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,)
)
assert_params_in_dict(content, ["block"])
block = content.get("block")
if not isinstance(block, bool):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Param 'block' must be a boolean.",
Codes.BAD_JSON,
)
if block:
await self._store.block_room(room_id, requester.user.to_string())
else:
await self._store.unblock_room(room_id)
return HTTPStatus.OK, {"block": block}

View file

@ -898,7 +898,7 @@ class UserTokenRestServlet(RestServlet):
if auth_user.to_string() == user_id:
raise SynapseError(400, "Cannot use admin API to login as self")
token = await self.auth_handler.get_access_token_for_user_id(
token = await self.auth_handler.create_access_token_for_user_id(
user_id=auth_user.to_string(),
device_id=None,
valid_until_ms=valid_until_ms,
@ -909,7 +909,7 @@ class UserTokenRestServlet(RestServlet):
class ShadowBanRestServlet(RestServlet):
"""An admin API for shadow-banning a user.
"""An admin API for controlling whether a user is shadow-banned.
A shadow-banned users receives successful responses to their client-server
API requests, but the events are not propagated into rooms.
@ -917,11 +917,19 @@ class ShadowBanRestServlet(RestServlet):
Shadow-banning a user should be used as a tool of last resort and may lead
to confusing or broken behaviour for the client.
Example:
Example of shadow-banning a user:
POST /_synapse/admin/v1/users/@test:example.com/shadow_ban
{}
200 OK
{}
Example of removing a user from being shadow-banned:
DELETE /_synapse/admin/v1/users/@test:example.com/shadow_ban
{}
200 OK
{}
"""
@ -945,6 +953,18 @@ class ShadowBanRestServlet(RestServlet):
return 200, {}
async def on_DELETE(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id):
raise SynapseError(400, "Only local users can be shadow-banned")
await self.store.set_shadow_banned(UserID.from_string(user_id), False)
return 200, {}
class RateLimitRestServlet(RestServlet):
"""An admin API to override ratelimiting for an user.

View file

@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
def client_patterns(
path_regex: str,
releases: Iterable[int] = (0,),
releases: Iterable[str] = ("r0", "v3"),
unstable: bool = True,
v1: bool = False,
) -> Iterable[Pattern]:
@ -52,7 +52,7 @@ def client_patterns(
v1_prefix = CLIENT_API_PREFIX + "/api/v1"
patterns.append(re.compile("^" + v1_prefix + path_regex))
for release in releases:
new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,)
new_prefix = CLIENT_API_PREFIX + f"/{release}"
patterns.append(re.compile("^" + new_prefix + path_regex))
return patterns

View file

@ -262,7 +262,7 @@ class SigningKeyUploadServlet(RestServlet):
}
"""
PATTERNS = client_patterns("/keys/device_signing/upload$", releases=())
PATTERNS = client_patterns("/keys/device_signing/upload$", releases=("v3",))
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -61,7 +61,8 @@ class LoginRestServlet(RestServlet):
TOKEN_TYPE = "m.login.token"
JWT_TYPE = "org.matrix.login.jwt"
JWT_TYPE_DEPRECATED = "m.login.jwt"
APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
APPSERVICE_TYPE = "m.login.application_service"
APPSERVICE_TYPE_UNSTABLE = "uk.half-shot.msc2778.login.application_service"
REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token"
def __init__(self, hs: "HomeServer"):
@ -71,6 +72,7 @@ class LoginRestServlet(RestServlet):
# JWT configuration variables.
self.jwt_enabled = hs.config.jwt.jwt_enabled
self.jwt_secret = hs.config.jwt.jwt_secret
self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim
self.jwt_algorithm = hs.config.jwt.jwt_algorithm
self.jwt_issuer = hs.config.jwt.jwt_issuer
self.jwt_audiences = hs.config.jwt.jwt_audiences
@ -79,7 +81,9 @@ class LoginRestServlet(RestServlet):
self.saml2_enabled = hs.config.saml2.saml2_enabled
self.cas_enabled = hs.config.cas.cas_enabled
self.oidc_enabled = hs.config.oidc.oidc_enabled
self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None
self._msc2918_enabled = (
hs.config.registration.refreshable_access_token_lifetime is not None
)
self.auth = hs.get_auth()
@ -143,6 +147,7 @@ class LoginRestServlet(RestServlet):
flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types())
flows.append({"type": LoginRestServlet.APPSERVICE_TYPE})
flows.append({"type": LoginRestServlet.APPSERVICE_TYPE_UNSTABLE})
return 200, {"flows": flows}
@ -159,7 +164,10 @@ class LoginRestServlet(RestServlet):
should_issue_refresh_token = False
try:
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
if login_submission["type"] in (
LoginRestServlet.APPSERVICE_TYPE,
LoginRestServlet.APPSERVICE_TYPE_UNSTABLE,
):
appservice = self.auth.get_appservice_by_req(request)
if appservice.is_rate_limited():
@ -408,7 +416,7 @@ class LoginRestServlet(RestServlet):
errcode=Codes.FORBIDDEN,
)
user = payload.get("sub", None)
user = payload.get(self.jwt_subject_claim, None)
if user is None:
raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
@ -447,7 +455,9 @@ class RefreshTokenServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
self._auth_handler = hs.get_auth_handler()
self._clock = hs.get_clock()
self.access_token_lifetime = hs.config.registration.access_token_lifetime
self.refreshable_access_token_lifetime = (
hs.config.registration.refreshable_access_token_lifetime
)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
refresh_submission = parse_json_object_from_request(request)
@ -457,7 +467,9 @@ class RefreshTokenServlet(RestServlet):
if not isinstance(token, str):
raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM)
valid_until_ms = self._clock.time_msec() + self.access_token_lifetime
valid_until_ms = (
self._clock.time_msec() + self.refreshable_access_token_lifetime
)
access_token, refresh_token = await self._auth_handler.refresh_token(
token, valid_until_ms
)
@ -556,7 +568,7 @@ class CasTicketServlet(RestServlet):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
LoginRestServlet(hs).register(http_server)
if hs.config.registration.access_token_lifetime is not None:
if hs.config.registration.refreshable_access_token_lifetime is not None:
RefreshTokenServlet(hs).register(http_server)
SsoRedirectServlet(hs).register(http_server)
if hs.config.cas.cas_enabled:

View file

@ -420,7 +420,9 @@ class RegisterRestServlet(RestServlet):
self.password_policy_handler = hs.get_password_policy_handler()
self.clock = hs.get_clock()
self._registration_enabled = self.hs.config.registration.enable_registration
self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None
self._msc2918_enabled = (
hs.config.registration.refreshable_access_token_lifetime is not None
)
self._registration_flows = _calculate_registration_flows(
hs.config, self.auth_handler

View file

@ -224,17 +224,17 @@ class RelationPaginationServlet(RestServlet):
)
now = self.clock.time_msec()
# We set bundle_aggregations to False when retrieving the original
# We set bundle_relations to False when retrieving the original
# event because we want the content before relations were applied to
# it.
original_event = await self._event_serializer.serialize_event(
event, now, bundle_aggregations=False
event, now, bundle_relations=False
)
# Similarly, we don't allow relations to be applied to relations, so we
# return the original relations without any aggregations on top of them
# here.
serialized_events = await self._event_serializer.serialize_events(
events, now, bundle_aggregations=False
events, now, bundle_relations=False
)
return_value = pagination_chunk.to_dict()
@ -298,7 +298,9 @@ class RelationAggregationPaginationServlet(RestServlet):
raise SynapseError(404, "Unknown parent event.")
if relation_type not in (RelationTypes.ANNOTATION, None):
raise SynapseError(400, "Relation type must be 'annotation'")
raise SynapseError(
400, f"Relation type must be '{RelationTypes.ANNOTATION}'"
)
limit = parse_integer(request, "limit", default=5)
from_token_str = parse_string(request, "from")

View file

@ -554,6 +554,7 @@ class RoomMessageListRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
self._hs = hs
self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@ -571,7 +572,9 @@ class RoomMessageListRestServlet(RestServlet):
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
event_filter: Optional[Filter] = Filter(
self._hs, json_decoder.decode(filter_json)
)
if (
event_filter
and event_filter.filter_json.get("event_format", "client")
@ -676,6 +679,7 @@ class RoomEventContextServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
self._hs = hs
self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer()
@ -692,7 +696,9 @@ class RoomEventContextServlet(RestServlet):
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
event_filter: Optional[Filter] = Filter(
self._hs, json_decoder.decode(filter_json)
)
else:
event_filter = None
@ -717,7 +723,7 @@ class RoomEventContextServlet(RestServlet):
results["state"],
time_now,
# No need to bundle aggregations for state events
bundle_aggregations=False,
bundle_relations=False,
)
return 200, results

View file

@ -29,7 +29,7 @@ from typing import (
from synapse.api.constants import Membership, PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.events import EventBase
from synapse.events.utils import (
@ -150,7 +150,7 @@ class SyncRestServlet(RestServlet):
request_key = (user, timeout, since, filter_id, full_state, device_id)
if filter_id is None:
filter_collection = DEFAULT_FILTER_COLLECTION
filter_collection = self.filtering.DEFAULT_FILTER_COLLECTION
elif filter_id.startswith("{"):
try:
filter_object = json_decoder.decode(filter_id)
@ -160,7 +160,7 @@ class SyncRestServlet(RestServlet):
except Exception:
raise SynapseError(400, "Invalid filter JSON")
self.filtering.check_valid_filter(filter_object)
filter_collection = FilterCollection(filter_object)
filter_collection = FilterCollection(self.hs, filter_object)
else:
try:
filter_collection = await self.filtering.get_user_filter(
@ -522,7 +522,7 @@ class SyncRestServlet(RestServlet):
time_now=time_now,
# We don't bundle "live" events, as otherwise clients
# will end up double counting annotations.
bundle_aggregations=False,
bundle_relations=False,
token_id=token_id,
event_format=event_formatter,
only_event_fields=only_fields,

View file

@ -29,7 +29,7 @@ from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.http.server import finish_request, respond_with_json
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.util.stringutils import is_ascii
from synapse.util.stringutils import is_ascii, parse_and_validate_server_name
logger = logging.getLogger(__name__)
@ -51,6 +51,19 @@ TEXT_CONTENT_TYPES = [
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
"""Parses the server name, media ID and optional file name from the request URI
Also performs some rough validation on the server name.
Args:
request: The `Request`.
Returns:
A tuple containing the parsed server name, media ID and optional file name.
Raises:
SynapseError(404): if parsing or validation fail for any reason
"""
try:
# The type on postpath seems incorrect in Twisted 21.2.0.
postpath: List[bytes] = request.postpath # type: ignore
@ -62,6 +75,9 @@ def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
server_name = server_name_bytes.decode("utf-8")
media_id = media_id_bytes.decode("utf8")
# Validate the server name, raising if invalid
parse_and_validate_server_name(server_name)
file_name = None
if len(postpath) > 2:
try:

View file

@ -16,7 +16,8 @@
import functools
import os
import re
from typing import Any, Callable, List, TypeVar, cast
import string
from typing import Any, Callable, List, TypeVar, Union, cast
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
@ -37,6 +38,85 @@ def _wrap_in_base_path(func: F) -> F:
return cast(F, _wrapped)
GetPathMethod = TypeVar(
"GetPathMethod", bound=Union[Callable[..., str], Callable[..., List[str]]]
)
def _wrap_with_jail_check(func: GetPathMethod) -> GetPathMethod:
"""Wraps a path-returning method to check that the returned path(s) do not escape
the media store directory.
The check is not expected to ever fail, unless `func` is missing a call to
`_validate_path_component`, or `_validate_path_component` is buggy.
Args:
func: The `MediaFilePaths` method to wrap. The method may return either a single
path, or a list of paths. Returned paths may be either absolute or relative.
Returns:
The method, wrapped with a check to ensure that the returned path(s) lie within
the media store directory. Raises a `ValueError` if the check fails.
"""
@functools.wraps(func)
def _wrapped(
self: "MediaFilePaths", *args: Any, **kwargs: Any
) -> Union[str, List[str]]:
path_or_paths = func(self, *args, **kwargs)
if isinstance(path_or_paths, list):
paths_to_check = path_or_paths
else:
paths_to_check = [path_or_paths]
for path in paths_to_check:
# path may be an absolute or relative path, depending on the method being
# wrapped. When "appending" an absolute path, `os.path.join` discards the
# previous path, which is desired here.
normalized_path = os.path.normpath(os.path.join(self.real_base_path, path))
if (
os.path.commonpath([normalized_path, self.real_base_path])
!= self.real_base_path
):
raise ValueError(f"Invalid media store path: {path!r}")
return path_or_paths
return cast(GetPathMethod, _wrapped)
ALLOWED_CHARACTERS = set(
string.ascii_letters
+ string.digits
+ "_-"
+ ".[]:" # Domain names, IPv6 addresses and ports in server names
)
FORBIDDEN_NAMES = {
"",
os.path.curdir, # "." for the current platform
os.path.pardir, # ".." for the current platform
}
def _validate_path_component(name: str) -> str:
"""Checks that the given string can be safely used as a path component
Args:
name: The path component to check.
Returns:
The path component if valid.
Raises:
ValueError: If `name` cannot be safely used as a path component.
"""
if not ALLOWED_CHARACTERS.issuperset(name) or name in FORBIDDEN_NAMES:
raise ValueError(f"Invalid path component: {name!r}")
return name
class MediaFilePaths:
"""Describes where files are stored on disk.
@ -48,22 +128,46 @@ class MediaFilePaths:
def __init__(self, primary_base_path: str):
self.base_path = primary_base_path
# The media store directory, with all symlinks resolved.
self.real_base_path = os.path.realpath(primary_base_path)
# Refuse to initialize if paths cannot be validated correctly for the current
# platform.
assert os.path.sep not in ALLOWED_CHARACTERS
assert os.path.altsep not in ALLOWED_CHARACTERS
# On Windows, paths have all sorts of weirdness which `_validate_path_component`
# does not consider. In any case, the remote media store can't work correctly
# for certain homeservers there, since ":"s aren't allowed in paths.
assert os.name == "posix"
@_wrap_with_jail_check
def local_media_filepath_rel(self, media_id: str) -> str:
return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:])
return os.path.join(
"local_content",
_validate_path_component(media_id[0:2]),
_validate_path_component(media_id[2:4]),
_validate_path_component(media_id[4:]),
)
local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
@_wrap_with_jail_check
def local_media_thumbnail_rel(
self, media_id: str, width: int, height: int, content_type: str, method: str
) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
"local_thumbnails", media_id[0:2], media_id[2:4], media_id[4:], file_name
"local_thumbnails",
_validate_path_component(media_id[0:2]),
_validate_path_component(media_id[2:4]),
_validate_path_component(media_id[4:]),
_validate_path_component(file_name),
)
local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel)
@_wrap_with_jail_check
def local_media_thumbnail_dir(self, media_id: str) -> str:
"""
Retrieve the local store path of thumbnails of a given media_id
@ -76,18 +180,24 @@ class MediaFilePaths:
return os.path.join(
self.base_path,
"local_thumbnails",
media_id[0:2],
media_id[2:4],
media_id[4:],
_validate_path_component(media_id[0:2]),
_validate_path_component(media_id[2:4]),
_validate_path_component(media_id[4:]),
)
@_wrap_with_jail_check
def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str:
return os.path.join(
"remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
"remote_content",
_validate_path_component(server_name),
_validate_path_component(file_id[0:2]),
_validate_path_component(file_id[2:4]),
_validate_path_component(file_id[4:]),
)
remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
@_wrap_with_jail_check
def remote_media_thumbnail_rel(
self,
server_name: str,
@ -101,11 +211,11 @@ class MediaFilePaths:
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
"remote_thumbnail",
server_name,
file_id[0:2],
file_id[2:4],
file_id[4:],
file_name,
_validate_path_component(server_name),
_validate_path_component(file_id[0:2]),
_validate_path_component(file_id[2:4]),
_validate_path_component(file_id[4:]),
_validate_path_component(file_name),
)
remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel)
@ -113,6 +223,7 @@ class MediaFilePaths:
# Legacy path that was used to store thumbnails previously.
# Should be removed after some time, when most of the thumbnails are stored
# using the new path.
@_wrap_with_jail_check
def remote_media_thumbnail_rel_legacy(
self, server_name: str, file_id: str, width: int, height: int, content_type: str
) -> str:
@ -120,43 +231,66 @@ class MediaFilePaths:
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
return os.path.join(
"remote_thumbnail",
server_name,
file_id[0:2],
file_id[2:4],
file_id[4:],
file_name,
_validate_path_component(server_name),
_validate_path_component(file_id[0:2]),
_validate_path_component(file_id[2:4]),
_validate_path_component(file_id[4:]),
_validate_path_component(file_name),
)
def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str:
return os.path.join(
self.base_path,
"remote_thumbnail",
server_name,
file_id[0:2],
file_id[2:4],
file_id[4:],
_validate_path_component(server_name),
_validate_path_component(file_id[0:2]),
_validate_path_component(file_id[2:4]),
_validate_path_component(file_id[4:]),
)
@_wrap_with_jail_check
def url_cache_filepath_rel(self, media_id: str) -> str:
if NEW_FORMAT_ID_RE.match(media_id):
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
return os.path.join("url_cache", media_id[:10], media_id[11:])
return os.path.join(
"url_cache",
_validate_path_component(media_id[:10]),
_validate_path_component(media_id[11:]),
)
else:
return os.path.join("url_cache", media_id[0:2], media_id[2:4], media_id[4:])
return os.path.join(
"url_cache",
_validate_path_component(media_id[0:2]),
_validate_path_component(media_id[2:4]),
_validate_path_component(media_id[4:]),
)
url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
@_wrap_with_jail_check
def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id file"
if NEW_FORMAT_ID_RE.match(media_id):
return [os.path.join(self.base_path, "url_cache", media_id[:10])]
return [
os.path.join(
self.base_path, "url_cache", _validate_path_component(media_id[:10])
)
]
else:
return [
os.path.join(self.base_path, "url_cache", media_id[0:2], media_id[2:4]),
os.path.join(self.base_path, "url_cache", media_id[0:2]),
os.path.join(
self.base_path,
"url_cache",
_validate_path_component(media_id[0:2]),
_validate_path_component(media_id[2:4]),
),
os.path.join(
self.base_path, "url_cache", _validate_path_component(media_id[0:2])
),
]
@_wrap_with_jail_check
def url_cache_thumbnail_rel(
self, media_id: str, width: int, height: int, content_type: str, method: str
) -> str:
@ -168,37 +302,46 @@ class MediaFilePaths:
if NEW_FORMAT_ID_RE.match(media_id):
return os.path.join(
"url_cache_thumbnails", media_id[:10], media_id[11:], file_name
"url_cache_thumbnails",
_validate_path_component(media_id[:10]),
_validate_path_component(media_id[11:]),
_validate_path_component(file_name),
)
else:
return os.path.join(
"url_cache_thumbnails",
media_id[0:2],
media_id[2:4],
media_id[4:],
file_name,
_validate_path_component(media_id[0:2]),
_validate_path_component(media_id[2:4]),
_validate_path_component(media_id[4:]),
_validate_path_component(file_name),
)
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
@_wrap_with_jail_check
def url_cache_thumbnail_directory_rel(self, media_id: str) -> str:
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
if NEW_FORMAT_ID_RE.match(media_id):
return os.path.join("url_cache_thumbnails", media_id[:10], media_id[11:])
return os.path.join(
"url_cache_thumbnails",
_validate_path_component(media_id[:10]),
_validate_path_component(media_id[11:]),
)
else:
return os.path.join(
"url_cache_thumbnails",
media_id[0:2],
media_id[2:4],
media_id[4:],
_validate_path_component(media_id[0:2]),
_validate_path_component(media_id[2:4]),
_validate_path_component(media_id[4:]),
)
url_cache_thumbnail_directory = _wrap_in_base_path(
url_cache_thumbnail_directory_rel
)
@_wrap_with_jail_check
def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id thumbnails"
# Media id is of the form <DATE><RANDOM_STRING>
@ -206,21 +349,35 @@ class MediaFilePaths:
if NEW_FORMAT_ID_RE.match(media_id):
return [
os.path.join(
self.base_path, "url_cache_thumbnails", media_id[:10], media_id[11:]
self.base_path,
"url_cache_thumbnails",
_validate_path_component(media_id[:10]),
_validate_path_component(media_id[11:]),
),
os.path.join(
self.base_path,
"url_cache_thumbnails",
_validate_path_component(media_id[:10]),
),
os.path.join(self.base_path, "url_cache_thumbnails", media_id[:10]),
]
else:
return [
os.path.join(
self.base_path,
"url_cache_thumbnails",
media_id[0:2],
media_id[2:4],
media_id[4:],
_validate_path_component(media_id[0:2]),
_validate_path_component(media_id[2:4]),
_validate_path_component(media_id[4:]),
),
os.path.join(
self.base_path, "url_cache_thumbnails", media_id[0:2], media_id[2:4]
self.base_path,
"url_cache_thumbnails",
_validate_path_component(media_id[0:2]),
_validate_path_component(media_id[2:4]),
),
os.path.join(
self.base_path,
"url_cache_thumbnails",
_validate_path_component(media_id[0:2]),
),
os.path.join(self.base_path, "url_cache_thumbnails", media_id[0:2]),
]

View file

@ -45,7 +45,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.oembed import OEmbedProvider
from synapse.types import JsonDict
from synapse.types import JsonDict, UserID
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache
@ -231,7 +231,7 @@ class PreviewUrlResource(DirectServeJsonResource):
og = await make_deferred_yieldable(observable.observe())
respond_with_json_bytes(request, 200, og, send_cors=True)
async def _do_preview(self, url: str, user: str, ts: int) -> bytes:
async def _do_preview(self, url: str, user: UserID, ts: int) -> bytes:
"""Check the db, and download the URL and build a preview
Args:
@ -360,7 +360,7 @@ class PreviewUrlResource(DirectServeJsonResource):
return jsonog.encode("utf8")
async def _download_url(self, url: str, user: str) -> MediaInfo:
async def _download_url(self, url: str, user: UserID) -> MediaInfo:
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
@ -450,7 +450,7 @@ class PreviewUrlResource(DirectServeJsonResource):
)
async def _precache_image_url(
self, user: str, media_info: MediaInfo, og: JsonDict
self, user: UserID, media_info: MediaInfo, og: JsonDict
) -> None:
"""
Pre-cache the image (if one exists) for posterity

View file

@ -101,8 +101,8 @@ class Thumbnailer:
fits within the given rectangle::
(w_in / h_in) = (w_out / h_out)
w_out = min(w_max, h_max * (w_in / h_in))
h_out = min(h_max, w_max * (h_in / w_in))
w_out = max(min(w_max, h_max * (w_in / h_in)), 1)
h_out = max(min(h_max, w_max * (h_in / w_in)), 1)
Args:
max_width: The largest possible width.
@ -110,9 +110,9 @@ class Thumbnailer:
"""
if max_width * self.height < max_height * self.width:
return max_width, (max_width * self.height) // self.width
return max_width, max((max_width * self.height) // self.width, 1)
else:
return (max_height * self.width) // self.height, max_height
return max((max_height * self.width) // self.height, 1), max_height
def _resize(self, width: int, height: int) -> Image.Image:
# 1-bit or 8-bit color palette images need converting to RGB