Add a constant for receipt types (m.read). (#11531)

And expand some type hints in the receipts storage module.
This commit is contained in:
Patrick Cloke 2021-12-08 12:26:29 -05:00 committed by GitHub
parent 7ecaa3b976
commit d93362d87f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 87 additions and 45 deletions

1
changelog.d/11531.misc Normal file
View File

@ -0,0 +1 @@
Add a receipt types constant for `m.read`.

View File

@ -253,5 +253,9 @@ class GuestAccess:
FORBIDDEN: Final = "forbidden" FORBIDDEN: Final = "forbidden"
class ReceiptTypes:
READ: Final = "m.read"
class ReadReceiptEventFields: class ReadReceiptEventFields:
MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden" MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden"

View File

@ -14,7 +14,7 @@
import logging import logging
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from synapse.api.constants import ReadReceiptEventFields from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.streams import EventSource from synapse.streams import EventSource
from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
@ -178,7 +178,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
for event_id in content.keys(): for event_id in content.keys():
event_content = content.get(event_id, {}) event_content = content.get(event_id, {})
m_read = event_content.get("m.read", {}) m_read = event_content.get(ReceiptTypes.READ, {})
# If m_read is missing copy over the original event_content as there is nothing to process here # If m_read is missing copy over the original event_content as there is nothing to process here
if not m_read: if not m_read:
@ -206,7 +206,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
# Set new users unless empty # Set new users unless empty
if len(new_users.keys()) > 0: if len(new_users.keys()) > 0:
new_event["content"][event_id] = {"m.read": new_users} new_event["content"][event_id] = {ReceiptTypes.READ: new_users}
# Append new_event to visible_events unless empty # Append new_event to visible_events unless empty
if len(new_event["content"].keys()) > 0: if len(new_event["content"].keys()) > 0:

View File

@ -28,7 +28,7 @@ from typing import (
import attr import attr
from prometheus_client import Counter from prometheus_client import Counter
from synapse.api.constants import AccountDataTypes, EventTypes, Membership from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes
from synapse.api.filtering import FilterCollection from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@ -1046,7 +1046,7 @@ class SyncHandler:
last_unread_event_id = await self.store.get_last_receipt_event_id_for_user( last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
user_id=sync_config.user.to_string(), user_id=sync_config.user.to_string(),
room_id=room_id, room_id=room_id,
receipt_type="m.read", receipt_type=ReceiptTypes.READ,
) )
notifs = await self.store.get_unread_event_push_actions_by_room_for_user( notifs = await self.store.get_unread_event_push_actions_by_room_for_user(

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from typing import Dict from typing import Dict
from synapse.api.constants import ReceiptTypes
from synapse.events import EventBase from synapse.events import EventBase
from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage import Storage from synapse.storage import Storage
@ -23,7 +24,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
invites = await store.get_invited_rooms_for_local_user(user_id) invites = await store.get_invited_rooms_for_local_user(user_id)
joins = await store.get_rooms_for_user(user_id) joins = await store.get_rooms_for_user(user_id)
my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read") my_receipts_by_room = await store.get_receipts_for_user(user_id, ReceiptTypes.READ)
badge = len(invites) badge = len(invites)

View File

@ -15,6 +15,7 @@
import logging import logging
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import ReceiptTypes
from synapse.events.utils import format_event_for_client_v2_without_room_id from synapse.events.utils import format_event_for_client_v2_without_room_id
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_integer, parse_string
@ -54,7 +55,7 @@ class NotificationsServlet(RestServlet):
) )
receipts_by_room = await self.store.get_receipts_for_user_with_orderings( receipts_by_room = await self.store.get_receipts_for_user_with_orderings(
user_id, "m.read" user_id, ReceiptTypes.READ
) )
notif_event_ids = [pa["event_id"] for pa in push_actions] notif_event_ids = [pa["event_id"] for pa in push_actions]

View File

@ -15,7 +15,7 @@
import logging import logging
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import ReadReceiptEventFields from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
@ -48,7 +48,7 @@ class ReadMarkerRestServlet(RestServlet):
await self.presence_handler.bump_presence_active_time(requester.user) await self.presence_handler.bump_presence_active_time(requester.user)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
read_event_id = body.get("m.read", None) read_event_id = body.get(ReceiptTypes.READ, None)
hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False) hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False)
if not isinstance(hidden, bool): if not isinstance(hidden, bool):
@ -62,7 +62,7 @@ class ReadMarkerRestServlet(RestServlet):
if read_event_id: if read_event_id:
await self.receipts_handler.received_client_receipt( await self.receipts_handler.received_client_receipt(
room_id, room_id,
"m.read", ReceiptTypes.READ,
user_id=requester.user.to_string(), user_id=requester.user.to_string(),
event_id=read_event_id, event_id=read_event_id,
hidden=hidden, hidden=hidden,

View File

@ -16,7 +16,7 @@ import logging
import re import re
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import ReadReceiptEventFields from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http import get_request_user_agent from synapse.http import get_request_user_agent
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
@ -53,7 +53,7 @@ class ReceiptRestServlet(RestServlet):
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if receipt_type != "m.read": if receipt_type != ReceiptTypes.READ:
raise SynapseError(400, "Receipt type must be 'm.read'") raise SynapseError(400, "Receipt type must be 'm.read'")
# Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body. # Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body.

View File

@ -14,14 +14,25 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
)
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import ReceiptTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict from synapse.types import JsonDict
@ -78,17 +89,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
) )
def get_max_receipt_stream_id(self): def get_max_receipt_stream_id(self) -> int:
"""Get the current max stream ID for receipts stream """Get the current max stream ID for receipts stream"""
Returns:
int
"""
return self._receipts_id_gen.get_current_token() return self._receipts_id_gen.get_current_token()
@cached() @cached()
async def get_users_with_read_receipts_in_room(self, room_id): async def get_users_with_read_receipts_in_room(self, room_id: str) -> Set[str]:
receipts = await self.get_receipts_for_room(room_id, "m.read") receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ)
return {r["user_id"] for r in receipts} return {r["user_id"] for r in receipts}
@cached(num_args=2) @cached(num_args=2)
@ -119,7 +126,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
) )
@cached(num_args=2) @cached(num_args=2)
async def get_receipts_for_user(self, user_id, receipt_type): async def get_receipts_for_user(
self, user_id: str, receipt_type: str
) -> Dict[str, str]:
rows = await self.db_pool.simple_select_list( rows = await self.db_pool.simple_select_list(
table="receipts_linearized", table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type}, keyvalues={"user_id": user_id, "receipt_type": receipt_type},
@ -129,8 +138,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
return {row["room_id"]: row["event_id"] for row in rows} return {row["room_id"]: row["event_id"] for row in rows}
async def get_receipts_for_user_with_orderings(self, user_id, receipt_type): async def get_receipts_for_user_with_orderings(
def f(txn): self, user_id: str, receipt_type: str
) -> JsonDict:
def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]:
sql = ( sql = (
"SELECT rl.room_id, rl.event_id," "SELECT rl.room_id, rl.event_id,"
" e.topological_ordering, e.stream_ordering" " e.topological_ordering, e.stream_ordering"
@ -209,10 +220,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(num_args=3, tree=True) @cached(num_args=3, tree=True)
async def _get_linearized_receipts_for_room( async def _get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None self, room_id: str, to_key: int, from_key: Optional[int] = None
) -> List[dict]: ) -> List[JsonDict]:
"""See get_linearized_receipts_for_room""" """See get_linearized_receipts_for_room"""
def f(txn): def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
if from_key: if from_key:
sql = ( sql = (
"SELECT * FROM receipts_linearized WHERE" "SELECT * FROM receipts_linearized WHERE"
@ -250,11 +261,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
list_name="room_ids", list_name="room_ids",
num_args=3, num_args=3,
) )
async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): async def _get_linearized_receipts_for_rooms(
self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
) -> Dict[str, List[JsonDict]]:
if not room_ids: if not room_ids:
return {} return {}
def f(txn): def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
if from_key: if from_key:
sql = """ sql = """
SELECT * FROM receipts_linearized WHERE SELECT * FROM receipts_linearized WHERE
@ -323,7 +336,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
A dictionary of roomids to a list of receipts. A dictionary of roomids to a list of receipts.
""" """
def f(txn): def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
if from_key: if from_key:
sql = """ sql = """
SELECT * FROM receipts_linearized WHERE SELECT * FROM receipts_linearized WHERE
@ -379,7 +392,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
if last_id == current_id: if last_id == current_id:
return defer.succeed([]) return defer.succeed([])
def _get_users_sent_receipts_between_txn(txn): def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
sql = """ sql = """
SELECT DISTINCT user_id FROM receipts_linearized SELECT DISTINCT user_id FROM receipts_linearized
WHERE ? < stream_id AND stream_id <= ? WHERE ? < stream_id AND stream_id <= ?
@ -419,7 +432,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
if last_id == current_id: if last_id == current_id:
return [], current_id, False return [], current_id, False
def get_all_updated_receipts_txn(txn): def get_all_updated_receipts_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, list]], int, bool]:
sql = """ sql = """
SELECT stream_id, room_id, receipt_type, user_id, event_id, data SELECT stream_id, room_id, receipt_type, user_id, event_id, data
FROM receipts_linearized FROM receipts_linearized
@ -446,8 +461,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
def _invalidate_get_users_with_receipts_in_room( def _invalidate_get_users_with_receipts_in_room(
self, room_id: str, receipt_type: str, user_id: str self, room_id: str, receipt_type: str, user_id: str
): ) -> None:
if receipt_type != "m.read": if receipt_type != ReceiptTypes.READ:
return return
res = self.get_users_with_read_receipts_in_room.cache.get_immediate( res = self.get_users_with_read_receipts_in_room.cache.get_immediate(
@ -461,7 +476,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
self.get_users_with_read_receipts_in_room.invalidate((room_id,)) self.get_users_with_read_receipts_in_room.invalidate((room_id,))
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): def invalidate_caches_for_receipt(
self, room_id: str, receipt_type: str, user_id: str
) -> None:
self.get_receipts_for_user.invalidate((user_id, receipt_type)) self.get_receipts_for_user.invalidate((user_id, receipt_type))
self._get_linearized_receipts_for_room.invalidate((room_id,)) self._get_linearized_receipts_for_room.invalidate((room_id,))
self.get_last_receipt_event_id_for_user.invalidate( self.get_last_receipt_event_id_for_user.invalidate(
@ -482,11 +499,18 @@ class ReceiptsWorkerStore(SQLBaseStore):
return super().process_replication_rows(stream_name, instance_name, token, rows) return super().process_replication_rows(stream_name, instance_name, token, rows)
def insert_linearized_receipt_txn( def insert_linearized_receipt_txn(
self, txn, room_id, receipt_type, user_id, event_id, data, stream_id self,
): txn: LoggingTransaction,
room_id: str,
receipt_type: str,
user_id: str,
event_id: str,
data: JsonDict,
stream_id: int,
) -> Optional[int]:
"""Inserts a read-receipt into the database if it's newer than the current RR """Inserts a read-receipt into the database if it's newer than the current RR
Returns: int|None Returns:
None if the RR is older than the current RR None if the RR is older than the current RR
otherwise, the rx timestamp of the event that the RR corresponds to otherwise, the rx timestamp of the event that the RR corresponds to
(or 0 if the event is unknown) (or 0 if the event is unknown)
@ -550,7 +574,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
lock=False, lock=False,
) )
if receipt_type == "m.read" and stream_ordering is not None: if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
self._remove_old_push_actions_before_txn( self._remove_old_push_actions_before_txn(
txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
) )
@ -580,7 +604,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
else: else:
# we need to points in graph -> linearized form. # we need to points in graph -> linearized form.
# TODO: Make this better. # TODO: Make this better.
def graph_to_linear(txn): def graph_to_linear(txn: LoggingTransaction) -> str:
clause, args = make_in_list_sql_clause( clause, args = make_in_list_sql_clause(
self.database_engine, "event_id", event_ids self.database_engine, "event_id", event_ids
) )
@ -634,11 +658,16 @@ class ReceiptsWorkerStore(SQLBaseStore):
return stream_id, max_persisted_id return stream_id, max_persisted_id
async def insert_graph_receipt( async def insert_graph_receipt(
self, room_id, receipt_type, user_id, event_ids, data self,
): room_id: str,
receipt_type: str,
user_id: str,
event_ids: List[str],
data: JsonDict,
) -> None:
assert self._can_write_to_receipts assert self._can_write_to_receipts
return await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"insert_graph_receipt", "insert_graph_receipt",
self.insert_graph_receipt_txn, self.insert_graph_receipt_txn,
room_id, room_id,
@ -649,8 +678,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
) )
def insert_graph_receipt_txn( def insert_graph_receipt_txn(
self, txn, room_id, receipt_type, user_id, event_ids, data self,
): txn: LoggingTransaction,
room_id: str,
receipt_type: str,
user_id: str,
event_ids: List[str],
data: JsonDict,
) -> None:
assert self._can_write_to_receipts assert self._can_write_to_receipts
txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))