Merge branch 'develop' into babolivier/mark_unread

This commit is contained in:
Brendan Abolivier 2020-06-10 11:42:30 +01:00
commit ec0a7b9034
963 changed files with 66272 additions and 30755 deletions

View file

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# Copyright 2018,2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,518 +14,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import calendar
import logging
import time
"""
The storage layer is split up into multiple parts to allow Synapse to run
against different configurations of databases (e.g. single or multiple
databases). The `Database` class represents a single physical database. The
`data_stores` are classes that talk directly to a `Database` instance and have
associated schemas, background updates, etc. On top of those there are classes
that provide high level interfaces that combine calls to multiple `data_stores`.
from twisted.internet import defer
There are also schemas that get applied to every database, regardless of the
data stores associated with them (e.g. the schema version tables), which are
stored in `synapse.storage.schema`.
"""
from synapse.api.constants import PresenceState
from synapse.storage.devices import DeviceStore
from synapse.storage.user_erasure_store import UserErasureStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.storage.data_stores import DataStores
from synapse.storage.data_stores.main import DataStore
from synapse.storage.persist_events import EventsPersistenceStorage
from synapse.storage.purge_events import PurgeEventsStorage
from synapse.storage.state import StateGroupStorage
from .account_data import AccountDataStore
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
from .client_ips import ClientIpStore
from .deviceinbox import DeviceInboxStore
from .directory import DirectoryStore
from .e2e_room_keys import EndToEndRoomKeyStore
from .end_to_end_keys import EndToEndKeyStore
from .engines import PostgresEngine
from .event_federation import EventFederationStore
from .event_push_actions import EventPushActionsStore
from .events import EventsStore
from .events_bg_updates import EventsBackgroundUpdatesStore
from .filtering import FilteringStore
from .group_server import GroupServerStore
from .keys import KeyStore
from .media_repository import MediaRepositoryStore
from .monthly_active_users import MonthlyActiveUsersStore
from .openid import OpenIdStore
from .presence import PresenceStore, UserPresenceState
from .profile import ProfileStore
from .push_rule import PushRuleStore
from .pusher import PusherStore
from .receipts import ReceiptsStore
from .registration import RegistrationStore
from .rejections import RejectionsStore
from .relations import RelationsStore
from .room import RoomStore
from .roommember import RoomMemberStore
from .search import SearchStore
from .signatures import SignatureStore
from .state import StateStore
from .stats import StatsStore
from .stream import StreamStore
from .tags import TagsStore
from .transactions import TransactionStore
from .user_directory import UserDirectoryStore
from .util.id_generators import ChainedIdGenerator, IdGenerator, StreamIdGenerator
logger = logging.getLogger(__name__)
__all__ = ["DataStores", "DataStore"]
class DataStore(
EventsBackgroundUpdatesStore,
RoomMemberStore,
RoomStore,
RegistrationStore,
StreamStore,
ProfileStore,
PresenceStore,
TransactionStore,
DirectoryStore,
KeyStore,
StateStore,
SignatureStore,
ApplicationServiceStore,
EventsStore,
EventFederationStore,
MediaRepositoryStore,
RejectionsStore,
FilteringStore,
PusherStore,
PushRuleStore,
ApplicationServiceTransactionStore,
ReceiptsStore,
EndToEndKeyStore,
EndToEndRoomKeyStore,
SearchStore,
TagsStore,
AccountDataStore,
EventPushActionsStore,
OpenIdStore,
ClientIpStore,
DeviceStore,
DeviceInboxStore,
UserDirectoryStore,
GroupServerStore,
UserErasureStore,
MonthlyActiveUsersStore,
StatsStore,
RelationsStore,
):
def __init__(self, db_conn, hs):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = hs.database_engine
class Storage(object):
"""The high level interfaces for talking to various storage layers.
"""
self._stream_id_gen = StreamIdGenerator(
db_conn,
"events",
"stream_ordering",
extra_tables=[("local_invites", "stream_id")],
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
"events",
"stream_ordering",
step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
)
self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id"
)
self._device_inbox_id_gen = StreamIdGenerator(
db_conn, "device_max_stream_id", "stream_id"
)
self._public_room_id_gen = StreamIdGenerator(
db_conn, "public_room_list_stream", "stream_id"
)
self._device_list_id_gen = StreamIdGenerator(
db_conn, "device_lists_stream", "stream_id"
)
def __init__(self, hs, stores: DataStores):
# We include the main data store here mainly so that we don't have to
# rewrite all the existing code to split it into high vs low level
# interfaces.
self.main = stores.main
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
self._push_rules_stream_id_gen = ChainedIdGenerator(
self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
)
self._pushers_id_gen = StreamIdGenerator(
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
self._group_updates_id_gen = StreamIdGenerator(
db_conn, "local_group_updates", "stream_id"
)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = StreamIdGenerator(
db_conn, "cache_invalidation_stream", "stream_id"
)
else:
self._cache_id_gen = None
self._presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self._get_cache_dict(
db_conn,
"presence_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self._presence_id_gen.get_current_token(),
)
self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache",
min_presence_val,
prefilled_cache=presence_cache_prefill,
)
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
db_conn,
"device_inbox",
entity_column="user_id",
stream_column="stream_id",
max_value=max_device_inbox_id,
limit=1000,
)
self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache",
min_device_inbox_id,
prefilled_cache=device_inbox_prefill,
)
# The federation outbox and the local device inbox uses the same
# stream_id generator.
device_outbox_prefill, min_device_outbox_id = self._get_cache_dict(
db_conn,
"device_federation_outbox",
entity_column="destination",
stream_column="stream_id",
max_value=max_device_inbox_id,
limit=1000,
)
self._device_federation_outbox_stream_cache = StreamChangeCache(
"DeviceFederationOutboxStreamChangeCache",
min_device_outbox_id,
prefilled_cache=device_outbox_prefill,
)
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
"DeviceListStreamChangeCache", device_list_max
)
self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache", device_list_max
)
events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
db_conn,
"current_state_delta_stream",
entity_column="room_id",
stream_column="stream_id",
max_value=events_max, # As we share the stream id with events token
limit=1000,
)
self._curr_state_delta_stream_cache = StreamChangeCache(
"_curr_state_delta_stream_cache",
min_curr_state_delta_id,
prefilled_cache=curr_state_delta_prefill,
)
_group_updates_prefill, min_group_updates_id = self._get_cache_dict(
db_conn,
"local_group_updates",
entity_column="user_id",
stream_column="stream_id",
max_value=self._group_updates_id_gen.get_current_token(),
limit=1000,
)
self._group_updates_stream_cache = StreamChangeCache(
"_group_updates_stream_cache",
min_group_updates_id,
prefilled_cache=_group_updates_prefill,
)
self._stream_order_on_start = self.get_room_max_stream_ordering()
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
# Used in _generate_user_daily_visits to keep track of progress
self._last_user_visit_update = self._get_start_of_day()
super(DataStore, self).__init__(db_conn, hs)
def take_presence_startup_info(self):
active_on_startup = self._presence_on_startup
self._presence_on_startup = None
return active_on_startup
def _get_active_presence(self, db_conn):
"""Fetch non-offline presence from the database so that we can register
the appropriate time outs.
"""
sql = (
"SELECT user_id, state, last_active_ts, last_federation_update_ts,"
" last_user_sync_ts, status_msg, currently_active FROM presence_stream"
" WHERE state != ?"
)
sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor()
txn.execute(sql, (PresenceState.OFFLINE,))
rows = self.cursor_to_dict(txn)
txn.close()
for row in rows:
row["currently_active"] = bool(row["currently_active"])
return [UserPresenceState(**row) for row in rows]
def count_daily_users(self):
"""
Counts the number of users who used this homeserver in the last 24 hours.
"""
yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
return self.runInteraction("count_daily_users", self._count_users, yesterday)
def count_monthly_users(self):
"""
Counts the number of users who used this homeserver in the last 30 days.
Note this method is intended for phonehome metrics only and is different
from the mau figure in synapse.storage.monthly_active_users which,
amongst other things, includes a 3 day grace period before a user counts.
"""
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
return self.runInteraction(
"count_monthly_users", self._count_users, thirty_days_ago
)
def _count_users(self, txn, time_from):
"""
Returns number of users seen in the past time_from period
"""
sql = """
SELECT COALESCE(count(*), 0) FROM (
SELECT user_id FROM user_ips
WHERE last_seen > ?
GROUP BY user_id
) u
"""
txn.execute(sql, (time_from,))
count, = txn.fetchone()
return count
def count_r30_users(self):
"""
Counts the number of 30 day retained users, defined as:-
* Users who have created their accounts more than 30 days ago
* Where last seen at most 30 days ago
* Where account creation and last_seen are > 30 days apart
Returns counts globaly for a given user as well as breaking
by platform
"""
def _count_r30_users(txn):
thirty_days_in_secs = 86400 * 30
now = int(self._clock.time())
thirty_days_ago_in_secs = now - thirty_days_in_secs
sql = """
SELECT platform, COALESCE(count(*), 0) FROM (
SELECT
users.name, platform, users.creation_ts * 1000,
MAX(uip.last_seen)
FROM users
INNER JOIN (
SELECT
user_id,
last_seen,
CASE
WHEN user_agent LIKE '%%Android%%' THEN 'android'
WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
ELSE 'unknown'
END
AS platform
FROM user_ips
) uip
ON users.name = uip.user_id
AND users.appservice_id is NULL
AND users.creation_ts < ?
AND uip.last_seen/1000 > ?
AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
GROUP BY users.name, platform, users.creation_ts
) u GROUP BY platform
"""
results = {}
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
for row in txn:
if row[0] == "unknown":
pass
results[row[0]] = row[1]
sql = """
SELECT COALESCE(count(*), 0) FROM (
SELECT users.name, users.creation_ts * 1000,
MAX(uip.last_seen)
FROM users
INNER JOIN (
SELECT
user_id,
last_seen
FROM user_ips
) uip
ON users.name = uip.user_id
AND appservice_id is NULL
AND users.creation_ts < ?
AND uip.last_seen/1000 > ?
AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
GROUP BY users.name, users.creation_ts
) u
"""
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
count, = txn.fetchone()
results["all"] = count
return results
return self.runInteraction("count_r30_users", _count_r30_users)
def _get_start_of_day(self):
"""
Returns millisecond unixtime for start of UTC day.
"""
now = time.gmtime()
today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
return today_start * 1000
def generate_user_daily_visits(self):
"""
Generates daily visit data for use in cohort/ retention analysis
"""
def _generate_user_daily_visits(txn):
logger.info("Calling _generate_user_daily_visits")
today_start = self._get_start_of_day()
a_day_in_milliseconds = 24 * 60 * 60 * 1000
now = self.clock.time_msec()
sql = """
INSERT INTO user_daily_visits (user_id, device_id, timestamp)
SELECT u.user_id, u.device_id, ?
FROM user_ips AS u
LEFT JOIN (
SELECT user_id, device_id, timestamp FROM user_daily_visits
WHERE timestamp = ?
) udv
ON u.user_id = udv.user_id AND u.device_id=udv.device_id
INNER JOIN users ON users.name=u.user_id
WHERE last_seen > ? AND last_seen <= ?
AND udv.timestamp IS NULL AND users.is_guest=0
AND users.appservice_id IS NULL
GROUP BY u.user_id, u.device_id
"""
# This means that the day has rolled over but there could still
# be entries from the previous day. There is an edge case
# where if the user logs in at 23:59 and overwrites their
# last_seen at 00:01 then they will not be counted in the
# previous day's stats - it is important that the query is run
# often to minimise this case.
if today_start > self._last_user_visit_update:
yesterday_start = today_start - a_day_in_milliseconds
txn.execute(
sql,
(
yesterday_start,
yesterday_start,
self._last_user_visit_update,
today_start,
),
)
self._last_user_visit_update = today_start
txn.execute(
sql, (today_start, today_start, self._last_user_visit_update, now)
)
# Update _last_user_visit_update to now. The reason to do this
# rather just clamping to the beginning of the day is to limit
# the size of the join - meaning that the query can be run more
# frequently
self._last_user_visit_update = now
return self.runInteraction(
"generate_user_daily_visits", _generate_user_daily_visits
)
def get_users(self):
"""Function to reterive a list of users in users table.
Args:
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
return self._simple_select_list(
table="users",
keyvalues={},
retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
desc="get_users",
)
@defer.inlineCallbacks
def get_users_paginate(self, order, start, limit):
"""Function to reterive a paginated list of users from
users list. This will return a json object, which contains
list of users and the total number of users in users table.
Args:
order (str): column name to order the select by this column
start (int): start number to begin the query from
limit (int): number of rows to reterive
Returns:
defer.Deferred: resolves to json object {list[dict[str, Any]], count}
"""
users = yield self.runInteraction(
"get_users_paginate",
self._simple_select_list_paginate_txn,
table="users",
keyvalues={"is_guest": False},
orderby=order,
start=start,
limit=limit,
retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
)
count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn)
retval = {"users": users, "total": count}
return retval
def search_users(self, term):
"""Function to search users list for one or more users with
the matched term.
Args:
term (str): search term
col (str): column to query term should be matched to
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
return self._simple_search_list(
table="users",
term=term,
col="name",
retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
desc="search_users",
)
def are_all_users_on_domain(txn, database_engine, domain):
sql = database_engine.convert_param_style(
"SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
)
pat = "%:" + domain
txn.execute(sql, (pat,))
num_not_matching = txn.fetchall()[0][0]
if num_not_matching == 0:
return True
return False
self.persistence = EventsPersistenceStorage(hs, stores)
self.purge_events = PurgeEventsStorage(hs, stores)
self.state = StateGroupStorage(hs, stores)

File diff suppressed because it is too large Load diff

View file

@ -14,6 +14,7 @@
# limitations under the License.
import logging
from typing import Optional
from canonicaljson import json
@ -22,7 +23,6 @@ from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from . import engines
from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
@ -74,7 +74,7 @@ class BackgroundUpdatePerformance(object):
return float(self.total_item_count) / float(self.total_duration_ms)
class BackgroundUpdateStore(SQLBaseStore):
class BackgroundUpdater(object):
""" Background updates are updates to the database that run in the
background. Each update processes a batch of data at once. We attempt to
limit the impact of each update by monitoring how long each batch takes to
@ -86,30 +86,34 @@ class BackgroundUpdateStore(SQLBaseStore):
BACKGROUND_UPDATE_INTERVAL_MS = 1000
BACKGROUND_UPDATE_DURATION_MS = 100
def __init__(self, db_conn, hs):
super(BackgroundUpdateStore, self).__init__(db_conn, hs)
def __init__(self, hs, database):
self._clock = hs.get_clock()
self.db = database
# if a background update is currently running, its name.
self._current_background_update = None # type: Optional[str]
self._background_update_performance = {}
self._background_update_queue = []
self._background_update_handlers = {}
self._all_done = False
def start_doing_background_updates(self):
run_as_background_process("background_updates", self._run_background_updates)
run_as_background_process("background_updates", self.run_background_updates)
@defer.inlineCallbacks
def _run_background_updates(self):
async def run_background_updates(self, sleep=True):
logger.info("Starting background schema updates")
while True:
yield self.hs.get_clock().sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
if sleep:
await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
try:
result = yield self.do_next_background_update(
result = await self.do_next_background_update(
self.BACKGROUND_UPDATE_DURATION_MS
)
except Exception:
logger.exception("Error doing update")
else:
if result is None:
if result:
logger.info(
"No more background updates to do."
" Unscheduling background update task."
@ -117,30 +121,29 @@ class BackgroundUpdateStore(SQLBaseStore):
self._all_done = True
return None
@defer.inlineCallbacks
def has_completed_background_updates(self):
async def has_completed_background_updates(self) -> bool:
"""Check if all the background updates have completed
Returns:
Deferred[bool]: True if all background updates have completed
True if all background updates have completed
"""
# if we've previously determined that there is nothing left to do, that
# is easy
if self._all_done:
return True
# obviously, if we have things in our queue, we're not done.
if self._background_update_queue:
# obviously, if we are currently processing an update, we're not done.
if self._current_background_update:
return False
# otherwise, check if there are updates to be run. This is important,
# as we may be running on a worker which doesn't perform the bg updates
# itself, but still wants to wait for them to happen.
updates = yield self._simple_select_onecol(
updates = await self.db.simple_select_onecol(
"background_updates",
keyvalues=None,
retcol="1",
desc="check_background_updates",
desc="has_completed_background_updates",
)
if not updates:
self._all_done = True
@ -148,42 +151,79 @@ class BackgroundUpdateStore(SQLBaseStore):
return False
@defer.inlineCallbacks
def do_next_background_update(self, desired_duration_ms):
async def has_completed_background_update(self, update_name) -> bool:
"""Check if the given background update has finished running.
"""
if self._all_done:
return True
if update_name == self._current_background_update:
return False
update_exists = await self.db.simple_select_one_onecol(
"background_updates",
keyvalues={"update_name": update_name},
retcol="1",
desc="has_completed_background_update",
allow_none=True,
)
return not update_exists
async def do_next_background_update(self, desired_duration_ms: float) -> bool:
"""Does some amount of work on the next queued background update
Returns once some amount of work is done.
Args:
desired_duration_ms(float): How long we want to spend
updating.
Returns:
A deferred that completes once some amount of work is done.
The deferred will have a value of None if there is currently
no more work to do.
True if we have finished running all the background updates, otherwise False
"""
if not self._background_update_queue:
updates = yield self._simple_select_list(
"background_updates",
keyvalues=None,
retcols=("update_name", "depends_on"),
def get_background_updates_txn(txn):
txn.execute(
"""
SELECT update_name, depends_on FROM background_updates
ORDER BY ordering, update_name
"""
)
in_flight = set(update["update_name"] for update in updates)
for update in updates:
if update["depends_on"] not in in_flight:
self._background_update_queue.append(update["update_name"])
return self.db.cursor_to_dict(txn)
if not self._background_update_queue:
# no work left to do
return None
if not self._current_background_update:
all_pending_updates = await self.db.runInteraction(
"background_updates", get_background_updates_txn,
)
if not all_pending_updates:
# no work left to do
return True
# pop from the front, and add back to the back
update_name = self._background_update_queue.pop(0)
self._background_update_queue.append(update_name)
# find the first update which isn't dependent on another one in the queue.
pending = {update["update_name"] for update in all_pending_updates}
for upd in all_pending_updates:
depends_on = upd["depends_on"]
if not depends_on or depends_on not in pending:
break
logger.info(
"Not starting on bg update %s until %s is done",
upd["update_name"],
depends_on,
)
else:
# if we get to the end of that for loop, there is a problem
raise Exception(
"Unable to find a background update which doesn't depend on "
"another: dependency cycle?"
)
res = yield self._do_background_update(update_name, desired_duration_ms)
return res
self._current_background_update = upd["update_name"]
@defer.inlineCallbacks
def _do_background_update(self, update_name, desired_duration_ms):
await self._do_background_update(desired_duration_ms)
return False
async def _do_background_update(self, desired_duration_ms: float) -> int:
update_name = self._current_background_update
logger.info("Starting update batch on background update '%s'", update_name)
update_handler = self._background_update_handlers[update_name]
@ -203,7 +243,7 @@ class BackgroundUpdateStore(SQLBaseStore):
else:
batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
progress_json = yield self._simple_select_one_onecol(
progress_json = await self.db.simple_select_one_onecol(
"background_updates",
keyvalues={"update_name": update_name},
retcol="progress_json",
@ -212,13 +252,13 @@ class BackgroundUpdateStore(SQLBaseStore):
progress = json.loads(progress_json)
time_start = self._clock.time_msec()
items_updated = yield update_handler(progress, batch_size)
items_updated = await update_handler(progress, batch_size)
time_stop = self._clock.time_msec()
duration_ms = time_stop - time_start
logger.info(
"Updating %r. Updated %r items in %rms."
"Running background update %r. Processed %r items in %rms."
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
update_name,
items_updated,
@ -241,7 +281,9 @@ class BackgroundUpdateStore(SQLBaseStore):
* A dict of the current progress
* An integer count of the number of items to update in this batch.
The handler should return a deferred integer count of items updated.
The handler should return a deferred or coroutine which returns an integer count
of items updated.
The handler is responsible for updating the progress of the update.
Args:
@ -357,7 +399,7 @@ class BackgroundUpdateStore(SQLBaseStore):
logger.debug("[SQL] %s", sql)
c.execute(sql)
if isinstance(self.database_engine, engines.PostgresEngine):
if isinstance(self.db.engine, engines.PostgresEngine):
runner = create_index_psql
elif psql_only:
runner = None
@ -368,33 +410,12 @@ class BackgroundUpdateStore(SQLBaseStore):
def updater(progress, batch_size):
if runner is not None:
logger.info("Adding index %s to %s", index_name, table)
yield self.runWithConnection(runner)
yield self.db.runWithConnection(runner)
yield self._end_background_update(update_name)
return 1
self.register_background_update_handler(update_name, updater)
def start_background_update(self, update_name, progress):
"""Starts a background update running.
Args:
update_name: The update to set running.
progress: The initial state of the progress of the update.
Returns:
A deferred that completes once the task has been added to the
queue.
"""
# Clear the background update queue so that we will pick up the new
# task on the next iteration of do_background_update.
self._background_update_queue = []
progress_json = json.dumps(progress)
return self._simple_insert(
"background_updates",
{"update_name": update_name, "progress_json": progress_json},
)
def _end_background_update(self, update_name):
"""Removes a completed background update task from the queue.
@ -403,13 +424,31 @@ class BackgroundUpdateStore(SQLBaseStore):
Returns:
A deferred that completes once the task is removed.
"""
self._background_update_queue = [
name for name in self._background_update_queue if name != update_name
]
return self._simple_delete_one(
if update_name != self._current_background_update:
raise Exception(
"Cannot end background update %s which isn't currently running"
% update_name
)
self._current_background_update = None
return self.db.simple_delete_one(
"background_updates", keyvalues={"update_name": update_name}
)
def _background_update_progress(self, update_name: str, progress: dict):
"""Update the progress of a background update
Args:
update_name: The name of the background update task
progress: The progress of the update.
"""
return self.db.runInteraction(
"background_update_progress",
self._background_update_progress_txn,
update_name,
progress,
)
def _background_update_progress_txn(self, txn, update_name, progress):
"""Update the progress of a background update
@ -421,7 +460,7 @@ class BackgroundUpdateStore(SQLBaseStore):
progress_json = json.dumps(progress)
self._simple_update_one_txn(
self.db.simple_update_one_txn(
txn,
"background_updates",
keyvalues={"update_name": update_name},

View file

@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.storage.data_stores.main.events import PersistEventsStore
from synapse.storage.data_stores.state import StateGroupDataStore
from synapse.storage.database import Database, make_conn
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
logger = logging.getLogger(__name__)
class DataStores(object):
"""The various data stores.
These are low level interfaces to physical databases.
Attributes:
main (DataStore)
"""
def __init__(self, main_store_class, hs):
# Note we pass in the main store class here as workers use a different main
# store.
self.databases = []
self.main = None
self.state = None
self.persist_events = None
for database_config in hs.config.database.databases:
db_name = database_config.name
engine = create_engine(database_config.config)
with make_conn(database_config, engine) as db_conn:
logger.info("Preparing database %r...", db_name)
engine.check_database(db_conn)
prepare_database(
db_conn, engine, hs.config, data_stores=database_config.data_stores,
)
database = Database(hs, database_config, engine)
if "main" in database_config.data_stores:
logger.info("Starting 'main' data store")
# Sanity check we don't try and configure the main store on
# multiple databases.
if self.main:
raise Exception("'main' data store already configured")
self.main = main_store_class(database, db_conn, hs)
# If we're on a process that can persist events also
# instantiate a `PersistEventsStore`
if hs.config.worker.writers.events == hs.get_instance_name():
self.persist_events = PersistEventsStore(
hs, database, self.main
)
if "state" in database_config.data_stores:
logger.info("Starting 'state' data store")
# Sanity check we don't try and configure the state store on
# multiple databases.
if self.state:
raise Exception("'state' data store already configured")
self.state = StateGroupDataStore(database, db_conn, hs)
db_conn.commit()
self.databases.append(database)
logger.info("Database %r prepared", db_name)
# Sanity check that we have actually configured all the required stores.
if not self.main:
raise Exception("No 'main' data store configured")
if not self.state:
raise Exception("No 'main' data store configured")

View file

@ -0,0 +1,592 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import calendar
import logging
import time
from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
IdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.util.caches.stream_change_cache import StreamChangeCache
from .account_data import AccountDataStore
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
from .cache import CacheInvalidationWorkerStore
from .censor_events import CensorEventsStore
from .client_ips import ClientIpStore
from .deviceinbox import DeviceInboxStore
from .devices import DeviceStore
from .directory import DirectoryStore
from .e2e_room_keys import EndToEndRoomKeyStore
from .end_to_end_keys import EndToEndKeyStore
from .event_federation import EventFederationStore
from .event_push_actions import EventPushActionsStore
from .events_bg_updates import EventsBackgroundUpdatesStore
from .filtering import FilteringStore
from .group_server import GroupServerStore
from .keys import KeyStore
from .media_repository import MediaRepositoryStore
from .metrics import ServerMetricsStore
from .monthly_active_users import MonthlyActiveUsersStore
from .openid import OpenIdStore
from .presence import PresenceStore, UserPresenceState
from .profile import ProfileStore
from .purge_events import PurgeEventsStore
from .push_rule import PushRuleStore
from .pusher import PusherStore
from .receipts import ReceiptsStore
from .registration import RegistrationStore
from .rejections import RejectionsStore
from .relations import RelationsStore
from .room import RoomStore
from .roommember import RoomMemberStore
from .search import SearchStore
from .signatures import SignatureStore
from .state import StateStore
from .stats import StatsStore
from .stream import StreamStore
from .tags import TagsStore
from .transactions import TransactionStore
from .ui_auth import UIAuthStore
from .user_directory import UserDirectoryStore
from .user_erasure_store import UserErasureStore
logger = logging.getLogger(__name__)
class DataStore(
EventsBackgroundUpdatesStore,
RoomMemberStore,
RoomStore,
RegistrationStore,
StreamStore,
ProfileStore,
PresenceStore,
TransactionStore,
DirectoryStore,
KeyStore,
StateStore,
SignatureStore,
ApplicationServiceStore,
PurgeEventsStore,
EventFederationStore,
MediaRepositoryStore,
RejectionsStore,
FilteringStore,
PusherStore,
PushRuleStore,
ApplicationServiceTransactionStore,
ReceiptsStore,
EndToEndKeyStore,
EndToEndRoomKeyStore,
SearchStore,
TagsStore,
AccountDataStore,
EventPushActionsStore,
OpenIdStore,
ClientIpStore,
DeviceStore,
DeviceInboxStore,
UserDirectoryStore,
GroupServerStore,
UserErasureStore,
MonthlyActiveUsersStore,
StatsStore,
RelationsStore,
CensorEventsStore,
UIAuthStore,
CacheInvalidationWorkerStore,
ServerMetricsStore,
):
def __init__(self, database: Database, db_conn, hs):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine
self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id"
)
self._device_inbox_id_gen = StreamIdGenerator(
db_conn, "device_max_stream_id", "stream_id"
)
self._public_room_id_gen = StreamIdGenerator(
db_conn, "public_room_list_stream", "stream_id"
)
self._device_list_id_gen = StreamIdGenerator(
db_conn,
"device_lists_stream",
"stream_id",
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
],
)
self._cross_signing_id_gen = StreamIdGenerator(
db_conn, "e2e_cross_signing_keys", "stream_id"
)
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
self._pushers_id_gen = StreamIdGenerator(
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
self._group_updates_id_gen = StreamIdGenerator(
db_conn, "local_group_updates", "stream_id"
)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = MultiWriterIdGenerator(
db_conn,
database,
instance_name="master",
table="cache_invalidation_stream_by_instance",
instance_column="instance_name",
id_column="stream_id",
sequence_name="cache_invalidation_stream_seq",
)
else:
self._cache_id_gen = None
super(DataStore, self).__init__(database, db_conn, hs)
self._presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self.db.get_cache_dict(
db_conn,
"presence_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self._presence_id_gen.get_current_token(),
)
self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache",
min_presence_val,
prefilled_cache=presence_cache_prefill,
)
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict(
db_conn,
"device_inbox",
entity_column="user_id",
stream_column="stream_id",
max_value=max_device_inbox_id,
limit=1000,
)
self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache",
min_device_inbox_id,
prefilled_cache=device_inbox_prefill,
)
# The federation outbox and the local device inbox uses the same
# stream_id generator.
device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict(
db_conn,
"device_federation_outbox",
entity_column="destination",
stream_column="stream_id",
max_value=max_device_inbox_id,
limit=1000,
)
self._device_federation_outbox_stream_cache = StreamChangeCache(
"DeviceFederationOutboxStreamChangeCache",
min_device_outbox_id,
prefilled_cache=device_outbox_prefill,
)
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
"DeviceListStreamChangeCache", device_list_max
)
self._user_signature_stream_cache = StreamChangeCache(
"UserSignatureStreamChangeCache", device_list_max
)
self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache", device_list_max
)
events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
db_conn,
"current_state_delta_stream",
entity_column="room_id",
stream_column="stream_id",
max_value=events_max, # As we share the stream id with events token
limit=1000,
)
self._curr_state_delta_stream_cache = StreamChangeCache(
"_curr_state_delta_stream_cache",
min_curr_state_delta_id,
prefilled_cache=curr_state_delta_prefill,
)
_group_updates_prefill, min_group_updates_id = self.db.get_cache_dict(
db_conn,
"local_group_updates",
entity_column="user_id",
stream_column="stream_id",
max_value=self._group_updates_id_gen.get_current_token(),
limit=1000,
)
self._group_updates_stream_cache = StreamChangeCache(
"_group_updates_stream_cache",
min_group_updates_id,
prefilled_cache=_group_updates_prefill,
)
self._stream_order_on_start = self.get_room_max_stream_ordering()
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
# Used in _generate_user_daily_visits to keep track of progress
self._last_user_visit_update = self._get_start_of_day()
def take_presence_startup_info(self):
active_on_startup = self._presence_on_startup
self._presence_on_startup = None
return active_on_startup
def _get_active_presence(self, db_conn):
"""Fetch non-offline presence from the database so that we can register
the appropriate time outs.
"""
sql = (
"SELECT user_id, state, last_active_ts, last_federation_update_ts,"
" last_user_sync_ts, status_msg, currently_active FROM presence_stream"
" WHERE state != ?"
)
sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor()
txn.execute(sql, (PresenceState.OFFLINE,))
rows = self.db.cursor_to_dict(txn)
txn.close()
for row in rows:
row["currently_active"] = bool(row["currently_active"])
return [UserPresenceState(**row) for row in rows]
def count_daily_users(self):
"""
Counts the number of users who used this homeserver in the last 24 hours.
"""
yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
return self.db.runInteraction("count_daily_users", self._count_users, yesterday)
def count_monthly_users(self):
"""
Counts the number of users who used this homeserver in the last 30 days.
Note this method is intended for phonehome metrics only and is different
from the mau figure in synapse.storage.monthly_active_users which,
amongst other things, includes a 3 day grace period before a user counts.
"""
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
return self.db.runInteraction(
"count_monthly_users", self._count_users, thirty_days_ago
)
def _count_users(self, txn, time_from):
"""
Returns number of users seen in the past time_from period
"""
sql = """
SELECT COALESCE(count(*), 0) FROM (
SELECT user_id FROM user_ips
WHERE last_seen > ?
GROUP BY user_id
) u
"""
txn.execute(sql, (time_from,))
(count,) = txn.fetchone()
return count
def count_r30_users(self):
"""
Counts the number of 30 day retained users, defined as:-
* Users who have created their accounts more than 30 days ago
* Where last seen at most 30 days ago
* Where account creation and last_seen are > 30 days apart
Returns counts globaly for a given user as well as breaking
by platform
"""
def _count_r30_users(txn):
thirty_days_in_secs = 86400 * 30
now = int(self._clock.time())
thirty_days_ago_in_secs = now - thirty_days_in_secs
sql = """
SELECT platform, COALESCE(count(*), 0) FROM (
SELECT
users.name, platform, users.creation_ts * 1000,
MAX(uip.last_seen)
FROM users
INNER JOIN (
SELECT
user_id,
last_seen,
CASE
WHEN user_agent LIKE '%%Android%%' THEN 'android'
WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
ELSE 'unknown'
END
AS platform
FROM user_ips
) uip
ON users.name = uip.user_id
AND users.appservice_id is NULL
AND users.creation_ts < ?
AND uip.last_seen/1000 > ?
AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
GROUP BY users.name, platform, users.creation_ts
) u GROUP BY platform
"""
results = {}
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
for row in txn:
if row[0] == "unknown":
pass
results[row[0]] = row[1]
sql = """
SELECT COALESCE(count(*), 0) FROM (
SELECT users.name, users.creation_ts * 1000,
MAX(uip.last_seen)
FROM users
INNER JOIN (
SELECT
user_id,
last_seen
FROM user_ips
) uip
ON users.name = uip.user_id
AND appservice_id is NULL
AND users.creation_ts < ?
AND uip.last_seen/1000 > ?
AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
GROUP BY users.name, users.creation_ts
) u
"""
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
(count,) = txn.fetchone()
results["all"] = count
return results
return self.db.runInteraction("count_r30_users", _count_r30_users)
def _get_start_of_day(self):
"""
Returns millisecond unixtime for start of UTC day.
"""
now = time.gmtime()
today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
return today_start * 1000
def generate_user_daily_visits(self):
"""
Generates daily visit data for use in cohort/ retention analysis
"""
def _generate_user_daily_visits(txn):
logger.info("Calling _generate_user_daily_visits")
today_start = self._get_start_of_day()
a_day_in_milliseconds = 24 * 60 * 60 * 1000
now = self.clock.time_msec()
sql = """
INSERT INTO user_daily_visits (user_id, device_id, timestamp)
SELECT u.user_id, u.device_id, ?
FROM user_ips AS u
LEFT JOIN (
SELECT user_id, device_id, timestamp FROM user_daily_visits
WHERE timestamp = ?
) udv
ON u.user_id = udv.user_id AND u.device_id=udv.device_id
INNER JOIN users ON users.name=u.user_id
WHERE last_seen > ? AND last_seen <= ?
AND udv.timestamp IS NULL AND users.is_guest=0
AND users.appservice_id IS NULL
GROUP BY u.user_id, u.device_id
"""
# This means that the day has rolled over but there could still
# be entries from the previous day. There is an edge case
# where if the user logs in at 23:59 and overwrites their
# last_seen at 00:01 then they will not be counted in the
# previous day's stats - it is important that the query is run
# often to minimise this case.
if today_start > self._last_user_visit_update:
yesterday_start = today_start - a_day_in_milliseconds
txn.execute(
sql,
(
yesterday_start,
yesterday_start,
self._last_user_visit_update,
today_start,
),
)
self._last_user_visit_update = today_start
txn.execute(
sql, (today_start, today_start, self._last_user_visit_update, now)
)
# Update _last_user_visit_update to now. The reason to do this
# rather just clamping to the beginning of the day is to limit
# the size of the join - meaning that the query can be run more
# frequently
self._last_user_visit_update = now
return self.db.runInteraction(
"generate_user_daily_visits", _generate_user_daily_visits
)
def get_users(self):
"""Function to retrieve a list of users in users table.
Args:
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
return self.db.simple_select_list(
table="users",
keyvalues={},
retcols=[
"name",
"password_hash",
"is_guest",
"admin",
"user_type",
"deactivated",
],
desc="get_users",
)
def get_users_paginate(
self, start, limit, name=None, guests=True, deactivated=False
):
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users and the
total number of users matching the filter criteria.
Args:
start (int): start number to begin the query from
limit (int): number of rows to retrieve
name (string): filter for user names
guests (bool): whether to in include guest users
deactivated (bool): whether to include deactivated users
Returns:
defer.Deferred: resolves to list[dict[str, Any]], int
"""
def get_users_paginate_txn(txn):
filters = []
args = []
if name:
filters.append("name LIKE ?")
args.append("%" + name + "%")
if not guests:
filters.append("is_guest = 0")
if not deactivated:
filters.append("deactivated = 0")
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause)
txn.execute(sql, args)
count = txn.fetchone()[0]
args = [self.hs.config.server_name] + args + [limit, start]
sql = """
SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url
FROM users as u
LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
{}
ORDER BY u.name LIMIT ? OFFSET ?
""".format(
where_clause
)
txn.execute(sql, args)
users = self.db.cursor_to_dict(txn)
return users, count
return self.db.runInteraction("get_users_paginate_txn", get_users_paginate_txn)
def search_users(self, term):
"""Function to search users list for one or more users with
the matched term.
Args:
term (str): search term
col (str): column to query term should be matched to
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
return self.db.simple_search_list(
table="users",
term=term,
col="name",
retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
desc="search_users",
)
def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig):
"""Called before upgrading an existing database to check that it is broadly sane
compared with the configuration.
"""
domain = config.server_name
sql = database_engine.convert_param_style(
"SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
)
pat = "%:" + domain
cur.execute(sql, (pat,))
num_not_matching = cur.fetchall()[0][0]
if num_not_matching == 0:
return
raise Exception(
"Found users in database not native to %s!\n"
"You cannot changed a synapse server_name after it's been configured"
% (domain,)
)
__all__ = ["DataStore", "check_database_before_upgrade"]

View file

@ -16,12 +16,14 @@
import abc
import logging
from typing import List, Tuple
from canonicaljson import json
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -38,13 +40,13 @@ class AccountDataWorkerStore(SQLBaseStore):
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
def __init__(self, db_conn, hs):
def __init__(self, database: Database, db_conn, hs):
account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max
)
super(AccountDataWorkerStore, self).__init__(db_conn, hs)
super(AccountDataWorkerStore, self).__init__(database, db_conn, hs)
@abc.abstractmethod
def get_max_account_data_stream_id(self):
@ -67,7 +69,7 @@ class AccountDataWorkerStore(SQLBaseStore):
"""
def get_account_data_for_user_txn(txn):
rows = self._simple_select_list_txn(
rows = self.db.simple_select_list_txn(
txn,
"account_data",
{"user_id": user_id},
@ -78,7 +80,7 @@ class AccountDataWorkerStore(SQLBaseStore):
row["account_data_type"]: json.loads(row["content"]) for row in rows
}
rows = self._simple_select_list_txn(
rows = self.db.simple_select_list_txn(
txn,
"room_account_data",
{"user_id": user_id},
@ -92,7 +94,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return global_account_data, by_room
return self.runInteraction(
return self.db.runInteraction(
"get_account_data_for_user", get_account_data_for_user_txn
)
@ -102,7 +104,7 @@ class AccountDataWorkerStore(SQLBaseStore):
Returns:
Deferred: A dict
"""
result = yield self._simple_select_one_onecol(
result = yield self.db.simple_select_one_onecol(
table="account_data",
keyvalues={"user_id": user_id, "account_data_type": data_type},
retcol="content",
@ -127,7 +129,7 @@ class AccountDataWorkerStore(SQLBaseStore):
"""
def get_account_data_for_room_txn(txn):
rows = self._simple_select_list_txn(
rows = self.db.simple_select_list_txn(
txn,
"room_account_data",
{"user_id": user_id, "room_id": room_id},
@ -138,7 +140,7 @@ class AccountDataWorkerStore(SQLBaseStore):
row["account_data_type"]: json.loads(row["content"]) for row in rows
}
return self.runInteraction(
return self.db.runInteraction(
"get_account_data_for_room", get_account_data_for_room_txn
)
@ -156,7 +158,7 @@ class AccountDataWorkerStore(SQLBaseStore):
"""
def get_account_data_for_room_and_type_txn(txn):
content_json = self._simple_select_one_onecol_txn(
content_json = self.db.simple_select_one_onecol_txn(
txn,
table="room_account_data",
keyvalues={
@ -170,45 +172,68 @@ class AccountDataWorkerStore(SQLBaseStore):
return json.loads(content_json) if content_json else None
return self.runInteraction(
return self.db.runInteraction(
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
)
def get_all_updated_account_data(
self, last_global_id, last_room_id, current_id, limit
):
"""Get all the client account_data that has changed on the server
Args:
last_global_id(int): The position to fetch from for top level data
last_room_id(int): The position to fetch from for per room data
current_id(int): The position to fetch up to.
Returns:
A deferred pair of lists of tuples of stream_id int, user_id string,
room_id string, type string, and content string.
"""
if last_room_id == current_id and last_global_id == current_id:
return defer.succeed(([], []))
async def get_updated_global_account_data(
self, last_id: int, current_id: int, limit: int
) -> List[Tuple[int, str, str]]:
"""Get the global account_data that has changed, for the account_data stream
def get_updated_account_data_txn(txn):
Args:
last_id: the last stream_id from the previous batch.
current_id: the maximum stream_id to return up to
limit: the maximum number of rows to return
Returns:
A list of tuples of stream_id int, user_id string,
and type string.
"""
if last_id == current_id:
return []
def get_updated_global_account_data_txn(txn):
sql = (
"SELECT stream_id, user_id, account_data_type, content"
"SELECT stream_id, user_id, account_data_type"
" FROM account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_global_id, current_id, limit))
global_results = txn.fetchall()
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
return await self.db.runInteraction(
"get_updated_global_account_data", get_updated_global_account_data_txn
)
async def get_updated_room_account_data(
self, last_id: int, current_id: int, limit: int
) -> List[Tuple[int, str, str, str]]:
"""Get the global account_data that has changed, for the account_data stream
Args:
last_id: the last stream_id from the previous batch.
current_id: the maximum stream_id to return up to
limit: the maximum number of rows to return
Returns:
A list of tuples of stream_id int, user_id string,
room_id string and type string.
"""
if last_id == current_id:
return []
def get_updated_room_account_data_txn(txn):
sql = (
"SELECT stream_id, user_id, room_id, account_data_type, content"
"SELECT stream_id, user_id, room_id, account_data_type"
" FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_room_id, current_id, limit))
room_results = txn.fetchall()
return global_results, room_results
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
return self.runInteraction(
"get_all_updated_account_data_txn", get_updated_account_data_txn
return await self.db.runInteraction(
"get_updated_room_account_data", get_updated_room_account_data_txn
)
def get_updated_account_data_for_user(self, user_id, stream_id):
@ -250,9 +275,9 @@ class AccountDataWorkerStore(SQLBaseStore):
user_id, int(stream_id)
)
if not changed:
return {}, {}
return defer.succeed(({}, {}))
return self.runInteraction(
return self.db.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
@ -270,12 +295,18 @@ class AccountDataWorkerStore(SQLBaseStore):
class AccountDataStore(AccountDataWorkerStore):
def __init__(self, db_conn, hs):
def __init__(self, database: Database, db_conn, hs):
self._account_data_id_gen = StreamIdGenerator(
db_conn, "account_data_max_stream_id", "stream_id"
db_conn,
"account_data_max_stream_id",
"stream_id",
extra_tables=[
("room_account_data", "stream_id"),
("room_tags_revisions", "stream_id"),
],
)
super(AccountDataStore, self).__init__(db_conn, hs)
super(AccountDataStore, self).__init__(database, db_conn, hs)
def get_max_account_data_stream_id(self):
"""Get the current max stream id for the private user data stream
@ -300,9 +331,9 @@ class AccountDataStore(AccountDataWorkerStore):
with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so _simple_upsert will
# on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict.
yield self._simple_upsert(
yield self.db.simple_upsert(
desc="add_room_account_data",
table="room_account_data",
keyvalues={
@ -346,9 +377,9 @@ class AccountDataStore(AccountDataWorkerStore):
with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so _simple_upsert will retry if
# (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict.
yield self._simple_upsert(
yield self.db.simple_upsert(
desc="add_user_account_data",
table="account_data",
keyvalues={"user_id": user_id, "account_data_type": account_data_type},
@ -362,6 +393,10 @@ class AccountDataStore(AccountDataWorkerStore):
# doesn't sound any worse than the whole update getting lost,
# which is what would happen if we combined the two into one
# transaction.
#
# Note: This is only here for backwards compat to allow admins to
# roll back to a previous Synapse version. Next time we update the
# database version we can remove this table.
yield self._update_max_stream_id(next_id)
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
@ -380,6 +415,10 @@ class AccountDataStore(AccountDataWorkerStore):
next_id(int): The the revision to advance to.
"""
# Note: This is only here for backwards compat to allow admins to
# roll back to a previous Synapse version. Next time we update the
# database version we can remove this table.
def _update(txn):
update_max_id_sql = (
"UPDATE account_data_max_stream_id"
@ -388,4 +427,4 @@ class AccountDataStore(AccountDataWorkerStore):
)
txn.execute(update_max_id_sql, (next_id, next_id))
return self.runInteraction("update_account_data_max_stream_id", _update)
return self.db.runInteraction("update_account_data_max_stream_id", _update)

View file

@ -22,20 +22,20 @@ from twisted.internet import defer
from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
from synapse.storage.events_worker import EventsWorkerStore
from ._base import SQLBaseStore
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.database import Database
logger = logging.getLogger(__name__)
def _make_exclusive_regex(services_cache):
# We precompie a regex constructed from all the regexes that the AS's
# We precompile a regex constructed from all the regexes that the AS's
# have registered for exclusive users.
exclusive_user_regexes = [
regex.pattern
for service in services_cache
for regex in service.get_exlusive_user_regexes()
for regex in service.get_exclusive_user_regexes()
]
if exclusive_user_regexes:
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
@ -49,13 +49,13 @@ def _make_exclusive_regex(services_cache):
class ApplicationServiceWorkerStore(SQLBaseStore):
def __init__(self, db_conn, hs):
def __init__(self, database: Database, db_conn, hs):
self.services_cache = load_appservices(
hs.hostname, hs.config.app_service_config_files
)
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs)
super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs)
def get_app_services(self):
return self.services_cache
@ -134,8 +134,8 @@ class ApplicationServiceTransactionWorkerStore(
A Deferred which resolves to a list of ApplicationServices, which
may be empty.
"""
results = yield self._simple_select_list(
"application_services_state", dict(state=state), ["as_id"]
results = yield self.db.simple_select_list(
"application_services_state", {"state": state}, ["as_id"]
)
# NB: This assumes this class is linked with ApplicationServiceStore
as_list = self.get_app_services()
@ -156,9 +156,9 @@ class ApplicationServiceTransactionWorkerStore(
Returns:
A Deferred which resolves to ApplicationServiceState.
"""
result = yield self._simple_select_one(
result = yield self.db.simple_select_one(
"application_services_state",
dict(as_id=service.id),
{"as_id": service.id},
["state"],
allow_none=True,
desc="get_appservice_state",
@ -176,8 +176,8 @@ class ApplicationServiceTransactionWorkerStore(
Returns:
A Deferred which resolves when the state was set successfully.
"""
return self._simple_upsert(
"application_services_state", dict(as_id=service.id), dict(state=state)
return self.db.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state}
)
def create_appservice_txn(self, service, events):
@ -217,7 +217,7 @@ class ApplicationServiceTransactionWorkerStore(
)
return AppServiceTransaction(service=service, id=new_txn_id, events=events)
return self.runInteraction("create_appservice_txn", _create_appservice_txn)
return self.db.runInteraction("create_appservice_txn", _create_appservice_txn)
def complete_appservice_txn(self, txn_id, service):
"""Completes an application service transaction.
@ -250,19 +250,23 @@ class ApplicationServiceTransactionWorkerStore(
)
# Set current txn_id for AS to 'txn_id'
self._simple_upsert_txn(
self.db.simple_upsert_txn(
txn,
"application_services_state",
dict(as_id=service.id),
dict(last_txn=txn_id),
{"as_id": service.id},
{"last_txn": txn_id},
)
# Delete txn
self._simple_delete_txn(
txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id)
self.db.simple_delete_txn(
txn,
"application_services_txns",
{"txn_id": txn_id, "as_id": service.id},
)
return self.runInteraction("complete_appservice_txn", _complete_appservice_txn)
return self.db.runInteraction(
"complete_appservice_txn", _complete_appservice_txn
)
@defer.inlineCallbacks
def get_oldest_unsent_txn(self, service):
@ -284,7 +288,7 @@ class ApplicationServiceTransactionWorkerStore(
" ORDER BY txn_id ASC LIMIT 1",
(service.id,),
)
rows = self.cursor_to_dict(txn)
rows = self.db.cursor_to_dict(txn)
if not rows:
return None
@ -292,7 +296,7 @@ class ApplicationServiceTransactionWorkerStore(
return entry
entry = yield self.runInteraction(
entry = yield self.db.runInteraction(
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
)
@ -322,7 +326,7 @@ class ApplicationServiceTransactionWorkerStore(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
)
return self.runInteraction(
return self.db.runInteraction(
"set_appservice_last_pos", set_appservice_last_pos_txn
)
@ -351,7 +355,7 @@ class ApplicationServiceTransactionWorkerStore(
return upper_bound, [row[1] for row in rows]
upper_bound, event_ids = yield self.runInteraction(
upper_bound, event_ids = yield self.db.runInteraction(
"get_new_events_for_appservice", get_new_events_for_appservice_txn
)

View file

@ -0,0 +1,280 @@
# -*- coding: utf-8 -*-
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
from typing import Any, Iterable, Optional, Tuple
from synapse.api.constants import EventTypes
from synapse.replication.tcp.streams.events import (
EventsStreamCurrentStateRow,
EventsStreamEventRow,
)
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__)
# This is a special cache name we use to batch multiple invalidations of caches
# based on the current state when notifying workers over replication.
CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
class CacheInvalidationWorkerStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs):
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
async def get_all_updated_caches(
self, instance_name: str, last_id: int, current_id: int, limit: int
):
"""Fetches cache invalidation rows between the two given IDs written
by the given instance. Returns at most `limit` rows.
"""
if last_id == current_id:
return []
def get_all_updated_caches_txn(txn):
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
sql = """
SELECT stream_id, cache_func, keys, invalidation_ts
FROM cache_invalidation_stream_by_instance
WHERE stream_id > ? AND instance_name = ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (last_id, instance_name, limit))
return txn.fetchall()
return await self.db.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "events":
for row in rows:
self._process_event_stream_row(token, row)
elif stream_name == "backfill":
for row in rows:
self._invalidate_caches_for_event(
-token,
row.event_id,
row.room_id,
row.type,
row.state_key,
row.redacts,
row.relates_to,
backfilled=True,
)
elif stream_name == "caches":
if self._cache_id_gen:
self._cache_id_gen.advance(instance_name, token)
for row in rows:
if row.cache_func == CURRENT_STATE_CACHE_NAME:
if row.keys is None:
raise Exception(
"Can't send an 'invalidate all' for current state cache"
)
room_id = row.keys[0]
members_changed = set(row.keys[1:])
self._invalidate_state_caches(room_id, members_changed)
else:
self._attempt_to_invalidate_cache(row.cache_func, row.keys)
super().process_replication_rows(stream_name, instance_name, token, rows)
def _process_event_stream_row(self, token, row):
data = row.data
if row.type == EventsStreamEventRow.TypeId:
self._invalidate_caches_for_event(
token,
data.event_id,
data.room_id,
data.type,
data.state_key,
data.redacts,
data.relates_to,
backfilled=False,
)
elif row.type == EventsStreamCurrentStateRow.TypeId:
self._curr_state_delta_stream_cache.entity_has_changed(
row.data.room_id, token
)
if data.type == EventTypes.Member:
self.get_rooms_for_user_with_stream_ordering.invalidate(
(data.state_key,)
)
else:
raise Exception("Unknown events stream row type %s" % (row.type,))
def _invalidate_caches_for_event(
self,
stream_ordering,
event_id,
room_id,
etype,
state_key,
redacts,
relates_to,
backfilled,
):
self._invalidate_get_event_cache(event_id)
self.get_latest_event_ids_in_room.invalidate((room_id,))
self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
if not backfilled:
self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
if redacts:
self._invalidate_get_event_cache(redacts)
if etype == EventTypes.Member:
self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
self.get_invited_rooms_for_local_user.invalidate((state_key,))
if relates_to:
self.get_relations_for_event.invalidate_many((relates_to,))
self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
self.get_applicable_edit.invalidate((relates_to,))
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
This should only be used to invalidate caches where slaves won't
otherwise know from other replication streams that the cache should
be invalidated.
"""
cache_func = getattr(self, cache_name, None)
if not cache_func:
return
cache_func.invalidate(keys)
await self.db.runInteraction(
"invalidate_cache_and_stream",
self._send_invalidation_to_replication,
cache_func.__name__,
keys,
)
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
This should only be used to invalidate caches where slaves won't
otherwise know from other replication streams that the cache should
be invalidated.
"""
txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
def _invalidate_all_cache_and_stream(self, txn, cache_func):
"""Invalidates the entire cache and adds it to the cache stream so slaves
will know to invalidate their caches.
"""
txn.call_after(cache_func.invalidate_all)
self._send_invalidation_to_replication(txn, cache_func.__name__, None)
def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed):
"""Special case invalidation of caches based on current state.
We special case this so that we can batch the cache invalidations into a
single replication poke.
Args:
txn
room_id (str): Room where state changed
members_changed (iterable[str]): The user_ids of members that have changed
"""
txn.call_after(self._invalidate_state_caches, room_id, members_changed)
if members_changed:
# We need to be careful that the size of the `members_changed` list
# isn't so large that it causes problems sending over replication, so we
# send them in chunks.
# Max line length is 16K, and max user ID length is 255, so 50 should
# be safe.
for chunk in batch_iter(members_changed, 50):
keys = itertools.chain([room_id], chunk)
self._send_invalidation_to_replication(
txn, CURRENT_STATE_CACHE_NAME, keys
)
else:
# if no members changed, we still need to invalidate the other caches.
self._send_invalidation_to_replication(
txn, CURRENT_STATE_CACHE_NAME, [room_id]
)
def _send_invalidation_to_replication(
self, txn, cache_name: str, keys: Optional[Iterable[Any]]
):
"""Notifies replication that given cache has been invalidated.
Note that this does *not* invalidate the cache locally.
Args:
txn
cache_name
keys: Entry to invalidate. If None will invalidate all.
"""
if cache_name == CURRENT_STATE_CACHE_NAME and keys is None:
raise Exception(
"Can't stream invalidate all with magic current state cache"
)
if isinstance(self.database_engine, PostgresEngine):
# get_next() returns a context manager which is designed to wrap
# the transaction. However, we want to only get an ID when we want
# to use it, here, so we need to call __enter__ manually, and have
# __exit__ called after the transaction finishes.
stream_id = self._cache_id_gen.get_next_txn(txn)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
if keys is not None:
keys = list(keys)
self.db.simple_insert_txn(
txn,
table="cache_invalidation_stream_by_instance",
values={
"stream_id": stream_id,
"instance_name": self._instance_name,
"cache_func": cache_name,
"keys": keys,
"invalidation_ts": self.clock.time_msec(),
},
)
def get_cache_stream_token(self, instance_name):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token(instance_name)
else:
return 0

View file

@ -0,0 +1,208 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from twisted.internet import defer
from synapse.events.utils import prune_event_dict
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
from synapse.storage.data_stores.main.events import encode_json
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.database import Database
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore):
def __init__(self, database: Database, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
def _censor_redactions():
return run_as_background_process(
"_censor_redactions", self._censor_redactions
)
if self.hs.config.redaction_retention_period is not None:
hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
async def _censor_redactions(self):
"""Censors all redactions older than the configured period that haven't
been censored yet.
By censor we mean update the event_json table with the redacted event.
"""
if self.hs.config.redaction_retention_period is None:
return
if not (
await self.db.updates.has_completed_background_update(
"redactions_have_censored_ts_idx"
)
):
# We don't want to run this until the appropriate index has been
# created.
return
before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period
# We fetch all redactions that:
# 1. point to an event we have,
# 2. has a received_ts from before the cut off, and
# 3. we haven't yet censored.
#
# This is limited to 100 events to ensure that we don't try and do too
# much at once. We'll get called again so this should eventually catch
# up.
sql = """
SELECT redactions.event_id, redacts FROM redactions
LEFT JOIN events AS original_event ON (
redacts = original_event.event_id
)
WHERE NOT have_censored
AND redactions.received_ts <= ?
ORDER BY redactions.received_ts ASC
LIMIT ?
"""
rows = await self.db.execute(
"_censor_redactions_fetch", None, sql, before_ts, 100
)
updates = []
for redaction_id, event_id in rows:
redaction_event = await self.get_event(redaction_id, allow_none=True)
original_event = await self.get_event(
event_id, allow_rejected=True, allow_none=True
)
# The SQL above ensures that we have both the redaction and
# original event, so if the `get_event` calls return None it
# means that the redaction wasn't allowed. Either way we know that
# the result won't change so we mark the fact that we've checked.
if (
redaction_event
and original_event
and original_event.internal_metadata.is_redacted()
):
# Redaction was allowed
pruned_json = encode_json(
prune_event_dict(
original_event.room_version, original_event.get_dict()
)
)
else:
# Redaction wasn't allowed
pruned_json = None
updates.append((redaction_id, event_id, pruned_json))
def _update_censor_txn(txn):
for redaction_id, event_id, pruned_json in updates:
if pruned_json:
self._censor_event_txn(txn, event_id, pruned_json)
self.db.simple_update_one_txn(
txn,
table="redactions",
keyvalues={"event_id": redaction_id},
updatevalues={"have_censored": True},
)
await self.db.runInteraction("_update_censor_txn", _update_censor_txn)
def _censor_event_txn(self, txn, event_id, pruned_json):
"""Censor an event by replacing its JSON in the event_json table with the
provided pruned JSON.
Args:
txn (LoggingTransaction): The database transaction.
event_id (str): The ID of the event to censor.
pruned_json (str): The pruned JSON
"""
self.db.simple_update_one_txn(
txn,
table="event_json",
keyvalues={"event_id": event_id},
updatevalues={"json": pruned_json},
)
@defer.inlineCallbacks
def expire_event(self, event_id):
"""Retrieve and expire an event that has expired, and delete its associated
expiry timestamp. If the event can't be retrieved, delete its associated
timestamp so we don't try to expire it again in the future.
Args:
event_id (str): The ID of the event to delete.
"""
# Try to retrieve the event's content from the database or the event cache.
event = yield self.get_event(event_id)
def delete_expired_event_txn(txn):
# Delete the expiry timestamp associated with this event from the database.
self._delete_event_expiry_txn(txn, event_id)
if not event:
# If we can't find the event, log a warning and delete the expiry date
# from the database so that we don't try to expire it again in the
# future.
logger.warning(
"Can't expire event %s because we don't have it.", event_id
)
return
# Prune the event's dict then convert it to JSON.
pruned_json = encode_json(
prune_event_dict(event.room_version, event.get_dict())
)
# Update the event_json table to replace the event's JSON with the pruned
# JSON.
self._censor_event_txn(txn, event.event_id, pruned_json)
# We need to invalidate the event cache entry for this event because we
# changed its content in the database. We can't call
# self._invalidate_cache_and_stream because self.get_event_cache isn't of the
# right type.
txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
# Send that invalidation to replication so that other workers also invalidate
# the event cache.
self._send_invalidation_to_replication(
txn, "_get_event_cache", (event.event_id,)
)
yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn)
def _delete_event_expiry_txn(self, txn, event_id):
"""Delete the expiry timestamp associated with an event ID without deleting the
actual event.
Args:
txn (LoggingTransaction): The transaction to use to perform the deletion.
event_id (str): The event ID to delete the associated expiry timestamp of.
"""
return self.db.simple_delete_txn(
txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
)

View file

@ -19,11 +19,10 @@ from six import iteritems
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches import CACHE_SIZE_FACTOR
from . import background_updates
from ._base import Cache
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database, make_tuple_comparison_clause
from synapse.util.caches.descriptors import Cache
logger = logging.getLogger(__name__)
@ -33,46 +32,41 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpStore(background_updates.BackgroundUpdateStore):
def __init__(self, db_conn, hs):
class ClientIpBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs):
super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
)
super(ClientIpStore, self).__init__(db_conn, hs)
self.register_background_index_update(
self.db.updates.register_background_index_update(
"user_ips_device_index",
index_name="user_ips_device_id",
table="user_ips",
columns=["user_id", "device_id", "last_seen"],
)
self.register_background_index_update(
self.db.updates.register_background_index_update(
"user_ips_last_seen_index",
index_name="user_ips_last_seen",
table="user_ips",
columns=["user_id", "last_seen"],
)
self.register_background_index_update(
self.db.updates.register_background_index_update(
"user_ips_last_seen_only_index",
index_name="user_ips_last_seen_only",
table="user_ips",
columns=["last_seen"],
)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
"user_ips_analyze", self._analyze_user_ip
)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
"user_ips_remove_dupes", self._remove_user_ip_dupes
)
# Register a unique index
self.register_background_index_update(
self.db.updates.register_background_index_update(
"user_ips_device_unique_index",
index_name="user_ips_user_token_ip_unique_index",
table="user_ips",
@ -81,18 +75,13 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
)
# Drop the old non-unique index
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
"user_ips_drop_nonunique_index", self._remove_user_ip_nonunique
)
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
self._batch_row_update = {}
self._client_ip_looper = self._clock.looping_call(
self._update_client_ips_batch, 5 * 1000
)
self.hs.get_reactor().addSystemEventTrigger(
"before", "shutdown", self._update_client_ips_batch
# Update the last seen info in devices.
self.db.updates.register_background_update_handler(
"devices_last_seen", self._devices_last_seen_update
)
@defer.inlineCallbacks
@ -102,8 +91,8 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
txn.close()
yield self.runWithConnection(f)
yield self._end_background_update("user_ips_drop_nonunique_index")
yield self.db.runWithConnection(f)
yield self.db.updates._end_background_update("user_ips_drop_nonunique_index")
return 1
@defer.inlineCallbacks
@ -117,9 +106,9 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
def user_ips_analyze(txn):
txn.execute("ANALYZE user_ips")
yield self.runInteraction("user_ips_analyze", user_ips_analyze)
yield self.db.runInteraction("user_ips_analyze", user_ips_analyze)
yield self._end_background_update("user_ips_analyze")
yield self.db.updates._end_background_update("user_ips_analyze")
return 1
@ -151,7 +140,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
return None
# Get a last seen that has roughly `batch_size` since `begin_last_seen`
end_last_seen = yield self.runInteraction(
end_last_seen = yield self.db.runInteraction(
"user_ips_dups_get_last_seen", get_last_seen
)
@ -282,17 +271,115 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
(user_id, access_token, ip, device_id, user_agent, last_seen),
)
self._background_update_progress_txn(
self.db.updates._background_update_progress_txn(
txn, "user_ips_remove_dupes", {"last_seen": end_last_seen}
)
yield self.runInteraction("user_ips_dups_remove", remove)
yield self.db.runInteraction("user_ips_dups_remove", remove)
if last:
yield self._end_background_update("user_ips_remove_dupes")
yield self.db.updates._end_background_update("user_ips_remove_dupes")
return batch_size
@defer.inlineCallbacks
def _devices_last_seen_update(self, progress, batch_size):
"""Background update to insert last seen info into devices table
"""
last_user_id = progress.get("last_user_id", "")
last_device_id = progress.get("last_device_id", "")
def _devices_last_seen_update_txn(txn):
# This consists of two queries:
#
# 1. The sub-query searches for the next N devices and joins
# against user_ips to find the max last_seen associated with
# that device.
# 2. The outer query then joins again against user_ips on
# user/device/last_seen. This *should* hopefully only
# return one row, but if it does return more than one then
# we'll just end up updating the same device row multiple
# times, which is fine.
where_clause, where_args = make_tuple_comparison_clause(
self.database_engine,
[("user_id", last_user_id), ("device_id", last_device_id)],
)
sql = """
SELECT
last_seen, ip, user_agent, user_id, device_id
FROM (
SELECT
user_id, device_id, MAX(u.last_seen) AS last_seen
FROM devices
INNER JOIN user_ips AS u USING (user_id, device_id)
WHERE %(where_clause)s
GROUP BY user_id, device_id
ORDER BY user_id ASC, device_id ASC
LIMIT ?
) c
INNER JOIN user_ips AS u USING (user_id, device_id, last_seen)
""" % {
"where_clause": where_clause
}
txn.execute(sql, where_args + [batch_size])
rows = txn.fetchall()
if not rows:
return 0
sql = """
UPDATE devices
SET last_seen = ?, ip = ?, user_agent = ?
WHERE user_id = ? AND device_id = ?
"""
txn.execute_batch(sql, rows)
_, _, _, user_id, device_id = rows[-1]
self.db.updates._background_update_progress_txn(
txn,
"devices_last_seen",
{"last_user_id": user_id, "last_device_id": device_id},
)
return len(rows)
updated = yield self.db.runInteraction(
"_devices_last_seen_update", _devices_last_seen_update_txn
)
if not updated:
yield self.db.updates._end_background_update("devices_last_seen")
return updated
class ClientIpStore(ClientIpBackgroundUpdateStore):
def __init__(self, database: Database, db_conn, hs):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000
)
super(ClientIpStore, self).__init__(database, db_conn, hs)
self.user_ips_max_age = hs.config.user_ips_max_age
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
self._batch_row_update = {}
self._client_ip_looper = self._clock.looping_call(
self._update_client_ips_batch, 5 * 1000
)
self.hs.get_reactor().addSystemEventTrigger(
"before", "shutdown", self._update_client_ips_batch
)
if self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
@defer.inlineCallbacks
def insert_client_ip(
self, user_id, access_token, ip, user_agent, device_id, now=None
@ -314,23 +401,22 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
self._batch_row_update[key] = (user_agent, device_id, now)
@wrap_as_background_process("update_client_ips")
def _update_client_ips_batch(self):
# If the DB pool has already terminated, don't try updating
if not self.hs.get_db_pool().running:
if not self.db.is_running():
return
def update():
to_update = self._batch_row_update
self._batch_row_update = {}
return self.runInteraction(
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
)
to_update = self._batch_row_update
self._batch_row_update = {}
return run_as_background_process("update_client_ips", update)
return self.db.runInteraction(
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
)
def _update_client_ips_batch_txn(self, txn, to_update):
if "user_ips" in self._unsafe_to_upsert_tables or (
if "user_ips" in self.db._unsafe_to_upsert_tables or (
not self.database_engine.can_native_upsert
):
self.database_engine.lock_table(txn, "user_ips")
@ -339,7 +425,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
try:
self._simple_upsert_txn(
self.db.simple_upsert_txn(
txn,
table="user_ips",
keyvalues={
@ -354,6 +440,23 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
},
lock=False,
)
# Technically an access token might not be associated with
# a device so we need to check.
if device_id:
# this is always an update rather than an upsert: the row should
# already exist, and if it doesn't, that may be because it has been
# deleted, and we don't want to re-create it.
self.db.simple_update_txn(
txn,
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
updatevalues={
"user_agent": user_agent,
"last_seen": last_seen,
"ip": ip,
},
)
except Exception as e:
# Failed to upsert, log and continue
logger.error("Failed to insert client IP %r: %r", entry, e)
@ -372,19 +475,14 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
keys giving the column names
"""
res = yield self.runInteraction(
"get_last_client_ip_by_device",
self._get_last_client_ip_by_device_txn,
user_id,
device_id,
retcols=(
"user_id",
"access_token",
"ip",
"user_agent",
"device_id",
"last_seen",
),
keyvalues = {"user_id": user_id}
if device_id is not None:
keyvalues["device_id"] = device_id
res = yield self.db.simple_select_list(
table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
)
ret = {(d["user_id"], d["device_id"]): d for d in res}
@ -403,42 +501,6 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
}
return ret
@classmethod
def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols):
where_clauses = []
bindings = []
if device_id is None:
where_clauses.append("user_id = ?")
bindings.extend((user_id,))
else:
where_clauses.append("(user_id = ? AND device_id = ?)")
bindings.extend((user_id, device_id))
if not where_clauses:
return []
inner_select = (
"SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips "
"WHERE %(where)s "
"GROUP BY user_id, device_id"
) % {"where": " OR ".join(where_clauses)}
sql = (
"SELECT %(retcols)s FROM user_ips "
"JOIN (%(inner_select)s) ips ON"
" user_ips.last_seen = ips.mls AND"
" user_ips.user_id = ips.user_id AND"
" (user_ips.device_id = ips.device_id OR"
" (user_ips.device_id IS NULL AND ips.device_id IS NULL)"
" )"
) % {
"retcols": ",".join("user_ips." + c for c in retcols),
"inner_select": inner_select,
}
txn.execute(sql, bindings)
return cls.cursor_to_dict(txn)
@defer.inlineCallbacks
def get_user_ip_and_agents(self, user):
user_id = user.to_string()
@ -450,7 +512,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
user_agent, _, last_seen = self._batch_row_update[key]
results[(access_token, ip)] = (user_agent, last_seen)
rows = yield self._simple_select_list(
rows = yield self.db.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "last_seen"],
@ -461,7 +523,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
for row in rows
)
return list(
return [
{
"access_token": access_token,
"ip": ip,
@ -469,4 +531,48 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"last_seen": last_seen,
}
for (access_token, ip), (user_agent, last_seen) in iteritems(results)
)
]
@wrap_as_background_process("prune_old_user_ips")
async def _prune_old_user_ips(self):
"""Removes entries in user IPs older than the configured period.
"""
if self.user_ips_max_age is None:
# Nothing to do
return
if not await self.db.updates.has_completed_background_update(
"devices_last_seen"
):
# Only start pruning if we have finished populating the devices
# last seen info.
return
# We do a slightly funky SQL delete to ensure we don't try and delete
# too much at once (as the table may be very large from before we
# started pruning).
#
# This works by finding the max last_seen that is less than the given
# time, but has no more than N rows before it, deleting all rows with
# a lesser last_seen time. (We COALESCE so that the sub-SELECT always
# returns exactly one row).
sql = """
DELETE FROM user_ips
WHERE last_seen <= (
SELECT COALESCE(MAX(last_seen), -1)
FROM (
SELECT last_seen FROM user_ips
WHERE last_seen <= ?
ORDER BY last_seen ASC
LIMIT 5000
) AS u
)
"""
timestamp = self.clock.time_msec() - self.user_ips_max_age
def _prune_old_user_ips_txn(txn):
txn.execute(sql, (timestamp,))
await self.db.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn)

View file

@ -20,8 +20,8 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.util.caches.expiringcache import ExpiringCache
logger = logging.getLogger(__name__)
@ -69,7 +69,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
stream_pos = current_stream_id
return messages, stream_pos
return self.runInteraction(
return self.db.runInteraction(
"get_new_messages_for_device", get_new_messages_for_device_txn
)
@ -109,7 +109,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, device_id, up_to_stream_id))
return txn.rowcount
count = yield self.runInteraction(
count = yield self.db.runInteraction(
"delete_messages_for_device", delete_messages_for_device_txn
)
@ -178,7 +178,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
stream_pos = current_stream_id
return messages, stream_pos
return self.runInteraction(
return self.db.runInteraction(
"get_new_device_msgs_for_remote",
get_new_messages_for_remote_destination_txn,
)
@ -203,28 +203,92 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
txn.execute(sql, (destination, up_to_stream_id))
return self.runInteraction(
return self.db.runInteraction(
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
)
def get_all_new_device_messages(self, last_pos, current_pos, limit):
"""
Args:
last_pos(int):
current_pos(int):
limit(int):
Returns:
A deferred list of rows from the device inbox
"""
if last_pos == current_pos:
return defer.succeed([])
class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
def get_all_new_device_messages_txn(txn):
# We limit like this as we might have multiple rows per stream_id, and
# we want to make sure we always get all entries for any stream_id
# we return.
upper_pos = min(current_pos, last_pos + limit)
sql = (
"SELECT max(stream_id), user_id"
" FROM device_inbox"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY user_id"
)
txn.execute(sql, (last_pos, upper_pos))
rows = txn.fetchall()
sql = (
"SELECT max(stream_id), destination"
" FROM device_federation_outbox"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY destination"
)
txn.execute(sql, (last_pos, upper_pos))
rows.extend(txn)
# Order by ascending stream ordering
rows.sort()
return rows
return self.db.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn
)
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, db_conn, hs):
super(DeviceInboxStore, self).__init__(db_conn, hs)
def __init__(self, database: Database, db_conn, hs):
super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.register_background_index_update(
self.db.updates.register_background_index_update(
"device_inbox_stream_index",
index_name="device_inbox_stream_id_user_id",
table="device_inbox",
columns=["stream_id", "user_id"],
)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
)
@defer.inlineCallbacks
def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close()
yield self.db.runWithConnection(reindex_txn)
yield self.db.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
return 1
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, database: Database, db_conn, hs):
super(DeviceInboxStore, self).__init__(database, db_conn, hs)
# Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions.
self._last_device_delete_cache = ExpiringCache(
@ -274,7 +338,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
yield self.runInteraction(
yield self.db.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
)
for user_id in local_messages_by_user_then_device.keys():
@ -294,7 +358,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
# Check if we've already inserted a matching message_id for that
# origin. This can happen if the origin doesn't receive our
# acknowledgement from the first time we received the message.
already_inserted = self._simple_select_one_txn(
already_inserted = self.db.simple_select_one_txn(
txn,
table="device_federation_inbox",
keyvalues={"origin": origin, "message_id": message_id},
@ -306,7 +370,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
# Add an entry for this message_id so that we know we've processed
# it.
self._simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="device_federation_inbox",
values={
@ -324,7 +388,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
yield self.runInteraction(
yield self.db.runInteraction(
"add_messages_from_remote_to_device_inbox",
add_messages_txn,
now_ms,
@ -347,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids.
sql = "SELECT device_id FROM devices" " WHERE user_id = ?"
sql = "SELECT device_id FROM devices WHERE user_id = ?"
txn.execute(sql, (user_id,))
message_json = json.dumps(messages_by_device["*"])
for row in txn:
@ -358,15 +422,15 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
else:
if not devices:
continue
sql = (
"SELECT device_id FROM devices"
" WHERE user_id = ? AND device_id IN ("
+ ",".join("?" * len(devices))
+ ")"
clause, args = make_in_list_sql_clause(
txn.database_engine, "device_id", devices
)
sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause
# TODO: Maybe this needs to be done in batches if there are
# too many local devices for a given user.
txn.execute(sql, [user_id] + devices)
txn.execute(sql, [user_id] + list(args))
for row in txn:
# Only insert into the local inbox if the device exists on
# this server
@ -391,60 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
rows.append((user_id, device_id, stream_id, message_json))
txn.executemany(sql, rows)
def get_all_new_device_messages(self, last_pos, current_pos, limit):
"""
Args:
last_pos(int):
current_pos(int):
limit(int):
Returns:
A deferred list of rows from the device inbox
"""
if last_pos == current_pos:
return defer.succeed([])
def get_all_new_device_messages_txn(txn):
# We limit like this as we might have multiple rows per stream_id, and
# we want to make sure we always get all entries for any stream_id
# we return.
upper_pos = min(current_pos, last_pos + limit)
sql = (
"SELECT max(stream_id), user_id"
" FROM device_inbox"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY user_id"
)
txn.execute(sql, (last_pos, upper_pos))
rows = txn.fetchall()
sql = (
"SELECT max(stream_id), destination"
" FROM device_federation_outbox"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY destination"
)
txn.execute(sql, (last_pos, upper_pos))
rows.extend(txn)
# Order by ascending stream ordering
rows.sort()
return rows
return self.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn
)
@defer.inlineCallbacks
def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close()
yield self.runWithConnection(reindex_txn)
yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
return 1

View file

@ -14,14 +14,14 @@
# limitations under the License.
from collections import namedtuple
from typing import Optional
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from ._base import SQLBaseStore
RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
@ -37,7 +37,7 @@ class DirectoryWorkerStore(SQLBaseStore):
Deferred: results in namedtuple with keys "room_id" and
"servers" or None if no association can be found
"""
room_id = yield self._simple_select_one_onecol(
room_id = yield self.db.simple_select_one_onecol(
"room_aliases",
{"room_alias": room_alias.to_string()},
"room_id",
@ -48,7 +48,7 @@ class DirectoryWorkerStore(SQLBaseStore):
if not room_id:
return None
servers = yield self._simple_select_onecol(
servers = yield self.db.simple_select_onecol(
"room_alias_servers",
{"room_alias": room_alias.to_string()},
"server",
@ -61,7 +61,7 @@ class DirectoryWorkerStore(SQLBaseStore):
return RoomAliasMapping(room_id, room_alias.to_string(), servers)
def get_room_alias_creator(self, room_alias):
return self._simple_select_one_onecol(
return self.db.simple_select_one_onecol(
table="room_aliases",
keyvalues={"room_alias": room_alias},
retcol="creator",
@ -70,7 +70,7 @@ class DirectoryWorkerStore(SQLBaseStore):
@cached(max_entries=5000)
def get_aliases_for_room(self, room_id):
return self._simple_select_onecol(
return self.db.simple_select_onecol(
"room_aliases",
{"room_id": room_id},
"room_alias",
@ -94,7 +94,7 @@ class DirectoryStore(DirectoryWorkerStore):
"""
def alias_txn(txn):
self._simple_insert_txn(
self.db.simple_insert_txn(
txn,
"room_aliases",
{
@ -104,7 +104,7 @@ class DirectoryStore(DirectoryWorkerStore):
},
)
self._simple_insert_many_txn(
self.db.simple_insert_many_txn(
txn,
table="room_alias_servers",
values=[
@ -118,7 +118,9 @@ class DirectoryStore(DirectoryWorkerStore):
)
try:
ret = yield self.runInteraction("create_room_alias_association", alias_txn)
ret = yield self.db.runInteraction(
"create_room_alias_association", alias_txn
)
except self.database_engine.module.IntegrityError:
raise SynapseError(
409, "Room alias %s already exists" % room_alias.to_string()
@ -127,7 +129,7 @@ class DirectoryStore(DirectoryWorkerStore):
@defer.inlineCallbacks
def delete_room_alias(self, room_alias):
room_id = yield self.runInteraction(
room_id = yield self.db.runInteraction(
"delete_room_alias", self._delete_room_alias_txn, room_alias
)
@ -158,10 +160,29 @@ class DirectoryStore(DirectoryWorkerStore):
return room_id
def update_aliases_for_room(self, old_room_id, new_room_id, creator):
def update_aliases_for_room(
self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
):
"""Repoint all of the aliases for a given room, to a different room.
Args:
old_room_id:
new_room_id:
creator: The user to record as the creator of the new mapping.
If None, the creator will be left unchanged.
"""
def _update_aliases_for_room_txn(txn):
sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?"
txn.execute(sql, (new_room_id, creator, old_room_id))
update_creator_sql = ""
sql_params = (new_room_id, old_room_id)
if creator:
update_creator_sql = ", creator = ?"
sql_params = (new_room_id, creator, old_room_id)
sql = "UPDATE room_aliases SET room_id = ? %s WHERE room_id = ?" % (
update_creator_sql,
)
txn.execute(sql, sql_params)
self._invalidate_cache_and_stream(
txn, self.get_aliases_for_room, (old_room_id,)
)
@ -169,6 +190,6 @@ class DirectoryStore(DirectoryWorkerStore):
txn, self.get_aliases_for_room, (new_room_id,)
)
return self.runInteraction(
return self.db.runInteraction(
"_update_aliases_for_room_txn", _update_aliases_for_room_txn
)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
# Copyright 2019 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -19,55 +20,13 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.logging.opentracing import log_kv, trace
from ._base import SQLBaseStore
from synapse.storage._base import SQLBaseStore
class EndToEndRoomKeyStore(SQLBaseStore):
@defer.inlineCallbacks
def get_e2e_room_key(self, user_id, version, room_id, session_id):
"""Get the encrypted E2E room key for a given session from a given
backup version of room_keys. We only store the 'best' room key for a given
session at a given time, as determined by the handler.
Args:
user_id(str): the user whose backup we're querying
version(str): the version ID of the backup for the set of keys we're querying
room_id(str): the ID of the room whose keys we're querying.
This is a bit redundant as it's implied by the session_id, but
we include for consistency with the rest of the API.
session_id(str): the session whose room_key we're querying.
Returns:
A deferred dict giving the session_data and message metadata for
this room key.
"""
row = yield self._simple_select_one(
table="e2e_room_keys",
keyvalues={
"user_id": user_id,
"version": version,
"room_id": room_id,
"session_id": session_id,
},
retcols=(
"first_message_index",
"forwarded_count",
"is_verified",
"session_data",
),
desc="get_e2e_room_key",
)
row["session_data"] = json.loads(row["session_data"])
return row
@defer.inlineCallbacks
def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
"""Replaces or inserts the encrypted E2E room key for a given session in
a given backup
def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
"""Replaces the encrypted E2E room key for a given session in a given backup
Args:
user_id(str): the user whose backup we're setting
@ -79,7 +38,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
StoreError
"""
yield self._simple_upsert(
yield self.db.simple_update_one(
table="e2e_room_keys",
keyvalues={
"user_id": user_id,
@ -87,21 +46,51 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"room_id": room_id,
"session_id": session_id,
},
values={
updatevalues={
"first_message_index": room_key["first_message_index"],
"forwarded_count": room_key["forwarded_count"],
"is_verified": room_key["is_verified"],
"session_data": json.dumps(room_key["session_data"]),
},
lock=False,
desc="update_e2e_room_key",
)
log_kv(
{
"message": "Set room key",
"room_id": room_id,
"session_id": session_id,
"room_key": room_key,
}
@defer.inlineCallbacks
def add_e2e_room_keys(self, user_id, version, room_keys):
"""Bulk add room keys to a given backup.
Args:
user_id (str): the user whose backup we're adding to
version (str): the version ID of the backup for the set of keys we're adding to
room_keys (iterable[(str, str, dict)]): the keys to add, in the form
(roomID, sessionID, keyData)
"""
values = []
for (room_id, session_id, room_key) in room_keys:
values.append(
{
"user_id": user_id,
"version": version,
"room_id": room_id,
"session_id": session_id,
"first_message_index": room_key["first_message_index"],
"forwarded_count": room_key["forwarded_count"],
"is_verified": room_key["is_verified"],
"session_data": json.dumps(room_key["session_data"]),
}
)
log_kv(
{
"message": "Set room key",
"room_id": room_id,
"session_id": session_id,
"room_key": room_key,
}
)
yield self.db.simple_insert_many(
table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
)
@trace
@ -111,11 +100,11 @@ class EndToEndRoomKeyStore(SQLBaseStore):
room, or a given session.
Args:
user_id(str): the user whose backup we're querying
version(str): the version ID of the backup for the set of keys we're querying
room_id(str): Optional. the ID of the room whose keys we're querying, if any.
user_id (str): the user whose backup we're querying
version (str): the version ID of the backup for the set of keys we're querying
room_id (str): Optional. the ID of the room whose keys we're querying, if any.
If not specified, we return the keys for all the rooms in the backup.
session_id(str): Optional. the session whose room_key we're querying, if any.
session_id (str): Optional. the session whose room_key we're querying, if any.
If specified, we also require the room_id to be specified.
If not specified, we return all the keys in this version of
the backup (or for the specified room)
@ -136,7 +125,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
if session_id:
keyvalues["session_id"] = session_id
rows = yield self._simple_select_list(
rows = yield self.db.simple_select_list(
table="e2e_room_keys",
keyvalues=keyvalues,
retcols=(
@ -157,12 +146,102 @@ class EndToEndRoomKeyStore(SQLBaseStore):
room_entry["sessions"][row["session_id"]] = {
"first_message_index": row["first_message_index"],
"forwarded_count": row["forwarded_count"],
"is_verified": row["is_verified"],
# is_verified must be returned to the client as a boolean
"is_verified": bool(row["is_verified"]),
"session_data": json.loads(row["session_data"]),
}
return sessions
def get_e2e_room_keys_multi(self, user_id, version, room_keys):
"""Get multiple room keys at a time. The difference between this function and
get_e2e_room_keys is that this function can be used to retrieve
multiple specific keys at a time, whereas get_e2e_room_keys is used for
getting all the keys in a backup version, all the keys for a room, or a
specific key.
Args:
user_id (str): the user whose backup we're querying
version (str): the version ID of the backup we're querying about
room_keys (dict[str, dict[str, iterable[str]]]): a map from
room ID -> {"session": [session ids]} indicating the session IDs
that we want to query
Returns:
Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key
"""
return self.db.runInteraction(
"get_e2e_room_keys_multi",
self._get_e2e_room_keys_multi_txn,
user_id,
version,
room_keys,
)
@staticmethod
def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys):
if not room_keys:
return {}
where_clauses = []
params = [user_id, version]
for room_id, room in room_keys.items():
sessions = list(room["sessions"])
if not sessions:
continue
params.append(room_id)
params.extend(sessions)
where_clauses.append(
"(room_id = ? AND session_id IN (%s))"
% (",".join(["?" for _ in sessions]),)
)
# check if we're actually querying something
if not where_clauses:
return {}
sql = """
SELECT room_id, session_id, first_message_index, forwarded_count,
is_verified, session_data
FROM e2e_room_keys
WHERE user_id = ? AND version = ? AND (%s)
""" % (
" OR ".join(where_clauses)
)
txn.execute(sql, params)
ret = {}
for row in txn:
room_id = row[0]
session_id = row[1]
ret.setdefault(room_id, {})
ret[room_id][session_id] = {
"first_message_index": row[2],
"forwarded_count": row[3],
"is_verified": row[4],
"session_data": json.loads(row[5]),
}
return ret
def count_e2e_room_keys(self, user_id, version):
"""Get the number of keys in a backup version.
Args:
user_id (str): the user whose backup we're querying
version (str): the version ID of the backup we're querying about
"""
return self.db.simple_select_one_onecol(
table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": version},
retcol="COUNT(*)",
desc="count_e2e_room_keys",
)
@trace
@defer.inlineCallbacks
def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
@ -189,7 +268,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
if session_id:
keyvalues["session_id"] = session_id
yield self._simple_delete(
yield self.db.simple_delete(
table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
)
@ -220,6 +299,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
version(str)
algorithm(str)
auth_data(object): opaque dict supplied by the client
etag(int): tag of the keys in the backup
"""
def _get_e2e_room_keys_version_info_txn(txn):
@ -233,17 +313,19 @@ class EndToEndRoomKeyStore(SQLBaseStore):
# it isn't there.
raise StoreError(404, "No row found")
result = self._simple_select_one_txn(
result = self.db.simple_select_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
retcols=("version", "algorithm", "auth_data"),
retcols=("version", "algorithm", "auth_data", "etag"),
)
result["auth_data"] = json.loads(result["auth_data"])
result["version"] = str(result["version"])
if result["etag"] is None:
result["etag"] = 0
return result
return self.runInteraction(
return self.db.runInteraction(
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
)
@ -271,7 +353,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
new_version = str(int(current_version) + 1)
self._simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="e2e_room_keys_versions",
values={
@ -284,26 +366,38 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return new_version
return self.runInteraction(
return self.db.runInteraction(
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
)
@trace
def update_e2e_room_keys_version(self, user_id, version, info):
def update_e2e_room_keys_version(
self, user_id, version, info=None, version_etag=None
):
"""Update a given backup version
Args:
user_id(str): the user whose backup version we're updating
version(str): the version ID of the backup version we're updating
info(dict): the new backup version info to store
info (dict): the new backup version info to store. If None, then
the backup version info is not updated
version_etag (Optional[int]): etag of the keys in the backup. If
None, then the etag is not updated
"""
updatevalues = {}
return self._simple_update(
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": version},
updatevalues={"auth_data": json.dumps(info["auth_data"])},
desc="update_e2e_room_keys_version",
)
if info is not None and "auth_data" in info:
updatevalues["auth_data"] = json.dumps(info["auth_data"])
if version_etag is not None:
updatevalues["etag"] = version_etag
if updatevalues:
return self.db.simple_update(
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": version},
updatevalues=updatevalues,
desc="update_e2e_room_keys_version",
)
@trace
def delete_e2e_room_keys_version(self, user_id, version=None):
@ -322,16 +416,24 @@ class EndToEndRoomKeyStore(SQLBaseStore):
def _delete_e2e_room_keys_version_txn(txn):
if version is None:
this_version = self._get_current_version(txn, user_id)
if this_version is None:
raise StoreError(404, "No current backup version")
else:
this_version = version
return self._simple_update_one_txn(
self.db.simple_delete_txn(
txn,
table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": this_version},
)
return self.db.simple_update_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version},
updatevalues={"deleted": 1},
)
return self.runInteraction(
return self.db.runInteraction(
"delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn
)

View file

@ -0,0 +1,721 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List
from six import iteritems
from canonicaljson import encode_canonical_json, json
from twisted.enterprise.adbapi import Connection
from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import make_in_list_sql_clause
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
class EndToEndKeyWorkerStore(SQLBaseStore):
@trace
@defer.inlineCallbacks
def get_e2e_device_keys(
self, query_list, include_all_devices=False, include_deleted_devices=False
):
"""Fetch a list of device keys.
Args:
query_list(list): List of pairs of user_ids and device_ids.
include_all_devices (bool): whether to include entries for devices
that don't have device keys
include_deleted_devices (bool): whether to include null entries for
devices which no longer exist (but were in the query_list).
This option only takes effect if include_all_devices is true.
Returns:
Dict mapping from user-id to dict mapping from device_id to
key data. The key data will be a dict in the same format as the
DeviceKeys type returned by POST /_matrix/client/r0/keys/query.
"""
set_tag("query_list", query_list)
if not query_list:
return {}
results = yield self.db.runInteraction(
"get_e2e_device_keys",
self._get_e2e_device_keys_txn,
query_list,
include_all_devices,
include_deleted_devices,
)
# Build the result structure, un-jsonify the results, and add the
# "unsigned" section
rv = {}
for user_id, device_keys in iteritems(results):
rv[user_id] = {}
for device_id, device_info in iteritems(device_keys):
r = db_to_json(device_info.pop("key_json"))
r["unsigned"] = {}
display_name = device_info["device_display_name"]
if display_name is not None:
r["unsigned"]["device_display_name"] = display_name
if "signatures" in device_info:
for sig_user_id, sigs in device_info["signatures"].items():
r.setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
rv[user_id][device_id] = r
return rv
@trace
def _get_e2e_device_keys_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
):
set_tag("include_all_devices", include_all_devices)
set_tag("include_deleted_devices", include_deleted_devices)
query_clauses = []
query_params = []
signature_query_clauses = []
signature_query_params = []
if include_all_devices is False:
include_deleted_devices = False
if include_deleted_devices:
deleted_devices = set(query_list)
for (user_id, device_id) in query_list:
query_clause = "user_id = ?"
query_params.append(user_id)
signature_query_clause = "target_user_id = ?"
signature_query_params.append(user_id)
if device_id is not None:
query_clause += " AND device_id = ?"
query_params.append(device_id)
signature_query_clause += " AND target_device_id = ?"
signature_query_params.append(device_id)
signature_query_clause += " AND user_id = ?"
signature_query_params.append(user_id)
query_clauses.append(query_clause)
signature_query_clauses.append(signature_query_clause)
sql = (
"SELECT user_id, device_id, "
" d.display_name AS device_display_name, "
" k.key_json"
" FROM devices d"
" %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
" WHERE %s AND NOT d.hidden"
) % (
"LEFT" if include_all_devices else "INNER",
" OR ".join("(" + q + ")" for q in query_clauses),
)
txn.execute(sql, query_params)
rows = self.db.cursor_to_dict(txn)
result = {}
for row in rows:
if include_deleted_devices:
deleted_devices.remove((row["user_id"], row["device_id"]))
result.setdefault(row["user_id"], {})[row["device_id"]] = row
if include_deleted_devices:
for user_id, device_id in deleted_devices:
result.setdefault(user_id, {})[device_id] = None
# get signatures on the device
signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % (
" OR ".join("(" + q + ")" for q in signature_query_clauses)
)
txn.execute(signature_sql, signature_query_params)
rows = self.db.cursor_to_dict(txn)
# add each cross-signing signature to the correct device in the result dict.
for row in rows:
signing_user_id = row["user_id"]
signing_key_id = row["key_id"]
target_user_id = row["target_user_id"]
target_device_id = row["target_device_id"]
signature = row["signature"]
target_user_result = result.get(target_user_id)
if not target_user_result:
continue
target_device_result = target_user_result.get(target_device_id)
if not target_device_result:
# note that target_device_result will be None for deleted devices.
continue
target_device_signatures = target_device_result.setdefault("signatures", {})
signing_user_signatures = target_device_signatures.setdefault(
signing_user_id, {}
)
signing_user_signatures[signing_key_id] = signature
log_kv(result)
return result
@defer.inlineCallbacks
def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
"""Retrieve a number of one-time keys for a user
Args:
user_id(str): id of user to get keys for
device_id(str): id of device to get keys for
key_ids(list[str]): list of key ids (excluding algorithm) to
retrieve
Returns:
deferred resolving to Dict[(str, str), str]: map from (algorithm,
key_id) to json string for key
"""
rows = yield self.db.simple_select_many_batch(
table="e2e_one_time_keys_json",
column="key_id",
iterable=key_ids,
retcols=("algorithm", "key_id", "key_json"),
keyvalues={"user_id": user_id, "device_id": device_id},
desc="add_e2e_one_time_keys_check",
)
result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
return result
@defer.inlineCallbacks
def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
"""Insert some new one time keys for a device. Errors if any of the
keys already exist.
Args:
user_id(str): id of user to get keys for
device_id(str): id of device to get keys for
time_now(long): insertion time to record (ms since epoch)
new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
(algorithm, key_id, key json)
"""
def _add_e2e_one_time_keys(txn):
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("new_keys", new_keys)
# We are protected from race between lookup and insertion due to
# a unique constraint. If there is a race of two calls to
# `add_e2e_one_time_keys` then they'll conflict and we will only
# insert one set.
self.db.simple_insert_many_txn(
txn,
table="e2e_one_time_keys_json",
values=[
{
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
"key_id": key_id,
"ts_added_ms": time_now,
"key_json": json_bytes,
}
for algorithm, key_id, json_bytes in new_keys
],
)
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
yield self.db.runInteraction(
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
)
@cached(max_entries=10000)
def count_e2e_one_time_keys(self, user_id, device_id):
""" Count the number of one time keys the server has for a device
Returns:
Dict mapping from algorithm to number of keys for that algorithm.
"""
def _count_e2e_one_time_keys(txn):
sql = (
"SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
" WHERE user_id = ? AND device_id = ?"
" GROUP BY algorithm"
)
txn.execute(sql, (user_id, device_id))
result = {}
for algorithm, key_count in txn:
result[algorithm] = key_count
return result
return self.db.runInteraction(
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
@defer.inlineCallbacks
def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
"""Returns a user's cross-signing key.
Args:
user_id (str): the user whose key is being requested
key_type (str): the type of key that is being requested: either 'master'
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
from_user_id (str): if specified, signatures made by this user on
the self-signing key will be included in the result
Returns:
dict of the key data or None if not found
"""
res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
user_keys = res.get(user_id)
if not user_keys:
return None
return user_keys.get(key_type)
@cached(num_args=1)
def _get_bare_e2e_cross_signing_keys(self, user_id):
"""Dummy function. Only used to make a cache for
_get_bare_e2e_cross_signing_keys_bulk.
"""
raise NotImplementedError()
@cachedList(
cached_method_name="_get_bare_e2e_cross_signing_keys",
list_name="user_ids",
num_args=1,
)
def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: List[str]
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched.
Args:
user_ids (list[str]): the users whose keys are being requested
Returns:
dict[str, dict[str, dict]]: mapping from user ID to key type to key
data. If a user's cross-signing keys were not found, either
their user ID will not be in the dict, or their user ID will map
to None.
"""
return self.db.runInteraction(
"get_bare_e2e_cross_signing_keys_bulk",
self._get_bare_e2e_cross_signing_keys_bulk_txn,
user_ids,
)
def _get_bare_e2e_cross_signing_keys_bulk_txn(
self, txn: Connection, user_ids: List[str],
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched.
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
user_ids (list[str]): the users whose keys are being requested
Returns:
dict[str, dict[str, dict]]: mapping from user ID to key type to key
data. If a user's cross-signing keys were not found, their user
ID will not be in the dict.
"""
result = {}
for user_chunk in batch_iter(user_ids, 100):
clause, params = make_in_list_sql_clause(
txn.database_engine, "k.user_id", user_chunk
)
sql = (
"""
SELECT k.user_id, k.keytype, k.keydata, k.stream_id
FROM e2e_cross_signing_keys k
INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
FROM e2e_cross_signing_keys
GROUP BY user_id, keytype) s
USING (user_id, stream_id, keytype)
WHERE
"""
+ clause
)
txn.execute(sql, params)
rows = self.db.cursor_to_dict(txn)
for row in rows:
user_id = row["user_id"]
key_type = row["keytype"]
key = json.loads(row["keydata"])
user_info = result.setdefault(user_id, {})
user_info[key_type] = key
return result
def _get_e2e_cross_signing_signatures_txn(
self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str,
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing signatures made by a user on a set of keys.
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
keys (dict[str, dict[str, dict]]): a map of user ID to key type to
key data. This dict will be modified to add signatures.
from_user_id (str): fetch the signatures made by this user
Returns:
dict[str, dict[str, dict]]: mapping from user ID to key type to key
data. The return value will be the same as the keys argument,
with the modifications included.
"""
# find out what cross-signing keys (a.k.a. devices) we need to get
# signatures for. This is a map of (user_id, device_id) to key type
# (device_id is the key's public part).
devices = {}
for user_id, user_info in keys.items():
if user_info is None:
continue
for key_type, key in user_info.items():
device_id = None
for k in key["keys"].values():
device_id = k
devices[(user_id, device_id)] = key_type
for batch in batch_iter(devices.keys(), size=100):
sql = """
SELECT target_user_id, target_device_id, key_id, signature
FROM e2e_cross_signing_signatures
WHERE user_id = ?
AND (%s)
""" % (
" OR ".join(
"(target_user_id = ? AND target_device_id = ?)" for _ in batch
)
)
query_params = [from_user_id]
for item in batch:
# item is a (user_id, device_id) tuple
query_params.extend(item)
txn.execute(sql, query_params)
rows = self.db.cursor_to_dict(txn)
# and add the signatures to the appropriate keys
for row in rows:
key_id = row["key_id"]
target_user_id = row["target_user_id"]
target_device_id = row["target_device_id"]
key_type = devices[(target_user_id, target_device_id)]
# We need to copy everything, because the result may have come
# from the cache. dict.copy only does a shallow copy, so we
# need to recursively copy the dicts that will be modified.
user_info = keys[target_user_id] = keys[target_user_id].copy()
target_user_key = user_info[key_type] = user_info[key_type].copy()
if "signatures" in target_user_key:
signatures = target_user_key["signatures"] = target_user_key[
"signatures"
].copy()
if from_user_id in signatures:
user_sigs = signatures[from_user_id] = signatures[from_user_id]
user_sigs[key_id] = row["signature"]
else:
signatures[from_user_id] = {key_id: row["signature"]}
else:
target_user_key["signatures"] = {
from_user_id: {key_id: row["signature"]}
}
return keys
@defer.inlineCallbacks
def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: str = None
) -> defer.Deferred:
"""Returns the cross-signing keys for a set of users.
Args:
user_ids (list[str]): the users whose keys are being requested
from_user_id (str): if specified, signatures made by this user on
the self-signing keys will be included in the result
Returns:
Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to
key data. If a user's cross-signing keys were not found, either
their user ID will not be in the dict, or their user ID will map
to None.
"""
result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
if from_user_id:
result = yield self.db.runInteraction(
"get_e2e_cross_signing_signatures",
self._get_e2e_cross_signing_signatures_txn,
result,
from_user_id,
)
return result
def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit):
"""Return a list of changes from the user signature stream to notify remotes.
Note that the user signature stream represents when a user signs their
device with their user-signing key, which is not published to other
users or servers, so no `destination` is needed in the returned
list. However, this is needed to poke workers.
Args:
from_key (int): the stream ID to start at (exclusive)
to_key (int): the stream ID to end at (inclusive)
Returns:
Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
"""
sql = """
SELECT stream_id, from_user_id AS user_id
FROM user_signature_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
return self.db.execute(
"get_all_user_signature_changes_for_remotes",
None,
sql,
from_key,
to_key,
limit,
)
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
"""
def _set_e2e_device_keys_txn(txn):
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("time_now", time_now)
set_tag("device_keys", device_keys)
old_key_json = self.db.simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="key_json",
allow_none=True,
)
# In py3 we need old_key_json to match new_key_json type. The DB
# returns unicode while encode_canonical_json returns bytes.
new_key_json = encode_canonical_json(device_keys).decode("utf-8")
if old_key_json == new_key_json:
log_kv({"Message": "Device key already stored."})
return False
self.db.simple_upsert_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
values={"ts_added_ms": time_now, "key_json": new_key_json},
)
log_kv({"message": "Device keys stored."})
return True
return self.db.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn)
def claim_e2e_one_time_keys(self, query_list):
"""Take a list of one time keys out of the database"""
@trace
def _claim_e2e_one_time_keys(txn):
sql = (
"SELECT key_id, key_json FROM e2e_one_time_keys_json"
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
" LIMIT 1"
)
result = {}
delete = []
for user_id, device_id, algorithm in query_list:
user_result = result.setdefault(user_id, {})
device_result = user_result.setdefault(device_id, {})
txn.execute(sql, (user_id, device_id, algorithm))
for key_id, key_json in txn:
device_result[algorithm + ":" + key_id] = key_json
delete.append((user_id, device_id, algorithm, key_id))
sql = (
"DELETE FROM e2e_one_time_keys_json"
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
" AND key_id = ?"
)
for user_id, device_id, algorithm, key_id in delete:
log_kv(
{
"message": "Executing claim e2e_one_time_keys transaction on database."
}
)
txn.execute(sql, (user_id, device_id, algorithm, key_id))
log_kv({"message": "finished executing and invalidating cache"})
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
return result
return self.db.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
def delete_e2e_keys_by_device(self, user_id, device_id):
def delete_e2e_keys_by_device_txn(txn):
log_kv(
{
"message": "Deleting keys for device",
"device_id": device_id,
"user_id": user_id,
}
)
self.db.simple_delete_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
)
self.db.simple_delete_txn(
txn,
table="e2e_one_time_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
)
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
return self.db.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)
def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key):
"""Set a user's cross-signing key.
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
user_id (str): the user to set the signing key for
key_type (str): the type of key that is being set: either 'master'
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
key (dict): the key data
"""
# the 'key' dict will look something like:
# {
# "user_id": "@alice:example.com",
# "usage": ["self_signing"],
# "keys": {
# "ed25519:base64+self+signing+public+key": "base64+self+signing+public+key",
# },
# "signatures": {
# "@alice:example.com": {
# "ed25519:base64+master+public+key": "base64+signature"
# }
# }
# }
# The "keys" property must only have one entry, which will be the public
# key, so we just grab the first value in there
pubkey = next(iter(key["keys"].values()))
# The cross-signing keys need to occupy the same namespace as devices,
# since signatures are identified by device ID. So add an entry to the
# device table to make sure that we don't have a collision with device
# IDs.
# We only need to do this for local users, since remote servers should be
# responsible for checking this for their own users.
if self.hs.is_mine_id(user_id):
self.db.simple_insert_txn(
txn,
"devices",
values={
"user_id": user_id,
"device_id": pubkey,
"display_name": key_type + " signing key",
"hidden": True,
},
)
# and finally, store the key itself
with self._cross_signing_id_gen.get_next() as stream_id:
self.db.simple_insert_txn(
txn,
"e2e_cross_signing_keys",
values={
"user_id": user_id,
"keytype": key_type,
"keydata": json.dumps(key),
"stream_id": stream_id,
},
)
self._invalidate_cache_and_stream(
txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
)
def set_e2e_cross_signing_key(self, user_id, key_type, key):
"""Set a user's cross-signing key.
Args:
user_id (str): the user to set the user-signing key for
key_type (str): the type of cross-signing key to set
key (dict): the key data
"""
return self.db.runInteraction(
"add_e2e_cross_signing_key",
self._set_e2e_cross_signing_key_txn,
user_id,
key_type,
key,
)
def store_e2e_cross_signing_signatures(self, user_id, signatures):
"""Stores cross-signing signatures.
Args:
user_id (str): the user who made the signatures
signatures (iterable[SignatureListItem]): signatures to add
"""
return self.db.simple_insert_many(
"e2e_cross_signing_signatures",
[
{
"user_id": user_id,
"key_id": item.signing_key_id,
"target_user_id": item.target_user_id,
"target_device_id": item.target_device_id,
"signature": item.signature,
}
for item in signatures
],
"add_e2e_signing_key",
)

View file

@ -12,22 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
import random
from typing import Dict, List, Optional, Set, Tuple
from six.moves import range
from six.moves.queue import Empty, PriorityQueue
from unpaddedbase64 import encode_base64
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.events_worker import EventsWorkerStore
from synapse.storage.signatures import SignatureWorkerStore
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
from synapse.storage.database import Database
from synapse.util.caches.descriptors import cached
from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__)
@ -47,37 +47,55 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
event_ids, include_given=include_given
).addCallback(self.get_events_as_list)
def get_auth_chain_ids(self, event_ids, include_given=False):
def get_auth_chain_ids(
self,
event_ids: List[str],
include_given: bool = False,
ignore_events: Optional[Set[str]] = None,
):
"""Get auth events for given event_ids. The events *must* be state events.
Args:
event_ids (list): state events
include_given (bool): include the given events in result
event_ids: state events
include_given: include the given events in result
ignore_events: Set of events to exclude from the returned auth
chain. This is useful if the caller will just discard the
given events anyway, and saves us from figuring out their auth
chains if not required.
Returns:
list of event_ids
"""
return self.runInteraction(
"get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given
return self.db.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
event_ids,
include_given,
ignore_events,
)
def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
if ignore_events is None:
ignore_events = set()
if include_given:
results = set(event_ids)
else:
results = set()
base_sql = "SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
base_sql = "SELECT auth_id FROM event_auth WHERE "
front = set(event_ids)
while front:
new_front = set()
front_list = list(front)
chunks = [front_list[x : x + 100] for x in range(0, len(front), 100)]
for chunk in chunks:
txn.execute(base_sql % (",".join(["?"] * len(chunk)),), chunk)
new_front.update([r[0] for r in txn])
for chunk in batch_iter(front, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "event_id", chunk
)
txn.execute(base_sql + clause, args)
new_front.update(r[0] for r in txn)
new_front -= ignore_events
new_front -= results
front = new_front
@ -85,13 +103,170 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return list(results)
def get_auth_chain_difference(self, state_sets: List[Set[str]]):
"""Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).
This equivalent to fetching the full auth chain for each set of state
and returning the events that don't appear in each and every auth
chain.
Returns:
Deferred[Set[str]]
"""
return self.db.runInteraction(
"get_auth_chain_difference",
self._get_auth_chain_difference_txn,
state_sets,
)
def _get_auth_chain_difference_txn(
self, txn, state_sets: List[Set[str]]
) -> Set[str]:
# Algorithm Description
# ~~~~~~~~~~~~~~~~~~~~~
#
# The idea here is to basically walk the auth graph of each state set in
# tandem, keeping track of which auth events are reachable by each state
# set. If we reach an auth event we've already visited (via a different
# state set) then we mark that auth event and all ancestors as reachable
# by the state set. This requires that we keep track of the auth chains
# in memory.
#
# Doing it in a such a way means that we can stop early if all auth
# events we're currently walking are reachable by all state sets.
#
# *Note*: We can't stop walking an event's auth chain if it is reachable
# by all state sets. This is because other auth chains we're walking
# might be reachable only via the original auth chain. For example,
# given the following auth chain:
#
# A -> C -> D -> E
# / /
# B -´---------´
#
# and state sets {A} and {B} then walking the auth chains of A and B
# would immediately show that C is reachable by both. However, if we
# stopped at C then we'd only reach E via the auth chain of B and so E
# would errornously get included in the returned difference.
#
# The other thing that we do is limit the number of auth chains we walk
# at once, due to practical limits (i.e. we can only query the database
# with a limited set of parameters). We pick the auth chains we walk
# each iteration based on their depth, in the hope that events with a
# lower depth are likely reachable by those with higher depths.
#
# We could use any ordering that we believe would give a rough
# topological ordering, e.g. origin server timestamp. If the ordering
# chosen is not topological then the algorithm still produces the right
# result, but perhaps a bit more inefficiently. This is why it is safe
# to use "depth" here.
initial_events = set(state_sets[0]).union(*state_sets[1:])
# Dict from events in auth chains to which sets *cannot* reach them.
# I.e. if the set is empty then all sets can reach the event.
event_to_missing_sets = {
event_id: {i for i, a in enumerate(state_sets) if event_id not in a}
for event_id in initial_events
}
# The sorted list of events whose auth chains we should walk.
search = [] # type: List[Tuple[int, str]]
# We need to get the depth of the initial events for sorting purposes.
sql = """
SELECT depth, event_id FROM events
WHERE %s
"""
# the list can be huge, so let's avoid looking them all up in one massive
# query.
for batch in batch_iter(initial_events, 1000):
clause, args = make_in_list_sql_clause(
txn.database_engine, "event_id", batch
)
txn.execute(sql % (clause,), args)
# I think building a temporary list with fetchall is more efficient than
# just `search.extend(txn)`, but this is unconfirmed
search.extend(txn.fetchall())
# sort by depth
search.sort()
# Map from event to its auth events
event_to_auth_events = {} # type: Dict[str, Set[str]]
base_sql = """
SELECT a.event_id, auth_id, depth
FROM event_auth AS a
INNER JOIN events AS e ON (e.event_id = a.auth_id)
WHERE
"""
while search:
# Check whether all our current walks are reachable by all state
# sets. If so we can bail.
if all(not event_to_missing_sets[eid] for _, eid in search):
break
# Fetch the auth events and their depths of the N last events we're
# currently walking
search, chunk = search[:-100], search[-100:]
clause, args = make_in_list_sql_clause(
txn.database_engine, "a.event_id", [e_id for _, e_id in chunk]
)
txn.execute(base_sql + clause, args)
for event_id, auth_event_id, auth_event_depth in txn:
event_to_auth_events.setdefault(event_id, set()).add(auth_event_id)
sets = event_to_missing_sets.get(auth_event_id)
if sets is None:
# First time we're seeing this event, so we add it to the
# queue of things to fetch.
search.append((auth_event_depth, auth_event_id))
# Assume that this event is unreachable from any of the
# state sets until proven otherwise
sets = event_to_missing_sets[auth_event_id] = set(
range(len(state_sets))
)
else:
# We've previously seen this event, so look up its auth
# events and recursively mark all ancestors as reachable
# by the current event's state set.
a_ids = event_to_auth_events.get(auth_event_id)
while a_ids:
new_aids = set()
for a_id in a_ids:
event_to_missing_sets[a_id].intersection_update(
event_to_missing_sets[event_id]
)
b = event_to_auth_events.get(a_id)
if b:
new_aids.update(b)
a_ids = new_aids
# Mark that the auth event is reachable by the approriate sets.
sets.intersection_update(event_to_missing_sets[event_id])
search.sort()
# Return all events where not all sets can reach them.
return {eid for eid, n in event_to_missing_sets.items() if n}
def get_oldest_events_in_room(self, room_id):
return self.runInteraction(
return self.db.runInteraction(
"get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
)
def get_oldest_events_with_depth_in_room(self, room_id):
return self.runInteraction(
return self.db.runInteraction(
"get_oldest_events_with_depth_in_room",
self.get_oldest_events_with_depth_in_room_txn,
room_id,
@ -122,7 +297,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
Returns
Deferred[int]
"""
rows = yield self._simple_select_many_batch(
rows = yield self.db.simple_select_many_batch(
table="events",
column="event_id",
iterable=event_ids,
@ -136,15 +311,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return max(row["depth"] for row in rows)
def _get_oldest_events_in_room_txn(self, txn, room_id):
return self._simple_select_onecol_txn(
return self.db.simple_select_onecol_txn(
txn,
table="event_backward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
)
@defer.inlineCallbacks
def get_prev_events_for_room(self, room_id):
def get_prev_events_for_room(self, room_id: str):
"""
Gets a subset of the current forward extremities in the given room.
@ -155,47 +329,37 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
room_id (str): room_id
Returns:
Deferred[list[(str, dict[str, str], int)]]
for each event, a tuple of (event_id, hashes, depth)
where *hashes* is a map from algorithm to hash.
"""
res = yield self.get_latest_event_ids_and_hashes_in_room(room_id)
if len(res) > 10:
# Sort by reverse depth, so we point to the most recent.
res.sort(key=lambda a: -a[2])
Deferred[List[str]]: the event ids of the forward extremites
# we use half of the limit for the actual most recent events, and
# the other half to randomly point to some of the older events, to
# make sure that we don't completely ignore the older events.
res = res[0:5] + random.sample(res[5:], 5)
return res
def get_latest_event_ids_and_hashes_in_room(self, room_id):
"""
Gets the current forward extremities in the given room
Args:
room_id (str): room_id
Returns:
Deferred[list[(str, dict[str, str], int)]]
for each event, a tuple of (event_id, hashes, depth)
where *hashes* is a map from algorithm to hash.
"""
return self.runInteraction(
"get_latest_event_ids_and_hashes_in_room",
self._get_latest_event_ids_and_hashes_in_room,
room_id,
return self.db.runInteraction(
"get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
)
def get_rooms_with_many_extremities(self, min_count, limit):
def _get_prev_events_for_room_txn(self, txn, room_id: str):
# we just use the 10 newest events. Older events will become
# prev_events of future events.
sql = """
SELECT e.event_id FROM event_forward_extremities AS f
INNER JOIN events AS e USING (event_id)
WHERE f.room_id = ?
ORDER BY e.depth DESC
LIMIT 10
"""
txn.execute(sql, (room_id,))
return [row[0] for row in txn]
def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter):
"""Get the top rooms with at least N extremities.
Args:
min_count (int): The minimum number of extremities
limit (int): The maximum number of rooms to return.
room_id_filter (iterable[str]): room_ids to exclude from the results
Returns:
Deferred[list]: At most `limit` room IDs that have at least
@ -203,60 +367,49 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"""
def _get_rooms_with_many_extremities_txn(txn):
where_clause = "1=1"
if room_id_filter:
where_clause = "room_id NOT IN (%s)" % (
",".join("?" for _ in room_id_filter),
)
sql = """
SELECT room_id FROM event_forward_extremities
WHERE %s
GROUP BY room_id
HAVING count(*) > ?
ORDER BY count(*) DESC
LIMIT ?
"""
""" % (
where_clause,
)
txn.execute(sql, (min_count, limit))
query_args = list(itertools.chain(room_id_filter, [min_count, limit]))
txn.execute(sql, query_args)
return [room_id for room_id, in txn]
return self.runInteraction(
return self.db.runInteraction(
"get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn
)
@cached(max_entries=5000, iterable=True)
def get_latest_event_ids_in_room(self, room_id):
return self._simple_select_onecol(
return self.db.simple_select_onecol(
table="event_forward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
desc="get_latest_event_ids_in_room",
)
def _get_latest_event_ids_and_hashes_in_room(self, txn, room_id):
sql = (
"SELECT e.event_id, e.depth FROM events as e "
"INNER JOIN event_forward_extremities as f "
"ON e.event_id = f.event_id "
"AND e.room_id = f.room_id "
"WHERE f.room_id = ?"
)
txn.execute(sql, (room_id,))
results = []
for event_id, depth in txn.fetchall():
hashes = self._get_event_reference_hashes_txn(txn, event_id)
prev_hashes = {
k: encode_base64(v) for k, v in hashes.items() if k == "sha256"
}
results.append((event_id, prev_hashes, depth))
return results
def get_min_depth(self, room_id):
""" For hte given room, get the minimum depth we have seen for it.
"""
return self.runInteraction(
return self.db.runInteraction(
"get_min_depth", self._get_min_depth_interaction, room_id
)
def _get_min_depth_interaction(self, txn, room_id):
min_depth = self._simple_select_one_onecol_txn(
min_depth = self.db.simple_select_one_onecol_txn(
txn,
table="room_depth",
keyvalues={"room_id": room_id},
@ -322,7 +475,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(sql, (stream_ordering, room_id))
return [event_id for event_id, in txn]
return self.runInteraction(
return self.db.runInteraction(
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
)
@ -337,7 +490,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
limit (int)
"""
return (
self.runInteraction(
self.db.runInteraction(
"get_backfill_events",
self._get_backfill_events,
room_id,
@ -349,9 +502,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
)
def _get_backfill_events(self, txn, room_id, event_list, limit):
logger.debug(
"_get_backfill_events: %s, %s, %s", room_id, repr(event_list), limit
)
logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
event_results = set()
@ -370,7 +521,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
queue = PriorityQueue()
for event_id in event_list:
depth = self._simple_select_one_onecol_txn(
depth = self.db.simple_select_one_onecol_txn(
txn,
table="events",
keyvalues={"event_id": event_id, "room_id": room_id},
@ -402,7 +553,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
@defer.inlineCallbacks
def get_missing_events(self, room_id, earliest_events, latest_events, limit):
ids = yield self.runInteraction(
ids = yield self.db.runInteraction(
"get_missing_events",
self._get_missing_events,
room_id,
@ -432,7 +583,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
query, (room_id, event_id, False, limit - len(event_results))
)
new_results = set(t[0] for t in txn) - seen_events
new_results = {t[0] for t in txn} - seen_events
new_front |= new_results
seen_events |= new_results
@ -455,7 +606,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
Returns:
Deferred[list[str]]
"""
rows = yield self._simple_select_many_batch(
rows = yield self.db.simple_select_many_batch(
table="event_edges",
column="prev_event_id",
iterable=event_ids,
@ -478,10 +629,10 @@ class EventFederationStore(EventFederationWorkerStore):
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
def __init__(self, db_conn, hs):
super(EventFederationStore, self).__init__(db_conn, hs)
def __init__(self, database: Database, db_conn, hs):
super(EventFederationStore, self).__init__(database, db_conn, hs)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
)
@ -489,89 +640,6 @@ class EventFederationStore(EventFederationWorkerStore):
self._delete_old_forward_extrem_cache, 60 * 60 * 1000
)
def _update_min_depth_for_room_txn(self, txn, room_id, depth):
min_depth = self._get_min_depth_interaction(txn, room_id)
if min_depth and depth >= min_depth:
return
self._simple_upsert_txn(
txn,
table="room_depth",
keyvalues={"room_id": room_id},
values={"min_depth": depth},
)
def _handle_mult_prev_events(self, txn, events):
"""
For the given event, update the event edges table and forward and
backward extremities tables.
"""
self._simple_insert_many_txn(
txn,
table="event_edges",
values=[
{
"event_id": ev.event_id,
"prev_event_id": e_id,
"room_id": ev.room_id,
"is_state": False,
}
for ev in events
for e_id in ev.prev_event_ids()
],
)
self._update_backward_extremeties(txn, events)
def _update_backward_extremeties(self, txn, events):
"""Updates the event_backward_extremities tables based on the new/updated
events being persisted.
This is called for new events *and* for events that were outliers, but
are now being persisted as non-outliers.
Forward extremities are handled when we first start persisting the events.
"""
events_by_room = {}
for ev in events:
events_by_room.setdefault(ev.room_id, []).append(ev)
query = (
"INSERT INTO event_backward_extremities (event_id, room_id)"
" SELECT ?, ? WHERE NOT EXISTS ("
" SELECT 1 FROM event_backward_extremities"
" WHERE event_id = ? AND room_id = ?"
" )"
" AND NOT EXISTS ("
" SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
" AND outlier = ?"
" )"
)
txn.executemany(
query,
[
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
for ev in events
for e_id in ev.prev_event_ids()
if not ev.internal_metadata.is_outlier()
],
)
query = (
"DELETE FROM event_backward_extremities"
" WHERE event_id = ? AND room_id = ?"
)
txn.executemany(
query,
[
(ev.event_id, ev.room_id)
for ev in events
if not ev.internal_metadata.is_outlier()
],
)
def _delete_old_forward_extrem_cache(self):
def _delete_old_forward_extrem_cache_txn(txn):
# Delete entries older than a month, while making sure we don't delete
@ -591,13 +659,13 @@ class EventFederationStore(EventFederationWorkerStore):
return run_as_background_process(
"delete_old_forward_extrem_cache",
self.runInteraction,
self.db.runInteraction,
"_delete_old_forward_extrem_cache",
_delete_old_forward_extrem_cache_txn,
)
def clean_room_for_join(self, room_id):
return self.runInteraction(
return self.db.runInteraction(
"clean_room_for_join", self._clean_room_for_join_txn, room_id
)
@ -641,17 +709,17 @@ class EventFederationStore(EventFederationWorkerStore):
"max_stream_id_exclusive": min_stream_id,
}
self._background_update_progress_txn(
self.db.updates._background_update_progress_txn(
txn, self.EVENT_AUTH_STATE_ONLY, new_progress
)
return min_stream_id >= target_min_stream_id
result = yield self.runInteraction(
result = yield self.db.runInteraction(
self.EVENT_AUTH_STATE_ONLY, delete_event_auth
)
if not result:
yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY)
yield self.db.updates._end_background_update(self.EVENT_AUTH_STATE_ONLY)
return batch_size

View file

@ -24,6 +24,7 @@ from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore
from synapse.storage.database import Database
from synapse.util.caches.descriptors import cachedInlineCallbacks
logger = logging.getLogger(__name__)
@ -68,8 +69,8 @@ def _deserialize_action(actions, is_highlight):
class EventPushActionsWorkerStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(EventPushActionsWorkerStore, self).__init__(db_conn, hs)
def __init__(self, database: Database, db_conn, hs):
super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs)
# These get correctly set by _find_stream_orderings_for_times_txn
self.stream_ordering_month_ago = None
@ -93,7 +94,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id
):
ret = yield self.runInteraction(
ret = yield self.db.runInteraction(
"get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn,
room_id,
@ -177,7 +178,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn]
ret = yield self.runInteraction("get_push_action_users_in_range", f)
ret = yield self.db.runInteraction("get_push_action_users_in_range", f)
return ret
@defer.inlineCallbacks
@ -230,7 +231,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
after_read_receipt = yield self.runInteraction(
after_read_receipt = yield self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
)
@ -259,7 +260,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
no_read_receipt = yield self.runInteraction(
no_read_receipt = yield self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
)
@ -332,7 +333,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
after_read_receipt = yield self.runInteraction(
after_read_receipt = yield self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
)
@ -361,7 +362,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
no_read_receipt = yield self.runInteraction(
no_read_receipt = yield self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
)
@ -411,7 +412,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, min_stream_ordering))
return bool(txn.fetchone())
return self.runInteraction(
return self.db.runInteraction(
"get_if_maybe_push_in_range_for_user",
_get_if_maybe_push_in_range_for_user_txn,
)
@ -446,7 +447,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
def _add_push_actions_to_staging_txn(txn):
# We don't use _simple_insert_many here to avoid the overhead
# We don't use simple_insert_many here to avoid the overhead
# of generating lists of dicts.
sql = """
@ -463,7 +464,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
),
)
return self.runInteraction(
return self.db.runInteraction(
"add_push_actions_to_staging", _add_push_actions_to_staging_txn
)
@ -477,7 +478,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"""
try:
res = yield self._simple_delete(
res = yield self.db.simple_delete(
table="event_push_actions_staging",
keyvalues={"event_id": event_id},
desc="remove_push_actions_from_staging",
@ -494,7 +495,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def _find_stream_orderings_for_times(self):
return run_as_background_process(
"event_push_action_stream_orderings",
self.runInteraction,
self.db.runInteraction,
"_find_stream_orderings_for_times",
self._find_stream_orderings_for_times_txn,
)
@ -530,7 +531,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
Deferred[int]: stream ordering of the first event received on/after
the timestamp
"""
return self.runInteraction(
return self.db.runInteraction(
"_find_first_stream_ordering_after_ts_txn",
self._find_first_stream_ordering_after_ts_txn,
ts,
@ -612,21 +613,38 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return range_end
@defer.inlineCallbacks
def get_time_of_last_push_action_before(self, stream_ordering):
def f(txn):
sql = (
"SELECT e.received_ts"
" FROM event_push_actions AS ep"
" JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
" WHERE ep.stream_ordering > ?"
" ORDER BY ep.stream_ordering ASC"
" LIMIT 1"
)
txn.execute(sql, (stream_ordering,))
return txn.fetchone()
result = yield self.db.runInteraction("get_time_of_last_push_action_before", f)
return result[0] if result else None
class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
def __init__(self, db_conn, hs):
super(EventPushActionsStore, self).__init__(db_conn, hs)
def __init__(self, database: Database, db_conn, hs):
super(EventPushActionsStore, self).__init__(database, db_conn, hs)
self.register_background_index_update(
self.db.updates.register_background_index_update(
self.EPA_HIGHLIGHT_INDEX,
index_name="event_push_actions_u_highlight",
table="event_push_actions",
columns=["user_id", "stream_ordering"],
)
self.register_background_index_update(
self.db.updates.register_background_index_update(
"event_push_actions_highlights_index",
index_name="event_push_actions_highlights_index",
table="event_push_actions",
@ -639,69 +657,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
self._start_rotate_notifs, 30 * 60 * 1000
)
def _set_push_actions_for_event_and_users_txn(
self, txn, events_and_contexts, all_events_and_contexts
):
"""Handles moving push actions from staging table to main
event_push_actions table for all events in `events_and_contexts`.
Also ensures that all events in `all_events_and_contexts` are removed
from the push action staging area.
Args:
events_and_contexts (list[(EventBase, EventContext)]): events
we are persisting
all_events_and_contexts (list[(EventBase, EventContext)]): all
events that we were going to persist. This includes events
we've already persisted, etc, that wouldn't appear in
events_and_context.
"""
sql = """
INSERT INTO event_push_actions (
room_id, event_id, user_id, actions, stream_ordering,
topological_ordering, notif, highlight
)
SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
FROM event_push_actions_staging
WHERE event_id = ?
"""
if events_and_contexts:
txn.executemany(
sql,
(
(
event.room_id,
event.internal_metadata.stream_ordering,
event.depth,
event.event_id,
)
for event, _ in events_and_contexts
),
)
for event, _ in events_and_contexts:
user_ids = self._simple_select_onecol_txn(
txn,
table="event_push_actions_staging",
keyvalues={"event_id": event.event_id},
retcol="user_id",
)
for uid in user_ids:
txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
(event.room_id, uid),
)
# Now we delete the staging area for *all* events that were being
# persisted.
txn.executemany(
"DELETE FROM event_push_actions_staging WHERE event_id = ?",
((event.event_id,) for event, _ in all_events_and_contexts),
)
@defer.inlineCallbacks
def get_push_actions_for_user(
self, user_id, before=None, limit=50, only_highlight=False
@ -732,50 +687,24 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
" LIMIT ?" % (before_clause,)
)
txn.execute(sql, args)
return self.cursor_to_dict(txn)
return self.db.cursor_to_dict(txn)
push_actions = yield self.runInteraction("get_push_actions_for_user", f)
push_actions = yield self.db.runInteraction("get_push_actions_for_user", f)
for pa in push_actions:
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
return push_actions
@defer.inlineCallbacks
def get_time_of_last_push_action_before(self, stream_ordering):
def f(txn):
sql = (
"SELECT e.received_ts"
" FROM event_push_actions AS ep"
" JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
" WHERE ep.stream_ordering > ?"
" ORDER BY ep.stream_ordering ASC"
" LIMIT 1"
)
txn.execute(sql, (stream_ordering,))
return txn.fetchone()
result = yield self.runInteraction("get_time_of_last_push_action_before", f)
return result[0] if result else None
@defer.inlineCallbacks
def get_latest_push_action_stream_ordering(self):
def f(txn):
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
return txn.fetchone()
result = yield self.runInteraction("get_latest_push_action_stream_ordering", f)
result = yield self.db.runInteraction(
"get_latest_push_action_stream_ordering", f
)
return result[0] or 0
def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
# Sad that we have to blow away the cache for the whole room here
txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
(room_id,),
)
txn.execute(
"DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
(room_id, event_id),
)
def _remove_old_push_actions_before_txn(
self, txn, room_id, user_id, stream_ordering
):
@ -835,7 +764,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
while True:
logger.info("Rotating notifications")
caught_up = yield self.runInteraction(
caught_up = yield self.db.runInteraction(
"_rotate_notifs", self._rotate_notifs_txn
)
if caught_up:
@ -849,7 +778,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
the archiving process has caught up or not.
"""
old_rotate_stream_ordering = self._simple_select_one_onecol_txn(
old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
keyvalues={},
@ -868,7 +797,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
)
stream_row = txn.fetchone()
if stream_row:
offset_stream_ordering, = stream_row
(offset_stream_ordering,) = stream_row
rotate_to_stream_ordering = min(
self.stream_ordering_day_ago, offset_stream_ordering
)
@ -885,7 +814,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
return caught_up
def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering):
old_rotate_stream_ordering = self._simple_select_one_onecol_txn(
old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
keyvalues={},
@ -917,7 +846,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
# If the `old.user_id` above is NULL then we know there isn't already an
# entry in the table, so we simply insert it. Otherwise we update the
# existing table.
self._simple_insert_many_txn(
self.db.simple_insert_many_txn(
txn,
table="event_push_summary",
values=[

File diff suppressed because it is too large Load diff

View file

@ -21,29 +21,31 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.api.constants import EventContentFields
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import Database
logger = logging.getLogger(__name__)
class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
class EventsBackgroundUpdatesStore(SQLBaseStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
def __init__(self, db_conn, hs):
super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs)
def __init__(self, database: Database, db_conn, hs):
super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
self._background_reindex_fields_sender,
)
self.register_background_index_update(
self.db.updates.register_background_index_update(
"event_contains_url_index",
index_name="event_contains_url_index",
table="events",
@ -54,7 +56,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
# an event_id index on event_search is useful for the purge_history
# api. Plus it means we get to enforce some integrity with a UNIQUE
# clause
self.register_background_index_update(
self.db.updates.register_background_index_update(
"event_search_event_id_idx",
index_name="event_search_event_id_idx",
table="event_search",
@ -63,10 +65,39 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
psql_only=True,
)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update
)
self.db.updates.register_background_update_handler(
"redactions_received_ts", self._redactions_received_ts
)
# This index gets deleted in `event_fix_redactions_bytes` update
self.db.updates.register_background_index_update(
"event_fix_redactions_bytes_create_index",
index_name="redactions_censored_redacts",
table="redactions",
columns=["redacts"],
where_clause="have_censored",
)
self.db.updates.register_background_update_handler(
"event_fix_redactions_bytes", self._event_fix_redactions_bytes
)
self.db.updates.register_background_update_handler(
"event_store_labels", self._event_store_labels
)
self.db.updates.register_background_index_update(
"redactions_have_censored_ts_idx",
index_name="redactions_have_censored_ts",
table="redactions",
columns=["received_ts"],
where_clause="NOT have_censored",
)
@defer.inlineCallbacks
def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
@ -122,18 +153,20 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
"rows_inserted": rows_inserted + len(rows),
}
self._background_update_progress_txn(
self.db.updates._background_update_progress_txn(
txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
)
return len(rows)
result = yield self.runInteraction(
result = yield self.db.runInteraction(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
)
if not result:
yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME)
yield self.db.updates._end_background_update(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
)
return result
@ -166,7 +199,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
for chunk in chunks:
ev_rows = self._simple_select_many_txn(
ev_rows = self.db.simple_select_many_txn(
txn,
table="event_json",
column="event_id",
@ -199,18 +232,20 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
"rows_inserted": rows_inserted + len(rows_to_update),
}
self._background_update_progress_txn(
self.db.updates._background_update_progress_txn(
txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress
)
return len(rows_to_update)
result = yield self.runInteraction(
result = yield self.db.runInteraction(
self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
)
if not result:
yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME)
yield self.db.updates._end_background_update(
self.EVENT_ORIGIN_SERVER_TS_NAME
)
return result
@ -308,12 +343,13 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
INNER JOIN event_json USING (event_id)
LEFT JOIN rejections USING (event_id)
WHERE
prev_event_id IN (%s)
AND NOT events.outlier
""" % (
",".join("?" for _ in to_check),
NOT events.outlier
AND
"""
clause, args = make_in_list_sql_clause(
self.database_engine, "prev_event_id", to_check
)
txn.execute(sql, to_check)
txn.execute(sql + clause, list(args))
for prev_event_id, event_id, metadata, rejected in txn:
if event_id in graph:
@ -342,7 +378,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
to_delete.intersection_update(original_set)
deleted = self._simple_delete_many_txn(
deleted = self.db.simple_delete_many_txn(
txn=txn,
table="event_forward_extremities",
column="event_id",
@ -358,7 +394,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
if deleted:
# We now need to invalidate the caches of these rooms
rows = self._simple_select_many_txn(
rows = self.db.simple_select_many_txn(
txn,
table="events",
column="event_id",
@ -366,13 +402,13 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
keyvalues={},
retcols=("room_id",),
)
room_ids = set(row["room_id"] for row in rows)
room_ids = {row["room_id"] for row in rows}
for room_id in room_ids:
txn.call_after(
self.get_latest_event_ids_in_room.invalidate, (room_id,)
)
self._simple_delete_many_txn(
self.db.simple_delete_many_txn(
txn=txn,
table="_extremities_to_check",
column="event_id",
@ -382,18 +418,172 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
return len(original_set)
num_handled = yield self.runInteraction(
num_handled = yield self.db.runInteraction(
"_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn
)
if not num_handled:
yield self._end_background_update(self.DELETE_SOFT_FAILED_EXTREMITIES)
yield self.db.updates._end_background_update(
self.DELETE_SOFT_FAILED_EXTREMITIES
)
def _drop_table_txn(txn):
txn.execute("DROP TABLE _extremities_to_check")
yield self.runInteraction(
yield self.db.runInteraction(
"_cleanup_extremities_bg_update_drop_table", _drop_table_txn
)
return num_handled
@defer.inlineCallbacks
def _redactions_received_ts(self, progress, batch_size):
"""Handles filling out the `received_ts` column in redactions.
"""
last_event_id = progress.get("last_event_id", "")
def _redactions_received_ts_txn(txn):
# Fetch the set of event IDs that we want to update
sql = """
SELECT event_id FROM redactions
WHERE event_id > ?
ORDER BY event_id ASC
LIMIT ?
"""
txn.execute(sql, (last_event_id, batch_size))
rows = txn.fetchall()
if not rows:
return 0
(upper_event_id,) = rows[-1]
# Update the redactions with the received_ts.
#
# Note: Not all events have an associated received_ts, so we
# fallback to using origin_server_ts. If we for some reason don't
# have an origin_server_ts, lets just use the current timestamp.
#
# We don't want to leave it null, as then we'll never try and
# censor those redactions.
sql = """
UPDATE redactions
SET received_ts = (
SELECT COALESCE(received_ts, origin_server_ts, ?) FROM events
WHERE events.event_id = redactions.event_id
)
WHERE ? <= event_id AND event_id <= ?
"""
txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id))
self.db.updates._background_update_progress_txn(
txn, "redactions_received_ts", {"last_event_id": upper_event_id}
)
return len(rows)
count = yield self.db.runInteraction(
"_redactions_received_ts", _redactions_received_ts_txn
)
if not count:
yield self.db.updates._end_background_update("redactions_received_ts")
return count
@defer.inlineCallbacks
def _event_fix_redactions_bytes(self, progress, batch_size):
"""Undoes hex encoded censored redacted event JSON.
"""
def _event_fix_redactions_bytes_txn(txn):
# This update is quite fast due to new index.
txn.execute(
"""
UPDATE event_json
SET
json = convert_from(json::bytea, 'utf8')
FROM redactions
WHERE
redactions.have_censored
AND event_json.event_id = redactions.redacts
AND json NOT LIKE '{%';
"""
)
txn.execute("DROP INDEX redactions_censored_redacts")
yield self.db.runInteraction(
"_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
)
yield self.db.updates._end_background_update("event_fix_redactions_bytes")
return 1
@defer.inlineCallbacks
def _event_store_labels(self, progress, batch_size):
"""Background update handler which will store labels for existing events."""
last_event_id = progress.get("last_event_id", "")
def _event_store_labels_txn(txn):
txn.execute(
"""
SELECT event_id, json FROM event_json
LEFT JOIN event_labels USING (event_id)
WHERE event_id > ? AND label IS NULL
ORDER BY event_id LIMIT ?
""",
(last_event_id, batch_size),
)
results = list(txn)
nbrows = 0
last_row_event_id = ""
for (event_id, event_json_raw) in results:
try:
event_json = json.loads(event_json_raw)
self.db.simple_insert_many_txn(
txn=txn,
table="event_labels",
values=[
{
"event_id": event_id,
"label": label,
"room_id": event_json["room_id"],
"topological_ordering": event_json["depth"],
}
for label in event_json["content"].get(
EventContentFields.LABELS, []
)
if isinstance(label, str)
],
)
except Exception as e:
logger.warning(
"Unable to load event %s (no labels will be imported): %s",
event_id,
e,
)
nbrows += 1
last_row_event_id = event_id
self.db.updates._background_update_progress_txn(
txn, "event_store_labels", {"last_event_id": last_row_event_id}
)
return nbrows
num_rows = yield self.db.runInteraction(
desc="event_store_labels", func=_event_store_labels_txn
)
if not num_rows:
yield self.db.updates._end_background_update("event_store_labels")
return num_rows

View file

@ -17,26 +17,35 @@ from __future__ import division
import itertools
import logging
import threading
from collections import namedtuple
from typing import List, Optional, Tuple
from canonicaljson import json
from constantly import NamedConstant, Names
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import NotFoundError
from synapse.api.room_versions import EventFormatVersions
from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
RoomVersions,
)
from synapse.events import make_event_from_dict
from synapse.events.utils import prune_event
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id
from synapse.util import batch_iter
from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
@ -53,7 +62,64 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
class EventRedactBehaviour(Names):
"""
What to do when retrieving a redacted event from the database.
"""
AS_IS = NamedConstant()
REDACT = NamedConstant()
BLOCK = NamedConstant()
class EventsWorkerStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs):
super(EventsWorkerStore, self).__init__(database, db_conn, hs)
if hs.config.worker.writers.events == hs.get_instance_name():
# We are the process in charge of generating stream ids for events,
# so instantiate ID generators based on the database
self._stream_id_gen = StreamIdGenerator(
db_conn,
"events",
"stream_ordering",
extra_tables=[("local_invites", "stream_id")],
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
"events",
"stream_ordering",
step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
)
else:
# Another process is in charge of persisting events and generating
# stream IDs: rely on the replication streams to let us know which
# IDs we can process.
self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1
)
self._get_event_cache = Cache(
"*getEvent*",
keylen=3,
max_entries=hs.config.caches.event_cache_size,
apply_cache_factor_from_config=False,
)
self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
self._event_fetch_ongoing = 0
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "events":
self._stream_id_gen.advance(token)
elif stream_name == "backfill":
self._backfill_id_gen.advance(-token)
super().process_replication_rows(stream_name, instance_name, token, rows)
def get_received_ts(self, event_id):
"""Get received_ts (when it was persisted) for the event.
@ -66,7 +132,7 @@ class EventsWorkerStore(SQLBaseStore):
Deferred[int|None]: Timestamp in milliseconds, or None for events
that were persisted before received_ts was implemented.
"""
return self._simple_select_one_onecol(
return self.db.simple_select_one_onecol(
table="events",
keyvalues={"event_id": event_id},
retcol="received_ts",
@ -105,32 +171,41 @@ class EventsWorkerStore(SQLBaseStore):
return ts
return self.runInteraction(
return self.db.runInteraction(
"get_approximate_received_ts", _get_approximate_received_ts_txn
)
@defer.inlineCallbacks
def get_event(
self,
event_id,
check_redacted=True,
get_prev_content=False,
allow_rejected=False,
allow_none=False,
check_room_id=None,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
allow_none: bool = False,
check_room_id: Optional[str] = None,
):
"""Get an event from the database by event_id.
Args:
event_id (str): The event_id of the event to fetch
check_redacted (bool): If True, check if event has been redacted
and redact it.
get_prev_content (bool): If True and event is a state event,
event_id: The event_id of the event to fetch
redact_behaviour: Determine what to do with a redacted event. Possible values:
* AS_IS - Return the full event body with no redacted content
* REDACT - Return the event but with a redacted body
* DISALLOW - Do not return redacted events (behave as per allow_none
if the event is redacted)
get_prev_content: If True and event is a state event,
include the previous states content in the unsigned field.
allow_rejected (bool): If True return rejected events.
allow_none (bool): If True, return None if no event found, if
allow_rejected: If True, return rejected events. Otherwise,
behave as per allow_none.
allow_none: If True, return None if no event found, if
False throw a NotFoundError
check_room_id (str|None): if not None, check the room of the found event.
check_room_id: if not None, check the room of the found event.
If there is a mismatch, behave as per allow_none.
Returns:
@ -141,7 +216,7 @@ class EventsWorkerStore(SQLBaseStore):
events = yield self.get_events_as_list(
[event_id],
check_redacted=check_redacted,
redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
@ -160,27 +235,34 @@ class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_events(
self,
event_ids,
check_redacted=True,
get_prev_content=False,
allow_rejected=False,
event_ids: List[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
):
"""Get events from the database
Args:
event_ids (list): The event_ids of the events to fetch
check_redacted (bool): If True, check if event has been redacted
and redact it.
get_prev_content (bool): If True and event is a state event,
event_ids: The event_ids of the events to fetch
redact_behaviour: Determine what to do with a redacted event. Possible
values:
* AS_IS - Return the full event body with no redacted content
* REDACT - Return the event but with a redacted body
* DISALLOW - Do not return redacted events (omit them from the response)
get_prev_content: If True and event is a state event,
include the previous states content in the unsigned field.
allow_rejected (bool): If True return rejected events.
allow_rejected: If True, return rejected events. Otherwise,
omits rejeted events from the response.
Returns:
Deferred : Dict from event_id to event.
"""
events = yield self.get_events_as_list(
event_ids,
check_redacted=check_redacted,
redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
@ -190,21 +272,29 @@ class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_events_as_list(
self,
event_ids,
check_redacted=True,
get_prev_content=False,
allow_rejected=False,
event_ids: List[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
):
"""Get events from the database and return in a list in the same order
as given by `event_ids` arg.
Unknown events will be omitted from the response.
Args:
event_ids (list): The event_ids of the events to fetch
check_redacted (bool): If True, check if event has been redacted
and redact it.
get_prev_content (bool): If True and event is a state event,
event_ids: The event_ids of the events to fetch
redact_behaviour: Determine what to do with a redacted event. Possible values:
* AS_IS - Return the full event body with no redacted content
* REDACT - Return the event but with a redacted body
* DISALLOW - Do not return redacted events (omit them from the response)
get_prev_content: If True and event is a state event,
include the previous states content in the unsigned field.
allow_rejected (bool): If True return rejected events.
allow_rejected: If True, return rejected events. Otherwise,
omits rejected events from the response.
Returns:
Deferred[list[EventBase]]: List of events fetched from the database. The
@ -238,6 +328,20 @@ class EventsWorkerStore(SQLBaseStore):
# we have to recheck auth now.
if not allow_rejected and entry.event.type == EventTypes.Redaction:
if entry.event.redacts is None:
# A redacted redaction doesn't have a `redacts` key, in
# which case lets just withhold the event.
#
# Note: Most of the time if the redactions has been
# redacted we still have the un-redacted event in the DB
# and so we'll still see the `redacts` key. However, this
# isn't always true e.g. if we have censored the event.
logger.debug(
"Withholding redaction event %s as we don't have redacts key",
event_id,
)
continue
redacted_event_id = entry.event.redacts
event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
original_event_entry = event_map.get(redacted_event_id)
@ -292,10 +396,14 @@ class EventsWorkerStore(SQLBaseStore):
# Update the cache to save doing the checks again.
entry.event.internal_metadata.recheck_redaction = False
if check_redacted and entry.redacted_event:
event = entry.redacted_event
else:
event = entry.event
event = entry.event
if entry.redacted_event:
if redact_behaviour == EventRedactBehaviour.BLOCK:
# Skip this event
continue
elif redact_behaviour == EventRedactBehaviour.REDACT:
event = entry.redacted_event
events.append(event)
@ -319,9 +427,14 @@ class EventsWorkerStore(SQLBaseStore):
If events are pulled from the database, they will be cached for future lookups.
Unknown events are omitted from the response.
Args:
event_ids (Iterable[str]): The event_ids of the events to fetch
allow_rejected (bool): Whether to include rejected events
allow_rejected (bool): Whether to include rejected events. If False,
rejected events are omitted from the response.
Returns:
Deferred[Dict[str, _EventCacheEntry]]:
@ -334,7 +447,7 @@ class EventsWorkerStore(SQLBaseStore):
missing_events_ids = [e for e in event_ids if e not in event_entry_map]
if missing_events_ids:
log_ctx = LoggingContext.current_context()
log_ctx = current_context()
log_ctx.record_event_fetch(len(missing_events_ids))
# Note that _get_events_from_db is also responsible for turning db rows
@ -422,11 +535,11 @@ class EventsWorkerStore(SQLBaseStore):
"""
with Measure(self._clock, "_fetch_event_list"):
try:
events_to_fetch = set(
events_to_fetch = {
event_id for events, _ in event_list for event_id in events
)
}
row_dict = self._new_transaction(
row_dict = self.db.new_transaction(
conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
)
@ -456,9 +569,13 @@ class EventsWorkerStore(SQLBaseStore):
Returned events will be added to the cache for future lookups.
Unknown events are omitted from the response.
Args:
event_ids (Iterable[str]): The event_ids of the events to fetch
allow_rejected (bool): Whether to include rejected events
allow_rejected (bool): Whether to include rejected events. If False,
rejected events are omitted from the response.
Returns:
Deferred[Dict[str, _EventCacheEntry]]:
@ -504,15 +621,56 @@ class EventsWorkerStore(SQLBaseStore):
# of a event format version, so it must be a V1 event.
format_version = EventFormatVersions.V1
original_ev = event_type_from_format_version(format_version)(
room_version_id = row["room_version_id"]
if not room_version_id:
# this should only happen for out-of-band membership events
if not internal_metadata.get("out_of_band_membership"):
logger.warning(
"Room %s for event %s is unknown", d["room_id"], event_id
)
continue
# take a wild stab at the room version based on the event format
if format_version == EventFormatVersions.V1:
room_version = RoomVersions.V1
elif format_version == EventFormatVersions.V2:
room_version = RoomVersions.V3
else:
room_version = RoomVersions.V5
else:
room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
if not room_version:
logger.error(
"Event %s in room %s has unknown room version %s",
event_id,
d["room_id"],
room_version_id,
)
continue
if room_version.event_format != format_version:
logger.error(
"Event %s in room %s with version %s has wrong format: "
"expected %s, was %s",
event_id,
d["room_id"],
room_version_id,
room_version.event_format,
format_version,
)
continue
original_ev = make_event_from_dict(
event_dict=d,
room_version=room_version,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
event_map[event_id] = original_ev
# finally, we can decide whether each one nededs redacting, and build
# finally, we can decide whether each one needs redacting, and build
# the cache entries.
result_map = {}
for event_id, original_ev in event_map.items():
@ -558,7 +716,7 @@ class EventsWorkerStore(SQLBaseStore):
if should_start:
run_as_background_process(
"fetch_events", self.runWithConnection, self._do_fetch
"fetch_events", self.db.runWithConnection, self._do_fetch
)
logger.debug("Loading %d events: %s", len(events), events)
@ -585,6 +743,12 @@ class EventsWorkerStore(SQLBaseStore):
of EventFormatVersions. 'None' means the event predates
EventFormatVersions (so the event is format V1).
* room_version_id (str|None): The version of the room which contains the event.
Hopefully one of RoomVersions.
Due to historical reasons, there may be a few events in the database which
do not have an associated room; in this case None will be returned here.
* rejected_reason (str|None): if the event was rejected, the reason
why.
@ -600,19 +764,24 @@ class EventsWorkerStore(SQLBaseStore):
"""
event_dict = {}
for evs in batch_iter(event_ids, 200):
sql = (
"SELECT "
" e.event_id, "
" e.internal_metadata,"
" e.json,"
" e.format_version, "
" rej.reason "
" FROM event_json as e"
" LEFT JOIN rejections as rej USING (event_id)"
" WHERE e.event_id IN (%s)"
) % (",".join(["?"] * len(evs)),)
sql = """\
SELECT
e.event_id,
e.internal_metadata,
e.json,
e.format_version,
r.room_version,
rej.reason
FROM event_json as e
LEFT JOIN rooms r USING (room_id)
LEFT JOIN rejections as rej USING (event_id)
WHERE """
txn.execute(sql, evs)
clause, args = make_in_list_sql_clause(
txn.database_engine, "e.event_id", evs
)
txn.execute(sql + clause, args)
for row in txn:
event_id = row[0]
@ -621,16 +790,17 @@ class EventsWorkerStore(SQLBaseStore):
"internal_metadata": row[1],
"json": row[2],
"format_version": row[3],
"rejected_reason": row[4],
"room_version_id": row[4],
"rejected_reason": row[5],
"redactions": [],
}
# check for redactions
redactions_sql = (
"SELECT event_id, redacts FROM redactions WHERE redacts IN (%s)"
) % (",".join(["?"] * len(evs)),)
redactions_sql = "SELECT event_id, redacts FROM redactions WHERE "
txn.execute(redactions_sql, evs)
clause, args = make_in_list_sql_clause(txn.database_engine, "redacts", evs)
txn.execute(redactions_sql + clause, args)
for (redacter, redacted) in txn:
d = event_dict.get(redacted)
@ -715,7 +885,7 @@ class EventsWorkerStore(SQLBaseStore):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
rows = yield self._simple_select_many_batch(
rows = yield self.db.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
@ -724,7 +894,7 @@ class EventsWorkerStore(SQLBaseStore):
desc="have_events_in_timeline",
)
return set(r["event_id"] for r in rows)
return {r["event_id"] for r in rows}
@defer.inlineCallbacks
def have_seen_events(self, event_ids):
@ -739,52 +909,21 @@ class EventsWorkerStore(SQLBaseStore):
results = set()
def have_seen_events_txn(txn, chunk):
sql = "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" % (
",".join("?" * len(chunk)),
sql = "SELECT event_id FROM events as e WHERE "
clause, args = make_in_list_sql_clause(
txn.database_engine, "e.event_id", chunk
)
txn.execute(sql, chunk)
txn.execute(sql + clause, args)
for (event_id,) in txn:
results.add(event_id)
# break the input up into chunks of 100
input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk)
return results
def get_seen_events_with_rejections(self, event_ids):
"""Given a list of event ids, check if we rejected them.
Args:
event_ids (list[str])
Returns:
Deferred[dict[str, str|None):
Has an entry for each event id we already have seen. Maps to
the rejected reason string if we rejected the event, else maps
to None.
"""
if not event_ids:
return defer.succeed({})
def f(txn):
sql = (
"SELECT e.event_id, reason FROM events as e "
"LEFT JOIN rejections as r ON e.event_id = r.event_id "
"WHERE e.event_id = ?"
yield self.db.runInteraction(
"have_seen_events", have_seen_events_txn, chunk
)
res = {}
for event_id in event_ids:
txn.execute(sql, (event_id,))
row = txn.fetchone()
if row:
_, rejected = row
res[event_id] = rejected
return res
return self.runInteraction("get_seen_events_with_rejections", f)
return results
def _get_total_state_event_counts_txn(self, txn, room_id):
"""
@ -810,7 +949,7 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
Deferred[int]
"""
return self.runInteraction(
return self.db.runInteraction(
"get_total_state_event_counts",
self._get_total_state_event_counts_txn,
room_id,
@ -835,7 +974,7 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
Deferred[int]
"""
return self.runInteraction(
return self.db.runInteraction(
"get_current_state_event_counts",
self._get_current_state_event_counts_txn,
room_id,
@ -862,3 +1001,343 @@ class EventsWorkerStore(SQLBaseStore):
complexity_v1 = round(state_events / 500, 2)
return {"v1": complexity_v1}
def get_current_backfill_token(self):
"""The current minimum token that backfilled events have reached"""
return -self._backfill_id_gen.get_current_token()
def get_current_events_token(self):
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
def get_all_new_forward_event_rows(self, last_id, current_id, limit):
"""Returns new events, for the Events replication stream
Args:
last_id: the last stream_id from the previous batch.
current_id: the maximum stream_id to return up to
limit: the maximum number of rows to return
Returns: Deferred[List[Tuple]]
a list of events stream rows. Each tuple consists of a stream id as
the first element, followed by fields suitable for casting into an
EventsStreamRow.
"""
def get_all_new_forward_event_rows(txn):
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? < stream_ordering AND stream_ordering <= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
return self.db.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
)
def get_ex_outlier_stream_rows(self, last_id, current_id):
"""Returns de-outliered events, for the Events replication stream
Args:
last_id: the last stream_id from the previous batch.
current_id: the maximum stream_id to return up to
Returns: Deferred[List[Tuple]]
a list of events stream rows. Each tuple consists of a stream id as
the first element, followed by fields suitable for casting into an
EventsStreamRow.
"""
def get_ex_outlier_stream_rows_txn(txn):
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" ORDER BY event_stream_ordering ASC"
)
txn.execute(sql, (last_id, current_id))
return txn.fetchall()
return self.db.runInteraction(
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
)
def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_new_backfill_event_rows(txn):
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
)
txn.execute(sql, (-last_id, -current_id, limit))
new_event_updates = txn.fetchall()
if len(new_event_updates) == limit:
upper_bound = new_event_updates[-1][0]
else:
upper_bound = current_id
sql = (
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_id, -upper_bound))
new_event_updates.extend(txn.fetchall())
return new_event_updates
return self.db.runInteraction(
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
)
async def get_all_updated_current_state_deltas(
self, from_token: int, to_token: int, target_row_count: int
) -> Tuple[List[Tuple], int, bool]:
"""Fetch updates from current_state_delta_stream
Args:
from_token: The previous stream token. Updates from this stream id will
be excluded.
to_token: The current stream token (ie the upper limit). Updates up to this
stream id will be included (modulo the 'limit' param)
target_row_count: The number of rows to try to return. If more rows are
available, we will set 'limited' in the result. In the event of a large
batch, we may return more rows than this.
Returns:
A triplet `(updates, new_last_token, limited)`, where:
* `updates` is a list of database tuples.
* `new_last_token` is the new position in stream.
* `limited` is whether there are more updates to fetch.
"""
def get_all_updated_current_state_deltas_txn(txn):
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC LIMIT ?
"""
txn.execute(sql, (from_token, to_token, target_row_count))
return txn.fetchall()
def get_deltas_for_stream_id_txn(txn, stream_id):
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE stream_id = ?
"""
txn.execute(sql, [stream_id])
return txn.fetchall()
# we need to make sure that, for every stream id in the results, we get *all*
# the rows with that stream id.
rows = await self.db.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
) # type: List[Tuple]
# if we've got fewer rows than the limit, we're good
if len(rows) < target_row_count:
return rows, to_token, False
# we hit the limit, so reduce the upper limit so that we exclude the stream id
# of the last row in the result.
assert rows[-1][0] <= to_token
to_token = rows[-1][0] - 1
# search backwards through the list for the point to truncate
for idx in range(len(rows) - 1, 0, -1):
if rows[idx - 1][0] <= to_token:
return rows[:idx], to_token, True
# bother. We didn't get a full set of changes for even a single
# stream id. let's run the query again, without a row limit, but for
# just one stream id.
to_token += 1
rows = await self.db.runInteraction(
"get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token
)
return rows, to_token, True
@cached(num_args=5, max_entries=10)
def get_all_new_events(
self,
last_backfill_id,
last_forward_id,
current_backfill_id,
current_forward_id,
limit,
):
"""Get all the new events that have arrived at the server either as
new events or as backfilled events"""
have_backfill_events = last_backfill_id != current_backfill_id
have_forward_events = last_forward_id != current_forward_id
if not have_backfill_events and not have_forward_events:
return defer.succeed(AllNewEventsResult([], [], [], [], []))
def get_all_new_events_txn(txn):
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" WHERE ? < stream_ordering AND stream_ordering <= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
)
if have_forward_events:
txn.execute(sql, (last_forward_id, current_forward_id, limit))
new_forward_events = txn.fetchall()
if len(new_forward_events) == limit:
upper_bound = new_forward_events[-1][0]
else:
upper_bound = current_forward_id
sql = (
"SELECT event_stream_ordering, event_id, state_group"
" FROM ex_outlier_stream"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (last_forward_id, upper_bound))
forward_ex_outliers = txn.fetchall()
else:
new_forward_events = []
forward_ex_outliers = []
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
" ORDER BY stream_ordering DESC"
" LIMIT ?"
)
if have_backfill_events:
txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
new_backfill_events = txn.fetchall()
if len(new_backfill_events) == limit:
upper_bound = new_backfill_events[-1][0]
else:
upper_bound = current_backfill_id
sql = (
"SELECT -event_stream_ordering, event_id, state_group"
" FROM ex_outlier_stream"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_backfill_id, -upper_bound))
backward_ex_outliers = txn.fetchall()
else:
new_backfill_events = []
backward_ex_outliers = []
return AllNewEventsResult(
new_forward_events,
new_backfill_events,
forward_ex_outliers,
backward_ex_outliers,
)
return self.db.runInteraction("get_all_new_events", get_all_new_events_txn)
async def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream
"""
to_1, so_1 = await self.get_event_ordering(event_id1)
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
@cachedInlineCallbacks(max_entries=5000)
def get_event_ordering(self, event_id):
res = yield self.db.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
keyvalues={"event_id": event_id},
allow_none=True,
)
if not res:
raise SynapseError(404, "Could not find event %s" % (event_id,))
return (int(res["topological_ordering"]), int(res["stream_ordering"]))
def get_next_event_to_expire(self):
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry
table, or None if there's no more event to expire.
Returns: Deferred[Optional[Tuple[str, int]]]
A tuple containing the event ID as its first element and an expiry timestamp
as its second one, if there's at least one row in the event_expiry table.
None otherwise.
"""
def get_next_event_to_expire_txn(txn):
txn.execute(
"""
SELECT event_id, expiry_ts FROM event_expiry
ORDER BY expiry_ts ASC LIMIT 1
"""
)
return txn.fetchone()
return self.db.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)
AllNewEventsResult = namedtuple(
"AllNewEventsResult",
[
"new_forward_events",
"new_backfill_events",
"forward_ex_outliers",
"backward_ex_outliers",
],
)

View file

@ -16,10 +16,9 @@
from canonicaljson import encode_canonical_json
from synapse.api.errors import Codes, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.util.caches.descriptors import cachedInlineCallbacks
from ._base import SQLBaseStore, db_to_json
class FilteringStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=2)
@ -31,7 +30,7 @@ class FilteringStore(SQLBaseStore):
except ValueError:
raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
def_json = yield self._simple_select_one_onecol(
def_json = yield self.db.simple_select_one_onecol(
table="user_filters",
keyvalues={"user_id": user_localpart, "filter_id": filter_id},
retcol="filter_json",
@ -51,12 +50,12 @@ class FilteringStore(SQLBaseStore):
"SELECT filter_id FROM user_filters "
"WHERE user_id = ? AND filter_json = ?"
)
txn.execute(sql, (user_localpart, def_json))
txn.execute(sql, (user_localpart, bytearray(def_json)))
filter_id_response = txn.fetchone()
if filter_id_response is not None:
return filter_id_response[0]
sql = "SELECT MAX(filter_id) FROM user_filters " "WHERE user_id = ?"
sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
txn.execute(sql, (user_localpart,))
max_id = txn.fetchone()[0]
if max_id is None:
@ -68,8 +67,8 @@ class FilteringStore(SQLBaseStore):
"INSERT INTO user_filters (user_id, filter_id, filter_json)"
"VALUES(?, ?, ?)"
)
txn.execute(sql, (user_localpart, filter_id, def_json))
txn.execute(sql, (user_localpart, filter_id, bytearray(def_json)))
return filter_id
return self.runInteraction("add_user_filter", _do_txn)
return self.db.runInteraction("add_user_filter", _do_txn)

View file

@ -0,0 +1,208 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
from signedjson.key import decode_verify_key_bytes
from synapse.storage._base import SQLBaseStore
from synapse.storage.keys import FetchKeyResult
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__)
db_binary_type = memoryview
class KeyStore(SQLBaseStore):
"""Persistence for signature verification keys
"""
@cached()
def _get_server_verify_key(self, server_name_and_key_id):
raise NotImplementedError()
@cachedList(
cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids"
)
def get_server_verify_keys(self, server_name_and_key_ids):
"""
Args:
server_name_and_key_ids (iterable[Tuple[str, str]]):
iterable of (server_name, key-id) tuples to fetch keys for
Returns:
Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
map from (server_name, key_id) -> FetchKeyResult, or None if the key is
unknown
"""
keys = {}
def _get_keys(txn, batch):
"""Processes a batch of keys to fetch, and adds the result to `keys`."""
# batch_iter always returns tuples so it's safe to do len(batch)
sql = (
"SELECT server_name, key_id, verify_key, ts_valid_until_ms "
"FROM server_signature_keys WHERE 1=0"
) + " OR (server_name=? AND key_id=?)" * len(batch)
txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
for row in txn:
server_name, key_id, key_bytes, ts_valid_until_ms = row
if ts_valid_until_ms is None:
# Old keys may be stored with a ts_valid_until_ms of null,
# in which case we treat this as if it was set to `0`, i.e.
# it won't match key requests that define a minimum
# `ts_valid_until_ms`.
ts_valid_until_ms = 0
res = FetchKeyResult(
verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
valid_until_ts=ts_valid_until_ms,
)
keys[(server_name, key_id)] = res
def _txn(txn):
for batch in batch_iter(server_name_and_key_ids, 50):
_get_keys(txn, batch)
return keys
return self.db.runInteraction("get_server_verify_keys", _txn)
def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
"""Stores NACL verification keys for remote servers.
Args:
from_server (str): Where the verification keys were looked up
ts_added_ms (int): The time to record that the key was added
verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
keys to be stored. Each entry is a triplet of
(server_name, key_id, key).
"""
key_values = []
value_values = []
invalidations = []
for server_name, key_id, fetch_result in verify_keys:
key_values.append((server_name, key_id))
value_values.append(
(
from_server,
ts_added_ms,
fetch_result.valid_until_ts,
db_binary_type(fetch_result.verify_key.encode()),
)
)
# invalidate takes a tuple corresponding to the params of
# _get_server_verify_key. _get_server_verify_key only takes one
# param, which is itself the 2-tuple (server_name, key_id).
invalidations.append((server_name, key_id))
def _invalidate(res):
f = self._get_server_verify_key.invalidate
for i in invalidations:
f((i,))
return res
return self.db.runInteraction(
"store_server_verify_keys",
self.db.simple_upsert_many_txn,
table="server_signature_keys",
key_names=("server_name", "key_id"),
key_values=key_values,
value_names=(
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"verify_key",
),
value_values=value_values,
).addCallback(_invalidate)
def store_server_keys_json(
self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
):
"""Stores the JSON bytes for a set of keys from a server
The JSON should be signed by the originating server, the intermediate
server, and by this server. Updates the value for the
(server_name, key_id, from_server) triplet if one already existed.
Args:
server_name (str): The name of the server.
key_id (str): The identifer of the key this JSON is for.
from_server (str): The server this JSON was fetched from.
ts_now_ms (int): The time now in milliseconds.
ts_valid_until_ms (int): The time when this json stops being valid.
key_json (bytes): The encoded JSON.
"""
return self.db.simple_upsert(
table="server_keys_json",
keyvalues={
"server_name": server_name,
"key_id": key_id,
"from_server": from_server,
},
values={
"server_name": server_name,
"key_id": key_id,
"from_server": from_server,
"ts_added_ms": ts_now_ms,
"ts_valid_until_ms": ts_expires_ms,
"key_json": db_binary_type(key_json_bytes),
},
desc="store_server_keys_json",
)
def get_server_keys_json(self, server_keys):
"""Retrive the key json for a list of server_keys and key ids.
If no keys are found for a given server, key_id and source then
that server, key_id, and source triplet entry will be an empty list.
The JSON is returned as a byte array so that it can be efficiently
used in an HTTP response.
Args:
server_keys (list): List of (server_name, key_id, source) triplets.
Returns:
Deferred[dict[Tuple[str, str, str|None], list[dict]]]:
Dict mapping (server_name, key_id, source) triplets to lists of dicts
"""
def _get_server_keys_json_txn(txn):
results = {}
for server_name, key_id, from_server in server_keys:
keyvalues = {"server_name": server_name}
if key_id is not None:
keyvalues["key_id"] = key_id
if from_server is not None:
keyvalues["from_server"] = from_server
rows = self.db.simple_select_list_txn(
txn,
"server_keys_json",
keyvalues=keyvalues,
retcols=(
"key_id",
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
),
)
results[(server_name, key_id, from_server)] = rows
return results
return self.db.runInteraction("get_server_keys_json", _get_server_keys_json_txn)

View file

@ -12,16 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database
class MediaRepositoryStore(BackgroundUpdateStore):
"""Persistence for attachments and avatars"""
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs):
super(MediaRepositoryBackgroundUpdateStore, self).__init__(
database, db_conn, hs
)
def __init__(self, db_conn, hs):
super(MediaRepositoryStore, self).__init__(db_conn, hs)
self.register_background_index_update(
self.db.updates.register_background_index_update(
update_name="local_media_repository_url_idx",
index_name="local_media_repository_url_idx",
table="local_media_repository",
@ -29,12 +30,19 @@ class MediaRepositoryStore(BackgroundUpdateStore):
where_clause="url_cache IS NOT NULL",
)
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"""Persistence for attachments and avatars"""
def __init__(self, database: Database, db_conn, hs):
super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
def get_local_media(self, media_id):
"""Get the metadata for a local piece of media
Returns:
None if the media_id doesn't exist.
"""
return self._simple_select_one(
return self.db.simple_select_one(
"local_media_repository",
{"media_id": media_id},
(
@ -59,7 +67,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
user_id,
url_cache=None,
):
return self._simple_insert(
return self.db.simple_insert(
"local_media_repository",
{
"media_id": media_id,
@ -119,12 +127,12 @@ class MediaRepositoryStore(BackgroundUpdateStore):
)
)
return self.runInteraction("get_url_cache", get_url_cache_txn)
return self.db.runInteraction("get_url_cache", get_url_cache_txn)
def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts
):
return self._simple_insert(
return self.db.simple_insert(
"local_media_repository_url_cache",
{
"url": url,
@ -139,7 +147,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
)
def get_local_media_thumbnails(self, media_id):
return self._simple_select_list(
return self.db.simple_select_list(
"local_media_repository_thumbnails",
{"media_id": media_id},
(
@ -161,7 +169,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
return self._simple_insert(
return self.db.simple_insert(
"local_media_repository_thumbnails",
{
"media_id": media_id,
@ -175,7 +183,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
)
def get_cached_remote_media(self, origin, media_id):
return self._simple_select_one(
return self.db.simple_select_one(
"remote_media_cache",
{"media_origin": origin, "media_id": media_id},
(
@ -200,7 +208,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
upload_name,
filesystem_id,
):
return self._simple_insert(
return self.db.simple_insert(
"remote_media_cache",
{
"media_origin": origin,
@ -245,10 +253,12 @@ class MediaRepositoryStore(BackgroundUpdateStore):
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
return self.runInteraction("update_cached_last_access_time", update_cache_txn)
return self.db.runInteraction(
"update_cached_last_access_time", update_cache_txn
)
def get_remote_media_thumbnails(self, origin, media_id):
return self._simple_select_list(
return self.db.simple_select_list(
"remote_media_cache_thumbnails",
{"media_origin": origin, "media_id": media_id},
(
@ -273,7 +283,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
return self._simple_insert(
return self.db.simple_insert(
"remote_media_cache_thumbnails",
{
"media_origin": origin,
@ -295,24 +305,24 @@ class MediaRepositoryStore(BackgroundUpdateStore):
" WHERE last_access_ts < ?"
)
return self._execute(
"get_remote_media_before", self.cursor_to_dict, sql, before_ts
return self.db.execute(
"get_remote_media_before", self.db.cursor_to_dict, sql, before_ts
)
def delete_remote_media(self, media_origin, media_id):
def delete_remote_media_txn(txn):
self._simple_delete_txn(
self.db.simple_delete_txn(
txn,
"remote_media_cache",
keyvalues={"media_origin": media_origin, "media_id": media_id},
)
self._simple_delete_txn(
self.db.simple_delete_txn(
txn,
"remote_media_cache_thumbnails",
keyvalues={"media_origin": media_origin, "media_id": media_id},
)
return self.runInteraction("delete_remote_media", delete_remote_media_txn)
return self.db.runInteraction("delete_remote_media", delete_remote_media_txn)
def get_expired_url_cache(self, now_ts):
sql = (
@ -326,18 +336,20 @@ class MediaRepositoryStore(BackgroundUpdateStore):
txn.execute(sql, (now_ts,))
return [row[0] for row in txn]
return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
return self.db.runInteraction(
"get_expired_url_cache", _get_expired_url_cache_txn
)
def delete_url_cache(self, media_ids):
async def delete_url_cache(self, media_ids):
if len(media_ids) == 0:
return
sql = "DELETE FROM local_media_repository_url_cache" " WHERE media_id = ?"
sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
def _delete_url_cache_txn(txn):
txn.executemany(sql, [(media_id,) for media_id in media_ids])
return self.runInteraction("delete_url_cache", _delete_url_cache_txn)
return await self.db.runInteraction("delete_url_cache", _delete_url_cache_txn)
def get_url_cache_media_before(self, before_ts):
sql = (
@ -351,23 +363,23 @@ class MediaRepositoryStore(BackgroundUpdateStore):
txn.execute(sql, (before_ts,))
return [row[0] for row in txn]
return self.runInteraction(
return self.db.runInteraction(
"get_url_cache_media_before", _get_url_cache_media_before_txn
)
def delete_url_cache_media(self, media_ids):
async def delete_url_cache_media(self, media_ids):
if len(media_ids) == 0:
return
def _delete_url_cache_media_txn(txn):
sql = "DELETE FROM local_media_repository" " WHERE media_id = ?"
sql = "DELETE FROM local_media_repository WHERE media_id = ?"
txn.executemany(sql, [(media_id,) for media_id in media_ids])
sql = "DELETE FROM local_media_repository_thumbnails" " WHERE media_id = ?"
sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
txn.executemany(sql, [(media_id,) for media_id in media_ids])
return self.runInteraction(
return await self.db.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn
)

View file

@ -0,0 +1,128 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
from collections import Counter
from twisted.internet import defer
from synapse.metrics import BucketCollector
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.event_push_actions import (
EventPushActionsWorkerStore,
)
from synapse.storage.database import Database
class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"""Functions to pull various metrics from the DB, for e.g. phone home
stats and prometheus metrics.
"""
def __init__(self, database: Database, db_conn, hs):
super().__init__(database, db_conn, hs)
# Collect metrics on the number of forward extremities that exist.
# Counter of number of extremities to count
self._current_forward_extremities_amount = (
Counter()
) # type: typing.Counter[int]
BucketCollector(
"synapse_forward_extremities",
lambda: self._current_forward_extremities_amount,
buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"],
)
# Read the extrems every 60 minutes
def read_forward_extremities():
# run as a background process to make sure that the database transactions
# have a logcontext to report to
return run_as_background_process(
"read_forward_extremities", self._read_forward_extremities
)
hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000)
async def _read_forward_extremities(self):
def fetch(txn):
txn.execute(
"""
select count(*) c from event_forward_extremities
group by room_id
"""
)
return txn.fetchall()
res = await self.db.runInteraction("read_forward_extremities", fetch)
self._current_forward_extremities_amount = Counter([x[0] for x in res])
@defer.inlineCallbacks
def count_daily_messages(self):
"""
Returns an estimate of the number of messages sent in the last day.
If it has been significantly less or more than one day since the last
call to this function, it will return None.
"""
def _count_messages(txn):
sql = """
SELECT COALESCE(COUNT(*), 0) FROM events
WHERE type = 'm.room.message'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
(count,) = txn.fetchone()
return count
ret = yield self.db.runInteraction("count_messages", _count_messages)
return ret
@defer.inlineCallbacks
def count_daily_sent_messages(self):
def _count_messages(txn):
# This is good enough as if you have silly characters in your own
# hostname then thats your own fault.
like_clause = "%:" + self.hs.hostname
sql = """
SELECT COALESCE(COUNT(*), 0) FROM events
WHERE type = 'm.room.message'
AND sender LIKE ?
AND stream_ordering > ?
"""
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
(count,) = txn.fetchone()
return count
ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages)
return ret
@defer.inlineCallbacks
def count_daily_active_rooms(self):
def _count(txn):
sql = """
SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
WHERE type = 'm.room.message'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
(count,) = txn.fetchone()
return count
ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
return ret

View file

@ -13,13 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database, make_in_list_sql_clause
from synapse.util.caches.descriptors import cached
from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
# Number of msec of granularity to store the monthly_active_user timestamp
@ -27,119 +28,11 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 60 * 60 * 1000
class MonthlyActiveUsersStore(SQLBaseStore):
def __init__(self, dbconn, hs):
super(MonthlyActiveUsersStore, self).__init__(None, hs)
class MonthlyActiveUsersWorkerStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs):
super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs)
self._clock = hs.get_clock()
self.hs = hs
self.reserved_users = ()
# Do not add more reserved users than the total allowable number
self._new_transaction(
dbconn,
"initialise_mau_threepids",
[],
[],
self._initialise_reserved_users,
hs.config.mau_limits_reserved_threepids[: self.hs.config.max_mau_value],
)
def _initialise_reserved_users(self, txn, threepids):
"""Ensures that reserved threepids are accounted for in the MAU table, should
be called on start up.
Args:
txn (cursor):
threepids (list[dict]): List of threepid dicts to reserve
"""
reserved_user_list = []
for tp in threepids:
user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
if user_id:
is_support = self.is_support_user_txn(txn, user_id)
if not is_support:
self.upsert_monthly_active_user_txn(txn, user_id)
reserved_user_list.append(user_id)
else:
logger.warning("mau limit reserved threepid %s not found in db" % tp)
self.reserved_users = tuple(reserved_user_list)
@defer.inlineCallbacks
def reap_monthly_active_users(self):
"""Cleans out monthly active user table to ensure that no stale
entries exist.
Returns:
Deferred[]
"""
def _reap_users(txn):
# Purge stale users
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
query_args = [thirty_days_ago]
base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?"
# Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
# when len(reserved_users) == 0. Works fine on sqlite.
if len(self.reserved_users) > 0:
# questionmarks is a hack to overcome sqlite not supporting
# tuples in 'WHERE IN %s'
questionmarks = "?" * len(self.reserved_users)
query_args.extend(self.reserved_users)
sql = base_sql + """ AND user_id NOT IN ({})""".format(
",".join(questionmarks)
)
else:
sql = base_sql
txn.execute(sql, query_args)
if self.hs.config.limit_usage_by_mau:
# If MAU user count still exceeds the MAU threshold, then delete on
# a least recently active basis.
# Note it is not possible to write this query using OFFSET due to
# incompatibilities in how sqlite and postgres support the feature.
# sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present
# While Postgres does not require 'LIMIT', but also does not support
# negative LIMIT values. So there is no way to write it that both can
# support
safe_guard = self.hs.config.max_mau_value - len(self.reserved_users)
# Must be greater than zero for postgres
safe_guard = safe_guard if safe_guard > 0 else 0
query_args = [safe_guard]
base_sql = """
DELETE FROM monthly_active_users
WHERE user_id NOT IN (
SELECT user_id FROM monthly_active_users
ORDER BY timestamp DESC
LIMIT ?
)
"""
# Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
# when len(reserved_users) == 0. Works fine on sqlite.
if len(self.reserved_users) > 0:
query_args.extend(self.reserved_users)
sql = base_sql + """ AND user_id NOT IN ({})""".format(
",".join(questionmarks)
)
else:
sql = base_sql
txn.execute(sql, query_args)
yield self.runInteraction("reap_monthly_active_users", _reap_users)
# It seems poor to invalidate the whole cache, Postgres supports
# 'Returning' which would allow me to invalidate only the
# specific users, but sqlite has no way to do this and instead
# I would need to SELECT and the DELETE which without locking
# is racy.
# Have resolved to invalidate the whole cache for now and do
# something about it if and when the perf becomes significant
self.user_last_seen_monthly_active.invalidate_all()
self.get_monthly_active_count.invalidate_all()
@cached(num_args=0)
def get_monthly_active_count(self):
@ -151,29 +44,211 @@ class MonthlyActiveUsersStore(SQLBaseStore):
def _count_users(txn):
sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
txn.execute(sql)
count, = txn.fetchone()
(count,) = txn.fetchone()
return count
return self.runInteraction("count_users", _count_users)
return self.db.runInteraction("count_users", _count_users)
@defer.inlineCallbacks
def get_registered_reserved_users_count(self):
"""Of the reserved threepids defined in config, how many are associated
with registered users?
@cached(num_args=0)
def get_monthly_active_count_by_service(self):
"""Generates current count of monthly active users broken down by service.
A service is typically an appservice but also includes native matrix users.
Since the `monthly_active_users` table is populated from the `user_ips` table
`config.track_appservice_user_ips` must be set to `true` for this
method to return anything other than native matrix users.
Returns:
Defered[int]: Number of real reserved users
Deferred[dict]: dict that includes a mapping between app_service_id
and the number of occurrences.
"""
count = 0
for tp in self.hs.config.mau_limits_reserved_threepids:
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
def _count_users_by_service(txn):
sql = """
SELECT COALESCE(appservice_id, 'native'), COALESCE(count(*), 0)
FROM monthly_active_users
LEFT JOIN users ON monthly_active_users.user_id=users.name
GROUP BY appservice_id;
"""
txn.execute(sql)
result = txn.fetchall()
return dict(result)
return self.db.runInteraction("count_users_by_service", _count_users_by_service)
async def get_registered_reserved_users(self) -> List[str]:
"""Of the reserved threepids defined in config, retrieve those that are associated
with registered users
Returns:
User IDs of actual users that are reserved
"""
users = []
for tp in self.hs.config.mau_limits_reserved_threepids[
: self.hs.config.max_mau_value
]:
user_id = await self.hs.get_datastore().get_user_id_by_threepid(
tp["medium"], tp["address"]
)
if user_id:
count = count + 1
return count
users.append(user_id)
return users
@cached(num_args=1)
def user_last_seen_monthly_active(self, user_id):
"""
Checks if a given user is part of the monthly active user group
Arguments:
user_id (str): user to add/update
Return:
Deferred[int] : timestamp since last seen, None if never seen
"""
return self.db.simple_select_one_onecol(
table="monthly_active_users",
keyvalues={"user_id": user_id},
retcol="timestamp",
allow_none=True,
desc="user_last_seen_monthly_active",
)
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
def __init__(self, database: Database, db_conn, hs):
super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs)
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
self._mau_stats_only = hs.config.mau_stats_only
self._max_mau_value = hs.config.max_mau_value
# Do not add more reserved users than the total allowable number
# cur = LoggingTransaction(
self.db.new_transaction(
db_conn,
"initialise_mau_threepids",
[],
[],
self._initialise_reserved_users,
hs.config.mau_limits_reserved_threepids[: self._max_mau_value],
)
def _initialise_reserved_users(self, txn, threepids):
"""Ensures that reserved threepids are accounted for in the MAU table, should
be called on start up.
Args:
txn (cursor):
threepids (list[dict]): List of threepid dicts to reserve
"""
# XXX what is this function trying to achieve? It upserts into
# monthly_active_users for each *registered* reserved mau user, but why?
#
# - shouldn't there already be an entry for each reserved user (at least
# if they have been active recently)?
#
# - if it's important that the timestamp is kept up to date, why do we only
# run this at startup?
for tp in threepids:
user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
if user_id:
is_support = self.is_support_user_txn(txn, user_id)
if not is_support:
# We do this manually here to avoid hitting #6791
self.db.simple_upsert_txn(
txn,
table="monthly_active_users",
keyvalues={"user_id": user_id},
values={"timestamp": int(self._clock.time_msec())},
)
else:
logger.warning("mau limit reserved threepid %s not found in db" % tp)
async def reap_monthly_active_users(self):
"""Cleans out monthly active user table to ensure that no stale
entries exist.
"""
def _reap_users(txn, reserved_users):
"""
Args:
reserved_users (tuple): reserved users to preserve
"""
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
in_clause, in_clause_args = make_in_list_sql_clause(
self.database_engine, "user_id", reserved_users
)
txn.execute(
"DELETE FROM monthly_active_users WHERE timestamp < ? AND NOT %s"
% (in_clause,),
[thirty_days_ago] + in_clause_args,
)
if self._limit_usage_by_mau:
# If MAU user count still exceeds the MAU threshold, then delete on
# a least recently active basis.
# Note it is not possible to write this query using OFFSET due to
# incompatibilities in how sqlite and postgres support the feature.
# Sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present,
# while Postgres does not require 'LIMIT', but also does not support
# negative LIMIT values. So there is no way to write it that both can
# support
# Limit must be >= 0 for postgres
num_of_non_reserved_users_to_remove = max(
self._max_mau_value - len(reserved_users), 0
)
# It is important to filter reserved users twice to guard
# against the case where the reserved user is present in the
# SELECT, meaning that a legitimate mau is deleted.
sql = """
DELETE FROM monthly_active_users
WHERE user_id NOT IN (
SELECT user_id FROM monthly_active_users
WHERE NOT %s
ORDER BY timestamp DESC
LIMIT ?
)
AND NOT %s
""" % (
in_clause,
in_clause,
)
query_args = (
in_clause_args
+ [num_of_non_reserved_users_to_remove]
+ in_clause_args
)
txn.execute(sql, query_args)
# It seems poor to invalidate the whole cache. Postgres supports
# 'Returning' which would allow me to invalidate only the
# specific users, but sqlite has no way to do this and instead
# I would need to SELECT and the DELETE which without locking
# is racy.
# Have resolved to invalidate the whole cache for now and do
# something about it if and when the perf becomes significant
self._invalidate_all_cache_and_stream(
txn, self.user_last_seen_monthly_active
)
self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
reserved_users = await self.get_registered_reserved_users()
await self.db.runInteraction(
"reap_monthly_active_users", _reap_users, reserved_users
)
@defer.inlineCallbacks
def upsert_monthly_active_user(self, user_id):
@ -182,6 +257,9 @@ class MonthlyActiveUsersStore(SQLBaseStore):
Args:
user_id (str): user to add/update
Returns:
Deferred
"""
# Support user never to be included in MAU stats. Note I can't easily call this
# from upsert_monthly_active_user_txn because then I need a _txn form of
@ -195,27 +273,13 @@ class MonthlyActiveUsersStore(SQLBaseStore):
if is_support:
return
yield self.runInteraction(
yield self.db.runInteraction(
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
)
user_in_mau = self.user_last_seen_monthly_active.cache.get(
(user_id,), None, update_metrics=False
)
if user_in_mau is None:
self.get_monthly_active_count.invalidate(())
self.user_last_seen_monthly_active.invalidate((user_id,))
def upsert_monthly_active_user_txn(self, txn, user_id):
"""Updates or inserts monthly active user member
Note that, after calling this method, it will generally be necessary
to invalidate the caches on user_last_seen_monthly_active and
get_monthly_active_count. We can't do that here, because we are running
in a database thread rather than the main thread, and we can't call
txn.call_after because txn may not be a LoggingTransaction.
We consciously do not call is_support_txn from this method because it
is not possible to cache the response. is_support_txn will be false in
almost all cases, so it seems reasonable to call it only for
@ -239,33 +303,22 @@ class MonthlyActiveUsersStore(SQLBaseStore):
# never be a big table and alternative approaches (batching multiple
# upserts into a single txn) introduced a lot of extra complexity.
# See https://github.com/matrix-org/synapse/issues/3854 for more
is_insert = self._simple_upsert_txn(
is_insert = self.db.simple_upsert_txn(
txn,
table="monthly_active_users",
keyvalues={"user_id": user_id},
values={"timestamp": int(self._clock.time_msec())},
)
return is_insert
@cached(num_args=1)
def user_last_seen_monthly_active(self, user_id):
"""
Checks if a given user is part of the monthly active user group
Arguments:
user_id (str): user to add/update
Return:
Deferred[int] : timestamp since last seen, None if never seen
"""
return self._simple_select_one_onecol(
table="monthly_active_users",
keyvalues={"user_id": user_id},
retcol="timestamp",
allow_none=True,
desc="user_last_seen_monthly_active",
self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
self._invalidate_cache_and_stream(
txn, self.get_monthly_active_count_by_service, ()
)
self._invalidate_cache_and_stream(
txn, self.user_last_seen_monthly_active, (user_id,)
)
return is_insert
@defer.inlineCallbacks
def populate_monthly_active_users(self, user_id):
@ -275,7 +328,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
Args:
user_id(str): the user_id to query
"""
if self.hs.config.limit_usage_by_mau or self.hs.config.mau_stats_only:
if self._limit_usage_by_mau or self._mau_stats_only:
# Trial users and guests should not be included as part of MAU group
is_guest = yield self.is_guest(user_id)
if is_guest:
@ -296,11 +349,11 @@ class MonthlyActiveUsersStore(SQLBaseStore):
# In the case where mau_stats_only is True and limit_usage_by_mau is
# False, there is no point in checking get_monthly_active_count - it
# adds no value and will break the logic if max_mau_value is exceeded.
if not self.hs.config.limit_usage_by_mau:
if not self._limit_usage_by_mau:
yield self.upsert_monthly_active_user(user_id)
else:
count = yield self.get_monthly_active_count()
if count < self.hs.config.max_mau_value:
if count < self._max_mau_value:
yield self.upsert_monthly_active_user(user_id)
elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
yield self.upsert_monthly_active_user(user_id)

View file

@ -1,9 +1,9 @@
from ._base import SQLBaseStore
from synapse.storage._base import SQLBaseStore
class OpenIdStore(SQLBaseStore):
def insert_open_id_token(self, token, ts_valid_until_ms, user_id):
return self._simple_insert(
return self.db.simple_insert(
table="open_id_tokens",
values={
"token": token,
@ -28,4 +28,6 @@ class OpenIdStore(SQLBaseStore):
else:
return rows[0][0]
return self.runInteraction("get_user_id_for_token", get_user_id_for_token_txn)
return self.db.runInteraction(
"get_user_id_for_token", get_user_id_for_token_txn
)

View file

@ -0,0 +1,153 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.presence import UserPresenceState
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
class PresenceStore(SQLBaseStore):
@defer.inlineCallbacks
def update_presence(self, presence_states):
stream_ordering_manager = self._presence_id_gen.get_next_mult(
len(presence_states)
)
with stream_ordering_manager as stream_orderings:
yield self.db.runInteraction(
"update_presence",
self._update_presence_txn,
stream_orderings,
presence_states,
)
return stream_orderings[-1], self._presence_id_gen.get_current_token()
def _update_presence_txn(self, txn, stream_orderings, presence_states):
for stream_id, state in zip(stream_orderings, presence_states):
txn.call_after(
self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
)
txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
# Actually insert new rows
self.db.simple_insert_many_txn(
txn,
table="presence_stream",
values=[
{
"stream_id": stream_id,
"user_id": state.user_id,
"state": state.state,
"last_active_ts": state.last_active_ts,
"last_federation_update_ts": state.last_federation_update_ts,
"last_user_sync_ts": state.last_user_sync_ts,
"status_msg": state.status_msg,
"currently_active": state.currently_active,
}
for stream_id, state in zip(stream_orderings, presence_states)
],
)
# Delete old rows to stop database from getting really big
sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "
for states in batch_iter(presence_states, 50):
clause, args = make_in_list_sql_clause(
self.database_engine, "user_id", [s.user_id for s in states]
)
txn.execute(sql + clause, [stream_id] + list(args))
def get_all_presence_updates(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_presence_updates_txn(txn):
sql = """
SELECT stream_id, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts,
status_msg,
currently_active
FROM presence_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
return self.db.runInteraction(
"get_all_presence_updates", get_all_presence_updates_txn
)
@cached()
def _get_presence_for_user(self, user_id):
raise NotImplementedError()
@cachedList(
cached_method_name="_get_presence_for_user",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
)
def get_presence_for_users(self, user_ids):
rows = yield self.db.simple_select_many_batch(
table="presence_stream",
column="user_id",
iterable=user_ids,
keyvalues={},
retcols=(
"user_id",
"state",
"last_active_ts",
"last_federation_update_ts",
"last_user_sync_ts",
"status_msg",
"currently_active",
),
desc="get_presence_for_users",
)
for row in rows:
row["currently_active"] = bool(row["currently_active"])
return {row["user_id"]: UserPresenceState(**row) for row in rows}
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
def allow_presence_visible(self, observed_localpart, observer_userid):
return self.db.simple_insert(
table="presence_allow_inbound",
values={
"observed_user_id": observed_localpart,
"observer_user_id": observer_userid,
},
desc="allow_presence_visible",
or_ignore=True,
)
def disallow_presence_visible(self, observed_localpart, observer_userid):
return self.db.simple_delete_one(
table="presence_allow_inbound",
keyvalues={
"observed_user_id": observed_localpart,
"observer_user_id": observer_userid,
},
desc="disallow_presence_visible",
)

View file

@ -16,16 +16,15 @@
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.storage.roommember import ProfileInfo
from ._base import SQLBaseStore
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.roommember import ProfileInfo
class ProfileWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_profileinfo(self, user_localpart):
try:
profile = yield self._simple_select_one(
profile = yield self.db.simple_select_one(
table="profiles",
keyvalues={"user_id": user_localpart},
retcols=("displayname", "avatar_url"),
@ -43,7 +42,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
def get_profile_displayname(self, user_localpart):
return self._simple_select_one_onecol(
return self.db.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="displayname",
@ -51,7 +50,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
def get_profile_avatar_url(self, user_localpart):
return self._simple_select_one_onecol(
return self.db.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="avatar_url",
@ -59,7 +58,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
def get_from_remote_profile_cache(self, user_id):
return self._simple_select_one(
return self.db.simple_select_one(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
retcols=("displayname", "avatar_url"),
@ -68,12 +67,12 @@ class ProfileWorkerStore(SQLBaseStore):
)
def create_profile(self, user_localpart):
return self._simple_insert(
return self.db.simple_insert(
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
)
def set_profile_displayname(self, user_localpart, new_displayname):
return self._simple_update_one(
return self.db.simple_update_one(
table="profiles",
keyvalues={"user_id": user_localpart},
updatevalues={"displayname": new_displayname},
@ -81,7 +80,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
def set_profile_avatar_url(self, user_localpart, new_avatar_url):
return self._simple_update_one(
return self.db.simple_update_one(
table="profiles",
keyvalues={"user_id": user_localpart},
updatevalues={"avatar_url": new_avatar_url},
@ -96,7 +95,7 @@ class ProfileStore(ProfileWorkerStore):
This should only be called when `is_subscribed_remote_profile_for_user`
would return true for the user.
"""
return self._simple_upsert(
return self.db.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
@ -108,10 +107,10 @@ class ProfileStore(ProfileWorkerStore):
)
def update_remote_profile_cache(self, user_id, displayname, avatar_url):
return self._simple_update(
return self.db.simple_update(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
updatevalues={
"displayname": displayname,
"avatar_url": avatar_url,
"last_check": self._clock.time_msec(),
@ -126,7 +125,7 @@ class ProfileStore(ProfileWorkerStore):
"""
subscribed = yield self.is_subscribed_remote_profile_for_user(user_id)
if not subscribed:
yield self._simple_delete(
yield self.db.simple_delete(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
desc="delete_remote_profile_cache",
@ -145,9 +144,9 @@ class ProfileStore(ProfileWorkerStore):
txn.execute(sql, (last_checked,))
return self.cursor_to_dict(txn)
return self.db.cursor_to_dict(txn)
return self.runInteraction(
return self.db.runInteraction(
"get_remote_profile_cache_entries_that_expire",
_get_remote_profile_cache_entries_that_expire_txn,
)
@ -156,7 +155,7 @@ class ProfileStore(ProfileWorkerStore):
def is_subscribed_remote_profile_for_user(self, user_id):
"""Check whether we are interested in a remote user's profile.
"""
res = yield self._simple_select_one_onecol(
res = yield self.db.simple_select_one_onecol(
table="group_users",
keyvalues={"user_id": user_id},
retcol="user_id",
@ -167,7 +166,7 @@ class ProfileStore(ProfileWorkerStore):
if res:
return True
res = yield self._simple_select_one_onecol(
res = yield self.db.simple_select_one_onecol(
table="group_invites",
keyvalues={"user_id": user_id},
retcol="user_id",

View file

@ -0,0 +1,399 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Tuple
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.state import StateGroupWorkerStore
from synapse.types import RoomStreamToken
logger = logging.getLogger(__name__)
class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
def purge_history(self, room_id, token, delete_local_events):
"""Deletes room history before a certain point
Args:
room_id (str):
token (str): A topological token to delete events before
delete_local_events (bool):
if True, we will delete local events as well as remote ones
(instead of just marking them as outliers and deleting their
state groups).
Returns:
Deferred[set[int]]: The set of state groups that are referenced by
deleted events.
"""
return self.db.runInteraction(
"purge_history",
self._purge_history_txn,
room_id,
token,
delete_local_events,
)
def _purge_history_txn(self, txn, room_id, token_str, delete_local_events):
token = RoomStreamToken.parse(token_str)
# Tables that should be pruned:
# event_auth
# event_backward_extremities
# event_edges
# event_forward_extremities
# event_json
# event_push_actions
# event_reference_hashes
# event_search
# event_to_state_groups
# events
# rejections
# room_depth
# state_groups
# state_groups_state
# we will build a temporary table listing the events so that we don't
# have to keep shovelling the list back and forth across the
# connection. Annoyingly the python sqlite driver commits the
# transaction on CREATE, so let's do this first.
#
# furthermore, we might already have the table from a previous (failed)
# purge attempt, so let's drop the table first.
txn.execute("DROP TABLE IF EXISTS events_to_purge")
txn.execute(
"CREATE TEMPORARY TABLE events_to_purge ("
" event_id TEXT NOT NULL,"
" should_delete BOOLEAN NOT NULL"
")"
)
# First ensure that we're not about to delete all the forward extremeties
txn.execute(
"SELECT e.event_id, e.depth FROM events as e "
"INNER JOIN event_forward_extremities as f "
"ON e.event_id = f.event_id "
"AND e.room_id = f.room_id "
"WHERE f.room_id = ?",
(room_id,),
)
rows = txn.fetchall()
max_depth = max(row[1] for row in rows)
if max_depth < token.topological:
# We need to ensure we don't delete all the events from the database
# otherwise we wouldn't be able to send any events (due to not
# having any backwards extremeties)
raise SynapseError(
400, "topological_ordering is greater than forward extremeties"
)
logger.info("[purge] looking for events to delete")
should_delete_expr = "state_key IS NULL"
should_delete_params = () # type: Tuple[Any, ...]
if not delete_local_events:
should_delete_expr += " AND event_id NOT LIKE ?"
# We include the parameter twice since we use the expression twice
should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname)
should_delete_params += (room_id, token.topological)
# Note that we insert events that are outliers and aren't going to be
# deleted, as nothing will happen to them.
txn.execute(
"INSERT INTO events_to_purge"
" SELECT event_id, %s"
" FROM events AS e LEFT JOIN state_events USING (event_id)"
" WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?"
% (should_delete_expr, should_delete_expr),
should_delete_params,
)
# We create the indices *after* insertion as that's a lot faster.
# create an index on should_delete because later we'll be looking for
# the should_delete / shouldn't_delete subsets
txn.execute(
"CREATE INDEX events_to_purge_should_delete"
" ON events_to_purge(should_delete)"
)
# We do joins against events_to_purge for e.g. calculating state
# groups to purge, etc., so lets make an index.
txn.execute("CREATE INDEX events_to_purge_id ON events_to_purge(event_id)")
txn.execute("SELECT event_id, should_delete FROM events_to_purge")
event_rows = txn.fetchall()
logger.info(
"[purge] found %i events before cutoff, of which %i can be deleted",
len(event_rows),
sum(1 for e in event_rows if e[1]),
)
logger.info("[purge] Finding new backward extremities")
# We calculate the new entries for the backward extremeties by finding
# events to be purged that are pointed to by events we're not going to
# purge.
txn.execute(
"SELECT DISTINCT e.event_id FROM events_to_purge AS e"
" INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id"
" LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id"
" WHERE ep2.event_id IS NULL"
)
new_backwards_extrems = txn.fetchall()
logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
txn.execute(
"DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,)
)
# Update backward extremeties
txn.executemany(
"INSERT INTO event_backward_extremities (room_id, event_id)"
" VALUES (?, ?)",
[(room_id, event_id) for event_id, in new_backwards_extrems],
)
logger.info("[purge] finding state groups referenced by deleted events")
# Get all state groups that are referenced by events that are to be
# deleted.
txn.execute(
"""
SELECT DISTINCT state_group FROM events_to_purge
INNER JOIN event_to_state_groups USING (event_id)
"""
)
referenced_state_groups = {sg for sg, in txn}
logger.info(
"[purge] found %i referenced state groups", len(referenced_state_groups)
)
logger.info("[purge] removing events from event_to_state_groups")
txn.execute(
"DELETE FROM event_to_state_groups "
"WHERE event_id IN (SELECT event_id from events_to_purge)"
)
for event_id, _ in event_rows:
txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
# Delete all remote non-state events
for table in (
"events",
"event_json",
"event_auth",
"event_edges",
"event_forward_extremities",
"event_reference_hashes",
"event_search",
"rejections",
):
logger.info("[purge] removing events from %s", table)
txn.execute(
"DELETE FROM %s WHERE event_id IN ("
" SELECT event_id FROM events_to_purge WHERE should_delete"
")" % (table,)
)
# event_push_actions lacks an index on event_id, and has one on
# (room_id, event_id) instead.
for table in ("event_push_actions",):
logger.info("[purge] removing events from %s", table)
txn.execute(
"DELETE FROM %s WHERE room_id = ? AND event_id IN ("
" SELECT event_id FROM events_to_purge WHERE should_delete"
")" % (table,),
(room_id,),
)
# Mark all state and own events as outliers
logger.info("[purge] marking remaining events as outliers")
txn.execute(
"UPDATE events SET outlier = ?"
" WHERE event_id IN ("
" SELECT event_id FROM events_to_purge "
" WHERE NOT should_delete"
")",
(True,),
)
# synapse tries to take out an exclusive lock on room_depth whenever it
# persists events (because upsert), and once we run this update, we
# will block that for the rest of our transaction.
#
# So, let's stick it at the end so that we don't block event
# persistence.
#
# We do this by calculating the minimum depth of the backwards
# extremities. However, the events in event_backward_extremities
# are ones we don't have yet so we need to look at the events that
# point to it via event_edges table.
txn.execute(
"""
SELECT COALESCE(MIN(depth), 0)
FROM event_backward_extremities AS eb
INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id
INNER JOIN events AS e ON e.event_id = eg.event_id
WHERE eb.room_id = ?
""",
(room_id,),
)
(min_depth,) = txn.fetchone()
logger.info("[purge] updating room_depth to %d", min_depth)
txn.execute(
"UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
(min_depth, room_id),
)
# finally, drop the temp table. this will commit the txn in sqlite,
# so make sure to keep this actually last.
txn.execute("DROP TABLE events_to_purge")
logger.info("[purge] done")
return referenced_state_groups
def purge_room(self, room_id):
"""Deletes all record of a room
Args:
room_id (str)
Returns:
Deferred[List[int]]: The list of state groups to delete.
"""
return self.db.runInteraction("purge_room", self._purge_room_txn, room_id)
def _purge_room_txn(self, txn, room_id):
# First we fetch all the state groups that should be deleted, before
# we delete that information.
txn.execute(
"""
SELECT DISTINCT state_group FROM events
INNER JOIN event_to_state_groups USING(event_id)
WHERE events.room_id = ?
""",
(room_id,),
)
state_groups = [row[0] for row in txn]
# Now we delete tables which lack an index on room_id but have one on event_id
for table in (
"event_auth",
"event_edges",
"event_push_actions_staging",
"event_reference_hashes",
"event_relations",
"event_to_state_groups",
"redactions",
"rejections",
"state_events",
):
logger.info("[purge] removing %s from %s", room_id, table)
txn.execute(
"""
DELETE FROM %s WHERE event_id IN (
SELECT event_id FROM events WHERE room_id=?
)
"""
% (table,),
(room_id,),
)
# and finally, the tables with an index on room_id (or no useful index)
for table in (
"current_state_events",
"event_backward_extremities",
"event_forward_extremities",
"event_json",
"event_push_actions",
"event_search",
"events",
"group_rooms",
"public_room_list_stream",
"receipts_graph",
"receipts_linearized",
"room_aliases",
"room_depth",
"room_memberships",
"room_stats_state",
"room_stats_current",
"room_stats_historical",
"room_stats_earliest_token",
"rooms",
"stream_ordering_to_exterm",
"users_in_public_rooms",
"users_who_share_private_rooms",
# no useful index, but let's clear them anyway
"appservice_room_list",
"e2e_room_keys",
"event_push_summary",
"pusher_throttle",
"group_summary_rooms",
"local_invites",
"room_account_data",
"room_tags",
"local_current_membership",
):
logger.info("[purge] removing %s from %s", room_id, table)
txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
# Other tables we do NOT need to clear out:
#
# - blocked_rooms
# This is important, to make sure that we don't accidentally rejoin a blocked
# room after it was purged
#
# - user_directory
# This has a room_id column, but it is unused
#
# Other tables that we might want to consider clearing out include:
#
# - event_reports
# Given that these are intended for abuse management my initial
# inclination is to leave them in place.
#
# - current_state_delta_stream
# - ex_outlier_stream
# - room_tags_revisions
# The problem with these is that they are largeish and there is no room_id
# index on them. In any case we should be clearing out 'stream' tables
# periodically anyway (#5888)
# TODO: we could probably usefully do a bunch of cache invalidation here
logger.info("[purge] done")
return state_groups

View file

@ -0,0 +1,729 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
from typing import Union
from canonicaljson import json
from twisted.internet import defer
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.pusher import PusherWorkerStore
from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
from synapse.storage.database import Database
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import ChainedIdGenerator
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
def _load_rules(rawrules, enabled_map):
ruleslist = []
for rawrule in rawrules:
rule = dict(rawrule)
rule["conditions"] = json.loads(rawrule["conditions"])
rule["actions"] = json.loads(rawrule["actions"])
rule["default"] = False
ruleslist.append(rule)
# We're going to be mutating this a lot, so do a deep copy
rules = list(list_with_base_rules(ruleslist))
for i, rule in enumerate(rules):
rule_id = rule["rule_id"]
if rule_id in enabled_map:
if rule.get("enabled", True) != bool(enabled_map[rule_id]):
# Rules are cached across users.
rule = dict(rule)
rule["enabled"] = bool(enabled_map[rule_id])
rules[i] = rule
return rules
class PushRulesWorkerStore(
ApplicationServiceWorkerStore,
ReceiptsWorkerStore,
PusherWorkerStore,
RoomMemberWorkerStore,
EventsWorkerStore,
SQLBaseStore,
):
"""This is an abstract base class where subclasses must implement
`get_max_push_rules_stream_id` which can be called in the initializer.
"""
# This ABCMeta metaclass ensures that we cannot be instantiated without
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
def __init__(self, database: Database, db_conn, hs):
super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
self._push_rules_stream_id_gen = ChainedIdGenerator(
self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
) # type: Union[ChainedIdGenerator, SlavedIdTracker]
else:
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id"
)
push_rules_prefill, push_rules_id = self.db.get_cache_dict(
db_conn,
"push_rules_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self.get_max_push_rules_stream_id(),
)
self.push_rules_stream_cache = StreamChangeCache(
"PushRulesStreamChangeCache",
push_rules_id,
prefilled_cache=push_rules_prefill,
)
@abc.abstractmethod
def get_max_push_rules_stream_id(self):
"""Get the position of the push rules stream.
Returns:
int
"""
raise NotImplementedError()
@cachedInlineCallbacks(max_entries=5000)
def get_push_rules_for_user(self, user_id):
rows = yield self.db.simple_select_list(
table="push_rules",
keyvalues={"user_name": user_id},
retcols=(
"user_name",
"rule_id",
"priority_class",
"priority",
"conditions",
"actions",
),
desc="get_push_rules_enabled_for_user",
)
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
rules = _load_rules(rows, enabled_map)
return rules
@cachedInlineCallbacks(max_entries=5000)
def get_push_rules_enabled_for_user(self, user_id):
results = yield self.db.simple_select_list(
table="push_rules_enable",
keyvalues={"user_name": user_id},
retcols=("user_name", "rule_id", "enabled"),
desc="get_push_rules_enabled_for_user",
)
return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
def have_push_rules_changed_for_user(self, user_id, last_id):
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
return defer.succeed(False)
else:
def have_push_rules_changed_txn(txn):
sql = (
"SELECT COUNT(stream_id) FROM push_rules_stream"
" WHERE user_id = ? AND ? < stream_id"
)
txn.execute(sql, (user_id, last_id))
(count,) = txn.fetchone()
return bool(count)
return self.db.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn
)
@cachedList(
cached_method_name="get_push_rules_for_user",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
)
def bulk_get_push_rules(self, user_ids):
if not user_ids:
return {}
results = {user_id: [] for user_id in user_ids}
rows = yield self.db.simple_select_many_batch(
table="push_rules",
column="user_name",
iterable=user_ids,
retcols=("*",),
desc="bulk_get_push_rules",
)
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
for row in rows:
results.setdefault(row["user_name"], []).append(row)
enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
for user_id, rules in results.items():
results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
return results
@defer.inlineCallbacks
def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
"""Copy a single push rule from one room to another for a specific user.
Args:
new_room_id (str): ID of the new room.
user_id (str): ID of user the push rule belongs to.
rule (Dict): A push rule.
"""
# Create new rule id
rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
new_rule_id = rule_id_scope + "/" + new_room_id
# Change room id in each condition
for condition in rule.get("conditions", []):
if condition.get("key") == "room_id":
condition["pattern"] = new_room_id
# Add the rule for the new room
yield self.add_push_rule(
user_id=user_id,
rule_id=new_rule_id,
priority_class=rule["priority_class"],
conditions=rule["conditions"],
actions=rule["actions"],
)
@defer.inlineCallbacks
def copy_push_rules_from_room_to_room_for_user(
self, old_room_id, new_room_id, user_id
):
"""Copy all of the push rules from one room to another for a specific
user.
Args:
old_room_id (str): ID of the old room.
new_room_id (str): ID of the new room.
user_id (str): ID of user to copy push rules for.
"""
# Retrieve push rules for this user
user_push_rules = yield self.get_push_rules_for_user(user_id)
# Get rules relating to the old room and copy them to the new room
for rule in user_push_rules:
conditions = rule.get("conditions", [])
if any(
(c.get("key") == "room_id" and c.get("pattern") == old_room_id)
for c in conditions
):
yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
@defer.inlineCallbacks
def bulk_get_push_rules_for_room(self, event, context):
state_group = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()
current_state_ids = yield context.get_current_state_ids()
result = yield self._bulk_get_push_rules_for_room(
event.room_id, state_group, current_state_ids, event=event
)
return result
@cachedInlineCallbacks(num_args=2, cache_context=True)
def _bulk_get_push_rules_for_room(
self, room_id, state_group, current_state_ids, cache_context, event=None
):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
# We also will want to generate notifs for other people in the room so
# their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room.
users_in_room = yield self._get_joined_users_from_context(
room_id,
state_group,
current_state_ids,
on_invalidate=cache_context.invalidate,
event=event,
)
# We ignore app service users for now. This is so that we don't fill
# up the `get_if_users_have_pushers` cache with AS entries that we
# know don't have pushers, nor even read receipts.
local_users_in_room = {
u
for u in users_in_room
if self.hs.is_mine_id(u)
and not self.get_if_app_services_interested_in_user(u)
}
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
if_users_with_pushers = yield self.get_if_users_have_pushers(
local_users_in_room, on_invalidate=cache_context.invalidate
)
user_ids = {
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
}
users_with_receipts = yield self.get_users_with_read_receipts_in_room(
room_id, on_invalidate=cache_context.invalidate
)
# any users with pushers must be ours: they have pushers
for uid in users_with_receipts:
if uid in local_users_in_room:
user_ids.add(uid)
rules_by_user = yield self.bulk_get_push_rules(
user_ids, on_invalidate=cache_context.invalidate
)
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
return rules_by_user
@cachedList(
cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
)
def bulk_get_push_rules_enabled(self, user_ids):
if not user_ids:
return {}
results = {user_id: {} for user_id in user_ids}
rows = yield self.db.simple_select_many_batch(
table="push_rules_enable",
column="user_name",
iterable=user_ids,
retcols=("user_name", "rule_id", "enabled"),
desc="bulk_get_push_rules_enabled",
)
for row in rows:
enabled = bool(row["enabled"])
results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
return results
def get_all_push_rule_updates(self, last_id, current_id, limit):
"""Get all the push rules changes that have happend on the server"""
if last_id == current_id:
return defer.succeed([])
def get_all_push_rule_updates_txn(txn):
sql = (
"SELECT stream_id, event_stream_ordering, user_id, rule_id,"
" op, priority_class, priority, conditions, actions"
" FROM push_rules_stream"
" WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
return self.db.runInteraction(
"get_all_push_rule_updates", get_all_push_rule_updates_txn
)
class PushRuleStore(PushRulesWorkerStore):
@defer.inlineCallbacks
def add_push_rule(
self,
user_id,
rule_id,
priority_class,
conditions,
actions,
before=None,
after=None,
):
conditions_json = json.dumps(conditions)
actions_json = json.dumps(actions)
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
if before or after:
yield self.db.runInteraction(
"_add_push_rule_relative_txn",
self._add_push_rule_relative_txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
priority_class,
conditions_json,
actions_json,
before,
after,
)
else:
yield self.db.runInteraction(
"_add_push_rule_highest_priority_txn",
self._add_push_rule_highest_priority_txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
priority_class,
conditions_json,
actions_json,
)
def _add_push_rule_relative_txn(
self,
txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
priority_class,
conditions_json,
actions_json,
before,
after,
):
# Lock the table since otherwise we'll have annoying races between the
# SELECT here and the UPSERT below.
self.database_engine.lock_table(txn, "push_rules")
relative_to_rule = before or after
res = self.db.simple_select_one_txn(
txn,
table="push_rules",
keyvalues={"user_name": user_id, "rule_id": relative_to_rule},
retcols=["priority_class", "priority"],
allow_none=True,
)
if not res:
raise RuleNotFoundException(
"before/after rule not found: %s" % (relative_to_rule,)
)
base_priority_class = res["priority_class"]
base_rule_priority = res["priority"]
if base_priority_class != priority_class:
raise InconsistentRuleException(
"Given priority class does not match class of relative rule"
)
if before:
# Higher priority rules are executed first, So adding a rule before
# a rule means giving it a higher priority than that rule.
new_rule_priority = base_rule_priority + 1
else:
# We increment the priority of the existing rules to make space for
# the new rule. Therefore if we want this rule to appear after
# an existing rule we give it the priority of the existing rule,
# and then increment the priority of the existing rule.
new_rule_priority = base_rule_priority
sql = (
"UPDATE push_rules SET priority = priority + 1"
" WHERE user_name = ? AND priority_class = ? AND priority >= ?"
)
txn.execute(sql, (user_id, priority_class, new_rule_priority))
self._upsert_push_rule_txn(
txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
priority_class,
new_rule_priority,
conditions_json,
actions_json,
)
def _add_push_rule_highest_priority_txn(
self,
txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
priority_class,
conditions_json,
actions_json,
):
# Lock the table since otherwise we'll have annoying races between the
# SELECT here and the UPSERT below.
self.database_engine.lock_table(txn, "push_rules")
# find the highest priority rule in that class
sql = (
"SELECT COUNT(*), MAX(priority) FROM push_rules"
" WHERE user_name = ? and priority_class = ?"
)
txn.execute(sql, (user_id, priority_class))
res = txn.fetchall()
(how_many, highest_prio) = res[0]
new_prio = 0
if how_many > 0:
new_prio = highest_prio + 1
self._upsert_push_rule_txn(
txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
priority_class,
new_prio,
conditions_json,
actions_json,
)
def _upsert_push_rule_txn(
self,
txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
priority_class,
priority,
conditions_json,
actions_json,
update_stream=True,
):
"""Specialised version of simple_upsert_txn that picks a push_rule_id
using the _push_rule_id_gen if it needs to insert the rule. It assumes
that the "push_rules" table is locked"""
sql = (
"UPDATE push_rules"
" SET priority_class = ?, priority = ?, conditions = ?, actions = ?"
" WHERE user_name = ? AND rule_id = ?"
)
txn.execute(
sql,
(priority_class, priority, conditions_json, actions_json, user_id, rule_id),
)
if txn.rowcount == 0:
# We didn't update a row with the given rule_id so insert one
push_rule_id = self._push_rule_id_gen.get_next()
self.db.simple_insert_txn(
txn,
table="push_rules",
values={
"id": push_rule_id,
"user_name": user_id,
"rule_id": rule_id,
"priority_class": priority_class,
"priority": priority,
"conditions": conditions_json,
"actions": actions_json,
},
)
if update_stream:
self._insert_push_rules_update_txn(
txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
op="ADD",
data={
"priority_class": priority_class,
"priority": priority,
"conditions": conditions_json,
"actions": actions_json,
},
)
@defer.inlineCallbacks
def delete_push_rule(self, user_id, rule_id):
"""
Delete a push rule. Args specify the row to be deleted and can be
any of the columns in the push_rule table, but below are the
standard ones
Args:
user_id (str): The matrix ID of the push rule owner
rule_id (str): The rule_id of the rule to be deleted
"""
def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
self.db.simple_delete_one_txn(
txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}
)
self._insert_push_rules_update_txn(
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
)
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
yield self.db.runInteraction(
"delete_push_rule",
delete_push_rule_txn,
stream_id,
event_stream_ordering,
)
@defer.inlineCallbacks
def set_push_rule_enabled(self, user_id, rule_id, enabled):
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
yield self.db.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
enabled,
)
def _set_push_rule_enabled_txn(
self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
):
new_id = self._push_rules_enable_id_gen.get_next()
self.db.simple_upsert_txn(
txn,
"push_rules_enable",
{"user_name": user_id, "rule_id": rule_id},
{"enabled": 1 if enabled else 0},
{"id": new_id},
)
self._insert_push_rules_update_txn(
txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
op="ENABLE" if enabled else "DISABLE",
)
@defer.inlineCallbacks
def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
actions_json = json.dumps(actions)
def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
if is_default_rule:
# Add a dummy rule to the rules table with the user specified
# actions.
priority_class = -1
priority = 1
self._upsert_push_rule_txn(
txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
priority_class,
priority,
"[]",
actions_json,
update_stream=False,
)
else:
self.db.simple_update_one_txn(
txn,
"push_rules",
{"user_name": user_id, "rule_id": rule_id},
{"actions": actions_json},
)
self._insert_push_rules_update_txn(
txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
op="ACTIONS",
data={"actions": actions_json},
)
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
yield self.db.runInteraction(
"set_push_rule_actions",
set_push_rule_actions_txn,
stream_id,
event_stream_ordering,
)
def _insert_push_rules_update_txn(
self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None
):
values = {
"stream_id": stream_id,
"event_stream_ordering": event_stream_ordering,
"user_id": user_id,
"rule_id": rule_id,
"op": op,
}
if data is not None:
values.update(data)
self.db.simple_insert_txn(txn, "push_rules_stream", values=values)
txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
txn.call_after(
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
)
def get_push_rules_stream_token(self):
"""Get the position of the push rules stream.
Returns a pair of a stream id for the push_rules stream and the
room stream ordering it corresponds to."""
return self._push_rules_stream_id_gen.get_current_token()
def get_max_push_rules_stream_id(self):
return self.get_push_rules_stream_token()[0]

View file

@ -15,52 +15,42 @@
# limitations under the License.
import logging
import six
from typing import Iterable, Iterator
from canonicaljson import encode_canonical_json, json
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
if six.PY2:
db_binary_type = six.moves.builtins.buffer
else:
db_binary_type = memoryview
class PusherWorkerStore(SQLBaseStore):
def _decode_pushers_rows(self, rows):
def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]:
"""JSON-decode the data in the rows returned from the `pushers` table
Drops any rows whose data cannot be decoded
"""
for r in rows:
dataJson = r["data"]
r["data"] = None
try:
if isinstance(dataJson, db_binary_type):
dataJson = str(dataJson).decode("UTF8")
r["data"] = json.loads(dataJson)
except Exception as e:
logger.warn(
logger.warning(
"Invalid JSON in data for pusher %d: %s, %s",
r["id"],
dataJson,
e.args[0],
)
pass
continue
if isinstance(r["pushkey"], db_binary_type):
r["pushkey"] = str(r["pushkey"]).decode("UTF8")
return rows
yield r
@defer.inlineCallbacks
def user_has_pusher(self, user_id):
ret = yield self._simple_select_one_onecol(
ret = yield self.db.simple_select_one_onecol(
"pushers", {"user_name": user_id}, "id", allow_none=True
)
return ret is not None
@ -73,7 +63,7 @@ class PusherWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_pushers_by(self, keyvalues):
ret = yield self._simple_select_list(
ret = yield self.db.simple_select_list(
"pushers",
keyvalues,
[
@ -101,11 +91,11 @@ class PusherWorkerStore(SQLBaseStore):
def get_all_pushers(self):
def get_pushers(txn):
txn.execute("SELECT * FROM pushers")
rows = self.cursor_to_dict(txn)
rows = self.db.cursor_to_dict(txn)
return self._decode_pushers_rows(rows)
rows = yield self.runInteraction("get_all_pushers", get_pushers)
rows = yield self.db.runInteraction("get_all_pushers", get_pushers)
return rows
def get_all_updated_pushers(self, last_id, current_id, limit):
@ -135,7 +125,7 @@ class PusherWorkerStore(SQLBaseStore):
return updated, deleted
return self.runInteraction(
return self.db.runInteraction(
"get_all_updated_pushers", get_all_updated_pushers_txn
)
@ -178,7 +168,7 @@ class PusherWorkerStore(SQLBaseStore):
return results
return self.runInteraction(
return self.db.runInteraction(
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
)
@ -194,7 +184,7 @@ class PusherWorkerStore(SQLBaseStore):
inlineCallbacks=True,
)
def get_if_users_have_pushers(self, user_ids):
rows = yield self._simple_select_many_batch(
rows = yield self.db.simple_select_many_batch(
table="pushers",
column="user_name",
iterable=user_ids,
@ -207,6 +197,84 @@ class PusherWorkerStore(SQLBaseStore):
return result
@defer.inlineCallbacks
def update_pusher_last_stream_ordering(
self, app_id, pushkey, user_id, last_stream_ordering
):
yield self.db.simple_update_one(
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
{"last_stream_ordering": last_stream_ordering},
desc="update_pusher_last_stream_ordering",
)
@defer.inlineCallbacks
def update_pusher_last_stream_ordering_and_success(
self, app_id, pushkey, user_id, last_stream_ordering, last_success
):
"""Update the last stream ordering position we've processed up to for
the given pusher.
Args:
app_id (str)
pushkey (str)
last_stream_ordering (int)
last_success (int)
Returns:
Deferred[bool]: True if the pusher still exists; False if it has been deleted.
"""
updated = yield self.db.simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={
"last_stream_ordering": last_stream_ordering,
"last_success": last_success,
},
desc="update_pusher_last_stream_ordering_and_success",
)
return bool(updated)
@defer.inlineCallbacks
def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
yield self.db.simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={"failing_since": failing_since},
desc="update_pusher_failing_since",
)
@defer.inlineCallbacks
def get_throttle_params_by_room(self, pusher_id):
res = yield self.db.simple_select_list(
"pusher_throttle",
{"pusher": pusher_id},
["room_id", "last_sent_ts", "throttle_ms"],
desc="get_throttle_params_by_room",
)
params_by_room = {}
for row in res:
params_by_room[row["room_id"]] = {
"last_sent_ts": row["last_sent_ts"],
"throttle_ms": row["throttle_ms"],
}
return params_by_room
@defer.inlineCallbacks
def set_throttle_params(self, pusher_id, room_id, params):
# no need to lock because `pusher_throttle` has a primary key on
# (pusher, room_id) so simple_upsert will retry
yield self.db.simple_upsert(
"pusher_throttle",
{"pusher": pusher_id, "room_id": room_id},
params,
desc="set_throttle_params",
lock=False,
)
class PusherStore(PusherWorkerStore):
def get_pushers_stream_token(self):
@ -230,8 +298,8 @@ class PusherStore(PusherWorkerStore):
):
with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so _simple_upsert will retry
yield self._simple_upsert(
# (app_id, pushkey, user_name) so simple_upsert will retry
yield self.db.simple_upsert(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
values={
@ -241,7 +309,7 @@ class PusherStore(PusherWorkerStore):
"device_display_name": device_display_name,
"ts": pushkey_ts,
"lang": lang,
"data": encode_canonical_json(data),
"data": bytearray(encode_canonical_json(data)),
"last_stream_ordering": last_stream_ordering,
"profile_tag": profile_tag,
"id": stream_id,
@ -256,7 +324,7 @@ class PusherStore(PusherWorkerStore):
if user_has_pusher is not True:
# invalidate, since we the user might not have had a pusher before
yield self.runInteraction(
yield self.db.runInteraction(
"add_pusher",
self._invalidate_cache_and_stream,
self.get_if_user_has_pusher,
@ -270,7 +338,7 @@ class PusherStore(PusherWorkerStore):
txn, self.get_if_user_has_pusher, (user_id,)
)
self._simple_delete_one_txn(
self.db.simple_delete_one_txn(
txn,
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
@ -279,7 +347,7 @@ class PusherStore(PusherWorkerStore):
# it's possible for us to end up with duplicate rows for
# (app_id, pushkey, user_id) at different stream_ids, but that
# doesn't really matter.
self._simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="deleted_pushers",
values={
@ -291,82 +359,4 @@ class PusherStore(PusherWorkerStore):
)
with self._pushers_id_gen.get_next() as stream_id:
yield self.runInteraction("delete_pusher", delete_pusher_txn, stream_id)
@defer.inlineCallbacks
def update_pusher_last_stream_ordering(
self, app_id, pushkey, user_id, last_stream_ordering
):
yield self._simple_update_one(
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
{"last_stream_ordering": last_stream_ordering},
desc="update_pusher_last_stream_ordering",
)
@defer.inlineCallbacks
def update_pusher_last_stream_ordering_and_success(
self, app_id, pushkey, user_id, last_stream_ordering, last_success
):
"""Update the last stream ordering position we've processed up to for
the given pusher.
Args:
app_id (str)
pushkey (str)
last_stream_ordering (int)
last_success (int)
Returns:
Deferred[bool]: True if the pusher still exists; False if it has been deleted.
"""
updated = yield self._simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={
"last_stream_ordering": last_stream_ordering,
"last_success": last_success,
},
desc="update_pusher_last_stream_ordering_and_success",
)
return bool(updated)
@defer.inlineCallbacks
def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
yield self._simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={"failing_since": failing_since},
desc="update_pusher_failing_since",
)
@defer.inlineCallbacks
def get_throttle_params_by_room(self, pusher_id):
res = yield self._simple_select_list(
"pusher_throttle",
{"pusher": pusher_id},
["room_id", "last_sent_ts", "throttle_ms"],
desc="get_throttle_params_by_room",
)
params_by_room = {}
for row in res:
params_by_room[row["room_id"]] = {
"last_sent_ts": row["last_sent_ts"],
"throttle_ms": row["throttle_ms"],
}
return params_by_room
@defer.inlineCallbacks
def set_throttle_params(self, pusher_id, room_id, params):
# no need to lock because `pusher_throttle` has a primary key on
# (pusher, room_id) so _simple_upsert will retry
yield self._simple_upsert(
"pusher_throttle",
{"pusher": pusher_id, "room_id": room_id},
params,
desc="set_throttle_params",
lock=False,
)
yield self.db.runInteraction("delete_pusher", delete_pusher_txn, stream_id)

View file

@ -21,12 +21,12 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import SQLBaseStore
from .util.id_generators import StreamIdGenerator
logger = logging.getLogger(__name__)
@ -39,8 +39,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
def __init__(self, db_conn, hs):
super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
def __init__(self, database: Database, db_conn, hs):
super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs)
self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
@ -58,11 +58,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cachedInlineCallbacks()
def get_users_with_read_receipts_in_room(self, room_id):
receipts = yield self.get_receipts_for_room(room_id, "m.read")
return set(r["user_id"] for r in receipts)
return {r["user_id"] for r in receipts}
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
return self._simple_select_list(
return self.db.simple_select_list(
table="receipts_linearized",
keyvalues={"room_id": room_id, "receipt_type": receipt_type},
retcols=("user_id", "event_id"),
@ -71,7 +71,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(num_args=3)
def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
return self._simple_select_one_onecol(
return self.db.simple_select_one_onecol(
table="receipts_linearized",
keyvalues={
"room_id": room_id,
@ -85,7 +85,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=2)
def get_receipts_for_user(self, user_id, receipt_type):
rows = yield self._simple_select_list(
rows = yield self.db.simple_select_list(
table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
retcols=("room_id", "event_id"),
@ -109,7 +109,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,))
return txn.fetchall()
rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f)
rows = yield self.db.runInteraction("get_receipts_for_user_with_orderings", f)
return {
row[0]: {
"event_id": row[1],
@ -188,11 +188,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (room_id, to_key))
rows = self.cursor_to_dict(txn)
rows = self.db.cursor_to_dict(txn)
return rows
rows = yield self.runInteraction("get_linearized_receipts_for_room", f)
rows = yield self.db.runInteraction("get_linearized_receipts_for_room", f)
if not rows:
return []
@ -217,28 +217,32 @@ class ReceiptsWorkerStore(SQLBaseStore):
def f(txn):
if from_key:
sql = (
"SELECT * FROM receipts_linearized WHERE"
" room_id IN (%s) AND stream_id > ? AND stream_id <= ?"
) % (",".join(["?"] * len(room_ids)))
args = list(room_ids)
args.extend([from_key, to_key])
sql = """
SELECT * FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ? AND
"""
clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", room_ids
)
txn.execute(sql, args)
txn.execute(sql + clause, [from_key, to_key] + list(args))
else:
sql = (
"SELECT * FROM receipts_linearized WHERE"
" room_id IN (%s) AND stream_id <= ?"
) % (",".join(["?"] * len(room_ids)))
sql = """
SELECT * FROM receipts_linearized WHERE
stream_id <= ? AND
"""
args = list(room_ids)
args.append(to_key)
clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", room_ids
)
txn.execute(sql, args)
txn.execute(sql + clause, [to_key] + list(args))
return self.cursor_to_dict(txn)
return self.db.cursor_to_dict(txn)
txn_results = yield self.runInteraction("_get_linearized_receipts_for_rooms", f)
txn_results = yield self.db.runInteraction(
"_get_linearized_receipts_for_rooms", f
)
results = {}
for row in txn_results:
@ -279,9 +283,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
args.append(limit)
txn.execute(sql, args)
return (r[0:5] + (json.loads(r[5]),) for r in txn)
return [r[0:5] + (json.loads(r[5]),) for r in txn]
return self.runInteraction(
return self.db.runInteraction(
"get_all_updated_receipts", get_all_updated_receipts_txn
)
@ -312,14 +316,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
class ReceiptsStore(ReceiptsWorkerStore):
def __init__(self, db_conn, hs):
def __init__(self, database: Database, db_conn, hs):
# We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
)
super(ReceiptsStore, self).__init__(db_conn, hs)
super(ReceiptsStore, self).__init__(database, db_conn, hs)
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
@ -334,7 +338,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
otherwise, the rx timestamp of the event that the RR corresponds to
(or 0 if the event is unknown)
"""
res = self._simple_select_one_txn(
res = self.db.simple_select_one_txn(
txn,
table="events",
retcols=["stream_ordering", "received_ts"],
@ -387,7 +391,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
(user_id, room_id, receipt_type),
)
self._simple_delete_txn(
self.db.simple_upsert_txn(
txn,
table="receipts_linearized",
keyvalues={
@ -395,19 +399,14 @@ class ReceiptsStore(ReceiptsWorkerStore):
"receipt_type": receipt_type,
"user_id": user_id,
},
)
self._simple_insert_txn(
txn,
table="receipts_linearized",
values={
"stream_id": stream_id,
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
"event_id": event_id,
"data": json.dumps(data),
},
# receipts_linearized has a unique constraint on
# (user_id, room_id, receipt_type), so no need to lock
lock=False,
)
if receipt_type == "m.read" and stream_ordering is not None:
@ -433,26 +432,32 @@ class ReceiptsStore(ReceiptsWorkerStore):
# we need to points in graph -> linearized form.
# TODO: Make this better.
def graph_to_linear(txn):
query = (
"SELECT event_id WHERE room_id = ? AND stream_ordering IN ("
" SELECT max(stream_ordering) WHERE event_id IN (%s)"
")"
) % (",".join(["?"] * len(event_ids)))
clause, args = make_in_list_sql_clause(
self.database_engine, "event_id", event_ids
)
txn.execute(query, [room_id] + event_ids)
sql = """
SELECT event_id WHERE room_id = ? AND stream_ordering IN (
SELECT max(stream_ordering) WHERE %s
)
""" % (
clause,
)
txn.execute(sql, [room_id] + list(args))
rows = txn.fetchall()
if rows:
return rows[0][0]
else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
linearized_event_id = yield self.runInteraction(
linearized_event_id = yield self.db.runInteraction(
"insert_receipt_conv", graph_to_linear
)
stream_id_manager = self._receipts_id_gen.get_next()
with stream_id_manager as stream_id:
event_ts = yield self.runInteraction(
event_ts = yield self.db.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
room_id,
@ -481,7 +486,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
return stream_id, max_persisted_id
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
return self.runInteraction(
return self.db.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
room_id,
@ -507,7 +512,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
)
self._simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="receipts_graph",
keyvalues={
@ -516,7 +521,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"user_id": user_id,
},
)
self._simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="receipts_graph",
values={

View file

@ -15,25 +15,14 @@
import logging
from ._base import SQLBaseStore
from synapse.storage._base import SQLBaseStore
logger = logging.getLogger(__name__)
class RejectionsStore(SQLBaseStore):
def _store_rejections_txn(self, txn, event_id, reason):
self._simple_insert_txn(
txn,
table="rejections",
values={
"event_id": event_id,
"reason": reason,
"last_check": self._clock.time_msec(),
},
)
def get_rejection_reason(self, event_id):
return self._simple_select_one_onecol(
return self.db.simple_select_one_onecol(
table="rejections",
retcol="reason",
keyvalues={"event_id": event_id},

View file

@ -0,0 +1,327 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import attr
from synapse.api.constants import RelationTypes
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.stream import generate_pagination_where_clause
from synapse.storage.relations import (
AggregationPaginationToken,
PaginationChunk,
RelationPaginationToken,
)
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
logger = logging.getLogger(__name__)
class RelationsWorkerStore(SQLBaseStore):
@cached(tree=True)
def get_relations_for_event(
self,
event_id,
relation_type=None,
event_type=None,
aggregation_key=None,
limit=5,
direction="b",
from_token=None,
to_token=None,
):
"""Get a list of relations for an event, ordered by topological ordering.
Args:
event_id (str): Fetch events that relate to this event ID.
relation_type (str|None): Only fetch events with this relation
type, if given.
event_type (str|None): Only fetch events with this event type, if
given.
aggregation_key (str|None): Only fetch events with this aggregation
key, if given.
limit (int): Only fetch the most recent `limit` events.
direction (str): Whether to fetch the most recent first (`"b"`) or
the oldest first (`"f"`).
from_token (RelationPaginationToken|None): Fetch rows from the given
token, or from the start if None.
to_token (RelationPaginationToken|None): Fetch rows up to the given
token, or up to the end if None.
Returns:
Deferred[PaginationChunk]: List of event IDs that match relations
requested. The rows are of the form `{"event_id": "..."}`.
"""
where_clause = ["relates_to_id = ?"]
where_args = [event_id]
if relation_type is not None:
where_clause.append("relation_type = ?")
where_args.append(relation_type)
if event_type is not None:
where_clause.append("type = ?")
where_args.append(event_type)
if aggregation_key:
where_clause.append("aggregation_key = ?")
where_args.append(aggregation_key)
pagination_clause = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
from_token=attr.astuple(from_token) if from_token else None,
to_token=attr.astuple(to_token) if to_token else None,
engine=self.database_engine,
)
if pagination_clause:
where_clause.append(pagination_clause)
if direction == "b":
order = "DESC"
else:
order = "ASC"
sql = """
SELECT event_id, topological_ordering, stream_ordering
FROM event_relations
INNER JOIN events USING (event_id)
WHERE %s
ORDER BY topological_ordering %s, stream_ordering %s
LIMIT ?
""" % (
" AND ".join(where_clause),
order,
order,
)
def _get_recent_references_for_event_txn(txn):
txn.execute(sql, where_args + [limit + 1])
last_topo_id = None
last_stream_id = None
events = []
for row in txn:
events.append({"event_id": row[0]})
last_topo_id = row[1]
last_stream_id = row[2]
next_batch = None
if len(events) > limit and last_topo_id and last_stream_id:
next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
return PaginationChunk(
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
)
return self.db.runInteraction(
"get_recent_references_for_event", _get_recent_references_for_event_txn
)
@cached(tree=True)
def get_aggregation_groups_for_event(
self,
event_id,
event_type=None,
limit=5,
direction="b",
from_token=None,
to_token=None,
):
"""Get a list of annotations on the event, grouped by event type and
aggregation key, sorted by count.
This is used e.g. to get the what and how many reactions have happend
on an event.
Args:
event_id (str): Fetch events that relate to this event ID.
event_type (str|None): Only fetch events with this event type, if
given.
limit (int): Only fetch the `limit` groups.
direction (str): Whether to fetch the highest count first (`"b"`) or
the lowest count first (`"f"`).
from_token (AggregationPaginationToken|None): Fetch rows from the
given token, or from the start if None.
to_token (AggregationPaginationToken|None): Fetch rows up to the
given token, or up to the end if None.
Returns:
Deferred[PaginationChunk]: List of groups of annotations that
match. Each row is a dict with `type`, `key` and `count` fields.
"""
where_clause = ["relates_to_id = ?", "relation_type = ?"]
where_args = [event_id, RelationTypes.ANNOTATION]
if event_type:
where_clause.append("type = ?")
where_args.append(event_type)
having_clause = generate_pagination_where_clause(
direction=direction,
column_names=("COUNT(*)", "MAX(stream_ordering)"),
from_token=attr.astuple(from_token) if from_token else None,
to_token=attr.astuple(to_token) if to_token else None,
engine=self.database_engine,
)
if direction == "b":
order = "DESC"
else:
order = "ASC"
if having_clause:
having_clause = "HAVING " + having_clause
else:
having_clause = ""
sql = """
SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering)
FROM event_relations
INNER JOIN events USING (event_id)
WHERE {where_clause}
GROUP BY relation_type, type, aggregation_key
{having_clause}
ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
LIMIT ?
""".format(
where_clause=" AND ".join(where_clause),
order=order,
having_clause=having_clause,
)
def _get_aggregation_groups_for_event_txn(txn):
txn.execute(sql, where_args + [limit + 1])
next_batch = None
events = []
for row in txn:
events.append({"type": row[0], "key": row[1], "count": row[2]})
next_batch = AggregationPaginationToken(row[2], row[3])
if len(events) <= limit:
next_batch = None
return PaginationChunk(
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
)
return self.db.runInteraction(
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
)
@cachedInlineCallbacks()
def get_applicable_edit(self, event_id):
"""Get the most recent edit (if any) that has happened for the given
event.
Correctly handles checking whether edits were allowed to happen.
Args:
event_id (str): The original event ID
Returns:
Deferred[EventBase|None]: Returns the most recent edit, if any.
"""
# We only allow edits for `m.room.message` events that have the same sender
# and event type. We can't assert these things during regular event auth so
# we have to do the checks post hoc.
# Fetches latest edit that has the same type and sender as the
# original, and is an `m.room.message`.
sql = """
SELECT edit.event_id FROM events AS edit
INNER JOIN event_relations USING (event_id)
INNER JOIN events AS original ON
original.event_id = relates_to_id
AND edit.type = original.type
AND edit.sender = original.sender
WHERE
relates_to_id = ?
AND relation_type = ?
AND edit.type = 'm.room.message'
ORDER by edit.origin_server_ts DESC, edit.event_id DESC
LIMIT 1
"""
def _get_applicable_edit_txn(txn):
txn.execute(sql, (event_id, RelationTypes.REPLACE))
row = txn.fetchone()
if row:
return row[0]
edit_id = yield self.db.runInteraction(
"get_applicable_edit", _get_applicable_edit_txn
)
if not edit_id:
return
edit_event = yield self.get_event(edit_id, allow_none=True)
return edit_event
def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
"""Check if a user has already annotated an event with the same key
(e.g. already liked an event).
Args:
parent_id (str): The event being annotated
event_type (str): The event type of the annotation
aggregation_key (str): The aggregation key of the annotation
sender (str): The sender of the annotation
Returns:
Deferred[bool]
"""
sql = """
SELECT 1 FROM event_relations
INNER JOIN events USING (event_id)
WHERE
relates_to_id = ?
AND relation_type = ?
AND type = ?
AND sender = ?
AND aggregation_key = ?
LIMIT 1;
"""
def _get_if_user_has_annotated_event(txn):
txn.execute(
sql,
(
parent_id,
RelationTypes.ANNOTATION,
event_type,
sender,
aggregation_key,
),
)
return bool(txn.fetchone())
return self.db.runInteraction(
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
class RelationsStore(RelationsWorkerStore):
pass

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -20,7 +20,6 @@ DROP INDEX IF EXISTS events_room_id; -- Prefix of events_room_stream
DROP INDEX IF EXISTS events_order; -- Prefix of events_order_topo_stream_room
DROP INDEX IF EXISTS events_topological_ordering; -- Prefix of events_order_topo_stream_room
DROP INDEX IF EXISTS events_stream_ordering; -- Duplicate of PRIMARY KEY
DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY
DROP INDEX IF EXISTS event_to_state_groups_id; -- Duplicate of PRIMARY KEY
DROP INDEX IF EXISTS event_push_actions_room_id_event_id_user_id_profile_tag; -- Duplicate of UNIQUE CONSTRAINT

Some files were not shown because too many files have changed in this diff Show more