Reduce duplicate code in receipts servlets. (#13198)

This commit is contained in:
Patrick Cloke 2022-07-13 13:23:16 -04:00 committed by GitHub
parent 3371e1abcb
commit 1d5c80b161
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 43 deletions

1
changelog.d/13198.misc Normal file
View File

@ -0,0 +1 @@
Refactor receipts servlet logic to avoid duplicated code.

View File

@ -40,6 +40,10 @@ class ReadMarkerRestServlet(RestServlet):
self.read_marker_handler = hs.get_read_marker_handler() self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self._known_receipt_types = {ReceiptTypes.READ, ReceiptTypes.FULLY_READ}
if hs.config.experimental.msc2285_enabled:
self._known_receipt_types.add(ReceiptTypes.READ_PRIVATE)
async def on_POST( async def on_POST(
self, request: SynapseRequest, room_id: str self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
@ -49,13 +53,7 @@ class ReadMarkerRestServlet(RestServlet):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
valid_receipt_types = { unrecognized_types = set(body.keys()) - self._known_receipt_types
ReceiptTypes.READ,
ReceiptTypes.FULLY_READ,
ReceiptTypes.READ_PRIVATE,
}
unrecognized_types = set(body.keys()) - valid_receipt_types
if unrecognized_types: if unrecognized_types:
# It's fine if there are unrecognized receipt types, but let's log # It's fine if there are unrecognized receipt types, but let's log
# it to help debug clients that have typoed the receipt type. # it to help debug clients that have typoed the receipt type.
@ -65,31 +63,25 @@ class ReadMarkerRestServlet(RestServlet):
# types. # types.
logger.info("Ignoring unrecognized receipt types: %s", unrecognized_types) logger.info("Ignoring unrecognized receipt types: %s", unrecognized_types)
read_event_id = body.get(ReceiptTypes.READ, None) for receipt_type in self._known_receipt_types:
if read_event_id: event_id = body.get(receipt_type, None)
await self.receipts_handler.received_client_receipt( # TODO Add validation to reject non-string event IDs.
room_id, if not event_id:
ReceiptTypes.READ, continue
user_id=requester.user.to_string(),
event_id=read_event_id,
)
read_private_event_id = body.get(ReceiptTypes.READ_PRIVATE, None) if receipt_type == ReceiptTypes.FULLY_READ:
if read_private_event_id and self.config.experimental.msc2285_enabled: await self.read_marker_handler.received_client_read_marker(
await self.receipts_handler.received_client_receipt( room_id,
room_id, user_id=requester.user.to_string(),
ReceiptTypes.READ_PRIVATE, event_id=event_id,
user_id=requester.user.to_string(), )
event_id=read_private_event_id, else:
) await self.receipts_handler.received_client_receipt(
room_id,
read_marker_event_id = body.get(ReceiptTypes.FULLY_READ, None) receipt_type,
if read_marker_event_id: user_id=requester.user.to_string(),
await self.read_marker_handler.received_client_read_marker( event_id=event_id,
room_id, )
user_id=requester.user.to_string(),
event_id=read_marker_event_id,
)
return 200, {} return 200, {}

View File

@ -39,31 +39,27 @@ class ReceiptRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.receipts_handler = hs.get_receipts_handler() self.receipts_handler = hs.get_receipts_handler()
self.read_marker_handler = hs.get_read_marker_handler() self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self._known_receipt_types = {ReceiptTypes.READ}
if hs.config.experimental.msc2285_enabled:
self._known_receipt_types.update(
(ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ)
)
async def on_POST( async def on_POST(
self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if self.hs.config.experimental.msc2285_enabled and receipt_type not in [ if receipt_type not in self._known_receipt_types:
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.FULLY_READ,
]:
raise SynapseError( raise SynapseError(
400, 400,
"Receipt type must be 'm.read', 'org.matrix.msc2285.read.private' or 'm.fully_read'", f"Receipt type must be {', '.join(self._known_receipt_types)}",
) )
elif (
not self.hs.config.experimental.msc2285_enabled
and receipt_type != ReceiptTypes.READ
):
raise SynapseError(400, "Receipt type must be 'm.read'")
parse_json_object_from_request(request, allow_empty_body=False) parse_json_object_from_request(request, allow_empty_body=False)