mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Pass the Requester down to the HttpTransactionCache. (#15200)
This commit is contained in:
parent
820f02b70b
commit
47bc84dd53
1
changelog.d/15200.misc
Normal file
1
changelog.d/15200.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Make the `HttpTransactionCache` use the `Requester` in addition of the just the `Request` to build the transaction key.
|
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import TYPE_CHECKING, Awaitable, Optional, Tuple
|
from typing import TYPE_CHECKING, Optional, Tuple
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.api.errors import NotFoundError, SynapseError
|
from synapse.api.errors import NotFoundError, SynapseError
|
||||||
@ -23,10 +23,10 @@ from synapse.http.servlet import (
|
|||||||
parse_json_object_from_request,
|
parse_json_object_from_request,
|
||||||
)
|
)
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.rest.admin import assert_requester_is_admin
|
from synapse.logging.opentracing import set_tag
|
||||||
from synapse.rest.admin._base import admin_patterns
|
from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
|
||||||
from synapse.rest.client.transactions import HttpTransactionCache
|
from synapse.rest.client.transactions import HttpTransactionCache
|
||||||
from synapse.types import JsonDict, UserID
|
from synapse.types import JsonDict, Requester, UserID
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
@ -70,10 +70,13 @@ class SendServerNoticeServlet(RestServlet):
|
|||||||
self.__class__.__name__,
|
self.__class__.__name__,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def on_POST(
|
async def _do(
|
||||||
self, request: SynapseRequest, txn_id: Optional[str] = None
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
requester: Requester,
|
||||||
|
txn_id: Optional[str],
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
await assert_requester_is_admin(self.auth, request)
|
await assert_user_is_admin(self.auth, requester)
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
assert_params_in_dict(body, ("user_id", "content"))
|
assert_params_in_dict(body, ("user_id", "content"))
|
||||||
event_type = body.get("type", EventTypes.Message)
|
event_type = body.get("type", EventTypes.Message)
|
||||||
@ -106,9 +109,18 @@ class SendServerNoticeServlet(RestServlet):
|
|||||||
|
|
||||||
return HTTPStatus.OK, {"event_id": event.event_id}
|
return HTTPStatus.OK, {"event_id": event.event_id}
|
||||||
|
|
||||||
def on_PUT(
|
async def on_POST(
|
||||||
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
return await self._do(request, requester, None)
|
||||||
|
|
||||||
|
async def on_PUT(
|
||||||
self, request: SynapseRequest, txn_id: str
|
self, request: SynapseRequest, txn_id: str
|
||||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
) -> Tuple[int, JsonDict]:
|
||||||
return self.txns.fetch_or_execute_request(
|
requester = await self.auth.get_user_by_req(request)
|
||||||
request, self.on_POST, request, txn_id
|
set_tag("txn_id", txn_id)
|
||||||
|
return await self.txns.fetch_or_execute_request(
|
||||||
|
request, requester, self._do, request, requester, txn_id
|
||||||
)
|
)
|
||||||
|
@ -57,7 +57,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
|||||||
from synapse.rest.client._base import client_patterns
|
from synapse.rest.client._base import client_patterns
|
||||||
from synapse.rest.client.transactions import HttpTransactionCache
|
from synapse.rest.client.transactions import HttpTransactionCache
|
||||||
from synapse.streams.config import PaginationConfig
|
from synapse.streams.config import PaginationConfig
|
||||||
from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID
|
from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID
|
||||||
from synapse.types.state import StateFilter
|
from synapse.types.state import StateFilter
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
from synapse.util.cancellation import cancellable
|
from synapse.util.cancellation import cancellable
|
||||||
@ -151,15 +151,22 @@ class RoomCreateRestServlet(TransactionRestServlet):
|
|||||||
PATTERNS = "/createRoom"
|
PATTERNS = "/createRoom"
|
||||||
register_txn_path(self, PATTERNS, http_server)
|
register_txn_path(self, PATTERNS, http_server)
|
||||||
|
|
||||||
def on_PUT(
|
async def on_PUT(
|
||||||
self, request: SynapseRequest, txn_id: str
|
self, request: SynapseRequest, txn_id: str
|
||||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
) -> Tuple[int, JsonDict]:
|
||||||
|
requester = await self.auth.get_user_by_req(request)
|
||||||
set_tag("txn_id", txn_id)
|
set_tag("txn_id", txn_id)
|
||||||
return self.txns.fetch_or_execute_request(request, self.on_POST, request)
|
return await self.txns.fetch_or_execute_request(
|
||||||
|
request, requester, self._do, request, requester
|
||||||
|
)
|
||||||
|
|
||||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
return await self._do(request, requester)
|
||||||
|
|
||||||
|
async def _do(
|
||||||
|
self, request: SynapseRequest, requester: Requester
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
room_id, _, _ = await self._room_creation_handler.create_room(
|
room_id, _, _ = await self._room_creation_handler.create_room(
|
||||||
requester, self.get_room_config(request)
|
requester, self.get_room_config(request)
|
||||||
)
|
)
|
||||||
@ -172,9 +179,9 @@ class RoomCreateRestServlet(TransactionRestServlet):
|
|||||||
|
|
||||||
|
|
||||||
# TODO: Needs unit testing for generic events
|
# TODO: Needs unit testing for generic events
|
||||||
class RoomStateEventRestServlet(TransactionRestServlet):
|
class RoomStateEventRestServlet(RestServlet):
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__()
|
||||||
self.event_creation_handler = hs.get_event_creation_handler()
|
self.event_creation_handler = hs.get_event_creation_handler()
|
||||||
self.room_member_handler = hs.get_room_member_handler()
|
self.room_member_handler = hs.get_room_member_handler()
|
||||||
self.message_handler = hs.get_message_handler()
|
self.message_handler = hs.get_message_handler()
|
||||||
@ -324,16 +331,16 @@ class RoomSendEventRestServlet(TransactionRestServlet):
|
|||||||
def register(self, http_server: HttpServer) -> None:
|
def register(self, http_server: HttpServer) -> None:
|
||||||
# /rooms/$roomid/send/$event_type[/$txn_id]
|
# /rooms/$roomid/send/$event_type[/$txn_id]
|
||||||
PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
|
PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
|
||||||
register_txn_path(self, PATTERNS, http_server, with_get=True)
|
register_txn_path(self, PATTERNS, http_server)
|
||||||
|
|
||||||
async def on_POST(
|
async def _do(
|
||||||
self,
|
self,
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
|
requester: Requester,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
event_type: str,
|
event_type: str,
|
||||||
txn_id: Optional[str] = None,
|
txn_id: Optional[str],
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
event_dict: JsonDict = {
|
event_dict: JsonDict = {
|
||||||
@ -362,18 +369,30 @@ class RoomSendEventRestServlet(TransactionRestServlet):
|
|||||||
set_tag("event_id", event_id)
|
set_tag("event_id", event_id)
|
||||||
return 200, {"event_id": event_id}
|
return 200, {"event_id": event_id}
|
||||||
|
|
||||||
def on_GET(
|
async def on_POST(
|
||||||
self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
|
self,
|
||||||
) -> Tuple[int, str]:
|
request: SynapseRequest,
|
||||||
return 200, "Not implemented"
|
room_id: str,
|
||||||
|
event_type: str,
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
return await self._do(request, requester, room_id, event_type, None)
|
||||||
|
|
||||||
def on_PUT(
|
async def on_PUT(
|
||||||
self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
|
self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
|
||||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
) -> Tuple[int, JsonDict]:
|
||||||
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
set_tag("txn_id", txn_id)
|
set_tag("txn_id", txn_id)
|
||||||
|
|
||||||
return self.txns.fetch_or_execute_request(
|
return await self.txns.fetch_or_execute_request(
|
||||||
request, self.on_POST, request, room_id, event_type, txn_id
|
request,
|
||||||
|
requester,
|
||||||
|
self._do,
|
||||||
|
request,
|
||||||
|
requester,
|
||||||
|
room_id,
|
||||||
|
event_type,
|
||||||
|
txn_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -389,14 +408,13 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
|
|||||||
PATTERNS = "/join/(?P<room_identifier>[^/]*)"
|
PATTERNS = "/join/(?P<room_identifier>[^/]*)"
|
||||||
register_txn_path(self, PATTERNS, http_server)
|
register_txn_path(self, PATTERNS, http_server)
|
||||||
|
|
||||||
async def on_POST(
|
async def _do(
|
||||||
self,
|
self,
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
|
requester: Requester,
|
||||||
room_identifier: str,
|
room_identifier: str,
|
||||||
txn_id: Optional[str] = None,
|
txn_id: Optional[str],
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
|
||||||
|
|
||||||
content = parse_json_object_from_request(request, allow_empty_body=True)
|
content = parse_json_object_from_request(request, allow_empty_body=True)
|
||||||
|
|
||||||
# twisted.web.server.Request.args is incorrectly defined as Optional[Any]
|
# twisted.web.server.Request.args is incorrectly defined as Optional[Any]
|
||||||
@ -420,22 +438,31 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
|
|||||||
|
|
||||||
return 200, {"room_id": room_id}
|
return 200, {"room_id": room_id}
|
||||||
|
|
||||||
def on_PUT(
|
async def on_POST(
|
||||||
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
room_identifier: str,
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
return await self._do(request, requester, room_identifier, None)
|
||||||
|
|
||||||
|
async def on_PUT(
|
||||||
self, request: SynapseRequest, room_identifier: str, txn_id: str
|
self, request: SynapseRequest, room_identifier: str, txn_id: str
|
||||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
) -> Tuple[int, JsonDict]:
|
||||||
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
set_tag("txn_id", txn_id)
|
set_tag("txn_id", txn_id)
|
||||||
|
|
||||||
return self.txns.fetch_or_execute_request(
|
return await self.txns.fetch_or_execute_request(
|
||||||
request, self.on_POST, request, room_identifier, txn_id
|
request, requester, self._do, request, requester, room_identifier, txn_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: Needs unit testing
|
# TODO: Needs unit testing
|
||||||
class PublicRoomListRestServlet(TransactionRestServlet):
|
class PublicRoomListRestServlet(RestServlet):
|
||||||
PATTERNS = client_patterns("/publicRooms$", v1=True)
|
PATTERNS = client_patterns("/publicRooms$", v1=True)
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
@ -907,22 +934,25 @@ class RoomForgetRestServlet(TransactionRestServlet):
|
|||||||
PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
|
PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
|
||||||
register_txn_path(self, PATTERNS, http_server)
|
register_txn_path(self, PATTERNS, http_server)
|
||||||
|
|
||||||
async def on_POST(
|
async def _do(self, requester: Requester, room_id: str) -> Tuple[int, JsonDict]:
|
||||||
self, request: SynapseRequest, room_id: str, txn_id: Optional[str] = None
|
|
||||||
) -> Tuple[int, JsonDict]:
|
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=False)
|
|
||||||
|
|
||||||
await self.room_member_handler.forget(user=requester.user, room_id=room_id)
|
await self.room_member_handler.forget(user=requester.user, room_id=room_id)
|
||||||
|
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
def on_PUT(
|
async def on_POST(
|
||||||
|
self, request: SynapseRequest, room_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
|
requester = await self.auth.get_user_by_req(request, allow_guest=False)
|
||||||
|
return await self._do(requester, room_id)
|
||||||
|
|
||||||
|
async def on_PUT(
|
||||||
self, request: SynapseRequest, room_id: str, txn_id: str
|
self, request: SynapseRequest, room_id: str, txn_id: str
|
||||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
) -> Tuple[int, JsonDict]:
|
||||||
|
requester = await self.auth.get_user_by_req(request, allow_guest=False)
|
||||||
set_tag("txn_id", txn_id)
|
set_tag("txn_id", txn_id)
|
||||||
|
|
||||||
return self.txns.fetch_or_execute_request(
|
return await self.txns.fetch_or_execute_request(
|
||||||
request, self.on_POST, request, room_id, txn_id
|
request, requester, self._do, requester, room_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -941,15 +971,14 @@ class RoomMembershipRestServlet(TransactionRestServlet):
|
|||||||
)
|
)
|
||||||
register_txn_path(self, PATTERNS, http_server)
|
register_txn_path(self, PATTERNS, http_server)
|
||||||
|
|
||||||
async def on_POST(
|
async def _do(
|
||||||
self,
|
self,
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
|
requester: Requester,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
membership_action: str,
|
membership_action: str,
|
||||||
txn_id: Optional[str] = None,
|
txn_id: Optional[str],
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
|
||||||
|
|
||||||
if requester.is_guest and membership_action not in {
|
if requester.is_guest and membership_action not in {
|
||||||
Membership.JOIN,
|
Membership.JOIN,
|
||||||
Membership.LEAVE,
|
Membership.LEAVE,
|
||||||
@ -1014,13 +1043,30 @@ class RoomMembershipRestServlet(TransactionRestServlet):
|
|||||||
|
|
||||||
return 200, return_value
|
return 200, return_value
|
||||||
|
|
||||||
def on_PUT(
|
async def on_POST(
|
||||||
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
room_id: str,
|
||||||
|
membership_action: str,
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
return await self._do(request, requester, room_id, membership_action, None)
|
||||||
|
|
||||||
|
async def on_PUT(
|
||||||
self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str
|
self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str
|
||||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
) -> Tuple[int, JsonDict]:
|
||||||
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
set_tag("txn_id", txn_id)
|
set_tag("txn_id", txn_id)
|
||||||
|
|
||||||
return self.txns.fetch_or_execute_request(
|
return await self.txns.fetch_or_execute_request(
|
||||||
request, self.on_POST, request, room_id, membership_action, txn_id
|
request,
|
||||||
|
requester,
|
||||||
|
self._do,
|
||||||
|
request,
|
||||||
|
requester,
|
||||||
|
room_id,
|
||||||
|
membership_action,
|
||||||
|
txn_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1036,14 +1082,14 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
|
|||||||
PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
|
PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
|
||||||
register_txn_path(self, PATTERNS, http_server)
|
register_txn_path(self, PATTERNS, http_server)
|
||||||
|
|
||||||
async def on_POST(
|
async def _do(
|
||||||
self,
|
self,
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
|
requester: Requester,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
event_id: str,
|
event_id: str,
|
||||||
txn_id: Optional[str] = None,
|
txn_id: Optional[str],
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -1094,13 +1140,23 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
|
|||||||
set_tag("event_id", event_id)
|
set_tag("event_id", event_id)
|
||||||
return 200, {"event_id": event_id}
|
return 200, {"event_id": event_id}
|
||||||
|
|
||||||
def on_PUT(
|
async def on_POST(
|
||||||
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
room_id: str,
|
||||||
|
event_id: str,
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
return await self._do(request, requester, room_id, event_id, None)
|
||||||
|
|
||||||
|
async def on_PUT(
|
||||||
self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str
|
self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str
|
||||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
) -> Tuple[int, JsonDict]:
|
||||||
|
requester = await self.auth.get_user_by_req(request)
|
||||||
set_tag("txn_id", txn_id)
|
set_tag("txn_id", txn_id)
|
||||||
|
|
||||||
return self.txns.fetch_or_execute_request(
|
return await self.txns.fetch_or_execute_request(
|
||||||
request, self.on_POST, request, room_id, event_id, txn_id
|
request, requester, self._do, request, requester, room_id, event_id, txn_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1224,7 +1280,6 @@ def register_txn_path(
|
|||||||
servlet: RestServlet,
|
servlet: RestServlet,
|
||||||
regex_string: str,
|
regex_string: str,
|
||||||
http_server: HttpServer,
|
http_server: HttpServer,
|
||||||
with_get: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Registers a transaction-based path.
|
"""Registers a transaction-based path.
|
||||||
|
|
||||||
@ -1236,7 +1291,6 @@ def register_txn_path(
|
|||||||
regex_string: The regex string to register. Must NOT have a
|
regex_string: The regex string to register. Must NOT have a
|
||||||
trailing $ as this string will be appended to.
|
trailing $ as this string will be appended to.
|
||||||
http_server: The http_server to register paths with.
|
http_server: The http_server to register paths with.
|
||||||
with_get: True to also register respective GET paths for the PUTs.
|
|
||||||
"""
|
"""
|
||||||
on_POST = getattr(servlet, "on_POST", None)
|
on_POST = getattr(servlet, "on_POST", None)
|
||||||
on_PUT = getattr(servlet, "on_PUT", None)
|
on_PUT = getattr(servlet, "on_PUT", None)
|
||||||
@ -1254,18 +1308,6 @@ def register_txn_path(
|
|||||||
on_PUT,
|
on_PUT,
|
||||||
servlet.__class__.__name__,
|
servlet.__class__.__name__,
|
||||||
)
|
)
|
||||||
on_GET = getattr(servlet, "on_GET", None)
|
|
||||||
if with_get:
|
|
||||||
if on_GET is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"register_txn_path called with with_get = True, but no on_GET method exists"
|
|
||||||
)
|
|
||||||
http_server.register_paths(
|
|
||||||
"GET",
|
|
||||||
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
|
|
||||||
on_GET,
|
|
||||||
servlet.__class__.__name__,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TimestampLookupRestServlet(RestServlet):
|
class TimestampLookupRestServlet(RestServlet):
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Awaitable, Tuple
|
from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
from synapse.http import servlet
|
from synapse.http import servlet
|
||||||
from synapse.http.server import HttpServer
|
from synapse.http.server import HttpServer
|
||||||
@ -21,7 +21,7 @@ from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_r
|
|||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.opentracing import set_tag
|
from synapse.logging.opentracing import set_tag
|
||||||
from synapse.rest.client.transactions import HttpTransactionCache
|
from synapse.rest.client.transactions import HttpTransactionCache
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict, Requester
|
||||||
|
|
||||||
from ._base import client_patterns
|
from ._base import client_patterns
|
||||||
|
|
||||||
@ -43,19 +43,26 @@ class SendToDeviceRestServlet(servlet.RestServlet):
|
|||||||
self.txns = HttpTransactionCache(hs)
|
self.txns = HttpTransactionCache(hs)
|
||||||
self.device_message_handler = hs.get_device_message_handler()
|
self.device_message_handler = hs.get_device_message_handler()
|
||||||
|
|
||||||
def on_PUT(
|
async def on_PUT(
|
||||||
self, request: SynapseRequest, message_type: str, txn_id: str
|
|
||||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
|
||||||
set_tag("txn_id", txn_id)
|
|
||||||
return self.txns.fetch_or_execute_request(
|
|
||||||
request, self._put, request, message_type, txn_id
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _put(
|
|
||||||
self, request: SynapseRequest, message_type: str, txn_id: str
|
self, request: SynapseRequest, message_type: str, txn_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
set_tag("txn_id", txn_id)
|
||||||
|
return await self.txns.fetch_or_execute_request(
|
||||||
|
request,
|
||||||
|
requester,
|
||||||
|
self._put,
|
||||||
|
request,
|
||||||
|
requester,
|
||||||
|
message_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _put(
|
||||||
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
requester: Requester,
|
||||||
|
message_type: str,
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
assert_params_in_dict(content, ("messages",))
|
assert_params_in_dict(content, ("messages",))
|
||||||
|
|
||||||
|
@ -15,16 +15,16 @@
|
|||||||
"""This module contains logic for storing HTTP PUT transactions. This is used
|
"""This module contains logic for storing HTTP PUT transactions. This is used
|
||||||
to ensure idempotency when performing PUTs using the REST API."""
|
to ensure idempotency when performing PUTs using the REST API."""
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Tuple
|
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Hashable, Tuple
|
||||||
|
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import Deferred
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
from twisted.web.server import Request
|
from twisted.web.iweb import IRequest
|
||||||
|
|
||||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict, Requester
|
||||||
from synapse.util.async_helpers import ObservableDeferred
|
from synapse.util.async_helpers import ObservableDeferred
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -41,53 +41,47 @@ P = ParamSpec("P")
|
|||||||
class HttpTransactionCache:
|
class HttpTransactionCache:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = self.hs.get_auth()
|
|
||||||
self.clock = self.hs.get_clock()
|
self.clock = self.hs.get_clock()
|
||||||
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
|
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
|
||||||
self.transactions: Dict[
|
self.transactions: Dict[
|
||||||
str, Tuple[ObservableDeferred[Tuple[int, JsonDict]], int]
|
Hashable, Tuple[ObservableDeferred[Tuple[int, JsonDict]], int]
|
||||||
] = {}
|
] = {}
|
||||||
# Try to clean entries every 30 mins. This means entries will exist
|
# Try to clean entries every 30 mins. This means entries will exist
|
||||||
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
|
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
|
||||||
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
|
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
|
||||||
|
|
||||||
def _get_transaction_key(self, request: Request) -> str:
|
def _get_transaction_key(self, request: IRequest, requester: Requester) -> Hashable:
|
||||||
"""A helper function which returns a transaction key that can be used
|
"""A helper function which returns a transaction key that can be used
|
||||||
with TransactionCache for idempotent requests.
|
with TransactionCache for idempotent requests.
|
||||||
|
|
||||||
Idempotency is based on the returned key being the same for separate
|
Idempotency is based on the returned key being the same for separate
|
||||||
requests to the same endpoint. The key is formed from the HTTP request
|
requests to the same endpoint. The key is formed from the HTTP request
|
||||||
path and the access_token for the requesting user.
|
path and attributes from the requester: the access_token_id for regular users,
|
||||||
|
the user ID for guest users, and the appservice ID for appservice users.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: The incoming request. Must contain an access_token.
|
request: The incoming request.
|
||||||
|
requester: The requester doing the request.
|
||||||
Returns:
|
Returns:
|
||||||
A transaction key
|
A transaction key
|
||||||
"""
|
"""
|
||||||
assert request.path is not None
|
assert request.path is not None
|
||||||
token = self.auth.get_access_token_from_request(request)
|
path: str = request.path.decode("utf8")
|
||||||
return request.path.decode("utf8") + "/" + token
|
if requester.is_guest:
|
||||||
|
assert requester.user is not None, "Guest requester must have a user ID set"
|
||||||
|
return (path, "guest", requester.user)
|
||||||
|
elif requester.app_service is not None:
|
||||||
|
return (path, "appservice", requester.app_service.id)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
requester.access_token_id is not None
|
||||||
|
), "Requester must have an access_token_id"
|
||||||
|
return (path, "user", requester.access_token_id)
|
||||||
|
|
||||||
def fetch_or_execute_request(
|
def fetch_or_execute_request(
|
||||||
self,
|
self,
|
||||||
request: Request,
|
request: IRequest,
|
||||||
fn: Callable[P, Awaitable[Tuple[int, JsonDict]]],
|
requester: Requester,
|
||||||
*args: P.args,
|
|
||||||
**kwargs: P.kwargs,
|
|
||||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
|
||||||
"""A helper function for fetch_or_execute which extracts
|
|
||||||
a transaction key from the given request.
|
|
||||||
|
|
||||||
See:
|
|
||||||
fetch_or_execute
|
|
||||||
"""
|
|
||||||
return self.fetch_or_execute(
|
|
||||||
self._get_transaction_key(request), fn, *args, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def fetch_or_execute(
|
|
||||||
self,
|
|
||||||
txn_key: str,
|
|
||||||
fn: Callable[P, Awaitable[Tuple[int, JsonDict]]],
|
fn: Callable[P, Awaitable[Tuple[int, JsonDict]]],
|
||||||
*args: P.args,
|
*args: P.args,
|
||||||
**kwargs: P.kwargs,
|
**kwargs: P.kwargs,
|
||||||
@ -96,14 +90,15 @@ class HttpTransactionCache:
|
|||||||
to produce a response for this transaction.
|
to produce a response for this transaction.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
txn_key: A key to ensure idempotency should fetch_or_execute be
|
request:
|
||||||
called again at a later point in time.
|
requester:
|
||||||
fn: A function which returns a tuple of (response_code, response_dict).
|
fn: A function which returns a tuple of (response_code, response_dict).
|
||||||
*args: Arguments to pass to fn.
|
*args: Arguments to pass to fn.
|
||||||
**kwargs: Keyword arguments to pass to fn.
|
**kwargs: Keyword arguments to pass to fn.
|
||||||
Returns:
|
Returns:
|
||||||
Deferred which resolves to a tuple of (response_code, response_dict).
|
Deferred which resolves to a tuple of (response_code, response_dict).
|
||||||
"""
|
"""
|
||||||
|
txn_key = self._get_transaction_key(request, requester)
|
||||||
if txn_key in self.transactions:
|
if txn_key in self.transactions:
|
||||||
observable = self.transactions[txn_key][0]
|
observable = self.transactions[txn_key][0]
|
||||||
else:
|
else:
|
||||||
|
@ -39,15 +39,23 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
|
|||||||
self.cache = HttpTransactionCache(self.hs)
|
self.cache = HttpTransactionCache(self.hs)
|
||||||
|
|
||||||
self.mock_http_response = (HTTPStatus.OK, {"result": "GOOD JOB!"})
|
self.mock_http_response = (HTTPStatus.OK, {"result": "GOOD JOB!"})
|
||||||
self.mock_key = "foo"
|
|
||||||
|
# Here we make sure that we're setting all the fields that HttpTransactionCache
|
||||||
|
# uses to build the transaction key.
|
||||||
|
self.mock_request = Mock()
|
||||||
|
self.mock_request.path = b"/foo/bar"
|
||||||
|
self.mock_requester = Mock()
|
||||||
|
self.mock_requester.app_service = None
|
||||||
|
self.mock_requester.is_guest = False
|
||||||
|
self.mock_requester.access_token_id = 1234
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_executes_given_function(
|
def test_executes_given_function(
|
||||||
self,
|
self,
|
||||||
) -> Generator["defer.Deferred[Any]", object, None]:
|
) -> Generator["defer.Deferred[Any]", object, None]:
|
||||||
cb = Mock(return_value=make_awaitable(self.mock_http_response))
|
cb = Mock(return_value=make_awaitable(self.mock_http_response))
|
||||||
res = yield self.cache.fetch_or_execute(
|
res = yield self.cache.fetch_or_execute_request(
|
||||||
self.mock_key, cb, "some_arg", keyword="arg"
|
self.mock_request, self.mock_requester, cb, "some_arg", keyword="arg"
|
||||||
)
|
)
|
||||||
cb.assert_called_once_with("some_arg", keyword="arg")
|
cb.assert_called_once_with("some_arg", keyword="arg")
|
||||||
self.assertEqual(res, self.mock_http_response)
|
self.assertEqual(res, self.mock_http_response)
|
||||||
@ -58,8 +66,13 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
|
|||||||
) -> Generator["defer.Deferred[Any]", object, None]:
|
) -> Generator["defer.Deferred[Any]", object, None]:
|
||||||
cb = Mock(return_value=make_awaitable(self.mock_http_response))
|
cb = Mock(return_value=make_awaitable(self.mock_http_response))
|
||||||
for i in range(3): # invoke multiple times
|
for i in range(3): # invoke multiple times
|
||||||
res = yield self.cache.fetch_or_execute(
|
res = yield self.cache.fetch_or_execute_request(
|
||||||
self.mock_key, cb, "some_arg", keyword="arg", changing_args=i
|
self.mock_request,
|
||||||
|
self.mock_requester,
|
||||||
|
cb,
|
||||||
|
"some_arg",
|
||||||
|
keyword="arg",
|
||||||
|
changing_args=i,
|
||||||
)
|
)
|
||||||
self.assertEqual(res, self.mock_http_response)
|
self.assertEqual(res, self.mock_http_response)
|
||||||
# expect only a single call to do the work
|
# expect only a single call to do the work
|
||||||
@ -77,7 +90,9 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test() -> Generator["defer.Deferred[Any]", object, None]:
|
def test() -> Generator["defer.Deferred[Any]", object, None]:
|
||||||
with LoggingContext("c") as c1:
|
with LoggingContext("c") as c1:
|
||||||
res = yield self.cache.fetch_or_execute(self.mock_key, cb)
|
res = yield self.cache.fetch_or_execute_request(
|
||||||
|
self.mock_request, self.mock_requester, cb
|
||||||
|
)
|
||||||
self.assertIs(current_context(), c1)
|
self.assertIs(current_context(), c1)
|
||||||
self.assertEqual(res, (1, {}))
|
self.assertEqual(res, (1, {}))
|
||||||
|
|
||||||
@ -106,12 +121,16 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
with LoggingContext("test") as test_context:
|
with LoggingContext("test") as test_context:
|
||||||
try:
|
try:
|
||||||
yield self.cache.fetch_or_execute(self.mock_key, cb)
|
yield self.cache.fetch_or_execute_request(
|
||||||
|
self.mock_request, self.mock_requester, cb
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.assertEqual(e.args[0], "boo")
|
self.assertEqual(e.args[0], "boo")
|
||||||
self.assertIs(current_context(), test_context)
|
self.assertIs(current_context(), test_context)
|
||||||
|
|
||||||
res = yield self.cache.fetch_or_execute(self.mock_key, cb)
|
res = yield self.cache.fetch_or_execute_request(
|
||||||
|
self.mock_request, self.mock_requester, cb
|
||||||
|
)
|
||||||
self.assertEqual(res, self.mock_http_response)
|
self.assertEqual(res, self.mock_http_response)
|
||||||
self.assertIs(current_context(), test_context)
|
self.assertIs(current_context(), test_context)
|
||||||
|
|
||||||
@ -134,29 +153,39 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
with LoggingContext("test") as test_context:
|
with LoggingContext("test") as test_context:
|
||||||
try:
|
try:
|
||||||
yield self.cache.fetch_or_execute(self.mock_key, cb)
|
yield self.cache.fetch_or_execute_request(
|
||||||
|
self.mock_request, self.mock_requester, cb
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.assertEqual(e.args[0], "boo")
|
self.assertEqual(e.args[0], "boo")
|
||||||
self.assertIs(current_context(), test_context)
|
self.assertIs(current_context(), test_context)
|
||||||
|
|
||||||
res = yield self.cache.fetch_or_execute(self.mock_key, cb)
|
res = yield self.cache.fetch_or_execute_request(
|
||||||
|
self.mock_request, self.mock_requester, cb
|
||||||
|
)
|
||||||
self.assertEqual(res, self.mock_http_response)
|
self.assertEqual(res, self.mock_http_response)
|
||||||
self.assertIs(current_context(), test_context)
|
self.assertIs(current_context(), test_context)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]:
|
def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]:
|
||||||
cb = Mock(return_value=make_awaitable(self.mock_http_response))
|
cb = Mock(return_value=make_awaitable(self.mock_http_response))
|
||||||
yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
|
yield self.cache.fetch_or_execute_request(
|
||||||
|
self.mock_request, self.mock_requester, cb, "an arg"
|
||||||
|
)
|
||||||
# should NOT have cleaned up yet
|
# should NOT have cleaned up yet
|
||||||
self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)
|
self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)
|
||||||
|
|
||||||
yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
|
yield self.cache.fetch_or_execute_request(
|
||||||
|
self.mock_request, self.mock_requester, cb, "an arg"
|
||||||
|
)
|
||||||
# still using cache
|
# still using cache
|
||||||
cb.assert_called_once_with("an arg")
|
cb.assert_called_once_with("an arg")
|
||||||
|
|
||||||
self.clock.advance_time_msec(CLEANUP_PERIOD_MS)
|
self.clock.advance_time_msec(CLEANUP_PERIOD_MS)
|
||||||
|
|
||||||
yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
|
yield self.cache.fetch_or_execute_request(
|
||||||
|
self.mock_request, self.mock_requester, cb, "an arg"
|
||||||
|
)
|
||||||
# no longer using cache
|
# no longer using cache
|
||||||
self.assertEqual(cb.call_count, 2)
|
self.assertEqual(cb.call_count, 2)
|
||||||
self.assertEqual(cb.call_args_list, [call("an arg"), call("an arg")])
|
self.assertEqual(cb.call_args_list, [call("an arg"), call("an arg")])
|
||||||
|
Loading…
Reference in New Issue
Block a user