Pass the Requester down to the HttpTransactionCache. (#15200)

This commit is contained in:
Quentin Gliech 2023-03-07 17:05:22 +01:00 committed by GitHub
parent 820f02b70b
commit 47bc84dd53
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 217 additions and 131 deletions

View file

@ -57,7 +57,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.client._base import client_patterns
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID
from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID
from synapse.types.state import StateFilter
from synapse.util import json_decoder
from synapse.util.cancellation import cancellable
@ -151,15 +151,22 @@ class RoomCreateRestServlet(TransactionRestServlet):
PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server)
def on_PUT(
async def on_PUT(
self, request: SynapseRequest, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(request, self.on_POST, request)
return await self.txns.fetch_or_execute_request(
request, requester, self._do, request, requester
)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
return await self._do(request, requester)
async def _do(
self, request: SynapseRequest, requester: Requester
) -> Tuple[int, JsonDict]:
room_id, _, _ = await self._room_creation_handler.create_room(
requester, self.get_room_config(request)
)
@ -172,9 +179,9 @@ class RoomCreateRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(TransactionRestServlet):
class RoomStateEventRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
super().__init__()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.message_handler = hs.get_message_handler()
@ -324,16 +331,16 @@ class RoomSendEventRestServlet(TransactionRestServlet):
def register(self, http_server: HttpServer) -> None:
# /rooms/$roomid/send/$event_type[/$txn_id]
PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
register_txn_path(self, PATTERNS, http_server, with_get=True)
register_txn_path(self, PATTERNS, http_server)
async def on_POST(
async def _do(
self,
request: SynapseRequest,
requester: Requester,
room_id: str,
event_type: str,
txn_id: Optional[str] = None,
txn_id: Optional[str],
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
event_dict: JsonDict = {
@ -362,18 +369,30 @@ class RoomSendEventRestServlet(TransactionRestServlet):
set_tag("event_id", event_id)
return 200, {"event_id": event_id}
def on_GET(
self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
) -> Tuple[int, str]:
return 200, "Not implemented"
async def on_POST(
self,
request: SynapseRequest,
room_id: str,
event_type: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
return await self._do(request, requester, room_id, event_type, None)
def on_PUT(
async def on_PUT(
self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_type, txn_id
return await self.txns.fetch_or_execute_request(
request,
requester,
self._do,
request,
requester,
room_id,
event_type,
txn_id,
)
@ -389,14 +408,13 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
PATTERNS = "/join/(?P<room_identifier>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
async def on_POST(
async def _do(
self,
request: SynapseRequest,
requester: Requester,
room_identifier: str,
txn_id: Optional[str] = None,
txn_id: Optional[str],
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request, allow_empty_body=True)
# twisted.web.server.Request.args is incorrectly defined as Optional[Any]
@ -420,22 +438,31 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
return 200, {"room_id": room_id}
def on_PUT(
async def on_POST(
self,
request: SynapseRequest,
room_identifier: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
return await self._do(request, requester, room_identifier, None)
async def on_PUT(
self, request: SynapseRequest, room_identifier: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_identifier, txn_id
return await self.txns.fetch_or_execute_request(
request, requester, self._do, request, requester, room_identifier, txn_id
)
# TODO: Needs unit testing
class PublicRoomListRestServlet(TransactionRestServlet):
class PublicRoomListRestServlet(RestServlet):
PATTERNS = client_patterns("/publicRooms$", v1=True)
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@ -907,22 +934,25 @@ class RoomForgetRestServlet(TransactionRestServlet):
PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
register_txn_path(self, PATTERNS, http_server)
async def on_POST(
self, request: SynapseRequest, room_id: str, txn_id: Optional[str] = None
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False)
async def _do(self, requester: Requester, room_id: str) -> Tuple[int, JsonDict]:
await self.room_member_handler.forget(user=requester.user, room_id=room_id)
return 200, {}
def on_PUT(
async def on_POST(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False)
return await self._do(requester, room_id)
async def on_PUT(
self, request: SynapseRequest, room_id: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False)
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, txn_id
return await self.txns.fetch_or_execute_request(
request, requester, self._do, requester, room_id
)
@ -941,15 +971,14 @@ class RoomMembershipRestServlet(TransactionRestServlet):
)
register_txn_path(self, PATTERNS, http_server)
async def on_POST(
async def _do(
self,
request: SynapseRequest,
requester: Requester,
room_id: str,
membership_action: str,
txn_id: Optional[str] = None,
txn_id: Optional[str],
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
if requester.is_guest and membership_action not in {
Membership.JOIN,
Membership.LEAVE,
@ -1014,13 +1043,30 @@ class RoomMembershipRestServlet(TransactionRestServlet):
return 200, return_value
def on_PUT(
async def on_POST(
self,
request: SynapseRequest,
room_id: str,
membership_action: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
return await self._do(request, requester, room_id, membership_action, None)
async def on_PUT(
self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, membership_action, txn_id
return await self.txns.fetch_or_execute_request(
request,
requester,
self._do,
request,
requester,
room_id,
membership_action,
txn_id,
)
@ -1036,14 +1082,14 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
async def on_POST(
async def _do(
self,
request: SynapseRequest,
requester: Requester,
room_id: str,
event_id: str,
txn_id: Optional[str] = None,
txn_id: Optional[str],
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
try:
@ -1094,13 +1140,23 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
set_tag("event_id", event_id)
return 200, {"event_id": event_id}
def on_PUT(
async def on_POST(
self,
request: SynapseRequest,
room_id: str,
event_id: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
return await self._do(request, requester, room_id, event_id, None)
async def on_PUT(
self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_id, txn_id
return await self.txns.fetch_or_execute_request(
request, requester, self._do, request, requester, room_id, event_id, txn_id
)
@ -1224,7 +1280,6 @@ def register_txn_path(
servlet: RestServlet,
regex_string: str,
http_server: HttpServer,
with_get: bool = False,
) -> None:
"""Registers a transaction-based path.
@ -1236,7 +1291,6 @@ def register_txn_path(
regex_string: The regex string to register. Must NOT have a
trailing $ as this string will be appended to.
http_server: The http_server to register paths with.
with_get: True to also register respective GET paths for the PUTs.
"""
on_POST = getattr(servlet, "on_POST", None)
on_PUT = getattr(servlet, "on_PUT", None)
@ -1254,18 +1308,6 @@ def register_txn_path(
on_PUT,
servlet.__class__.__name__,
)
on_GET = getattr(servlet, "on_GET", None)
if with_get:
if on_GET is None:
raise RuntimeError(
"register_txn_path called with with_get = True, but no on_GET method exists"
)
http_server.register_paths(
"GET",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
on_GET,
servlet.__class__.__name__,
)
class TimestampLookupRestServlet(RestServlet):