mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-08-17 12:30:19 -04:00
Merge remote-tracking branch 'upstream/release-v1.48'
This commit is contained in:
commit
9f4fa40b64
175 changed files with 6413 additions and 1993 deletions
|
@ -12,7 +12,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
from synapse.http.server import HttpServer, JsonResource
|
||||
from synapse.rest import admin
|
||||
|
@ -62,6 +62,8 @@ from synapse.rest.client import (
|
|||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
RegisterServletsFunc = Callable[["HomeServer", HttpServer], None]
|
||||
|
||||
|
||||
class ClientRestResource(JsonResource):
|
||||
"""Matrix Client API REST resource.
|
||||
|
|
|
@ -28,6 +28,7 @@ from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
|
|||
from synapse.rest.admin.background_updates import (
|
||||
BackgroundUpdateEnabledRestServlet,
|
||||
BackgroundUpdateRestServlet,
|
||||
BackgroundUpdateStartJobRestServlet,
|
||||
)
|
||||
from synapse.rest.admin.devices import (
|
||||
DeleteDevicesRestServlet,
|
||||
|
@ -46,6 +47,9 @@ from synapse.rest.admin.registration_tokens import (
|
|||
RegistrationTokenRestServlet,
|
||||
)
|
||||
from synapse.rest.admin.rooms import (
|
||||
BlockRoomRestServlet,
|
||||
DeleteRoomStatusByDeleteIdRestServlet,
|
||||
DeleteRoomStatusByRoomIdRestServlet,
|
||||
ForwardExtremitiesRestServlet,
|
||||
JoinRoomAliasServlet,
|
||||
ListRoomRestServlet,
|
||||
|
@ -53,6 +57,7 @@ from synapse.rest.admin.rooms import (
|
|||
RoomEventContextServlet,
|
||||
RoomMembersRestServlet,
|
||||
RoomRestServlet,
|
||||
RoomRestV2Servlet,
|
||||
RoomStateRestServlet,
|
||||
)
|
||||
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
|
||||
|
@ -220,10 +225,14 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
|||
Register all the admin servlets.
|
||||
"""
|
||||
register_servlets_for_client_rest_resource(hs, http_server)
|
||||
BlockRoomRestServlet(hs).register(http_server)
|
||||
ListRoomRestServlet(hs).register(http_server)
|
||||
RoomStateRestServlet(hs).register(http_server)
|
||||
RoomRestServlet(hs).register(http_server)
|
||||
RoomRestV2Servlet(hs).register(http_server)
|
||||
RoomMembersRestServlet(hs).register(http_server)
|
||||
DeleteRoomStatusByDeleteIdRestServlet(hs).register(http_server)
|
||||
DeleteRoomStatusByRoomIdRestServlet(hs).register(http_server)
|
||||
JoinRoomAliasServlet(hs).register(http_server)
|
||||
VersionServlet(hs).register(http_server)
|
||||
UserAdminServlet(hs).register(http_server)
|
||||
|
@ -253,6 +262,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
|||
SendServerNoticeServlet(hs).register(http_server)
|
||||
BackgroundUpdateEnabledRestServlet(hs).register(http_server)
|
||||
BackgroundUpdateRestServlet(hs).register(http_server)
|
||||
BackgroundUpdateStartJobRestServlet(hs).register(http_server)
|
||||
|
||||
|
||||
def register_servlets_for_client_rest_resource(
|
||||
|
|
|
@ -12,10 +12,15 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
parse_json_object_from_request,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
|
||||
from synapse.types import JsonDict
|
||||
|
@ -29,37 +34,36 @@ logger = logging.getLogger(__name__)
|
|||
class BackgroundUpdateEnabledRestServlet(RestServlet):
|
||||
"""Allows temporarily disabling background updates"""
|
||||
|
||||
PATTERNS = admin_patterns("/background_updates/enabled")
|
||||
PATTERNS = admin_patterns("/background_updates/enabled$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.group_server = hs.get_groups_server_handler()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
self.data_stores = hs.get_datastores()
|
||||
self._auth = hs.get_auth()
|
||||
self._data_stores = hs.get_datastores()
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
requester = await self._auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self._auth, requester.user)
|
||||
|
||||
# We need to check that all configured databases have updates enabled.
|
||||
# (They *should* all be in sync.)
|
||||
enabled = all(db.updates.enabled for db in self.data_stores.databases)
|
||||
enabled = all(db.updates.enabled for db in self._data_stores.databases)
|
||||
|
||||
return 200, {"enabled": enabled}
|
||||
return HTTPStatus.OK, {"enabled": enabled}
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
requester = await self._auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self._auth, requester.user)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
enabled = body.get("enabled", True)
|
||||
|
||||
if not isinstance(enabled, bool):
|
||||
raise SynapseError(400, "'enabled' parameter must be a boolean")
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "'enabled' parameter must be a boolean"
|
||||
)
|
||||
|
||||
for db in self.data_stores.databases:
|
||||
for db in self._data_stores.databases:
|
||||
db.updates.enabled = enabled
|
||||
|
||||
# If we're re-enabling them ensure that we start the background
|
||||
|
@ -67,32 +71,29 @@ class BackgroundUpdateEnabledRestServlet(RestServlet):
|
|||
if enabled:
|
||||
db.updates.start_doing_background_updates()
|
||||
|
||||
return 200, {"enabled": enabled}
|
||||
return HTTPStatus.OK, {"enabled": enabled}
|
||||
|
||||
|
||||
class BackgroundUpdateRestServlet(RestServlet):
|
||||
"""Fetch information about background updates"""
|
||||
|
||||
PATTERNS = admin_patterns("/background_updates/status")
|
||||
PATTERNS = admin_patterns("/background_updates/status$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.group_server = hs.get_groups_server_handler()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
self.data_stores = hs.get_datastores()
|
||||
self._auth = hs.get_auth()
|
||||
self._data_stores = hs.get_datastores()
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
requester = await self._auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self._auth, requester.user)
|
||||
|
||||
# We need to check that all configured databases have updates enabled.
|
||||
# (They *should* all be in sync.)
|
||||
enabled = all(db.updates.enabled for db in self.data_stores.databases)
|
||||
enabled = all(db.updates.enabled for db in self._data_stores.databases)
|
||||
|
||||
current_updates = {}
|
||||
|
||||
for db in self.data_stores.databases:
|
||||
for db in self._data_stores.databases:
|
||||
update = db.updates.get_current_update()
|
||||
if not update:
|
||||
continue
|
||||
|
@ -104,4 +105,72 @@ class BackgroundUpdateRestServlet(RestServlet):
|
|||
"average_items_per_ms": update.average_items_per_ms(),
|
||||
}
|
||||
|
||||
return 200, {"enabled": enabled, "current_updates": current_updates}
|
||||
return HTTPStatus.OK, {"enabled": enabled, "current_updates": current_updates}
|
||||
|
||||
|
||||
class BackgroundUpdateStartJobRestServlet(RestServlet):
|
||||
"""Allows to start specific background updates"""
|
||||
|
||||
PATTERNS = admin_patterns("/background_updates/start_job")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._auth = hs.get_auth()
|
||||
self._store = hs.get_datastore()
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self._auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self._auth, requester.user)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(body, ["job_name"])
|
||||
|
||||
job_name = body["job_name"]
|
||||
|
||||
if job_name == "populate_stats_process_rooms":
|
||||
jobs = [
|
||||
{
|
||||
"update_name": "populate_stats_process_rooms",
|
||||
"progress_json": "{}",
|
||||
},
|
||||
]
|
||||
elif job_name == "regenerate_directory":
|
||||
jobs = [
|
||||
{
|
||||
"update_name": "populate_user_directory_createtables",
|
||||
"progress_json": "{}",
|
||||
"depends_on": "",
|
||||
},
|
||||
{
|
||||
"update_name": "populate_user_directory_process_rooms",
|
||||
"progress_json": "{}",
|
||||
"depends_on": "populate_user_directory_createtables",
|
||||
},
|
||||
{
|
||||
"update_name": "populate_user_directory_process_users",
|
||||
"progress_json": "{}",
|
||||
"depends_on": "populate_user_directory_process_rooms",
|
||||
},
|
||||
{
|
||||
"update_name": "populate_user_directory_cleanup",
|
||||
"progress_json": "{}",
|
||||
"depends_on": "populate_user_directory_process_users",
|
||||
},
|
||||
]
|
||||
else:
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid job_name")
|
||||
|
||||
try:
|
||||
await self._store.db_pool.simple_insert_many(
|
||||
table="background_updates",
|
||||
values=jobs,
|
||||
desc=f"admin_api_run_{job_name}",
|
||||
)
|
||||
except self._store.db_pool.engine.module.IntegrityError:
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"Job %s is already in queue of background updates." % (job_name,),
|
||||
)
|
||||
|
||||
self._store.db_pool.updates.start_doing_background_updates()
|
||||
|
||||
return HTTPStatus.OK, {}
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
|
||||
from urllib import parse as urlparse
|
||||
|
||||
from synapse.api.constants import EventTypes, JoinRules, Membership
|
||||
|
@ -34,7 +34,7 @@ from synapse.rest.admin._base import (
|
|||
assert_user_is_admin,
|
||||
)
|
||||
from synapse.storage.databases.main.room import RoomSortOrder
|
||||
from synapse.types import JsonDict, UserID, create_requester
|
||||
from synapse.types import JsonDict, RoomID, UserID, create_requester
|
||||
from synapse.util import json_decoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -46,6 +46,138 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RoomRestV2Servlet(RestServlet):
|
||||
"""Delete a room from server asynchronously with a background task.
|
||||
|
||||
It is a combination and improvement of shutdown and purge room.
|
||||
|
||||
Shuts down a room by removing all local users from the room.
|
||||
Blocking all future invites and joins to the room is optional.
|
||||
|
||||
If desired any local aliases will be repointed to a new room
|
||||
created by `new_room_user_id` and kicked users will be auto-
|
||||
joined to the new room.
|
||||
|
||||
If 'purge' is true, it will remove all traces of a room from the database.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$", "v2")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._auth = hs.get_auth()
|
||||
self._store = hs.get_datastore()
|
||||
self._pagination_handler = hs.get_pagination_handler()
|
||||
|
||||
async def on_DELETE(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
|
||||
requester = await self._auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self._auth, requester.user)
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
block = content.get("block", False)
|
||||
if not isinstance(block, bool):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"Param 'block' must be a boolean, if given",
|
||||
Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
purge = content.get("purge", True)
|
||||
if not isinstance(purge, bool):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"Param 'purge' must be a boolean, if given",
|
||||
Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
force_purge = content.get("force_purge", False)
|
||||
if not isinstance(force_purge, bool):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"Param 'force_purge' must be a boolean, if given",
|
||||
Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
if not RoomID.is_valid(room_id):
|
||||
raise SynapseError(400, "%s is not a legal room ID" % (room_id,))
|
||||
|
||||
if not await self._store.get_room(room_id):
|
||||
raise NotFoundError("Unknown room id %s" % (room_id,))
|
||||
|
||||
delete_id = self._pagination_handler.start_shutdown_and_purge_room(
|
||||
room_id=room_id,
|
||||
new_room_user_id=content.get("new_room_user_id"),
|
||||
new_room_name=content.get("room_name"),
|
||||
message=content.get("message"),
|
||||
requester_user_id=requester.user.to_string(),
|
||||
block=block,
|
||||
purge=purge,
|
||||
force_purge=force_purge,
|
||||
)
|
||||
|
||||
return 200, {"delete_id": delete_id}
|
||||
|
||||
|
||||
class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
|
||||
"""Get the status of the delete room background task."""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete_status$", "v2")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._auth = hs.get_auth()
|
||||
self._pagination_handler = hs.get_pagination_handler()
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
|
||||
await assert_requester_is_admin(self._auth, request)
|
||||
|
||||
if not RoomID.is_valid(room_id):
|
||||
raise SynapseError(400, "%s is not a legal room ID" % (room_id,))
|
||||
|
||||
delete_ids = self._pagination_handler.get_delete_ids_by_room(room_id)
|
||||
if delete_ids is None:
|
||||
raise NotFoundError("No delete task for room_id '%s' found" % room_id)
|
||||
|
||||
response = []
|
||||
for delete_id in delete_ids:
|
||||
delete = self._pagination_handler.get_delete_status(delete_id)
|
||||
if delete:
|
||||
response += [
|
||||
{
|
||||
"delete_id": delete_id,
|
||||
**delete.asdict(),
|
||||
}
|
||||
]
|
||||
return 200, {"results": cast(JsonDict, response)}
|
||||
|
||||
|
||||
class DeleteRoomStatusByDeleteIdRestServlet(RestServlet):
|
||||
"""Get the status of the delete room background task."""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/delete_status/(?P<delete_id>[^/]+)$", "v2")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._auth = hs.get_auth()
|
||||
self._pagination_handler = hs.get_pagination_handler()
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, delete_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
|
||||
await assert_requester_is_admin(self._auth, request)
|
||||
|
||||
delete_status = self._pagination_handler.get_delete_status(delete_id)
|
||||
if delete_status is None:
|
||||
raise NotFoundError("delete id '%s' not found" % delete_id)
|
||||
|
||||
return 200, cast(JsonDict, delete_status.asdict())
|
||||
|
||||
|
||||
class ListRoomRestServlet(RestServlet):
|
||||
"""
|
||||
List all rooms that are known to the homeserver. Results are returned
|
||||
|
@ -239,9 +371,22 @@ class RoomRestServlet(RestServlet):
|
|||
|
||||
# Purge room
|
||||
if purge:
|
||||
await pagination_handler.purge_room(room_id, force=force_purge)
|
||||
try:
|
||||
await pagination_handler.purge_room(room_id, force=force_purge)
|
||||
except NotFoundError:
|
||||
if block:
|
||||
# We can block unknown rooms with this endpoint, in which case
|
||||
# a failed purge is expected.
|
||||
pass
|
||||
else:
|
||||
# But otherwise, we expect this purge to have succeeded.
|
||||
raise
|
||||
|
||||
return 200, ret
|
||||
# Cast safety: cast away the knowledge that this is a TypedDict.
|
||||
# See https://github.com/python/mypy/issues/4976#issuecomment-579883622
|
||||
# for some discussion on why this is necessary. Either way,
|
||||
# `ret` is an opaque dictionary blob as far as the rest of the app cares.
|
||||
return 200, cast(JsonDict, ret)
|
||||
|
||||
|
||||
class RoomMembersRestServlet(RestServlet):
|
||||
|
@ -303,7 +448,7 @@ class RoomStateRestServlet(RestServlet):
|
|||
now,
|
||||
# We don't bother bundling aggregations in when asked for state
|
||||
# events, as clients won't use them.
|
||||
bundle_aggregations=False,
|
||||
bundle_relations=False,
|
||||
)
|
||||
ret = {"state": room_state}
|
||||
|
||||
|
@ -583,6 +728,7 @@ class RoomEventContextServlet(RestServlet):
|
|||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self._hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
self.room_context_handler = hs.get_room_context_handler()
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
|
@ -600,7 +746,9 @@ class RoomEventContextServlet(RestServlet):
|
|||
filter_str = parse_string(request, "filter", encoding="utf-8")
|
||||
if filter_str:
|
||||
filter_json = urlparse.unquote(filter_str)
|
||||
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
|
||||
event_filter: Optional[Filter] = Filter(
|
||||
self._hs, json_decoder.decode(filter_json)
|
||||
)
|
||||
else:
|
||||
event_filter = None
|
||||
|
||||
|
@ -630,7 +778,70 @@ class RoomEventContextServlet(RestServlet):
|
|||
results["state"],
|
||||
time_now,
|
||||
# No need to bundle aggregations for state events
|
||||
bundle_aggregations=False,
|
||||
bundle_relations=False,
|
||||
)
|
||||
|
||||
return 200, results
|
||||
|
||||
|
||||
class BlockRoomRestServlet(RestServlet):
|
||||
"""
|
||||
Manage blocking of rooms.
|
||||
On PUT: Add or remove a room from blocking list.
|
||||
On GET: Get blocking status of room and user who has blocked this room.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/block$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._auth = hs.get_auth()
|
||||
self._store = hs.get_datastore()
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self._auth, request)
|
||||
|
||||
if not RoomID.is_valid(room_id):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,)
|
||||
)
|
||||
|
||||
blocked_by = await self._store.room_is_blocked_by(room_id)
|
||||
# Test `not None` if `user_id` is an empty string
|
||||
# if someone add manually an entry in database
|
||||
if blocked_by is not None:
|
||||
response = {"block": True, "user_id": blocked_by}
|
||||
else:
|
||||
response = {"block": False}
|
||||
|
||||
return HTTPStatus.OK, response
|
||||
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self._auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self._auth, requester.user)
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
if not RoomID.is_valid(room_id):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,)
|
||||
)
|
||||
|
||||
assert_params_in_dict(content, ["block"])
|
||||
block = content.get("block")
|
||||
if not isinstance(block, bool):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"Param 'block' must be a boolean.",
|
||||
Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
if block:
|
||||
await self._store.block_room(room_id, requester.user.to_string())
|
||||
else:
|
||||
await self._store.unblock_room(room_id)
|
||||
|
||||
return HTTPStatus.OK, {"block": block}
|
||||
|
|
|
@ -898,7 +898,7 @@ class UserTokenRestServlet(RestServlet):
|
|||
if auth_user.to_string() == user_id:
|
||||
raise SynapseError(400, "Cannot use admin API to login as self")
|
||||
|
||||
token = await self.auth_handler.get_access_token_for_user_id(
|
||||
token = await self.auth_handler.create_access_token_for_user_id(
|
||||
user_id=auth_user.to_string(),
|
||||
device_id=None,
|
||||
valid_until_ms=valid_until_ms,
|
||||
|
@ -909,7 +909,7 @@ class UserTokenRestServlet(RestServlet):
|
|||
|
||||
|
||||
class ShadowBanRestServlet(RestServlet):
|
||||
"""An admin API for shadow-banning a user.
|
||||
"""An admin API for controlling whether a user is shadow-banned.
|
||||
|
||||
A shadow-banned users receives successful responses to their client-server
|
||||
API requests, but the events are not propagated into rooms.
|
||||
|
@ -917,11 +917,19 @@ class ShadowBanRestServlet(RestServlet):
|
|||
Shadow-banning a user should be used as a tool of last resort and may lead
|
||||
to confusing or broken behaviour for the client.
|
||||
|
||||
Example:
|
||||
Example of shadow-banning a user:
|
||||
|
||||
POST /_synapse/admin/v1/users/@test:example.com/shadow_ban
|
||||
{}
|
||||
|
||||
200 OK
|
||||
{}
|
||||
|
||||
Example of removing a user from being shadow-banned:
|
||||
|
||||
DELETE /_synapse/admin/v1/users/@test:example.com/shadow_ban
|
||||
{}
|
||||
|
||||
200 OK
|
||||
{}
|
||||
"""
|
||||
|
@ -945,6 +953,18 @@ class ShadowBanRestServlet(RestServlet):
|
|||
|
||||
return 200, {}
|
||||
|
||||
async def on_DELETE(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
raise SynapseError(400, "Only local users can be shadow-banned")
|
||||
|
||||
await self.store.set_shadow_banned(UserID.from_string(user_id), False)
|
||||
|
||||
return 200, {}
|
||||
|
||||
|
||||
class RateLimitRestServlet(RestServlet):
|
||||
"""An admin API to override ratelimiting for an user.
|
||||
|
|
|
@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
def client_patterns(
|
||||
path_regex: str,
|
||||
releases: Iterable[int] = (0,),
|
||||
releases: Iterable[str] = ("r0", "v3"),
|
||||
unstable: bool = True,
|
||||
v1: bool = False,
|
||||
) -> Iterable[Pattern]:
|
||||
|
@ -52,7 +52,7 @@ def client_patterns(
|
|||
v1_prefix = CLIENT_API_PREFIX + "/api/v1"
|
||||
patterns.append(re.compile("^" + v1_prefix + path_regex))
|
||||
for release in releases:
|
||||
new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,)
|
||||
new_prefix = CLIENT_API_PREFIX + f"/{release}"
|
||||
patterns.append(re.compile("^" + new_prefix + path_regex))
|
||||
|
||||
return patterns
|
||||
|
|
|
@ -262,7 +262,7 @@ class SigningKeyUploadServlet(RestServlet):
|
|||
}
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns("/keys/device_signing/upload$", releases=())
|
||||
PATTERNS = client_patterns("/keys/device_signing/upload$", releases=("v3",))
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
|
|
|
@ -61,7 +61,8 @@ class LoginRestServlet(RestServlet):
|
|||
TOKEN_TYPE = "m.login.token"
|
||||
JWT_TYPE = "org.matrix.login.jwt"
|
||||
JWT_TYPE_DEPRECATED = "m.login.jwt"
|
||||
APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
|
||||
APPSERVICE_TYPE = "m.login.application_service"
|
||||
APPSERVICE_TYPE_UNSTABLE = "uk.half-shot.msc2778.login.application_service"
|
||||
REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
|
@ -71,6 +72,7 @@ class LoginRestServlet(RestServlet):
|
|||
# JWT configuration variables.
|
||||
self.jwt_enabled = hs.config.jwt.jwt_enabled
|
||||
self.jwt_secret = hs.config.jwt.jwt_secret
|
||||
self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim
|
||||
self.jwt_algorithm = hs.config.jwt.jwt_algorithm
|
||||
self.jwt_issuer = hs.config.jwt.jwt_issuer
|
||||
self.jwt_audiences = hs.config.jwt.jwt_audiences
|
||||
|
@ -79,7 +81,9 @@ class LoginRestServlet(RestServlet):
|
|||
self.saml2_enabled = hs.config.saml2.saml2_enabled
|
||||
self.cas_enabled = hs.config.cas.cas_enabled
|
||||
self.oidc_enabled = hs.config.oidc.oidc_enabled
|
||||
self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None
|
||||
self._msc2918_enabled = (
|
||||
hs.config.registration.refreshable_access_token_lifetime is not None
|
||||
)
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
|
@ -143,6 +147,7 @@ class LoginRestServlet(RestServlet):
|
|||
flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types())
|
||||
|
||||
flows.append({"type": LoginRestServlet.APPSERVICE_TYPE})
|
||||
flows.append({"type": LoginRestServlet.APPSERVICE_TYPE_UNSTABLE})
|
||||
|
||||
return 200, {"flows": flows}
|
||||
|
||||
|
@ -159,7 +164,10 @@ class LoginRestServlet(RestServlet):
|
|||
should_issue_refresh_token = False
|
||||
|
||||
try:
|
||||
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
|
||||
if login_submission["type"] in (
|
||||
LoginRestServlet.APPSERVICE_TYPE,
|
||||
LoginRestServlet.APPSERVICE_TYPE_UNSTABLE,
|
||||
):
|
||||
appservice = self.auth.get_appservice_by_req(request)
|
||||
|
||||
if appservice.is_rate_limited():
|
||||
|
@ -408,7 +416,7 @@ class LoginRestServlet(RestServlet):
|
|||
errcode=Codes.FORBIDDEN,
|
||||
)
|
||||
|
||||
user = payload.get("sub", None)
|
||||
user = payload.get(self.jwt_subject_claim, None)
|
||||
if user is None:
|
||||
raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
|
||||
|
||||
|
@ -447,7 +455,9 @@ class RefreshTokenServlet(RestServlet):
|
|||
def __init__(self, hs: "HomeServer"):
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._clock = hs.get_clock()
|
||||
self.access_token_lifetime = hs.config.registration.access_token_lifetime
|
||||
self.refreshable_access_token_lifetime = (
|
||||
hs.config.registration.refreshable_access_token_lifetime
|
||||
)
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
refresh_submission = parse_json_object_from_request(request)
|
||||
|
@ -457,7 +467,9 @@ class RefreshTokenServlet(RestServlet):
|
|||
if not isinstance(token, str):
|
||||
raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM)
|
||||
|
||||
valid_until_ms = self._clock.time_msec() + self.access_token_lifetime
|
||||
valid_until_ms = (
|
||||
self._clock.time_msec() + self.refreshable_access_token_lifetime
|
||||
)
|
||||
access_token, refresh_token = await self._auth_handler.refresh_token(
|
||||
token, valid_until_ms
|
||||
)
|
||||
|
@ -556,7 +568,7 @@ class CasTicketServlet(RestServlet):
|
|||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
LoginRestServlet(hs).register(http_server)
|
||||
if hs.config.registration.access_token_lifetime is not None:
|
||||
if hs.config.registration.refreshable_access_token_lifetime is not None:
|
||||
RefreshTokenServlet(hs).register(http_server)
|
||||
SsoRedirectServlet(hs).register(http_server)
|
||||
if hs.config.cas.cas_enabled:
|
||||
|
|
|
@ -420,7 +420,9 @@ class RegisterRestServlet(RestServlet):
|
|||
self.password_policy_handler = hs.get_password_policy_handler()
|
||||
self.clock = hs.get_clock()
|
||||
self._registration_enabled = self.hs.config.registration.enable_registration
|
||||
self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None
|
||||
self._msc2918_enabled = (
|
||||
hs.config.registration.refreshable_access_token_lifetime is not None
|
||||
)
|
||||
|
||||
self._registration_flows = _calculate_registration_flows(
|
||||
hs.config, self.auth_handler
|
||||
|
|
|
@ -224,17 +224,17 @@ class RelationPaginationServlet(RestServlet):
|
|||
)
|
||||
|
||||
now = self.clock.time_msec()
|
||||
# We set bundle_aggregations to False when retrieving the original
|
||||
# We set bundle_relations to False when retrieving the original
|
||||
# event because we want the content before relations were applied to
|
||||
# it.
|
||||
original_event = await self._event_serializer.serialize_event(
|
||||
event, now, bundle_aggregations=False
|
||||
event, now, bundle_relations=False
|
||||
)
|
||||
# Similarly, we don't allow relations to be applied to relations, so we
|
||||
# return the original relations without any aggregations on top of them
|
||||
# here.
|
||||
serialized_events = await self._event_serializer.serialize_events(
|
||||
events, now, bundle_aggregations=False
|
||||
events, now, bundle_relations=False
|
||||
)
|
||||
|
||||
return_value = pagination_chunk.to_dict()
|
||||
|
@ -298,7 +298,9 @@ class RelationAggregationPaginationServlet(RestServlet):
|
|||
raise SynapseError(404, "Unknown parent event.")
|
||||
|
||||
if relation_type not in (RelationTypes.ANNOTATION, None):
|
||||
raise SynapseError(400, "Relation type must be 'annotation'")
|
||||
raise SynapseError(
|
||||
400, f"Relation type must be '{RelationTypes.ANNOTATION}'"
|
||||
)
|
||||
|
||||
limit = parse_integer(request, "limit", default=5)
|
||||
from_token_str = parse_string(request, "from")
|
||||
|
|
|
@ -554,6 +554,7 @@ class RoomMessageListRestServlet(RestServlet):
|
|||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self._hs = hs
|
||||
self.pagination_handler = hs.get_pagination_handler()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
@ -571,7 +572,9 @@ class RoomMessageListRestServlet(RestServlet):
|
|||
filter_str = parse_string(request, "filter", encoding="utf-8")
|
||||
if filter_str:
|
||||
filter_json = urlparse.unquote(filter_str)
|
||||
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
|
||||
event_filter: Optional[Filter] = Filter(
|
||||
self._hs, json_decoder.decode(filter_json)
|
||||
)
|
||||
if (
|
||||
event_filter
|
||||
and event_filter.filter_json.get("event_format", "client")
|
||||
|
@ -676,6 +679,7 @@ class RoomEventContextServlet(RestServlet):
|
|||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self._hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
self.room_context_handler = hs.get_room_context_handler()
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
|
@ -692,7 +696,9 @@ class RoomEventContextServlet(RestServlet):
|
|||
filter_str = parse_string(request, "filter", encoding="utf-8")
|
||||
if filter_str:
|
||||
filter_json = urlparse.unquote(filter_str)
|
||||
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
|
||||
event_filter: Optional[Filter] = Filter(
|
||||
self._hs, json_decoder.decode(filter_json)
|
||||
)
|
||||
else:
|
||||
event_filter = None
|
||||
|
||||
|
@ -717,7 +723,7 @@ class RoomEventContextServlet(RestServlet):
|
|||
results["state"],
|
||||
time_now,
|
||||
# No need to bundle aggregations for state events
|
||||
bundle_aggregations=False,
|
||||
bundle_relations=False,
|
||||
)
|
||||
|
||||
return 200, results
|
||||
|
|
|
@ -29,7 +29,7 @@ from typing import (
|
|||
|
||||
from synapse.api.constants import Membership, PresenceState
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError
|
||||
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
|
||||
from synapse.api.filtering import FilterCollection
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.utils import (
|
||||
|
@ -150,7 +150,7 @@ class SyncRestServlet(RestServlet):
|
|||
request_key = (user, timeout, since, filter_id, full_state, device_id)
|
||||
|
||||
if filter_id is None:
|
||||
filter_collection = DEFAULT_FILTER_COLLECTION
|
||||
filter_collection = self.filtering.DEFAULT_FILTER_COLLECTION
|
||||
elif filter_id.startswith("{"):
|
||||
try:
|
||||
filter_object = json_decoder.decode(filter_id)
|
||||
|
@ -160,7 +160,7 @@ class SyncRestServlet(RestServlet):
|
|||
except Exception:
|
||||
raise SynapseError(400, "Invalid filter JSON")
|
||||
self.filtering.check_valid_filter(filter_object)
|
||||
filter_collection = FilterCollection(filter_object)
|
||||
filter_collection = FilterCollection(self.hs, filter_object)
|
||||
else:
|
||||
try:
|
||||
filter_collection = await self.filtering.get_user_filter(
|
||||
|
@ -522,7 +522,7 @@ class SyncRestServlet(RestServlet):
|
|||
time_now=time_now,
|
||||
# We don't bundle "live" events, as otherwise clients
|
||||
# will end up double counting annotations.
|
||||
bundle_aggregations=False,
|
||||
bundle_relations=False,
|
||||
token_id=token_id,
|
||||
event_format=event_formatter,
|
||||
only_event_fields=only_fields,
|
||||
|
|
|
@ -29,7 +29,7 @@ from synapse.api.errors import Codes, SynapseError, cs_error
|
|||
from synapse.http.server import finish_request, respond_with_json
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.util.stringutils import is_ascii
|
||||
from synapse.util.stringutils import is_ascii, parse_and_validate_server_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -51,6 +51,19 @@ TEXT_CONTENT_TYPES = [
|
|||
|
||||
|
||||
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
|
||||
"""Parses the server name, media ID and optional file name from the request URI
|
||||
|
||||
Also performs some rough validation on the server name.
|
||||
|
||||
Args:
|
||||
request: The `Request`.
|
||||
|
||||
Returns:
|
||||
A tuple containing the parsed server name, media ID and optional file name.
|
||||
|
||||
Raises:
|
||||
SynapseError(404): if parsing or validation fail for any reason
|
||||
"""
|
||||
try:
|
||||
# The type on postpath seems incorrect in Twisted 21.2.0.
|
||||
postpath: List[bytes] = request.postpath # type: ignore
|
||||
|
@ -62,6 +75,9 @@ def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
|
|||
server_name = server_name_bytes.decode("utf-8")
|
||||
media_id = media_id_bytes.decode("utf8")
|
||||
|
||||
# Validate the server name, raising if invalid
|
||||
parse_and_validate_server_name(server_name)
|
||||
|
||||
file_name = None
|
||||
if len(postpath) > 2:
|
||||
try:
|
||||
|
|
|
@ -16,7 +16,8 @@
|
|||
import functools
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Callable, List, TypeVar, cast
|
||||
import string
|
||||
from typing import Any, Callable, List, TypeVar, Union, cast
|
||||
|
||||
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
|
||||
|
||||
|
@ -37,6 +38,85 @@ def _wrap_in_base_path(func: F) -> F:
|
|||
return cast(F, _wrapped)
|
||||
|
||||
|
||||
GetPathMethod = TypeVar(
|
||||
"GetPathMethod", bound=Union[Callable[..., str], Callable[..., List[str]]]
|
||||
)
|
||||
|
||||
|
||||
def _wrap_with_jail_check(func: GetPathMethod) -> GetPathMethod:
|
||||
"""Wraps a path-returning method to check that the returned path(s) do not escape
|
||||
the media store directory.
|
||||
|
||||
The check is not expected to ever fail, unless `func` is missing a call to
|
||||
`_validate_path_component`, or `_validate_path_component` is buggy.
|
||||
|
||||
Args:
|
||||
func: The `MediaFilePaths` method to wrap. The method may return either a single
|
||||
path, or a list of paths. Returned paths may be either absolute or relative.
|
||||
|
||||
Returns:
|
||||
The method, wrapped with a check to ensure that the returned path(s) lie within
|
||||
the media store directory. Raises a `ValueError` if the check fails.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def _wrapped(
|
||||
self: "MediaFilePaths", *args: Any, **kwargs: Any
|
||||
) -> Union[str, List[str]]:
|
||||
path_or_paths = func(self, *args, **kwargs)
|
||||
|
||||
if isinstance(path_or_paths, list):
|
||||
paths_to_check = path_or_paths
|
||||
else:
|
||||
paths_to_check = [path_or_paths]
|
||||
|
||||
for path in paths_to_check:
|
||||
# path may be an absolute or relative path, depending on the method being
|
||||
# wrapped. When "appending" an absolute path, `os.path.join` discards the
|
||||
# previous path, which is desired here.
|
||||
normalized_path = os.path.normpath(os.path.join(self.real_base_path, path))
|
||||
if (
|
||||
os.path.commonpath([normalized_path, self.real_base_path])
|
||||
!= self.real_base_path
|
||||
):
|
||||
raise ValueError(f"Invalid media store path: {path!r}")
|
||||
|
||||
return path_or_paths
|
||||
|
||||
return cast(GetPathMethod, _wrapped)
|
||||
|
||||
|
||||
ALLOWED_CHARACTERS = set(
|
||||
string.ascii_letters
|
||||
+ string.digits
|
||||
+ "_-"
|
||||
+ ".[]:" # Domain names, IPv6 addresses and ports in server names
|
||||
)
|
||||
FORBIDDEN_NAMES = {
|
||||
"",
|
||||
os.path.curdir, # "." for the current platform
|
||||
os.path.pardir, # ".." for the current platform
|
||||
}
|
||||
|
||||
|
||||
def _validate_path_component(name: str) -> str:
|
||||
"""Checks that the given string can be safely used as a path component
|
||||
|
||||
Args:
|
||||
name: The path component to check.
|
||||
|
||||
Returns:
|
||||
The path component if valid.
|
||||
|
||||
Raises:
|
||||
ValueError: If `name` cannot be safely used as a path component.
|
||||
"""
|
||||
if not ALLOWED_CHARACTERS.issuperset(name) or name in FORBIDDEN_NAMES:
|
||||
raise ValueError(f"Invalid path component: {name!r}")
|
||||
|
||||
return name
|
||||
|
||||
|
||||
class MediaFilePaths:
|
||||
"""Describes where files are stored on disk.
|
||||
|
||||
|
@ -48,22 +128,46 @@ class MediaFilePaths:
|
|||
def __init__(self, primary_base_path: str):
|
||||
self.base_path = primary_base_path
|
||||
|
||||
# The media store directory, with all symlinks resolved.
|
||||
self.real_base_path = os.path.realpath(primary_base_path)
|
||||
|
||||
# Refuse to initialize if paths cannot be validated correctly for the current
|
||||
# platform.
|
||||
assert os.path.sep not in ALLOWED_CHARACTERS
|
||||
assert os.path.altsep not in ALLOWED_CHARACTERS
|
||||
# On Windows, paths have all sorts of weirdness which `_validate_path_component`
|
||||
# does not consider. In any case, the remote media store can't work correctly
|
||||
# for certain homeservers there, since ":"s aren't allowed in paths.
|
||||
assert os.name == "posix"
|
||||
|
||||
@_wrap_with_jail_check
|
||||
def local_media_filepath_rel(self, media_id: str) -> str:
|
||||
return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:])
|
||||
return os.path.join(
|
||||
"local_content",
|
||||
_validate_path_component(media_id[0:2]),
|
||||
_validate_path_component(media_id[2:4]),
|
||||
_validate_path_component(media_id[4:]),
|
||||
)
|
||||
|
||||
local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
|
||||
|
||||
@_wrap_with_jail_check
|
||||
def local_media_thumbnail_rel(
|
||||
self, media_id: str, width: int, height: int, content_type: str, method: str
|
||||
) -> str:
|
||||
top_level_type, sub_type = content_type.split("/")
|
||||
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
|
||||
return os.path.join(
|
||||
"local_thumbnails", media_id[0:2], media_id[2:4], media_id[4:], file_name
|
||||
"local_thumbnails",
|
||||
_validate_path_component(media_id[0:2]),
|
||||
_validate_path_component(media_id[2:4]),
|
||||
_validate_path_component(media_id[4:]),
|
||||
_validate_path_component(file_name),
|
||||
)
|
||||
|
||||
local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel)
|
||||
|
||||
@_wrap_with_jail_check
|
||||
def local_media_thumbnail_dir(self, media_id: str) -> str:
|
||||
"""
|
||||
Retrieve the local store path of thumbnails of a given media_id
|
||||
|
@ -76,18 +180,24 @@ class MediaFilePaths:
|
|||
return os.path.join(
|
||||
self.base_path,
|
||||
"local_thumbnails",
|
||||
media_id[0:2],
|
||||
media_id[2:4],
|
||||
media_id[4:],
|
||||
_validate_path_component(media_id[0:2]),
|
||||
_validate_path_component(media_id[2:4]),
|
||||
_validate_path_component(media_id[4:]),
|
||||
)
|
||||
|
||||
@_wrap_with_jail_check
|
||||
def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str:
|
||||
return os.path.join(
|
||||
"remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
|
||||
"remote_content",
|
||||
_validate_path_component(server_name),
|
||||
_validate_path_component(file_id[0:2]),
|
||||
_validate_path_component(file_id[2:4]),
|
||||
_validate_path_component(file_id[4:]),
|
||||
)
|
||||
|
||||
remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
|
||||
|
||||
@_wrap_with_jail_check
|
||||
def remote_media_thumbnail_rel(
|
||||
self,
|
||||
server_name: str,
|
||||
|
@ -101,11 +211,11 @@ class MediaFilePaths:
|
|||
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
|
||||
return os.path.join(
|
||||
"remote_thumbnail",
|
||||
server_name,
|
||||
file_id[0:2],
|
||||
file_id[2:4],
|
||||
file_id[4:],
|
||||
file_name,
|
||||
_validate_path_component(server_name),
|
||||
_validate_path_component(file_id[0:2]),
|
||||
_validate_path_component(file_id[2:4]),
|
||||
_validate_path_component(file_id[4:]),
|
||||
_validate_path_component(file_name),
|
||||
)
|
||||
|
||||
remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel)
|
||||
|
@ -113,6 +223,7 @@ class MediaFilePaths:
|
|||
# Legacy path that was used to store thumbnails previously.
|
||||
# Should be removed after some time, when most of the thumbnails are stored
|
||||
# using the new path.
|
||||
@_wrap_with_jail_check
|
||||
def remote_media_thumbnail_rel_legacy(
|
||||
self, server_name: str, file_id: str, width: int, height: int, content_type: str
|
||||
) -> str:
|
||||
|
@ -120,43 +231,66 @@ class MediaFilePaths:
|
|||
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
|
||||
return os.path.join(
|
||||
"remote_thumbnail",
|
||||
server_name,
|
||||
file_id[0:2],
|
||||
file_id[2:4],
|
||||
file_id[4:],
|
||||
file_name,
|
||||
_validate_path_component(server_name),
|
||||
_validate_path_component(file_id[0:2]),
|
||||
_validate_path_component(file_id[2:4]),
|
||||
_validate_path_component(file_id[4:]),
|
||||
_validate_path_component(file_name),
|
||||
)
|
||||
|
||||
def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str:
|
||||
return os.path.join(
|
||||
self.base_path,
|
||||
"remote_thumbnail",
|
||||
server_name,
|
||||
file_id[0:2],
|
||||
file_id[2:4],
|
||||
file_id[4:],
|
||||
_validate_path_component(server_name),
|
||||
_validate_path_component(file_id[0:2]),
|
||||
_validate_path_component(file_id[2:4]),
|
||||
_validate_path_component(file_id[4:]),
|
||||
)
|
||||
|
||||
@_wrap_with_jail_check
|
||||
def url_cache_filepath_rel(self, media_id: str) -> str:
|
||||
if NEW_FORMAT_ID_RE.match(media_id):
|
||||
# Media id is of the form <DATE><RANDOM_STRING>
|
||||
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||
return os.path.join("url_cache", media_id[:10], media_id[11:])
|
||||
return os.path.join(
|
||||
"url_cache",
|
||||
_validate_path_component(media_id[:10]),
|
||||
_validate_path_component(media_id[11:]),
|
||||
)
|
||||
else:
|
||||
return os.path.join("url_cache", media_id[0:2], media_id[2:4], media_id[4:])
|
||||
return os.path.join(
|
||||
"url_cache",
|
||||
_validate_path_component(media_id[0:2]),
|
||||
_validate_path_component(media_id[2:4]),
|
||||
_validate_path_component(media_id[4:]),
|
||||
)
|
||||
|
||||
url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
|
||||
|
||||
@_wrap_with_jail_check
|
||||
def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]:
|
||||
"The dirs to try and remove if we delete the media_id file"
|
||||
if NEW_FORMAT_ID_RE.match(media_id):
|
||||
return [os.path.join(self.base_path, "url_cache", media_id[:10])]
|
||||
return [
|
||||
os.path.join(
|
||||
self.base_path, "url_cache", _validate_path_component(media_id[:10])
|
||||
)
|
||||
]
|
||||
else:
|
||||
return [
|
||||
os.path.join(self.base_path, "url_cache", media_id[0:2], media_id[2:4]),
|
||||
os.path.join(self.base_path, "url_cache", media_id[0:2]),
|
||||
os.path.join(
|
||||
self.base_path,
|
||||
"url_cache",
|
||||
_validate_path_component(media_id[0:2]),
|
||||
_validate_path_component(media_id[2:4]),
|
||||
),
|
||||
os.path.join(
|
||||
self.base_path, "url_cache", _validate_path_component(media_id[0:2])
|
||||
),
|
||||
]
|
||||
|
||||
@_wrap_with_jail_check
|
||||
def url_cache_thumbnail_rel(
|
||||
self, media_id: str, width: int, height: int, content_type: str, method: str
|
||||
) -> str:
|
||||
|
@ -168,37 +302,46 @@ class MediaFilePaths:
|
|||
|
||||
if NEW_FORMAT_ID_RE.match(media_id):
|
||||
return os.path.join(
|
||||
"url_cache_thumbnails", media_id[:10], media_id[11:], file_name
|
||||
"url_cache_thumbnails",
|
||||
_validate_path_component(media_id[:10]),
|
||||
_validate_path_component(media_id[11:]),
|
||||
_validate_path_component(file_name),
|
||||
)
|
||||
else:
|
||||
return os.path.join(
|
||||
"url_cache_thumbnails",
|
||||
media_id[0:2],
|
||||
media_id[2:4],
|
||||
media_id[4:],
|
||||
file_name,
|
||||
_validate_path_component(media_id[0:2]),
|
||||
_validate_path_component(media_id[2:4]),
|
||||
_validate_path_component(media_id[4:]),
|
||||
_validate_path_component(file_name),
|
||||
)
|
||||
|
||||
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
|
||||
|
||||
@_wrap_with_jail_check
|
||||
def url_cache_thumbnail_directory_rel(self, media_id: str) -> str:
|
||||
# Media id is of the form <DATE><RANDOM_STRING>
|
||||
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||
|
||||
if NEW_FORMAT_ID_RE.match(media_id):
|
||||
return os.path.join("url_cache_thumbnails", media_id[:10], media_id[11:])
|
||||
return os.path.join(
|
||||
"url_cache_thumbnails",
|
||||
_validate_path_component(media_id[:10]),
|
||||
_validate_path_component(media_id[11:]),
|
||||
)
|
||||
else:
|
||||
return os.path.join(
|
||||
"url_cache_thumbnails",
|
||||
media_id[0:2],
|
||||
media_id[2:4],
|
||||
media_id[4:],
|
||||
_validate_path_component(media_id[0:2]),
|
||||
_validate_path_component(media_id[2:4]),
|
||||
_validate_path_component(media_id[4:]),
|
||||
)
|
||||
|
||||
url_cache_thumbnail_directory = _wrap_in_base_path(
|
||||
url_cache_thumbnail_directory_rel
|
||||
)
|
||||
|
||||
@_wrap_with_jail_check
|
||||
def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
|
||||
"The dirs to try and remove if we delete the media_id thumbnails"
|
||||
# Media id is of the form <DATE><RANDOM_STRING>
|
||||
|
@ -206,21 +349,35 @@ class MediaFilePaths:
|
|||
if NEW_FORMAT_ID_RE.match(media_id):
|
||||
return [
|
||||
os.path.join(
|
||||
self.base_path, "url_cache_thumbnails", media_id[:10], media_id[11:]
|
||||
self.base_path,
|
||||
"url_cache_thumbnails",
|
||||
_validate_path_component(media_id[:10]),
|
||||
_validate_path_component(media_id[11:]),
|
||||
),
|
||||
os.path.join(
|
||||
self.base_path,
|
||||
"url_cache_thumbnails",
|
||||
_validate_path_component(media_id[:10]),
|
||||
),
|
||||
os.path.join(self.base_path, "url_cache_thumbnails", media_id[:10]),
|
||||
]
|
||||
else:
|
||||
return [
|
||||
os.path.join(
|
||||
self.base_path,
|
||||
"url_cache_thumbnails",
|
||||
media_id[0:2],
|
||||
media_id[2:4],
|
||||
media_id[4:],
|
||||
_validate_path_component(media_id[0:2]),
|
||||
_validate_path_component(media_id[2:4]),
|
||||
_validate_path_component(media_id[4:]),
|
||||
),
|
||||
os.path.join(
|
||||
self.base_path, "url_cache_thumbnails", media_id[0:2], media_id[2:4]
|
||||
self.base_path,
|
||||
"url_cache_thumbnails",
|
||||
_validate_path_component(media_id[0:2]),
|
||||
_validate_path_component(media_id[2:4]),
|
||||
),
|
||||
os.path.join(
|
||||
self.base_path,
|
||||
"url_cache_thumbnails",
|
||||
_validate_path_component(media_id[0:2]),
|
||||
),
|
||||
os.path.join(self.base_path, "url_cache_thumbnails", media_id[0:2]),
|
||||
]
|
||||
|
|
|
@ -45,7 +45,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
|||
from synapse.rest.media.v1._base import get_filename_from_headers
|
||||
from synapse.rest.media.v1.media_storage import MediaStorage
|
||||
from synapse.rest.media.v1.oembed import OEmbedProvider
|
||||
from synapse.types import JsonDict
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.async_helpers import ObservableDeferred
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
|
@ -231,7 +231,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||
og = await make_deferred_yieldable(observable.observe())
|
||||
respond_with_json_bytes(request, 200, og, send_cors=True)
|
||||
|
||||
async def _do_preview(self, url: str, user: str, ts: int) -> bytes:
|
||||
async def _do_preview(self, url: str, user: UserID, ts: int) -> bytes:
|
||||
"""Check the db, and download the URL and build a preview
|
||||
|
||||
Args:
|
||||
|
@ -360,7 +360,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||
|
||||
return jsonog.encode("utf8")
|
||||
|
||||
async def _download_url(self, url: str, user: str) -> MediaInfo:
|
||||
async def _download_url(self, url: str, user: UserID) -> MediaInfo:
|
||||
# TODO: we should probably honour robots.txt... except in practice
|
||||
# we're most likely being explicitly triggered by a human rather than a
|
||||
# bot, so are we really a robot?
|
||||
|
@ -450,7 +450,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||
)
|
||||
|
||||
async def _precache_image_url(
|
||||
self, user: str, media_info: MediaInfo, og: JsonDict
|
||||
self, user: UserID, media_info: MediaInfo, og: JsonDict
|
||||
) -> None:
|
||||
"""
|
||||
Pre-cache the image (if one exists) for posterity
|
||||
|
|
|
@ -101,8 +101,8 @@ class Thumbnailer:
|
|||
fits within the given rectangle::
|
||||
|
||||
(w_in / h_in) = (w_out / h_out)
|
||||
w_out = min(w_max, h_max * (w_in / h_in))
|
||||
h_out = min(h_max, w_max * (h_in / w_in))
|
||||
w_out = max(min(w_max, h_max * (w_in / h_in)), 1)
|
||||
h_out = max(min(h_max, w_max * (h_in / w_in)), 1)
|
||||
|
||||
Args:
|
||||
max_width: The largest possible width.
|
||||
|
@ -110,9 +110,9 @@ class Thumbnailer:
|
|||
"""
|
||||
|
||||
if max_width * self.height < max_height * self.width:
|
||||
return max_width, (max_width * self.height) // self.width
|
||||
return max_width, max((max_width * self.height) // self.width, 1)
|
||||
else:
|
||||
return (max_height * self.width) // self.height, max_height
|
||||
return max((max_height * self.width) // self.height, 1), max_height
|
||||
|
||||
def _resize(self, width: int, height: int) -> Image.Image:
|
||||
# 1-bit or 8-bit color palette images need converting to RGB
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue