mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Reduce duplicate code in receipts servlets. (#13198)
This commit is contained in:
parent
3371e1abcb
commit
1d5c80b161
1
changelog.d/13198.misc
Normal file
1
changelog.d/13198.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Refactor receipts servlet logic to avoid duplicated code.
|
@ -40,6 +40,10 @@ class ReadMarkerRestServlet(RestServlet):
|
|||||||
self.read_marker_handler = hs.get_read_marker_handler()
|
self.read_marker_handler = hs.get_read_marker_handler()
|
||||||
self.presence_handler = hs.get_presence_handler()
|
self.presence_handler = hs.get_presence_handler()
|
||||||
|
|
||||||
|
self._known_receipt_types = {ReceiptTypes.READ, ReceiptTypes.FULLY_READ}
|
||||||
|
if hs.config.experimental.msc2285_enabled:
|
||||||
|
self._known_receipt_types.add(ReceiptTypes.READ_PRIVATE)
|
||||||
|
|
||||||
async def on_POST(
|
async def on_POST(
|
||||||
self, request: SynapseRequest, room_id: str
|
self, request: SynapseRequest, room_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
@ -49,13 +53,7 @@ class ReadMarkerRestServlet(RestServlet):
|
|||||||
|
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
valid_receipt_types = {
|
unrecognized_types = set(body.keys()) - self._known_receipt_types
|
||||||
ReceiptTypes.READ,
|
|
||||||
ReceiptTypes.FULLY_READ,
|
|
||||||
ReceiptTypes.READ_PRIVATE,
|
|
||||||
}
|
|
||||||
|
|
||||||
unrecognized_types = set(body.keys()) - valid_receipt_types
|
|
||||||
if unrecognized_types:
|
if unrecognized_types:
|
||||||
# It's fine if there are unrecognized receipt types, but let's log
|
# It's fine if there are unrecognized receipt types, but let's log
|
||||||
# it to help debug clients that have typoed the receipt type.
|
# it to help debug clients that have typoed the receipt type.
|
||||||
@ -65,30 +63,24 @@ class ReadMarkerRestServlet(RestServlet):
|
|||||||
# types.
|
# types.
|
||||||
logger.info("Ignoring unrecognized receipt types: %s", unrecognized_types)
|
logger.info("Ignoring unrecognized receipt types: %s", unrecognized_types)
|
||||||
|
|
||||||
read_event_id = body.get(ReceiptTypes.READ, None)
|
for receipt_type in self._known_receipt_types:
|
||||||
if read_event_id:
|
event_id = body.get(receipt_type, None)
|
||||||
await self.receipts_handler.received_client_receipt(
|
# TODO Add validation to reject non-string event IDs.
|
||||||
room_id,
|
if not event_id:
|
||||||
ReceiptTypes.READ,
|
continue
|
||||||
user_id=requester.user.to_string(),
|
|
||||||
event_id=read_event_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
read_private_event_id = body.get(ReceiptTypes.READ_PRIVATE, None)
|
if receipt_type == ReceiptTypes.FULLY_READ:
|
||||||
if read_private_event_id and self.config.experimental.msc2285_enabled:
|
|
||||||
await self.receipts_handler.received_client_receipt(
|
|
||||||
room_id,
|
|
||||||
ReceiptTypes.READ_PRIVATE,
|
|
||||||
user_id=requester.user.to_string(),
|
|
||||||
event_id=read_private_event_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
read_marker_event_id = body.get(ReceiptTypes.FULLY_READ, None)
|
|
||||||
if read_marker_event_id:
|
|
||||||
await self.read_marker_handler.received_client_read_marker(
|
await self.read_marker_handler.received_client_read_marker(
|
||||||
room_id,
|
room_id,
|
||||||
user_id=requester.user.to_string(),
|
user_id=requester.user.to_string(),
|
||||||
event_id=read_marker_event_id,
|
event_id=event_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self.receipts_handler.received_client_receipt(
|
||||||
|
room_id,
|
||||||
|
receipt_type,
|
||||||
|
user_id=requester.user.to_string(),
|
||||||
|
event_id=event_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
@ -39,31 +39,27 @@ class ReceiptRestServlet(RestServlet):
|
|||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hs = hs
|
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.receipts_handler = hs.get_receipts_handler()
|
self.receipts_handler = hs.get_receipts_handler()
|
||||||
self.read_marker_handler = hs.get_read_marker_handler()
|
self.read_marker_handler = hs.get_read_marker_handler()
|
||||||
self.presence_handler = hs.get_presence_handler()
|
self.presence_handler = hs.get_presence_handler()
|
||||||
|
|
||||||
|
self._known_receipt_types = {ReceiptTypes.READ}
|
||||||
|
if hs.config.experimental.msc2285_enabled:
|
||||||
|
self._known_receipt_types.update(
|
||||||
|
(ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ)
|
||||||
|
)
|
||||||
|
|
||||||
async def on_POST(
|
async def on_POST(
|
||||||
self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str
|
self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
if self.hs.config.experimental.msc2285_enabled and receipt_type not in [
|
if receipt_type not in self._known_receipt_types:
|
||||||
ReceiptTypes.READ,
|
|
||||||
ReceiptTypes.READ_PRIVATE,
|
|
||||||
ReceiptTypes.FULLY_READ,
|
|
||||||
]:
|
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400,
|
400,
|
||||||
"Receipt type must be 'm.read', 'org.matrix.msc2285.read.private' or 'm.fully_read'",
|
f"Receipt type must be {', '.join(self._known_receipt_types)}",
|
||||||
)
|
)
|
||||||
elif (
|
|
||||||
not self.hs.config.experimental.msc2285_enabled
|
|
||||||
and receipt_type != ReceiptTypes.READ
|
|
||||||
):
|
|
||||||
raise SynapseError(400, "Receipt type must be 'm.read'")
|
|
||||||
|
|
||||||
parse_json_object_from_request(request, allow_empty_body=False)
|
parse_json_object_from_request(request, allow_empty_body=False)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user