Merge remote-tracking branch 'upstream/release-v1.50'

This commit is contained in:
Tulir Asokan 2022-01-07 14:21:32 +02:00
commit e9caf56ca0
205 changed files with 4905 additions and 2749 deletions

View file

@ -69,6 +69,7 @@ from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
from synapse.rest.admin.statistics import UserMediaStatisticsRestServlet
from synapse.rest.admin.username_available import UsernameAvailableRestServlet
from synapse.rest.admin.users import (
AccountDataRestServlet,
AccountValidityRenewServlet,
DeactivateAccountRestServlet,
PushersRestServlet,
@ -108,7 +109,7 @@ class VersionServlet(RestServlet):
class PurgeHistoryRestServlet(RestServlet):
PATTERNS = admin_patterns(
"/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
"/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]*))?$"
)
def __init__(self, hs: "HomeServer"):
@ -195,7 +196,7 @@ class PurgeHistoryRestServlet(RestServlet):
class PurgeHistoryStatusRestServlet(RestServlet):
PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]+)")
PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.pagination_handler = hs.get_pagination_handler()
@ -255,6 +256,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserMediaStatisticsRestServlet(hs).register(http_server)
EventReportDetailRestServlet(hs).register(http_server)
EventReportsRestServlet(hs).register(http_server)
AccountDataRestServlet(hs).register(http_server)
PushersRestServlet(hs).register(http_server)
MakeRoomAdminRestServlet(hs).register(http_server)
ShadowBanRestServlet(hs).register(http_server)

View file

@ -22,7 +22,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import JsonDict
if TYPE_CHECKING:
@ -41,8 +41,7 @@ class BackgroundUpdateEnabledRestServlet(RestServlet):
self._data_stores = hs.get_datastores()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request)
await assert_user_is_admin(self._auth, requester.user)
await assert_requester_is_admin(self._auth, request)
# We need to check that all configured databases have updates enabled.
# (They *should* all be in sync.)
@ -51,8 +50,7 @@ class BackgroundUpdateEnabledRestServlet(RestServlet):
return HTTPStatus.OK, {"enabled": enabled}
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request)
await assert_user_is_admin(self._auth, requester.user)
await assert_requester_is_admin(self._auth, request)
body = parse_json_object_from_request(request)
@ -84,8 +82,7 @@ class BackgroundUpdateRestServlet(RestServlet):
self._data_stores = hs.get_datastores()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request)
await assert_user_is_admin(self._auth, requester.user)
await assert_requester_is_admin(self._auth, request)
# We need to check that all configured databases have updates enabled.
# (They *should* all be in sync.)
@ -111,15 +108,14 @@ class BackgroundUpdateRestServlet(RestServlet):
class BackgroundUpdateStartJobRestServlet(RestServlet):
"""Allows to start specific background updates"""
PATTERNS = admin_patterns("/background_updates/start_job")
PATTERNS = admin_patterns("/background_updates/start_job$")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
self._store = hs.get_datastore()
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request)
await assert_user_is_admin(self._auth, requester.user)
await assert_requester_is_admin(self._auth, request)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["job_name"])

View file

@ -42,10 +42,10 @@ class DeviceRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
self.is_mine = hs.is_mine
async def on_GET(
self, request: SynapseRequest, user_id: str, device_id: str
@ -53,7 +53,7 @@ class DeviceRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user):
if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string())
@ -63,6 +63,8 @@ class DeviceRestServlet(RestServlet):
device = await self.device_handler.get_device(
target_user.to_string(), device_id
)
if device is None:
raise NotFoundError("No device found")
return HTTPStatus.OK, device
async def on_DELETE(
@ -71,7 +73,7 @@ class DeviceRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user):
if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string())
@ -87,7 +89,7 @@ class DeviceRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user):
if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string())
@ -109,14 +111,10 @@ class DevicesRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2")
def __init__(self, hs: "HomeServer"):
"""
Args:
hs: server
"""
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
self.is_mine = hs.is_mine
async def on_GET(
self, request: SynapseRequest, user_id: str
@ -124,7 +122,7 @@ class DevicesRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user):
if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string())
@ -144,10 +142,10 @@ class DeleteDevicesRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
self.is_mine = hs.is_mine
async def on_POST(
self, request: SynapseRequest, user_id: str
@ -155,7 +153,7 @@ class DeleteDevicesRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user):
if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string())

View file

@ -52,7 +52,6 @@ class EventReportsRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@ -115,7 +114,6 @@ class EventReportDetailRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()

View file

@ -100,7 +100,7 @@ class DestinationsRestServlet(RestServlet):
200 OK with details of a destination if success otherwise an error.
"""
PATTERNS = admin_patterns("/federation/destinations/(?P<destination>[^/]+)$")
PATTERNS = admin_patterns("/federation/destinations/(?P<destination>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()

View file

@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
class DeleteGroupAdminRestServlet(RestServlet):
"""Allows deleting of local groups"""
PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)")
PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.group_server = hs.get_groups_server_handler()

View file

@ -17,7 +17,7 @@ import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
@ -41,9 +41,9 @@ class QuarantineMediaInRoom(RestServlet):
"""
PATTERNS = [
*admin_patterns("/room/(?P<room_id>[^/]+)/media/quarantine$"),
*admin_patterns("/room/(?P<room_id>[^/]*)/media/quarantine$"),
# This path kept around for legacy reasons
*admin_patterns("/quarantine_media/(?P<room_id>[^/]+)"),
*admin_patterns("/quarantine_media/(?P<room_id>[^/]*)$"),
]
def __init__(self, hs: "HomeServer"):
@ -71,7 +71,7 @@ class QuarantineMediaByUser(RestServlet):
this server.
"""
PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine$")
PATTERNS = admin_patterns("/user/(?P<user_id>[^/]*)/media/quarantine$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@ -99,7 +99,7 @@ class QuarantineMediaByID(RestServlet):
"""
PATTERNS = admin_patterns(
"/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
"/media/quarantine/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
)
def __init__(self, hs: "HomeServer"):
@ -128,7 +128,7 @@ class UnquarantineMediaByID(RestServlet):
"""
PATTERNS = admin_patterns(
"/media/unquarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
"/media/unquarantine/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
)
def __init__(self, hs: "HomeServer"):
@ -138,8 +138,7 @@ class UnquarantineMediaByID(RestServlet):
async def on_POST(
self, request: SynapseRequest, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
await assert_requester_is_admin(self.auth, request)
logging.info(
"Remove from quarantine local media by ID: %s/%s", server_name, media_id
@ -154,7 +153,7 @@ class UnquarantineMediaByID(RestServlet):
class ProtectMediaByID(RestServlet):
"""Protect local media from being quarantined."""
PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@ -163,8 +162,7 @@ class ProtectMediaByID(RestServlet):
async def on_POST(
self, request: SynapseRequest, media_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
await assert_requester_is_admin(self.auth, request)
logging.info("Protecting local media by ID: %s", media_id)
@ -177,7 +175,7 @@ class ProtectMediaByID(RestServlet):
class UnprotectMediaByID(RestServlet):
"""Unprotect local media from being quarantined."""
PATTERNS = admin_patterns("/media/unprotect/(?P<media_id>[^/]+)")
PATTERNS = admin_patterns("/media/unprotect/(?P<media_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@ -186,8 +184,7 @@ class UnprotectMediaByID(RestServlet):
async def on_POST(
self, request: SynapseRequest, media_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
await assert_requester_is_admin(self.auth, request)
logging.info("Unprotecting local media by ID: %s", media_id)
@ -200,7 +197,7 @@ class UnprotectMediaByID(RestServlet):
class ListMediaInRoom(RestServlet):
"""Lists all of the media in a given room."""
PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media$")
PATTERNS = admin_patterns("/room/(?P<room_id>[^/]*)/media$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@ -209,10 +206,7 @@ class ListMediaInRoom(RestServlet):
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
is_admin = await self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
await assert_requester_is_admin(self.auth, request)
local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id)
@ -254,7 +248,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
class DeleteMediaByID(RestServlet):
"""Delete local media by a given ID. Removes it from this server."""
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@ -286,7 +280,7 @@ class DeleteMediaByDateSize(RestServlet):
timestamp and size.
"""
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete$")
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/delete$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@ -353,7 +347,7 @@ class UserMediaRestServlet(RestServlet):
media that exist given for this user
"""
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$")
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/media$")
def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
@ -403,16 +397,7 @@ class UserMediaRestServlet(RestServlet):
request,
"order_by",
default=MediaSortOrder.CREATED_TS.value,
allowed_values=(
MediaSortOrder.MEDIA_ID.value,
MediaSortOrder.UPLOAD_NAME.value,
MediaSortOrder.CREATED_TS.value,
MediaSortOrder.LAST_ACCESS_TS.value,
MediaSortOrder.MEDIA_LENGTH.value,
MediaSortOrder.MEDIA_TYPE.value,
MediaSortOrder.QUARANTINED_BY.value,
MediaSortOrder.SAFE_FROM_QUARANTINE.value,
),
allowed_values=[sort_order.value for sort_order in MediaSortOrder],
)
direction = parse_string(
request, "dir", default="f", allowed_values=("f", "b")
@ -470,16 +455,7 @@ class UserMediaRestServlet(RestServlet):
request,
"order_by",
default=MediaSortOrder.CREATED_TS.value,
allowed_values=(
MediaSortOrder.MEDIA_ID.value,
MediaSortOrder.UPLOAD_NAME.value,
MediaSortOrder.CREATED_TS.value,
MediaSortOrder.LAST_ACCESS_TS.value,
MediaSortOrder.MEDIA_LENGTH.value,
MediaSortOrder.MEDIA_TYPE.value,
MediaSortOrder.QUARANTINED_BY.value,
MediaSortOrder.SAFE_FROM_QUARANTINE.value,
),
allowed_values=[sort_order.value for sort_order in MediaSortOrder],
)
direction = parse_string(
request, "dir", default="f", allowed_values=("f", "b")

View file

@ -70,7 +70,6 @@ class ListRegistrationTokensRestServlet(RestServlet):
PATTERNS = admin_patterns("/registration_tokens$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@ -109,7 +108,6 @@ class NewRegistrationTokenRestServlet(RestServlet):
PATTERNS = admin_patterns("/registration_tokens/new$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@ -260,7 +258,6 @@ class RegistrationTokenRestServlet(RestServlet):
PATTERNS = admin_patterns("/registration_tokens/(?P<token>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.store = hs.get_datastore()

View file

@ -61,7 +61,7 @@ class RoomRestV2Servlet(RestServlet):
If 'purge' is true, it will remove all traces of a room from the database.
"""
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$", "v2")
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)$", "v2")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
@ -123,7 +123,7 @@ class RoomRestV2Servlet(RestServlet):
class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
"""Get the status of the delete room background task."""
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete_status$", "v2")
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/delete_status$", "v2")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
@ -160,7 +160,7 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
class DeleteRoomStatusByDeleteIdRestServlet(RestServlet):
"""Get the status of the delete room background task."""
PATTERNS = admin_patterns("/rooms/delete_status/(?P<delete_id>[^/]+)$", "v2")
PATTERNS = admin_patterns("/rooms/delete_status/(?P<delete_id>[^/]*)$", "v2")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
@ -193,35 +193,17 @@ class ListRoomRestServlet(RestServlet):
self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
await assert_requester_is_admin(self.auth, request)
# Extract query parameters
start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
order_by = parse_string(request, "order_by", default=RoomSortOrder.NAME.value)
if order_by not in (
RoomSortOrder.ALPHABETICAL.value,
RoomSortOrder.SIZE.value,
RoomSortOrder.NAME.value,
RoomSortOrder.CANONICAL_ALIAS.value,
RoomSortOrder.JOINED_MEMBERS.value,
RoomSortOrder.JOINED_LOCAL_MEMBERS.value,
RoomSortOrder.VERSION.value,
RoomSortOrder.CREATOR.value,
RoomSortOrder.ENCRYPTION.value,
RoomSortOrder.FEDERATABLE.value,
RoomSortOrder.PUBLIC.value,
RoomSortOrder.JOIN_RULES.value,
RoomSortOrder.GUEST_ACCESS.value,
RoomSortOrder.HISTORY_VISIBILITY.value,
RoomSortOrder.STATE_EVENTS.value,
):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Unknown value for order_by: %s" % (order_by,),
errcode=Codes.INVALID_PARAM,
)
order_by = parse_string(
request,
"order_by",
default=RoomSortOrder.NAME.value,
allowed_values=[sort_order.value for sort_order in RoomSortOrder],
)
search_term = parse_string(request, "search_term", encoding="utf-8")
if search_term == "":
@ -292,10 +274,9 @@ class RoomRestServlet(RestServlet):
TODO: Add on_POST to allow room creation without joining the room
"""
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$")
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.room_shutdown_handler = hs.get_room_shutdown_handler()
@ -397,10 +378,9 @@ class RoomMembersRestServlet(RestServlet):
Get members list of a room.
"""
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members")
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/members$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@ -424,10 +404,9 @@ class RoomStateRestServlet(RestServlet):
Get full state within a room.
"""
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/state")
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/state$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@ -436,8 +415,7 @@ class RoomStateRestServlet(RestServlet):
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
await assert_requester_is_admin(self.auth, request)
ret = await self.store.get_room(room_id)
if not ret:
@ -454,14 +432,14 @@ class RoomStateRestServlet(RestServlet):
class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)$")
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
self.state_handler = hs.get_state_handler()
self.is_mine = hs.is_mine
async def on_POST(
self, request: SynapseRequest, room_identifier: str
@ -477,7 +455,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
assert_params_in_dict(content, ["user_id"])
target_user = UserID.from_string(content["user_id"])
if not self.hs.is_mine(target_user):
if not self.is_mine(target_user):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"This endpoint can only be used with local users",
@ -542,11 +520,10 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
}
"""
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin")
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin$")
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.event_creation_handler = hs.get_event_creation_handler()
@ -688,19 +665,17 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
"""
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities")
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities$")
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_DELETE(
self, request: SynapseRequest, room_identifier: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
await assert_requester_is_admin(self.auth, request)
room_id, _ = await self.resolve_room_id(room_identifier)
@ -710,8 +685,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
async def on_GET(
self, request: SynapseRequest, room_identifier: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
await assert_requester_is_admin(self.auth, request)
room_id, _ = await self.resolve_room_id(room_identifier)
@ -771,13 +745,19 @@ class RoomEventContextServlet(RestServlet):
time_now = self.clock.time_msec()
results["events_before"] = await self._event_serializer.serialize_events(
results["events_before"], time_now
results["events_before"],
time_now,
bundle_aggregations=True,
)
results["event"] = await self._event_serializer.serialize_event(
results["event"], time_now
results["event"],
time_now,
bundle_aggregations=True,
)
results["events_after"] = await self._event_serializer.serialize_events(
results["events_after"], time_now
results["events_after"],
time_now,
bundle_aggregations=True,
)
results["state"] = await self._event_serializer.serialize_events(
results["state"], time_now
@ -793,7 +773,7 @@ class BlockRoomRestServlet(RestServlet):
On GET: Get blocking status of room and user who has blocked this room.
"""
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/block$")
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/block$")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()

View file

@ -52,11 +52,11 @@ class SendServerNoticeServlet(RestServlet):
"""
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.server_notices_manager = hs.get_server_notices_manager()
self.admin_handler = hs.get_admin_handler()
self.txns = HttpTransactionCache(hs)
self.is_mine = hs.is_mine
def register(self, json_resource: HttpServer) -> None:
PATTERN = "/send_server_notice"
@ -88,7 +88,7 @@ class SendServerNoticeServlet(RestServlet):
)
target_user = UserID.from_string(body["user_id"])
if not self.hs.is_mine(target_user):
if not self.is_mine(target_user):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Server notices can only be sent to local users"
)

View file

@ -37,7 +37,6 @@ class UserMediaStatisticsRestServlet(RestServlet):
PATTERNS = admin_patterns("/statistics/users/media$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@ -45,19 +44,16 @@ class UserMediaStatisticsRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
order_by = parse_string(
request, "order_by", default=UserSortOrder.USER_ID.value
request,
"order_by",
default=UserSortOrder.USER_ID.value,
allowed_values=(
UserSortOrder.MEDIA_LENGTH.value,
UserSortOrder.MEDIA_COUNT.value,
UserSortOrder.USER_ID.value,
UserSortOrder.DISPLAYNAME.value,
),
)
if order_by not in (
UserSortOrder.MEDIA_LENGTH.value,
UserSortOrder.MEDIA_COUNT.value,
UserSortOrder.USER_ID.value,
UserSortOrder.DISPLAYNAME.value,
):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Unknown value for order_by: %s" % (order_by,),
errcode=Codes.INVALID_PARAM,
)
start = parse_integer(request, "from", default=0)
if start < 0:

View file

@ -37,7 +37,7 @@ class UsernameAvailableRestServlet(RestServlet):
}
"""
PATTERNS = admin_patterns("/username_available")
PATTERNS = admin_patterns("/username_available$")
def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()

View file

@ -66,7 +66,6 @@ class UsersRestServletV2(RestServlet):
"""
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
@ -126,7 +125,7 @@ class UsersRestServletV2(RestServlet):
class UserRestServletV2(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)$", "v2")
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$", "v2")
"""Get request to list user details.
This needs user to have administrator access in Synapse.
@ -414,7 +413,7 @@ class UserRegisterServlet(RestServlet):
nonce to the time it was generated, in int seconds.
"""
PATTERNS = admin_patterns("/register")
PATTERNS = admin_patterns("/register$")
NONCE_TIMEOUT = 60
def __init__(self, hs: "HomeServer"):
@ -561,9 +560,9 @@ class WhoisRestServlet(RestServlet):
]
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
self.is_mine = hs.is_mine
async def on_GET(
self, request: SynapseRequest, user_id: str
@ -575,7 +574,7 @@ class WhoisRestServlet(RestServlet):
if target_user != auth_user:
await assert_user_is_admin(self.auth, auth_user)
if not self.hs.is_mine(target_user):
if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user")
ret = await self.admin_handler.get_whois(target_user)
@ -584,7 +583,7 @@ class WhoisRestServlet(RestServlet):
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)")
PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self._deactivate_account_handler = hs.get_deactivate_account_handler()
@ -630,7 +629,6 @@ class AccountValidityRenewServlet(RestServlet):
PATTERNS = admin_patterns("/account_validity/validity$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth()
@ -674,11 +672,10 @@ class ResetPasswordRestServlet(RestServlet):
200 OK with empty object if success otherwise an error.
"""
PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self._set_password_handler = hs.get_set_password_handler()
@ -718,12 +715,12 @@ class SearchUsersRestServlet(RestServlet):
200 OK with json object {list[dict[str, Any]], count} or empty object.
"""
PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)")
PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.is_mine = hs.is_mine
async def on_GET(
self, request: SynapseRequest, target_user_id: str
@ -740,7 +737,7 @@ class SearchUsersRestServlet(RestServlet):
# if not is_admin and target_user != auth_user:
# raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
if not self.hs.is_mine(target_user):
if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only users a local user")
term = parse_string(request, "term", required=True)
@ -779,9 +776,9 @@ class UserAdminServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.is_mine = hs.is_mine
async def on_GET(
self, request: SynapseRequest, user_id: str
@ -790,7 +787,7 @@ class UserAdminServlet(RestServlet):
target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user):
if not self.is_mine(target_user):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Only local users can be admins of this homeserver",
@ -813,7 +810,7 @@ class UserAdminServlet(RestServlet):
assert_params_in_dict(body, ["admin"])
if not self.hs.is_mine(target_user):
if not self.is_mine(target_user):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Only local users can be admins of this homeserver",
@ -834,7 +831,7 @@ class UserMembershipRestServlet(RestServlet):
Get room list of an user.
"""
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$")
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/joined_rooms$")
def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
@ -909,10 +906,10 @@ class UserTokenRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/login$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.is_mine_id = hs.is_mine_id
async def on_POST(
self, request: SynapseRequest, user_id: str
@ -921,7 +918,7 @@ class UserTokenRestServlet(RestServlet):
await assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user
if not self.hs.is_mine_id(user_id):
if not self.is_mine_id(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Only local users can be logged in as"
)
@ -975,19 +972,19 @@ class ShadowBanRestServlet(RestServlet):
{}
"""
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban")
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.is_mine_id = hs.is_mine_id
async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id):
if not self.is_mine_id(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned"
)
@ -1001,7 +998,7 @@ class ShadowBanRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id):
if not self.is_mine_id(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned"
)
@ -1027,19 +1024,19 @@ class RateLimitRestServlet(RestServlet):
}
"""
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit")
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.is_mine_id = hs.is_mine_id
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id):
if not self.is_mine_id(user_id):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
if not await self.store.get_user_by_id(user_id):
@ -1068,7 +1065,7 @@ class RateLimitRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id):
if not self.is_mine_id(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited"
)
@ -1113,7 +1110,7 @@ class RateLimitRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id):
if not self.is_mine_id(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited"
)
@ -1124,3 +1121,33 @@ class RateLimitRestServlet(RestServlet):
await self.store.delete_ratelimit_for_user(user_id)
return HTTPStatus.OK, {}
class AccountDataRestServlet(RestServlet):
"""Retrieve the given user's account data"""
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/accountdata")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
self._store = hs.get_datastore()
self._is_mine_id = hs.is_mine_id
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self._auth, request)
if not self._is_mine_id(user_id):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
if not await self._store.get_user_by_id(user_id):
raise NotFoundError("User not found")
global_data, by_room_data = await self._store.get_account_data_for_user(user_id)
return HTTPStatus.OK, {
"account_data": {
"global": global_data,
"rooms": by_room_data,
},
}

View file

@ -73,6 +73,9 @@ class CapabilitiesRestServlet(RestServlet):
"enabled": self.config.registration.enable_3pid_changes
}
if self.config.experimental.msc3440_enabled:
response["capabilities"]["io.element.thread"] = {"enabled": True}
return 200, response

View file

@ -17,6 +17,7 @@ import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api import errors
from synapse.api.errors import NotFoundError
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
@ -24,10 +25,9 @@ from synapse.http.servlet import (
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns, interactive_auth_handler
from synapse.types import JsonDict
from ._base import client_patterns, interactive_auth_handler
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -116,6 +116,8 @@ class DeviceRestServlet(RestServlet):
device = await self.device_handler.get_device(
requester.user.to_string(), device_id
)
if device is None:
raise NotFoundError("No device found")
return 200, device
@interactive_auth_handler

View file

@ -15,6 +15,7 @@
import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import ReceiptTypes
from synapse.events.utils import format_event_for_client_v2_without_room_id
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string
@ -54,10 +55,10 @@ class NotificationsServlet(RestServlet):
)
receipts_by_room = await self.store.get_receipts_for_user_with_orderings(
user_id, "m.read"
user_id, ReceiptTypes.READ
)
notif_event_ids = [pa["event_id"] for pa in push_actions]
notif_event_ids = [pa.event_id for pa in push_actions]
notif_events = await self.store.get_events(notif_event_ids)
returned_push_actions = []
@ -66,30 +67,30 @@ class NotificationsServlet(RestServlet):
for pa in push_actions:
returned_pa = {
"room_id": pa["room_id"],
"profile_tag": pa["profile_tag"],
"actions": pa["actions"],
"ts": pa["received_ts"],
"room_id": pa.room_id,
"profile_tag": pa.profile_tag,
"actions": pa.actions,
"ts": pa.received_ts,
"event": (
await self._event_serializer.serialize_event(
notif_events[pa["event_id"]],
notif_events[pa.event_id],
self.clock.time_msec(),
event_format=format_event_for_client_v2_without_room_id,
)
),
}
if pa["room_id"] not in receipts_by_room:
if pa.room_id not in receipts_by_room:
returned_pa["read"] = False
else:
receipt = receipts_by_room[pa["room_id"]]
receipt = receipts_by_room[pa.room_id]
returned_pa["read"] = (
receipt["topological_ordering"],
receipt["stream_ordering"],
) >= (pa["topological_ordering"], pa["stream_ordering"])
) >= (pa.topological_ordering, pa.stream_ordering)
returned_push_actions.append(returned_pa)
next_token = str(pa["stream_ordering"])
next_token = str(pa.stream_ordering)
return 200, {"notifications": returned_push_actions, "next_token": next_token}

View file

@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import ReadReceiptEventFields
from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@ -48,7 +48,7 @@ class ReadMarkerRestServlet(RestServlet):
await self.presence_handler.bump_presence_active_time(requester.user)
body = parse_json_object_from_request(request)
read_event_id = body.get("m.read", None)
read_event_id = body.get(ReceiptTypes.READ, None)
read_extra = body.get("com.beeper.read.extra", None)
hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False)
@ -63,7 +63,7 @@ class ReadMarkerRestServlet(RestServlet):
if read_event_id:
await self.receipts_handler.received_client_receipt(
room_id,
"m.read",
ReceiptTypes.READ,
user_id=requester.user.to_string(),
event_id=read_event_id,
hidden=hidden,

View file

@ -16,7 +16,7 @@ import logging
import re
from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import ReadReceiptEventFields
from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
from synapse.api.errors import Codes, SynapseError
from synapse.http import get_request_user_agent
from synapse.http.server import HttpServer
@ -53,7 +53,7 @@ class ReceiptRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if receipt_type != "m.read":
if receipt_type != ReceiptTypes.READ:
raise SynapseError(400, "Receipt type must be 'm.read'")
# Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body.

View file

@ -212,6 +212,7 @@ class RelationPaginationServlet(RestServlet):
pagination_chunk = await self.store.get_relations_for_event(
event_id=parent_id,
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
limit=limit,
@ -231,7 +232,9 @@ class RelationPaginationServlet(RestServlet):
)
# The relations returned for the requested event do include their
# bundled aggregations.
serialized_events = await self._event_serializer.serialize_events(events, now)
serialized_events = await self._event_serializer.serialize_events(
events, now, bundle_aggregations=True
)
return_value = pagination_chunk.to_dict()
return_value["chunk"] = serialized_events
@ -317,6 +320,7 @@ class RelationAggregationPaginationServlet(RestServlet):
pagination_chunk = await self.store.get_aggregation_groups_for_event(
event_id=parent_id,
room_id=room_id,
event_type=event_type,
limit=limit,
from_token=from_token,
@ -383,7 +387,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
# This checks that a) the event exists and b) the user is allowed to
# view it.
await self.event_handler.get_event(requester.user, room_id, parent_id)
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
if event is None:
raise SynapseError(404, "Unknown parent event.")
if relation_type != RelationTypes.ANNOTATION:
raise SynapseError(400, "Relation type must be 'annotation'")
@ -402,6 +408,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
result = await self.store.get_relations_for_event(
event_id=parent_id,
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
aggregation_key=key,

View file

@ -187,7 +187,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
state_key: str,
txn_id: Optional[str] = None,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester = await self.auth.get_user_by_req(request, allow_guest=True)
if txn_id:
set_tag("txn_id", txn_id)
@ -666,7 +666,9 @@ class RoomEventServlet(RestServlet):
time_now = self.clock.time_msec()
if event:
event_dict = await self._event_serializer.serialize_event(event, time_now)
event_dict = await self._event_serializer.serialize_event(
event, time_now, bundle_aggregations=True
)
return 200, event_dict
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
@ -711,13 +713,13 @@ class RoomEventContextServlet(RestServlet):
time_now = self.clock.time_msec()
results["events_before"] = await self._event_serializer.serialize_events(
results["events_before"], time_now
results["events_before"], time_now, bundle_aggregations=True
)
results["event"] = await self._event_serializer.serialize_event(
results["event"], time_now
results["event"], time_now, bundle_aggregations=True
)
results["events_after"] = await self._event_serializer.serialize_events(
results["events_after"], time_now
results["events_after"], time_now, bundle_aggregations=True
)
results["state"] = await self._event_serializer.serialize_events(
results["state"], time_now

View file

@ -48,6 +48,7 @@ from synapse.handlers.sync import (
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import trace
from synapse.types import JsonDict, StreamToken
from synapse.util import json_decoder
@ -222,6 +223,7 @@ class SyncRestServlet(RestServlet):
logger.debug("Event formatting complete")
return 200, response_content
@trace(opname="sync.encode_response")
async def encode_response(
self,
time_now: int,
@ -293,6 +295,9 @@ class SyncRestServlet(RestServlet):
response[
"org.matrix.msc2732.device_unused_fallback_key_types"
] = sync_result.device_unused_fallback_key_types
response[
"device_unused_fallback_key_types"
] = sync_result.device_unused_fallback_key_types
if joined:
response["rooms"][Membership.JOIN] = joined
@ -329,6 +334,7 @@ class SyncRestServlet(RestServlet):
]
}
@trace(opname="sync.encode_joined")
async def encode_joined(
self,
rooms: List[JoinedSyncResult],
@ -365,6 +371,7 @@ class SyncRestServlet(RestServlet):
return joined
@trace(opname="sync.encode_invited")
async def encode_invited(
self,
rooms: List[InvitedSyncResult],
@ -403,6 +410,7 @@ class SyncRestServlet(RestServlet):
return invited
@trace(opname="sync.encode_knocked")
async def encode_knocked(
self,
rooms: List[KnockedSyncResult],
@ -457,6 +465,7 @@ class SyncRestServlet(RestServlet):
return knocked
@trace(opname="sync.encode_archived")
async def encode_archived(
self,
rooms: List[ArchivedSyncResult],

View file

@ -93,6 +93,10 @@ class VersionsRestServlet(RestServlet):
"org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled,
# Supports receiving hidden read receipts as per MSC2285
"org.matrix.msc2285": self.config.experimental.msc2285_enabled,
# Adds support for importing historical messages as per MSC2716
"org.matrix.msc2716": self.config.experimental.msc2716_enabled,
# Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030
"org.matrix.msc3030": self.config.experimental.msc3030_enabled,
},
},
)

View file

@ -13,7 +13,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
from canonicaljson import encode_canonical_json
from signedjson.sign import sign_json
@ -99,7 +99,7 @@ class LocalKey(Resource):
json_object = sign_json(json_object, self.config.server.server_name, key)
return json_object
def render_GET(self, request: Request) -> int:
def render_GET(self, request: Request) -> Optional[int]:
time_now = self.clock.time_msec()
# Update the expiry time if less than half the interval remains.
if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts:

View file

@ -739,14 +739,21 @@ class MediaRepository:
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
# they have the same dimensions of a scaled one.
thumbnails: Dict[Tuple[int, int, str], str] = {}
for r_width, r_height, r_method, r_type in requirements:
if r_method == "crop":
thumbnails.setdefault((r_width, r_height, r_type), r_method)
elif r_method == "scale":
t_width, t_height = thumbnailer.aspect(r_width, r_height)
for requirement in requirements:
if requirement.method == "crop":
thumbnails.setdefault(
(requirement.width, requirement.height, requirement.media_type),
requirement.method,
)
elif requirement.method == "scale":
t_width, t_height = thumbnailer.aspect(
requirement.width, requirement.height
)
t_width = min(m_width, t_width)
t_height = min(m_height, t_height)
thumbnails[(t_width, t_height, r_type)] = r_method
thumbnails[
(t_width, t_height, requirement.media_type)
] = requirement.method
# Now we generate the thumbnails for each dimension, store it
for (t_width, t_height, t_type), t_method in thumbnails.items():

View file

@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, List, Optional
import attr
from synapse.rest.media.v1.preview_html import parse_html_description
from synapse.types import JsonDict
from synapse.util import json_decoder
@ -245,8 +246,6 @@ def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) ->
if video_urls:
open_graph_response["og:video"] = video_urls[0]
from synapse.rest.media.v1.preview_url_resource import _calc_description
description = _calc_description(tree)
description = parse_html_description(tree)
if description:
open_graph_response["og:description"] = description

View file

@ -0,0 +1,397 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import codecs
import itertools
import logging
import re
from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union
from urllib import parse as urlparse
if TYPE_CHECKING:
from lxml import etree
logger = logging.getLogger(__name__)
_charset_match = re.compile(
br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I
)
_xml_encoding_match = re.compile(
br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I
)
_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
def _normalise_encoding(encoding: str) -> Optional[str]:
"""Use the Python codec's name as the normalised entry."""
try:
return codecs.lookup(encoding).name
except LookupError:
return None
def _get_html_media_encodings(
body: bytes, content_type: Optional[str]
) -> Iterable[str]:
"""
Get potential encoding of the body based on the (presumably) HTML body or the content-type header.
The precedence used for finding a character encoding is:
1. <meta> tag with a charset declared.
2. The XML document's character encoding attribute.
3. The Content-Type header.
4. Fallback to utf-8.
5. Fallback to windows-1252.
This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector.
Args:
body: The HTML document, as bytes.
content_type: The Content-Type header.
Returns:
The character encoding of the body, as a string.
"""
# There's no point in returning an encoding more than once.
attempted_encodings: Set[str] = set()
# Limit searches to the first 1kb, since it ought to be at the top.
body_start = body[:1024]
# Check if it has an encoding set in a meta tag.
match = _charset_match.search(body_start)
if match:
encoding = _normalise_encoding(match.group(1).decode("ascii"))
if encoding:
attempted_encodings.add(encoding)
yield encoding
# TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
# Check if it has an XML document with an encoding.
match = _xml_encoding_match.match(body_start)
if match:
encoding = _normalise_encoding(match.group(1).decode("ascii"))
if encoding and encoding not in attempted_encodings:
attempted_encodings.add(encoding)
yield encoding
# Check the HTTP Content-Type header for a character set.
if content_type:
content_match = _content_type_match.match(content_type)
if content_match:
encoding = _normalise_encoding(content_match.group(1))
if encoding and encoding not in attempted_encodings:
attempted_encodings.add(encoding)
yield encoding
# Finally, fallback to UTF-8, then windows-1252.
for fallback in ("utf-8", "cp1252"):
if fallback not in attempted_encodings:
yield fallback
def decode_body(
body: bytes, uri: str, content_type: Optional[str] = None
) -> Optional["etree.Element"]:
"""
This uses lxml to parse the HTML document.
Args:
body: The HTML document, as bytes.
uri: The URI used to download the body.
content_type: The Content-Type header.
Returns:
The parsed HTML body, or None if an error occurred during processed.
"""
# If there's no body, nothing useful is going to be found.
if not body:
return None
# The idea here is that multiple encodings are tried until one works.
# Unfortunately the result is never used and then LXML will decode the string
# again with the found encoding.
for encoding in _get_html_media_encodings(body, content_type):
try:
body.decode(encoding)
except Exception:
pass
else:
break
else:
logger.warning("Unable to decode HTML body for %s", uri)
return None
from lxml import etree
# Create an HTML parser.
parser = etree.HTMLParser(recover=True, encoding=encoding)
# Attempt to parse the body. Returns None if the body was successfully
# parsed, but no tree was found.
return etree.fromstring(body, parser)
def parse_html_to_open_graph(
tree: "etree.Element", media_uri: str
) -> Dict[str, Optional[str]]:
"""
Parse the HTML document into an Open Graph response.
This uses lxml to search the HTML document for Open Graph data (or
synthesizes it from the document).
Args:
tree: The parsed HTML document.
media_url: The URI used to download the body.
Returns:
The Open Graph response as a dictionary.
"""
# if we see any image URLs in the OG response, then spider them
# (although the client could choose to do this by asking for previews of those
# URLs to avoid DoSing the server)
# "og:type" : "video",
# "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
# "og:site_name" : "YouTube",
# "og:video:type" : "application/x-shockwave-flash",
# "og:description" : "Fun stuff happening here",
# "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
# "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
# "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
# "og:video:width" : "1280"
# "og:video:height" : "720",
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
og: Dict[str, Optional[str]] = {}
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
if "content" in tag.attrib:
# if we've got more than 50 tags, someone is taking the piss
if len(og) >= 50:
logger.warning("Skipping OG for page with too many 'og:' tags")
return {}
og[tag.attrib["property"]] = tag.attrib["content"]
# TODO: grab article: meta tags too, e.g.:
# "article:publisher" : "https://www.facebook.com/thethudonline" />
# "article:author" content="https://www.facebook.com/thethudonline" />
# "article:tag" content="baby" />
# "article:section" content="Breaking News" />
# "article:published_time" content="2016-03-31T19:58:24+00:00" />
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
if "og:title" not in og:
# do some basic spidering of the HTML
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
if title and title[0].text is not None:
og["og:title"] = title[0].text.strip()
else:
og["og:title"] = None
if "og:image" not in og:
# TODO: extract a favicon failing all else
meta_image = tree.xpath(
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
)
if meta_image:
og["og:image"] = rebase_url(meta_image[0], media_uri)
else:
# TODO: consider inlined CSS styles as well as width & height attribs
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
images = sorted(
images,
key=lambda i: (
-1 * float(i.attrib["width"]) * float(i.attrib["height"])
),
)
if not images:
images = tree.xpath("//img[@src]")
if images:
og["og:image"] = images[0].attrib["src"]
if "og:description" not in og:
meta_description = tree.xpath(
"//*/meta"
"[translate(@name, 'DESCRIPTION', 'description')='description']"
"/@content"
)
if meta_description:
og["og:description"] = meta_description[0]
else:
og["og:description"] = parse_html_description(tree)
elif og["og:description"]:
# This must be a non-empty string at this point.
assert isinstance(og["og:description"], str)
og["og:description"] = summarize_paragraphs([og["og:description"]])
# TODO: delete the url downloads to stop diskfilling,
# as we only ever cared about its OG
return og
def parse_html_description(tree: "etree.Element") -> Optional[str]:
"""
Calculate a text description based on an HTML document.
Grabs any text nodes which are inside the <body/> tag, unless they are within
an HTML5 semantic markup tag (<header/>, <nav/>, <aside/>, <footer/>), or
if they are within a <script/> or <style/> tag.
This is a very very very coarse approximation to a plain text render of the page.
Args:
tree: The parsed HTML document.
Returns:
The plain text description, or None if one cannot be generated.
"""
# We don't just use XPATH here as that is slow on some machines.
from lxml import etree
TAGS_TO_REMOVE = (
"header",
"nav",
"aside",
"footer",
"script",
"noscript",
"style",
etree.Comment,
)
# Split all the text nodes into paragraphs (by splitting on new
# lines)
text_nodes = (
re.sub(r"\s+", "\n", el).strip()
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
)
return summarize_paragraphs(text_nodes)
def _iterate_over_text(
tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
) -> Generator[str, None, None]:
"""Iterate over the tree returning text nodes in a depth first fashion,
skipping text nodes inside certain tags.
"""
# This is basically a stack that we extend using itertools.chain.
# This will either consist of an element to iterate over *or* a string
# to be returned.
elements = iter([tree])
while True:
el = next(elements, None)
if el is None:
return
if isinstance(el, str):
yield el
elif el.tag not in tags_to_ignore:
# el.text is the text before the first child, so we can immediately
# return it if the text exists.
if el.text:
yield el.text
# We add to the stack all the elements children, interspersed with
# each child's tail text (if it exists). The tail text of a node
# is text that comes *after* the node, so we always include it even
# if we ignore the child node.
elements = itertools.chain(
itertools.chain.from_iterable( # Basically a flatmap
[child, child.tail] if child.tail else [child]
for child in el.iterchildren()
),
elements,
)
def rebase_url(url: str, base: str) -> str:
base_parts = list(urlparse.urlparse(base))
url_parts = list(urlparse.urlparse(url))
if not url_parts[0]: # fix up schema
url_parts[0] = base_parts[0] or "http"
if not url_parts[1]: # fix up hostname
url_parts[1] = base_parts[1]
if not url_parts[2].startswith("/"):
url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
return urlparse.urlunparse(url_parts)
def summarize_paragraphs(
text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
) -> Optional[str]:
"""
Try to get a summary respecting first paragraph and then word boundaries.
Args:
text_nodes: The paragraphs to summarize.
min_size: The minimum number of words to include.
max_size: The maximum number of words to include.
Returns:
A summary of the text nodes, or None if that was not possible.
"""
# TODO: Respect sentences?
description = ""
# Keep adding paragraphs until we get to the MIN_SIZE.
for text_node in text_nodes:
if len(description) < min_size:
text_node = re.sub(r"[\t \r\n]+", " ", text_node)
description += text_node + "\n\n"
else:
break
description = description.strip()
description = re.sub(r"[\t ]+", " ", description)
description = re.sub(r"[\t \r\n]*[\r\n]+", "\n\n", description)
# If the concatenation of paragraphs to get above MIN_SIZE
# took us over MAX_SIZE, then we need to truncate mid paragraph
if len(description) > max_size:
new_desc = ""
# This splits the paragraph into words, but keeping the
# (preceding) whitespace intact so we can easily concat
# words back together.
for match in re.finditer(r"\s*\S+", description):
word = match.group()
# Keep adding words while the total length is less than
# MAX_SIZE.
if len(word) + len(new_desc) < max_size:
new_desc += word
else:
# At this point the next word *will* take us over
# MAX_SIZE, but we also want to ensure that its not
# a huge word. If it is add it anyway and we'll
# truncate later.
if len(new_desc) < min_size:
new_desc += word
break
# Double check that we're not over the limit
if len(new_desc) > max_size:
new_desc = new_desc[:max_size]
# We always add an ellipsis because at the very least
# we chopped mid paragraph.
description = new_desc.strip() + ""
return description if description else None

View file

@ -12,18 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import codecs
import datetime
import errno
import fnmatch
import itertools
import logging
import os
import re
import shutil
import sys
import traceback
from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Iterable, Optional, Tuple
from urllib import parse as urlparse
import attr
@ -45,6 +43,11 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.oembed import OEmbedProvider
from synapse.rest.media.v1.preview_html import (
decode_body,
parse_html_to_open_graph,
rebase_url,
)
from synapse.types import JsonDict, UserID
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
@ -54,21 +57,11 @@ from synapse.util.stringutils import random_string
from ._base import FileInfo
if TYPE_CHECKING:
from lxml import etree
from synapse.rest.media.v1.media_repository import MediaRepository
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
_charset_match = re.compile(
br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I
)
_xml_encoding_match = re.compile(
br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I
)
_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
OG_TAG_NAME_MAXLEN = 50
OG_TAG_VALUE_MAXLEN = 1000
@ -311,7 +304,7 @@ class PreviewUrlResource(DirectServeJsonResource):
# If there was no oEmbed URL (or oEmbed parsing failed), attempt
# to generate the Open Graph information from the HTML.
if not oembed_url or not og:
og = _calc_og(tree, media_info.uri)
og = parse_html_to_open_graph(tree, media_info.uri)
await self._precache_image_url(user, media_info, og)
else:
@ -468,7 +461,7 @@ class PreviewUrlResource(DirectServeJsonResource):
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
image_info = await self._download_url(
_rebase_url(og["og:image"], media_info.uri), user
rebase_url(og["og:image"], media_info.uri), user
)
if _is_media(image_info.media_type):
@ -632,301 +625,6 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("No media removed from url cache")
def _normalise_encoding(encoding: str) -> Optional[str]:
"""Use the Python codec's name as the normalised entry."""
try:
return codecs.lookup(encoding).name
except LookupError:
return None
def get_html_media_encodings(body: bytes, content_type: Optional[str]) -> Iterable[str]:
"""
Get potential encoding of the body based on the (presumably) HTML body or the content-type header.
The precedence used for finding a character encoding is:
1. <meta> tag with a charset declared.
2. The XML document's character encoding attribute.
3. The Content-Type header.
4. Fallback to utf-8.
5. Fallback to windows-1252.
This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector.
Args:
body: The HTML document, as bytes.
content_type: The Content-Type header.
Returns:
The character encoding of the body, as a string.
"""
# There's no point in returning an encoding more than once.
attempted_encodings: Set[str] = set()
# Limit searches to the first 1kb, since it ought to be at the top.
body_start = body[:1024]
# Check if it has an encoding set in a meta tag.
match = _charset_match.search(body_start)
if match:
encoding = _normalise_encoding(match.group(1).decode("ascii"))
if encoding:
attempted_encodings.add(encoding)
yield encoding
# TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
# Check if it has an XML document with an encoding.
match = _xml_encoding_match.match(body_start)
if match:
encoding = _normalise_encoding(match.group(1).decode("ascii"))
if encoding and encoding not in attempted_encodings:
attempted_encodings.add(encoding)
yield encoding
# Check the HTTP Content-Type header for a character set.
if content_type:
content_match = _content_type_match.match(content_type)
if content_match:
encoding = _normalise_encoding(content_match.group(1))
if encoding and encoding not in attempted_encodings:
attempted_encodings.add(encoding)
yield encoding
# Finally, fallback to UTF-8, then windows-1252.
for fallback in ("utf-8", "cp1252"):
if fallback not in attempted_encodings:
yield fallback
def decode_body(
body: bytes, uri: str, content_type: Optional[str] = None
) -> Optional["etree.Element"]:
"""
This uses lxml to parse the HTML document.
Args:
body: The HTML document, as bytes.
uri: The URI used to download the body.
content_type: The Content-Type header.
Returns:
The parsed HTML body, or None if an error occurred during processed.
"""
# If there's no body, nothing useful is going to be found.
if not body:
return None
# The idea here is that multiple encodings are tried until one works.
# Unfortunately the result is never used and then LXML will decode the string
# again with the found encoding.
for encoding in get_html_media_encodings(body, content_type):
try:
body.decode(encoding)
except Exception:
pass
else:
break
else:
logger.warning("Unable to decode HTML body for %s", uri)
return None
from lxml import etree
# Create an HTML parser.
parser = etree.HTMLParser(recover=True, encoding=encoding)
# Attempt to parse the body. Returns None if the body was successfully
# parsed, but no tree was found.
return etree.fromstring(body, parser)
def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
"""
Calculate metadata for an HTML document.
This uses lxml to search the HTML document for Open Graph data.
Args:
tree: The parsed HTML document.
media_url: The URI used to download the body.
Returns:
The Open Graph response as a dictionary.
"""
# if we see any image URLs in the OG response, then spider them
# (although the client could choose to do this by asking for previews of those
# URLs to avoid DoSing the server)
# "og:type" : "video",
# "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
# "og:site_name" : "YouTube",
# "og:video:type" : "application/x-shockwave-flash",
# "og:description" : "Fun stuff happening here",
# "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
# "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
# "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
# "og:video:width" : "1280"
# "og:video:height" : "720",
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
og: Dict[str, Optional[str]] = {}
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
if "content" in tag.attrib:
# if we've got more than 50 tags, someone is taking the piss
if len(og) >= 50:
logger.warning("Skipping OG for page with too many 'og:' tags")
return {}
og[tag.attrib["property"]] = tag.attrib["content"]
# TODO: grab article: meta tags too, e.g.:
# "article:publisher" : "https://www.facebook.com/thethudonline" />
# "article:author" content="https://www.facebook.com/thethudonline" />
# "article:tag" content="baby" />
# "article:section" content="Breaking News" />
# "article:published_time" content="2016-03-31T19:58:24+00:00" />
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
if "og:title" not in og:
# do some basic spidering of the HTML
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
if title and title[0].text is not None:
og["og:title"] = title[0].text.strip()
else:
og["og:title"] = None
if "og:image" not in og:
# TODO: extract a favicon failing all else
meta_image = tree.xpath(
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
)
if meta_image:
og["og:image"] = _rebase_url(meta_image[0], media_uri)
else:
# TODO: consider inlined CSS styles as well as width & height attribs
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
images = sorted(
images,
key=lambda i: (
-1 * float(i.attrib["width"]) * float(i.attrib["height"])
),
)
if not images:
images = tree.xpath("//img[@src]")
if images:
og["og:image"] = images[0].attrib["src"]
if "og:description" not in og:
meta_description = tree.xpath(
"//*/meta"
"[translate(@name, 'DESCRIPTION', 'description')='description']"
"/@content"
)
if meta_description:
og["og:description"] = meta_description[0]
else:
og["og:description"] = _calc_description(tree)
elif og["og:description"]:
# This must be a non-empty string at this point.
assert isinstance(og["og:description"], str)
og["og:description"] = summarize_paragraphs([og["og:description"]])
# TODO: delete the url downloads to stop diskfilling,
# as we only ever cared about its OG
return og
def _calc_description(tree: "etree.Element") -> Optional[str]:
"""
Calculate a text description based on an HTML document.
Grabs any text nodes which are inside the <body/> tag, unless they are within
an HTML5 semantic markup tag (<header/>, <nav/>, <aside/>, <footer/>), or
if they are within a <script/> or <style/> tag.
This is a very very very coarse approximation to a plain text render of the page.
Args:
tree: The parsed HTML document.
Returns:
The plain text description, or None if one cannot be generated.
"""
# We don't just use XPATH here as that is slow on some machines.
from lxml import etree
TAGS_TO_REMOVE = (
"header",
"nav",
"aside",
"footer",
"script",
"noscript",
"style",
etree.Comment,
)
# Split all the text nodes into paragraphs (by splitting on new
# lines)
text_nodes = (
re.sub(r"\s+", "\n", el).strip()
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
)
return summarize_paragraphs(text_nodes)
def _iterate_over_text(
tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
) -> Generator[str, None, None]:
"""Iterate over the tree returning text nodes in a depth first fashion,
skipping text nodes inside certain tags.
"""
# This is basically a stack that we extend using itertools.chain.
# This will either consist of an element to iterate over *or* a string
# to be returned.
elements = iter([tree])
while True:
el = next(elements, None)
if el is None:
return
if isinstance(el, str):
yield el
elif el.tag not in tags_to_ignore:
# el.text is the text before the first child, so we can immediately
# return it if the text exists.
if el.text:
yield el.text
# We add to the stack all the elements children, interspersed with
# each child's tail text (if it exists). The tail text of a node
# is text that comes *after* the node, so we always include it even
# if we ignore the child node.
elements = itertools.chain(
itertools.chain.from_iterable( # Basically a flatmap
[child, child.tail] if child.tail else [child]
for child in el.iterchildren()
),
elements,
)
def _rebase_url(url: str, base: str) -> str:
base_parts = list(urlparse.urlparse(base))
url_parts = list(urlparse.urlparse(url))
if not url_parts[0]: # fix up schema
url_parts[0] = base_parts[0] or "http"
if not url_parts[1]: # fix up hostname
url_parts[1] = base_parts[1]
if not url_parts[2].startswith("/"):
url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
return urlparse.urlunparse(url_parts)
def _is_media(content_type: str) -> bool:
return content_type.lower().startswith("image/")
@ -940,68 +638,3 @@ def _is_html(content_type: str) -> bool:
def _is_json(content_type: str) -> bool:
return content_type.lower().startswith("application/json")
def summarize_paragraphs(
text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
) -> Optional[str]:
"""
Try to get a summary respecting first paragraph and then word boundaries.
Args:
text_nodes: The paragraphs to summarize.
min_size: The minimum number of words to include.
max_size: The maximum number of words to include.
Returns:
A summary of the text nodes, or None if that was not possible.
"""
# TODO: Respect sentences?
description = ""
# Keep adding paragraphs until we get to the MIN_SIZE.
for text_node in text_nodes:
if len(description) < min_size:
text_node = re.sub(r"[\t \r\n]+", " ", text_node)
description += text_node + "\n\n"
else:
break
description = description.strip()
description = re.sub(r"[\t ]+", " ", description)
description = re.sub(r"[\t \r\n]*[\r\n]+", "\n\n", description)
# If the concatenation of paragraphs to get above MIN_SIZE
# took us over MAX_SIZE, then we need to truncate mid paragraph
if len(description) > max_size:
new_desc = ""
# This splits the paragraph into words, but keeping the
# (preceding) whitespace intact so we can easily concat
# words back together.
for match in re.finditer(r"\s*\S+", description):
word = match.group()
# Keep adding words while the total length is less than
# MAX_SIZE.
if len(word) + len(new_desc) < max_size:
new_desc += word
else:
# At this point the next word *will* take us over
# MAX_SIZE, but we also want to ensure that its not
# a huge word. If it is add it anyway and we'll
# truncate later.
if len(new_desc) < min_size:
new_desc += word
break
# Double check that we're not over the limit
if len(new_desc) > max_size:
new_desc = new_desc[:max_size]
# We always add an ellipsis because at the very least
# we chopped mid paragraph.
description = new_desc.strip() + ""
return description if description else None