mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-08-16 01:33:34 -04:00
Merge remote-tracking branch 'upstream/release-v1.50'
This commit is contained in:
commit
e9caf56ca0
205 changed files with 4905 additions and 2749 deletions
|
@ -69,6 +69,7 @@ from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
|
|||
from synapse.rest.admin.statistics import UserMediaStatisticsRestServlet
|
||||
from synapse.rest.admin.username_available import UsernameAvailableRestServlet
|
||||
from synapse.rest.admin.users import (
|
||||
AccountDataRestServlet,
|
||||
AccountValidityRenewServlet,
|
||||
DeactivateAccountRestServlet,
|
||||
PushersRestServlet,
|
||||
|
@ -108,7 +109,7 @@ class VersionServlet(RestServlet):
|
|||
|
||||
class PurgeHistoryRestServlet(RestServlet):
|
||||
PATTERNS = admin_patterns(
|
||||
"/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
|
||||
"/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]*))?$"
|
||||
)
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
|
@ -195,7 +196,7 @@ class PurgeHistoryRestServlet(RestServlet):
|
|||
|
||||
|
||||
class PurgeHistoryStatusRestServlet(RestServlet):
|
||||
PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]+)")
|
||||
PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.pagination_handler = hs.get_pagination_handler()
|
||||
|
@ -255,6 +256,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
|||
UserMediaStatisticsRestServlet(hs).register(http_server)
|
||||
EventReportDetailRestServlet(hs).register(http_server)
|
||||
EventReportsRestServlet(hs).register(http_server)
|
||||
AccountDataRestServlet(hs).register(http_server)
|
||||
PushersRestServlet(hs).register(http_server)
|
||||
MakeRoomAdminRestServlet(hs).register(http_server)
|
||||
ShadowBanRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -22,7 +22,7 @@ from synapse.http.servlet import (
|
|||
parse_json_object_from_request,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
|
||||
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -41,8 +41,7 @@ class BackgroundUpdateEnabledRestServlet(RestServlet):
|
|||
self._data_stores = hs.get_datastores()
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self._auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self._auth, requester.user)
|
||||
await assert_requester_is_admin(self._auth, request)
|
||||
|
||||
# We need to check that all configured databases have updates enabled.
|
||||
# (They *should* all be in sync.)
|
||||
|
@ -51,8 +50,7 @@ class BackgroundUpdateEnabledRestServlet(RestServlet):
|
|||
return HTTPStatus.OK, {"enabled": enabled}
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self._auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self._auth, requester.user)
|
||||
await assert_requester_is_admin(self._auth, request)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
|
@ -84,8 +82,7 @@ class BackgroundUpdateRestServlet(RestServlet):
|
|||
self._data_stores = hs.get_datastores()
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self._auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self._auth, requester.user)
|
||||
await assert_requester_is_admin(self._auth, request)
|
||||
|
||||
# We need to check that all configured databases have updates enabled.
|
||||
# (They *should* all be in sync.)
|
||||
|
@ -111,15 +108,14 @@ class BackgroundUpdateRestServlet(RestServlet):
|
|||
class BackgroundUpdateStartJobRestServlet(RestServlet):
|
||||
"""Allows to start specific background updates"""
|
||||
|
||||
PATTERNS = admin_patterns("/background_updates/start_job")
|
||||
PATTERNS = admin_patterns("/background_updates/start_job$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._auth = hs.get_auth()
|
||||
self._store = hs.get_datastore()
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self._auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self._auth, requester.user)
|
||||
await assert_requester_is_admin(self._auth, request)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(body, ["job_name"])
|
||||
|
|
|
@ -42,10 +42,10 @@ class DeviceRestServlet(RestServlet):
|
|||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.store = hs.get_datastore()
|
||||
self.is_mine = hs.is_mine
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str, device_id: str
|
||||
|
@ -53,7 +53,7 @@ class DeviceRestServlet(RestServlet):
|
|||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
target_user = UserID.from_string(user_id)
|
||||
if not self.hs.is_mine(target_user):
|
||||
if not self.is_mine(target_user):
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
|
||||
|
||||
u = await self.store.get_user_by_id(target_user.to_string())
|
||||
|
@ -63,6 +63,8 @@ class DeviceRestServlet(RestServlet):
|
|||
device = await self.device_handler.get_device(
|
||||
target_user.to_string(), device_id
|
||||
)
|
||||
if device is None:
|
||||
raise NotFoundError("No device found")
|
||||
return HTTPStatus.OK, device
|
||||
|
||||
async def on_DELETE(
|
||||
|
@ -71,7 +73,7 @@ class DeviceRestServlet(RestServlet):
|
|||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
target_user = UserID.from_string(user_id)
|
||||
if not self.hs.is_mine(target_user):
|
||||
if not self.is_mine(target_user):
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
|
||||
|
||||
u = await self.store.get_user_by_id(target_user.to_string())
|
||||
|
@ -87,7 +89,7 @@ class DeviceRestServlet(RestServlet):
|
|||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
target_user = UserID.from_string(user_id)
|
||||
if not self.hs.is_mine(target_user):
|
||||
if not self.is_mine(target_user):
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
|
||||
|
||||
u = await self.store.get_user_by_id(target_user.to_string())
|
||||
|
@ -109,14 +111,10 @@ class DevicesRestServlet(RestServlet):
|
|||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
"""
|
||||
Args:
|
||||
hs: server
|
||||
"""
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.store = hs.get_datastore()
|
||||
self.is_mine = hs.is_mine
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
|
@ -124,7 +122,7 @@ class DevicesRestServlet(RestServlet):
|
|||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
target_user = UserID.from_string(user_id)
|
||||
if not self.hs.is_mine(target_user):
|
||||
if not self.is_mine(target_user):
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
|
||||
|
||||
u = await self.store.get_user_by_id(target_user.to_string())
|
||||
|
@ -144,10 +142,10 @@ class DeleteDevicesRestServlet(RestServlet):
|
|||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.store = hs.get_datastore()
|
||||
self.is_mine = hs.is_mine
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
|
@ -155,7 +153,7 @@ class DeleteDevicesRestServlet(RestServlet):
|
|||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
target_user = UserID.from_string(user_id)
|
||||
if not self.hs.is_mine(target_user):
|
||||
if not self.is_mine(target_user):
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
|
||||
|
||||
u = await self.store.get_user_by_id(target_user.to_string())
|
||||
|
|
|
@ -52,7 +52,6 @@ class EventReportsRestServlet(RestServlet):
|
|||
PATTERNS = admin_patterns("/event_reports$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
|
@ -115,7 +114,6 @@ class EventReportDetailRestServlet(RestServlet):
|
|||
PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
|
|
|
@ -100,7 +100,7 @@ class DestinationsRestServlet(RestServlet):
|
|||
200 OK with details of a destination if success otherwise an error.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/federation/destinations/(?P<destination>[^/]+)$")
|
||||
PATTERNS = admin_patterns("/federation/destinations/(?P<destination>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._auth = hs.get_auth()
|
||||
|
|
|
@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
|
|||
class DeleteGroupAdminRestServlet(RestServlet):
|
||||
"""Allows deleting of local groups"""
|
||||
|
||||
PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)")
|
||||
PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.group_server = hs.get_groups_server_handler()
|
||||
|
|
|
@ -17,7 +17,7 @@ import logging
|
|||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
|
||||
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
|
@ -41,9 +41,9 @@ class QuarantineMediaInRoom(RestServlet):
|
|||
"""
|
||||
|
||||
PATTERNS = [
|
||||
*admin_patterns("/room/(?P<room_id>[^/]+)/media/quarantine$"),
|
||||
*admin_patterns("/room/(?P<room_id>[^/]*)/media/quarantine$"),
|
||||
# This path kept around for legacy reasons
|
||||
*admin_patterns("/quarantine_media/(?P<room_id>[^/]+)"),
|
||||
*admin_patterns("/quarantine_media/(?P<room_id>[^/]*)$"),
|
||||
]
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
|
@ -71,7 +71,7 @@ class QuarantineMediaByUser(RestServlet):
|
|||
this server.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine$")
|
||||
PATTERNS = admin_patterns("/user/(?P<user_id>[^/]*)/media/quarantine$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
|
@ -99,7 +99,7 @@ class QuarantineMediaByID(RestServlet):
|
|||
"""
|
||||
|
||||
PATTERNS = admin_patterns(
|
||||
"/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
|
||||
"/media/quarantine/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
|
@ -128,7 +128,7 @@ class UnquarantineMediaByID(RestServlet):
|
|||
"""
|
||||
|
||||
PATTERNS = admin_patterns(
|
||||
"/media/unquarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
|
||||
"/media/unquarantine/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
|
@ -138,8 +138,7 @@ class UnquarantineMediaByID(RestServlet):
|
|||
async def on_POST(
|
||||
self, request: SynapseRequest, server_name: str, media_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
logging.info(
|
||||
"Remove from quarantine local media by ID: %s/%s", server_name, media_id
|
||||
|
@ -154,7 +153,7 @@ class UnquarantineMediaByID(RestServlet):
|
|||
class ProtectMediaByID(RestServlet):
|
||||
"""Protect local media from being quarantined."""
|
||||
|
||||
PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
|
||||
PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
|
@ -163,8 +162,7 @@ class ProtectMediaByID(RestServlet):
|
|||
async def on_POST(
|
||||
self, request: SynapseRequest, media_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
logging.info("Protecting local media by ID: %s", media_id)
|
||||
|
||||
|
@ -177,7 +175,7 @@ class ProtectMediaByID(RestServlet):
|
|||
class UnprotectMediaByID(RestServlet):
|
||||
"""Unprotect local media from being quarantined."""
|
||||
|
||||
PATTERNS = admin_patterns("/media/unprotect/(?P<media_id>[^/]+)")
|
||||
PATTERNS = admin_patterns("/media/unprotect/(?P<media_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
|
@ -186,8 +184,7 @@ class UnprotectMediaByID(RestServlet):
|
|||
async def on_POST(
|
||||
self, request: SynapseRequest, media_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
logging.info("Unprotecting local media by ID: %s", media_id)
|
||||
|
||||
|
@ -200,7 +197,7 @@ class UnprotectMediaByID(RestServlet):
|
|||
class ListMediaInRoom(RestServlet):
|
||||
"""Lists all of the media in a given room."""
|
||||
|
||||
PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media$")
|
||||
PATTERNS = admin_patterns("/room/(?P<room_id>[^/]*)/media$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
|
@ -209,10 +206,7 @@ class ListMediaInRoom(RestServlet):
|
|||
async def on_GET(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
is_admin = await self.auth.is_server_admin(requester.user)
|
||||
if not is_admin:
|
||||
raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id)
|
||||
|
||||
|
@ -254,7 +248,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
|
|||
class DeleteMediaByID(RestServlet):
|
||||
"""Delete local media by a given ID. Removes it from this server."""
|
||||
|
||||
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
|
||||
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
|
@ -286,7 +280,7 @@ class DeleteMediaByDateSize(RestServlet):
|
|||
timestamp and size.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete$")
|
||||
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/delete$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
|
@ -353,7 +347,7 @@ class UserMediaRestServlet(RestServlet):
|
|||
media that exist given for this user
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$")
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/media$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.is_mine = hs.is_mine
|
||||
|
@ -403,16 +397,7 @@ class UserMediaRestServlet(RestServlet):
|
|||
request,
|
||||
"order_by",
|
||||
default=MediaSortOrder.CREATED_TS.value,
|
||||
allowed_values=(
|
||||
MediaSortOrder.MEDIA_ID.value,
|
||||
MediaSortOrder.UPLOAD_NAME.value,
|
||||
MediaSortOrder.CREATED_TS.value,
|
||||
MediaSortOrder.LAST_ACCESS_TS.value,
|
||||
MediaSortOrder.MEDIA_LENGTH.value,
|
||||
MediaSortOrder.MEDIA_TYPE.value,
|
||||
MediaSortOrder.QUARANTINED_BY.value,
|
||||
MediaSortOrder.SAFE_FROM_QUARANTINE.value,
|
||||
),
|
||||
allowed_values=[sort_order.value for sort_order in MediaSortOrder],
|
||||
)
|
||||
direction = parse_string(
|
||||
request, "dir", default="f", allowed_values=("f", "b")
|
||||
|
@ -470,16 +455,7 @@ class UserMediaRestServlet(RestServlet):
|
|||
request,
|
||||
"order_by",
|
||||
default=MediaSortOrder.CREATED_TS.value,
|
||||
allowed_values=(
|
||||
MediaSortOrder.MEDIA_ID.value,
|
||||
MediaSortOrder.UPLOAD_NAME.value,
|
||||
MediaSortOrder.CREATED_TS.value,
|
||||
MediaSortOrder.LAST_ACCESS_TS.value,
|
||||
MediaSortOrder.MEDIA_LENGTH.value,
|
||||
MediaSortOrder.MEDIA_TYPE.value,
|
||||
MediaSortOrder.QUARANTINED_BY.value,
|
||||
MediaSortOrder.SAFE_FROM_QUARANTINE.value,
|
||||
),
|
||||
allowed_values=[sort_order.value for sort_order in MediaSortOrder],
|
||||
)
|
||||
direction = parse_string(
|
||||
request, "dir", default="f", allowed_values=("f", "b")
|
||||
|
|
|
@ -70,7 +70,6 @@ class ListRegistrationTokensRestServlet(RestServlet):
|
|||
PATTERNS = admin_patterns("/registration_tokens$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
|
@ -109,7 +108,6 @@ class NewRegistrationTokenRestServlet(RestServlet):
|
|||
PATTERNS = admin_patterns("/registration_tokens/new$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
|
@ -260,7 +258,6 @@ class RegistrationTokenRestServlet(RestServlet):
|
|||
PATTERNS = admin_patterns("/registration_tokens/(?P<token>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
|
|
@ -61,7 +61,7 @@ class RoomRestV2Servlet(RestServlet):
|
|||
If 'purge' is true, it will remove all traces of a room from the database.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$", "v2")
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)$", "v2")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._auth = hs.get_auth()
|
||||
|
@ -123,7 +123,7 @@ class RoomRestV2Servlet(RestServlet):
|
|||
class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
|
||||
"""Get the status of the delete room background task."""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete_status$", "v2")
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/delete_status$", "v2")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._auth = hs.get_auth()
|
||||
|
@ -160,7 +160,7 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
|
|||
class DeleteRoomStatusByDeleteIdRestServlet(RestServlet):
|
||||
"""Get the status of the delete room background task."""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/delete_status/(?P<delete_id>[^/]+)$", "v2")
|
||||
PATTERNS = admin_patterns("/rooms/delete_status/(?P<delete_id>[^/]*)$", "v2")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._auth = hs.get_auth()
|
||||
|
@ -193,35 +193,17 @@ class ListRoomRestServlet(RestServlet):
|
|||
self.admin_handler = hs.get_admin_handler()
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
# Extract query parameters
|
||||
start = parse_integer(request, "from", default=0)
|
||||
limit = parse_integer(request, "limit", default=100)
|
||||
order_by = parse_string(request, "order_by", default=RoomSortOrder.NAME.value)
|
||||
if order_by not in (
|
||||
RoomSortOrder.ALPHABETICAL.value,
|
||||
RoomSortOrder.SIZE.value,
|
||||
RoomSortOrder.NAME.value,
|
||||
RoomSortOrder.CANONICAL_ALIAS.value,
|
||||
RoomSortOrder.JOINED_MEMBERS.value,
|
||||
RoomSortOrder.JOINED_LOCAL_MEMBERS.value,
|
||||
RoomSortOrder.VERSION.value,
|
||||
RoomSortOrder.CREATOR.value,
|
||||
RoomSortOrder.ENCRYPTION.value,
|
||||
RoomSortOrder.FEDERATABLE.value,
|
||||
RoomSortOrder.PUBLIC.value,
|
||||
RoomSortOrder.JOIN_RULES.value,
|
||||
RoomSortOrder.GUEST_ACCESS.value,
|
||||
RoomSortOrder.HISTORY_VISIBILITY.value,
|
||||
RoomSortOrder.STATE_EVENTS.value,
|
||||
):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"Unknown value for order_by: %s" % (order_by,),
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
order_by = parse_string(
|
||||
request,
|
||||
"order_by",
|
||||
default=RoomSortOrder.NAME.value,
|
||||
allowed_values=[sort_order.value for sort_order in RoomSortOrder],
|
||||
)
|
||||
|
||||
search_term = parse_string(request, "search_term", encoding="utf-8")
|
||||
if search_term == "":
|
||||
|
@ -292,10 +274,9 @@ class RoomRestServlet(RestServlet):
|
|||
TODO: Add on_POST to allow room creation without joining the room
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$")
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.room_shutdown_handler = hs.get_room_shutdown_handler()
|
||||
|
@ -397,10 +378,9 @@ class RoomMembersRestServlet(RestServlet):
|
|||
Get members list of a room.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members")
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/members$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
|
@ -424,10 +404,9 @@ class RoomStateRestServlet(RestServlet):
|
|||
Get full state within a room.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/state")
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/state$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
|
@ -436,8 +415,7 @@ class RoomStateRestServlet(RestServlet):
|
|||
async def on_GET(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
ret = await self.store.get_room(room_id)
|
||||
if not ret:
|
||||
|
@ -454,14 +432,14 @@ class RoomStateRestServlet(RestServlet):
|
|||
|
||||
class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
|
||||
|
||||
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
|
||||
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.admin_handler = hs.get_admin_handler()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self.is_mine = hs.is_mine
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, room_identifier: str
|
||||
|
@ -477,7 +455,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
|
|||
assert_params_in_dict(content, ["user_id"])
|
||||
target_user = UserID.from_string(content["user_id"])
|
||||
|
||||
if not self.hs.is_mine(target_user):
|
||||
if not self.is_mine(target_user):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"This endpoint can only be used with local users",
|
||||
|
@ -542,11 +520,10 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
|
|||
}
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin")
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
|
@ -688,19 +665,17 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
|
|||
GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities")
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_DELETE(
|
||||
self, request: SynapseRequest, room_identifier: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
room_id, _ = await self.resolve_room_id(room_identifier)
|
||||
|
||||
|
@ -710,8 +685,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
|
|||
async def on_GET(
|
||||
self, request: SynapseRequest, room_identifier: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
room_id, _ = await self.resolve_room_id(room_identifier)
|
||||
|
||||
|
@ -771,13 +745,19 @@ class RoomEventContextServlet(RestServlet):
|
|||
|
||||
time_now = self.clock.time_msec()
|
||||
results["events_before"] = await self._event_serializer.serialize_events(
|
||||
results["events_before"], time_now
|
||||
results["events_before"],
|
||||
time_now,
|
||||
bundle_aggregations=True,
|
||||
)
|
||||
results["event"] = await self._event_serializer.serialize_event(
|
||||
results["event"], time_now
|
||||
results["event"],
|
||||
time_now,
|
||||
bundle_aggregations=True,
|
||||
)
|
||||
results["events_after"] = await self._event_serializer.serialize_events(
|
||||
results["events_after"], time_now
|
||||
results["events_after"],
|
||||
time_now,
|
||||
bundle_aggregations=True,
|
||||
)
|
||||
results["state"] = await self._event_serializer.serialize_events(
|
||||
results["state"], time_now
|
||||
|
@ -793,7 +773,7 @@ class BlockRoomRestServlet(RestServlet):
|
|||
On GET: Get blocking status of room and user who has blocked this room.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/block$")
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/block$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._auth = hs.get_auth()
|
||||
|
|
|
@ -52,11 +52,11 @@ class SendServerNoticeServlet(RestServlet):
|
|||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.server_notices_manager = hs.get_server_notices_manager()
|
||||
self.admin_handler = hs.get_admin_handler()
|
||||
self.txns = HttpTransactionCache(hs)
|
||||
self.is_mine = hs.is_mine
|
||||
|
||||
def register(self, json_resource: HttpServer) -> None:
|
||||
PATTERN = "/send_server_notice"
|
||||
|
@ -88,7 +88,7 @@ class SendServerNoticeServlet(RestServlet):
|
|||
)
|
||||
|
||||
target_user = UserID.from_string(body["user_id"])
|
||||
if not self.hs.is_mine(target_user):
|
||||
if not self.is_mine(target_user):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Server notices can only be sent to local users"
|
||||
)
|
||||
|
|
|
@ -37,7 +37,6 @@ class UserMediaStatisticsRestServlet(RestServlet):
|
|||
PATTERNS = admin_patterns("/statistics/users/media$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
|
@ -45,19 +44,16 @@ class UserMediaStatisticsRestServlet(RestServlet):
|
|||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
order_by = parse_string(
|
||||
request, "order_by", default=UserSortOrder.USER_ID.value
|
||||
request,
|
||||
"order_by",
|
||||
default=UserSortOrder.USER_ID.value,
|
||||
allowed_values=(
|
||||
UserSortOrder.MEDIA_LENGTH.value,
|
||||
UserSortOrder.MEDIA_COUNT.value,
|
||||
UserSortOrder.USER_ID.value,
|
||||
UserSortOrder.DISPLAYNAME.value,
|
||||
),
|
||||
)
|
||||
if order_by not in (
|
||||
UserSortOrder.MEDIA_LENGTH.value,
|
||||
UserSortOrder.MEDIA_COUNT.value,
|
||||
UserSortOrder.USER_ID.value,
|
||||
UserSortOrder.DISPLAYNAME.value,
|
||||
):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"Unknown value for order_by: %s" % (order_by,),
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
start = parse_integer(request, "from", default=0)
|
||||
if start < 0:
|
||||
|
|
|
@ -37,7 +37,7 @@ class UsernameAvailableRestServlet(RestServlet):
|
|||
}
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/username_available")
|
||||
PATTERNS = admin_patterns("/username_available$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.auth = hs.get_auth()
|
||||
|
|
|
@ -66,7 +66,6 @@ class UsersRestServletV2(RestServlet):
|
|||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.admin_handler = hs.get_admin_handler()
|
||||
|
@ -126,7 +125,7 @@ class UsersRestServletV2(RestServlet):
|
|||
|
||||
|
||||
class UserRestServletV2(RestServlet):
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)$", "v2")
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$", "v2")
|
||||
|
||||
"""Get request to list user details.
|
||||
This needs user to have administrator access in Synapse.
|
||||
|
@ -414,7 +413,7 @@ class UserRegisterServlet(RestServlet):
|
|||
nonce to the time it was generated, in int seconds.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/register")
|
||||
PATTERNS = admin_patterns("/register$")
|
||||
NONCE_TIMEOUT = 60
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
|
@ -561,9 +560,9 @@ class WhoisRestServlet(RestServlet):
|
|||
]
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.admin_handler = hs.get_admin_handler()
|
||||
self.is_mine = hs.is_mine
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
|
@ -575,7 +574,7 @@ class WhoisRestServlet(RestServlet):
|
|||
if target_user != auth_user:
|
||||
await assert_user_is_admin(self.auth, auth_user)
|
||||
|
||||
if not self.hs.is_mine(target_user):
|
||||
if not self.is_mine(target_user):
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user")
|
||||
|
||||
ret = await self.admin_handler.get_whois(target_user)
|
||||
|
@ -584,7 +583,7 @@ class WhoisRestServlet(RestServlet):
|
|||
|
||||
|
||||
class DeactivateAccountRestServlet(RestServlet):
|
||||
PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)")
|
||||
PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._deactivate_account_handler = hs.get_deactivate_account_handler()
|
||||
|
@ -630,7 +629,6 @@ class AccountValidityRenewServlet(RestServlet):
|
|||
PATTERNS = admin_patterns("/account_validity/validity$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.account_activity_handler = hs.get_account_validity_handler()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
|
@ -674,11 +672,10 @@ class ResetPasswordRestServlet(RestServlet):
|
|||
200 OK with empty object if success otherwise an error.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
|
||||
PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self._set_password_handler = hs.get_set_password_handler()
|
||||
|
@ -718,12 +715,12 @@ class SearchUsersRestServlet(RestServlet):
|
|||
200 OK with json object {list[dict[str, Any]], count} or empty object.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)")
|
||||
PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.is_mine = hs.is_mine
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, target_user_id: str
|
||||
|
@ -740,7 +737,7 @@ class SearchUsersRestServlet(RestServlet):
|
|||
# if not is_admin and target_user != auth_user:
|
||||
# raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
|
||||
|
||||
if not self.hs.is_mine(target_user):
|
||||
if not self.is_mine(target_user):
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only users a local user")
|
||||
|
||||
term = parse_string(request, "term", required=True)
|
||||
|
@ -779,9 +776,9 @@ class UserAdminServlet(RestServlet):
|
|||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.is_mine = hs.is_mine
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
|
@ -790,7 +787,7 @@ class UserAdminServlet(RestServlet):
|
|||
|
||||
target_user = UserID.from_string(user_id)
|
||||
|
||||
if not self.hs.is_mine(target_user):
|
||||
if not self.is_mine(target_user):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"Only local users can be admins of this homeserver",
|
||||
|
@ -813,7 +810,7 @@ class UserAdminServlet(RestServlet):
|
|||
|
||||
assert_params_in_dict(body, ["admin"])
|
||||
|
||||
if not self.hs.is_mine(target_user):
|
||||
if not self.is_mine(target_user):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"Only local users can be admins of this homeserver",
|
||||
|
@ -834,7 +831,7 @@ class UserMembershipRestServlet(RestServlet):
|
|||
Get room list of an user.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$")
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/joined_rooms$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.is_mine = hs.is_mine
|
||||
|
@ -909,10 +906,10 @@ class UserTokenRestServlet(RestServlet):
|
|||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/login$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
|
@ -921,7 +918,7 @@ class UserTokenRestServlet(RestServlet):
|
|||
await assert_user_is_admin(self.auth, requester.user)
|
||||
auth_user = requester.user
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
if not self.is_mine_id(user_id):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Only local users can be logged in as"
|
||||
)
|
||||
|
@ -975,19 +972,19 @@ class ShadowBanRestServlet(RestServlet):
|
|||
{}
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban")
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
if not self.is_mine_id(user_id):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned"
|
||||
)
|
||||
|
@ -1001,7 +998,7 @@ class ShadowBanRestServlet(RestServlet):
|
|||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
if not self.is_mine_id(user_id):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned"
|
||||
)
|
||||
|
@ -1027,19 +1024,19 @@ class RateLimitRestServlet(RestServlet):
|
|||
}
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit")
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
if not self.is_mine_id(user_id):
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
|
||||
|
||||
if not await self.store.get_user_by_id(user_id):
|
||||
|
@ -1068,7 +1065,7 @@ class RateLimitRestServlet(RestServlet):
|
|||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
if not self.is_mine_id(user_id):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited"
|
||||
)
|
||||
|
@ -1113,7 +1110,7 @@ class RateLimitRestServlet(RestServlet):
|
|||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
if not self.is_mine_id(user_id):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited"
|
||||
)
|
||||
|
@ -1124,3 +1121,33 @@ class RateLimitRestServlet(RestServlet):
|
|||
await self.store.delete_ratelimit_for_user(user_id)
|
||||
|
||||
return HTTPStatus.OK, {}
|
||||
|
||||
|
||||
class AccountDataRestServlet(RestServlet):
|
||||
"""Retrieve the given user's account data"""
|
||||
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/accountdata")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._auth = hs.get_auth()
|
||||
self._store = hs.get_datastore()
|
||||
self._is_mine_id = hs.is_mine_id
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self._auth, request)
|
||||
|
||||
if not self._is_mine_id(user_id):
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
|
||||
|
||||
if not await self._store.get_user_by_id(user_id):
|
||||
raise NotFoundError("User not found")
|
||||
|
||||
global_data, by_room_data = await self._store.get_account_data_for_user(user_id)
|
||||
return HTTPStatus.OK, {
|
||||
"account_data": {
|
||||
"global": global_data,
|
||||
"rooms": by_room_data,
|
||||
},
|
||||
}
|
||||
|
|
|
@ -73,6 +73,9 @@ class CapabilitiesRestServlet(RestServlet):
|
|||
"enabled": self.config.registration.enable_3pid_changes
|
||||
}
|
||||
|
||||
if self.config.experimental.msc3440_enabled:
|
||||
response["capabilities"]["io.element.thread"] = {"enabled": True}
|
||||
|
||||
return 200, response
|
||||
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ import logging
|
|||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api import errors
|
||||
from synapse.api.errors import NotFoundError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
|
@ -24,10 +25,9 @@ from synapse.http.servlet import (
|
|||
parse_json_object_from_request,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client._base import client_patterns, interactive_auth_handler
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import client_patterns, interactive_auth_handler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
@ -116,6 +116,8 @@ class DeviceRestServlet(RestServlet):
|
|||
device = await self.device_handler.get_device(
|
||||
requester.user.to_string(), device_id
|
||||
)
|
||||
if device is None:
|
||||
raise NotFoundError("No device found")
|
||||
return 200, device
|
||||
|
||||
@interactive_auth_handler
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.constants import ReceiptTypes
|
||||
from synapse.events.utils import format_event_for_client_v2_without_room_id
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_integer, parse_string
|
||||
|
@ -54,10 +55,10 @@ class NotificationsServlet(RestServlet):
|
|||
)
|
||||
|
||||
receipts_by_room = await self.store.get_receipts_for_user_with_orderings(
|
||||
user_id, "m.read"
|
||||
user_id, ReceiptTypes.READ
|
||||
)
|
||||
|
||||
notif_event_ids = [pa["event_id"] for pa in push_actions]
|
||||
notif_event_ids = [pa.event_id for pa in push_actions]
|
||||
notif_events = await self.store.get_events(notif_event_ids)
|
||||
|
||||
returned_push_actions = []
|
||||
|
@ -66,30 +67,30 @@ class NotificationsServlet(RestServlet):
|
|||
|
||||
for pa in push_actions:
|
||||
returned_pa = {
|
||||
"room_id": pa["room_id"],
|
||||
"profile_tag": pa["profile_tag"],
|
||||
"actions": pa["actions"],
|
||||
"ts": pa["received_ts"],
|
||||
"room_id": pa.room_id,
|
||||
"profile_tag": pa.profile_tag,
|
||||
"actions": pa.actions,
|
||||
"ts": pa.received_ts,
|
||||
"event": (
|
||||
await self._event_serializer.serialize_event(
|
||||
notif_events[pa["event_id"]],
|
||||
notif_events[pa.event_id],
|
||||
self.clock.time_msec(),
|
||||
event_format=format_event_for_client_v2_without_room_id,
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
if pa["room_id"] not in receipts_by_room:
|
||||
if pa.room_id not in receipts_by_room:
|
||||
returned_pa["read"] = False
|
||||
else:
|
||||
receipt = receipts_by_room[pa["room_id"]]
|
||||
receipt = receipts_by_room[pa.room_id]
|
||||
|
||||
returned_pa["read"] = (
|
||||
receipt["topological_ordering"],
|
||||
receipt["stream_ordering"],
|
||||
) >= (pa["topological_ordering"], pa["stream_ordering"])
|
||||
) >= (pa.topological_ordering, pa.stream_ordering)
|
||||
returned_push_actions.append(returned_pa)
|
||||
next_token = str(pa["stream_ordering"])
|
||||
next_token = str(pa.stream_ordering)
|
||||
|
||||
return 200, {"notifications": returned_push_actions, "next_token": next_token}
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.constants import ReadReceiptEventFields
|
||||
from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
|
@ -48,7 +48,7 @@ class ReadMarkerRestServlet(RestServlet):
|
|||
await self.presence_handler.bump_presence_active_time(requester.user)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
read_event_id = body.get("m.read", None)
|
||||
read_event_id = body.get(ReceiptTypes.READ, None)
|
||||
read_extra = body.get("com.beeper.read.extra", None)
|
||||
hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False)
|
||||
|
||||
|
@ -63,7 +63,7 @@ class ReadMarkerRestServlet(RestServlet):
|
|||
if read_event_id:
|
||||
await self.receipts_handler.received_client_receipt(
|
||||
room_id,
|
||||
"m.read",
|
||||
ReceiptTypes.READ,
|
||||
user_id=requester.user.to_string(),
|
||||
event_id=read_event_id,
|
||||
hidden=hidden,
|
||||
|
|
|
@ -16,7 +16,7 @@ import logging
|
|||
import re
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.constants import ReadReceiptEventFields
|
||||
from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http import get_request_user_agent
|
||||
from synapse.http.server import HttpServer
|
||||
|
@ -53,7 +53,7 @@ class ReceiptRestServlet(RestServlet):
|
|||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
if receipt_type != "m.read":
|
||||
if receipt_type != ReceiptTypes.READ:
|
||||
raise SynapseError(400, "Receipt type must be 'm.read'")
|
||||
|
||||
# Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body.
|
||||
|
|
|
@ -212,6 +212,7 @@ class RelationPaginationServlet(RestServlet):
|
|||
|
||||
pagination_chunk = await self.store.get_relations_for_event(
|
||||
event_id=parent_id,
|
||||
room_id=room_id,
|
||||
relation_type=relation_type,
|
||||
event_type=event_type,
|
||||
limit=limit,
|
||||
|
@ -231,7 +232,9 @@ class RelationPaginationServlet(RestServlet):
|
|||
)
|
||||
# The relations returned for the requested event do include their
|
||||
# bundled aggregations.
|
||||
serialized_events = await self._event_serializer.serialize_events(events, now)
|
||||
serialized_events = await self._event_serializer.serialize_events(
|
||||
events, now, bundle_aggregations=True
|
||||
)
|
||||
|
||||
return_value = pagination_chunk.to_dict()
|
||||
return_value["chunk"] = serialized_events
|
||||
|
@ -317,6 +320,7 @@ class RelationAggregationPaginationServlet(RestServlet):
|
|||
|
||||
pagination_chunk = await self.store.get_aggregation_groups_for_event(
|
||||
event_id=parent_id,
|
||||
room_id=room_id,
|
||||
event_type=event_type,
|
||||
limit=limit,
|
||||
from_token=from_token,
|
||||
|
@ -383,7 +387,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
|
|||
|
||||
# This checks that a) the event exists and b) the user is allowed to
|
||||
# view it.
|
||||
await self.event_handler.get_event(requester.user, room_id, parent_id)
|
||||
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
|
||||
if event is None:
|
||||
raise SynapseError(404, "Unknown parent event.")
|
||||
|
||||
if relation_type != RelationTypes.ANNOTATION:
|
||||
raise SynapseError(400, "Relation type must be 'annotation'")
|
||||
|
@ -402,6 +408,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
|
|||
|
||||
result = await self.store.get_relations_for_event(
|
||||
event_id=parent_id,
|
||||
room_id=room_id,
|
||||
relation_type=relation_type,
|
||||
event_type=event_type,
|
||||
aggregation_key=key,
|
||||
|
|
|
@ -187,7 +187,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
|
|||
state_key: str,
|
||||
txn_id: Optional[str] = None,
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
if txn_id:
|
||||
set_tag("txn_id", txn_id)
|
||||
|
@ -666,7 +666,9 @@ class RoomEventServlet(RestServlet):
|
|||
|
||||
time_now = self.clock.time_msec()
|
||||
if event:
|
||||
event_dict = await self._event_serializer.serialize_event(event, time_now)
|
||||
event_dict = await self._event_serializer.serialize_event(
|
||||
event, time_now, bundle_aggregations=True
|
||||
)
|
||||
return 200, event_dict
|
||||
|
||||
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
|
||||
|
@ -711,13 +713,13 @@ class RoomEventContextServlet(RestServlet):
|
|||
|
||||
time_now = self.clock.time_msec()
|
||||
results["events_before"] = await self._event_serializer.serialize_events(
|
||||
results["events_before"], time_now
|
||||
results["events_before"], time_now, bundle_aggregations=True
|
||||
)
|
||||
results["event"] = await self._event_serializer.serialize_event(
|
||||
results["event"], time_now
|
||||
results["event"], time_now, bundle_aggregations=True
|
||||
)
|
||||
results["events_after"] = await self._event_serializer.serialize_events(
|
||||
results["events_after"], time_now
|
||||
results["events_after"], time_now, bundle_aggregations=True
|
||||
)
|
||||
results["state"] = await self._event_serializer.serialize_events(
|
||||
results["state"], time_now
|
||||
|
|
|
@ -48,6 +48,7 @@ from synapse.handlers.sync import (
|
|||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.opentracing import trace
|
||||
from synapse.types import JsonDict, StreamToken
|
||||
from synapse.util import json_decoder
|
||||
|
||||
|
@ -222,6 +223,7 @@ class SyncRestServlet(RestServlet):
|
|||
logger.debug("Event formatting complete")
|
||||
return 200, response_content
|
||||
|
||||
@trace(opname="sync.encode_response")
|
||||
async def encode_response(
|
||||
self,
|
||||
time_now: int,
|
||||
|
@ -293,6 +295,9 @@ class SyncRestServlet(RestServlet):
|
|||
response[
|
||||
"org.matrix.msc2732.device_unused_fallback_key_types"
|
||||
] = sync_result.device_unused_fallback_key_types
|
||||
response[
|
||||
"device_unused_fallback_key_types"
|
||||
] = sync_result.device_unused_fallback_key_types
|
||||
|
||||
if joined:
|
||||
response["rooms"][Membership.JOIN] = joined
|
||||
|
@ -329,6 +334,7 @@ class SyncRestServlet(RestServlet):
|
|||
]
|
||||
}
|
||||
|
||||
@trace(opname="sync.encode_joined")
|
||||
async def encode_joined(
|
||||
self,
|
||||
rooms: List[JoinedSyncResult],
|
||||
|
@ -365,6 +371,7 @@ class SyncRestServlet(RestServlet):
|
|||
|
||||
return joined
|
||||
|
||||
@trace(opname="sync.encode_invited")
|
||||
async def encode_invited(
|
||||
self,
|
||||
rooms: List[InvitedSyncResult],
|
||||
|
@ -403,6 +410,7 @@ class SyncRestServlet(RestServlet):
|
|||
|
||||
return invited
|
||||
|
||||
@trace(opname="sync.encode_knocked")
|
||||
async def encode_knocked(
|
||||
self,
|
||||
rooms: List[KnockedSyncResult],
|
||||
|
@ -457,6 +465,7 @@ class SyncRestServlet(RestServlet):
|
|||
|
||||
return knocked
|
||||
|
||||
@trace(opname="sync.encode_archived")
|
||||
async def encode_archived(
|
||||
self,
|
||||
rooms: List[ArchivedSyncResult],
|
||||
|
|
|
@ -93,6 +93,10 @@ class VersionsRestServlet(RestServlet):
|
|||
"org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled,
|
||||
# Supports receiving hidden read receipts as per MSC2285
|
||||
"org.matrix.msc2285": self.config.experimental.msc2285_enabled,
|
||||
# Adds support for importing historical messages as per MSC2716
|
||||
"org.matrix.msc2716": self.config.experimental.msc2716_enabled,
|
||||
# Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030
|
||||
"org.matrix.msc3030": self.config.experimental.msc3030_enabled,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
from signedjson.sign import sign_json
|
||||
|
@ -99,7 +99,7 @@ class LocalKey(Resource):
|
|||
json_object = sign_json(json_object, self.config.server.server_name, key)
|
||||
return json_object
|
||||
|
||||
def render_GET(self, request: Request) -> int:
|
||||
def render_GET(self, request: Request) -> Optional[int]:
|
||||
time_now = self.clock.time_msec()
|
||||
# Update the expiry time if less than half the interval remains.
|
||||
if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts:
|
||||
|
|
|
@ -739,14 +739,21 @@ class MediaRepository:
|
|||
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
|
||||
# they have the same dimensions of a scaled one.
|
||||
thumbnails: Dict[Tuple[int, int, str], str] = {}
|
||||
for r_width, r_height, r_method, r_type in requirements:
|
||||
if r_method == "crop":
|
||||
thumbnails.setdefault((r_width, r_height, r_type), r_method)
|
||||
elif r_method == "scale":
|
||||
t_width, t_height = thumbnailer.aspect(r_width, r_height)
|
||||
for requirement in requirements:
|
||||
if requirement.method == "crop":
|
||||
thumbnails.setdefault(
|
||||
(requirement.width, requirement.height, requirement.media_type),
|
||||
requirement.method,
|
||||
)
|
||||
elif requirement.method == "scale":
|
||||
t_width, t_height = thumbnailer.aspect(
|
||||
requirement.width, requirement.height
|
||||
)
|
||||
t_width = min(m_width, t_width)
|
||||
t_height = min(m_height, t_height)
|
||||
thumbnails[(t_width, t_height, r_type)] = r_method
|
||||
thumbnails[
|
||||
(t_width, t_height, requirement.media_type)
|
||||
] = requirement.method
|
||||
|
||||
# Now we generate the thumbnails for each dimension, store it
|
||||
for (t_width, t_height, t_type), t_method in thumbnails.items():
|
||||
|
|
|
@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, List, Optional
|
|||
|
||||
import attr
|
||||
|
||||
from synapse.rest.media.v1.preview_html import parse_html_description
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_decoder
|
||||
|
||||
|
@ -245,8 +246,6 @@ def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) ->
|
|||
if video_urls:
|
||||
open_graph_response["og:video"] = video_urls[0]
|
||||
|
||||
from synapse.rest.media.v1.preview_url_resource import _calc_description
|
||||
|
||||
description = _calc_description(tree)
|
||||
description = parse_html_description(tree)
|
||||
if description:
|
||||
open_graph_response["og:description"] = description
|
||||
|
|
397
synapse/rest/media/v1/preview_html.py
Normal file
397
synapse/rest/media/v1/preview_html.py
Normal file
|
@ -0,0 +1,397 @@
|
|||
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import codecs
|
||||
import itertools
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union
|
||||
from urllib import parse as urlparse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lxml import etree
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_charset_match = re.compile(
|
||||
br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I
|
||||
)
|
||||
_xml_encoding_match = re.compile(
|
||||
br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I
|
||||
)
|
||||
_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
|
||||
|
||||
|
||||
def _normalise_encoding(encoding: str) -> Optional[str]:
|
||||
"""Use the Python codec's name as the normalised entry."""
|
||||
try:
|
||||
return codecs.lookup(encoding).name
|
||||
except LookupError:
|
||||
return None
|
||||
|
||||
|
||||
def _get_html_media_encodings(
|
||||
body: bytes, content_type: Optional[str]
|
||||
) -> Iterable[str]:
|
||||
"""
|
||||
Get potential encoding of the body based on the (presumably) HTML body or the content-type header.
|
||||
|
||||
The precedence used for finding a character encoding is:
|
||||
|
||||
1. <meta> tag with a charset declared.
|
||||
2. The XML document's character encoding attribute.
|
||||
3. The Content-Type header.
|
||||
4. Fallback to utf-8.
|
||||
5. Fallback to windows-1252.
|
||||
|
||||
This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector.
|
||||
|
||||
Args:
|
||||
body: The HTML document, as bytes.
|
||||
content_type: The Content-Type header.
|
||||
|
||||
Returns:
|
||||
The character encoding of the body, as a string.
|
||||
"""
|
||||
# There's no point in returning an encoding more than once.
|
||||
attempted_encodings: Set[str] = set()
|
||||
|
||||
# Limit searches to the first 1kb, since it ought to be at the top.
|
||||
body_start = body[:1024]
|
||||
|
||||
# Check if it has an encoding set in a meta tag.
|
||||
match = _charset_match.search(body_start)
|
||||
if match:
|
||||
encoding = _normalise_encoding(match.group(1).decode("ascii"))
|
||||
if encoding:
|
||||
attempted_encodings.add(encoding)
|
||||
yield encoding
|
||||
|
||||
# TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
|
||||
|
||||
# Check if it has an XML document with an encoding.
|
||||
match = _xml_encoding_match.match(body_start)
|
||||
if match:
|
||||
encoding = _normalise_encoding(match.group(1).decode("ascii"))
|
||||
if encoding and encoding not in attempted_encodings:
|
||||
attempted_encodings.add(encoding)
|
||||
yield encoding
|
||||
|
||||
# Check the HTTP Content-Type header for a character set.
|
||||
if content_type:
|
||||
content_match = _content_type_match.match(content_type)
|
||||
if content_match:
|
||||
encoding = _normalise_encoding(content_match.group(1))
|
||||
if encoding and encoding not in attempted_encodings:
|
||||
attempted_encodings.add(encoding)
|
||||
yield encoding
|
||||
|
||||
# Finally, fallback to UTF-8, then windows-1252.
|
||||
for fallback in ("utf-8", "cp1252"):
|
||||
if fallback not in attempted_encodings:
|
||||
yield fallback
|
||||
|
||||
|
||||
def decode_body(
|
||||
body: bytes, uri: str, content_type: Optional[str] = None
|
||||
) -> Optional["etree.Element"]:
|
||||
"""
|
||||
This uses lxml to parse the HTML document.
|
||||
|
||||
Args:
|
||||
body: The HTML document, as bytes.
|
||||
uri: The URI used to download the body.
|
||||
content_type: The Content-Type header.
|
||||
|
||||
Returns:
|
||||
The parsed HTML body, or None if an error occurred during processed.
|
||||
"""
|
||||
# If there's no body, nothing useful is going to be found.
|
||||
if not body:
|
||||
return None
|
||||
|
||||
# The idea here is that multiple encodings are tried until one works.
|
||||
# Unfortunately the result is never used and then LXML will decode the string
|
||||
# again with the found encoding.
|
||||
for encoding in _get_html_media_encodings(body, content_type):
|
||||
try:
|
||||
body.decode(encoding)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
break
|
||||
else:
|
||||
logger.warning("Unable to decode HTML body for %s", uri)
|
||||
return None
|
||||
|
||||
from lxml import etree
|
||||
|
||||
# Create an HTML parser.
|
||||
parser = etree.HTMLParser(recover=True, encoding=encoding)
|
||||
|
||||
# Attempt to parse the body. Returns None if the body was successfully
|
||||
# parsed, but no tree was found.
|
||||
return etree.fromstring(body, parser)
|
||||
|
||||
|
||||
def parse_html_to_open_graph(
|
||||
tree: "etree.Element", media_uri: str
|
||||
) -> Dict[str, Optional[str]]:
|
||||
"""
|
||||
Parse the HTML document into an Open Graph response.
|
||||
|
||||
This uses lxml to search the HTML document for Open Graph data (or
|
||||
synthesizes it from the document).
|
||||
|
||||
Args:
|
||||
tree: The parsed HTML document.
|
||||
media_url: The URI used to download the body.
|
||||
|
||||
Returns:
|
||||
The Open Graph response as a dictionary.
|
||||
"""
|
||||
|
||||
# if we see any image URLs in the OG response, then spider them
|
||||
# (although the client could choose to do this by asking for previews of those
|
||||
# URLs to avoid DoSing the server)
|
||||
|
||||
# "og:type" : "video",
|
||||
# "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
|
||||
# "og:site_name" : "YouTube",
|
||||
# "og:video:type" : "application/x-shockwave-flash",
|
||||
# "og:description" : "Fun stuff happening here",
|
||||
# "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
|
||||
# "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
|
||||
# "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
|
||||
# "og:video:width" : "1280"
|
||||
# "og:video:height" : "720",
|
||||
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
|
||||
|
||||
og: Dict[str, Optional[str]] = {}
|
||||
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
|
||||
if "content" in tag.attrib:
|
||||
# if we've got more than 50 tags, someone is taking the piss
|
||||
if len(og) >= 50:
|
||||
logger.warning("Skipping OG for page with too many 'og:' tags")
|
||||
return {}
|
||||
og[tag.attrib["property"]] = tag.attrib["content"]
|
||||
|
||||
# TODO: grab article: meta tags too, e.g.:
|
||||
|
||||
# "article:publisher" : "https://www.facebook.com/thethudonline" />
|
||||
# "article:author" content="https://www.facebook.com/thethudonline" />
|
||||
# "article:tag" content="baby" />
|
||||
# "article:section" content="Breaking News" />
|
||||
# "article:published_time" content="2016-03-31T19:58:24+00:00" />
|
||||
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
|
||||
|
||||
if "og:title" not in og:
|
||||
# do some basic spidering of the HTML
|
||||
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
|
||||
if title and title[0].text is not None:
|
||||
og["og:title"] = title[0].text.strip()
|
||||
else:
|
||||
og["og:title"] = None
|
||||
|
||||
if "og:image" not in og:
|
||||
# TODO: extract a favicon failing all else
|
||||
meta_image = tree.xpath(
|
||||
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
|
||||
)
|
||||
if meta_image:
|
||||
og["og:image"] = rebase_url(meta_image[0], media_uri)
|
||||
else:
|
||||
# TODO: consider inlined CSS styles as well as width & height attribs
|
||||
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
|
||||
images = sorted(
|
||||
images,
|
||||
key=lambda i: (
|
||||
-1 * float(i.attrib["width"]) * float(i.attrib["height"])
|
||||
),
|
||||
)
|
||||
if not images:
|
||||
images = tree.xpath("//img[@src]")
|
||||
if images:
|
||||
og["og:image"] = images[0].attrib["src"]
|
||||
|
||||
if "og:description" not in og:
|
||||
meta_description = tree.xpath(
|
||||
"//*/meta"
|
||||
"[translate(@name, 'DESCRIPTION', 'description')='description']"
|
||||
"/@content"
|
||||
)
|
||||
if meta_description:
|
||||
og["og:description"] = meta_description[0]
|
||||
else:
|
||||
og["og:description"] = parse_html_description(tree)
|
||||
elif og["og:description"]:
|
||||
# This must be a non-empty string at this point.
|
||||
assert isinstance(og["og:description"], str)
|
||||
og["og:description"] = summarize_paragraphs([og["og:description"]])
|
||||
|
||||
# TODO: delete the url downloads to stop diskfilling,
|
||||
# as we only ever cared about its OG
|
||||
return og
|
||||
|
||||
|
||||
def parse_html_description(tree: "etree.Element") -> Optional[str]:
|
||||
"""
|
||||
Calculate a text description based on an HTML document.
|
||||
|
||||
Grabs any text nodes which are inside the <body/> tag, unless they are within
|
||||
an HTML5 semantic markup tag (<header/>, <nav/>, <aside/>, <footer/>), or
|
||||
if they are within a <script/> or <style/> tag.
|
||||
|
||||
This is a very very very coarse approximation to a plain text render of the page.
|
||||
|
||||
Args:
|
||||
tree: The parsed HTML document.
|
||||
|
||||
Returns:
|
||||
The plain text description, or None if one cannot be generated.
|
||||
"""
|
||||
# We don't just use XPATH here as that is slow on some machines.
|
||||
|
||||
from lxml import etree
|
||||
|
||||
TAGS_TO_REMOVE = (
|
||||
"header",
|
||||
"nav",
|
||||
"aside",
|
||||
"footer",
|
||||
"script",
|
||||
"noscript",
|
||||
"style",
|
||||
etree.Comment,
|
||||
)
|
||||
|
||||
# Split all the text nodes into paragraphs (by splitting on new
|
||||
# lines)
|
||||
text_nodes = (
|
||||
re.sub(r"\s+", "\n", el).strip()
|
||||
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
|
||||
)
|
||||
return summarize_paragraphs(text_nodes)
|
||||
|
||||
|
||||
def _iterate_over_text(
|
||||
tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
|
||||
) -> Generator[str, None, None]:
|
||||
"""Iterate over the tree returning text nodes in a depth first fashion,
|
||||
skipping text nodes inside certain tags.
|
||||
"""
|
||||
# This is basically a stack that we extend using itertools.chain.
|
||||
# This will either consist of an element to iterate over *or* a string
|
||||
# to be returned.
|
||||
elements = iter([tree])
|
||||
while True:
|
||||
el = next(elements, None)
|
||||
if el is None:
|
||||
return
|
||||
|
||||
if isinstance(el, str):
|
||||
yield el
|
||||
elif el.tag not in tags_to_ignore:
|
||||
# el.text is the text before the first child, so we can immediately
|
||||
# return it if the text exists.
|
||||
if el.text:
|
||||
yield el.text
|
||||
|
||||
# We add to the stack all the elements children, interspersed with
|
||||
# each child's tail text (if it exists). The tail text of a node
|
||||
# is text that comes *after* the node, so we always include it even
|
||||
# if we ignore the child node.
|
||||
elements = itertools.chain(
|
||||
itertools.chain.from_iterable( # Basically a flatmap
|
||||
[child, child.tail] if child.tail else [child]
|
||||
for child in el.iterchildren()
|
||||
),
|
||||
elements,
|
||||
)
|
||||
|
||||
|
||||
def rebase_url(url: str, base: str) -> str:
|
||||
base_parts = list(urlparse.urlparse(base))
|
||||
url_parts = list(urlparse.urlparse(url))
|
||||
if not url_parts[0]: # fix up schema
|
||||
url_parts[0] = base_parts[0] or "http"
|
||||
if not url_parts[1]: # fix up hostname
|
||||
url_parts[1] = base_parts[1]
|
||||
if not url_parts[2].startswith("/"):
|
||||
url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
|
||||
return urlparse.urlunparse(url_parts)
|
||||
|
||||
|
||||
def summarize_paragraphs(
|
||||
text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Try to get a summary respecting first paragraph and then word boundaries.
|
||||
|
||||
Args:
|
||||
text_nodes: The paragraphs to summarize.
|
||||
min_size: The minimum number of words to include.
|
||||
max_size: The maximum number of words to include.
|
||||
|
||||
Returns:
|
||||
A summary of the text nodes, or None if that was not possible.
|
||||
"""
|
||||
|
||||
# TODO: Respect sentences?
|
||||
|
||||
description = ""
|
||||
|
||||
# Keep adding paragraphs until we get to the MIN_SIZE.
|
||||
for text_node in text_nodes:
|
||||
if len(description) < min_size:
|
||||
text_node = re.sub(r"[\t \r\n]+", " ", text_node)
|
||||
description += text_node + "\n\n"
|
||||
else:
|
||||
break
|
||||
|
||||
description = description.strip()
|
||||
description = re.sub(r"[\t ]+", " ", description)
|
||||
description = re.sub(r"[\t \r\n]*[\r\n]+", "\n\n", description)
|
||||
|
||||
# If the concatenation of paragraphs to get above MIN_SIZE
|
||||
# took us over MAX_SIZE, then we need to truncate mid paragraph
|
||||
if len(description) > max_size:
|
||||
new_desc = ""
|
||||
|
||||
# This splits the paragraph into words, but keeping the
|
||||
# (preceding) whitespace intact so we can easily concat
|
||||
# words back together.
|
||||
for match in re.finditer(r"\s*\S+", description):
|
||||
word = match.group()
|
||||
|
||||
# Keep adding words while the total length is less than
|
||||
# MAX_SIZE.
|
||||
if len(word) + len(new_desc) < max_size:
|
||||
new_desc += word
|
||||
else:
|
||||
# At this point the next word *will* take us over
|
||||
# MAX_SIZE, but we also want to ensure that its not
|
||||
# a huge word. If it is add it anyway and we'll
|
||||
# truncate later.
|
||||
if len(new_desc) < min_size:
|
||||
new_desc += word
|
||||
break
|
||||
|
||||
# Double check that we're not over the limit
|
||||
if len(new_desc) > max_size:
|
||||
new_desc = new_desc[:max_size]
|
||||
|
||||
# We always add an ellipsis because at the very least
|
||||
# we chopped mid paragraph.
|
||||
description = new_desc.strip() + "…"
|
||||
return description if description else None
|
|
@ -12,18 +12,16 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import codecs
|
||||
import datetime
|
||||
import errno
|
||||
import fnmatch
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Iterable, Optional, Tuple
|
||||
from urllib import parse as urlparse
|
||||
|
||||
import attr
|
||||
|
@ -45,6 +43,11 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
|||
from synapse.rest.media.v1._base import get_filename_from_headers
|
||||
from synapse.rest.media.v1.media_storage import MediaStorage
|
||||
from synapse.rest.media.v1.oembed import OEmbedProvider
|
||||
from synapse.rest.media.v1.preview_html import (
|
||||
decode_body,
|
||||
parse_html_to_open_graph,
|
||||
rebase_url,
|
||||
)
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.async_helpers import ObservableDeferred
|
||||
|
@ -54,21 +57,11 @@ from synapse.util.stringutils import random_string
|
|||
from ._base import FileInfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lxml import etree
|
||||
|
||||
from synapse.rest.media.v1.media_repository import MediaRepository
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_charset_match = re.compile(
|
||||
br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I
|
||||
)
|
||||
_xml_encoding_match = re.compile(
|
||||
br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I
|
||||
)
|
||||
_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
|
||||
|
||||
OG_TAG_NAME_MAXLEN = 50
|
||||
OG_TAG_VALUE_MAXLEN = 1000
|
||||
|
||||
|
@ -311,7 +304,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||
# If there was no oEmbed URL (or oEmbed parsing failed), attempt
|
||||
# to generate the Open Graph information from the HTML.
|
||||
if not oembed_url or not og:
|
||||
og = _calc_og(tree, media_info.uri)
|
||||
og = parse_html_to_open_graph(tree, media_info.uri)
|
||||
|
||||
await self._precache_image_url(user, media_info, og)
|
||||
else:
|
||||
|
@ -468,7 +461,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||
# request itself and benefit from the same caching etc. But for now we
|
||||
# just rely on the caching on the master request to speed things up.
|
||||
image_info = await self._download_url(
|
||||
_rebase_url(og["og:image"], media_info.uri), user
|
||||
rebase_url(og["og:image"], media_info.uri), user
|
||||
)
|
||||
|
||||
if _is_media(image_info.media_type):
|
||||
|
@ -632,301 +625,6 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||
logger.debug("No media removed from url cache")
|
||||
|
||||
|
||||
def _normalise_encoding(encoding: str) -> Optional[str]:
|
||||
"""Use the Python codec's name as the normalised entry."""
|
||||
try:
|
||||
return codecs.lookup(encoding).name
|
||||
except LookupError:
|
||||
return None
|
||||
|
||||
|
||||
def get_html_media_encodings(body: bytes, content_type: Optional[str]) -> Iterable[str]:
|
||||
"""
|
||||
Get potential encoding of the body based on the (presumably) HTML body or the content-type header.
|
||||
|
||||
The precedence used for finding a character encoding is:
|
||||
|
||||
1. <meta> tag with a charset declared.
|
||||
2. The XML document's character encoding attribute.
|
||||
3. The Content-Type header.
|
||||
4. Fallback to utf-8.
|
||||
5. Fallback to windows-1252.
|
||||
|
||||
This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector.
|
||||
|
||||
Args:
|
||||
body: The HTML document, as bytes.
|
||||
content_type: The Content-Type header.
|
||||
|
||||
Returns:
|
||||
The character encoding of the body, as a string.
|
||||
"""
|
||||
# There's no point in returning an encoding more than once.
|
||||
attempted_encodings: Set[str] = set()
|
||||
|
||||
# Limit searches to the first 1kb, since it ought to be at the top.
|
||||
body_start = body[:1024]
|
||||
|
||||
# Check if it has an encoding set in a meta tag.
|
||||
match = _charset_match.search(body_start)
|
||||
if match:
|
||||
encoding = _normalise_encoding(match.group(1).decode("ascii"))
|
||||
if encoding:
|
||||
attempted_encodings.add(encoding)
|
||||
yield encoding
|
||||
|
||||
# TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
|
||||
|
||||
# Check if it has an XML document with an encoding.
|
||||
match = _xml_encoding_match.match(body_start)
|
||||
if match:
|
||||
encoding = _normalise_encoding(match.group(1).decode("ascii"))
|
||||
if encoding and encoding not in attempted_encodings:
|
||||
attempted_encodings.add(encoding)
|
||||
yield encoding
|
||||
|
||||
# Check the HTTP Content-Type header for a character set.
|
||||
if content_type:
|
||||
content_match = _content_type_match.match(content_type)
|
||||
if content_match:
|
||||
encoding = _normalise_encoding(content_match.group(1))
|
||||
if encoding and encoding not in attempted_encodings:
|
||||
attempted_encodings.add(encoding)
|
||||
yield encoding
|
||||
|
||||
# Finally, fallback to UTF-8, then windows-1252.
|
||||
for fallback in ("utf-8", "cp1252"):
|
||||
if fallback not in attempted_encodings:
|
||||
yield fallback
|
||||
|
||||
|
||||
def decode_body(
|
||||
body: bytes, uri: str, content_type: Optional[str] = None
|
||||
) -> Optional["etree.Element"]:
|
||||
"""
|
||||
This uses lxml to parse the HTML document.
|
||||
|
||||
Args:
|
||||
body: The HTML document, as bytes.
|
||||
uri: The URI used to download the body.
|
||||
content_type: The Content-Type header.
|
||||
|
||||
Returns:
|
||||
The parsed HTML body, or None if an error occurred during processed.
|
||||
"""
|
||||
# If there's no body, nothing useful is going to be found.
|
||||
if not body:
|
||||
return None
|
||||
|
||||
# The idea here is that multiple encodings are tried until one works.
|
||||
# Unfortunately the result is never used and then LXML will decode the string
|
||||
# again with the found encoding.
|
||||
for encoding in get_html_media_encodings(body, content_type):
|
||||
try:
|
||||
body.decode(encoding)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
break
|
||||
else:
|
||||
logger.warning("Unable to decode HTML body for %s", uri)
|
||||
return None
|
||||
|
||||
from lxml import etree
|
||||
|
||||
# Create an HTML parser.
|
||||
parser = etree.HTMLParser(recover=True, encoding=encoding)
|
||||
|
||||
# Attempt to parse the body. Returns None if the body was successfully
|
||||
# parsed, but no tree was found.
|
||||
return etree.fromstring(body, parser)
|
||||
|
||||
|
||||
def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
|
||||
"""
|
||||
Calculate metadata for an HTML document.
|
||||
|
||||
This uses lxml to search the HTML document for Open Graph data.
|
||||
|
||||
Args:
|
||||
tree: The parsed HTML document.
|
||||
media_url: The URI used to download the body.
|
||||
|
||||
Returns:
|
||||
The Open Graph response as a dictionary.
|
||||
"""
|
||||
|
||||
# if we see any image URLs in the OG response, then spider them
|
||||
# (although the client could choose to do this by asking for previews of those
|
||||
# URLs to avoid DoSing the server)
|
||||
|
||||
# "og:type" : "video",
|
||||
# "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
|
||||
# "og:site_name" : "YouTube",
|
||||
# "og:video:type" : "application/x-shockwave-flash",
|
||||
# "og:description" : "Fun stuff happening here",
|
||||
# "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
|
||||
# "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
|
||||
# "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
|
||||
# "og:video:width" : "1280"
|
||||
# "og:video:height" : "720",
|
||||
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
|
||||
|
||||
og: Dict[str, Optional[str]] = {}
|
||||
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
|
||||
if "content" in tag.attrib:
|
||||
# if we've got more than 50 tags, someone is taking the piss
|
||||
if len(og) >= 50:
|
||||
logger.warning("Skipping OG for page with too many 'og:' tags")
|
||||
return {}
|
||||
og[tag.attrib["property"]] = tag.attrib["content"]
|
||||
|
||||
# TODO: grab article: meta tags too, e.g.:
|
||||
|
||||
# "article:publisher" : "https://www.facebook.com/thethudonline" />
|
||||
# "article:author" content="https://www.facebook.com/thethudonline" />
|
||||
# "article:tag" content="baby" />
|
||||
# "article:section" content="Breaking News" />
|
||||
# "article:published_time" content="2016-03-31T19:58:24+00:00" />
|
||||
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
|
||||
|
||||
if "og:title" not in og:
|
||||
# do some basic spidering of the HTML
|
||||
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
|
||||
if title and title[0].text is not None:
|
||||
og["og:title"] = title[0].text.strip()
|
||||
else:
|
||||
og["og:title"] = None
|
||||
|
||||
if "og:image" not in og:
|
||||
# TODO: extract a favicon failing all else
|
||||
meta_image = tree.xpath(
|
||||
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
|
||||
)
|
||||
if meta_image:
|
||||
og["og:image"] = _rebase_url(meta_image[0], media_uri)
|
||||
else:
|
||||
# TODO: consider inlined CSS styles as well as width & height attribs
|
||||
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
|
||||
images = sorted(
|
||||
images,
|
||||
key=lambda i: (
|
||||
-1 * float(i.attrib["width"]) * float(i.attrib["height"])
|
||||
),
|
||||
)
|
||||
if not images:
|
||||
images = tree.xpath("//img[@src]")
|
||||
if images:
|
||||
og["og:image"] = images[0].attrib["src"]
|
||||
|
||||
if "og:description" not in og:
|
||||
meta_description = tree.xpath(
|
||||
"//*/meta"
|
||||
"[translate(@name, 'DESCRIPTION', 'description')='description']"
|
||||
"/@content"
|
||||
)
|
||||
if meta_description:
|
||||
og["og:description"] = meta_description[0]
|
||||
else:
|
||||
og["og:description"] = _calc_description(tree)
|
||||
elif og["og:description"]:
|
||||
# This must be a non-empty string at this point.
|
||||
assert isinstance(og["og:description"], str)
|
||||
og["og:description"] = summarize_paragraphs([og["og:description"]])
|
||||
|
||||
# TODO: delete the url downloads to stop diskfilling,
|
||||
# as we only ever cared about its OG
|
||||
return og
|
||||
|
||||
|
||||
def _calc_description(tree: "etree.Element") -> Optional[str]:
|
||||
"""
|
||||
Calculate a text description based on an HTML document.
|
||||
|
||||
Grabs any text nodes which are inside the <body/> tag, unless they are within
|
||||
an HTML5 semantic markup tag (<header/>, <nav/>, <aside/>, <footer/>), or
|
||||
if they are within a <script/> or <style/> tag.
|
||||
|
||||
This is a very very very coarse approximation to a plain text render of the page.
|
||||
|
||||
Args:
|
||||
tree: The parsed HTML document.
|
||||
|
||||
Returns:
|
||||
The plain text description, or None if one cannot be generated.
|
||||
"""
|
||||
# We don't just use XPATH here as that is slow on some machines.
|
||||
|
||||
from lxml import etree
|
||||
|
||||
TAGS_TO_REMOVE = (
|
||||
"header",
|
||||
"nav",
|
||||
"aside",
|
||||
"footer",
|
||||
"script",
|
||||
"noscript",
|
||||
"style",
|
||||
etree.Comment,
|
||||
)
|
||||
|
||||
# Split all the text nodes into paragraphs (by splitting on new
|
||||
# lines)
|
||||
text_nodes = (
|
||||
re.sub(r"\s+", "\n", el).strip()
|
||||
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
|
||||
)
|
||||
return summarize_paragraphs(text_nodes)
|
||||
|
||||
|
||||
def _iterate_over_text(
|
||||
tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
|
||||
) -> Generator[str, None, None]:
|
||||
"""Iterate over the tree returning text nodes in a depth first fashion,
|
||||
skipping text nodes inside certain tags.
|
||||
"""
|
||||
# This is basically a stack that we extend using itertools.chain.
|
||||
# This will either consist of an element to iterate over *or* a string
|
||||
# to be returned.
|
||||
elements = iter([tree])
|
||||
while True:
|
||||
el = next(elements, None)
|
||||
if el is None:
|
||||
return
|
||||
|
||||
if isinstance(el, str):
|
||||
yield el
|
||||
elif el.tag not in tags_to_ignore:
|
||||
# el.text is the text before the first child, so we can immediately
|
||||
# return it if the text exists.
|
||||
if el.text:
|
||||
yield el.text
|
||||
|
||||
# We add to the stack all the elements children, interspersed with
|
||||
# each child's tail text (if it exists). The tail text of a node
|
||||
# is text that comes *after* the node, so we always include it even
|
||||
# if we ignore the child node.
|
||||
elements = itertools.chain(
|
||||
itertools.chain.from_iterable( # Basically a flatmap
|
||||
[child, child.tail] if child.tail else [child]
|
||||
for child in el.iterchildren()
|
||||
),
|
||||
elements,
|
||||
)
|
||||
|
||||
|
||||
def _rebase_url(url: str, base: str) -> str:
|
||||
base_parts = list(urlparse.urlparse(base))
|
||||
url_parts = list(urlparse.urlparse(url))
|
||||
if not url_parts[0]: # fix up schema
|
||||
url_parts[0] = base_parts[0] or "http"
|
||||
if not url_parts[1]: # fix up hostname
|
||||
url_parts[1] = base_parts[1]
|
||||
if not url_parts[2].startswith("/"):
|
||||
url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
|
||||
return urlparse.urlunparse(url_parts)
|
||||
|
||||
|
||||
def _is_media(content_type: str) -> bool:
|
||||
return content_type.lower().startswith("image/")
|
||||
|
||||
|
@ -940,68 +638,3 @@ def _is_html(content_type: str) -> bool:
|
|||
|
||||
def _is_json(content_type: str) -> bool:
|
||||
return content_type.lower().startswith("application/json")
|
||||
|
||||
|
||||
def summarize_paragraphs(
|
||||
text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Try to get a summary respecting first paragraph and then word boundaries.
|
||||
|
||||
Args:
|
||||
text_nodes: The paragraphs to summarize.
|
||||
min_size: The minimum number of words to include.
|
||||
max_size: The maximum number of words to include.
|
||||
|
||||
Returns:
|
||||
A summary of the text nodes, or None if that was not possible.
|
||||
"""
|
||||
|
||||
# TODO: Respect sentences?
|
||||
|
||||
description = ""
|
||||
|
||||
# Keep adding paragraphs until we get to the MIN_SIZE.
|
||||
for text_node in text_nodes:
|
||||
if len(description) < min_size:
|
||||
text_node = re.sub(r"[\t \r\n]+", " ", text_node)
|
||||
description += text_node + "\n\n"
|
||||
else:
|
||||
break
|
||||
|
||||
description = description.strip()
|
||||
description = re.sub(r"[\t ]+", " ", description)
|
||||
description = re.sub(r"[\t \r\n]*[\r\n]+", "\n\n", description)
|
||||
|
||||
# If the concatenation of paragraphs to get above MIN_SIZE
|
||||
# took us over MAX_SIZE, then we need to truncate mid paragraph
|
||||
if len(description) > max_size:
|
||||
new_desc = ""
|
||||
|
||||
# This splits the paragraph into words, but keeping the
|
||||
# (preceding) whitespace intact so we can easily concat
|
||||
# words back together.
|
||||
for match in re.finditer(r"\s*\S+", description):
|
||||
word = match.group()
|
||||
|
||||
# Keep adding words while the total length is less than
|
||||
# MAX_SIZE.
|
||||
if len(word) + len(new_desc) < max_size:
|
||||
new_desc += word
|
||||
else:
|
||||
# At this point the next word *will* take us over
|
||||
# MAX_SIZE, but we also want to ensure that its not
|
||||
# a huge word. If it is add it anyway and we'll
|
||||
# truncate later.
|
||||
if len(new_desc) < min_size:
|
||||
new_desc += word
|
||||
break
|
||||
|
||||
# Double check that we're not over the limit
|
||||
if len(new_desc) > max_size:
|
||||
new_desc = new_desc[:max_size]
|
||||
|
||||
# We always add an ellipsis because at the very least
|
||||
# we chopped mid paragraph.
|
||||
description = new_desc.strip() + "…"
|
||||
return description if description else None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue