540 lines
18 KiB
Python
Raw Normal View History

# -*- coding: utf-8 -*-
2016-01-07 04:26:29 +00:00
# Copyright 2014-2016 OpenMarket Ltd
2018-02-16 12:17:14 +00:00
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2018-07-09 16:09:20 +10:00
import abc
import logging
from canonicaljson import json
2018-07-09 16:09:20 +10:00
from twisted.internet import defer
2019-10-02 19:07:07 +01:00
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import Database
2019-10-02 19:07:07 +01:00
from synapse.storage.util.id_generators import StreamIdGenerator
2018-07-09 16:09:20 +10:00
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
2018-02-16 12:17:14 +00:00
class ReceiptsWorkerStore(SQLBaseStore):
"""This is an abstract base class where subclasses must implement
`get_max_receipt_stream_id` which can be called in the initializer.
"""
2018-02-16 12:17:14 +00:00
# This ABCMeta metaclass ensures that we cannot be instantiated without
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
def __init__(self, database: Database, db_conn, hs):
super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs)
2016-01-28 16:37:41 +00:00
self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
)
@abc.abstractmethod
def get_max_receipt_stream_id(self):
"""Get the current max stream ID for receipts stream
Returns:
int
"""
2018-02-20 17:59:23 +00:00
raise NotImplementedError()
@cachedInlineCallbacks()
def get_users_with_read_receipts_in_room(self, room_id):
receipts = yield self.get_receipts_for_room(room_id, "m.read")
return set(r["user_id"] for r in receipts)
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
return self.db.simple_select_list(
table="receipts_linearized",
keyvalues={"room_id": room_id, "receipt_type": receipt_type},
retcols=("user_id", "event_id"),
desc="get_receipts_for_room",
)
@cached(num_args=3)
def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
return self.db.simple_select_one_onecol(
table="receipts_linearized",
keyvalues={
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
},
retcol="event_id",
desc="get_own_receipt_for_user",
allow_none=True,
)
@cachedInlineCallbacks(num_args=2)
def get_receipts_for_user(self, user_id, receipt_type):
rows = yield self.db.simple_select_list(
2016-03-23 16:13:05 +00:00
table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
2016-03-23 16:13:05 +00:00
retcols=("room_id", "event_id"),
desc="get_receipts_for_user",
)
return {row["room_id"]: row["event_id"] for row in rows}
2016-05-23 18:33:51 +01:00
@defer.inlineCallbacks
def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
def f(txn):
sql = (
"SELECT rl.room_id, rl.event_id,"
" e.topological_ordering, e.stream_ordering"
2016-08-18 17:53:44 +01:00
" FROM receipts_linearized AS rl"
" INNER JOIN events AS e USING (room_id, event_id)"
2016-05-23 18:33:51 +01:00
" WHERE rl.room_id = e.room_id"
" AND rl.event_id = e.event_id"
" AND user_id = ?"
)
txn.execute(sql, (user_id,))
return txn.fetchall()
rows = yield self.db.runInteraction("get_receipts_for_user_with_orderings", f)
return {
row[0]: {
"event_id": row[1],
"topological_ordering": row[2],
"stream_ordering": row[3],
}
for row in rows
}
2016-05-23 18:33:51 +01:00
@defer.inlineCallbacks
2015-07-14 10:19:07 +01:00
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
2015-07-13 13:30:43 +01:00
"""Get receipts for multiple rooms for sending to clients.
2015-07-14 10:19:07 +01:00
Args:
room_ids (list): List of room_ids.
to_key (int): Max stream id to fetch receipts upto.
from_key (int): Min stream id to fetch receipts from. None fetches
from the start.
Returns:
list: A list of receipts.
2015-07-13 13:30:43 +01:00
"""
room_ids = set(room_ids)
2018-07-10 18:12:39 +01:00
if from_key is not None:
# Only ask the database about rooms where there have been new
# receipts added since `from_key`
2016-01-28 16:37:41 +00:00
room_ids = yield self._receipts_stream_cache.get_entities_changed(
room_ids, from_key
)
results = yield self._get_linearized_receipts_for_rooms(
room_ids, to_key, from_key=from_key
)
return [ev for res in results.values() for ev in res]
2015-07-14 10:19:07 +01:00
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
2015-07-13 13:30:43 +01:00
"""Get receipts for a single room for sending to clients.
2015-07-14 10:19:07 +01:00
Args:
room_ids (str): The room id.
to_key (int): Max stream id to fetch receipts upto.
from_key (int): Min stream id to fetch receipts from. None fetches
from the start.
Returns:
2018-07-12 09:45:37 +01:00
Deferred[list]: A list of receipts.
2015-07-13 13:30:43 +01:00
"""
2018-07-10 18:12:39 +01:00
if from_key is not None:
# Check the cache first to see if any new receipts have been added
# since`from_key`. If not we can no-op.
if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
defer.succeed([])
return self._get_linearized_receipts_for_room(room_id, to_key, from_key)
@cachedInlineCallbacks(num_args=3, tree=True)
def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""See get_linearized_receipts_for_room
"""
def f(txn):
2015-07-08 10:54:01 +01:00
if from_key:
sql = (
"SELECT * FROM receipts_linearized WHERE"
" room_id = ? AND stream_id > ? AND stream_id <= ?"
)
txn.execute(sql, (room_id, from_key, to_key))
2015-07-08 10:54:01 +01:00
else:
sql = (
"SELECT * FROM receipts_linearized WHERE"
" room_id = ? AND stream_id <= ?"
)
txn.execute(sql, (room_id, to_key))
rows = self.db.cursor_to_dict(txn)
return rows
rows = yield self.db.runInteraction("get_linearized_receipts_for_room", f)
if not rows:
return []
content = {}
for row in rows:
content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
row["user_id"]
] = json.loads(row["data"])
return [{"type": "m.receipt", "room_id": room_id, "content": content}]
@cachedList(
cached_method_name="_get_linearized_receipts_for_room",
list_name="room_ids",
num_args=3,
inlineCallbacks=True,
)
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
return {}
def f(txn):
if from_key:
2019-10-02 19:07:07 +01:00
sql = """
SELECT * FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ? AND
"""
clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", room_ids
)
2019-10-02 19:07:07 +01:00
txn.execute(sql + clause, [from_key, to_key] + list(args))
else:
2019-10-02 19:07:07 +01:00
sql = """
SELECT * FROM receipts_linearized WHERE
stream_id <= ? AND
"""
2019-10-02 19:07:07 +01:00
clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", room_ids
)
2019-10-02 19:07:07 +01:00
txn.execute(sql + clause, [to_key] + list(args))
return self.db.cursor_to_dict(txn)
txn_results = yield self.db.runInteraction(
"_get_linearized_receipts_for_rooms", f
)
results = {}
for row in txn_results:
2015-08-18 10:49:23 +01:00
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
room_event = results.setdefault(
row["room_id"],
{"type": "m.receipt", "room_id": row["room_id"], "content": {}},
)
2015-08-18 10:49:23 +01:00
# The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
2015-08-18 11:54:03 +01:00
event_entry = room_event["content"].setdefault(row["event_id"], {})
receipt_type = event_entry.setdefault(row["receipt_type"], {})
2015-08-18 10:49:23 +01:00
receipt_type[row["user_id"]] = json.loads(row["data"])
results = {
room_id: [results[room_id]] if room_id in results else []
for room_id in room_ids
}
return results
2018-02-16 12:17:14 +00:00
def get_all_updated_receipts(self, last_id, current_id, limit=None):
if last_id == current_id:
return defer.succeed([])
def get_all_updated_receipts_txn(txn):
sql = (
"SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
" FROM receipts_linearized"
" WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
)
args = [last_id, current_id]
if limit is not None:
sql += " LIMIT ?"
args.append(limit)
txn.execute(sql, args)
return list(r[0:5] + (json.loads(r[5]),) for r in txn)
return self.db.runInteraction(
2018-02-16 12:17:14 +00:00
"get_all_updated_receipts", get_all_updated_receipts_txn
)
def _invalidate_get_users_with_receipts_in_room(
self, room_id, receipt_type, user_id
):
2018-02-16 12:17:14 +00:00
if receipt_type != "m.read":
return
# Returns either an ObservableDeferred or the raw result
2018-02-16 12:17:14 +00:00
res = self.get_users_with_read_receipts_in_room.cache.get(
room_id, None, update_metrics=False
2018-02-16 12:17:14 +00:00
)
# first handle the Deferred case
if isinstance(res, defer.Deferred):
if res.called:
2018-02-16 12:17:14 +00:00
res = res.result
else:
res = None
if res and user_id in res:
# We'd only be adding to the set, so no point invalidating if the
# user is already there
return
2018-02-16 12:17:14 +00:00
self.get_users_with_read_receipts_in_room.invalidate((room_id,))
class ReceiptsStore(ReceiptsWorkerStore):
def __init__(self, database: Database, db_conn, hs):
# We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
)
super(ReceiptsStore, self).__init__(database, db_conn, hs)
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
def insert_linearized_receipt_txn(
self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
):
"""Inserts a read-receipt into the database if it's newer than the current RR
Returns: int|None
None if the RR is older than the current RR
otherwise, the rx timestamp of the event that the RR corresponds to
(or 0 if the event is unknown)
"""
res = self.db.simple_select_one_txn(
txn,
table="events",
retcols=["stream_ordering", "received_ts"],
keyvalues={"event_id": event_id},
allow_none=True,
)
stream_ordering = int(res["stream_ordering"]) if res else None
rx_ts = res["received_ts"] if res else 0
# We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts
if stream_ordering is not None:
sql = (
"SELECT stream_ordering, event_id FROM events"
" INNER JOIN receipts_linearized as r USING (event_id, room_id)"
" WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
)
txn.execute(sql, (room_id, receipt_type, user_id))
for so, eid in txn:
if int(so) >= stream_ordering:
logger.debug(
"Ignoring new receipt for %s in favour of existing "
"one for later event %s",
event_id,
eid,
)
return None
txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
txn.call_after(
self._invalidate_get_users_with_receipts_in_room,
room_id,
receipt_type,
user_id,
)
txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
# FIXME: This shouldn't invalidate the whole cache
txn.call_after(
self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
)
txn.call_after(
self._receipts_stream_cache.entity_has_changed, room_id, stream_id
)
txn.call_after(
self.get_last_receipt_event_id_for_user.invalidate,
(user_id, room_id, receipt_type),
)
self.db.simple_delete_txn(
txn,
table="receipts_linearized",
keyvalues={
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
},
)
self.db.simple_insert_txn(
txn,
table="receipts_linearized",
values={
2015-07-07 10:55:31 +01:00
"stream_id": stream_id,
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
"event_id": event_id,
"data": json.dumps(data),
},
)
if receipt_type == "m.read" and stream_ordering is not None:
2016-05-20 15:25:12 +01:00
self._remove_old_push_actions_before_txn(
txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
2016-05-03 14:22:33 +01:00
)
return rx_ts
@defer.inlineCallbacks
def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
2015-07-13 13:30:43 +01:00
"""Insert a receipt, either from local client or remote server.
Automatically does conversion between linearized and graph
representations.
"""
if not event_ids:
return
if len(event_ids) == 1:
linearized_event_id = event_ids[0]
else:
# we need to points in graph -> linearized form.
2015-07-09 16:14:46 +01:00
# TODO: Make this better.
def graph_to_linear(txn):
2019-10-02 19:07:07 +01:00
clause, args = make_in_list_sql_clause(
self.database_engine, "event_id", event_ids
)
sql = """
SELECT event_id WHERE room_id = ? AND stream_ordering IN (
SELECT max(stream_ordering) WHERE %s
)
""" % (
clause,
)
2019-10-02 19:07:07 +01:00
txn.execute(sql, [room_id] + list(args))
rows = txn.fetchall()
if rows:
return rows[0][0]
else:
2015-07-09 16:14:46 +01:00
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
linearized_event_id = yield self.db.runInteraction(
2015-07-07 10:55:31 +01:00
"insert_receipt_conv", graph_to_linear
)
stream_id_manager = self._receipts_id_gen.get_next()
2015-07-07 10:55:31 +01:00
with stream_id_manager as stream_id:
event_ts = yield self.db.runInteraction(
2015-07-07 10:55:31 +01:00
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
room_id,
receipt_type,
user_id,
linearized_event_id,
data,
stream_id=stream_id,
)
if event_ts is None:
return None
now = self._clock.time_msec()
logger.debug(
"RR for event %s in %s (%i ms old)",
linearized_event_id,
room_id,
now - event_ts,
)
yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
max_persisted_id = self._receipts_id_gen.get_current_token()
return stream_id, max_persisted_id
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
return self.db.runInteraction(
2015-07-07 10:55:31 +01:00
"insert_graph_receipt",
self.insert_graph_receipt_txn,
room_id,
receipt_type,
user_id,
event_ids,
data,
)
def insert_graph_receipt_txn(
self, txn, room_id, receipt_type, user_id, event_ids, data
):
txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
txn.call_after(
self._invalidate_get_users_with_receipts_in_room,
room_id,
receipt_type,
user_id,
)
txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
# FIXME: This shouldn't invalidate the whole cache
txn.call_after(
self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
)
self.db.simple_delete_txn(
txn,
table="receipts_graph",
keyvalues={
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
},
)
self.db.simple_insert_txn(
txn,
table="receipts_graph",
values={
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
"event_ids": json.dumps(event_ids),
"data": json.dumps(data),
},
)