Merge branch 'develop' into e2e_backups

This commit is contained in:
Hubert Chathi 2018-08-24 11:44:26 -04:00 committed by GitHub
commit 83caead95a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
232 changed files with 7197 additions and 4107 deletions

View file

@ -40,6 +40,7 @@ from .filtering import FilteringStore
from .group_server import GroupServerStore
from .keys import KeyStore
from .media_repository import MediaRepositoryStore
from .monthly_active_users import MonthlyActiveUsersStore
from .openid import OpenIdStore
from .presence import PresenceStore, UserPresenceState
from .profile import ProfileStore
@ -89,6 +90,7 @@ class DataStore(RoomMemberStore, RoomStore,
UserDirectoryStore,
GroupServerStore,
UserErasureStore,
MonthlyActiveUsersStore,
):
def __init__(self, db_conn, hs):
@ -96,7 +98,6 @@ class DataStore(RoomMemberStore, RoomStore,
self._clock = hs.get_clock()
self.database_engine = hs.database_engine
self.db_conn = db_conn
self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering",
extra_tables=[("local_invites", "stream_id")]
@ -269,31 +270,6 @@ class DataStore(RoomMemberStore, RoomStore,
return self.runInteraction("count_users", _count_users)
def count_monthly_users(self):
"""Counts the number of users who used this homeserver in the last 30 days
This method should be refactored with count_daily_users - the only
reason not to is waiting on definition of mau
Returns:
Defered[int]
"""
def _count_monthly_users(txn):
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
sql = """
SELECT COALESCE(count(*), 0) FROM (
SELECT user_id FROM user_ips
WHERE last_seen > ?
GROUP BY user_id
) u
"""
txn.execute(sql, (thirty_days_ago,))
count, = txn.fetchone()
return count
return self.runInteraction("count_monthly_users", _count_monthly_users)
def count_r30_users(self):
"""
Counts the number of 30 day retained users, defined as:-

View file

@ -1150,17 +1150,16 @@ class SQLBaseStore(object):
defer.returnValue(retval)
def get_user_count_txn(self, txn):
"""Get a total number of registerd users in the users list.
"""Get a total number of registered users in the users list.
Args:
txn : Transaction object
Returns:
defer.Deferred: resolves to int
int : number of users
"""
sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;"
txn.execute(sql_count)
count = txn.fetchone()[0]
defer.returnValue(count)
return txn.fetchone()[0]
def _simple_search_list(self, table, term, col, retcols,
desc="_simple_search_list"):

View file

@ -35,6 +35,7 @@ LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpStore(background_updates.BackgroundUpdateStore):
def __init__(self, db_conn, hs):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
keylen=4,
@ -74,6 +75,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"before", "shutdown", self._update_client_ips_batch
)
@defer.inlineCallbacks
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id,
now=None):
if not now:
@ -84,7 +86,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
last_seen = self.client_ip_last_seen.get(key)
except KeyError:
last_seen = None
yield self.populate_monthly_active_users(user_id)
# Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return
@ -94,6 +96,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
self._batch_row_update[key] = (user_agent, device_id, now)
def _update_client_ips_batch(self):
# If the DB pool has already terminated, don't try updating
if not self.hs.get_db_pool().running:
return
def update():
to_update = self._batch_row_update
self._batch_row_update = {}

View file

@ -38,7 +38,7 @@ from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import RoomStreamToken, get_domain_from_id
from synapse.util.async import ObservableDeferred
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
@ -485,9 +485,14 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
new_forward_extremeties=new_forward_extremeties,
)
persist_event_counter.inc(len(chunk))
synapse.metrics.event_persisted_position.set(
chunk[-1][0].internal_metadata.stream_ordering,
)
if not backfilled:
# backfilled events have negative stream orderings, so we don't
# want to set the event_persisted_position to that.
synapse.metrics.event_persisted_position.set(
chunk[-1][0].internal_metadata.stream_ordering,
)
for event, context in chunk:
if context.app_service:
origin_type = "local"
@ -700,9 +705,11 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
}
events_map = {ev.event_id: ev for ev, _ in events_context}
room_version = yield self.get_room_version(room_id)
logger.debug("calling resolve_state_groups from preserve_events")
res = yield self._state_resolution_handler.resolve_state_groups(
room_id, state_groups, events_map, get_events
room_id, room_version, state_groups, events_map, get_events
)
defer.returnValue((res.state, None))
@ -1430,88 +1437,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
(event.event_id, event.redacts)
)
@defer.inlineCallbacks
def have_events_in_timeline(self, event_ids):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
rows = yield self._simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
iterable=list(event_ids),
keyvalues={"outlier": False},
desc="have_events_in_timeline",
)
defer.returnValue(set(r["event_id"] for r in rows))
@defer.inlineCallbacks
def have_seen_events(self, event_ids):
"""Given a list of event ids, check if we have already processed them.
Args:
event_ids (iterable[str]):
Returns:
Deferred[set[str]]: The events we have already seen.
"""
results = set()
def have_seen_events_txn(txn, chunk):
sql = (
"SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
% (",".join("?" * len(chunk)), )
)
txn.execute(sql, chunk)
for (event_id, ) in txn:
results.add(event_id)
# break the input up into chunks of 100
input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
[]):
yield self.runInteraction(
"have_seen_events",
have_seen_events_txn,
chunk,
)
defer.returnValue(results)
def get_seen_events_with_rejections(self, event_ids):
"""Given a list of event ids, check if we rejected them.
Args:
event_ids (list[str])
Returns:
Deferred[dict[str, str|None):
Has an entry for each event id we already have seen. Maps to
the rejected reason string if we rejected the event, else maps
to None.
"""
if not event_ids:
return defer.succeed({})
def f(txn):
sql = (
"SELECT e.event_id, reason FROM events as e "
"LEFT JOIN rejections as r ON e.event_id = r.event_id "
"WHERE e.event_id = ?"
)
res = {}
for event_id in event_ids:
txn.execute(sql, (event_id,))
row = txn.fetchone()
if row:
_, rejected = row
res[event_id] = rejected
return res
return self.runInteraction("get_rejection_reasons", f)
@defer.inlineCallbacks
def count_daily_messages(self):
"""
@ -1988,7 +1913,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
max_depth = max(row[0] for row in rows)
if max_depth <= token.topological:
# We need to ensure we don't delete all the events from the datanase
# We need to ensure we don't delete all the events from the database
# otherwise we wouldn't be able to send any events (due to not
# having any backwards extremeties)
raise SynapseError(

View file

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
from collections import namedtuple
@ -442,3 +443,85 @@ class EventsWorkerStore(SQLBaseStore):
self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
defer.returnValue(cache_entry)
@defer.inlineCallbacks
def have_events_in_timeline(self, event_ids):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
rows = yield self._simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
iterable=list(event_ids),
keyvalues={"outlier": False},
desc="have_events_in_timeline",
)
defer.returnValue(set(r["event_id"] for r in rows))
@defer.inlineCallbacks
def have_seen_events(self, event_ids):
"""Given a list of event ids, check if we have already processed them.
Args:
event_ids (iterable[str]):
Returns:
Deferred[set[str]]: The events we have already seen.
"""
results = set()
def have_seen_events_txn(txn, chunk):
sql = (
"SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
% (",".join("?" * len(chunk)), )
)
txn.execute(sql, chunk)
for (event_id, ) in txn:
results.add(event_id)
# break the input up into chunks of 100
input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
[]):
yield self.runInteraction(
"have_seen_events",
have_seen_events_txn,
chunk,
)
defer.returnValue(results)
def get_seen_events_with_rejections(self, event_ids):
"""Given a list of event ids, check if we rejected them.
Args:
event_ids (list[str])
Returns:
Deferred[dict[str, str|None):
Has an entry for each event id we already have seen. Maps to
the rejected reason string if we rejected the event, else maps
to None.
"""
if not event_ids:
return defer.succeed({})
def f(txn):
sql = (
"SELECT e.event_id, reason FROM events as e "
"LEFT JOIN rejections as r ON e.event_id = r.event_id "
"WHERE e.event_id = ?"
)
res = {}
for event_id in event_ids:
txn.execute(sql, (event_id,))
row = txn.fetchone()
if row:
_, rejected = row
res[event_id] = rejected
return res
return self.runInteraction("get_rejection_reasons", f)

View file

@ -0,0 +1,222 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.util.caches.descriptors import cached
from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
# Number of msec of granularity to store the monthly_active_user timestamp
# This means it is not necessary to update the table on every request
LAST_SEEN_GRANULARITY = 60 * 60 * 1000
class MonthlyActiveUsersStore(SQLBaseStore):
def __init__(self, dbconn, hs):
super(MonthlyActiveUsersStore, self).__init__(None, hs)
self._clock = hs.get_clock()
self.hs = hs
self.reserved_users = ()
@defer.inlineCallbacks
def initialise_reserved_users(self, threepids):
# TODO Why can't I do this in init?
store = self.hs.get_datastore()
reserved_user_list = []
# Do not add more reserved users than the total allowable number
for tp in threepids[:self.hs.config.max_mau_value]:
user_id = yield store.get_user_id_by_threepid(
tp["medium"], tp["address"]
)
if user_id:
yield self.upsert_monthly_active_user(user_id)
reserved_user_list.append(user_id)
else:
logger.warning(
"mau limit reserved threepid %s not found in db" % tp
)
self.reserved_users = tuple(reserved_user_list)
@defer.inlineCallbacks
def reap_monthly_active_users(self):
"""
Cleans out monthly active user table to ensure that no stale
entries exist.
Returns:
Deferred[]
"""
def _reap_users(txn):
# Purge stale users
thirty_days_ago = (
int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
)
query_args = [thirty_days_ago]
base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?"
# Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
# when len(reserved_users) == 0. Works fine on sqlite.
if len(self.reserved_users) > 0:
# questionmarks is a hack to overcome sqlite not supporting
# tuples in 'WHERE IN %s'
questionmarks = '?' * len(self.reserved_users)
query_args.extend(self.reserved_users)
sql = base_sql + """ AND user_id NOT IN ({})""".format(
','.join(questionmarks)
)
else:
sql = base_sql
txn.execute(sql, query_args)
# If MAU user count still exceeds the MAU threshold, then delete on
# a least recently active basis.
# Note it is not possible to write this query using OFFSET due to
# incompatibilities in how sqlite and postgres support the feature.
# sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present
# While Postgres does not require 'LIMIT', but also does not support
# negative LIMIT values. So there is no way to write it that both can
# support
safe_guard = self.hs.config.max_mau_value - len(self.reserved_users)
# Must be greater than zero for postgres
safe_guard = safe_guard if safe_guard > 0 else 0
query_args = [safe_guard]
base_sql = """
DELETE FROM monthly_active_users
WHERE user_id NOT IN (
SELECT user_id FROM monthly_active_users
ORDER BY timestamp DESC
LIMIT ?
)
"""
# Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
# when len(reserved_users) == 0. Works fine on sqlite.
if len(self.reserved_users) > 0:
query_args.extend(self.reserved_users)
sql = base_sql + """ AND user_id NOT IN ({})""".format(
','.join(questionmarks)
)
else:
sql = base_sql
txn.execute(sql, query_args)
yield self.runInteraction("reap_monthly_active_users", _reap_users)
# It seems poor to invalidate the whole cache, Postgres supports
# 'Returning' which would allow me to invalidate only the
# specific users, but sqlite has no way to do this and instead
# I would need to SELECT and the DELETE which without locking
# is racy.
# Have resolved to invalidate the whole cache for now and do
# something about it if and when the perf becomes significant
self.user_last_seen_monthly_active.invalidate_all()
self.get_monthly_active_count.invalidate_all()
@cached(num_args=0)
def get_monthly_active_count(self):
"""Generates current count of monthly active users
Returns:
Defered[int]: Number of current monthly active users
"""
def _count_users(txn):
sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
txn.execute(sql)
count, = txn.fetchone()
return count
return self.runInteraction("count_users", _count_users)
@defer.inlineCallbacks
def upsert_monthly_active_user(self, user_id):
"""
Updates or inserts monthly active user member
Arguments:
user_id (str): user to add/update
Deferred[bool]: True if a new entry was created, False if an
existing one was updated.
"""
is_insert = yield self._simple_upsert(
desc="upsert_monthly_active_user",
table="monthly_active_users",
keyvalues={
"user_id": user_id,
},
values={
"timestamp": int(self._clock.time_msec()),
},
lock=False,
)
if is_insert:
self.user_last_seen_monthly_active.invalidate((user_id,))
self.get_monthly_active_count.invalidate(())
@cached(num_args=1)
def user_last_seen_monthly_active(self, user_id):
"""
Checks if a given user is part of the monthly active user group
Arguments:
user_id (str): user to add/update
Return:
Deferred[int] : timestamp since last seen, None if never seen
"""
return(self._simple_select_one_onecol(
table="monthly_active_users",
keyvalues={
"user_id": user_id,
},
retcol="timestamp",
allow_none=True,
desc="user_last_seen_monthly_active",
))
@defer.inlineCallbacks
def populate_monthly_active_users(self, user_id):
"""Checks on the state of monthly active user limits and optionally
add the user to the monthly active tables
Args:
user_id(str): the user_id to query
"""
if self.hs.config.limit_usage_by_mau:
is_trial = yield self.is_trial_user(user_id)
if is_trial:
# we don't track trial users in the MAU table.
return
last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id)
now = self.hs.get_clock().time_msec()
# We want to reduce to the total number of db writes, and are happy
# to trade accuracy of timestamp in order to lighten load. This means
# We always insert new users (where MAU threshold has not been reached),
# but only update if we have not previously seen the user for
# LAST_SEEN_GRANULARITY ms
if last_seen_timestamp is None:
count = yield self.get_monthly_active_count()
if count < self.hs.config.max_mau_value:
yield self.upsert_monthly_active_user(user_id)
elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
yield self.upsert_monthly_active_user(user_id)

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 50
SCHEMA_VERSION = 51
dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -71,8 +71,6 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_from_remote_profile_cache",
)
class ProfileStore(ProfileWorkerStore):
def create_profile(self, user_localpart):
return self._simple_insert(
table="profiles",
@ -96,6 +94,8 @@ class ProfileStore(ProfileWorkerStore):
desc="set_profile_avatar_url",
)
class ProfileStore(ProfileWorkerStore):
def add_remote_profile_cache(self, user_id, displayname, avatar_url):
"""Ensure we are caching the remote user's profiles.

View file

@ -26,6 +26,11 @@ from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
class RegistrationWorkerStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(RegistrationWorkerStore, self).__init__(db_conn, hs)
self.config = hs.config
@cached()
def get_user_by_id(self, user_id):
return self._simple_select_one(
@ -36,12 +41,33 @@ class RegistrationWorkerStore(SQLBaseStore):
retcols=[
"name", "password_hash", "is_guest",
"consent_version", "consent_server_notice_sent",
"appservice_id",
"appservice_id", "creation_ts",
],
allow_none=True,
desc="get_user_by_id",
)
@defer.inlineCallbacks
def is_trial_user(self, user_id):
"""Checks if user is in the "trial" period, i.e. within the first
N days of registration defined by `mau_trial_days` config
Args:
user_id (str)
Returns:
Deferred[bool]
"""
info = yield self.get_user_by_id(user_id)
if not info:
defer.returnValue(False)
now = self.clock.time_msec()
trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
defer.returnValue(is_trial)
@cached()
def get_user_by_access_token(self, token):
"""Get a user from the given access token.

View file

@ -41,6 +41,22 @@ RatelimitOverride = collections.namedtuple(
class RoomWorkerStore(SQLBaseStore):
def get_room(self, room_id):
"""Retrieve a room.
Args:
room_id (str): The ID of the room to retrieve.
Returns:
A namedtuple containing the room information, or an empty list.
"""
return self._simple_select_one(
table="rooms",
keyvalues={"room_id": room_id},
retcols=("room_id", "is_public", "creator"),
desc="get_room",
allow_none=True,
)
def get_public_room_ids(self):
return self._simple_select_onecol(
table="rooms",
@ -170,6 +186,35 @@ class RoomWorkerStore(SQLBaseStore):
desc="is_room_blocked",
)
@cachedInlineCallbacks(max_entries=10000)
def get_ratelimit_for_user(self, user_id):
"""Check if there are any overrides for ratelimiting for the given
user
Args:
user_id (str)
Returns:
RatelimitOverride if there is an override, else None. If the contents
of RatelimitOverride are None or 0 then ratelimitng has been
disabled for that user entirely.
"""
row = yield self._simple_select_one(
table="ratelimit_override",
keyvalues={"user_id": user_id},
retcols=("messages_per_second", "burst_count"),
allow_none=True,
desc="get_ratelimit_for_user",
)
if row:
defer.returnValue(RatelimitOverride(
messages_per_second=row["messages_per_second"],
burst_count=row["burst_count"],
))
else:
defer.returnValue(None)
class RoomStore(RoomWorkerStore, SearchStore):
@ -215,22 +260,6 @@ class RoomStore(RoomWorkerStore, SearchStore):
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
def get_room(self, room_id):
"""Retrieve a room.
Args:
room_id (str): The ID of the room to retrieve.
Returns:
A namedtuple containing the room information, or an empty list.
"""
return self._simple_select_one(
table="rooms",
keyvalues={"room_id": room_id},
retcols=("room_id", "is_public", "creator"),
desc="get_room",
allow_none=True,
)
@defer.inlineCallbacks
def set_room_is_public(self, room_id, is_public):
def set_room_is_public_txn(txn, next_id):
@ -469,35 +498,6 @@ class RoomStore(RoomWorkerStore, SearchStore):
"get_all_new_public_rooms", get_all_new_public_rooms
)
@cachedInlineCallbacks(max_entries=10000)
def get_ratelimit_for_user(self, user_id):
"""Check if there are any overrides for ratelimiting for the given
user
Args:
user_id (str)
Returns:
RatelimitOverride if there is an override, else None. If the contents
of RatelimitOverride are None or 0 then ratelimitng has been
disabled for that user entirely.
"""
row = yield self._simple_select_one(
table="ratelimit_override",
keyvalues={"user_id": user_id},
retcols=("messages_per_second", "burst_count"),
allow_none=True,
desc="get_ratelimit_for_user",
)
if row:
defer.returnValue(RatelimitOverride(
messages_per_second=row["messages_per_second"],
burst_count=row["burst_count"],
))
else:
defer.returnValue(None)
@defer.inlineCallbacks
def block_room(self, room_id, user_id):
yield self._simple_insert(

View file

@ -26,7 +26,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import get_domain_from_id
from synapse.util.async import Linearizer
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.stringutils import to_ascii

View file

@ -0,0 +1,27 @@
/* Copyright 2018 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- a table of monthly active users, for use where blocking based on mau limits
CREATE TABLE monthly_active_users (
user_id TEXT NOT NULL,
-- Last time we saw the user. Not guaranteed to be accurate due to rate limiting
-- on updates, Granularity of updates governed by
-- synapse.storage.monthly_active_users.LAST_SEEN_GRANULARITY
-- Measured in ms since epoch.
timestamp BIGINT NOT NULL
);
CREATE UNIQUE INDEX monthly_active_users_users ON monthly_active_users(user_id);
CREATE INDEX monthly_active_users_time_stamp ON monthly_active_users(timestamp);

View file

@ -21,15 +21,17 @@ from six.moves import range
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import NotFoundError
from synapse.storage._base import SQLBaseStore
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.events_worker import EventsWorkerStore
from synapse.util.caches import get_cache_factor_for, intern_string
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.stringutils import to_ascii
from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
@ -46,7 +48,8 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
return len(self.delta_ids) if self.delta_ids else 0
class StateGroupWorkerStore(SQLBaseStore):
# this inherits from EventsWorkerStore because it calls self.get_events
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""The parts of StateGroupStore that can be called from workers.
"""
@ -57,9 +60,68 @@ class StateGroupWorkerStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(StateGroupWorkerStore, self).__init__(db_conn, hs)
# Originally the state store used a single DictionaryCache to cache the
# event IDs for the state types in a given state group to avoid hammering
# on the state_group* tables.
#
# The point of using a DictionaryCache is that it can cache a subset
# of the state events for a given state group (i.e. a subset of the keys for a
# given dict which is an entry in the cache for a given state group ID).
#
# However, this poses problems when performing complicated queries
# on the store - for instance: "give me all the state for this group, but
# limit members to this subset of users", as DictionaryCache's API isn't
# rich enough to say "please cache any of these fields, apart from this subset".
# This is problematic when lazy loading members, which requires this behaviour,
# as without it the cache has no choice but to speculatively load all
# state events for the group, which negates the efficiency being sought.
#
# Rather than overcomplicating DictionaryCache's API, we instead split the
# state_group_cache into two halves - one for tracking non-member events,
# and the other for tracking member_events. This means that lazy loading
# queries can be made in a cache-friendly manner by querying both caches
# separately and then merging the result. So for the example above, you
# would query the members cache for a specific subset of state keys
# (which DictionaryCache will handle efficiently and fine) and the non-members
# cache for all state (which DictionaryCache will similarly handle fine)
# and then just merge the results together.
#
# We size the non-members cache to be smaller than the members cache as the
# vast majority of state in Matrix (today) is member events.
self._state_group_cache = DictionaryCache(
"*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache")
"*stateGroupCache*",
# TODO: this hasn't been tuned yet
50000 * get_cache_factor_for("stateGroupCache")
)
self._state_group_members_cache = DictionaryCache(
"*stateGroupMembersCache*",
500000 * get_cache_factor_for("stateGroupMembersCache")
)
@defer.inlineCallbacks
def get_room_version(self, room_id):
"""Get the room_version of a given room
Args:
room_id (str)
Returns:
Deferred[str]
Raises:
NotFoundError if the room is unknown
"""
# for now we do this by looking at the create event. We may want to cache this
# more intelligently in future.
state_ids = yield self.get_current_state_ids(room_id)
create_id = state_ids.get((EventTypes.Create, ""))
if not create_id:
raise NotFoundError("Unknown room")
create_event = yield self.get_event(create_id)
defer.returnValue(create_event.content.get("room_version", "1"))
@cached(max_entries=100000, iterable=True)
def get_current_state_ids(self, room_id):
@ -89,6 +151,69 @@ class StateGroupWorkerStore(SQLBaseStore):
_get_current_state_ids_txn,
)
# FIXME: how should this be cached?
def get_filtered_current_state_ids(self, room_id, types, filtered_types=None):
"""Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result
of doing a fresh state resolution as per state_handler.get_current_state
Args:
room_id (str)
types (list[(Str, (Str|None))]): List of (type, state_key) tuples
which are used to filter the state fetched. `state_key` may be
None, which matches any `state_key`
filtered_types (list[Str]|None): List of types to apply the above filter to.
Returns:
deferred: dict of (type, state_key) -> event
"""
include_other_types = False if filtered_types is None else True
def _get_filtered_current_state_ids_txn(txn):
results = {}
sql = """SELECT type, state_key, event_id FROM current_state_events
WHERE room_id = ? %s"""
# Turns out that postgres doesn't like doing a list of OR's and
# is about 1000x slower, so we just issue a query for each specific
# type seperately.
if types:
clause_to_args = [
(
"AND type = ? AND state_key = ?",
(etype, state_key)
) if state_key is not None else (
"AND type = ?",
(etype,)
)
for etype, state_key in types
]
if include_other_types:
unique_types = set(filtered_types)
clause_to_args.append(
(
"AND type <> ? " * len(unique_types),
list(unique_types)
)
)
else:
# If types is None we fetch all the state, and so just use an
# empty where clause with no extra args.
clause_to_args = [("", [])]
for where_clause, where_args in clause_to_args:
args = [room_id]
args.extend(where_args)
txn.execute(sql % (where_clause,), args)
for row in txn:
typ, state_key, event_id = row
key = (intern_string(typ), intern_string(state_key))
results[key] = event_id
return results
return self.runInteraction(
"get_filtered_current_state_ids",
_get_filtered_current_state_ids_txn,
)
@cached(max_entries=10000, iterable=True)
def get_state_group_delta(self, state_group):
"""Given a state group try to return a previous group and a delta between
@ -185,7 +310,7 @@ class StateGroupWorkerStore(SQLBaseStore):
})
@defer.inlineCallbacks
def _get_state_groups_from_groups(self, groups, types):
def _get_state_groups_from_groups(self, groups, types, members=None):
"""Returns the state groups for a given set of groups, filtering on
types of state events.
@ -194,6 +319,9 @@ class StateGroupWorkerStore(SQLBaseStore):
types (Iterable[str, str|None]|None): list of 2-tuples of the form
(`type`, `state_key`), where a `state_key` of `None` matches all
state_keys for the `type`. If None, all types are returned.
members (bool|None): If not None, then, in addition to any filtering
implied by types, the results are also filtered to only include
member events (if True), or to exclude member events (if False)
Returns:
dictionary state_group -> (dict of (type, state_key) -> event id)
@ -204,14 +332,14 @@ class StateGroupWorkerStore(SQLBaseStore):
for chunk in chunks:
res = yield self.runInteraction(
"_get_state_groups_from_groups",
self._get_state_groups_from_groups_txn, chunk, types,
self._get_state_groups_from_groups_txn, chunk, types, members,
)
results.update(res)
defer.returnValue(results)
def _get_state_groups_from_groups_txn(
self, txn, groups, types=None,
self, txn, groups, types=None, members=None,
):
results = {group: {} for group in groups}
@ -249,6 +377,11 @@ class StateGroupWorkerStore(SQLBaseStore):
%s
""")
if members is True:
sql += " AND type = '%s'" % (EventTypes.Member,)
elif members is False:
sql += " AND type <> '%s'" % (EventTypes.Member,)
# Turns out that postgres doesn't like doing a list of OR's and
# is about 1000x slower, so we just issue a query for each specific
# type seperately.
@ -296,6 +429,11 @@ class StateGroupWorkerStore(SQLBaseStore):
else:
where_clause = ""
if members is True:
where_clause += " AND type = '%s'" % EventTypes.Member
elif members is False:
where_clause += " AND type <> '%s'" % EventTypes.Member
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups:
@ -362,8 +500,7 @@ class StateGroupWorkerStore(SQLBaseStore):
If None, `types` filtering is applied to all events.
Returns:
deferred: A list of dicts corresponding to the event_ids given.
The dicts are mappings from (type, state_key) -> state_events
deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
"""
event_to_groups = yield self._get_state_group_for_events(
event_ids,
@ -391,7 +528,8 @@ class StateGroupWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None):
"""
Get the state dicts corresponding to a list of events
Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves)
Args:
event_ids(list(str)): events whose state should be returned
@ -404,7 +542,7 @@ class StateGroupWorkerStore(SQLBaseStore):
If None, `types` filtering is applied to all events.
Returns:
A deferred dict from event_id -> (type, state_key) -> state_event
A deferred dict from event_id -> (type, state_key) -> event_id
"""
event_to_groups = yield self._get_state_group_for_events(
event_ids,
@ -490,10 +628,11 @@ class StateGroupWorkerStore(SQLBaseStore):
defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
def _get_some_state_from_cache(self, group, types, filtered_types=None):
def _get_some_state_from_cache(self, cache, group, types, filtered_types=None):
"""Checks if group is in cache. See `_get_state_for_groups`
Args:
cache(DictionaryCache): the state group cache to use
group(int): The state group to lookup
types(list[str, str|None]): List of 2-tuples of the form
(`type`, `state_key`), where a `state_key` of `None` matches all
@ -507,11 +646,11 @@ class StateGroupWorkerStore(SQLBaseStore):
requests state from the cache, if False we need to query the DB for the
missing state.
"""
is_all, known_absent, state_dict_ids = self._state_group_cache.get(group)
is_all, known_absent, state_dict_ids = cache.get(group)
type_to_key = {}
# tracks whether any of ourrequested types are missing from the cache
# tracks whether any of our requested types are missing from the cache
missing_types = False
for typ, state_key in types:
@ -558,7 +697,7 @@ class StateGroupWorkerStore(SQLBaseStore):
if include(k[0], k[1])
}, got_all
def _get_all_state_from_cache(self, group):
def _get_all_state_from_cache(self, cache, group):
"""Checks if group is in cache. See `_get_state_for_groups`
Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool
@ -566,9 +705,10 @@ class StateGroupWorkerStore(SQLBaseStore):
cache, if False we need to query the DB for the missing state.
Args:
cache(DictionaryCache): the state group cache to use
group: The state group to lookup
"""
is_all, _, state_dict_ids = self._state_group_cache.get(group)
is_all, _, state_dict_ids = cache.get(group)
return state_dict_ids, is_all
@ -591,6 +731,62 @@ class StateGroupWorkerStore(SQLBaseStore):
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns:
Deferred[dict[int, dict[(type, state_key), EventBase]]]
a dictionary mapping from state group to state dictionary.
"""
if types is not None:
non_member_types = [t for t in types if t[0] != EventTypes.Member]
if filtered_types is not None and EventTypes.Member not in filtered_types:
# we want all of the membership events
member_types = None
else:
member_types = [t for t in types if t[0] == EventTypes.Member]
else:
non_member_types = None
member_types = None
non_member_state = yield self._get_state_for_groups_using_cache(
groups, self._state_group_cache, non_member_types, filtered_types,
)
# XXX: we could skip this entirely if member_types is []
member_state = yield self._get_state_for_groups_using_cache(
# we set filtered_types=None as member_state only ever contain members.
groups, self._state_group_members_cache, member_types, None,
)
state = non_member_state
for group in groups:
state[group].update(member_state[group])
defer.returnValue(state)
@defer.inlineCallbacks
def _get_state_for_groups_using_cache(
self, groups, cache, types=None, filtered_types=None
):
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key, querying from a specific cache.
Args:
groups (iterable[int]): list of state groups for which we want
to get the state.
cache (DictionaryCache): the cache of group ids to state dicts which
we will pass through - either the normal state cache or the specific
members state cache.
types (None|iterable[(str, None|str)]):
indicates the state type/keys required. If None, the whole
state is fetched and returned.
Otherwise, each entry should be a `(type, state_key)` tuple to
include in the response. A `state_key` of None is a wildcard
meaning that we require all state with that type.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns:
Deferred[dict[int, dict[(type, state_key), EventBase]]]
a dictionary mapping from state group to state dictionary.
@ -602,7 +798,7 @@ class StateGroupWorkerStore(SQLBaseStore):
if types is not None:
for group in set(groups):
state_dict_ids, got_all = self._get_some_state_from_cache(
group, types, filtered_types
cache, group, types, filtered_types
)
results[group] = state_dict_ids
@ -611,7 +807,7 @@ class StateGroupWorkerStore(SQLBaseStore):
else:
for group in set(groups):
state_dict_ids, got_all = self._get_all_state_from_cache(
group
cache, group
)
results[group] = state_dict_ids
@ -620,8 +816,8 @@ class StateGroupWorkerStore(SQLBaseStore):
missing_groups.append(group)
if missing_groups:
# Okay, so we have some missing_types, lets fetch them.
cache_seq_num = self._state_group_cache.sequence
# Okay, so we have some missing_types, let's fetch them.
cache_seq_num = cache.sequence
# the DictionaryCache knows if it has *all* the state, but
# does not know if it has all of the keys of a particular type,
@ -635,7 +831,7 @@ class StateGroupWorkerStore(SQLBaseStore):
types_to_fetch = types
group_to_state_dict = yield self._get_state_groups_from_groups(
missing_groups, types_to_fetch
missing_groups, types_to_fetch, cache == self._state_group_members_cache,
)
for group, group_state_dict in iteritems(group_to_state_dict):
@ -655,7 +851,7 @@ class StateGroupWorkerStore(SQLBaseStore):
# update the cache with all the things we fetched from the
# database.
self._state_group_cache.update(
cache.update(
cache_seq_num,
key=group,
value=group_state_dict,
@ -757,15 +953,33 @@ class StateGroupWorkerStore(SQLBaseStore):
],
)
# Prefill the state group cache with this group.
# Prefill the state group caches with this group.
# It's fine to use the sequence like this as the state group map
# is immutable. (If the map wasn't immutable then this prefill could
# race with another update)
current_member_state_ids = {
s: ev
for (s, ev) in iteritems(current_state_ids)
if s[0] == EventTypes.Member
}
txn.call_after(
self._state_group_members_cache.update,
self._state_group_members_cache.sequence,
key=state_group,
value=dict(current_member_state_ids),
)
current_non_member_state_ids = {
s: ev
for (s, ev) in iteritems(current_state_ids)
if s[0] != EventTypes.Member
}
txn.call_after(
self._state_group_cache.update,
self._state_group_cache.sequence,
key=state_group,
value=dict(current_state_ids),
value=dict(current_non_member_state_ids),
)
return state_group

View file

@ -348,7 +348,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
end_token (str): The stream token representing now.
Returns:
Deferred[tuple[list[FrozenEvent], str]]: Returns a list of
Deferred[tuple[list[FrozenEvent], str]]: Returns a list of
events and a token pointing to the start of the returned
events.
The events returned are in ascending order.
@ -379,7 +379,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
end_token (str): The stream token representing now.
Returns:
Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
_EventDictReturn and a token pointing to the start of the returned
events.
The events returned are in ascending order.