mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Re-implement unread counts (#7736)
This commit is contained in:
parent
2184f61fae
commit
8dff4a1242
1
changelog.d/7736.feature
Normal file
1
changelog.d/7736.feature
Normal file
@ -0,0 +1 @@
|
||||
Add unread messages count to sync responses.
|
@ -69,7 +69,7 @@ logger = logging.getLogger("synapse_port_db")
|
||||
|
||||
|
||||
BOOLEAN_COLUMNS = {
|
||||
"events": ["processed", "outlier", "contains_url"],
|
||||
"events": ["processed", "outlier", "contains_url", "count_as_unread"],
|
||||
"rooms": ["is_public"],
|
||||
"event_edges": ["is_state"],
|
||||
"presence_list": ["accepted"],
|
||||
|
@ -103,6 +103,7 @@ class JoinedSyncResult:
|
||||
account_data = attr.ib(type=List[JsonDict])
|
||||
unread_notifications = attr.ib(type=JsonDict)
|
||||
summary = attr.ib(type=Optional[JsonDict])
|
||||
unread_count = attr.ib(type=int)
|
||||
|
||||
def __nonzero__(self) -> bool:
|
||||
"""Make the result appear empty if there are no updates. This is used
|
||||
@ -1886,6 +1887,10 @@ class SyncHandler(object):
|
||||
|
||||
if room_builder.rtype == "joined":
|
||||
unread_notifications = {} # type: Dict[str, str]
|
||||
|
||||
unread_count = await self.store.get_unread_message_count_for_user(
|
||||
room_id, sync_config.user.to_string(),
|
||||
)
|
||||
room_sync = JoinedSyncResult(
|
||||
room_id=room_id,
|
||||
timeline=batch,
|
||||
@ -1894,6 +1899,7 @@ class SyncHandler(object):
|
||||
account_data=account_data_events,
|
||||
unread_notifications=unread_notifications,
|
||||
summary=summary,
|
||||
unread_count=unread_count,
|
||||
)
|
||||
|
||||
if room_sync or always_include:
|
||||
|
@ -21,22 +21,13 @@ async def get_badge_count(store, user_id):
|
||||
invites = await store.get_invited_rooms_for_local_user(user_id)
|
||||
joins = await store.get_rooms_for_user(user_id)
|
||||
|
||||
my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read")
|
||||
|
||||
badge = len(invites)
|
||||
|
||||
for room_id in joins:
|
||||
if room_id in my_receipts_by_room:
|
||||
last_unread_event_id = my_receipts_by_room[room_id]
|
||||
|
||||
notifs = await (
|
||||
store.get_unread_event_push_actions_by_room_for_user(
|
||||
room_id, user_id, last_unread_event_id
|
||||
)
|
||||
)
|
||||
unread_count = await store.get_unread_message_count_for_user(room_id, user_id)
|
||||
# return one badge count per conversation, as count per
|
||||
# message is so noisy as to be almost useless
|
||||
badge += 1 if notifs["notify_count"] else 0
|
||||
badge += 1 if unread_count else 0
|
||||
return badge
|
||||
|
||||
|
||||
|
@ -426,6 +426,7 @@ class SyncRestServlet(RestServlet):
|
||||
result["ephemeral"] = {"events": ephemeral_events}
|
||||
result["unread_notifications"] = room.unread_notifications
|
||||
result["summary"] = room.summary
|
||||
result["org.matrix.msc2654.unread_count"] = room.unread_count
|
||||
|
||||
return result
|
||||
|
||||
|
@ -172,6 +172,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
|
||||
self.get_latest_event_ids_in_room.invalidate((room_id,))
|
||||
|
||||
self.get_unread_message_count_for_user.invalidate_many((room_id,))
|
||||
self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
|
||||
|
||||
if not backfilled:
|
||||
|
@ -53,6 +53,47 @@ event_counter = Counter(
|
||||
["type", "origin_type", "origin_entity"],
|
||||
)
|
||||
|
||||
STATE_EVENT_TYPES_TO_MARK_UNREAD = {
|
||||
EventTypes.Topic,
|
||||
EventTypes.Name,
|
||||
EventTypes.RoomAvatar,
|
||||
EventTypes.Tombstone,
|
||||
}
|
||||
|
||||
|
||||
def should_count_as_unread(event: EventBase, context: EventContext) -> bool:
|
||||
# Exclude rejected and soft-failed events.
|
||||
if context.rejected or event.internal_metadata.is_soft_failed():
|
||||
return False
|
||||
|
||||
# Exclude notices.
|
||||
if (
|
||||
not event.is_state()
|
||||
and event.type == EventTypes.Message
|
||||
and event.content.get("msgtype") == "m.notice"
|
||||
):
|
||||
return False
|
||||
|
||||
# Exclude edits.
|
||||
relates_to = event.content.get("m.relates_to", {})
|
||||
if relates_to.get("rel_type") == RelationTypes.REPLACE:
|
||||
return False
|
||||
|
||||
# Mark events that have a non-empty string body as unread.
|
||||
body = event.content.get("body")
|
||||
if isinstance(body, str) and body:
|
||||
return True
|
||||
|
||||
# Mark some state events as unread.
|
||||
if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD:
|
||||
return True
|
||||
|
||||
# Mark encrypted events as unread.
|
||||
if not event.is_state() and event.type == EventTypes.Encrypted:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def encode_json(json_object):
|
||||
"""
|
||||
@ -196,6 +237,10 @@ class PersistEventsStore:
|
||||
|
||||
event_counter.labels(event.type, origin_type, origin_entity).inc()
|
||||
|
||||
self.store.get_unread_message_count_for_user.invalidate_many(
|
||||
(event.room_id,),
|
||||
)
|
||||
|
||||
for room_id, new_state in current_state_for_room.items():
|
||||
self.store.get_current_state_ids.prefill((room_id,), new_state)
|
||||
|
||||
@ -817,8 +862,9 @@ class PersistEventsStore:
|
||||
"contains_url": (
|
||||
"url" in event.content and isinstance(event.content["url"], str)
|
||||
),
|
||||
"count_as_unread": should_count_as_unread(event, context),
|
||||
}
|
||||
for event, _ in events_and_contexts
|
||||
for event, context in events_and_contexts
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -41,9 +41,15 @@ from synapse.replication.tcp.streams import BackfillStream
|
||||
from synapse.replication.tcp.streams.events import EventsStream
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
|
||||
from synapse.util.caches.descriptors import (
|
||||
Cache,
|
||||
_CacheContext,
|
||||
cached,
|
||||
cachedInlineCallbacks,
|
||||
)
|
||||
from synapse.util.iterutils import batch_iter
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
@ -1358,6 +1364,84 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
|
||||
)
|
||||
|
||||
@cached(tree=True, cache_context=True)
|
||||
async def get_unread_message_count_for_user(
|
||||
self, room_id: str, user_id: str, cache_context: _CacheContext,
|
||||
) -> int:
|
||||
"""Retrieve the count of unread messages for the given room and user.
|
||||
|
||||
Args:
|
||||
room_id: The ID of the room to count unread messages in.
|
||||
user_id: The ID of the user to count unread messages for.
|
||||
|
||||
Returns:
|
||||
The number of unread messages for the given user in the given room.
|
||||
"""
|
||||
with Measure(self._clock, "get_unread_message_count_for_user"):
|
||||
last_read_event_id = await self.get_last_receipt_event_id_for_user(
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
receipt_type="m.read",
|
||||
on_invalidate=cache_context.invalidate,
|
||||
)
|
||||
|
||||
return await self.db.runInteraction(
|
||||
"get_unread_message_count_for_user",
|
||||
self._get_unread_message_count_for_user_txn,
|
||||
user_id,
|
||||
room_id,
|
||||
last_read_event_id,
|
||||
)
|
||||
|
||||
def _get_unread_message_count_for_user_txn(
|
||||
self,
|
||||
txn: Cursor,
|
||||
user_id: str,
|
||||
room_id: str,
|
||||
last_read_event_id: Optional[str],
|
||||
) -> int:
|
||||
if last_read_event_id:
|
||||
# Get the stream ordering for the last read event.
|
||||
stream_ordering = self.db.simple_select_one_onecol_txn(
|
||||
txn=txn,
|
||||
table="events",
|
||||
keyvalues={"room_id": room_id, "event_id": last_read_event_id},
|
||||
retcol="stream_ordering",
|
||||
)
|
||||
else:
|
||||
# If there's no read receipt for that room, it probably means the user hasn't
|
||||
# opened it yet, in which case use the stream ID of their join event.
|
||||
# We can't just set it to 0 otherwise messages from other local users from
|
||||
# before this user joined will be counted as well.
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT stream_ordering FROM local_current_membership
|
||||
LEFT JOIN events USING (event_id, room_id)
|
||||
WHERE membership = 'join'
|
||||
AND user_id = ?
|
||||
AND room_id = ?
|
||||
""",
|
||||
(user_id, room_id),
|
||||
)
|
||||
row = txn.fetchone()
|
||||
|
||||
if row is None:
|
||||
return 0
|
||||
|
||||
stream_ordering = row[0]
|
||||
|
||||
# Count the messages that qualify as unread after the stream ordering we've just
|
||||
# retrieved.
|
||||
sql = """
|
||||
SELECT COUNT(*) FROM events
|
||||
WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread
|
||||
"""
|
||||
|
||||
txn.execute(sql, (user_id, room_id, stream_ordering))
|
||||
row = txn.fetchone()
|
||||
|
||||
return row[0] if row else 0
|
||||
|
||||
|
||||
AllNewEventsResult = namedtuple(
|
||||
"AllNewEventsResult",
|
||||
|
@ -0,0 +1,18 @@
|
||||
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- Store a boolean value in the events table for whether the event should be counted in
|
||||
-- the unread_count property of sync responses.
|
||||
ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN;
|
@ -143,6 +143,26 @@ class RestHelper(object):
|
||||
|
||||
return channel.json_body
|
||||
|
||||
def redact(self, room_id, event_id, txn_id=None, tok=None, expect_code=200):
|
||||
if txn_id is None:
|
||||
txn_id = "m%s" % (str(time.time()))
|
||||
|
||||
path = "/_matrix/client/r0/rooms/%s/redact/%s/%s" % (room_id, event_id, txn_id)
|
||||
if tok:
|
||||
path = path + "?access_token=%s" % tok
|
||||
|
||||
request, channel = make_request(
|
||||
self.hs.get_reactor(), "PUT", path, json.dumps({}).encode("utf8")
|
||||
)
|
||||
render(request, self.resource, self.hs.get_reactor())
|
||||
|
||||
assert int(channel.result["code"]) == expect_code, (
|
||||
"Expected: %d, got: %d, resp: %r"
|
||||
% (expect_code, int(channel.result["code"]), channel.result["body"])
|
||||
)
|
||||
|
||||
return channel.json_body
|
||||
|
||||
def _read_write_state(
|
||||
self,
|
||||
room_id: str,
|
||||
|
@ -16,9 +16,9 @@
|
||||
import json
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.api.constants import EventContentFields, EventTypes
|
||||
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
|
||||
from synapse.rest.client.v1 import login, room
|
||||
from synapse.rest.client.v2_alpha import sync
|
||||
from synapse.rest.client.v2_alpha import read_marker, sync
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import TimedOutException
|
||||
@ -324,3 +324,156 @@ class SyncTypingTests(unittest.HomeserverTestCase):
|
||||
"GET", sync_url % (access_token, next_batch)
|
||||
)
|
||||
self.assertRaises(TimedOutException, self.render, request)
|
||||
|
||||
|
||||
class UnreadMessagesTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets,
|
||||
login.register_servlets,
|
||||
read_marker.register_servlets,
|
||||
room.register_servlets,
|
||||
sync.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.url = "/sync?since=%s"
|
||||
self.next_batch = "s0"
|
||||
|
||||
# Register the first user (used to check the unread counts).
|
||||
self.user_id = self.register_user("kermit", "monkey")
|
||||
self.tok = self.login("kermit", "monkey")
|
||||
|
||||
# Create the room we'll check unread counts for.
|
||||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
|
||||
|
||||
# Register the second user (used to send events to the room).
|
||||
self.user2 = self.register_user("kermit2", "monkey")
|
||||
self.tok2 = self.login("kermit2", "monkey")
|
||||
|
||||
# Change the power levels of the room so that the second user can send state
|
||||
# events.
|
||||
self.helper.send_state(
|
||||
self.room_id,
|
||||
EventTypes.PowerLevels,
|
||||
{
|
||||
"users": {self.user_id: 100, self.user2: 100},
|
||||
"users_default": 0,
|
||||
"events": {
|
||||
"m.room.name": 50,
|
||||
"m.room.power_levels": 100,
|
||||
"m.room.history_visibility": 100,
|
||||
"m.room.canonical_alias": 50,
|
||||
"m.room.avatar": 50,
|
||||
"m.room.tombstone": 100,
|
||||
"m.room.server_acl": 100,
|
||||
"m.room.encryption": 100,
|
||||
},
|
||||
"events_default": 0,
|
||||
"state_default": 50,
|
||||
"ban": 50,
|
||||
"kick": 50,
|
||||
"redact": 50,
|
||||
"invite": 0,
|
||||
},
|
||||
tok=self.tok,
|
||||
)
|
||||
|
||||
def test_unread_counts(self):
|
||||
"""Tests that /sync returns the right value for the unread count (MSC2654)."""
|
||||
|
||||
# Check that our own messages don't increase the unread count.
|
||||
self.helper.send(self.room_id, "hello", tok=self.tok)
|
||||
self._check_unread_count(0)
|
||||
|
||||
# Join the new user and check that this doesn't increase the unread count.
|
||||
self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
|
||||
self._check_unread_count(0)
|
||||
|
||||
# Check that the new user sending a message increases our unread count.
|
||||
res = self.helper.send(self.room_id, "hello", tok=self.tok2)
|
||||
self._check_unread_count(1)
|
||||
|
||||
# Send a read receipt to tell the server we've read the latest event.
|
||||
body = json.dumps({"m.read": res["event_id"]}).encode("utf8")
|
||||
request, channel = self.make_request(
|
||||
"POST",
|
||||
"/rooms/%s/read_markers" % self.room_id,
|
||||
body,
|
||||
access_token=self.tok,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# Check that the unread counter is back to 0.
|
||||
self._check_unread_count(0)
|
||||
|
||||
# Check that room name changes increase the unread counter.
|
||||
self.helper.send_state(
|
||||
self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2,
|
||||
)
|
||||
self._check_unread_count(1)
|
||||
|
||||
# Check that room topic changes increase the unread counter.
|
||||
self.helper.send_state(
|
||||
self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2,
|
||||
)
|
||||
self._check_unread_count(2)
|
||||
|
||||
# Check that encrypted messages increase the unread counter.
|
||||
self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2)
|
||||
self._check_unread_count(3)
|
||||
|
||||
# Check that custom events with a body increase the unread counter.
|
||||
self.helper.send_event(
|
||||
self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2,
|
||||
)
|
||||
self._check_unread_count(4)
|
||||
|
||||
# Check that edits don't increase the unread counter.
|
||||
self.helper.send_event(
|
||||
room_id=self.room_id,
|
||||
type=EventTypes.Message,
|
||||
content={
|
||||
"body": "hello",
|
||||
"msgtype": "m.text",
|
||||
"m.relates_to": {"rel_type": RelationTypes.REPLACE},
|
||||
},
|
||||
tok=self.tok2,
|
||||
)
|
||||
self._check_unread_count(4)
|
||||
|
||||
# Check that notices don't increase the unread counter.
|
||||
self.helper.send_event(
|
||||
room_id=self.room_id,
|
||||
type=EventTypes.Message,
|
||||
content={"body": "hello", "msgtype": "m.notice"},
|
||||
tok=self.tok2,
|
||||
)
|
||||
self._check_unread_count(4)
|
||||
|
||||
# Check that tombstone events changes increase the unread counter.
|
||||
self.helper.send_state(
|
||||
self.room_id,
|
||||
EventTypes.Tombstone,
|
||||
{"replacement_room": "!someroom:test"},
|
||||
tok=self.tok2,
|
||||
)
|
||||
self._check_unread_count(5)
|
||||
|
||||
def _check_unread_count(self, expected_count: True):
|
||||
"""Syncs and compares the unread count with the expected value."""
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url % self.next_batch, access_token=self.tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
room_entry = channel.json_body["rooms"]["join"][self.room_id]
|
||||
self.assertEqual(
|
||||
room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry,
|
||||
)
|
||||
|
||||
# Store the next batch for the next request.
|
||||
self.next_batch = channel.json_body["next_batch"]
|
||||
|
Loading…
Reference in New Issue
Block a user