mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-15 19:22:12 -04:00
parent
fe6cfc80ec
commit
2ffd6783c7
12 changed files with 19 additions and 339 deletions
|
@ -41,15 +41,9 @@ from synapse.replication.tcp.streams import BackfillStream
|
|||
from synapse.replication.tcp.streams.events import EventsStream
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util.caches.descriptors import (
|
||||
Cache,
|
||||
_CacheContext,
|
||||
cached,
|
||||
cachedInlineCallbacks,
|
||||
)
|
||||
from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
|
||||
from synapse.util.iterutils import batch_iter
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
|
@ -1364,84 +1358,6 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
|
||||
)
|
||||
|
||||
@cached(tree=True, cache_context=True)
|
||||
async def get_unread_message_count_for_user(
|
||||
self, room_id: str, user_id: str, cache_context: _CacheContext,
|
||||
) -> int:
|
||||
"""Retrieve the count of unread messages for the given room and user.
|
||||
|
||||
Args:
|
||||
room_id: The ID of the room to count unread messages in.
|
||||
user_id: The ID of the user to count unread messages for.
|
||||
|
||||
Returns:
|
||||
The number of unread messages for the given user in the given room.
|
||||
"""
|
||||
with Measure(self._clock, "get_unread_message_count_for_user"):
|
||||
last_read_event_id = await self.get_last_receipt_event_id_for_user(
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
receipt_type="m.read",
|
||||
on_invalidate=cache_context.invalidate,
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_unread_message_count_for_user",
|
||||
self._get_unread_message_count_for_user_txn,
|
||||
user_id,
|
||||
room_id,
|
||||
last_read_event_id,
|
||||
)
|
||||
|
||||
def _get_unread_message_count_for_user_txn(
|
||||
self,
|
||||
txn: Cursor,
|
||||
user_id: str,
|
||||
room_id: str,
|
||||
last_read_event_id: Optional[str],
|
||||
) -> int:
|
||||
if last_read_event_id:
|
||||
# Get the stream ordering for the last read event.
|
||||
stream_ordering = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn=txn,
|
||||
table="events",
|
||||
keyvalues={"room_id": room_id, "event_id": last_read_event_id},
|
||||
retcol="stream_ordering",
|
||||
)
|
||||
else:
|
||||
# If there's no read receipt for that room, it probably means the user hasn't
|
||||
# opened it yet, in which case use the stream ID of their join event.
|
||||
# We can't just set it to 0 otherwise messages from other local users from
|
||||
# before this user joined will be counted as well.
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT stream_ordering FROM local_current_membership
|
||||
LEFT JOIN events USING (event_id, room_id)
|
||||
WHERE membership = 'join'
|
||||
AND user_id = ?
|
||||
AND room_id = ?
|
||||
""",
|
||||
(user_id, room_id),
|
||||
)
|
||||
row = txn.fetchone()
|
||||
|
||||
if row is None:
|
||||
return 0
|
||||
|
||||
stream_ordering = row[0]
|
||||
|
||||
# Count the messages that qualify as unread after the stream ordering we've just
|
||||
# retrieved.
|
||||
sql = """
|
||||
SELECT COUNT(*) FROM events
|
||||
WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread
|
||||
"""
|
||||
|
||||
txn.execute(sql, (user_id, room_id, stream_ordering))
|
||||
row = txn.fetchone()
|
||||
|
||||
return row[0] if row else 0
|
||||
|
||||
|
||||
AllNewEventsResult = namedtuple(
|
||||
"AllNewEventsResult",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue