Merge remote-tracking branch 'upstream/release-v1.29.0'

This commit is contained in:
Tulir Asokan 2021-03-04 12:49:37 +02:00
commit adb990d8ba
126 changed files with 2399 additions and 765 deletions

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import NotFoundError, SynapseError
from synapse.http.servlet import (
@ -20,8 +21,12 @@ from synapse.http.servlet import (
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import UserID
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -35,14 +40,16 @@ class DeviceRestServlet(RestServlet):
"/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$", "v2"
)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
async def on_GET(self, request, user_id, device_id):
async def on_GET(
self, request: SynapseRequest, user_id, device_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
@ -58,7 +65,9 @@ class DeviceRestServlet(RestServlet):
)
return 200, device
async def on_DELETE(self, request, user_id, device_id):
async def on_DELETE(
self, request: SynapseRequest, user_id: str, device_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
@ -72,7 +81,9 @@ class DeviceRestServlet(RestServlet):
await self.device_handler.delete_device(target_user.to_string(), device_id)
return 200, {}
async def on_PUT(self, request, user_id, device_id):
async def on_PUT(
self, request: SynapseRequest, user_id: str, device_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
@ -97,7 +108,7 @@ class DevicesRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
@ -107,7 +118,9 @@ class DevicesRestServlet(RestServlet):
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
async def on_GET(self, request, user_id):
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
@ -130,13 +143,15 @@ class DeleteDevicesRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
async def on_POST(self, request, user_id):
async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)

View file

@ -14,10 +14,16 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -45,12 +51,12 @@ class EventReportsRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
start = parse_integer(request, "from", default=0)
@ -106,26 +112,28 @@ class EventReportDetailRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request, report_id):
async def on_GET(
self, request: SynapseRequest, report_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
message = (
"The report_id parameter must be a string representing a positive integer."
)
try:
report_id = int(report_id)
resolved_report_id = int(report_id)
except ValueError:
raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
if report_id < 0:
if resolved_report_id < 0:
raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
ret = await self.store.get_event_report(report_id)
ret = await self.store.get_event_report(resolved_report_id)
if not ret:
raise NotFoundError("Event report not found")

View file

@ -17,7 +17,7 @@
import logging
from typing import TYPE_CHECKING, Tuple
from twisted.web.http import Request
from twisted.web.server import Request
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer

View file

@ -44,6 +44,48 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class ResolveRoomIdMixin:
def __init__(self, hs: "HomeServer"):
self.room_member_handler = hs.get_room_member_handler()
async def resolve_room_id(
self, room_identifier: str, remote_room_hosts: Optional[List[str]] = None
) -> Tuple[str, Optional[List[str]]]:
"""
Resolve a room identifier to a room ID, if necessary.
This also performanes checks to ensure the room ID is of the proper form.
Args:
room_identifier: The room ID or alias.
remote_room_hosts: The potential remote room hosts to use.
Returns:
The resolved room ID.
Raises:
SynapseError if the room ID is of the wrong form.
"""
if RoomID.is_valid(room_identifier):
resolved_room_id = room_identifier
elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier)
(
room_id,
remote_room_hosts,
) = await self.room_member_handler.lookup_room_alias(room_alias)
resolved_room_id = room_id.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
if not resolved_room_id:
raise SynapseError(
400, "Unknown room ID or room alias %s" % room_identifier
)
return resolved_room_id, remote_room_hosts
class ShutdownRoomRestServlet(RestServlet):
"""Shuts down a room by removing all local users from the room and blocking
all future invites and joins to the room. Any local aliases will be repointed
@ -334,14 +376,14 @@ class RoomStateRestServlet(RestServlet):
return 200, ret
class JoinRoomAliasServlet(RestServlet):
class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
self.room_member_handler = hs.get_room_member_handler()
self.admin_handler = hs.get_admin_handler()
self.state_handler = hs.get_state_handler()
@ -362,22 +404,16 @@ class JoinRoomAliasServlet(RestServlet):
if not await self.admin_handler.get_user(target_user):
raise NotFoundError("User not found")
if RoomID.is_valid(room_identifier):
room_id = room_identifier
try:
remote_room_hosts = [
x.decode("ascii") for x in request.args[b"server_name"]
] # type: Optional[List[str]]
except Exception:
remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier):
handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier)
room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
# Get the room ID from the identifier.
try:
remote_room_hosts = [
x.decode("ascii") for x in request.args[b"server_name"]
] # type: Optional[List[str]]
except Exception:
remote_room_hosts = None
room_id, remote_room_hosts = await self.resolve_room_id(
room_identifier, remote_room_hosts
)
fake_requester = create_requester(
target_user, authenticated_entity=requester.authenticated_entity
@ -412,7 +448,7 @@ class JoinRoomAliasServlet(RestServlet):
return 200, {"room_id": room_id}
class MakeRoomAdminRestServlet(RestServlet):
class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
"""Allows a server admin to get power in a room if a local user has power in
a room. Will also invite the user if they're not in the room and it's a
private room. Can specify another user (rather than the admin user) to be
@ -427,29 +463,21 @@ class MakeRoomAdminRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin")
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
self.room_member_handler = hs.get_room_member_handler()
self.event_creation_handler = hs.get_event_creation_handler()
self.state_handler = hs.get_state_handler()
self.is_mine_id = hs.is_mine_id
async def on_POST(self, request, room_identifier):
async def on_POST(
self, request: SynapseRequest, room_identifier: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
content = parse_json_object_from_request(request, allow_empty_body=True)
# Resolve to a room ID, if necessary.
if RoomID.is_valid(room_identifier):
room_id = room_identifier
elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier)
room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
room_id = room_id.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
room_id, _ = await self.resolve_room_id(room_identifier)
# Which user to grant room admin rights to.
user_to_add = content.get("user_id", requester.user.to_string())
@ -556,7 +584,7 @@ class MakeRoomAdminRestServlet(RestServlet):
return 200, {}
class ForwardExtremitiesRestServlet(RestServlet):
class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
"""Allows a server admin to get or clear forward extremities.
Clearing does not require restarting the server.
@ -571,43 +599,29 @@ class ForwardExtremitiesRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities")
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
self.room_member_handler = hs.get_room_member_handler()
self.store = hs.get_datastore()
async def resolve_room_id(self, room_identifier: str) -> str:
"""Resolve to a room ID, if necessary."""
if RoomID.is_valid(room_identifier):
resolved_room_id = room_identifier
elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier)
room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
resolved_room_id = room_id.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
if not resolved_room_id:
raise SynapseError(
400, "Unknown room ID or room alias %s" % room_identifier
)
return resolved_room_id
async def on_DELETE(self, request, room_identifier):
async def on_DELETE(
self, request: SynapseRequest, room_identifier: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
room_id = await self.resolve_room_id(room_identifier)
room_id, _ = await self.resolve_room_id(room_identifier)
deleted_count = await self.store.delete_forward_extremities_for_room(room_id)
return 200, {"deleted": deleted_count}
async def on_GET(self, request, room_identifier):
async def on_GET(
self, request: SynapseRequest, room_identifier: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
room_id = await self.resolve_room_id(room_identifier)
room_id, _ = await self.resolve_room_id(room_identifier)
extremities = await self.store.get_forward_extremities_for_room(room_id)
return 200, {"count": len(extremities), "results": extremities}
@ -623,14 +637,16 @@ class RoomEventContextServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
async def on_GET(self, request, room_id, event_id):
async def on_GET(
self, request: SynapseRequest, room_id: str, event_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False)
await assert_user_is_admin(self.auth, requester.user)

View file

@ -16,7 +16,7 @@ import hashlib
import hmac
import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError
@ -35,6 +35,7 @@ from synapse.rest.admin._base import (
assert_user_is_admin,
)
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.databases.main.media_repository import MediaSortOrder
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
@ -46,13 +47,15 @@ logger = logging.getLogger(__name__)
class UsersRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request, user_id):
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, List[JsonDict]]:
target_user = UserID.from_string(user_id)
await assert_requester_is_admin(self.auth, request)
@ -152,7 +155,7 @@ class UserRestServletV2(RestServlet):
otherwise an error.
"""
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
@ -164,7 +167,9 @@ class UserRestServletV2(RestServlet):
self.registration_handler = hs.get_registration_handler()
self.pusher_pool = hs.get_pusherpool()
async def on_GET(self, request, user_id):
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
@ -178,7 +183,9 @@ class UserRestServletV2(RestServlet):
return 200, ret
async def on_PUT(self, request, user_id):
async def on_PUT(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@ -272,6 +279,8 @@ class UserRestServletV2(RestServlet):
)
user = await self.admin_handler.get_user(target_user)
assert user is not None
return 200, user
else: # create user
@ -329,9 +338,10 @@ class UserRestServletV2(RestServlet):
target_user, requester, body["avatar_url"], True
)
ret = await self.admin_handler.get_user(target_user)
user = await self.admin_handler.get_user(target_user)
assert user is not None
return 201, ret
return 201, user
class UserRegisterServlet(RestServlet):
@ -345,10 +355,10 @@ class UserRegisterServlet(RestServlet):
PATTERNS = admin_patterns("/register")
NONCE_TIMEOUT = 60
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.auth_handler = hs.get_auth_handler()
self.reactor = hs.get_reactor()
self.nonces = {}
self.nonces = {} # type: Dict[str, int]
self.hs = hs
def _clear_old_nonces(self):
@ -361,7 +371,7 @@ class UserRegisterServlet(RestServlet):
if now - v > self.NONCE_TIMEOUT:
del self.nonces[k]
def on_GET(self, request):
def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""
Generate a new nonce.
"""
@ -371,7 +381,7 @@ class UserRegisterServlet(RestServlet):
self.nonces[nonce] = int(self.reactor.seconds())
return 200, {"nonce": nonce}
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
self._clear_old_nonces()
if not self.hs.config.registration_shared_secret:
@ -477,12 +487,14 @@ class WhoisRestServlet(RestServlet):
client_patterns("/admin" + path_regex, v1=True)
)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request, user_id):
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request)
auth_user = requester.user
@ -507,7 +519,9 @@ class DeactivateAccountRestServlet(RestServlet):
self.is_mine = hs.is_mine
self.store = hs.get_datastore()
async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]:
async def on_POST(
self, request: SynapseRequest, target_user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@ -549,7 +563,7 @@ class AccountValidityRenewServlet(RestServlet):
self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth()
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
@ -583,14 +597,16 @@ class ResetPasswordRestServlet(RestServlet):
PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self._set_password_handler = hs.get_set_password_handler()
async def on_POST(self, request, target_user_id):
async def on_POST(
self, request: SynapseRequest, target_user_id: str
) -> Tuple[int, JsonDict]:
"""Post request to allow an administrator reset password for a user.
This needs user to have administrator access in Synapse.
"""
@ -625,12 +641,14 @@ class SearchUsersRestServlet(RestServlet):
PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_GET(self, request, target_user_id):
async def on_GET(
self, request: SynapseRequest, target_user_id: str
) -> Tuple[int, Optional[List[JsonDict]]]:
"""Get request to search user table for specific users according to
search term.
This needs user to have a administrator access in Synapse.
@ -681,12 +699,14 @@ class UserAdminServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_GET(self, request, user_id):
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
@ -698,7 +718,9 @@ class UserAdminServlet(RestServlet):
return 200, {"admin": is_admin}
async def on_PUT(self, request, user_id):
async def on_PUT(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user
@ -729,12 +751,14 @@ class UserMembershipRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request, user_id):
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
room_ids = await self.store.get_rooms_for_user(user_id)
@ -757,7 +781,7 @@ class PushersRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
self.store = hs.get_datastore()
self.auth = hs.get_auth()
@ -798,7 +822,7 @@ class UserMediaRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@ -832,8 +856,33 @@ class UserMediaRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM,
)
# If neither `order_by` nor `dir` is set, set the default order
# to newest media is on top for backward compatibility.
if b"order_by" not in request.args and b"dir" not in request.args:
order_by = MediaSortOrder.CREATED_TS.value
direction = "b"
else:
order_by = parse_string(
request,
"order_by",
default=MediaSortOrder.CREATED_TS.value,
allowed_values=(
MediaSortOrder.MEDIA_ID.value,
MediaSortOrder.UPLOAD_NAME.value,
MediaSortOrder.CREATED_TS.value,
MediaSortOrder.LAST_ACCESS_TS.value,
MediaSortOrder.MEDIA_LENGTH.value,
MediaSortOrder.MEDIA_TYPE.value,
MediaSortOrder.QUARANTINED_BY.value,
MediaSortOrder.SAFE_FROM_QUARANTINE.value,
),
)
direction = parse_string(
request, "dir", default="f", allowed_values=("f", "b")
)
media, total = await self.store.get_local_media_by_user_paginate(
start, limit, user_id
start, limit, user_id, order_by, direction
)
ret = {"media": media, "total": total}
@ -865,7 +914,9 @@ class UserTokenRestServlet(RestServlet):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
async def on_POST(self, request, user_id):
async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user
@ -917,7 +968,9 @@ class ShadowBanRestServlet(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request, user_id):
async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id):

View file

@ -20,6 +20,7 @@ from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice import ApplicationService
from synapse.handlers.sso import SsoIdentityProvider
from synapse.http import get_request_uri
from synapse.http.server import HttpServer, finish_request
from synapse.http.servlet import (
RestServlet,
@ -354,6 +355,7 @@ class SsoRedirectServlet(RestServlet):
hs.get_oidc_handler()
self._sso_handler = hs.get_sso_handler()
self._msc2858_enabled = hs.config.experimental.msc2858_enabled
self._public_baseurl = hs.config.public_baseurl
def register(self, http_server: HttpServer) -> None:
super().register(http_server)
@ -373,6 +375,32 @@ class SsoRedirectServlet(RestServlet):
async def on_GET(
self, request: SynapseRequest, idp_id: Optional[str] = None
) -> None:
if not self._public_baseurl:
raise SynapseError(400, "SSO requires a valid public_baseurl")
# if this isn't the expected hostname, redirect to the right one, so that we
# get our cookies back.
requested_uri = get_request_uri(request)
baseurl_bytes = self._public_baseurl.encode("utf-8")
if not requested_uri.startswith(baseurl_bytes):
# swap out the incorrect base URL for the right one.
#
# The idea here is to redirect from
# https://foo.bar/whatever/_matrix/...
# to
# https://public.baseurl/_matrix/...
#
i = requested_uri.index(b"/_matrix")
new_uri = baseurl_bytes[:-1] + requested_uri[i:]
logger.info(
"Requested URI %s is not canonical: redirecting to %s",
requested_uri.decode("utf-8", errors="replace"),
new_uri.decode("utf-8", errors="replace"),
)
request.redirect(new_uri)
finish_request(request)
return
client_redirect_url = parse_string(
request, "redirectUrl", required=True, encoding=None
)

View file

@ -18,7 +18,7 @@ import logging
from functools import wraps
from typing import TYPE_CHECKING, Optional, Tuple
from twisted.web.http import Request
from twisted.web.server import Request
from synapse.api.constants import (
MAX_GROUP_CATEGORYID_LENGTH,

View file

@ -56,10 +56,8 @@ class SendToDeviceRestServlet(servlet.RestServlet):
content = parse_json_object_from_request(request)
assert_params_in_dict(content, ("messages",))
sender_user_id = requester.user.to_string()
await self.device_message_handler.send_device_message(
sender_user_id, message_type, content["messages"]
requester, message_type, content["messages"]
)
response = (200, {}) # type: Tuple[int, dict]

View file

@ -21,7 +21,7 @@ from typing import Awaitable, Dict, Generator, List, Optional, Tuple
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
from twisted.web.http import Request
from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.http.server import finish_request, respond_with_json
@ -49,18 +49,20 @@ TEXT_CONTENT_TYPES = [
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
try:
# The type on postpath seems incorrect in Twisted 21.2.0.
postpath = request.postpath # type: List[bytes] # type: ignore
assert postpath
# This allows users to append e.g. /test.png to the URL. Useful for
# clients that parse the URL to see content type.
server_name, media_id = request.postpath[:2]
if isinstance(server_name, bytes):
server_name = server_name.decode("utf-8")
media_id = media_id.decode("utf8")
server_name_bytes, media_id_bytes = postpath[:2]
server_name = server_name_bytes.decode("utf-8")
media_id = media_id_bytes.decode("utf8")
file_name = None
if len(request.postpath) > 2:
if len(postpath) > 2:
try:
file_name = urllib.parse.unquote(request.postpath[-1].decode("utf-8"))
file_name = urllib.parse.unquote(postpath[-1].decode("utf-8"))
except UnicodeDecodeError:
pass
return server_name, media_id, file_name

View file

@ -17,7 +17,7 @@
from typing import TYPE_CHECKING
from twisted.web.http import Request
from twisted.web.server import Request
from synapse.http.server import DirectServeJsonResource, respond_with_json

View file

@ -16,7 +16,7 @@
import logging
from typing import TYPE_CHECKING
from twisted.web.http import Request
from twisted.web.server import Request
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_boolean

View file

@ -22,8 +22,8 @@ from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import twisted.internet.error
import twisted.web.http
from twisted.web.http import Request
from twisted.web.resource import Resource
from twisted.web.server import Request
from synapse.api.errors import (
FederationDeniedError,
@ -509,7 +509,7 @@ class MediaRepository:
t_height: int,
t_method: str,
t_type: str,
url_cache: str,
url_cache: Optional[str],
) -> Optional[str]:
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(None, media_id, url_cache=url_cache)

View file

@ -244,7 +244,7 @@ class MediaStorage:
await consumer.wait()
return local_path
raise Exception("file could not be found")
raise NotFoundError()
def _file_info_to_path(self, file_info: FileInfo) -> str:
"""Converts file_info into a relative path.

View file

@ -29,7 +29,7 @@ from urllib import parse as urlparse
import attr
from twisted.internet.error import DNSLookupError
from twisted.web.http import Request
from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient
@ -137,9 +137,8 @@ class PreviewUrlResource(DirectServeJsonResource):
treq_args={"browser_like_redirects": True},
ip_whitelist=hs.config.url_preview_ip_range_whitelist,
ip_blacklist=hs.config.url_preview_ip_range_blacklist,
http_proxy=os.getenvb(b"http_proxy"),
https_proxy=os.getenvb(b"HTTPS_PROXY"),
user_agent=f"{hs.version_string} UrlPreviewBot"
use_proxy=True,
)
self.media_repo = media_repo
self.primary_base_path = media_repo.primary_base_path

View file

@ -18,7 +18,7 @@
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from twisted.web.http import Request
from twisted.web.server import Request
from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeJsonResource, set_cors_headers
@ -114,6 +114,7 @@ class ThumbnailResource(DirectServeJsonResource):
m_type,
thumbnail_infos,
media_id,
media_id,
url_cache=media_info["url_cache"],
server_name=None,
)
@ -269,6 +270,7 @@ class ThumbnailResource(DirectServeJsonResource):
method,
m_type,
thumbnail_infos,
media_id,
media_info["filesystem_id"],
url_cache=None,
server_name=server_name,
@ -282,6 +284,7 @@ class ThumbnailResource(DirectServeJsonResource):
desired_method: str,
desired_type: str,
thumbnail_infos: List[Dict[str, Any]],
media_id: str,
file_id: str,
url_cache: Optional[str] = None,
server_name: Optional[str] = None,
@ -316,9 +319,60 @@ class ThumbnailResource(DirectServeJsonResource):
respond_404(request)
return
responder = await self.media_storage.fetch_media(file_info)
if responder:
await respond_with_responder(
request,
responder,
file_info.thumbnail_type,
file_info.thumbnail_length,
)
return
# If we can't find the thumbnail we regenerate it. This can happen
# if e.g. we've deleted the thumbnails but still have the original
# image somewhere.
#
# Since we have an entry for the thumbnail in the DB we a) know we
# have have successfully generated the thumbnail in the past (so we
# don't need to worry about repeatedly failing to generate
# thumbnails), and b) have already calculated that appropriate
# width/height/method so we can just call the "generate exact"
# methods.
# First let's check that we do actually have the original image
# still. This will throw a 404 if we don't.
# TODO: We should refetch the thumbnails for remote media.
await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache)
)
if server_name:
await self.media_repo.generate_remote_exact_thumbnail(
server_name,
file_id=file_id,
media_id=media_id,
t_width=file_info.thumbnail_width,
t_height=file_info.thumbnail_height,
t_method=file_info.thumbnail_method,
t_type=file_info.thumbnail_type,
)
else:
await self.media_repo.generate_local_exact_thumbnail(
media_id=media_id,
t_width=file_info.thumbnail_width,
t_height=file_info.thumbnail_height,
t_method=file_info.thumbnail_method,
t_type=file_info.thumbnail_type,
url_cache=url_cache,
)
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(
request, responder, file_info.thumbnail_type, file_info.thumbnail_length
request,
responder,
file_info.thumbnail_type,
file_info.thumbnail_length,
)
else:
logger.info("Failed to find any generated thumbnails")

View file

@ -15,9 +15,9 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from typing import IO, TYPE_CHECKING
from twisted.web.http import Request
from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json
@ -79,7 +79,9 @@ class UploadResource(DirectServeJsonResource):
headers = request.requestHeaders
if headers.hasHeader(b"Content-Type"):
media_type = headers.getRawHeaders(b"Content-Type")[0].decode("ascii")
content_type_headers = headers.getRawHeaders(b"Content-Type")
assert content_type_headers # for mypy
media_type = content_type_headers[0].decode("ascii")
else:
raise SynapseError(msg="Upload request missing 'Content-Type'", code=400)
@ -88,8 +90,9 @@ class UploadResource(DirectServeJsonResource):
# TODO(markjh): parse content-dispostion
try:
content = request.content # type: IO # type: ignore
content_uri = await self.media_repo.create_content(
media_type, upload_name, request.content, content_length, requester.user
media_type, upload_name, content, content_length, requester.user
)
except SpamMediaException:
# For uploading of media we want to respond with a 400, instead of

View file

@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING
from twisted.web.http import Request
from twisted.web.server import Request
from synapse.api.errors import SynapseError
from synapse.handlers.sso import get_username_mapping_session_cookie_from_request

View file

@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING, Tuple
from twisted.web.http import Request
from twisted.web.server import Request
from synapse.api.errors import ThreepidValidationError
from synapse.config.emailconfig import ThreepidBehaviour

View file

@ -16,8 +16,8 @@
import logging
from typing import TYPE_CHECKING, List
from twisted.web.http import Request
from twisted.web.resource import Resource
from twisted.web.server import Request
from synapse.api.errors import SynapseError
from synapse.handlers.sso import get_username_mapping_session_cookie_from_request

View file

@ -16,7 +16,7 @@
import logging
from typing import TYPE_CHECKING
from twisted.web.http import Request
from twisted.web.server import Request
from synapse.api.errors import SynapseError
from synapse.handlers.sso import get_username_mapping_session_cookie_from_request