mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-07-22 15:40:36 -04:00
Pass the Requester down to the HttpTransactionCache. (#15200)
This commit is contained in:
parent
820f02b70b
commit
47bc84dd53
6 changed files with 217 additions and 131 deletions
|
@ -57,7 +57,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
|||
from synapse.rest.client._base import client_patterns
|
||||
from synapse.rest.client.transactions import HttpTransactionCache
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID
|
||||
from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID
|
||||
from synapse.types.state import StateFilter
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.cancellation import cancellable
|
||||
|
@ -151,15 +151,22 @@ class RoomCreateRestServlet(TransactionRestServlet):
|
|||
PATTERNS = "/createRoom"
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
||||
def on_PUT(
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, txn_id: str
|
||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
set_tag("txn_id", txn_id)
|
||||
return self.txns.fetch_or_execute_request(request, self.on_POST, request)
|
||||
return await self.txns.fetch_or_execute_request(
|
||||
request, requester, self._do, request, requester
|
||||
)
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
return await self._do(request, requester)
|
||||
|
||||
async def _do(
|
||||
self, request: SynapseRequest, requester: Requester
|
||||
) -> Tuple[int, JsonDict]:
|
||||
room_id, _, _ = await self._room_creation_handler.create_room(
|
||||
requester, self.get_room_config(request)
|
||||
)
|
||||
|
@ -172,9 +179,9 @@ class RoomCreateRestServlet(TransactionRestServlet):
|
|||
|
||||
|
||||
# TODO: Needs unit testing for generic events
|
||||
class RoomStateEventRestServlet(TransactionRestServlet):
|
||||
class RoomStateEventRestServlet(RestServlet):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
super().__init__()
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
self.room_member_handler = hs.get_room_member_handler()
|
||||
self.message_handler = hs.get_message_handler()
|
||||
|
@ -324,16 +331,16 @@ class RoomSendEventRestServlet(TransactionRestServlet):
|
|||
def register(self, http_server: HttpServer) -> None:
|
||||
# /rooms/$roomid/send/$event_type[/$txn_id]
|
||||
PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
|
||||
register_txn_path(self, PATTERNS, http_server, with_get=True)
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
||||
async def on_POST(
|
||||
async def _do(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
requester: Requester,
|
||||
room_id: str,
|
||||
event_type: str,
|
||||
txn_id: Optional[str] = None,
|
||||
txn_id: Optional[str],
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
event_dict: JsonDict = {
|
||||
|
@ -362,18 +369,30 @@ class RoomSendEventRestServlet(TransactionRestServlet):
|
|||
set_tag("event_id", event_id)
|
||||
return 200, {"event_id": event_id}
|
||||
|
||||
def on_GET(
|
||||
self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
|
||||
) -> Tuple[int, str]:
|
||||
return 200, "Not implemented"
|
||||
async def on_POST(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
room_id: str,
|
||||
event_type: str,
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
return await self._do(request, requester, room_id, event_type, None)
|
||||
|
||||
def on_PUT(
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
|
||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
set_tag("txn_id", txn_id)
|
||||
|
||||
return self.txns.fetch_or_execute_request(
|
||||
request, self.on_POST, request, room_id, event_type, txn_id
|
||||
return await self.txns.fetch_or_execute_request(
|
||||
request,
|
||||
requester,
|
||||
self._do,
|
||||
request,
|
||||
requester,
|
||||
room_id,
|
||||
event_type,
|
||||
txn_id,
|
||||
)
|
||||
|
||||
|
||||
|
@ -389,14 +408,13 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
|
|||
PATTERNS = "/join/(?P<room_identifier>[^/]*)"
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
||||
async def on_POST(
|
||||
async def _do(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
requester: Requester,
|
||||
room_identifier: str,
|
||||
txn_id: Optional[str] = None,
|
||||
txn_id: Optional[str],
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
content = parse_json_object_from_request(request, allow_empty_body=True)
|
||||
|
||||
# twisted.web.server.Request.args is incorrectly defined as Optional[Any]
|
||||
|
@ -420,22 +438,31 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
|
|||
|
||||
return 200, {"room_id": room_id}
|
||||
|
||||
def on_PUT(
|
||||
async def on_POST(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
room_identifier: str,
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
return await self._do(request, requester, room_identifier, None)
|
||||
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, room_identifier: str, txn_id: str
|
||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
set_tag("txn_id", txn_id)
|
||||
|
||||
return self.txns.fetch_or_execute_request(
|
||||
request, self.on_POST, request, room_identifier, txn_id
|
||||
return await self.txns.fetch_or_execute_request(
|
||||
request, requester, self._do, request, requester, room_identifier, txn_id
|
||||
)
|
||||
|
||||
|
||||
# TODO: Needs unit testing
|
||||
class PublicRoomListRestServlet(TransactionRestServlet):
|
||||
class PublicRoomListRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/publicRooms$", v1=True)
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
|
@ -907,22 +934,25 @@ class RoomForgetRestServlet(TransactionRestServlet):
|
|||
PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, room_id: str, txn_id: Optional[str] = None
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=False)
|
||||
|
||||
async def _do(self, requester: Requester, room_id: str) -> Tuple[int, JsonDict]:
|
||||
await self.room_member_handler.forget(user=requester.user, room_id=room_id)
|
||||
|
||||
return 200, {}
|
||||
|
||||
def on_PUT(
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=False)
|
||||
return await self._do(requester, room_id)
|
||||
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, room_id: str, txn_id: str
|
||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=False)
|
||||
set_tag("txn_id", txn_id)
|
||||
|
||||
return self.txns.fetch_or_execute_request(
|
||||
request, self.on_POST, request, room_id, txn_id
|
||||
return await self.txns.fetch_or_execute_request(
|
||||
request, requester, self._do, requester, room_id
|
||||
)
|
||||
|
||||
|
||||
|
@ -941,15 +971,14 @@ class RoomMembershipRestServlet(TransactionRestServlet):
|
|||
)
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
||||
async def on_POST(
|
||||
async def _do(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
requester: Requester,
|
||||
room_id: str,
|
||||
membership_action: str,
|
||||
txn_id: Optional[str] = None,
|
||||
txn_id: Optional[str],
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
if requester.is_guest and membership_action not in {
|
||||
Membership.JOIN,
|
||||
Membership.LEAVE,
|
||||
|
@ -1014,13 +1043,30 @@ class RoomMembershipRestServlet(TransactionRestServlet):
|
|||
|
||||
return 200, return_value
|
||||
|
||||
def on_PUT(
|
||||
async def on_POST(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
room_id: str,
|
||||
membership_action: str,
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
return await self._do(request, requester, room_id, membership_action, None)
|
||||
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str
|
||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
set_tag("txn_id", txn_id)
|
||||
|
||||
return self.txns.fetch_or_execute_request(
|
||||
request, self.on_POST, request, room_id, membership_action, txn_id
|
||||
return await self.txns.fetch_or_execute_request(
|
||||
request,
|
||||
requester,
|
||||
self._do,
|
||||
request,
|
||||
requester,
|
||||
room_id,
|
||||
membership_action,
|
||||
txn_id,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1036,14 +1082,14 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
|
|||
PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
||||
async def on_POST(
|
||||
async def _do(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
requester: Requester,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
txn_id: Optional[str] = None,
|
||||
txn_id: Optional[str],
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
try:
|
||||
|
@ -1094,13 +1140,23 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
|
|||
set_tag("event_id", event_id)
|
||||
return 200, {"event_id": event_id}
|
||||
|
||||
def on_PUT(
|
||||
async def on_POST(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
return await self._do(request, requester, room_id, event_id, None)
|
||||
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str
|
||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
set_tag("txn_id", txn_id)
|
||||
|
||||
return self.txns.fetch_or_execute_request(
|
||||
request, self.on_POST, request, room_id, event_id, txn_id
|
||||
return await self.txns.fetch_or_execute_request(
|
||||
request, requester, self._do, request, requester, room_id, event_id, txn_id
|
||||
)
|
||||
|
||||
|
||||
|
@ -1224,7 +1280,6 @@ def register_txn_path(
|
|||
servlet: RestServlet,
|
||||
regex_string: str,
|
||||
http_server: HttpServer,
|
||||
with_get: bool = False,
|
||||
) -> None:
|
||||
"""Registers a transaction-based path.
|
||||
|
||||
|
@ -1236,7 +1291,6 @@ def register_txn_path(
|
|||
regex_string: The regex string to register. Must NOT have a
|
||||
trailing $ as this string will be appended to.
|
||||
http_server: The http_server to register paths with.
|
||||
with_get: True to also register respective GET paths for the PUTs.
|
||||
"""
|
||||
on_POST = getattr(servlet, "on_POST", None)
|
||||
on_PUT = getattr(servlet, "on_PUT", None)
|
||||
|
@ -1254,18 +1308,6 @@ def register_txn_path(
|
|||
on_PUT,
|
||||
servlet.__class__.__name__,
|
||||
)
|
||||
on_GET = getattr(servlet, "on_GET", None)
|
||||
if with_get:
|
||||
if on_GET is None:
|
||||
raise RuntimeError(
|
||||
"register_txn_path called with with_get = True, but no on_GET method exists"
|
||||
)
|
||||
http_server.register_paths(
|
||||
"GET",
|
||||
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
|
||||
on_GET,
|
||||
servlet.__class__.__name__,
|
||||
)
|
||||
|
||||
|
||||
class TimestampLookupRestServlet(RestServlet):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue