mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-12-28 22:36:16 -05:00
Run black on the rest of the storage module (#4996)
This commit is contained in:
parent
3039d61baf
commit
7efd1d87c2
1
changelog.d/4996.misc
Normal file
1
changelog.d/4996.misc
Normal file
@ -0,0 +1 @@
|
||||
Run `black` on the remainder of `synapse/storage/`.
|
@ -61,48 +61,60 @@ from .util.id_generators import ChainedIdGenerator, IdGenerator, StreamIdGenerat
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataStore(RoomMemberStore, RoomStore,
|
||||
RegistrationStore, StreamStore, ProfileStore,
|
||||
PresenceStore, TransactionStore,
|
||||
DirectoryStore, KeyStore, StateStore, SignatureStore,
|
||||
ApplicationServiceStore,
|
||||
EventsStore,
|
||||
EventFederationStore,
|
||||
MediaRepositoryStore,
|
||||
RejectionsStore,
|
||||
FilteringStore,
|
||||
PusherStore,
|
||||
PushRuleStore,
|
||||
ApplicationServiceTransactionStore,
|
||||
ReceiptsStore,
|
||||
EndToEndKeyStore,
|
||||
EndToEndRoomKeyStore,
|
||||
SearchStore,
|
||||
TagsStore,
|
||||
AccountDataStore,
|
||||
EventPushActionsStore,
|
||||
OpenIdStore,
|
||||
ClientIpStore,
|
||||
DeviceStore,
|
||||
DeviceInboxStore,
|
||||
UserDirectoryStore,
|
||||
GroupServerStore,
|
||||
UserErasureStore,
|
||||
MonthlyActiveUsersStore,
|
||||
):
|
||||
|
||||
class DataStore(
|
||||
RoomMemberStore,
|
||||
RoomStore,
|
||||
RegistrationStore,
|
||||
StreamStore,
|
||||
ProfileStore,
|
||||
PresenceStore,
|
||||
TransactionStore,
|
||||
DirectoryStore,
|
||||
KeyStore,
|
||||
StateStore,
|
||||
SignatureStore,
|
||||
ApplicationServiceStore,
|
||||
EventsStore,
|
||||
EventFederationStore,
|
||||
MediaRepositoryStore,
|
||||
RejectionsStore,
|
||||
FilteringStore,
|
||||
PusherStore,
|
||||
PushRuleStore,
|
||||
ApplicationServiceTransactionStore,
|
||||
ReceiptsStore,
|
||||
EndToEndKeyStore,
|
||||
EndToEndRoomKeyStore,
|
||||
SearchStore,
|
||||
TagsStore,
|
||||
AccountDataStore,
|
||||
EventPushActionsStore,
|
||||
OpenIdStore,
|
||||
ClientIpStore,
|
||||
DeviceStore,
|
||||
DeviceInboxStore,
|
||||
UserDirectoryStore,
|
||||
GroupServerStore,
|
||||
UserErasureStore,
|
||||
MonthlyActiveUsersStore,
|
||||
):
|
||||
def __init__(self, db_conn, hs):
|
||||
self.hs = hs
|
||||
self._clock = hs.get_clock()
|
||||
self.database_engine = hs.database_engine
|
||||
|
||||
self._stream_id_gen = StreamIdGenerator(
|
||||
db_conn, "events", "stream_ordering",
|
||||
extra_tables=[("local_invites", "stream_id")]
|
||||
db_conn,
|
||||
"events",
|
||||
"stream_ordering",
|
||||
extra_tables=[("local_invites", "stream_id")],
|
||||
)
|
||||
self._backfill_id_gen = StreamIdGenerator(
|
||||
db_conn, "events", "stream_ordering", step=-1,
|
||||
extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
|
||||
db_conn,
|
||||
"events",
|
||||
"stream_ordering",
|
||||
step=-1,
|
||||
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
|
||||
)
|
||||
self._presence_id_gen = StreamIdGenerator(
|
||||
db_conn, "presence_stream", "stream_id"
|
||||
@ -114,7 +126,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
db_conn, "public_room_list_stream", "stream_id"
|
||||
)
|
||||
self._device_list_id_gen = StreamIdGenerator(
|
||||
db_conn, "device_lists_stream", "stream_id",
|
||||
db_conn, "device_lists_stream", "stream_id"
|
||||
)
|
||||
|
||||
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
||||
@ -125,16 +137,15 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
|
||||
)
|
||||
self._pushers_id_gen = StreamIdGenerator(
|
||||
db_conn, "pushers", "id",
|
||||
extra_tables=[("deleted_pushers", "stream_id")],
|
||||
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
|
||||
)
|
||||
self._group_updates_id_gen = StreamIdGenerator(
|
||||
db_conn, "local_group_updates", "stream_id",
|
||||
db_conn, "local_group_updates", "stream_id"
|
||||
)
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
self._cache_id_gen = StreamIdGenerator(
|
||||
db_conn, "cache_invalidation_stream", "stream_id",
|
||||
db_conn, "cache_invalidation_stream", "stream_id"
|
||||
)
|
||||
else:
|
||||
self._cache_id_gen = None
|
||||
@ -142,72 +153,82 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
self._presence_on_startup = self._get_active_presence(db_conn)
|
||||
|
||||
presence_cache_prefill, min_presence_val = self._get_cache_dict(
|
||||
db_conn, "presence_stream",
|
||||
db_conn,
|
||||
"presence_stream",
|
||||
entity_column="user_id",
|
||||
stream_column="stream_id",
|
||||
max_value=self._presence_id_gen.get_current_token(),
|
||||
)
|
||||
self.presence_stream_cache = StreamChangeCache(
|
||||
"PresenceStreamChangeCache", min_presence_val,
|
||||
prefilled_cache=presence_cache_prefill
|
||||
"PresenceStreamChangeCache",
|
||||
min_presence_val,
|
||||
prefilled_cache=presence_cache_prefill,
|
||||
)
|
||||
|
||||
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
|
||||
device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
|
||||
db_conn, "device_inbox",
|
||||
db_conn,
|
||||
"device_inbox",
|
||||
entity_column="user_id",
|
||||
stream_column="stream_id",
|
||||
max_value=max_device_inbox_id,
|
||||
limit=1000,
|
||||
)
|
||||
self._device_inbox_stream_cache = StreamChangeCache(
|
||||
"DeviceInboxStreamChangeCache", min_device_inbox_id,
|
||||
"DeviceInboxStreamChangeCache",
|
||||
min_device_inbox_id,
|
||||
prefilled_cache=device_inbox_prefill,
|
||||
)
|
||||
# The federation outbox and the local device inbox uses the same
|
||||
# stream_id generator.
|
||||
device_outbox_prefill, min_device_outbox_id = self._get_cache_dict(
|
||||
db_conn, "device_federation_outbox",
|
||||
db_conn,
|
||||
"device_federation_outbox",
|
||||
entity_column="destination",
|
||||
stream_column="stream_id",
|
||||
max_value=max_device_inbox_id,
|
||||
limit=1000,
|
||||
)
|
||||
self._device_federation_outbox_stream_cache = StreamChangeCache(
|
||||
"DeviceFederationOutboxStreamChangeCache", min_device_outbox_id,
|
||||
"DeviceFederationOutboxStreamChangeCache",
|
||||
min_device_outbox_id,
|
||||
prefilled_cache=device_outbox_prefill,
|
||||
)
|
||||
|
||||
device_list_max = self._device_list_id_gen.get_current_token()
|
||||
self._device_list_stream_cache = StreamChangeCache(
|
||||
"DeviceListStreamChangeCache", device_list_max,
|
||||
"DeviceListStreamChangeCache", device_list_max
|
||||
)
|
||||
self._device_list_federation_stream_cache = StreamChangeCache(
|
||||
"DeviceListFederationStreamChangeCache", device_list_max,
|
||||
"DeviceListFederationStreamChangeCache", device_list_max
|
||||
)
|
||||
|
||||
events_max = self._stream_id_gen.get_current_token()
|
||||
curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
|
||||
db_conn, "current_state_delta_stream",
|
||||
db_conn,
|
||||
"current_state_delta_stream",
|
||||
entity_column="room_id",
|
||||
stream_column="stream_id",
|
||||
max_value=events_max, # As we share the stream id with events token
|
||||
limit=1000,
|
||||
)
|
||||
self._curr_state_delta_stream_cache = StreamChangeCache(
|
||||
"_curr_state_delta_stream_cache", min_curr_state_delta_id,
|
||||
"_curr_state_delta_stream_cache",
|
||||
min_curr_state_delta_id,
|
||||
prefilled_cache=curr_state_delta_prefill,
|
||||
)
|
||||
|
||||
_group_updates_prefill, min_group_updates_id = self._get_cache_dict(
|
||||
db_conn, "local_group_updates",
|
||||
db_conn,
|
||||
"local_group_updates",
|
||||
entity_column="user_id",
|
||||
stream_column="stream_id",
|
||||
max_value=self._group_updates_id_gen.get_current_token(),
|
||||
limit=1000,
|
||||
)
|
||||
self._group_updates_stream_cache = StreamChangeCache(
|
||||
"_group_updates_stream_cache", min_group_updates_id,
|
||||
"_group_updates_stream_cache",
|
||||
min_group_updates_id,
|
||||
prefilled_cache=_group_updates_prefill,
|
||||
)
|
||||
|
||||
@ -250,6 +271,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
"""
|
||||
Counts the number of users who used this homeserver in the last 24 hours.
|
||||
"""
|
||||
|
||||
def _count_users(txn):
|
||||
yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
|
||||
|
||||
@ -277,6 +299,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
Returns counts globaly for a given user as well as breaking
|
||||
by platform
|
||||
"""
|
||||
|
||||
def _count_r30_users(txn):
|
||||
thirty_days_in_secs = 86400 * 30
|
||||
now = int(self._clock.time())
|
||||
@ -313,8 +336,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
"""
|
||||
|
||||
results = {}
|
||||
txn.execute(sql, (thirty_days_ago_in_secs,
|
||||
thirty_days_ago_in_secs))
|
||||
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
|
||||
|
||||
for row in txn:
|
||||
if row[0] == 'unknown':
|
||||
@ -341,8 +363,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
) u
|
||||
"""
|
||||
|
||||
txn.execute(sql, (thirty_days_ago_in_secs,
|
||||
thirty_days_ago_in_secs))
|
||||
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
|
||||
|
||||
count, = txn.fetchone()
|
||||
results['all'] = count
|
||||
@ -356,15 +377,14 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
Returns millisecond unixtime for start of UTC day.
|
||||
"""
|
||||
now = time.gmtime()
|
||||
today_start = calendar.timegm((
|
||||
now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0,
|
||||
))
|
||||
today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
|
||||
return today_start * 1000
|
||||
|
||||
def generate_user_daily_visits(self):
|
||||
"""
|
||||
Generates daily visit data for use in cohort/ retention analysis
|
||||
"""
|
||||
|
||||
def _generate_user_daily_visits(txn):
|
||||
logger.info("Calling _generate_user_daily_visits")
|
||||
today_start = self._get_start_of_day()
|
||||
@ -395,25 +415,29 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
# often to minimise this case.
|
||||
if today_start > self._last_user_visit_update:
|
||||
yesterday_start = today_start - a_day_in_milliseconds
|
||||
txn.execute(sql, (
|
||||
yesterday_start, yesterday_start,
|
||||
self._last_user_visit_update, today_start
|
||||
))
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
yesterday_start,
|
||||
yesterday_start,
|
||||
self._last_user_visit_update,
|
||||
today_start,
|
||||
),
|
||||
)
|
||||
self._last_user_visit_update = today_start
|
||||
|
||||
txn.execute(sql, (
|
||||
today_start, today_start,
|
||||
self._last_user_visit_update,
|
||||
now
|
||||
))
|
||||
txn.execute(
|
||||
sql, (today_start, today_start, self._last_user_visit_update, now)
|
||||
)
|
||||
# Update _last_user_visit_update to now. The reason to do this
|
||||
# rather just clamping to the beginning of the day is to limit
|
||||
# the size of the join - meaning that the query can be run more
|
||||
# frequently
|
||||
self._last_user_visit_update = now
|
||||
|
||||
return self.runInteraction("generate_user_daily_visits",
|
||||
_generate_user_daily_visits)
|
||||
return self.runInteraction(
|
||||
"generate_user_daily_visits", _generate_user_daily_visits
|
||||
)
|
||||
|
||||
def get_users(self):
|
||||
"""Function to reterive a list of users in users table.
|
||||
@ -425,12 +449,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
return self._simple_select_list(
|
||||
table="users",
|
||||
keyvalues={},
|
||||
retcols=[
|
||||
"name",
|
||||
"password_hash",
|
||||
"is_guest",
|
||||
"admin"
|
||||
],
|
||||
retcols=["name", "password_hash", "is_guest", "admin"],
|
||||
desc="get_users",
|
||||
)
|
||||
|
||||
@ -451,20 +470,9 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
i_limit = (int)(limit)
|
||||
return self.get_user_list_paginate(
|
||||
table="users",
|
||||
keyvalues={
|
||||
"is_guest": is_guest
|
||||
},
|
||||
pagevalues=[
|
||||
order,
|
||||
i_limit,
|
||||
i_start
|
||||
],
|
||||
retcols=[
|
||||
"name",
|
||||
"password_hash",
|
||||
"is_guest",
|
||||
"admin"
|
||||
],
|
||||
keyvalues={"is_guest": is_guest},
|
||||
pagevalues=[order, i_limit, i_start],
|
||||
retcols=["name", "password_hash", "is_guest", "admin"],
|
||||
desc="get_users_paginate",
|
||||
)
|
||||
|
||||
@ -482,12 +490,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
table="users",
|
||||
term=term,
|
||||
col="name",
|
||||
retcols=[
|
||||
"name",
|
||||
"password_hash",
|
||||
"is_guest",
|
||||
"admin"
|
||||
],
|
||||
retcols=["name", "password_hash", "is_guest", "admin"],
|
||||
desc="search_users",
|
||||
)
|
||||
|
||||
|
@ -41,7 +41,7 @@ try:
|
||||
MAX_TXN_ID = sys.maxint - 1
|
||||
except AttributeError:
|
||||
# python 3 does not have a maximum int value
|
||||
MAX_TXN_ID = 2**63 - 1
|
||||
MAX_TXN_ID = 2 ** 63 - 1
|
||||
|
||||
sql_logger = logging.getLogger("synapse.storage.SQL")
|
||||
transaction_logger = logging.getLogger("synapse.storage.txn")
|
||||
@ -76,12 +76,18 @@ class LoggingTransaction(object):
|
||||
"""An object that almost-transparently proxies for the 'txn' object
|
||||
passed to the constructor. Adds logging and metrics to the .execute()
|
||||
method."""
|
||||
|
||||
__slots__ = [
|
||||
"txn", "name", "database_engine", "after_callbacks", "exception_callbacks",
|
||||
"txn",
|
||||
"name",
|
||||
"database_engine",
|
||||
"after_callbacks",
|
||||
"exception_callbacks",
|
||||
]
|
||||
|
||||
def __init__(self, txn, name, database_engine, after_callbacks,
|
||||
exception_callbacks):
|
||||
def __init__(
|
||||
self, txn, name, database_engine, after_callbacks, exception_callbacks
|
||||
):
|
||||
object.__setattr__(self, "txn", txn)
|
||||
object.__setattr__(self, "name", name)
|
||||
object.__setattr__(self, "database_engine", database_engine)
|
||||
@ -110,6 +116,7 @@ class LoggingTransaction(object):
|
||||
def execute_batch(self, sql, args):
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
from psycopg2.extras import execute_batch
|
||||
|
||||
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
|
||||
else:
|
||||
for val in args:
|
||||
@ -134,10 +141,7 @@ class LoggingTransaction(object):
|
||||
sql = self.database_engine.convert_param_style(sql)
|
||||
if args:
|
||||
try:
|
||||
sql_logger.debug(
|
||||
"[SQL values] {%s} %r",
|
||||
self.name, args[0]
|
||||
)
|
||||
sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
|
||||
except Exception:
|
||||
# Don't let logging failures stop SQL from working
|
||||
pass
|
||||
@ -145,9 +149,7 @@ class LoggingTransaction(object):
|
||||
start = time.time()
|
||||
|
||||
try:
|
||||
return func(
|
||||
sql, *args
|
||||
)
|
||||
return func(sql, *args)
|
||||
except Exception as e:
|
||||
logger.debug("[SQL FAIL] {%s} %s", self.name, e)
|
||||
raise
|
||||
@ -176,11 +178,9 @@ class PerformanceCounters(object):
|
||||
counters = []
|
||||
for name, (count, cum_time) in iteritems(self.current_counters):
|
||||
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
|
||||
counters.append((
|
||||
(cum_time - prev_time) / interval_duration,
|
||||
count - prev_count,
|
||||
name
|
||||
))
|
||||
counters.append(
|
||||
((cum_time - prev_time) / interval_duration, count - prev_count, name)
|
||||
)
|
||||
|
||||
self.previous_counters = dict(self.current_counters)
|
||||
|
||||
@ -212,8 +212,9 @@ class SQLBaseStore(object):
|
||||
self._txn_perf_counters = PerformanceCounters()
|
||||
self._get_event_counters = PerformanceCounters()
|
||||
|
||||
self._get_event_cache = Cache("*getEvent*", keylen=3,
|
||||
max_entries=hs.config.event_cache_size)
|
||||
self._get_event_cache = Cache(
|
||||
"*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
|
||||
)
|
||||
|
||||
self._event_fetch_lock = threading.Condition()
|
||||
self._event_fetch_list = []
|
||||
@ -239,7 +240,7 @@ class SQLBaseStore(object):
|
||||
0.0,
|
||||
run_as_background_process,
|
||||
"upsert_safety_check",
|
||||
self._check_safe_to_upsert
|
||||
self._check_safe_to_upsert,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -271,7 +272,7 @@ class SQLBaseStore(object):
|
||||
15.0,
|
||||
run_as_background_process,
|
||||
"upsert_safety_check",
|
||||
self._check_safe_to_upsert
|
||||
self._check_safe_to_upsert,
|
||||
)
|
||||
|
||||
def start_profiling(self):
|
||||
@ -298,13 +299,16 @@ class SQLBaseStore(object):
|
||||
|
||||
perf_logger.info(
|
||||
"Total database time: %.3f%% {%s} {%s}",
|
||||
ratio * 100, top_three_counters, top_3_event_counters
|
||||
ratio * 100,
|
||||
top_three_counters,
|
||||
top_3_event_counters,
|
||||
)
|
||||
|
||||
self._clock.looping_call(loop, 10000)
|
||||
|
||||
def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks,
|
||||
func, *args, **kwargs):
|
||||
def _new_transaction(
|
||||
self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
|
||||
):
|
||||
start = time.time()
|
||||
txn_id = self._TXN_ID
|
||||
|
||||
@ -312,7 +316,7 @@ class SQLBaseStore(object):
|
||||
# growing really large.
|
||||
self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
|
||||
|
||||
name = "%s-%x" % (desc, txn_id, )
|
||||
name = "%s-%x" % (desc, txn_id)
|
||||
|
||||
transaction_logger.debug("[TXN START] {%s}", name)
|
||||
|
||||
@ -323,7 +327,10 @@ class SQLBaseStore(object):
|
||||
try:
|
||||
txn = conn.cursor()
|
||||
txn = LoggingTransaction(
|
||||
txn, name, self.database_engine, after_callbacks,
|
||||
txn,
|
||||
name,
|
||||
self.database_engine,
|
||||
after_callbacks,
|
||||
exception_callbacks,
|
||||
)
|
||||
r = func(txn, *args, **kwargs)
|
||||
@ -334,7 +341,10 @@ class SQLBaseStore(object):
|
||||
# transaction.
|
||||
logger.warning(
|
||||
"[TXN OPERROR] {%s} %s %d/%d",
|
||||
name, exception_to_unicode(e), i, N
|
||||
name,
|
||||
exception_to_unicode(e),
|
||||
i,
|
||||
N,
|
||||
)
|
||||
if i < N:
|
||||
i += 1
|
||||
@ -342,8 +352,7 @@ class SQLBaseStore(object):
|
||||
conn.rollback()
|
||||
except self.database_engine.module.Error as e1:
|
||||
logger.warning(
|
||||
"[TXN EROLL] {%s} %s",
|
||||
name, exception_to_unicode(e1),
|
||||
"[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
|
||||
)
|
||||
continue
|
||||
raise
|
||||
@ -357,7 +366,8 @@ class SQLBaseStore(object):
|
||||
except self.database_engine.module.Error as e1:
|
||||
logger.warning(
|
||||
"[TXN EROLL] {%s} %s",
|
||||
name, exception_to_unicode(e1),
|
||||
name,
|
||||
exception_to_unicode(e1),
|
||||
)
|
||||
continue
|
||||
raise
|
||||
@ -396,16 +406,17 @@ class SQLBaseStore(object):
|
||||
exception_callbacks = []
|
||||
|
||||
if LoggingContext.current_context() == LoggingContext.sentinel:
|
||||
logger.warn(
|
||||
"Starting db txn '%s' from sentinel context",
|
||||
desc,
|
||||
)
|
||||
logger.warn("Starting db txn '%s' from sentinel context", desc)
|
||||
|
||||
try:
|
||||
result = yield self.runWithConnection(
|
||||
self._new_transaction,
|
||||
desc, after_callbacks, exception_callbacks, func,
|
||||
*args, **kwargs
|
||||
desc,
|
||||
after_callbacks,
|
||||
exception_callbacks,
|
||||
func,
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
for after_callback, after_args, after_kwargs in after_callbacks:
|
||||
@ -434,7 +445,7 @@ class SQLBaseStore(object):
|
||||
parent_context = LoggingContext.current_context()
|
||||
if parent_context == LoggingContext.sentinel:
|
||||
logger.warn(
|
||||
"Starting db connection from sentinel context: metrics will be lost",
|
||||
"Starting db connection from sentinel context: metrics will be lost"
|
||||
)
|
||||
parent_context = None
|
||||
|
||||
@ -453,9 +464,7 @@ class SQLBaseStore(object):
|
||||
return func(conn, *args, **kwargs)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
result = yield self._db_pool.runWithConnection(
|
||||
inner_func, *args, **kwargs
|
||||
)
|
||||
result = yield self._db_pool.runWithConnection(inner_func, *args, **kwargs)
|
||||
|
||||
defer.returnValue(result)
|
||||
|
||||
@ -469,9 +478,7 @@ class SQLBaseStore(object):
|
||||
A list of dicts where the key is the column header.
|
||||
"""
|
||||
col_headers = list(intern(str(column[0])) for column in cursor.description)
|
||||
results = list(
|
||||
dict(zip(col_headers, row)) for row in cursor
|
||||
)
|
||||
results = list(dict(zip(col_headers, row)) for row in cursor)
|
||||
return results
|
||||
|
||||
def _execute(self, desc, decoder, query, *args):
|
||||
@ -485,6 +492,7 @@ class SQLBaseStore(object):
|
||||
Returns:
|
||||
The result of decoder(results)
|
||||
"""
|
||||
|
||||
def interaction(txn):
|
||||
txn.execute(query, args)
|
||||
if decoder:
|
||||
@ -498,8 +506,7 @@ class SQLBaseStore(object):
|
||||
# no complex WHERE clauses, just a dict of values for columns.
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _simple_insert(self, table, values, or_ignore=False,
|
||||
desc="_simple_insert"):
|
||||
def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"):
|
||||
"""Executes an INSERT query on the named table.
|
||||
|
||||
Args:
|
||||
@ -511,10 +518,7 @@ class SQLBaseStore(object):
|
||||
`or_ignore` is True
|
||||
"""
|
||||
try:
|
||||
yield self.runInteraction(
|
||||
desc,
|
||||
self._simple_insert_txn, table, values,
|
||||
)
|
||||
yield self.runInteraction(desc, self._simple_insert_txn, table, values)
|
||||
except self.database_engine.module.IntegrityError:
|
||||
# We have to do or_ignore flag at this layer, since we can't reuse
|
||||
# a cursor after we receive an error from the db.
|
||||
@ -530,15 +534,13 @@ class SQLBaseStore(object):
|
||||
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
|
||||
table,
|
||||
", ".join(k for k in keys),
|
||||
", ".join("?" for _ in keys)
|
||||
", ".join("?" for _ in keys),
|
||||
)
|
||||
|
||||
txn.execute(sql, vals)
|
||||
|
||||
def _simple_insert_many(self, table, values, desc):
|
||||
return self.runInteraction(
|
||||
desc, self._simple_insert_many_txn, table, values
|
||||
)
|
||||
return self.runInteraction(desc, self._simple_insert_many_txn, table, values)
|
||||
|
||||
@staticmethod
|
||||
def _simple_insert_many_txn(txn, table, values):
|
||||
@ -553,24 +555,18 @@ class SQLBaseStore(object):
|
||||
#
|
||||
# The sort is to ensure that we don't rely on dictionary iteration
|
||||
# order.
|
||||
keys, vals = zip(*[
|
||||
zip(
|
||||
*(sorted(i.items(), key=lambda kv: kv[0]))
|
||||
)
|
||||
for i in values
|
||||
if i
|
||||
])
|
||||
keys, vals = zip(
|
||||
*[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
|
||||
)
|
||||
|
||||
for k in keys:
|
||||
if k != keys[0]:
|
||||
raise RuntimeError(
|
||||
"All items must have the same keys"
|
||||
)
|
||||
raise RuntimeError("All items must have the same keys")
|
||||
|
||||
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
|
||||
table,
|
||||
", ".join(k for k in keys[0]),
|
||||
", ".join("?" for _ in keys[0])
|
||||
", ".join("?" for _ in keys[0]),
|
||||
)
|
||||
|
||||
txn.executemany(sql, vals)
|
||||
@ -583,7 +579,7 @@ class SQLBaseStore(object):
|
||||
values,
|
||||
insertion_values={},
|
||||
desc="_simple_upsert",
|
||||
lock=True
|
||||
lock=True,
|
||||
):
|
||||
"""
|
||||
|
||||
@ -635,13 +631,7 @@ class SQLBaseStore(object):
|
||||
)
|
||||
|
||||
def _simple_upsert_txn(
|
||||
self,
|
||||
txn,
|
||||
table,
|
||||
keyvalues,
|
||||
values,
|
||||
insertion_values={},
|
||||
lock=True,
|
||||
self, txn, table, keyvalues, values, insertion_values={}, lock=True
|
||||
):
|
||||
"""
|
||||
Pick the UPSERT method which works best on the platform. Either the
|
||||
@ -665,11 +655,7 @@ class SQLBaseStore(object):
|
||||
and table not in self._unsafe_to_upsert_tables
|
||||
):
|
||||
return self._simple_upsert_txn_native_upsert(
|
||||
txn,
|
||||
table,
|
||||
keyvalues,
|
||||
values,
|
||||
insertion_values=insertion_values,
|
||||
txn, table, keyvalues, values, insertion_values=insertion_values
|
||||
)
|
||||
else:
|
||||
return self._simple_upsert_txn_emulated(
|
||||
@ -714,7 +700,7 @@ class SQLBaseStore(object):
|
||||
# SELECT instead to see if it exists.
|
||||
sql = "SELECT 1 FROM %s WHERE %s" % (
|
||||
table,
|
||||
" AND ".join(_getwhere(k) for k in keyvalues)
|
||||
" AND ".join(_getwhere(k) for k in keyvalues),
|
||||
)
|
||||
sqlargs = list(keyvalues.values())
|
||||
txn.execute(sql, sqlargs)
|
||||
@ -726,7 +712,7 @@ class SQLBaseStore(object):
|
||||
sql = "UPDATE %s SET %s WHERE %s" % (
|
||||
table,
|
||||
", ".join("%s = ?" % (k,) for k in values),
|
||||
" AND ".join(_getwhere(k) for k in keyvalues)
|
||||
" AND ".join(_getwhere(k) for k in keyvalues),
|
||||
)
|
||||
sqlargs = list(values.values()) + list(keyvalues.values())
|
||||
|
||||
@ -773,19 +759,14 @@ class SQLBaseStore(object):
|
||||
latter = "NOTHING"
|
||||
else:
|
||||
allvalues.update(values)
|
||||
latter = (
|
||||
"UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
|
||||
)
|
||||
latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
|
||||
|
||||
sql = (
|
||||
"INSERT INTO %s (%s) VALUES (%s) "
|
||||
"ON CONFLICT (%s) DO %s"
|
||||
) % (
|
||||
sql = ("INSERT INTO %s (%s) VALUES (%s) " "ON CONFLICT (%s) DO %s") % (
|
||||
table,
|
||||
", ".join(k for k in allvalues),
|
||||
", ".join("?" for _ in allvalues),
|
||||
", ".join(k for k in keyvalues),
|
||||
latter
|
||||
latter,
|
||||
)
|
||||
txn.execute(sql, list(allvalues.values()))
|
||||
|
||||
@ -870,8 +851,8 @@ class SQLBaseStore(object):
|
||||
latter = "NOTHING"
|
||||
value_values = [() for x in range(len(key_values))]
|
||||
else:
|
||||
latter = (
|
||||
"UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in value_names)
|
||||
latter = "UPDATE SET " + ", ".join(
|
||||
k + "=EXCLUDED." + k for k in value_names
|
||||
)
|
||||
|
||||
sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
|
||||
@ -889,8 +870,9 @@ class SQLBaseStore(object):
|
||||
|
||||
return txn.execute_batch(sql, args)
|
||||
|
||||
def _simple_select_one(self, table, keyvalues, retcols,
|
||||
allow_none=False, desc="_simple_select_one"):
|
||||
def _simple_select_one(
|
||||
self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"
|
||||
):
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
return a single row, returning multiple columns from it.
|
||||
|
||||
@ -903,14 +885,17 @@ class SQLBaseStore(object):
|
||||
statement returns no rows
|
||||
"""
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_select_one_txn,
|
||||
table, keyvalues, retcols, allow_none,
|
||||
desc, self._simple_select_one_txn, table, keyvalues, retcols, allow_none
|
||||
)
|
||||
|
||||
def _simple_select_one_onecol(self, table, keyvalues, retcol,
|
||||
allow_none=False,
|
||||
desc="_simple_select_one_onecol"):
|
||||
def _simple_select_one_onecol(
|
||||
self,
|
||||
table,
|
||||
keyvalues,
|
||||
retcol,
|
||||
allow_none=False,
|
||||
desc="_simple_select_one_onecol",
|
||||
):
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
return a single row, returning a single column from it.
|
||||
|
||||
@ -922,17 +907,18 @@ class SQLBaseStore(object):
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_select_one_onecol_txn,
|
||||
table, keyvalues, retcol, allow_none=allow_none,
|
||||
table,
|
||||
keyvalues,
|
||||
retcol,
|
||||
allow_none=allow_none,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol,
|
||||
allow_none=False):
|
||||
def _simple_select_one_onecol_txn(
|
||||
cls, txn, table, keyvalues, retcol, allow_none=False
|
||||
):
|
||||
ret = cls._simple_select_onecol_txn(
|
||||
txn,
|
||||
table=table,
|
||||
keyvalues=keyvalues,
|
||||
retcol=retcol,
|
||||
txn, table=table, keyvalues=keyvalues, retcol=retcol
|
||||
)
|
||||
|
||||
if ret:
|
||||
@ -945,12 +931,7 @@ class SQLBaseStore(object):
|
||||
|
||||
@staticmethod
|
||||
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
|
||||
sql = (
|
||||
"SELECT %(retcol)s FROM %(table)s"
|
||||
) % {
|
||||
"retcol": retcol,
|
||||
"table": table,
|
||||
}
|
||||
sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
|
||||
|
||||
if keyvalues:
|
||||
sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
|
||||
@ -960,8 +941,9 @@ class SQLBaseStore(object):
|
||||
|
||||
return [r[0] for r in txn]
|
||||
|
||||
def _simple_select_onecol(self, table, keyvalues, retcol,
|
||||
desc="_simple_select_onecol"):
|
||||
def _simple_select_onecol(
|
||||
self, table, keyvalues, retcol, desc="_simple_select_onecol"
|
||||
):
|
||||
"""Executes a SELECT query on the named table, which returns a list
|
||||
comprising of the values of the named column from the selected rows.
|
||||
|
||||
@ -974,13 +956,12 @@ class SQLBaseStore(object):
|
||||
Deferred: Results in a list
|
||||
"""
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_select_onecol_txn,
|
||||
table, keyvalues, retcol
|
||||
desc, self._simple_select_onecol_txn, table, keyvalues, retcol
|
||||
)
|
||||
|
||||
def _simple_select_list(self, table, keyvalues, retcols,
|
||||
desc="_simple_select_list"):
|
||||
def _simple_select_list(
|
||||
self, table, keyvalues, retcols, desc="_simple_select_list"
|
||||
):
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
|
||||
@ -994,9 +975,7 @@ class SQLBaseStore(object):
|
||||
defer.Deferred: resolves to list[dict[str, Any]]
|
||||
"""
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_select_list_txn,
|
||||
table, keyvalues, retcols
|
||||
desc, self._simple_select_list_txn, table, keyvalues, retcols
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -1016,22 +995,26 @@ class SQLBaseStore(object):
|
||||
sql = "SELECT %s FROM %s WHERE %s" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues),
|
||||
)
|
||||
txn.execute(sql, list(keyvalues.values()))
|
||||
else:
|
||||
sql = "SELECT %s FROM %s" % (
|
||||
", ".join(retcols),
|
||||
table
|
||||
)
|
||||
sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
|
||||
txn.execute(sql)
|
||||
|
||||
return cls.cursor_to_dict(txn)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _simple_select_many_batch(self, table, column, iterable, retcols,
|
||||
keyvalues={}, desc="_simple_select_many_batch",
|
||||
batch_size=100):
|
||||
def _simple_select_many_batch(
|
||||
self,
|
||||
table,
|
||||
column,
|
||||
iterable,
|
||||
retcols,
|
||||
keyvalues={},
|
||||
desc="_simple_select_many_batch",
|
||||
batch_size=100,
|
||||
):
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
|
||||
@ -1053,14 +1036,17 @@ class SQLBaseStore(object):
|
||||
it_list = list(iterable)
|
||||
|
||||
chunks = [
|
||||
it_list[i:i + batch_size]
|
||||
for i in range(0, len(it_list), batch_size)
|
||||
it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
|
||||
]
|
||||
for chunk in chunks:
|
||||
rows = yield self.runInteraction(
|
||||
desc,
|
||||
self._simple_select_many_txn,
|
||||
table, column, chunk, keyvalues, retcols
|
||||
table,
|
||||
column,
|
||||
chunk,
|
||||
keyvalues,
|
||||
retcols,
|
||||
)
|
||||
|
||||
results.extend(rows)
|
||||
@ -1089,9 +1075,7 @@ class SQLBaseStore(object):
|
||||
|
||||
clauses = []
|
||||
values = []
|
||||
clauses.append(
|
||||
"%s IN (%s)" % (column, ",".join("?" for _ in iterable))
|
||||
)
|
||||
clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
|
||||
values.extend(iterable)
|
||||
|
||||
for key, value in iteritems(keyvalues):
|
||||
@ -1099,19 +1083,14 @@ class SQLBaseStore(object):
|
||||
values.append(value)
|
||||
|
||||
if clauses:
|
||||
sql = "%s WHERE %s" % (
|
||||
sql,
|
||||
" AND ".join(clauses),
|
||||
)
|
||||
sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
|
||||
|
||||
txn.execute(sql, values)
|
||||
return cls.cursor_to_dict(txn)
|
||||
|
||||
def _simple_update(self, table, keyvalues, updatevalues, desc):
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_update_txn,
|
||||
table, keyvalues, updatevalues,
|
||||
desc, self._simple_update_txn, table, keyvalues, updatevalues
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -1127,15 +1106,13 @@ class SQLBaseStore(object):
|
||||
where,
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
update_sql,
|
||||
list(updatevalues.values()) + list(keyvalues.values())
|
||||
)
|
||||
txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values()))
|
||||
|
||||
return txn.rowcount
|
||||
|
||||
def _simple_update_one(self, table, keyvalues, updatevalues,
|
||||
desc="_simple_update_one"):
|
||||
def _simple_update_one(
|
||||
self, table, keyvalues, updatevalues, desc="_simple_update_one"
|
||||
):
|
||||
"""Executes an UPDATE query on the named table, setting new values for
|
||||
columns in a row matching the key values.
|
||||
|
||||
@ -1154,9 +1131,7 @@ class SQLBaseStore(object):
|
||||
the update column in the 'keyvalues' dict as well.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_update_one_txn,
|
||||
table, keyvalues, updatevalues,
|
||||
desc, self._simple_update_one_txn, table, keyvalues, updatevalues
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -1169,12 +1144,11 @@ class SQLBaseStore(object):
|
||||
raise StoreError(500, "More than one row matched (%s)" % (table,))
|
||||
|
||||
@staticmethod
|
||||
def _simple_select_one_txn(txn, table, keyvalues, retcols,
|
||||
allow_none=False):
|
||||
def _simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
|
||||
select_sql = "SELECT %s FROM %s WHERE %s" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues)
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues),
|
||||
)
|
||||
|
||||
txn.execute(select_sql, list(keyvalues.values()))
|
||||
@ -1197,9 +1171,7 @@ class SQLBaseStore(object):
|
||||
table : string giving the table name
|
||||
keyvalues : dict of column names and values to select the row with
|
||||
"""
|
||||
return self.runInteraction(
|
||||
desc, self._simple_delete_one_txn, table, keyvalues
|
||||
)
|
||||
return self.runInteraction(desc, self._simple_delete_one_txn, table, keyvalues)
|
||||
|
||||
@staticmethod
|
||||
def _simple_delete_one_txn(txn, table, keyvalues):
|
||||
@ -1212,7 +1184,7 @@ class SQLBaseStore(object):
|
||||
"""
|
||||
sql = "DELETE FROM %s WHERE %s" % (
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues),
|
||||
)
|
||||
|
||||
txn.execute(sql, list(keyvalues.values()))
|
||||
@ -1222,15 +1194,13 @@ class SQLBaseStore(object):
|
||||
raise StoreError(500, "More than one row matched (%s)" % (table,))
|
||||
|
||||
def _simple_delete(self, table, keyvalues, desc):
|
||||
return self.runInteraction(
|
||||
desc, self._simple_delete_txn, table, keyvalues
|
||||
)
|
||||
return self.runInteraction(desc, self._simple_delete_txn, table, keyvalues)
|
||||
|
||||
@staticmethod
|
||||
def _simple_delete_txn(txn, table, keyvalues):
|
||||
sql = "DELETE FROM %s WHERE %s" % (
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues),
|
||||
)
|
||||
|
||||
return txn.execute(sql, list(keyvalues.values()))
|
||||
@ -1260,9 +1230,7 @@ class SQLBaseStore(object):
|
||||
|
||||
clauses = []
|
||||
values = []
|
||||
clauses.append(
|
||||
"%s IN (%s)" % (column, ",".join("?" for _ in iterable))
|
||||
)
|
||||
clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
|
||||
values.extend(iterable)
|
||||
|
||||
for key, value in iteritems(keyvalues):
|
||||
@ -1270,14 +1238,12 @@ class SQLBaseStore(object):
|
||||
values.append(value)
|
||||
|
||||
if clauses:
|
||||
sql = "%s WHERE %s" % (
|
||||
sql,
|
||||
" AND ".join(clauses),
|
||||
)
|
||||
sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
|
||||
return txn.execute(sql, values)
|
||||
|
||||
def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
|
||||
max_value, limit=100000):
|
||||
def _get_cache_dict(
|
||||
self, db_conn, table, entity_column, stream_column, max_value, limit=100000
|
||||
):
|
||||
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
|
||||
# It doesn't really matter how many we get, the StreamChangeCache will
|
||||
# do the right thing to ensure it respects the max size of cache.
|
||||
@ -1297,10 +1263,7 @@ class SQLBaseStore(object):
|
||||
txn = db_conn.cursor()
|
||||
txn.execute(sql, (int(max_value),))
|
||||
|
||||
cache = {
|
||||
row[0]: int(row[1])
|
||||
for row in txn
|
||||
}
|
||||
cache = {row[0]: int(row[1]) for row in txn}
|
||||
|
||||
txn.close()
|
||||
|
||||
@ -1342,9 +1305,7 @@ class SQLBaseStore(object):
|
||||
# be safe.
|
||||
for chunk in batch_iter(members_changed, 50):
|
||||
keys = itertools.chain([room_id], chunk)
|
||||
self._send_invalidation_to_replication(
|
||||
txn, _CURRENT_STATE_CACHE_NAME, keys,
|
||||
)
|
||||
self._send_invalidation_to_replication(txn, _CURRENT_STATE_CACHE_NAME, keys)
|
||||
|
||||
def _invalidate_state_caches(self, room_id, members_changed):
|
||||
"""Invalidates caches that are based on the current state, but does
|
||||
@ -1356,22 +1317,12 @@ class SQLBaseStore(object):
|
||||
changed
|
||||
"""
|
||||
for host in set(get_domain_from_id(u) for u in members_changed):
|
||||
self._attempt_to_invalidate_cache(
|
||||
"is_host_joined", (room_id, host,),
|
||||
)
|
||||
self._attempt_to_invalidate_cache(
|
||||
"was_host_joined", (room_id, host,),
|
||||
)
|
||||
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
|
||||
self._attempt_to_invalidate_cache("was_host_joined", (room_id, host))
|
||||
|
||||
self._attempt_to_invalidate_cache(
|
||||
"get_users_in_room", (room_id,),
|
||||
)
|
||||
self._attempt_to_invalidate_cache(
|
||||
"get_room_summary", (room_id,),
|
||||
)
|
||||
self._attempt_to_invalidate_cache(
|
||||
"get_current_state_ids", (room_id,),
|
||||
)
|
||||
self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
|
||||
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
|
||||
self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))
|
||||
|
||||
def _attempt_to_invalidate_cache(self, cache_name, key):
|
||||
"""Attempts to invalidate the cache of the given name, ignoring if the
|
||||
@ -1419,7 +1370,7 @@ class SQLBaseStore(object):
|
||||
"cache_func": cache_name,
|
||||
"keys": list(keys),
|
||||
"invalidation_ts": self.clock.time_msec(),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def get_all_updated_caches(self, last_id, current_id, limit):
|
||||
@ -1435,11 +1386,10 @@ class SQLBaseStore(object):
|
||||
" FROM cache_invalidation_stream"
|
||||
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (last_id, limit,))
|
||||
txn.execute(sql, (last_id, limit))
|
||||
return txn.fetchall()
|
||||
return self.runInteraction(
|
||||
"get_all_updated_caches", get_all_updated_caches_txn
|
||||
)
|
||||
|
||||
return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn)
|
||||
|
||||
def get_cache_stream_token(self):
|
||||
if self._cache_id_gen:
|
||||
@ -1447,8 +1397,9 @@ class SQLBaseStore(object):
|
||||
else:
|
||||
return 0
|
||||
|
||||
def _simple_select_list_paginate(self, table, keyvalues, pagevalues, retcols,
|
||||
desc="_simple_select_list_paginate"):
|
||||
def _simple_select_list_paginate(
|
||||
self, table, keyvalues, pagevalues, retcols, desc="_simple_select_list_paginate"
|
||||
):
|
||||
"""Executes a SELECT query on the named table with start and limit,
|
||||
of row numbers, which may return zero or number of rows from start to limit,
|
||||
returning the result as a list of dicts.
|
||||
@ -1468,11 +1419,16 @@ class SQLBaseStore(object):
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_select_list_paginate_txn,
|
||||
table, keyvalues, pagevalues, retcols
|
||||
table,
|
||||
keyvalues,
|
||||
pagevalues,
|
||||
retcols,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _simple_select_list_paginate_txn(cls, txn, table, keyvalues, pagevalues, retcols):
|
||||
def _simple_select_list_paginate_txn(
|
||||
cls, txn, table, keyvalues, pagevalues, retcols
|
||||
):
|
||||
"""Executes a SELECT query on the named table with start and limit,
|
||||
of row numbers, which may return zero or number of rows from start to limit,
|
||||
returning the result as a list of dicts.
|
||||
@ -1497,22 +1453,23 @@ class SQLBaseStore(object):
|
||||
", ".join(retcols),
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues),
|
||||
" ? ASC LIMIT ? OFFSET ?"
|
||||
" ? ASC LIMIT ? OFFSET ?",
|
||||
)
|
||||
txn.execute(sql, list(keyvalues.values()) + list(pagevalues))
|
||||
else:
|
||||
sql = "SELECT %s FROM %s ORDER BY %s" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
" ? ASC LIMIT ? OFFSET ?"
|
||||
" ? ASC LIMIT ? OFFSET ?",
|
||||
)
|
||||
txn.execute(sql, pagevalues)
|
||||
|
||||
return cls.cursor_to_dict(txn)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_list_paginate(self, table, keyvalues, pagevalues, retcols,
|
||||
desc="get_user_list_paginate"):
|
||||
def get_user_list_paginate(
|
||||
self, table, keyvalues, pagevalues, retcols, desc="get_user_list_paginate"
|
||||
):
|
||||
"""Get a list of users from start row to a limit number of rows. This will
|
||||
return a json object with users and total number of users in users list.
|
||||
|
||||
@ -1532,16 +1489,13 @@ class SQLBaseStore(object):
|
||||
users = yield self.runInteraction(
|
||||
desc,
|
||||
self._simple_select_list_paginate_txn,
|
||||
table, keyvalues, pagevalues, retcols
|
||||
table,
|
||||
keyvalues,
|
||||
pagevalues,
|
||||
retcols,
|
||||
)
|
||||
count = yield self.runInteraction(
|
||||
desc,
|
||||
self.get_user_count_txn
|
||||
)
|
||||
retval = {
|
||||
"users": users,
|
||||
"total": count
|
||||
}
|
||||
count = yield self.runInteraction(desc, self.get_user_count_txn)
|
||||
retval = {"users": users, "total": count}
|
||||
defer.returnValue(retval)
|
||||
|
||||
def get_user_count_txn(self, txn):
|
||||
@ -1556,8 +1510,9 @@ class SQLBaseStore(object):
|
||||
txn.execute(sql_count)
|
||||
return txn.fetchone()[0]
|
||||
|
||||
def _simple_search_list(self, table, term, col, retcols,
|
||||
desc="_simple_search_list"):
|
||||
def _simple_search_list(
|
||||
self, table, term, col, retcols, desc="_simple_search_list"
|
||||
):
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
|
||||
@ -1572,9 +1527,7 @@ class SQLBaseStore(object):
|
||||
"""
|
||||
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_search_list_txn,
|
||||
table, term, col, retcols
|
||||
desc, self._simple_search_list_txn, table, term, col, retcols
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -1593,11 +1546,7 @@ class SQLBaseStore(object):
|
||||
defer.Deferred: resolves to list[dict[str, Any]] or None
|
||||
"""
|
||||
if term:
|
||||
sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
col
|
||||
)
|
||||
sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
|
||||
termvalues = ["%%" + term + "%%"]
|
||||
txn.execute(sql, termvalues)
|
||||
else:
|
||||
@ -1618,6 +1567,7 @@ class _RollbackButIsFineException(Exception):
|
||||
""" This exception is used to rollback a transaction without implying
|
||||
something went wrong.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
|
@ -41,7 +41,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
account_max = self.get_max_account_data_stream_id()
|
||||
self._account_data_stream_cache = StreamChangeCache(
|
||||
"AccountDataAndTagsChangeCache", account_max,
|
||||
"AccountDataAndTagsChangeCache", account_max
|
||||
)
|
||||
|
||||
super(AccountDataWorkerStore, self).__init__(db_conn, hs)
|
||||
@ -68,8 +68,10 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
|
||||
def get_account_data_for_user_txn(txn):
|
||||
rows = self._simple_select_list_txn(
|
||||
txn, "account_data", {"user_id": user_id},
|
||||
["account_data_type", "content"]
|
||||
txn,
|
||||
"account_data",
|
||||
{"user_id": user_id},
|
||||
["account_data_type", "content"],
|
||||
)
|
||||
|
||||
global_account_data = {
|
||||
@ -77,8 +79,10 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
}
|
||||
|
||||
rows = self._simple_select_list_txn(
|
||||
txn, "room_account_data", {"user_id": user_id},
|
||||
["room_id", "account_data_type", "content"]
|
||||
txn,
|
||||
"room_account_data",
|
||||
{"user_id": user_id},
|
||||
["room_id", "account_data_type", "content"],
|
||||
)
|
||||
|
||||
by_room = {}
|
||||
@ -100,10 +104,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
"""
|
||||
result = yield self._simple_select_one_onecol(
|
||||
table="account_data",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"account_data_type": data_type,
|
||||
},
|
||||
keyvalues={"user_id": user_id, "account_data_type": data_type},
|
||||
retcol="content",
|
||||
desc="get_global_account_data_by_type_for_user",
|
||||
allow_none=True,
|
||||
@ -124,10 +125,13 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
Returns:
|
||||
A deferred dict of the room account_data
|
||||
"""
|
||||
|
||||
def get_account_data_for_room_txn(txn):
|
||||
rows = self._simple_select_list_txn(
|
||||
txn, "room_account_data", {"user_id": user_id, "room_id": room_id},
|
||||
["account_data_type", "content"]
|
||||
txn,
|
||||
"room_account_data",
|
||||
{"user_id": user_id, "room_id": room_id},
|
||||
["account_data_type", "content"],
|
||||
)
|
||||
|
||||
return {
|
||||
@ -150,6 +154,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
A deferred of the room account_data for that type, or None if
|
||||
there isn't any set.
|
||||
"""
|
||||
|
||||
def get_account_data_for_room_and_type_txn(txn):
|
||||
content_json = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
@ -160,18 +165,18 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
"account_data_type": account_data_type,
|
||||
},
|
||||
retcol="content",
|
||||
allow_none=True
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
return json.loads(content_json) if content_json else None
|
||||
|
||||
return self.runInteraction(
|
||||
"get_account_data_for_room_and_type",
|
||||
get_account_data_for_room_and_type_txn,
|
||||
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
|
||||
)
|
||||
|
||||
def get_all_updated_account_data(self, last_global_id, last_room_id,
|
||||
current_id, limit):
|
||||
def get_all_updated_account_data(
|
||||
self, last_global_id, last_room_id, current_id, limit
|
||||
):
|
||||
"""Get all the client account_data that has changed on the server
|
||||
Args:
|
||||
last_global_id(int): The position to fetch from for top level data
|
||||
@ -201,6 +206,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
txn.execute(sql, (last_room_id, current_id, limit))
|
||||
room_results = txn.fetchall()
|
||||
return (global_results, room_results)
|
||||
|
||||
return self.runInteraction(
|
||||
"get_all_updated_account_data_txn", get_updated_account_data_txn
|
||||
)
|
||||
@ -224,9 +230,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
|
||||
txn.execute(sql, (user_id, stream_id))
|
||||
|
||||
global_account_data = {
|
||||
row[0]: json.loads(row[1]) for row in txn
|
||||
}
|
||||
global_account_data = {row[0]: json.loads(row[1]) for row in txn}
|
||||
|
||||
sql = (
|
||||
"SELECT room_id, account_data_type, content FROM room_account_data"
|
||||
@ -255,7 +259,8 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
@cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
|
||||
def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
|
||||
ignored_account_data = yield self.get_global_account_data_by_type_for_user(
|
||||
"m.ignored_user_list", ignorer_user_id,
|
||||
"m.ignored_user_list",
|
||||
ignorer_user_id,
|
||||
on_invalidate=cache_context.invalidate,
|
||||
)
|
||||
if not ignored_account_data:
|
||||
@ -307,10 +312,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||
"room_id": room_id,
|
||||
"account_data_type": account_data_type,
|
||||
},
|
||||
values={
|
||||
"stream_id": next_id,
|
||||
"content": content_json,
|
||||
},
|
||||
values={"stream_id": next_id, "content": content_json},
|
||||
lock=False,
|
||||
)
|
||||
|
||||
@ -324,9 +326,9 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||
|
||||
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
|
||||
self.get_account_data_for_user.invalidate((user_id,))
|
||||
self.get_account_data_for_room.invalidate((user_id, room_id,))
|
||||
self.get_account_data_for_room.invalidate((user_id, room_id))
|
||||
self.get_account_data_for_room_and_type.prefill(
|
||||
(user_id, room_id, account_data_type,), content,
|
||||
(user_id, room_id, account_data_type), content
|
||||
)
|
||||
|
||||
result = self._account_data_id_gen.get_current_token()
|
||||
@ -351,14 +353,8 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||
yield self._simple_upsert(
|
||||
desc="add_user_account_data",
|
||||
table="account_data",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"account_data_type": account_data_type,
|
||||
},
|
||||
values={
|
||||
"stream_id": next_id,
|
||||
"content": content_json,
|
||||
},
|
||||
keyvalues={"user_id": user_id, "account_data_type": account_data_type},
|
||||
values={"stream_id": next_id, "content": content_json},
|
||||
lock=False,
|
||||
)
|
||||
|
||||
@ -370,12 +366,10 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||
# transaction.
|
||||
yield self._update_max_stream_id(next_id)
|
||||
|
||||
self._account_data_stream_cache.entity_has_changed(
|
||||
user_id, next_id,
|
||||
)
|
||||
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
|
||||
self.get_account_data_for_user.invalidate((user_id,))
|
||||
self.get_global_account_data_by_type_for_user.invalidate(
|
||||
(account_data_type, user_id,)
|
||||
(account_data_type, user_id)
|
||||
)
|
||||
|
||||
result = self._account_data_id_gen.get_current_token()
|
||||
@ -387,6 +381,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||
Args:
|
||||
next_id(int): The the revision to advance to.
|
||||
"""
|
||||
|
||||
def _update(txn):
|
||||
update_max_id_sql = (
|
||||
"UPDATE account_data_max_stream_id"
|
||||
@ -394,7 +389,5 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||
" WHERE stream_id < ?"
|
||||
)
|
||||
txn.execute(update_max_id_sql, (next_id, next_id))
|
||||
return self.runInteraction(
|
||||
"update_account_data_max_stream_id",
|
||||
_update,
|
||||
)
|
||||
|
||||
return self.runInteraction("update_account_data_max_stream_id", _update)
|
||||
|
@ -51,8 +51,7 @@ def _make_exclusive_regex(services_cache):
|
||||
class ApplicationServiceWorkerStore(SQLBaseStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
self.services_cache = load_appservices(
|
||||
hs.hostname,
|
||||
hs.config.app_service_config_files
|
||||
hs.hostname, hs.config.app_service_config_files
|
||||
)
|
||||
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
|
||||
|
||||
@ -122,8 +121,9 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
|
||||
pass
|
||||
|
||||
|
||||
class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
|
||||
EventsWorkerStore):
|
||||
class ApplicationServiceTransactionWorkerStore(
|
||||
ApplicationServiceWorkerStore, EventsWorkerStore
|
||||
):
|
||||
@defer.inlineCallbacks
|
||||
def get_appservices_by_state(self, state):
|
||||
"""Get a list of application services based on their state.
|
||||
@ -135,9 +135,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
|
||||
may be empty.
|
||||
"""
|
||||
results = yield self._simple_select_list(
|
||||
"application_services_state",
|
||||
dict(state=state),
|
||||
["as_id"]
|
||||
"application_services_state", dict(state=state), ["as_id"]
|
||||
)
|
||||
# NB: This assumes this class is linked with ApplicationServiceStore
|
||||
as_list = self.get_app_services()
|
||||
@ -180,9 +178,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
|
||||
A Deferred which resolves when the state was set successfully.
|
||||
"""
|
||||
return self._simple_upsert(
|
||||
"application_services_state",
|
||||
dict(as_id=service.id),
|
||||
dict(state=state)
|
||||
"application_services_state", dict(as_id=service.id), dict(state=state)
|
||||
)
|
||||
|
||||
def create_appservice_txn(self, service, events):
|
||||
@ -195,6 +191,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
|
||||
Returns:
|
||||
AppServiceTransaction: A new transaction.
|
||||
"""
|
||||
|
||||
def _create_appservice_txn(txn):
|
||||
# work out new txn id (highest txn id for this service += 1)
|
||||
# The highest id may be the last one sent (in which case it is last_txn)
|
||||
@ -204,7 +201,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
|
||||
|
||||
txn.execute(
|
||||
"SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
|
||||
(service.id,)
|
||||
(service.id,),
|
||||
)
|
||||
highest_txn_id = txn.fetchone()[0]
|
||||
if highest_txn_id is None:
|
||||
@ -217,16 +214,11 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
|
||||
txn.execute(
|
||||
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
|
||||
"VALUES(?,?,?)",
|
||||
(service.id, new_txn_id, event_ids)
|
||||
)
|
||||
return AppServiceTransaction(
|
||||
service=service, id=new_txn_id, events=events
|
||||
(service.id, new_txn_id, event_ids),
|
||||
)
|
||||
return AppServiceTransaction(service=service, id=new_txn_id, events=events)
|
||||
|
||||
return self.runInteraction(
|
||||
"create_appservice_txn",
|
||||
_create_appservice_txn,
|
||||
)
|
||||
return self.runInteraction("create_appservice_txn", _create_appservice_txn)
|
||||
|
||||
def complete_appservice_txn(self, txn_id, service):
|
||||
"""Completes an application service transaction.
|
||||
@ -252,26 +244,26 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
|
||||
"appservice: Completing a transaction which has an ID > 1 from "
|
||||
"the last ID sent to this AS. We've either dropped events or "
|
||||
"sent it to the AS out of order. FIX ME. last_txn=%s "
|
||||
"completing_txn=%s service_id=%s", last_txn_id, txn_id,
|
||||
service.id
|
||||
"completing_txn=%s service_id=%s",
|
||||
last_txn_id,
|
||||
txn_id,
|
||||
service.id,
|
||||
)
|
||||
|
||||
# Set current txn_id for AS to 'txn_id'
|
||||
self._simple_upsert_txn(
|
||||
txn, "application_services_state", dict(as_id=service.id),
|
||||
dict(last_txn=txn_id)
|
||||
txn,
|
||||
"application_services_state",
|
||||
dict(as_id=service.id),
|
||||
dict(last_txn=txn_id),
|
||||
)
|
||||
|
||||
# Delete txn
|
||||
self._simple_delete_txn(
|
||||
txn, "application_services_txns",
|
||||
dict(txn_id=txn_id, as_id=service.id)
|
||||
txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id)
|
||||
)
|
||||
|
||||
return self.runInteraction(
|
||||
"complete_appservice_txn",
|
||||
_complete_appservice_txn,
|
||||
)
|
||||
return self.runInteraction("complete_appservice_txn", _complete_appservice_txn)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_oldest_unsent_txn(self, service):
|
||||
@ -284,13 +276,14 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
|
||||
A Deferred which resolves to an AppServiceTransaction or
|
||||
None.
|
||||
"""
|
||||
|
||||
def _get_oldest_unsent_txn(txn):
|
||||
# Monotonically increasing txn ids, so just select the smallest
|
||||
# one in the txns table (we delete them when they are sent)
|
||||
txn.execute(
|
||||
"SELECT * FROM application_services_txns WHERE as_id=?"
|
||||
" ORDER BY txn_id ASC LIMIT 1",
|
||||
(service.id,)
|
||||
(service.id,),
|
||||
)
|
||||
rows = self.cursor_to_dict(txn)
|
||||
if not rows:
|
||||
@ -301,8 +294,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
|
||||
return entry
|
||||
|
||||
entry = yield self.runInteraction(
|
||||
"get_oldest_unsent_appservice_txn",
|
||||
_get_oldest_unsent_txn,
|
||||
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
|
||||
)
|
||||
|
||||
if not entry:
|
||||
@ -312,14 +304,14 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
|
||||
|
||||
events = yield self._get_events(event_ids)
|
||||
|
||||
defer.returnValue(AppServiceTransaction(
|
||||
service=service, id=entry["txn_id"], events=events
|
||||
))
|
||||
defer.returnValue(
|
||||
AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
|
||||
)
|
||||
|
||||
def _get_last_txn(self, txn, service_id):
|
||||
txn.execute(
|
||||
"SELECT last_txn FROM application_services_state WHERE as_id=?",
|
||||
(service_id,)
|
||||
(service_id,),
|
||||
)
|
||||
last_txn_id = txn.fetchone()
|
||||
if last_txn_id is None or last_txn_id[0] is None: # no row exists
|
||||
@ -332,6 +324,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
|
||||
txn.execute(
|
||||
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
|
||||
)
|
||||
|
||||
return self.runInteraction(
|
||||
"set_appservice_last_pos", set_appservice_last_pos_txn
|
||||
)
|
||||
@ -362,7 +355,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
|
||||
return upper_bound, [row[1] for row in rows]
|
||||
|
||||
upper_bound, event_ids = yield self.runInteraction(
|
||||
"get_new_events_for_appservice", get_new_events_for_appservice_txn,
|
||||
"get_new_events_for_appservice", get_new_events_for_appservice_txn
|
||||
)
|
||||
|
||||
events = yield self._get_events(event_ids)
|
||||
|
@ -94,16 +94,13 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||
self._all_done = False
|
||||
|
||||
def start_doing_background_updates(self):
|
||||
run_as_background_process(
|
||||
"background_updates", self._run_background_updates,
|
||||
)
|
||||
run_as_background_process("background_updates", self._run_background_updates)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _run_background_updates(self):
|
||||
logger.info("Starting background schema updates")
|
||||
while True:
|
||||
yield self.hs.get_clock().sleep(
|
||||
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)
|
||||
yield self.hs.get_clock().sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
|
||||
|
||||
try:
|
||||
result = yield self.do_next_background_update(
|
||||
@ -187,8 +184,7 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_background_update(self, update_name, desired_duration_ms):
|
||||
logger.info("Starting update batch on background update '%s'",
|
||||
update_name)
|
||||
logger.info("Starting update batch on background update '%s'", update_name)
|
||||
|
||||
update_handler = self._background_update_handlers[update_name]
|
||||
|
||||
@ -210,7 +206,7 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||
progress_json = yield self._simple_select_one_onecol(
|
||||
"background_updates",
|
||||
keyvalues={"update_name": update_name},
|
||||
retcol="progress_json"
|
||||
retcol="progress_json",
|
||||
)
|
||||
|
||||
progress = json.loads(progress_json)
|
||||
@ -224,7 +220,9 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||
logger.info(
|
||||
"Updating %r. Updated %r items in %rms."
|
||||
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
|
||||
update_name, items_updated, duration_ms,
|
||||
update_name,
|
||||
items_updated,
|
||||
duration_ms,
|
||||
performance.total_items_per_ms(),
|
||||
performance.average_items_per_ms(),
|
||||
performance.total_item_count,
|
||||
@ -264,6 +262,7 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||
Args:
|
||||
update_name (str): Name of update
|
||||
"""
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def noop_update(progress, batch_size):
|
||||
yield self._end_background_update(update_name)
|
||||
@ -271,10 +270,16 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||
|
||||
self.register_background_update_handler(update_name, noop_update)
|
||||
|
||||
def register_background_index_update(self, update_name, index_name,
|
||||
table, columns, where_clause=None,
|
||||
unique=False,
|
||||
psql_only=False):
|
||||
def register_background_index_update(
|
||||
self,
|
||||
update_name,
|
||||
index_name,
|
||||
table,
|
||||
columns,
|
||||
where_clause=None,
|
||||
unique=False,
|
||||
psql_only=False,
|
||||
):
|
||||
"""Helper for store classes to do a background index addition
|
||||
|
||||
To use:
|
||||
@ -320,7 +325,7 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||
"name": index_name,
|
||||
"table": table,
|
||||
"columns": ", ".join(columns),
|
||||
"where_clause": "WHERE " + where_clause if where_clause else ""
|
||||
"where_clause": "WHERE " + where_clause if where_clause else "",
|
||||
}
|
||||
logger.debug("[SQL] %s", sql)
|
||||
c.execute(sql)
|
||||
@ -387,7 +392,7 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||
|
||||
return self._simple_insert(
|
||||
"background_updates",
|
||||
{"update_name": update_name, "progress_json": progress_json}
|
||||
{"update_name": update_name, "progress_json": progress_json},
|
||||
)
|
||||
|
||||
def _end_background_update(self, update_name):
|
||||
|
@ -37,9 +37,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
|
||||
self.client_ip_last_seen = Cache(
|
||||
name="client_ip_last_seen",
|
||||
keylen=4,
|
||||
max_entries=50000 * CACHE_SIZE_FACTOR,
|
||||
name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
|
||||
)
|
||||
|
||||
super(ClientIpStore, self).__init__(db_conn, hs)
|
||||
@ -66,13 +64,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
)
|
||||
|
||||
self.register_background_update_handler(
|
||||
"user_ips_analyze",
|
||||
self._analyze_user_ip,
|
||||
"user_ips_analyze", self._analyze_user_ip
|
||||
)
|
||||
|
||||
self.register_background_update_handler(
|
||||
"user_ips_remove_dupes",
|
||||
self._remove_user_ip_dupes,
|
||||
"user_ips_remove_dupes", self._remove_user_ip_dupes
|
||||
)
|
||||
|
||||
# Register a unique index
|
||||
@ -86,8 +82,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
|
||||
# Drop the old non-unique index
|
||||
self.register_background_update_handler(
|
||||
"user_ips_drop_nonunique_index",
|
||||
self._remove_user_ip_nonunique,
|
||||
"user_ips_drop_nonunique_index", self._remove_user_ip_nonunique
|
||||
)
|
||||
|
||||
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
|
||||
@ -104,9 +99,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
def _remove_user_ip_nonunique(self, progress, batch_size):
|
||||
def f(conn):
|
||||
txn = conn.cursor()
|
||||
txn.execute(
|
||||
"DROP INDEX IF EXISTS user_ips_user_ip"
|
||||
)
|
||||
txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
|
||||
txn.close()
|
||||
|
||||
yield self.runWithConnection(f)
|
||||
@ -124,9 +117,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
def user_ips_analyze(txn):
|
||||
txn.execute("ANALYZE user_ips")
|
||||
|
||||
yield self.runInteraction(
|
||||
"user_ips_analyze", user_ips_analyze
|
||||
)
|
||||
yield self.runInteraction("user_ips_analyze", user_ips_analyze)
|
||||
|
||||
yield self._end_background_update("user_ips_analyze")
|
||||
|
||||
@ -151,7 +142,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
LIMIT 1
|
||||
OFFSET ?
|
||||
""",
|
||||
(begin_last_seen, batch_size)
|
||||
(begin_last_seen, batch_size),
|
||||
)
|
||||
row = txn.fetchone()
|
||||
if row:
|
||||
@ -169,7 +160,8 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
|
||||
logger.info(
|
||||
"Scanning for duplicate 'user_ips' rows in range: %s <= last_seen < %s",
|
||||
begin_last_seen, end_last_seen,
|
||||
begin_last_seen,
|
||||
end_last_seen,
|
||||
)
|
||||
|
||||
def remove(txn):
|
||||
@ -207,8 +199,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
INNER JOIN user_ips USING (user_id, access_token, ip)
|
||||
GROUP BY user_id, access_token, ip
|
||||
HAVING count(*) > 1
|
||||
""".format(clause),
|
||||
args
|
||||
""".format(
|
||||
clause
|
||||
),
|
||||
args,
|
||||
)
|
||||
res = txn.fetchall()
|
||||
|
||||
@ -254,7 +248,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
DELETE FROM user_ips
|
||||
WHERE user_id = ? AND access_token = ? AND ip = ? AND last_seen < ?
|
||||
""",
|
||||
(user_id, access_token, ip, last_seen)
|
||||
(user_id, access_token, ip, last_seen),
|
||||
)
|
||||
if txn.rowcount == count - 1:
|
||||
# We deleted all but one of the duplicate rows, i.e. there
|
||||
@ -263,7 +257,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
continue
|
||||
elif txn.rowcount >= count:
|
||||
raise Exception(
|
||||
"We deleted more duplicate rows from 'user_ips' than expected",
|
||||
"We deleted more duplicate rows from 'user_ips' than expected"
|
||||
)
|
||||
|
||||
# The previous step didn't delete enough rows, so we fallback to
|
||||
@ -275,7 +269,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
DELETE FROM user_ips
|
||||
WHERE user_id = ? AND access_token = ? AND ip = ?
|
||||
""",
|
||||
(user_id, access_token, ip)
|
||||
(user_id, access_token, ip),
|
||||
)
|
||||
|
||||
# Add in one to be the last_seen
|
||||
@ -285,7 +279,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
(user_id, access_token, ip, device_id, user_agent, last_seen)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(user_id, access_token, ip, device_id, user_agent, last_seen)
|
||||
(user_id, access_token, ip, device_id, user_agent, last_seen),
|
||||
)
|
||||
|
||||
self._background_update_progress_txn(
|
||||
@ -300,8 +294,9 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
defer.returnValue(batch_size)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id,
|
||||
now=None):
|
||||
def insert_client_ip(
|
||||
self, user_id, access_token, ip, user_agent, device_id, now=None
|
||||
):
|
||||
if not now:
|
||||
now = int(self._clock.time_msec())
|
||||
key = (user_id, access_token, ip)
|
||||
@ -329,13 +324,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
to_update = self._batch_row_update
|
||||
self._batch_row_update = {}
|
||||
return self.runInteraction(
|
||||
"_update_client_ips_batch", self._update_client_ips_batch_txn,
|
||||
to_update,
|
||||
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
|
||||
)
|
||||
|
||||
return run_as_background_process(
|
||||
"update_client_ips", update,
|
||||
)
|
||||
return run_as_background_process("update_client_ips", update)
|
||||
|
||||
def _update_client_ips_batch_txn(self, txn, to_update):
|
||||
if "user_ips" in self._unsafe_to_upsert_tables or (
|
||||
@ -383,7 +375,8 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
res = yield self.runInteraction(
|
||||
"get_last_client_ip_by_device",
|
||||
self._get_last_client_ip_by_device_txn,
|
||||
user_id, device_id,
|
||||
user_id,
|
||||
device_id,
|
||||
retcols=(
|
||||
"user_id",
|
||||
"access_token",
|
||||
@ -416,7 +409,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
bindings = []
|
||||
if device_id is None:
|
||||
where_clauses.append("user_id = ?")
|
||||
bindings.extend((user_id, ))
|
||||
bindings.extend((user_id,))
|
||||
else:
|
||||
where_clauses.append("(user_id = ? AND device_id = ?)")
|
||||
bindings.extend((user_id, device_id))
|
||||
@ -428,9 +421,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
"SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips "
|
||||
"WHERE %(where)s "
|
||||
"GROUP BY user_id, device_id"
|
||||
) % {
|
||||
"where": " OR ".join(where_clauses),
|
||||
}
|
||||
) % {"where": " OR ".join(where_clauses)}
|
||||
|
||||
sql = (
|
||||
"SELECT %(retcols)s FROM user_ips "
|
||||
@ -462,9 +453,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
rows = yield self._simple_select_list(
|
||||
table="user_ips",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=[
|
||||
"access_token", "ip", "user_agent", "last_seen"
|
||||
],
|
||||
retcols=["access_token", "ip", "user_agent", "last_seen"],
|
||||
desc="get_user_ip_and_agents",
|
||||
)
|
||||
|
||||
@ -472,12 +461,14 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
|
||||
for row in rows
|
||||
)
|
||||
defer.returnValue(list(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"ip": ip,
|
||||
"user_agent": user_agent,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
for (access_token, ip), (user_agent, last_seen) in iteritems(results)
|
||||
))
|
||||
defer.returnValue(
|
||||
list(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"ip": ip,
|
||||
"user_agent": user_agent,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
for (access_token, ip), (user_agent, last_seen) in iteritems(results)
|
||||
)
|
||||
)
|
||||
|
@ -57,9 +57,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||
" ORDER BY stream_id ASC"
|
||||
" LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (
|
||||
user_id, device_id, last_stream_id, current_stream_id, limit
|
||||
))
|
||||
txn.execute(
|
||||
sql, (user_id, device_id, last_stream_id, current_stream_id, limit)
|
||||
)
|
||||
messages = []
|
||||
for row in txn:
|
||||
stream_pos = row[0]
|
||||
@ -69,7 +69,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||
return (messages, stream_pos)
|
||||
|
||||
return self.runInteraction(
|
||||
"get_new_messages_for_device", get_new_messages_for_device_txn,
|
||||
"get_new_messages_for_device", get_new_messages_for_device_txn
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -146,9 +146,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||
" ORDER BY stream_id ASC"
|
||||
" LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (
|
||||
destination, last_stream_id, current_stream_id, limit
|
||||
))
|
||||
txn.execute(sql, (destination, last_stream_id, current_stream_id, limit))
|
||||
messages = []
|
||||
for row in txn:
|
||||
stream_pos = row[0]
|
||||
@ -172,6 +170,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||
Returns:
|
||||
A deferred that resolves when the messages have been deleted.
|
||||
"""
|
||||
|
||||
def delete_messages_for_remote_destination_txn(txn):
|
||||
sql = (
|
||||
"DELETE FROM device_federation_outbox"
|
||||
@ -181,8 +180,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||
txn.execute(sql, (destination, up_to_stream_id))
|
||||
|
||||
return self.runInteraction(
|
||||
"delete_device_msgs_for_remote",
|
||||
delete_messages_for_remote_destination_txn
|
||||
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
|
||||
)
|
||||
|
||||
|
||||
@ -200,8 +198,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
|
||||
)
|
||||
|
||||
self.register_background_update_handler(
|
||||
self.DEVICE_INBOX_STREAM_ID,
|
||||
self._background_drop_index_device_inbox,
|
||||
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
|
||||
)
|
||||
|
||||
# Map of (user_id, device_id) to the last stream_id that has been
|
||||
@ -214,8 +211,9 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_messages_to_device_inbox(self, local_messages_by_user_then_device,
|
||||
remote_messages_by_destination):
|
||||
def add_messages_to_device_inbox(
|
||||
self, local_messages_by_user_then_device, remote_messages_by_destination
|
||||
):
|
||||
"""Used to send messages from this server.
|
||||
|
||||
Args:
|
||||
@ -252,15 +250,10 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
|
||||
with self._device_inbox_id_gen.get_next() as stream_id:
|
||||
now_ms = self.clock.time_msec()
|
||||
yield self.runInteraction(
|
||||
"add_messages_to_device_inbox",
|
||||
add_messages_txn,
|
||||
now_ms,
|
||||
stream_id,
|
||||
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
|
||||
)
|
||||
for user_id in local_messages_by_user_then_device.keys():
|
||||
self._device_inbox_stream_cache.entity_has_changed(
|
||||
user_id, stream_id
|
||||
)
|
||||
self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id)
|
||||
for destination in remote_messages_by_destination.keys():
|
||||
self._device_federation_outbox_stream_cache.entity_has_changed(
|
||||
destination, stream_id
|
||||
@ -277,7 +270,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
|
||||
# origin. This can happen if the origin doesn't receive our
|
||||
# acknowledgement from the first time we received the message.
|
||||
already_inserted = self._simple_select_one_txn(
|
||||
txn, table="device_federation_inbox",
|
||||
txn,
|
||||
table="device_federation_inbox",
|
||||
keyvalues={"origin": origin, "message_id": message_id},
|
||||
retcols=("message_id",),
|
||||
allow_none=True,
|
||||
@ -288,7 +282,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
|
||||
# Add an entry for this message_id so that we know we've processed
|
||||
# it.
|
||||
self._simple_insert_txn(
|
||||
txn, table="device_federation_inbox",
|
||||
txn,
|
||||
table="device_federation_inbox",
|
||||
values={
|
||||
"origin": origin,
|
||||
"message_id": message_id,
|
||||
@ -311,19 +306,14 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
|
||||
stream_id,
|
||||
)
|
||||
for user_id in local_messages_by_user_then_device.keys():
|
||||
self._device_inbox_stream_cache.entity_has_changed(
|
||||
user_id, stream_id
|
||||
)
|
||||
self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id)
|
||||
|
||||
defer.returnValue(stream_id)
|
||||
|
||||
def _add_messages_to_local_device_inbox_txn(self, txn, stream_id,
|
||||
messages_by_user_then_device):
|
||||
sql = (
|
||||
"UPDATE device_max_stream_id"
|
||||
" SET stream_id = ?"
|
||||
" WHERE stream_id < ?"
|
||||
)
|
||||
def _add_messages_to_local_device_inbox_txn(
|
||||
self, txn, stream_id, messages_by_user_then_device
|
||||
):
|
||||
sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?"
|
||||
txn.execute(sql, (stream_id, stream_id))
|
||||
|
||||
local_by_user_then_device = {}
|
||||
@ -332,10 +322,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
|
||||
devices = list(messages_by_device.keys())
|
||||
if len(devices) == 1 and devices[0] == "*":
|
||||
# Handle wildcard device_ids.
|
||||
sql = (
|
||||
"SELECT device_id FROM devices"
|
||||
" WHERE user_id = ?"
|
||||
)
|
||||
sql = "SELECT device_id FROM devices" " WHERE user_id = ?"
|
||||
txn.execute(sql, (user_id,))
|
||||
message_json = json.dumps(messages_by_device["*"])
|
||||
for row in txn:
|
||||
@ -428,9 +415,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
|
||||
def _background_drop_index_device_inbox(self, progress, batch_size):
|
||||
def reindex_txn(conn):
|
||||
txn = conn.cursor()
|
||||
txn.execute(
|
||||
"DROP INDEX IF EXISTS device_inbox_stream_id"
|
||||
)
|
||||
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
|
||||
txn.close()
|
||||
|
||||
yield self.runWithConnection(reindex_txn)
|
||||
|
@ -67,7 +67,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
table="devices",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("user_id", "device_id", "display_name"),
|
||||
desc="get_devices_by_user"
|
||||
desc="get_devices_by_user",
|
||||
)
|
||||
|
||||
defer.returnValue({d["device_id"]: d for d in devices})
|
||||
@ -87,21 +87,23 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
return (now_stream_id, [])
|
||||
|
||||
return self.runInteraction(
|
||||
"get_devices_by_remote", self._get_devices_by_remote_txn,
|
||||
destination, from_stream_id, now_stream_id,
|
||||
"get_devices_by_remote",
|
||||
self._get_devices_by_remote_txn,
|
||||
destination,
|
||||
from_stream_id,
|
||||
now_stream_id,
|
||||
)
|
||||
|
||||
def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
|
||||
now_stream_id):
|
||||
def _get_devices_by_remote_txn(
|
||||
self, txn, destination, from_stream_id, now_stream_id
|
||||
):
|
||||
sql = """
|
||||
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
|
||||
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
|
||||
GROUP BY user_id, device_id
|
||||
LIMIT 20
|
||||
"""
|
||||
txn.execute(
|
||||
sql, (destination, from_stream_id, now_stream_id, False)
|
||||
)
|
||||
txn.execute(sql, (destination, from_stream_id, now_stream_id, False))
|
||||
|
||||
# maps (user_id, device_id) -> stream_id
|
||||
query_map = {(r[0], r[1]): r[2] for r in txn}
|
||||
@ -112,7 +114,10 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
now_stream_id = max(stream_id for stream_id in itervalues(query_map))
|
||||
|
||||
devices = self._get_e2e_device_keys_txn(
|
||||
txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
|
||||
txn,
|
||||
query_map.keys(),
|
||||
include_all_devices=True,
|
||||
include_deleted_devices=True,
|
||||
)
|
||||
|
||||
prev_sent_id_sql = """
|
||||
@ -157,8 +162,10 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
"""Mark that updates have successfully been sent to the destination.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
|
||||
destination, stream_id,
|
||||
"mark_as_sent_devices_by_remote",
|
||||
self._mark_as_sent_devices_by_remote_txn,
|
||||
destination,
|
||||
stream_id,
|
||||
)
|
||||
|
||||
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
|
||||
@ -173,7 +180,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
WHERE destination = ? AND o.stream_id <= ?
|
||||
GROUP BY user_id
|
||||
"""
|
||||
txn.execute(sql, (destination, stream_id,))
|
||||
txn.execute(sql, (destination, stream_id))
|
||||
rows = txn.fetchall()
|
||||
|
||||
sql = """
|
||||
@ -181,16 +188,14 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
SET stream_id = ?
|
||||
WHERE destination = ? AND user_id = ?
|
||||
"""
|
||||
txn.executemany(
|
||||
sql, ((row[1], destination, row[0],) for row in rows if row[2])
|
||||
)
|
||||
txn.executemany(sql, ((row[1], destination, row[0]) for row in rows if row[2]))
|
||||
|
||||
sql = """
|
||||
INSERT INTO device_lists_outbound_last_success
|
||||
(destination, user_id, stream_id) VALUES (?, ?, ?)
|
||||
"""
|
||||
txn.executemany(
|
||||
sql, ((destination, row[0], row[1],) for row in rows if not row[2])
|
||||
sql, ((destination, row[0], row[1]) for row in rows if not row[2])
|
||||
)
|
||||
|
||||
# Delete all sent outbound pokes
|
||||
@ -198,7 +203,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
DELETE FROM device_lists_outbound_pokes
|
||||
WHERE destination = ? AND stream_id <= ?
|
||||
"""
|
||||
txn.execute(sql, (destination, stream_id,))
|
||||
txn.execute(sql, (destination, stream_id))
|
||||
|
||||
def get_device_stream_token(self):
|
||||
return self._device_list_id_gen.get_current_token()
|
||||
@ -240,10 +245,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
def _get_cached_user_device(self, user_id, device_id):
|
||||
content = yield self._simple_select_one_onecol(
|
||||
table="device_lists_remote_cache",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
retcol="content",
|
||||
desc="_get_cached_user_device",
|
||||
)
|
||||
@ -253,16 +255,13 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
def _get_cached_devices_for_user(self, user_id):
|
||||
devices = yield self._simple_select_list(
|
||||
table="device_lists_remote_cache",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
},
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("device_id", "content"),
|
||||
desc="_get_cached_devices_for_user",
|
||||
)
|
||||
defer.returnValue({
|
||||
device["device_id"]: db_to_json(device["content"])
|
||||
for device in devices
|
||||
})
|
||||
defer.returnValue(
|
||||
{device["device_id"]: db_to_json(device["content"]) for device in devices}
|
||||
)
|
||||
|
||||
def get_devices_with_keys_by_user(self, user_id):
|
||||
"""Get all devices (with any device keys) for a user
|
||||
@ -272,7 +271,8 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"get_devices_with_keys_by_user",
|
||||
self._get_devices_with_keys_by_user_txn, user_id,
|
||||
self._get_devices_with_keys_by_user_txn,
|
||||
user_id,
|
||||
)
|
||||
|
||||
def _get_devices_with_keys_by_user_txn(self, txn, user_id):
|
||||
@ -286,9 +286,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
user_devices = devices[user_id]
|
||||
results = []
|
||||
for device_id, device in iteritems(user_devices):
|
||||
result = {
|
||||
"device_id": device_id,
|
||||
}
|
||||
result = {"device_id": device_id}
|
||||
|
||||
key_json = device.get("key_json", None)
|
||||
if key_json:
|
||||
@ -315,7 +313,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
sql = """
|
||||
SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
|
||||
"""
|
||||
rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
|
||||
rows = yield self._execute(
|
||||
"get_user_whose_devices_changed", None, sql, from_key
|
||||
)
|
||||
defer.returnValue(set(row[0] for row in rows))
|
||||
|
||||
def get_all_device_list_changes_for_remotes(self, from_key, to_key):
|
||||
@ -333,8 +333,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
GROUP BY user_id, destination
|
||||
"""
|
||||
return self._execute(
|
||||
"get_all_device_list_changes_for_remotes", None,
|
||||
sql, from_key, to_key
|
||||
"get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
|
||||
)
|
||||
|
||||
@cached(max_entries=10000)
|
||||
@ -350,21 +349,22 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
@cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
|
||||
list_name="user_ids", inlineCallbacks=True)
|
||||
@cachedList(
|
||||
cached_method_name="get_device_list_last_stream_id_for_remote",
|
||||
list_name="user_ids",
|
||||
inlineCallbacks=True,
|
||||
)
|
||||
def get_device_list_last_stream_id_for_remotes(self, user_ids):
|
||||
rows = yield self._simple_select_many_batch(
|
||||
table="device_lists_remote_extremeties",
|
||||
column="user_id",
|
||||
iterable=user_ids,
|
||||
retcols=("user_id", "stream_id",),
|
||||
retcols=("user_id", "stream_id"),
|
||||
desc="get_device_list_last_stream_id_for_remotes",
|
||||
)
|
||||
|
||||
results = {user_id: None for user_id in user_ids}
|
||||
results.update({
|
||||
row["user_id"]: row["stream_id"] for row in rows
|
||||
})
|
||||
results.update({row["user_id"]: row["stream_id"] for row in rows})
|
||||
|
||||
defer.returnValue(results)
|
||||
|
||||
@ -376,14 +376,10 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
|
||||
# Map of (user_id, device_id) -> bool. If there is an entry that implies
|
||||
# the device exists.
|
||||
self.device_id_exists_cache = Cache(
|
||||
name="device_id_exists",
|
||||
keylen=2,
|
||||
max_entries=10000,
|
||||
name="device_id_exists", keylen=2, max_entries=10000
|
||||
)
|
||||
|
||||
self._clock.looping_call(
|
||||
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
|
||||
)
|
||||
self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
|
||||
|
||||
self.register_background_index_update(
|
||||
"device_lists_stream_idx",
|
||||
@ -417,8 +413,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def store_device(self, user_id, device_id,
|
||||
initial_device_display_name):
|
||||
def store_device(self, user_id, device_id, initial_device_display_name):
|
||||
"""Ensure the given device is known; add it to the store if not
|
||||
|
||||
Args:
|
||||
@ -440,7 +435,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
|
||||
values={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"display_name": initial_device_display_name
|
||||
"display_name": initial_device_display_name,
|
||||
},
|
||||
desc="store_device",
|
||||
or_ignore=True,
|
||||
@ -448,12 +443,17 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
|
||||
self.device_id_exists_cache.prefill(key, True)
|
||||
defer.returnValue(inserted)
|
||||
except Exception as e:
|
||||
logger.error("store_device with device_id=%s(%r) user_id=%s(%r)"
|
||||
" display_name=%s(%r) failed: %s",
|
||||
type(device_id).__name__, device_id,
|
||||
type(user_id).__name__, user_id,
|
||||
type(initial_device_display_name).__name__,
|
||||
initial_device_display_name, e)
|
||||
logger.error(
|
||||
"store_device with device_id=%s(%r) user_id=%s(%r)"
|
||||
" display_name=%s(%r) failed: %s",
|
||||
type(device_id).__name__,
|
||||
device_id,
|
||||
type(user_id).__name__,
|
||||
user_id,
|
||||
type(initial_device_display_name).__name__,
|
||||
initial_device_display_name,
|
||||
e,
|
||||
)
|
||||
raise StoreError(500, "Problem storing device.")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -525,15 +525,14 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
|
||||
"""
|
||||
yield self._simple_delete(
|
||||
table="device_lists_remote_extremeties",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
},
|
||||
keyvalues={"user_id": user_id},
|
||||
desc="mark_remote_user_device_list_as_unsubscribed",
|
||||
)
|
||||
self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
|
||||
|
||||
def update_remote_device_list_cache_entry(self, user_id, device_id, content,
|
||||
stream_id):
|
||||
def update_remote_device_list_cache_entry(
|
||||
self, user_id, device_id, content, stream_id
|
||||
):
|
||||
"""Updates a single device in the cache of a remote user's devicelist.
|
||||
|
||||
Note: assumes that we are the only thread that can be updating this user's
|
||||
@ -551,42 +550,35 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
|
||||
return self.runInteraction(
|
||||
"update_remote_device_list_cache_entry",
|
||||
self._update_remote_device_list_cache_entry_txn,
|
||||
user_id, device_id, content, stream_id,
|
||||
user_id,
|
||||
device_id,
|
||||
content,
|
||||
stream_id,
|
||||
)
|
||||
|
||||
def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
|
||||
content, stream_id):
|
||||
def _update_remote_device_list_cache_entry_txn(
|
||||
self, txn, user_id, device_id, content, stream_id
|
||||
):
|
||||
if content.get("deleted"):
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="device_lists_remote_cache",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
)
|
||||
|
||||
txn.call_after(
|
||||
self.device_id_exists_cache.invalidate, (user_id, device_id,)
|
||||
)
|
||||
txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id))
|
||||
else:
|
||||
self._simple_upsert_txn(
|
||||
txn,
|
||||
table="device_lists_remote_cache",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
values={
|
||||
"content": json.dumps(content),
|
||||
},
|
||||
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
values={"content": json.dumps(content)},
|
||||
# we don't need to lock, because we assume we are the only thread
|
||||
# updating this user's devices.
|
||||
lock=False,
|
||||
)
|
||||
|
||||
txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,))
|
||||
txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id))
|
||||
txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
|
||||
txn.call_after(
|
||||
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
|
||||
@ -595,13 +587,8 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
|
||||
self._simple_upsert_txn(
|
||||
txn,
|
||||
table="device_lists_remote_extremeties",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
},
|
||||
values={
|
||||
"stream_id": stream_id,
|
||||
},
|
||||
|
||||
keyvalues={"user_id": user_id},
|
||||
values={"stream_id": stream_id},
|
||||
# again, we can assume we are the only thread updating this user's
|
||||
# extremity.
|
||||
lock=False,
|
||||
@ -624,17 +611,14 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
|
||||
return self.runInteraction(
|
||||
"update_remote_device_list_cache",
|
||||
self._update_remote_device_list_cache_txn,
|
||||
user_id, devices, stream_id,
|
||||
user_id,
|
||||
devices,
|
||||
stream_id,
|
||||
)
|
||||
|
||||
def _update_remote_device_list_cache_txn(self, txn, user_id, devices,
|
||||
stream_id):
|
||||
def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id):
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="device_lists_remote_cache",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
},
|
||||
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
@ -647,7 +631,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
|
||||
"content": json.dumps(content),
|
||||
}
|
||||
for content in devices
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
|
||||
@ -659,13 +643,8 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
|
||||
self._simple_upsert_txn(
|
||||
txn,
|
||||
table="device_lists_remote_extremeties",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
},
|
||||
values={
|
||||
"stream_id": stream_id,
|
||||
},
|
||||
|
||||
keyvalues={"user_id": user_id},
|
||||
values={"stream_id": stream_id},
|
||||
# we don't need to lock, because we can assume we are the only thread
|
||||
# updating this user's extremity.
|
||||
lock=False,
|
||||
@ -678,8 +657,12 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
|
||||
"""
|
||||
with self._device_list_id_gen.get_next() as stream_id:
|
||||
yield self.runInteraction(
|
||||
"add_device_change_to_streams", self._add_device_change_txn,
|
||||
user_id, device_ids, hosts, stream_id,
|
||||
"add_device_change_to_streams",
|
||||
self._add_device_change_txn,
|
||||
user_id,
|
||||
device_ids,
|
||||
hosts,
|
||||
stream_id,
|
||||
)
|
||||
defer.returnValue(stream_id)
|
||||
|
||||
@ -687,13 +670,13 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
|
||||
now = self._clock.time_msec()
|
||||
|
||||
txn.call_after(
|
||||
self._device_list_stream_cache.entity_has_changed,
|
||||
user_id, stream_id,
|
||||
self._device_list_stream_cache.entity_has_changed, user_id, stream_id
|
||||
)
|
||||
for host in hosts:
|
||||
txn.call_after(
|
||||
self._device_list_federation_stream_cache.entity_has_changed,
|
||||
host, stream_id,
|
||||
host,
|
||||
stream_id,
|
||||
)
|
||||
|
||||
# Delete older entries in the table, as we really only care about
|
||||
@ -703,20 +686,16 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
|
||||
DELETE FROM device_lists_stream
|
||||
WHERE user_id = ? AND device_id = ? AND stream_id < ?
|
||||
""",
|
||||
[(user_id, device_id, stream_id) for device_id in device_ids]
|
||||
[(user_id, device_id, stream_id) for device_id in device_ids],
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="device_lists_stream",
|
||||
values=[
|
||||
{
|
||||
"stream_id": stream_id,
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
}
|
||||
{"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
|
||||
for device_id in device_ids
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
@ -733,7 +712,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
|
||||
}
|
||||
for destination in hosts
|
||||
for device_id in device_ids
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
def _prune_old_outbound_device_pokes(self):
|
||||
@ -764,11 +743,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
|
||||
"""
|
||||
|
||||
txn.executemany(
|
||||
delete_sql,
|
||||
(
|
||||
(yesterday, row[0], row[1], row[2])
|
||||
for row in rows
|
||||
)
|
||||
delete_sql, ((yesterday, row[0], row[1], row[2]) for row in rows)
|
||||
)
|
||||
|
||||
# Since we've deleted unsent deltas, we need to remove the entry
|
||||
@ -792,12 +767,8 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
|
||||
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
|
||||
def f(conn):
|
||||
txn = conn.cursor()
|
||||
txn.execute(
|
||||
"DROP INDEX IF EXISTS device_lists_remote_cache_id"
|
||||
)
|
||||
txn.execute(
|
||||
"DROP INDEX IF EXISTS device_lists_remote_extremeties_id"
|
||||
)
|
||||
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
|
||||
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
|
||||
txn.close()
|
||||
|
||||
yield self.runWithConnection(f)
|
||||
|
@ -22,10 +22,7 @@ from synapse.util.caches.descriptors import cached
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
RoomAliasMapping = namedtuple(
|
||||
"RoomAliasMapping",
|
||||
("room_id", "room_alias", "servers",)
|
||||
)
|
||||
RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
|
||||
|
||||
|
||||
class DirectoryWorkerStore(SQLBaseStore):
|
||||
@ -63,16 +60,12 @@ class DirectoryWorkerStore(SQLBaseStore):
|
||||
defer.returnValue(None)
|
||||
return
|
||||
|
||||
defer.returnValue(
|
||||
RoomAliasMapping(room_id, room_alias.to_string(), servers)
|
||||
)
|
||||
defer.returnValue(RoomAliasMapping(room_id, room_alias.to_string(), servers))
|
||||
|
||||
def get_room_alias_creator(self, room_alias):
|
||||
return self._simple_select_one_onecol(
|
||||
table="room_aliases",
|
||||
keyvalues={
|
||||
"room_alias": room_alias,
|
||||
},
|
||||
keyvalues={"room_alias": room_alias},
|
||||
retcol="creator",
|
||||
desc="get_room_alias_creator",
|
||||
)
|
||||
@ -101,6 +94,7 @@ class DirectoryStore(DirectoryWorkerStore):
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
|
||||
def alias_txn(txn):
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
@ -115,10 +109,10 @@ class DirectoryStore(DirectoryWorkerStore):
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="room_alias_servers",
|
||||
values=[{
|
||||
"room_alias": room_alias.to_string(),
|
||||
"server": server,
|
||||
} for server in servers],
|
||||
values=[
|
||||
{"room_alias": room_alias.to_string(), "server": server}
|
||||
for server in servers
|
||||
],
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
@ -126,9 +120,7 @@ class DirectoryStore(DirectoryWorkerStore):
|
||||
)
|
||||
|
||||
try:
|
||||
ret = yield self.runInteraction(
|
||||
"create_room_alias_association", alias_txn
|
||||
)
|
||||
ret = yield self.runInteraction("create_room_alias_association", alias_txn)
|
||||
except self.database_engine.module.IntegrityError:
|
||||
raise SynapseError(
|
||||
409, "Room alias %s already exists" % room_alias.to_string()
|
||||
@ -138,9 +130,7 @@ class DirectoryStore(DirectoryWorkerStore):
|
||||
@defer.inlineCallbacks
|
||||
def delete_room_alias(self, room_alias):
|
||||
room_id = yield self.runInteraction(
|
||||
"delete_room_alias",
|
||||
self._delete_room_alias_txn,
|
||||
room_alias,
|
||||
"delete_room_alias", self._delete_room_alias_txn, room_alias
|
||||
)
|
||||
|
||||
defer.returnValue(room_id)
|
||||
@ -148,7 +138,7 @@ class DirectoryStore(DirectoryWorkerStore):
|
||||
def _delete_room_alias_txn(self, txn, room_alias):
|
||||
txn.execute(
|
||||
"SELECT room_id FROM room_aliases WHERE room_alias = ?",
|
||||
(room_alias.to_string(),)
|
||||
(room_alias.to_string(),),
|
||||
)
|
||||
|
||||
res = txn.fetchone()
|
||||
@ -158,31 +148,29 @@ class DirectoryStore(DirectoryWorkerStore):
|
||||
return None
|
||||
|
||||
txn.execute(
|
||||
"DELETE FROM room_aliases WHERE room_alias = ?",
|
||||
(room_alias.to_string(),)
|
||||
"DELETE FROM room_aliases WHERE room_alias = ?", (room_alias.to_string(),)
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
"DELETE FROM room_alias_servers WHERE room_alias = ?",
|
||||
(room_alias.to_string(),)
|
||||
(room_alias.to_string(),),
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_aliases_for_room, (room_id,)
|
||||
)
|
||||
self._invalidate_cache_and_stream(txn, self.get_aliases_for_room, (room_id,))
|
||||
|
||||
return room_id
|
||||
|
||||
def update_aliases_for_room(self, old_room_id, new_room_id, creator):
|
||||
def _update_aliases_for_room_txn(txn):
|
||||
sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?"
|
||||
txn.execute(sql, (new_room_id, creator, old_room_id,))
|
||||
txn.execute(sql, (new_room_id, creator, old_room_id))
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_aliases_for_room, (old_room_id,)
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_aliases_for_room, (new_room_id,)
|
||||
)
|
||||
|
||||
return self.runInteraction(
|
||||
"_update_aliases_for_room_txn", _update_aliases_for_room_txn
|
||||
)
|
||||
|
@ -23,7 +23,6 @@ from ._base import SQLBaseStore
|
||||
|
||||
|
||||
class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_e2e_room_key(self, user_id, version, room_id, session_id):
|
||||
"""Get the encrypted E2E room key for a given session from a given
|
||||
@ -97,9 +96,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_e2e_room_keys(
|
||||
self, user_id, version, room_id=None, session_id=None
|
||||
):
|
||||
def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
|
||||
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
|
||||
room, or a given session.
|
||||
|
||||
@ -123,10 +120,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
except ValueError:
|
||||
defer.returnValue({'rooms': {}})
|
||||
|
||||
keyvalues = {
|
||||
"user_id": user_id,
|
||||
"version": version,
|
||||
}
|
||||
keyvalues = {"user_id": user_id, "version": version}
|
||||
if room_id:
|
||||
keyvalues['room_id'] = room_id
|
||||
if session_id:
|
||||
@ -160,9 +154,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
defer.returnValue(sessions)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_e2e_room_keys(
|
||||
self, user_id, version, room_id=None, session_id=None
|
||||
):
|
||||
def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
|
||||
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
|
||||
room or a given session.
|
||||
|
||||
@ -180,19 +172,14 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
A deferred of the deletion transaction
|
||||
"""
|
||||
|
||||
keyvalues = {
|
||||
"user_id": user_id,
|
||||
"version": int(version),
|
||||
}
|
||||
keyvalues = {"user_id": user_id, "version": int(version)}
|
||||
if room_id:
|
||||
keyvalues['room_id'] = room_id
|
||||
if session_id:
|
||||
keyvalues['session_id'] = session_id
|
||||
|
||||
yield self._simple_delete(
|
||||
table="e2e_room_keys",
|
||||
keyvalues=keyvalues,
|
||||
desc="delete_e2e_room_keys",
|
||||
table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -200,7 +187,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
txn.execute(
|
||||
"SELECT MAX(version) FROM e2e_room_keys_versions "
|
||||
"WHERE user_id=? AND deleted=0",
|
||||
(user_id,)
|
||||
(user_id,),
|
||||
)
|
||||
row = txn.fetchone()
|
||||
if not row:
|
||||
@ -238,24 +225,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
result = self._simple_select_one_txn(
|
||||
txn,
|
||||
table="e2e_room_keys_versions",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"version": this_version,
|
||||
"deleted": 0,
|
||||
},
|
||||
retcols=(
|
||||
"version",
|
||||
"algorithm",
|
||||
"auth_data",
|
||||
),
|
||||
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
|
||||
retcols=("version", "algorithm", "auth_data"),
|
||||
)
|
||||
result["auth_data"] = json.loads(result["auth_data"])
|
||||
result["version"] = str(result["version"])
|
||||
return result
|
||||
|
||||
return self.runInteraction(
|
||||
"get_e2e_room_keys_version_info",
|
||||
_get_e2e_room_keys_version_info_txn
|
||||
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
|
||||
)
|
||||
|
||||
def create_e2e_room_keys_version(self, user_id, info):
|
||||
@ -273,7 +251,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
def _create_e2e_room_keys_version_txn(txn):
|
||||
txn.execute(
|
||||
"SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?",
|
||||
(user_id,)
|
||||
(user_id,),
|
||||
)
|
||||
current_version = txn.fetchone()[0]
|
||||
if current_version is None:
|
||||
@ -309,14 +287,9 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
|
||||
return self._simple_update(
|
||||
table="e2e_room_keys_versions",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"version": version,
|
||||
},
|
||||
updatevalues={
|
||||
"auth_data": json.dumps(info["auth_data"]),
|
||||
},
|
||||
desc="update_e2e_room_keys_version"
|
||||
keyvalues={"user_id": user_id, "version": version},
|
||||
updatevalues={"auth_data": json.dumps(info["auth_data"])},
|
||||
desc="update_e2e_room_keys_version",
|
||||
)
|
||||
|
||||
def delete_e2e_room_keys_version(self, user_id, version=None):
|
||||
@ -341,16 +314,10 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
return self._simple_update_one_txn(
|
||||
txn,
|
||||
table="e2e_room_keys_versions",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"version": this_version,
|
||||
},
|
||||
updatevalues={
|
||||
"deleted": 1,
|
||||
}
|
||||
keyvalues={"user_id": user_id, "version": this_version},
|
||||
updatevalues={"deleted": 1},
|
||||
)
|
||||
|
||||
return self.runInteraction(
|
||||
"delete_e2e_room_keys_version",
|
||||
_delete_e2e_room_keys_version_txn
|
||||
"delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn
|
||||
)
|
||||
|
@ -26,8 +26,7 @@ from ._base import SQLBaseStore, db_to_json
|
||||
class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
@defer.inlineCallbacks
|
||||
def get_e2e_device_keys(
|
||||
self, query_list, include_all_devices=False,
|
||||
include_deleted_devices=False,
|
||||
self, query_list, include_all_devices=False, include_deleted_devices=False
|
||||
):
|
||||
"""Fetch a list of device keys.
|
||||
Args:
|
||||
@ -45,8 +44,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
defer.returnValue({})
|
||||
|
||||
results = yield self.runInteraction(
|
||||
"get_e2e_device_keys", self._get_e2e_device_keys_txn,
|
||||
query_list, include_all_devices, include_deleted_devices,
|
||||
"get_e2e_device_keys",
|
||||
self._get_e2e_device_keys_txn,
|
||||
query_list,
|
||||
include_all_devices,
|
||||
include_deleted_devices,
|
||||
)
|
||||
|
||||
for user_id, device_keys in iteritems(results):
|
||||
@ -56,8 +58,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
defer.returnValue(results)
|
||||
|
||||
def _get_e2e_device_keys_txn(
|
||||
self, txn, query_list, include_all_devices=False,
|
||||
include_deleted_devices=False,
|
||||
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
|
||||
):
|
||||
query_clauses = []
|
||||
query_params = []
|
||||
@ -87,7 +88,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
" WHERE %s"
|
||||
) % (
|
||||
"LEFT" if include_all_devices else "INNER",
|
||||
" OR ".join("(" + q + ")" for q in query_clauses)
|
||||
" OR ".join("(" + q + ")" for q in query_clauses),
|
||||
)
|
||||
|
||||
txn.execute(sql, query_params)
|
||||
@ -124,17 +125,14 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
table="e2e_one_time_keys_json",
|
||||
column="key_id",
|
||||
iterable=key_ids,
|
||||
retcols=("algorithm", "key_id", "key_json",),
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
retcols=("algorithm", "key_id", "key_json"),
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
desc="add_e2e_one_time_keys_check",
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
(row["algorithm"], row["key_id"]): row["key_json"] for row in rows
|
||||
})
|
||||
defer.returnValue(
|
||||
{(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
|
||||
@ -155,7 +153,8 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
# `add_e2e_one_time_keys` then they'll conflict and we will only
|
||||
# insert one set.
|
||||
self._simple_insert_many_txn(
|
||||
txn, table="e2e_one_time_keys_json",
|
||||
txn,
|
||||
table="e2e_one_time_keys_json",
|
||||
values=[
|
||||
{
|
||||
"user_id": user_id,
|
||||
@ -169,8 +168,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
],
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.count_e2e_one_time_keys, (user_id, device_id,)
|
||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||
)
|
||||
|
||||
yield self.runInteraction(
|
||||
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
|
||||
)
|
||||
@ -181,6 +181,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
Returns:
|
||||
Dict mapping from algorithm to number of keys for that algorithm.
|
||||
"""
|
||||
|
||||
def _count_e2e_one_time_keys(txn):
|
||||
sql = (
|
||||
"SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
|
||||
@ -192,9 +193,8 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
for algorithm, key_count in txn:
|
||||
result[algorithm] = key_count
|
||||
return result
|
||||
return self.runInteraction(
|
||||
"count_e2e_one_time_keys", _count_e2e_one_time_keys
|
||||
)
|
||||
|
||||
return self.runInteraction("count_e2e_one_time_keys", _count_e2e_one_time_keys)
|
||||
|
||||
|
||||
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
@ -202,14 +202,12 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
"""Stores device keys for a device. Returns whether there was a change
|
||||
or the keys were already in the database.
|
||||
"""
|
||||
|
||||
def _set_e2e_device_keys_txn(txn):
|
||||
old_key_json = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="e2e_device_keys_json",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
retcol="key_json",
|
||||
allow_none=True,
|
||||
)
|
||||
@ -224,24 +222,17 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
self._simple_upsert_txn(
|
||||
txn,
|
||||
table="e2e_device_keys_json",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
values={
|
||||
"ts_added_ms": time_now,
|
||||
"key_json": new_key_json,
|
||||
}
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
values={"ts_added_ms": time_now, "key_json": new_key_json},
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
return self.runInteraction(
|
||||
"set_e2e_device_keys", _set_e2e_device_keys_txn
|
||||
)
|
||||
return self.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn)
|
||||
|
||||
def claim_e2e_one_time_keys(self, query_list):
|
||||
"""Take a list of one time keys out of the database"""
|
||||
|
||||
def _claim_e2e_one_time_keys(txn):
|
||||
sql = (
|
||||
"SELECT key_id, key_json FROM e2e_one_time_keys_json"
|
||||
@ -265,12 +256,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
for user_id, device_id, algorithm, key_id in delete:
|
||||
txn.execute(sql, (user_id, device_id, algorithm, key_id))
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.count_e2e_one_time_keys, (user_id, device_id,)
|
||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||
)
|
||||
return result
|
||||
return self.runInteraction(
|
||||
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
|
||||
)
|
||||
|
||||
return self.runInteraction("claim_e2e_one_time_keys", _claim_e2e_one_time_keys)
|
||||
|
||||
def delete_e2e_keys_by_device(self, user_id, device_id):
|
||||
def delete_e2e_keys_by_device_txn(txn):
|
||||
@ -285,8 +275,9 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.count_e2e_one_time_keys, (user_id, device_id,)
|
||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||
)
|
||||
|
||||
return self.runInteraction(
|
||||
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
|
||||
)
|
||||
|
@ -20,10 +20,7 @@ from ._base import IncorrectDatabaseSetup
|
||||
from .postgres import PostgresEngine
|
||||
from .sqlite import Sqlite3Engine
|
||||
|
||||
SUPPORTED_MODULE = {
|
||||
"sqlite3": Sqlite3Engine,
|
||||
"psycopg2": PostgresEngine,
|
||||
}
|
||||
SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine}
|
||||
|
||||
|
||||
def create_engine(database_config):
|
||||
@ -32,15 +29,12 @@ def create_engine(database_config):
|
||||
|
||||
if engine_class:
|
||||
# pypy requires psycopg2cffi rather than psycopg2
|
||||
if (name == "psycopg2" and
|
||||
platform.python_implementation() == "PyPy"):
|
||||
if name == "psycopg2" and platform.python_implementation() == "PyPy":
|
||||
name = "psycopg2cffi"
|
||||
module = importlib.import_module(name)
|
||||
return engine_class(module, database_config)
|
||||
|
||||
raise RuntimeError(
|
||||
"Unsupported database engine '%s'" % (name,)
|
||||
)
|
||||
raise RuntimeError("Unsupported database engine '%s'" % (name,))
|
||||
|
||||
|
||||
__all__ = ["create_engine", "IncorrectDatabaseSetup"]
|
||||
|
@ -23,7 +23,7 @@ class PostgresEngine(object):
|
||||
self.module = database_module
|
||||
self.module.extensions.register_type(self.module.extensions.UNICODE)
|
||||
self.synchronous_commit = database_config.get("synchronous_commit", True)
|
||||
self._version = None # unknown as yet
|
||||
self._version = None # unknown as yet
|
||||
|
||||
def check_database(self, txn):
|
||||
txn.execute("SHOW SERVER_ENCODING")
|
||||
@ -31,8 +31,7 @@ class PostgresEngine(object):
|
||||
if rows and rows[0][0] != "UTF8":
|
||||
raise IncorrectDatabaseSetup(
|
||||
"Database has incorrect encoding: '%s' instead of 'UTF8'\n"
|
||||
"See docs/postgres.rst for more information."
|
||||
% (rows[0][0],)
|
||||
"See docs/postgres.rst for more information." % (rows[0][0],)
|
||||
)
|
||||
|
||||
def convert_param_style(self, sql):
|
||||
@ -103,12 +102,6 @@ class PostgresEngine(object):
|
||||
|
||||
# https://www.postgresql.org/docs/current/libpq-status.html#LIBPQ-PQSERVERVERSION
|
||||
if numver >= 100000:
|
||||
return "%i.%i" % (
|
||||
numver / 10000, numver % 10000,
|
||||
)
|
||||
return "%i.%i" % (numver / 10000, numver % 10000)
|
||||
else:
|
||||
return "%i.%i.%i" % (
|
||||
numver / 10000,
|
||||
(numver % 10000) / 100,
|
||||
numver % 100,
|
||||
)
|
||||
return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100)
|
||||
|
@ -82,9 +82,10 @@ class Sqlite3Engine(object):
|
||||
|
||||
# Following functions taken from: https://github.com/coleifer/peewee
|
||||
|
||||
|
||||
def _parse_match_info(buf):
|
||||
bufsize = len(buf)
|
||||
return [struct.unpack('@I', buf[i:i + 4])[0] for i in range(0, bufsize, 4)]
|
||||
return [struct.unpack('@I', buf[i : i + 4])[0] for i in range(0, bufsize, 4)]
|
||||
|
||||
|
||||
def _rank(raw_match_info):
|
||||
@ -98,7 +99,7 @@ def _rank(raw_match_info):
|
||||
phrase_info_idx = 2 + (phrase_num * c * 3)
|
||||
for col_num in range(c):
|
||||
col_idx = phrase_info_idx + (col_num * 3)
|
||||
x1, x2 = match_info[col_idx:col_idx + 2]
|
||||
x1, x2 = match_info[col_idx : col_idx + 2]
|
||||
if x1 > 0:
|
||||
score += float(x1) / x2
|
||||
return score
|
||||
|
@ -32,8 +32,7 @@ from synapse.util.caches.descriptors import cached
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
|
||||
SQLBaseStore):
|
||||
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
|
||||
def get_auth_chain(self, event_ids, include_given=False):
|
||||
"""Get auth events for given event_ids. The events *must* be state events.
|
||||
|
||||
@ -45,7 +44,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
|
||||
list of events
|
||||
"""
|
||||
return self.get_auth_chain_ids(
|
||||
event_ids, include_given=include_given,
|
||||
event_ids, include_given=include_given
|
||||
).addCallback(self._get_events)
|
||||
|
||||
def get_auth_chain_ids(self, event_ids, include_given=False):
|
||||
@ -59,9 +58,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
|
||||
list of event_ids
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"get_auth_chain_ids",
|
||||
self._get_auth_chain_ids_txn,
|
||||
event_ids, include_given
|
||||
"get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given
|
||||
)
|
||||
|
||||
def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
|
||||
@ -70,23 +67,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
|
||||
else:
|
||||
results = set()
|
||||
|
||||
base_sql = (
|
||||
"SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
|
||||
)
|
||||
base_sql = "SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
|
||||
|
||||
front = set(event_ids)
|
||||
while front:
|
||||
new_front = set()
|
||||
front_list = list(front)
|
||||
chunks = [
|
||||
front_list[x:x + 100]
|
||||
for x in range(0, len(front), 100)
|
||||
]
|
||||
chunks = [front_list[x : x + 100] for x in range(0, len(front), 100)]
|
||||
for chunk in chunks:
|
||||
txn.execute(
|
||||
base_sql % (",".join(["?"] * len(chunk)),),
|
||||
chunk
|
||||
)
|
||||
txn.execute(base_sql % (",".join(["?"] * len(chunk)),), chunk)
|
||||
new_front.update([r[0] for r in txn])
|
||||
|
||||
new_front -= results
|
||||
@ -98,9 +87,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
|
||||
|
||||
def get_oldest_events_in_room(self, room_id):
|
||||
return self.runInteraction(
|
||||
"get_oldest_events_in_room",
|
||||
self._get_oldest_events_in_room_txn,
|
||||
room_id,
|
||||
"get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
|
||||
)
|
||||
|
||||
def get_oldest_events_with_depth_in_room(self, room_id):
|
||||
@ -121,7 +108,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
|
||||
" GROUP BY b.event_id"
|
||||
)
|
||||
|
||||
txn.execute(sql, (room_id, False,))
|
||||
txn.execute(sql, (room_id, False))
|
||||
|
||||
return dict(txn)
|
||||
|
||||
@ -152,9 +139,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
|
||||
return self._simple_select_onecol_txn(
|
||||
txn,
|
||||
table="event_backward_extremities",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
},
|
||||
keyvalues={"room_id": room_id},
|
||||
retcol="event_id",
|
||||
)
|
||||
|
||||
@ -209,9 +194,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
|
||||
def get_latest_event_ids_in_room(self, room_id):
|
||||
return self._simple_select_onecol(
|
||||
table="event_forward_extremities",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
},
|
||||
keyvalues={"room_id": room_id},
|
||||
retcol="event_id",
|
||||
desc="get_latest_event_ids_in_room",
|
||||
)
|
||||
@ -225,14 +208,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
|
||||
"WHERE f.room_id = ?"
|
||||
)
|
||||
|
||||
txn.execute(sql, (room_id, ))
|
||||
txn.execute(sql, (room_id,))
|
||||
|
||||
results = []
|
||||
for event_id, depth in txn.fetchall():
|
||||
hashes = self._get_event_reference_hashes_txn(txn, event_id)
|
||||
prev_hashes = {
|
||||
k: encode_base64(v) for k, v in hashes.items()
|
||||
if k == "sha256"
|
||||
k: encode_base64(v) for k, v in hashes.items() if k == "sha256"
|
||||
}
|
||||
results.append((event_id, prev_hashes, depth))
|
||||
|
||||
@ -242,9 +224,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
|
||||
""" For hte given room, get the minimum depth we have seen for it.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"get_min_depth",
|
||||
self._get_min_depth_interaction,
|
||||
room_id,
|
||||
"get_min_depth", self._get_min_depth_interaction, room_id
|
||||
)
|
||||
|
||||
def _get_min_depth_interaction(self, txn, room_id):
|
||||
@ -300,7 +280,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
|
||||
if stream_ordering <= self.stream_ordering_month_ago:
|
||||
raise StoreError(400, "stream_ordering too old")
|
||||
|
||||
sql = ("""
|
||||
sql = """
|
||||
SELECT event_id FROM stream_ordering_to_exterm
|
||||
INNER JOIN (
|
||||
SELECT room_id, MAX(stream_ordering) AS stream_ordering
|
||||
@ -308,15 +288,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
|
||||
WHERE stream_ordering <= ? GROUP BY room_id
|
||||
) AS rms USING (room_id, stream_ordering)
|
||||
WHERE room_id = ?
|
||||
""")
|
||||
"""
|
||||
|
||||
def get_forward_extremeties_for_room_txn(txn):
|
||||
txn.execute(sql, (stream_ordering, room_id))
|
||||
return [event_id for event_id, in txn]
|
||||
|
||||
return self.runInteraction(
|
||||
"get_forward_extremeties_for_room",
|
||||
get_forward_extremeties_for_room_txn
|
||||
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
|
||||
)
|
||||
|
||||
def get_backfill_events(self, room_id, event_list, limit):
|
||||
@ -329,19 +308,21 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
|
||||
event_list (list)
|
||||
limit (int)
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"get_backfill_events",
|
||||
self._get_backfill_events, room_id, event_list, limit
|
||||
).addCallback(
|
||||
self._get_events
|
||||
).addCallback(
|
||||
lambda l: sorted(l, key=lambda e: -e.depth)
|
||||
return (
|
||||
self.runInteraction(
|
||||
"get_backfill_events",
|
||||
self._get_backfill_events,
|
||||
room_id,
|
||||
event_list,
|
||||
limit,
|
||||
)
|
||||
.addCallback(self._get_events)
|
||||
.addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
|
||||
)
|
||||
|
||||
def _get_backfill_events(self, txn, room_id, event_list, limit):
|
||||
logger.debug(
|
||||
"_get_backfill_events: %s, %s, %s",
|
||||
room_id, repr(event_list), limit
|
||||
"_get_backfill_events: %s, %s, %s", room_id, repr(event_list), limit
|
||||
)
|
||||
|
||||
event_results = set()
|
||||
@ -364,10 +345,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
|
||||
depth = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="events",
|
||||
keyvalues={
|
||||
"event_id": event_id,
|
||||
"room_id": room_id,
|
||||
},
|
||||
keyvalues={"event_id": event_id, "room_id": room_id},
|
||||
retcol="depth",
|
||||
allow_none=True,
|
||||
)
|
||||
@ -386,10 +364,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
|
||||
|
||||
event_results.add(event_id)
|
||||
|
||||
txn.execute(
|
||||
query,
|
||||
(event_id, False, limit - len(event_results))
|
||||
)
|
||||
txn.execute(query, (event_id, False, limit - len(event_results)))
|
||||
|
||||
for row in txn:
|
||||
if row[1] not in event_results:
|
||||
@ -398,18 +373,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
|
||||
return event_results
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_missing_events(self, room_id, earliest_events, latest_events,
|
||||
limit):
|
||||
def get_missing_events(self, room_id, earliest_events, latest_events, limit):
|
||||
ids = yield self.runInteraction(
|
||||
"get_missing_events",
|
||||
self._get_missing_events,
|
||||
room_id, earliest_events, latest_events, limit,
|
||||
room_id,
|
||||
earliest_events,
|
||||
latest_events,
|
||||
limit,
|
||||
)
|
||||
events = yield self._get_events(ids)
|
||||
defer.returnValue(events)
|
||||
|
||||
def _get_missing_events(self, txn, room_id, earliest_events, latest_events,
|
||||
limit):
|
||||
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
|
||||
|
||||
seen_events = set(earliest_events)
|
||||
front = set(latest_events) - seen_events
|
||||
@ -425,8 +401,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
|
||||
new_front = set()
|
||||
for event_id in front:
|
||||
txn.execute(
|
||||
query,
|
||||
(room_id, event_id, False, limit - len(event_results))
|
||||
query, (room_id, event_id, False, limit - len(event_results))
|
||||
)
|
||||
|
||||
new_results = set(t[0] for t in txn) - seen_events
|
||||
@ -457,12 +432,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
|
||||
column="prev_event_id",
|
||||
iterable=event_ids,
|
||||
retcols=("event_id",),
|
||||
desc="get_successor_events"
|
||||
desc="get_successor_events",
|
||||
)
|
||||
|
||||
defer.returnValue([
|
||||
row["event_id"] for row in rows
|
||||
])
|
||||
defer.returnValue([row["event_id"] for row in rows])
|
||||
|
||||
|
||||
class EventFederationStore(EventFederationWorkerStore):
|
||||
@ -481,12 +454,11 @@ class EventFederationStore(EventFederationWorkerStore):
|
||||
super(EventFederationStore, self).__init__(db_conn, hs)
|
||||
|
||||
self.register_background_update_handler(
|
||||
self.EVENT_AUTH_STATE_ONLY,
|
||||
self._background_delete_non_state_event_auth,
|
||||
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
|
||||
)
|
||||
|
||||
hs.get_clock().looping_call(
|
||||
self._delete_old_forward_extrem_cache, 60 * 60 * 1000,
|
||||
self._delete_old_forward_extrem_cache, 60 * 60 * 1000
|
||||
)
|
||||
|
||||
def _update_min_depth_for_room_txn(self, txn, room_id, depth):
|
||||
@ -498,12 +470,8 @@ class EventFederationStore(EventFederationWorkerStore):
|
||||
self._simple_upsert_txn(
|
||||
txn,
|
||||
table="room_depth",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
},
|
||||
values={
|
||||
"min_depth": depth,
|
||||
},
|
||||
keyvalues={"room_id": room_id},
|
||||
values={"min_depth": depth},
|
||||
)
|
||||
|
||||
def _handle_mult_prev_events(self, txn, events):
|
||||
@ -553,11 +521,15 @@ class EventFederationStore(EventFederationWorkerStore):
|
||||
" )"
|
||||
)
|
||||
|
||||
txn.executemany(query, [
|
||||
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
|
||||
for ev in events for e_id in ev.prev_event_ids()
|
||||
if not ev.internal_metadata.is_outlier()
|
||||
])
|
||||
txn.executemany(
|
||||
query,
|
||||
[
|
||||
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
|
||||
for ev in events
|
||||
for e_id in ev.prev_event_ids()
|
||||
if not ev.internal_metadata.is_outlier()
|
||||
],
|
||||
)
|
||||
|
||||
query = (
|
||||
"DELETE FROM event_backward_extremities"
|
||||
@ -566,16 +538,17 @@ class EventFederationStore(EventFederationWorkerStore):
|
||||
txn.executemany(
|
||||
query,
|
||||
[
|
||||
(ev.event_id, ev.room_id) for ev in events
|
||||
(ev.event_id, ev.room_id)
|
||||
for ev in events
|
||||
if not ev.internal_metadata.is_outlier()
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
def _delete_old_forward_extrem_cache(self):
|
||||
def _delete_old_forward_extrem_cache_txn(txn):
|
||||
# Delete entries older than a month, while making sure we don't delete
|
||||
# the only entries for a room.
|
||||
sql = ("""
|
||||
sql = """
|
||||
DELETE FROM stream_ordering_to_exterm
|
||||
WHERE
|
||||
room_id IN (
|
||||
@ -583,11 +556,11 @@ class EventFederationStore(EventFederationWorkerStore):
|
||||
FROM stream_ordering_to_exterm
|
||||
WHERE stream_ordering > ?
|
||||
) AND stream_ordering < ?
|
||||
""")
|
||||
"""
|
||||
txn.execute(
|
||||
sql,
|
||||
(self.stream_ordering_month_ago, self.stream_ordering_month_ago,)
|
||||
sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
|
||||
)
|
||||
|
||||
return run_as_background_process(
|
||||
"delete_old_forward_extrem_cache",
|
||||
self.runInteraction,
|
||||
@ -597,9 +570,7 @@ class EventFederationStore(EventFederationWorkerStore):
|
||||
|
||||
def clean_room_for_join(self, room_id):
|
||||
return self.runInteraction(
|
||||
"clean_room_for_join",
|
||||
self._clean_room_for_join_txn,
|
||||
room_id,
|
||||
"clean_room_for_join", self._clean_room_for_join_txn, room_id
|
||||
)
|
||||
|
||||
def _clean_room_for_join_txn(self, txn, room_id):
|
||||
@ -635,7 +606,7 @@ class EventFederationStore(EventFederationWorkerStore):
|
||||
)
|
||||
"""
|
||||
|
||||
txn.execute(sql, (min_stream_id, max_stream_id,))
|
||||
txn.execute(sql, (min_stream_id, max_stream_id))
|
||||
|
||||
new_progress = {
|
||||
"target_min_stream_id_inclusive": target_min_stream_id,
|
||||
|
@ -31,7 +31,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}]
|
||||
DEFAULT_HIGHLIGHT_ACTION = [
|
||||
"notify", {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"}
|
||||
"notify",
|
||||
{"set_tweak": "sound", "value": "default"},
|
||||
{"set_tweak": "highlight"},
|
||||
]
|
||||
|
||||
|
||||
@ -91,25 +93,26 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
|
||||
@cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
|
||||
def get_unread_event_push_actions_by_room_for_user(
|
||||
self, room_id, user_id, last_read_event_id
|
||||
self, room_id, user_id, last_read_event_id
|
||||
):
|
||||
ret = yield self.runInteraction(
|
||||
"get_unread_event_push_actions_by_room",
|
||||
self._get_unread_counts_by_receipt_txn,
|
||||
room_id, user_id, last_read_event_id
|
||||
room_id,
|
||||
user_id,
|
||||
last_read_event_id,
|
||||
)
|
||||
defer.returnValue(ret)
|
||||
|
||||
def _get_unread_counts_by_receipt_txn(self, txn, room_id, user_id,
|
||||
last_read_event_id):
|
||||
def _get_unread_counts_by_receipt_txn(
|
||||
self, txn, room_id, user_id, last_read_event_id
|
||||
):
|
||||
sql = (
|
||||
"SELECT stream_ordering"
|
||||
" FROM events"
|
||||
" WHERE room_id = ? AND event_id = ?"
|
||||
)
|
||||
txn.execute(
|
||||
sql, (room_id, last_read_event_id)
|
||||
)
|
||||
txn.execute(sql, (room_id, last_read_event_id))
|
||||
results = txn.fetchall()
|
||||
if len(results) == 0:
|
||||
return {"notify_count": 0, "highlight_count": 0}
|
||||
@ -138,10 +141,13 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
row = txn.fetchone()
|
||||
notify_count = row[0] if row else 0
|
||||
|
||||
txn.execute("""
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT notif_count FROM event_push_summary
|
||||
WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
|
||||
""", (room_id, user_id, stream_ordering,))
|
||||
""",
|
||||
(room_id, user_id, stream_ordering),
|
||||
)
|
||||
rows = txn.fetchall()
|
||||
if rows:
|
||||
notify_count += rows[0][0]
|
||||
@ -161,10 +167,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
row = txn.fetchone()
|
||||
highlight_count = row[0] if row else 0
|
||||
|
||||
return {
|
||||
"notify_count": notify_count,
|
||||
"highlight_count": highlight_count,
|
||||
}
|
||||
return {"notify_count": notify_count, "highlight_count": highlight_count}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering):
|
||||
@ -175,6 +178,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
)
|
||||
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
|
||||
return [r[0] for r in txn]
|
||||
|
||||
ret = yield self.runInteraction("get_push_action_users_in_range", f)
|
||||
defer.returnValue(ret)
|
||||
|
||||
@ -223,12 +227,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
" AND ep.stream_ordering <= ?"
|
||||
" ORDER BY ep.stream_ordering ASC LIMIT ?"
|
||||
)
|
||||
args = [
|
||||
user_id, user_id,
|
||||
min_stream_ordering, max_stream_ordering, limit,
|
||||
]
|
||||
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
|
||||
txn.execute(sql, args)
|
||||
return txn.fetchall()
|
||||
|
||||
after_read_receipt = yield self.runInteraction(
|
||||
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
|
||||
)
|
||||
@ -253,12 +255,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
" AND ep.stream_ordering <= ?"
|
||||
" ORDER BY ep.stream_ordering ASC LIMIT ?"
|
||||
)
|
||||
args = [
|
||||
user_id, user_id,
|
||||
min_stream_ordering, max_stream_ordering, limit,
|
||||
]
|
||||
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
|
||||
txn.execute(sql, args)
|
||||
return txn.fetchall()
|
||||
|
||||
no_read_receipt = yield self.runInteraction(
|
||||
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
|
||||
)
|
||||
@ -269,7 +269,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
"room_id": row[1],
|
||||
"stream_ordering": row[2],
|
||||
"actions": _deserialize_action(row[3], row[4]),
|
||||
} for row in after_read_receipt + no_read_receipt
|
||||
}
|
||||
for row in after_read_receipt + no_read_receipt
|
||||
]
|
||||
|
||||
# Now sort it so it's ordered correctly, since currently it will
|
||||
@ -326,12 +327,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
" AND ep.stream_ordering <= ?"
|
||||
" ORDER BY ep.stream_ordering DESC LIMIT ?"
|
||||
)
|
||||
args = [
|
||||
user_id, user_id,
|
||||
min_stream_ordering, max_stream_ordering, limit,
|
||||
]
|
||||
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
|
||||
txn.execute(sql, args)
|
||||
return txn.fetchall()
|
||||
|
||||
after_read_receipt = yield self.runInteraction(
|
||||
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
|
||||
)
|
||||
@ -356,12 +355,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
" AND ep.stream_ordering <= ?"
|
||||
" ORDER BY ep.stream_ordering DESC LIMIT ?"
|
||||
)
|
||||
args = [
|
||||
user_id, user_id,
|
||||
min_stream_ordering, max_stream_ordering, limit,
|
||||
]
|
||||
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
|
||||
txn.execute(sql, args)
|
||||
return txn.fetchall()
|
||||
|
||||
no_read_receipt = yield self.runInteraction(
|
||||
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
|
||||
)
|
||||
@ -374,7 +371,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
"stream_ordering": row[2],
|
||||
"actions": _deserialize_action(row[3], row[4]),
|
||||
"received_ts": row[5],
|
||||
} for row in after_read_receipt + no_read_receipt
|
||||
}
|
||||
for row in after_read_receipt + no_read_receipt
|
||||
]
|
||||
|
||||
# Now sort it so it's ordered correctly, since currently it will
|
||||
@ -408,7 +406,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
txn.execute(sql, (user_id, min_stream_ordering,))
|
||||
txn.execute(sql, (user_id, min_stream_ordering))
|
||||
return bool(txn.fetchone())
|
||||
|
||||
return self.runInteraction(
|
||||
@ -454,10 +452,13 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
"""
|
||||
|
||||
txn.executemany(sql, (
|
||||
_gen_entry(user_id, actions)
|
||||
for user_id, actions in iteritems(user_id_actions)
|
||||
))
|
||||
txn.executemany(
|
||||
sql,
|
||||
(
|
||||
_gen_entry(user_id, actions)
|
||||
for user_id, actions in iteritems(user_id_actions)
|
||||
),
|
||||
)
|
||||
|
||||
return self.runInteraction(
|
||||
"add_push_actions_to_staging", _add_push_actions_to_staging_txn
|
||||
@ -475,9 +476,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
try:
|
||||
res = yield self._simple_delete(
|
||||
table="event_push_actions_staging",
|
||||
keyvalues={
|
||||
"event_id": event_id,
|
||||
},
|
||||
keyvalues={"event_id": event_id},
|
||||
desc="remove_push_actions_from_staging",
|
||||
)
|
||||
defer.returnValue(res)
|
||||
@ -486,7 +485,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
# another exception here really isn't helpful - there's nothing
|
||||
# the caller can do about it. Just log the exception and move on.
|
||||
logger.exception(
|
||||
"Error removing push actions after event persistence failure",
|
||||
"Error removing push actions after event persistence failure"
|
||||
)
|
||||
|
||||
def _find_stream_orderings_for_times(self):
|
||||
@ -503,16 +502,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
|
||||
)
|
||||
logger.info(
|
||||
"Found stream ordering 1 month ago: it's %d",
|
||||
self.stream_ordering_month_ago
|
||||
"Found stream ordering 1 month ago: it's %d", self.stream_ordering_month_ago
|
||||
)
|
||||
logger.info("Searching for stream ordering 1 day ago")
|
||||
self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn(
|
||||
txn, self._clock.time_msec() - 24 * 60 * 60 * 1000
|
||||
)
|
||||
logger.info(
|
||||
"Found stream ordering 1 day ago: it's %d",
|
||||
self.stream_ordering_day_ago
|
||||
"Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago
|
||||
)
|
||||
|
||||
def find_first_stream_ordering_after_ts(self, ts):
|
||||
@ -631,16 +628,17 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
||||
index_name="event_push_actions_highlights_index",
|
||||
table="event_push_actions",
|
||||
columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
|
||||
where_clause="highlight=1"
|
||||
where_clause="highlight=1",
|
||||
)
|
||||
|
||||
self._doing_notif_rotation = False
|
||||
self._rotate_notif_loop = self._clock.looping_call(
|
||||
self._start_rotate_notifs, 30 * 60 * 1000,
|
||||
self._start_rotate_notifs, 30 * 60 * 1000
|
||||
)
|
||||
|
||||
def _set_push_actions_for_event_and_users_txn(self, txn, events_and_contexts,
|
||||
all_events_and_contexts):
|
||||
def _set_push_actions_for_event_and_users_txn(
|
||||
self, txn, events_and_contexts, all_events_and_contexts
|
||||
):
|
||||
"""Handles moving push actions from staging table to main
|
||||
event_push_actions table for all events in `events_and_contexts`.
|
||||
|
||||
@ -667,43 +665,44 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
||||
"""
|
||||
|
||||
if events_and_contexts:
|
||||
txn.executemany(sql, (
|
||||
txn.executemany(
|
||||
sql,
|
||||
(
|
||||
event.room_id, event.internal_metadata.stream_ordering,
|
||||
event.depth, event.event_id,
|
||||
)
|
||||
for event, _ in events_and_contexts
|
||||
))
|
||||
(
|
||||
event.room_id,
|
||||
event.internal_metadata.stream_ordering,
|
||||
event.depth,
|
||||
event.event_id,
|
||||
)
|
||||
for event, _ in events_and_contexts
|
||||
),
|
||||
)
|
||||
|
||||
for event, _ in events_and_contexts:
|
||||
user_ids = self._simple_select_onecol_txn(
|
||||
txn,
|
||||
table="event_push_actions_staging",
|
||||
keyvalues={
|
||||
"event_id": event.event_id,
|
||||
},
|
||||
keyvalues={"event_id": event.event_id},
|
||||
retcol="user_id",
|
||||
)
|
||||
|
||||
for uid in user_ids:
|
||||
txn.call_after(
|
||||
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
|
||||
(event.room_id, uid,)
|
||||
(event.room_id, uid),
|
||||
)
|
||||
|
||||
# Now we delete the staging area for *all* events that were being
|
||||
# persisted.
|
||||
txn.executemany(
|
||||
"DELETE FROM event_push_actions_staging WHERE event_id = ?",
|
||||
(
|
||||
(event.event_id,)
|
||||
for event, _ in all_events_and_contexts
|
||||
)
|
||||
((event.event_id,) for event, _ in all_events_and_contexts),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_push_actions_for_user(self, user_id, before=None, limit=50,
|
||||
only_highlight=False):
|
||||
def get_push_actions_for_user(
|
||||
self, user_id, before=None, limit=50, only_highlight=False
|
||||
):
|
||||
def f(txn):
|
||||
before_clause = ""
|
||||
if before:
|
||||
@ -727,15 +726,12 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
||||
" WHERE epa.event_id = e.event_id"
|
||||
" AND epa.user_id = ? %s"
|
||||
" ORDER BY epa.stream_ordering DESC"
|
||||
" LIMIT ?"
|
||||
% (before_clause,)
|
||||
" LIMIT ?" % (before_clause,)
|
||||
)
|
||||
txn.execute(sql, args)
|
||||
return self.cursor_to_dict(txn)
|
||||
|
||||
push_actions = yield self.runInteraction(
|
||||
"get_push_actions_for_user", f
|
||||
)
|
||||
push_actions = yield self.runInteraction("get_push_actions_for_user", f)
|
||||
for pa in push_actions:
|
||||
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
|
||||
defer.returnValue(push_actions)
|
||||
@ -753,6 +749,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
||||
)
|
||||
txn.execute(sql, (stream_ordering,))
|
||||
return txn.fetchone()
|
||||
|
||||
result = yield self.runInteraction("get_time_of_last_push_action_before", f)
|
||||
defer.returnValue(result[0] if result else None)
|
||||
|
||||
@ -761,24 +758,24 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
||||
def f(txn):
|
||||
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
|
||||
return txn.fetchone()
|
||||
result = yield self.runInteraction(
|
||||
"get_latest_push_action_stream_ordering", f
|
||||
)
|
||||
|
||||
result = yield self.runInteraction("get_latest_push_action_stream_ordering", f)
|
||||
defer.returnValue(result[0] or 0)
|
||||
|
||||
def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
|
||||
# Sad that we have to blow away the cache for the whole room here
|
||||
txn.call_after(
|
||||
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
|
||||
(room_id,)
|
||||
(room_id,),
|
||||
)
|
||||
txn.execute(
|
||||
"DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
|
||||
(room_id, event_id)
|
||||
(room_id, event_id),
|
||||
)
|
||||
|
||||
def _remove_old_push_actions_before_txn(self, txn, room_id, user_id,
|
||||
stream_ordering):
|
||||
def _remove_old_push_actions_before_txn(
|
||||
self, txn, room_id, user_id, stream_ordering
|
||||
):
|
||||
"""
|
||||
Purges old push actions for a user and room before a given
|
||||
stream_ordering.
|
||||
@ -795,7 +792,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
||||
"""
|
||||
txn.call_after(
|
||||
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
|
||||
(room_id, user_id, )
|
||||
(room_id, user_id),
|
||||
)
|
||||
|
||||
# We need to join on the events table to get the received_ts for
|
||||
@ -811,13 +808,16 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
||||
" WHERE user_id = ? AND room_id = ? AND "
|
||||
" stream_ordering <= ?"
|
||||
" AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
|
||||
(user_id, room_id, stream_ordering, self.stream_ordering_month_ago)
|
||||
(user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
|
||||
)
|
||||
|
||||
txn.execute("""
|
||||
txn.execute(
|
||||
"""
|
||||
DELETE FROM event_push_summary
|
||||
WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
|
||||
""", (room_id, user_id, stream_ordering))
|
||||
""",
|
||||
(room_id, user_id, stream_ordering),
|
||||
)
|
||||
|
||||
def _start_rotate_notifs(self):
|
||||
return run_as_background_process("rotate_notifs", self._rotate_notifs)
|
||||
@ -833,8 +833,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
||||
logger.info("Rotating notifications")
|
||||
|
||||
caught_up = yield self.runInteraction(
|
||||
"_rotate_notifs",
|
||||
self._rotate_notifs_txn
|
||||
"_rotate_notifs", self._rotate_notifs_txn
|
||||
)
|
||||
if caught_up:
|
||||
break
|
||||
@ -856,11 +855,14 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
||||
|
||||
# We don't to try and rotate millions of rows at once, so we cap the
|
||||
# maximum stream ordering we'll rotate before.
|
||||
txn.execute("""
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT stream_ordering FROM event_push_actions
|
||||
WHERE stream_ordering > ?
|
||||
ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
|
||||
""", (old_rotate_stream_ordering, self._rotate_count))
|
||||
""",
|
||||
(old_rotate_stream_ordering, self._rotate_count),
|
||||
)
|
||||
stream_row = txn.fetchone()
|
||||
if stream_row:
|
||||
offset_stream_ordering, = stream_row
|
||||
@ -904,7 +906,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
||||
LEFT JOIN event_push_summary AS old USING (user_id, room_id)
|
||||
"""
|
||||
|
||||
txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering,))
|
||||
txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering))
|
||||
rows = txn.fetchall()
|
||||
|
||||
logger.info("Rotating notifications, handling %d rows", len(rows))
|
||||
@ -922,8 +924,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
||||
"notif_count": row[2],
|
||||
"stream_ordering": row[3],
|
||||
}
|
||||
for row in rows if row[4] is None
|
||||
]
|
||||
for row in rows
|
||||
if row[4] is None
|
||||
],
|
||||
)
|
||||
|
||||
txn.executemany(
|
||||
@ -931,20 +934,20 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
||||
UPDATE event_push_summary SET notif_count = ?, stream_ordering = ?
|
||||
WHERE user_id = ? AND room_id = ?
|
||||
""",
|
||||
((row[2], row[3], row[0], row[1],) for row in rows if row[4] is not None)
|
||||
((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None),
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
"DELETE FROM event_push_actions"
|
||||
" WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0",
|
||||
(old_rotate_stream_ordering, rotate_to_stream_ordering,)
|
||||
(old_rotate_stream_ordering, rotate_to_stream_ordering),
|
||||
)
|
||||
|
||||
logger.info("Rotating notifications, deleted %s push actions", txn.rowcount)
|
||||
|
||||
txn.execute(
|
||||
"UPDATE event_push_summary_stream_ordering SET stream_ordering = ?",
|
||||
(rotate_to_stream_ordering,)
|
||||
(rotate_to_stream_ordering,),
|
||||
)
|
||||
|
||||
|
||||
|
@ -71,17 +71,21 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
"""
|
||||
return self._simple_select_one_onecol(
|
||||
table="events",
|
||||
keyvalues={
|
||||
"event_id": event_id,
|
||||
},
|
||||
keyvalues={"event_id": event_id},
|
||||
retcol="received_ts",
|
||||
desc="get_received_ts",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_event(self, event_id, check_redacted=True,
|
||||
get_prev_content=False, allow_rejected=False,
|
||||
allow_none=False, check_room_id=None):
|
||||
def get_event(
|
||||
self,
|
||||
event_id,
|
||||
check_redacted=True,
|
||||
get_prev_content=False,
|
||||
allow_rejected=False,
|
||||
allow_none=False,
|
||||
check_room_id=None,
|
||||
):
|
||||
"""Get an event from the database by event_id.
|
||||
|
||||
Args:
|
||||
@ -118,8 +122,13 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
defer.returnValue(event)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_events(self, event_ids, check_redacted=True,
|
||||
get_prev_content=False, allow_rejected=False):
|
||||
def get_events(
|
||||
self,
|
||||
event_ids,
|
||||
check_redacted=True,
|
||||
get_prev_content=False,
|
||||
allow_rejected=False,
|
||||
):
|
||||
"""Get events from the database
|
||||
|
||||
Args:
|
||||
@ -143,8 +152,13 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
defer.returnValue({e.event_id: e for e in events})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_events(self, event_ids, check_redacted=True,
|
||||
get_prev_content=False, allow_rejected=False):
|
||||
def _get_events(
|
||||
self,
|
||||
event_ids,
|
||||
check_redacted=True,
|
||||
get_prev_content=False,
|
||||
allow_rejected=False,
|
||||
):
|
||||
if not event_ids:
|
||||
defer.returnValue([])
|
||||
|
||||
@ -152,8 +166,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
event_ids = set(event_ids)
|
||||
|
||||
event_entry_map = self._get_events_from_cache(
|
||||
event_ids,
|
||||
allow_rejected=allow_rejected,
|
||||
event_ids, allow_rejected=allow_rejected
|
||||
)
|
||||
|
||||
missing_events_ids = [e for e in event_ids if e not in event_entry_map]
|
||||
@ -169,8 +182,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
#
|
||||
# _enqueue_events is a bit of a rubbish name but naming is hard.
|
||||
missing_events = yield self._enqueue_events(
|
||||
missing_events_ids,
|
||||
allow_rejected=allow_rejected,
|
||||
missing_events_ids, allow_rejected=allow_rejected
|
||||
)
|
||||
|
||||
event_entry_map.update(missing_events)
|
||||
@ -214,7 +226,10 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
expected_domain = get_domain_from_id(entry.event.sender)
|
||||
if orig_sender and get_domain_from_id(orig_sender) == expected_domain:
|
||||
if (
|
||||
orig_sender
|
||||
and get_domain_from_id(orig_sender) == expected_domain
|
||||
):
|
||||
# This redaction event is allowed. Mark as not needing a
|
||||
# recheck.
|
||||
entry.event.internal_metadata.recheck_redaction = False
|
||||
@ -267,8 +282,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
for event_id in events:
|
||||
ret = self._get_event_cache.get(
|
||||
(event_id,), None,
|
||||
update_metrics=update_metrics,
|
||||
(event_id,), None, update_metrics=update_metrics
|
||||
)
|
||||
if not ret:
|
||||
continue
|
||||
@ -318,19 +332,13 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
with Measure(self._clock, "_fetch_event_list"):
|
||||
try:
|
||||
event_id_lists = list(zip(*event_list))[0]
|
||||
event_ids = [
|
||||
item for sublist in event_id_lists for item in sublist
|
||||
]
|
||||
event_ids = [item for sublist in event_id_lists for item in sublist]
|
||||
|
||||
rows = self._new_transaction(
|
||||
conn, "do_fetch", [], [],
|
||||
self._fetch_event_rows, event_ids,
|
||||
conn, "do_fetch", [], [], self._fetch_event_rows, event_ids
|
||||
)
|
||||
|
||||
row_dict = {
|
||||
r["event_id"]: r
|
||||
for r in rows
|
||||
}
|
||||
row_dict = {r["event_id"]: r for r in rows}
|
||||
|
||||
# We only want to resolve deferreds from the main thread
|
||||
def fire(lst, res):
|
||||
@ -338,13 +346,10 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
if not d.called:
|
||||
try:
|
||||
with PreserveLoggingContext():
|
||||
d.callback([
|
||||
res[i]
|
||||
for i in ids
|
||||
if i in res
|
||||
])
|
||||
d.callback([res[i] for i in ids if i in res])
|
||||
except Exception:
|
||||
logger.exception("Failed to callback")
|
||||
|
||||
with PreserveLoggingContext():
|
||||
self.hs.get_reactor().callFromThread(fire, event_list, row_dict)
|
||||
except Exception as e:
|
||||
@ -371,9 +376,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
events_d = defer.Deferred()
|
||||
with self._event_fetch_lock:
|
||||
self._event_fetch_list.append(
|
||||
(events, events_d)
|
||||
)
|
||||
self._event_fetch_list.append((events, events_d))
|
||||
|
||||
self._event_fetch_lock.notify()
|
||||
|
||||
@ -385,9 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
if should_start:
|
||||
run_as_background_process(
|
||||
"fetch_events",
|
||||
self.runWithConnection,
|
||||
self._do_fetch,
|
||||
"fetch_events", self.runWithConnection, self._do_fetch
|
||||
)
|
||||
|
||||
logger.debug("Loading %d events", len(events))
|
||||
@ -398,29 +399,30 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
if not allow_rejected:
|
||||
rows[:] = [r for r in rows if not r["rejects"]]
|
||||
|
||||
res = yield make_deferred_yieldable(defer.gatherResults(
|
||||
[
|
||||
run_in_background(
|
||||
self._get_event_from_row,
|
||||
row["internal_metadata"], row["json"], row["redacts"],
|
||||
rejected_reason=row["rejects"],
|
||||
format_version=row["format_version"],
|
||||
)
|
||||
for row in rows
|
||||
],
|
||||
consumeErrors=True
|
||||
))
|
||||
res = yield make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
run_in_background(
|
||||
self._get_event_from_row,
|
||||
row["internal_metadata"],
|
||||
row["json"],
|
||||
row["redacts"],
|
||||
rejected_reason=row["rejects"],
|
||||
format_version=row["format_version"],
|
||||
)
|
||||
for row in rows
|
||||
],
|
||||
consumeErrors=True,
|
||||
)
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
e.event.event_id: e
|
||||
for e in res if e
|
||||
})
|
||||
defer.returnValue({e.event.event_id: e for e in res if e})
|
||||
|
||||
def _fetch_event_rows(self, txn, events):
|
||||
rows = []
|
||||
N = 200
|
||||
for i in range(1 + len(events) // N):
|
||||
evs = events[i * N:(i + 1) * N]
|
||||
evs = events[i * N : (i + 1) * N]
|
||||
if not evs:
|
||||
break
|
||||
|
||||
@ -444,8 +446,9 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
return rows
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_event_from_row(self, internal_metadata, js, redacted,
|
||||
format_version, rejected_reason=None):
|
||||
def _get_event_from_row(
|
||||
self, internal_metadata, js, redacted, format_version, rejected_reason=None
|
||||
):
|
||||
with Measure(self._clock, "_get_event_from_row"):
|
||||
d = json.loads(js)
|
||||
internal_metadata = json.loads(internal_metadata)
|
||||
@ -484,9 +487,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
# Get the redaction event.
|
||||
|
||||
because = yield self.get_event(
|
||||
redaction_id,
|
||||
check_redacted=False,
|
||||
allow_none=True,
|
||||
redaction_id, check_redacted=False, allow_none=True
|
||||
)
|
||||
|
||||
if because:
|
||||
@ -508,8 +509,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
redacted_event = None
|
||||
|
||||
cache_entry = _EventCacheEntry(
|
||||
event=original_ev,
|
||||
redacted_event=redacted_event,
|
||||
event=original_ev, redacted_event=redacted_event
|
||||
)
|
||||
|
||||
self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
|
||||
@ -545,23 +545,17 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
results = set()
|
||||
|
||||
def have_seen_events_txn(txn, chunk):
|
||||
sql = (
|
||||
"SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
|
||||
% (",".join("?" * len(chunk)), )
|
||||
sql = "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" % (
|
||||
",".join("?" * len(chunk)),
|
||||
)
|
||||
txn.execute(sql, chunk)
|
||||
for (event_id, ) in txn:
|
||||
for (event_id,) in txn:
|
||||
results.add(event_id)
|
||||
|
||||
# break the input up into chunks of 100
|
||||
input_iterator = iter(event_ids)
|
||||
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
|
||||
[]):
|
||||
yield self.runInteraction(
|
||||
"have_seen_events",
|
||||
have_seen_events_txn,
|
||||
chunk,
|
||||
)
|
||||
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
|
||||
yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk)
|
||||
defer.returnValue(results)
|
||||
|
||||
def get_seen_events_with_rejections(self, event_ids):
|
||||
|
@ -35,10 +35,7 @@ class FilteringStore(SQLBaseStore):
|
||||
|
||||
def_json = yield self._simple_select_one_onecol(
|
||||
table="user_filters",
|
||||
keyvalues={
|
||||
"user_id": user_localpart,
|
||||
"filter_id": filter_id,
|
||||
},
|
||||
keyvalues={"user_id": user_localpart, "filter_id": filter_id},
|
||||
retcol="filter_json",
|
||||
allow_none=False,
|
||||
desc="get_user_filter",
|
||||
@ -61,10 +58,7 @@ class FilteringStore(SQLBaseStore):
|
||||
if filter_id_response is not None:
|
||||
return filter_id_response[0]
|
||||
|
||||
sql = (
|
||||
"SELECT MAX(filter_id) FROM user_filters "
|
||||
"WHERE user_id = ?"
|
||||
)
|
||||
sql = "SELECT MAX(filter_id) FROM user_filters " "WHERE user_id = ?"
|
||||
txn.execute(sql, (user_localpart,))
|
||||
max_id = txn.fetchone()[0]
|
||||
if max_id is None:
|
||||
|
@ -38,24 +38,22 @@ class GroupServerStore(SQLBaseStore):
|
||||
"""
|
||||
return self._simple_update_one(
|
||||
table="groups",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
},
|
||||
updatevalues={
|
||||
"join_policy": join_policy,
|
||||
},
|
||||
keyvalues={"group_id": group_id},
|
||||
updatevalues={"join_policy": join_policy},
|
||||
desc="set_group_join_policy",
|
||||
)
|
||||
|
||||
def get_group(self, group_id):
|
||||
return self._simple_select_one(
|
||||
table="groups",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id},
|
||||
retcols=(
|
||||
"name", "short_description", "long_description",
|
||||
"avatar_url", "is_public", "join_policy",
|
||||
"name",
|
||||
"short_description",
|
||||
"long_description",
|
||||
"avatar_url",
|
||||
"is_public",
|
||||
"join_policy",
|
||||
),
|
||||
allow_none=True,
|
||||
desc="get_group",
|
||||
@ -64,16 +62,14 @@ class GroupServerStore(SQLBaseStore):
|
||||
def get_users_in_group(self, group_id, include_private=False):
|
||||
# TODO: Pagination
|
||||
|
||||
keyvalues = {
|
||||
"group_id": group_id,
|
||||
}
|
||||
keyvalues = {"group_id": group_id}
|
||||
if not include_private:
|
||||
keyvalues["is_public"] = True
|
||||
|
||||
return self._simple_select_list(
|
||||
table="group_users",
|
||||
keyvalues=keyvalues,
|
||||
retcols=("user_id", "is_public", "is_admin",),
|
||||
retcols=("user_id", "is_public", "is_admin"),
|
||||
desc="get_users_in_group",
|
||||
)
|
||||
|
||||
@ -82,9 +78,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
|
||||
return self._simple_select_onecol(
|
||||
table="group_invites",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id},
|
||||
retcol="user_id",
|
||||
desc="get_invited_users_in_group",
|
||||
)
|
||||
@ -92,16 +86,14 @@ class GroupServerStore(SQLBaseStore):
|
||||
def get_rooms_in_group(self, group_id, include_private=False):
|
||||
# TODO: Pagination
|
||||
|
||||
keyvalues = {
|
||||
"group_id": group_id,
|
||||
}
|
||||
keyvalues = {"group_id": group_id}
|
||||
if not include_private:
|
||||
keyvalues["is_public"] = True
|
||||
|
||||
return self._simple_select_list(
|
||||
table="group_rooms",
|
||||
keyvalues=keyvalues,
|
||||
retcols=("room_id", "is_public",),
|
||||
retcols=("room_id", "is_public"),
|
||||
desc="get_rooms_in_group",
|
||||
)
|
||||
|
||||
@ -110,10 +102,9 @@ class GroupServerStore(SQLBaseStore):
|
||||
|
||||
Returns ([rooms], [categories])
|
||||
"""
|
||||
|
||||
def _get_rooms_for_summary_txn(txn):
|
||||
keyvalues = {
|
||||
"group_id": group_id,
|
||||
}
|
||||
keyvalues = {"group_id": group_id}
|
||||
if not include_private:
|
||||
keyvalues["is_public"] = True
|
||||
|
||||
@ -162,18 +153,23 @@ class GroupServerStore(SQLBaseStore):
|
||||
}
|
||||
|
||||
return rooms, categories
|
||||
return self.runInteraction(
|
||||
"get_rooms_for_summary", _get_rooms_for_summary_txn
|
||||
)
|
||||
|
||||
return self.runInteraction("get_rooms_for_summary", _get_rooms_for_summary_txn)
|
||||
|
||||
def add_room_to_summary(self, group_id, room_id, category_id, order, is_public):
|
||||
return self.runInteraction(
|
||||
"add_room_to_summary", self._add_room_to_summary_txn,
|
||||
group_id, room_id, category_id, order, is_public,
|
||||
"add_room_to_summary",
|
||||
self._add_room_to_summary_txn,
|
||||
group_id,
|
||||
room_id,
|
||||
category_id,
|
||||
order,
|
||||
is_public,
|
||||
)
|
||||
|
||||
def _add_room_to_summary_txn(self, txn, group_id, room_id, category_id, order,
|
||||
is_public):
|
||||
def _add_room_to_summary_txn(
|
||||
self, txn, group_id, room_id, category_id, order, is_public
|
||||
):
|
||||
"""Add (or update) room's entry in summary.
|
||||
|
||||
Args:
|
||||
@ -188,10 +184,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
room_in_group = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="group_rooms",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"room_id": room_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "room_id": room_id},
|
||||
retcol="room_id",
|
||||
allow_none=True,
|
||||
)
|
||||
@ -204,10 +197,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
cat_exists = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="group_room_categories",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"category_id": category_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "category_id": category_id},
|
||||
retcol="group_id",
|
||||
allow_none=True,
|
||||
)
|
||||
@ -218,22 +208,22 @@ class GroupServerStore(SQLBaseStore):
|
||||
cat_exists = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="group_summary_room_categories",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"category_id": category_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "category_id": category_id},
|
||||
retcol="group_id",
|
||||
allow_none=True,
|
||||
)
|
||||
if not cat_exists:
|
||||
# If not, add it with an order larger than all others
|
||||
txn.execute("""
|
||||
txn.execute(
|
||||
"""
|
||||
INSERT INTO group_summary_room_categories
|
||||
(group_id, category_id, cat_order)
|
||||
SELECT ?, ?, COALESCE(MAX(cat_order), 0) + 1
|
||||
FROM group_summary_room_categories
|
||||
WHERE group_id = ? AND category_id = ?
|
||||
""", (group_id, category_id, group_id, category_id))
|
||||
""",
|
||||
(group_id, category_id, group_id, category_id),
|
||||
)
|
||||
|
||||
existing = self._simple_select_one_txn(
|
||||
txn,
|
||||
@ -243,7 +233,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
"room_id": room_id,
|
||||
"category_id": category_id,
|
||||
},
|
||||
retcols=("room_order", "is_public",),
|
||||
retcols=("room_order", "is_public"),
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
@ -253,13 +243,13 @@ class GroupServerStore(SQLBaseStore):
|
||||
UPDATE group_summary_rooms SET room_order = room_order + 1
|
||||
WHERE group_id = ? AND category_id = ? AND room_order >= ?
|
||||
"""
|
||||
txn.execute(sql, (group_id, category_id, order,))
|
||||
txn.execute(sql, (group_id, category_id, order))
|
||||
elif not existing:
|
||||
sql = """
|
||||
SELECT COALESCE(MAX(room_order), 0) + 1 FROM group_summary_rooms
|
||||
WHERE group_id = ? AND category_id = ?
|
||||
"""
|
||||
txn.execute(sql, (group_id, category_id,))
|
||||
txn.execute(sql, (group_id, category_id))
|
||||
order, = txn.fetchone()
|
||||
|
||||
if existing:
|
||||
@ -312,29 +302,26 @@ class GroupServerStore(SQLBaseStore):
|
||||
def get_group_categories(self, group_id):
|
||||
rows = yield self._simple_select_list(
|
||||
table="group_room_categories",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id},
|
||||
retcols=("category_id", "is_public", "profile"),
|
||||
desc="get_group_categories",
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
row["category_id"]: {
|
||||
"is_public": row["is_public"],
|
||||
"profile": json.loads(row["profile"]),
|
||||
defer.returnValue(
|
||||
{
|
||||
row["category_id"]: {
|
||||
"is_public": row["is_public"],
|
||||
"profile": json.loads(row["profile"]),
|
||||
}
|
||||
for row in rows
|
||||
}
|
||||
for row in rows
|
||||
})
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_group_category(self, group_id, category_id):
|
||||
category = yield self._simple_select_one(
|
||||
table="group_room_categories",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"category_id": category_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "category_id": category_id},
|
||||
retcols=("is_public", "profile"),
|
||||
desc="get_group_category",
|
||||
)
|
||||
@ -361,10 +348,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
|
||||
return self._simple_upsert(
|
||||
table="group_room_categories",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"category_id": category_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "category_id": category_id},
|
||||
values=update_values,
|
||||
insertion_values=insertion_values,
|
||||
desc="upsert_group_category",
|
||||
@ -373,10 +357,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
def remove_group_category(self, group_id, category_id):
|
||||
return self._simple_delete(
|
||||
table="group_room_categories",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"category_id": category_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "category_id": category_id},
|
||||
desc="remove_group_category",
|
||||
)
|
||||
|
||||
@ -384,29 +365,26 @@ class GroupServerStore(SQLBaseStore):
|
||||
def get_group_roles(self, group_id):
|
||||
rows = yield self._simple_select_list(
|
||||
table="group_roles",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id},
|
||||
retcols=("role_id", "is_public", "profile"),
|
||||
desc="get_group_roles",
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
row["role_id"]: {
|
||||
"is_public": row["is_public"],
|
||||
"profile": json.loads(row["profile"]),
|
||||
defer.returnValue(
|
||||
{
|
||||
row["role_id"]: {
|
||||
"is_public": row["is_public"],
|
||||
"profile": json.loads(row["profile"]),
|
||||
}
|
||||
for row in rows
|
||||
}
|
||||
for row in rows
|
||||
})
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_group_role(self, group_id, role_id):
|
||||
role = yield self._simple_select_one(
|
||||
table="group_roles",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"role_id": role_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "role_id": role_id},
|
||||
retcols=("is_public", "profile"),
|
||||
desc="get_group_role",
|
||||
)
|
||||
@ -433,10 +411,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
|
||||
return self._simple_upsert(
|
||||
table="group_roles",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"role_id": role_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "role_id": role_id},
|
||||
values=update_values,
|
||||
insertion_values=insertion_values,
|
||||
desc="upsert_group_role",
|
||||
@ -445,21 +420,24 @@ class GroupServerStore(SQLBaseStore):
|
||||
def remove_group_role(self, group_id, role_id):
|
||||
return self._simple_delete(
|
||||
table="group_roles",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"role_id": role_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "role_id": role_id},
|
||||
desc="remove_group_role",
|
||||
)
|
||||
|
||||
def add_user_to_summary(self, group_id, user_id, role_id, order, is_public):
|
||||
return self.runInteraction(
|
||||
"add_user_to_summary", self._add_user_to_summary_txn,
|
||||
group_id, user_id, role_id, order, is_public,
|
||||
"add_user_to_summary",
|
||||
self._add_user_to_summary_txn,
|
||||
group_id,
|
||||
user_id,
|
||||
role_id,
|
||||
order,
|
||||
is_public,
|
||||
)
|
||||
|
||||
def _add_user_to_summary_txn(self, txn, group_id, user_id, role_id, order,
|
||||
is_public):
|
||||
def _add_user_to_summary_txn(
|
||||
self, txn, group_id, user_id, role_id, order, is_public
|
||||
):
|
||||
"""Add (or update) user's entry in summary.
|
||||
|
||||
Args:
|
||||
@ -474,10 +452,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
user_in_group = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="group_users",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
retcol="user_id",
|
||||
allow_none=True,
|
||||
)
|
||||
@ -490,10 +465,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
role_exists = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="group_roles",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"role_id": role_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "role_id": role_id},
|
||||
retcol="group_id",
|
||||
allow_none=True,
|
||||
)
|
||||
@ -504,32 +476,28 @@ class GroupServerStore(SQLBaseStore):
|
||||
role_exists = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="group_summary_roles",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"role_id": role_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "role_id": role_id},
|
||||
retcol="group_id",
|
||||
allow_none=True,
|
||||
)
|
||||
if not role_exists:
|
||||
# If not, add it with an order larger than all others
|
||||
txn.execute("""
|
||||
txn.execute(
|
||||
"""
|
||||
INSERT INTO group_summary_roles
|
||||
(group_id, role_id, role_order)
|
||||
SELECT ?, ?, COALESCE(MAX(role_order), 0) + 1
|
||||
FROM group_summary_roles
|
||||
WHERE group_id = ? AND role_id = ?
|
||||
""", (group_id, role_id, group_id, role_id))
|
||||
""",
|
||||
(group_id, role_id, group_id, role_id),
|
||||
)
|
||||
|
||||
existing = self._simple_select_one_txn(
|
||||
txn,
|
||||
table="group_summary_users",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
"role_id": role_id,
|
||||
},
|
||||
retcols=("user_order", "is_public",),
|
||||
keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id},
|
||||
retcols=("user_order", "is_public"),
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
@ -539,13 +507,13 @@ class GroupServerStore(SQLBaseStore):
|
||||
UPDATE group_summary_users SET user_order = user_order + 1
|
||||
WHERE group_id = ? AND role_id = ? AND user_order >= ?
|
||||
"""
|
||||
txn.execute(sql, (group_id, role_id, order,))
|
||||
txn.execute(sql, (group_id, role_id, order))
|
||||
elif not existing:
|
||||
sql = """
|
||||
SELECT COALESCE(MAX(user_order), 0) + 1 FROM group_summary_users
|
||||
WHERE group_id = ? AND role_id = ?
|
||||
"""
|
||||
txn.execute(sql, (group_id, role_id,))
|
||||
txn.execute(sql, (group_id, role_id))
|
||||
order, = txn.fetchone()
|
||||
|
||||
if existing:
|
||||
@ -586,11 +554,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
|
||||
return self._simple_delete(
|
||||
table="group_summary_users",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"role_id": role_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id},
|
||||
desc="remove_user_from_summary",
|
||||
)
|
||||
|
||||
@ -599,10 +563,9 @@ class GroupServerStore(SQLBaseStore):
|
||||
|
||||
Returns ([users], [roles])
|
||||
"""
|
||||
|
||||
def _get_users_for_summary_txn(txn):
|
||||
keyvalues = {
|
||||
"group_id": group_id,
|
||||
}
|
||||
keyvalues = {"group_id": group_id}
|
||||
if not include_private:
|
||||
keyvalues["is_public"] = True
|
||||
|
||||
@ -651,6 +614,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
}
|
||||
|
||||
return users, roles
|
||||
|
||||
return self.runInteraction(
|
||||
"get_users_for_summary_by_role", _get_users_for_summary_txn
|
||||
)
|
||||
@ -658,10 +622,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
def is_user_in_group(self, user_id, group_id):
|
||||
return self._simple_select_one_onecol(
|
||||
table="group_users",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
retcol="user_id",
|
||||
allow_none=True,
|
||||
desc="is_user_in_group",
|
||||
@ -670,10 +631,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
def is_user_admin_in_group(self, group_id, user_id):
|
||||
return self._simple_select_one_onecol(
|
||||
table="group_users",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
retcol="is_admin",
|
||||
allow_none=True,
|
||||
desc="is_user_admin_in_group",
|
||||
@ -684,10 +642,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
"""
|
||||
return self._simple_insert(
|
||||
table="group_invites",
|
||||
values={
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
values={"group_id": group_id, "user_id": user_id},
|
||||
desc="add_group_invite",
|
||||
)
|
||||
|
||||
@ -696,10 +651,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
"""
|
||||
return self._simple_select_one_onecol(
|
||||
table="group_invites",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
retcol="user_id",
|
||||
desc="is_user_invited_to_local_group",
|
||||
allow_none=True,
|
||||
@ -718,14 +670,12 @@ class GroupServerStore(SQLBaseStore):
|
||||
|
||||
Returns an empty dict if the user is not join/invite/etc
|
||||
"""
|
||||
|
||||
def _get_users_membership_in_group_txn(txn):
|
||||
row = self._simple_select_one_txn(
|
||||
txn,
|
||||
table="group_users",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
retcols=("is_admin", "is_public"),
|
||||
allow_none=True,
|
||||
)
|
||||
@ -740,27 +690,29 @@ class GroupServerStore(SQLBaseStore):
|
||||
row = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="group_invites",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
retcol="user_id",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
if row:
|
||||
return {
|
||||
"membership": "invite",
|
||||
}
|
||||
return {"membership": "invite"}
|
||||
|
||||
return {}
|
||||
|
||||
return self.runInteraction(
|
||||
"get_users_membership_info_in_group", _get_users_membership_in_group_txn,
|
||||
"get_users_membership_info_in_group", _get_users_membership_in_group_txn
|
||||
)
|
||||
|
||||
def add_user_to_group(self, group_id, user_id, is_admin=False, is_public=True,
|
||||
local_attestation=None, remote_attestation=None):
|
||||
def add_user_to_group(
|
||||
self,
|
||||
group_id,
|
||||
user_id,
|
||||
is_admin=False,
|
||||
is_public=True,
|
||||
local_attestation=None,
|
||||
remote_attestation=None,
|
||||
):
|
||||
"""Add a user to the group server.
|
||||
|
||||
Args:
|
||||
@ -774,6 +726,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
remote_attestation (dict): The attestation given to GS by remote
|
||||
server. Optional if the user and group are on the same server
|
||||
"""
|
||||
|
||||
def _add_user_to_group_txn(txn):
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
@ -789,10 +742,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="group_invites",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
)
|
||||
|
||||
if local_attestation:
|
||||
@ -817,75 +767,52 @@ class GroupServerStore(SQLBaseStore):
|
||||
},
|
||||
)
|
||||
|
||||
return self.runInteraction(
|
||||
"add_user_to_group", _add_user_to_group_txn
|
||||
)
|
||||
return self.runInteraction("add_user_to_group", _add_user_to_group_txn)
|
||||
|
||||
def remove_user_from_group(self, group_id, user_id):
|
||||
def _remove_user_from_group_txn(txn):
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="group_users",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
)
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="group_invites",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
)
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="group_attestations_renewals",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
)
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="group_attestations_remote",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
)
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="group_summary_users",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
)
|
||||
return self.runInteraction("remove_user_from_group", _remove_user_from_group_txn)
|
||||
|
||||
return self.runInteraction(
|
||||
"remove_user_from_group", _remove_user_from_group_txn
|
||||
)
|
||||
|
||||
def add_room_to_group(self, group_id, room_id, is_public):
|
||||
return self._simple_insert(
|
||||
table="group_rooms",
|
||||
values={
|
||||
"group_id": group_id,
|
||||
"room_id": room_id,
|
||||
"is_public": is_public,
|
||||
},
|
||||
values={"group_id": group_id, "room_id": room_id, "is_public": is_public},
|
||||
desc="add_room_to_group",
|
||||
)
|
||||
|
||||
def update_room_in_group_visibility(self, group_id, room_id, is_public):
|
||||
return self._simple_update(
|
||||
table="group_rooms",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"room_id": room_id,
|
||||
},
|
||||
updatevalues={
|
||||
"is_public": is_public,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "room_id": room_id},
|
||||
updatevalues={"is_public": is_public},
|
||||
desc="update_room_in_group_visibility",
|
||||
)
|
||||
|
||||
@ -894,22 +821,17 @@ class GroupServerStore(SQLBaseStore):
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="group_rooms",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"room_id": room_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "room_id": room_id},
|
||||
)
|
||||
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="group_summary_rooms",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"room_id": room_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "room_id": room_id},
|
||||
)
|
||||
|
||||
return self.runInteraction(
|
||||
"remove_room_from_group", _remove_room_from_group_txn,
|
||||
"remove_room_from_group", _remove_room_from_group_txn
|
||||
)
|
||||
|
||||
def get_publicised_groups_for_user(self, user_id):
|
||||
@ -917,11 +839,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
"""
|
||||
return self._simple_select_onecol(
|
||||
table="local_group_membership",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"membership": "join",
|
||||
"is_publicised": True,
|
||||
},
|
||||
keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
|
||||
retcol="group_id",
|
||||
desc="get_publicised_groups_for_user",
|
||||
)
|
||||
@ -931,23 +849,23 @@ class GroupServerStore(SQLBaseStore):
|
||||
"""
|
||||
return self._simple_update_one(
|
||||
table="local_group_membership",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
updatevalues={
|
||||
"is_publicised": publicise,
|
||||
},
|
||||
desc="update_group_publicity"
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
updatevalues={"is_publicised": publicise},
|
||||
desc="update_group_publicity",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def register_user_group_membership(self, group_id, user_id, membership,
|
||||
is_admin=False, content={},
|
||||
local_attestation=None,
|
||||
remote_attestation=None,
|
||||
is_publicised=False,
|
||||
):
|
||||
def register_user_group_membership(
|
||||
self,
|
||||
group_id,
|
||||
user_id,
|
||||
membership,
|
||||
is_admin=False,
|
||||
content={},
|
||||
local_attestation=None,
|
||||
remote_attestation=None,
|
||||
is_publicised=False,
|
||||
):
|
||||
"""Registers that a local user is a member of a (local or remote) group.
|
||||
|
||||
Args:
|
||||
@ -962,15 +880,13 @@ class GroupServerStore(SQLBaseStore):
|
||||
remote_attestation (dict): If remote group then store the remote
|
||||
attestation from the group, else None.
|
||||
"""
|
||||
|
||||
def _register_user_group_membership_txn(txn, next_id):
|
||||
# TODO: Upsert?
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="local_group_membership",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
)
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
@ -993,8 +909,10 @@ class GroupServerStore(SQLBaseStore):
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
"type": "membership",
|
||||
"content": json.dumps({"membership": membership, "content": content}),
|
||||
}
|
||||
"content": json.dumps(
|
||||
{"membership": membership, "content": content}
|
||||
),
|
||||
},
|
||||
)
|
||||
self._group_updates_stream_cache.entity_has_changed(user_id, next_id)
|
||||
|
||||
@ -1009,7 +927,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
"valid_until_ms": local_attestation["valid_until_ms"],
|
||||
}
|
||||
},
|
||||
)
|
||||
if remote_attestation:
|
||||
self._simple_insert_txn(
|
||||
@ -1020,24 +938,18 @@ class GroupServerStore(SQLBaseStore):
|
||||
"user_id": user_id,
|
||||
"valid_until_ms": remote_attestation["valid_until_ms"],
|
||||
"attestation_json": json.dumps(remote_attestation),
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="group_attestations_renewals",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
)
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="group_attestations_remote",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
)
|
||||
|
||||
return next_id
|
||||
@ -1045,13 +957,15 @@ class GroupServerStore(SQLBaseStore):
|
||||
with self._group_updates_id_gen.get_next() as next_id:
|
||||
res = yield self.runInteraction(
|
||||
"register_user_group_membership",
|
||||
_register_user_group_membership_txn, next_id,
|
||||
_register_user_group_membership_txn,
|
||||
next_id,
|
||||
)
|
||||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_group(self, group_id, user_id, name, avatar_url, short_description,
|
||||
long_description,):
|
||||
def create_group(
|
||||
self, group_id, user_id, name, avatar_url, short_description, long_description
|
||||
):
|
||||
yield self._simple_insert(
|
||||
table="groups",
|
||||
values={
|
||||
@ -1066,12 +980,10 @@ class GroupServerStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_group_profile(self, group_id, profile,):
|
||||
def update_group_profile(self, group_id, profile):
|
||||
yield self._simple_update_one(
|
||||
table="groups",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id},
|
||||
updatevalues=profile,
|
||||
desc="update_group_profile",
|
||||
)
|
||||
@ -1079,6 +991,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
def get_attestations_need_renewals(self, valid_until_ms):
|
||||
"""Get all attestations that need to be renewed until givent time
|
||||
"""
|
||||
|
||||
def _get_attestations_need_renewals_txn(txn):
|
||||
sql = """
|
||||
SELECT group_id, user_id FROM group_attestations_renewals
|
||||
@ -1086,6 +999,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
"""
|
||||
txn.execute(sql, (valid_until_ms,))
|
||||
return self.cursor_to_dict(txn)
|
||||
|
||||
return self.runInteraction(
|
||||
"get_attestations_need_renewals", _get_attestations_need_renewals_txn
|
||||
)
|
||||
@ -1095,13 +1009,8 @@ class GroupServerStore(SQLBaseStore):
|
||||
"""
|
||||
return self._simple_update_one(
|
||||
table="group_attestations_renewals",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
updatevalues={
|
||||
"valid_until_ms": attestation["valid_until_ms"],
|
||||
},
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
|
||||
desc="update_attestation_renewal",
|
||||
)
|
||||
|
||||
@ -1110,13 +1019,10 @@ class GroupServerStore(SQLBaseStore):
|
||||
"""
|
||||
return self._simple_update_one(
|
||||
table="group_attestations_remote",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
updatevalues={
|
||||
"valid_until_ms": attestation["valid_until_ms"],
|
||||
"attestation_json": json.dumps(attestation)
|
||||
"attestation_json": json.dumps(attestation),
|
||||
},
|
||||
desc="update_remote_attestion",
|
||||
)
|
||||
@ -1132,10 +1038,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
"""
|
||||
return self._simple_delete(
|
||||
table="group_attestations_renewals",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
desc="remove_attestation_renewal",
|
||||
)
|
||||
|
||||
@ -1146,10 +1049,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
"""
|
||||
row = yield self._simple_select_one(
|
||||
table="group_attestations_remote",
|
||||
keyvalues={
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
retcols=("valid_until_ms", "attestation_json"),
|
||||
desc="get_remote_attestation",
|
||||
allow_none=True,
|
||||
@ -1164,10 +1064,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
def get_joined_groups(self, user_id):
|
||||
return self._simple_select_onecol(
|
||||
table="local_group_membership",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"membership": "join",
|
||||
},
|
||||
keyvalues={"user_id": user_id, "membership": "join"},
|
||||
retcol="group_id",
|
||||
desc="get_joined_groups",
|
||||
)
|
||||
@ -1181,7 +1078,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
WHERE user_id = ? AND membership != 'leave'
|
||||
AND stream_id <= ?
|
||||
"""
|
||||
txn.execute(sql, (user_id, now_token,))
|
||||
txn.execute(sql, (user_id, now_token))
|
||||
return [
|
||||
{
|
||||
"group_id": row[0],
|
||||
@ -1191,14 +1088,15 @@ class GroupServerStore(SQLBaseStore):
|
||||
}
|
||||
for row in txn
|
||||
]
|
||||
|
||||
return self.runInteraction(
|
||||
"get_all_groups_for_user", _get_all_groups_for_user_txn,
|
||||
"get_all_groups_for_user", _get_all_groups_for_user_txn
|
||||
)
|
||||
|
||||
def get_groups_changes_for_user(self, user_id, from_token, to_token):
|
||||
from_token = int(from_token)
|
||||
has_changed = self._group_updates_stream_cache.has_entity_changed(
|
||||
user_id, from_token,
|
||||
user_id, from_token
|
||||
)
|
||||
if not has_changed:
|
||||
return []
|
||||
@ -1210,21 +1108,25 @@ class GroupServerStore(SQLBaseStore):
|
||||
INNER JOIN local_group_membership USING (group_id, user_id)
|
||||
WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
|
||||
"""
|
||||
txn.execute(sql, (user_id, from_token, to_token,))
|
||||
return [{
|
||||
"group_id": group_id,
|
||||
"membership": membership,
|
||||
"type": gtype,
|
||||
"content": json.loads(content_json),
|
||||
} for group_id, membership, gtype, content_json in txn]
|
||||
txn.execute(sql, (user_id, from_token, to_token))
|
||||
return [
|
||||
{
|
||||
"group_id": group_id,
|
||||
"membership": membership,
|
||||
"type": gtype,
|
||||
"content": json.loads(content_json),
|
||||
}
|
||||
for group_id, membership, gtype, content_json in txn
|
||||
]
|
||||
|
||||
return self.runInteraction(
|
||||
"get_groups_changes_for_user", _get_groups_changes_for_user_txn,
|
||||
"get_groups_changes_for_user", _get_groups_changes_for_user_txn
|
||||
)
|
||||
|
||||
def get_all_groups_changes(self, from_token, to_token, limit):
|
||||
from_token = int(from_token)
|
||||
has_changed = self._group_updates_stream_cache.has_any_entity_changed(
|
||||
from_token,
|
||||
from_token
|
||||
)
|
||||
if not has_changed:
|
||||
return []
|
||||
@ -1236,16 +1138,14 @@ class GroupServerStore(SQLBaseStore):
|
||||
WHERE ? < stream_id AND stream_id <= ?
|
||||
LIMIT ?
|
||||
"""
|
||||
txn.execute(sql, (from_token, to_token, limit,))
|
||||
return [(
|
||||
stream_id,
|
||||
group_id,
|
||||
user_id,
|
||||
gtype,
|
||||
json.loads(content_json),
|
||||
) for stream_id, group_id, user_id, gtype, content_json in txn]
|
||||
txn.execute(sql, (from_token, to_token, limit))
|
||||
return [
|
||||
(stream_id, group_id, user_id, gtype, json.loads(content_json))
|
||||
for stream_id, group_id, user_id, gtype, content_json in txn
|
||||
]
|
||||
|
||||
return self.runInteraction(
|
||||
"get_all_groups_changes", _get_all_groups_changes_txn,
|
||||
"get_all_groups_changes", _get_all_groups_changes_txn
|
||||
)
|
||||
|
||||
def get_group_stream_token(self):
|
||||
|
@ -56,12 +56,13 @@ class KeyStore(SQLBaseStore):
|
||||
desc="get_server_certificate",
|
||||
)
|
||||
tls_certificate = OpenSSL.crypto.load_certificate(
|
||||
OpenSSL.crypto.FILETYPE_ASN1, tls_certificate_bytes,
|
||||
OpenSSL.crypto.FILETYPE_ASN1, tls_certificate_bytes
|
||||
)
|
||||
defer.returnValue(tls_certificate)
|
||||
|
||||
def store_server_certificate(self, server_name, from_server, time_now_ms,
|
||||
tls_certificate):
|
||||
def store_server_certificate(
|
||||
self, server_name, from_server, time_now_ms, tls_certificate
|
||||
):
|
||||
"""Stores the TLS X.509 certificate for the given server
|
||||
Args:
|
||||
server_name (str): The name of the server.
|
||||
@ -75,10 +76,7 @@ class KeyStore(SQLBaseStore):
|
||||
fingerprint = hashlib.sha256(tls_certificate_bytes).hexdigest()
|
||||
return self._simple_upsert(
|
||||
table="server_tls_certificates",
|
||||
keyvalues={
|
||||
"server_name": server_name,
|
||||
"fingerprint": fingerprint,
|
||||
},
|
||||
keyvalues={"server_name": server_name, "fingerprint": fingerprint},
|
||||
values={
|
||||
"from_server": from_server,
|
||||
"ts_added_ms": time_now_ms,
|
||||
@ -91,19 +89,14 @@ class KeyStore(SQLBaseStore):
|
||||
def _get_server_verify_key(self, server_name, key_id):
|
||||
verify_key_bytes = yield self._simple_select_one_onecol(
|
||||
table="server_signature_keys",
|
||||
keyvalues={
|
||||
"server_name": server_name,
|
||||
"key_id": key_id,
|
||||
},
|
||||
keyvalues={"server_name": server_name, "key_id": key_id},
|
||||
retcol="verify_key",
|
||||
desc="_get_server_verify_key",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
if verify_key_bytes:
|
||||
defer.returnValue(decode_verify_key_bytes(
|
||||
key_id, bytes(verify_key_bytes)
|
||||
))
|
||||
defer.returnValue(decode_verify_key_bytes(key_id, bytes(verify_key_bytes)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_server_verify_keys(self, server_name, key_ids):
|
||||
@ -123,8 +116,9 @@ class KeyStore(SQLBaseStore):
|
||||
keys[key_id] = key
|
||||
defer.returnValue(keys)
|
||||
|
||||
def store_server_verify_key(self, server_name, from_server, time_now_ms,
|
||||
verify_key):
|
||||
def store_server_verify_key(
|
||||
self, server_name, from_server, time_now_ms, verify_key
|
||||
):
|
||||
"""Stores a NACL verification key for the given server.
|
||||
Args:
|
||||
server_name (str): The name of the server.
|
||||
@ -139,10 +133,7 @@ class KeyStore(SQLBaseStore):
|
||||
self._simple_upsert_txn(
|
||||
txn,
|
||||
table="server_signature_keys",
|
||||
keyvalues={
|
||||
"server_name": server_name,
|
||||
"key_id": key_id,
|
||||
},
|
||||
keyvalues={"server_name": server_name, "key_id": key_id},
|
||||
values={
|
||||
"from_server": from_server,
|
||||
"ts_added_ms": time_now_ms,
|
||||
@ -150,14 +141,14 @@ class KeyStore(SQLBaseStore):
|
||||
},
|
||||
)
|
||||
txn.call_after(
|
||||
self._get_server_verify_key.invalidate,
|
||||
(server_name, key_id)
|
||||
self._get_server_verify_key.invalidate, (server_name, key_id)
|
||||
)
|
||||
|
||||
return self.runInteraction("store_server_verify_key", _txn)
|
||||
|
||||
def store_server_keys_json(self, server_name, key_id, from_server,
|
||||
ts_now_ms, ts_expires_ms, key_json_bytes):
|
||||
def store_server_keys_json(
|
||||
self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
|
||||
):
|
||||
"""Stores the JSON bytes for a set of keys from a server
|
||||
The JSON should be signed by the originating server, the intermediate
|
||||
server, and by this server. Updates the value for the
|
||||
@ -200,6 +191,7 @@ class KeyStore(SQLBaseStore):
|
||||
Dict mapping (server_name, key_id, source) triplets to dicts with
|
||||
"ts_valid_until_ms" and "key_json" keys.
|
||||
"""
|
||||
|
||||
def _get_server_keys_json_txn(txn):
|
||||
results = {}
|
||||
for server_name, key_id, from_server in server_keys:
|
||||
@ -222,6 +214,5 @@ class KeyStore(SQLBaseStore):
|
||||
)
|
||||
results[(server_name, key_id, from_server)] = rows
|
||||
return results
|
||||
return self.runInteraction(
|
||||
"get_server_keys_json", _get_server_keys_json_txn
|
||||
)
|
||||
|
||||
return self.runInteraction("get_server_keys_json", _get_server_keys_json_txn)
|
||||
|
@ -38,15 +38,27 @@ class MediaRepositoryStore(BackgroundUpdateStore):
|
||||
"local_media_repository",
|
||||
{"media_id": media_id},
|
||||
(
|
||||
"media_type", "media_length", "upload_name", "created_ts",
|
||||
"quarantined_by", "url_cache",
|
||||
"media_type",
|
||||
"media_length",
|
||||
"upload_name",
|
||||
"created_ts",
|
||||
"quarantined_by",
|
||||
"url_cache",
|
||||
),
|
||||
allow_none=True,
|
||||
desc="get_local_media",
|
||||
)
|
||||
|
||||
def store_local_media(self, media_id, media_type, time_now_ms, upload_name,
|
||||
media_length, user_id, url_cache=None):
|
||||
def store_local_media(
|
||||
self,
|
||||
media_id,
|
||||
media_type,
|
||||
time_now_ms,
|
||||
upload_name,
|
||||
media_length,
|
||||
user_id,
|
||||
url_cache=None,
|
||||
):
|
||||
return self._simple_insert(
|
||||
"local_media_repository",
|
||||
{
|
||||
@ -66,6 +78,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
|
||||
Returns:
|
||||
None if the URL isn't cached.
|
||||
"""
|
||||
|
||||
def get_url_cache_txn(txn):
|
||||
# get the most recently cached result (relative to the given ts)
|
||||
sql = (
|
||||
@ -92,16 +105,25 @@ class MediaRepositoryStore(BackgroundUpdateStore):
|
||||
if not row:
|
||||
return None
|
||||
|
||||
return dict(zip((
|
||||
'response_code', 'etag', 'expires_ts', 'og', 'media_id', 'download_ts'
|
||||
), row))
|
||||
return dict(
|
||||
zip(
|
||||
(
|
||||
'response_code',
|
||||
'etag',
|
||||
'expires_ts',
|
||||
'og',
|
||||
'media_id',
|
||||
'download_ts',
|
||||
),
|
||||
row,
|
||||
)
|
||||
)
|
||||
|
||||
return self.runInteraction(
|
||||
"get_url_cache", get_url_cache_txn
|
||||
)
|
||||
return self.runInteraction("get_url_cache", get_url_cache_txn)
|
||||
|
||||
def store_url_cache(self, url, response_code, etag, expires_ts, og, media_id,
|
||||
download_ts):
|
||||
def store_url_cache(
|
||||
self, url, response_code, etag, expires_ts, og, media_id, download_ts
|
||||
):
|
||||
return self._simple_insert(
|
||||
"local_media_repository_url_cache",
|
||||
{
|
||||
@ -121,15 +143,24 @@ class MediaRepositoryStore(BackgroundUpdateStore):
|
||||
"local_media_repository_thumbnails",
|
||||
{"media_id": media_id},
|
||||
(
|
||||
"thumbnail_width", "thumbnail_height", "thumbnail_method",
|
||||
"thumbnail_type", "thumbnail_length",
|
||||
"thumbnail_width",
|
||||
"thumbnail_height",
|
||||
"thumbnail_method",
|
||||
"thumbnail_type",
|
||||
"thumbnail_length",
|
||||
),
|
||||
desc="get_local_media_thumbnails",
|
||||
)
|
||||
|
||||
def store_local_thumbnail(self, media_id, thumbnail_width,
|
||||
thumbnail_height, thumbnail_type,
|
||||
thumbnail_method, thumbnail_length):
|
||||
def store_local_thumbnail(
|
||||
self,
|
||||
media_id,
|
||||
thumbnail_width,
|
||||
thumbnail_height,
|
||||
thumbnail_type,
|
||||
thumbnail_method,
|
||||
thumbnail_length,
|
||||
):
|
||||
return self._simple_insert(
|
||||
"local_media_repository_thumbnails",
|
||||
{
|
||||
@ -148,16 +179,27 @@ class MediaRepositoryStore(BackgroundUpdateStore):
|
||||
"remote_media_cache",
|
||||
{"media_origin": origin, "media_id": media_id},
|
||||
(
|
||||
"media_type", "media_length", "upload_name", "created_ts",
|
||||
"filesystem_id", "quarantined_by",
|
||||
"media_type",
|
||||
"media_length",
|
||||
"upload_name",
|
||||
"created_ts",
|
||||
"filesystem_id",
|
||||
"quarantined_by",
|
||||
),
|
||||
allow_none=True,
|
||||
desc="get_cached_remote_media",
|
||||
)
|
||||
|
||||
def store_cached_remote_media(self, origin, media_id, media_type,
|
||||
media_length, time_now_ms, upload_name,
|
||||
filesystem_id):
|
||||
def store_cached_remote_media(
|
||||
self,
|
||||
origin,
|
||||
media_id,
|
||||
media_type,
|
||||
media_length,
|
||||
time_now_ms,
|
||||
upload_name,
|
||||
filesystem_id,
|
||||
):
|
||||
return self._simple_insert(
|
||||
"remote_media_cache",
|
||||
{
|
||||
@ -181,26 +223,27 @@ class MediaRepositoryStore(BackgroundUpdateStore):
|
||||
remote_media (iterable[(str, str)]): Set of (server_name, media_id)
|
||||
time_ms: Current time in milliseconds
|
||||
"""
|
||||
|
||||
def update_cache_txn(txn):
|
||||
sql = (
|
||||
"UPDATE remote_media_cache SET last_access_ts = ?"
|
||||
" WHERE media_origin = ? AND media_id = ?"
|
||||
)
|
||||
|
||||
txn.executemany(sql, (
|
||||
(time_ms, media_origin, media_id)
|
||||
for media_origin, media_id in remote_media
|
||||
))
|
||||
txn.executemany(
|
||||
sql,
|
||||
(
|
||||
(time_ms, media_origin, media_id)
|
||||
for media_origin, media_id in remote_media
|
||||
),
|
||||
)
|
||||
|
||||
sql = (
|
||||
"UPDATE local_media_repository SET last_access_ts = ?"
|
||||
" WHERE media_id = ?"
|
||||
)
|
||||
|
||||
txn.executemany(sql, (
|
||||
(time_ms, media_id)
|
||||
for media_id in local_media
|
||||
))
|
||||
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
|
||||
|
||||
return self.runInteraction("update_cached_last_access_time", update_cache_txn)
|
||||
|
||||
@ -209,16 +252,27 @@ class MediaRepositoryStore(BackgroundUpdateStore):
|
||||
"remote_media_cache_thumbnails",
|
||||
{"media_origin": origin, "media_id": media_id},
|
||||
(
|
||||
"thumbnail_width", "thumbnail_height", "thumbnail_method",
|
||||
"thumbnail_type", "thumbnail_length", "filesystem_id",
|
||||
"thumbnail_width",
|
||||
"thumbnail_height",
|
||||
"thumbnail_method",
|
||||
"thumbnail_type",
|
||||
"thumbnail_length",
|
||||
"filesystem_id",
|
||||
),
|
||||
desc="get_remote_media_thumbnails",
|
||||
)
|
||||
|
||||
def store_remote_media_thumbnail(self, origin, media_id, filesystem_id,
|
||||
thumbnail_width, thumbnail_height,
|
||||
thumbnail_type, thumbnail_method,
|
||||
thumbnail_length):
|
||||
def store_remote_media_thumbnail(
|
||||
self,
|
||||
origin,
|
||||
media_id,
|
||||
filesystem_id,
|
||||
thumbnail_width,
|
||||
thumbnail_height,
|
||||
thumbnail_type,
|
||||
thumbnail_method,
|
||||
thumbnail_length,
|
||||
):
|
||||
return self._simple_insert(
|
||||
"remote_media_cache_thumbnails",
|
||||
{
|
||||
@ -250,17 +304,14 @@ class MediaRepositoryStore(BackgroundUpdateStore):
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
"remote_media_cache",
|
||||
keyvalues={
|
||||
"media_origin": media_origin, "media_id": media_id
|
||||
},
|
||||
keyvalues={"media_origin": media_origin, "media_id": media_id},
|
||||
)
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
"remote_media_cache_thumbnails",
|
||||
keyvalues={
|
||||
"media_origin": media_origin, "media_id": media_id
|
||||
},
|
||||
keyvalues={"media_origin": media_origin, "media_id": media_id},
|
||||
)
|
||||
|
||||
return self.runInteraction("delete_remote_media", delete_remote_media_txn)
|
||||
|
||||
def get_expired_url_cache(self, now_ts):
|
||||
@ -281,10 +332,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
|
||||
if len(media_ids) == 0:
|
||||
return
|
||||
|
||||
sql = (
|
||||
"DELETE FROM local_media_repository_url_cache"
|
||||
" WHERE media_id = ?"
|
||||
)
|
||||
sql = "DELETE FROM local_media_repository_url_cache" " WHERE media_id = ?"
|
||||
|
||||
def _delete_url_cache_txn(txn):
|
||||
txn.executemany(sql, [(media_id,) for media_id in media_ids])
|
||||
@ -304,7 +352,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
|
||||
return [row[0] for row in txn]
|
||||
|
||||
return self.runInteraction(
|
||||
"get_url_cache_media_before", _get_url_cache_media_before_txn,
|
||||
"get_url_cache_media_before", _get_url_cache_media_before_txn
|
||||
)
|
||||
|
||||
def delete_url_cache_media(self, media_ids):
|
||||
@ -312,20 +360,14 @@ class MediaRepositoryStore(BackgroundUpdateStore):
|
||||
return
|
||||
|
||||
def _delete_url_cache_media_txn(txn):
|
||||
sql = (
|
||||
"DELETE FROM local_media_repository"
|
||||
" WHERE media_id = ?"
|
||||
)
|
||||
sql = "DELETE FROM local_media_repository" " WHERE media_id = ?"
|
||||
|
||||
txn.executemany(sql, [(media_id,) for media_id in media_ids])
|
||||
|
||||
sql = (
|
||||
"DELETE FROM local_media_repository_thumbnails"
|
||||
" WHERE media_id = ?"
|
||||
)
|
||||
sql = "DELETE FROM local_media_repository_thumbnails" " WHERE media_id = ?"
|
||||
|
||||
txn.executemany(sql, [(media_id,) for media_id in media_ids])
|
||||
|
||||
return self.runInteraction(
|
||||
"delete_url_cache_media", _delete_url_cache_media_txn,
|
||||
"delete_url_cache_media", _delete_url_cache_media_txn
|
||||
)
|
||||
|
@ -35,9 +35,12 @@ class MonthlyActiveUsersStore(SQLBaseStore):
|
||||
self.reserved_users = ()
|
||||
# Do not add more reserved users than the total allowable number
|
||||
self._new_transaction(
|
||||
dbconn, "initialise_mau_threepids", [], [],
|
||||
dbconn,
|
||||
"initialise_mau_threepids",
|
||||
[],
|
||||
[],
|
||||
self._initialise_reserved_users,
|
||||
hs.config.mau_limits_reserved_threepids[:self.hs.config.max_mau_value],
|
||||
hs.config.mau_limits_reserved_threepids[: self.hs.config.max_mau_value],
|
||||
)
|
||||
|
||||
def _initialise_reserved_users(self, txn, threepids):
|
||||
@ -51,10 +54,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
|
||||
reserved_user_list = []
|
||||
|
||||
for tp in threepids:
|
||||
user_id = self.get_user_id_by_threepid_txn(
|
||||
txn,
|
||||
tp["medium"], tp["address"]
|
||||
)
|
||||
user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
|
||||
|
||||
if user_id:
|
||||
is_support = self.is_support_user_txn(txn, user_id)
|
||||
@ -62,9 +62,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
|
||||
self.upsert_monthly_active_user_txn(txn, user_id)
|
||||
reserved_user_list.append(user_id)
|
||||
else:
|
||||
logger.warning(
|
||||
"mau limit reserved threepid %s not found in db" % tp
|
||||
)
|
||||
logger.warning("mau limit reserved threepid %s not found in db" % tp)
|
||||
self.reserved_users = tuple(reserved_user_list)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -75,12 +73,11 @@ class MonthlyActiveUsersStore(SQLBaseStore):
|
||||
Returns:
|
||||
Deferred[]
|
||||
"""
|
||||
|
||||
def _reap_users(txn):
|
||||
# Purge stale users
|
||||
|
||||
thirty_days_ago = (
|
||||
int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
|
||||
)
|
||||
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
|
||||
query_args = [thirty_days_ago]
|
||||
base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?"
|
||||
|
||||
@ -158,6 +155,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
|
||||
txn.execute(sql)
|
||||
count, = txn.fetchone()
|
||||
return count
|
||||
|
||||
return self.runInteraction("count_users", _count_users)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -198,14 +196,11 @@ class MonthlyActiveUsersStore(SQLBaseStore):
|
||||
return
|
||||
|
||||
yield self.runInteraction(
|
||||
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn,
|
||||
user_id
|
||||
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
|
||||
)
|
||||
|
||||
user_in_mau = self.user_last_seen_monthly_active.cache.get(
|
||||
(user_id,),
|
||||
None,
|
||||
update_metrics=False
|
||||
(user_id,), None, update_metrics=False
|
||||
)
|
||||
if user_in_mau is None:
|
||||
self.get_monthly_active_count.invalidate(())
|
||||
@ -247,12 +242,8 @@ class MonthlyActiveUsersStore(SQLBaseStore):
|
||||
is_insert = self._simple_upsert_txn(
|
||||
txn,
|
||||
table="monthly_active_users",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
},
|
||||
values={
|
||||
"timestamp": int(self._clock.time_msec()),
|
||||
},
|
||||
keyvalues={"user_id": user_id},
|
||||
values={"timestamp": int(self._clock.time_msec())},
|
||||
)
|
||||
|
||||
return is_insert
|
||||
@ -268,15 +259,13 @@ class MonthlyActiveUsersStore(SQLBaseStore):
|
||||
|
||||
"""
|
||||
|
||||
return(self._simple_select_one_onecol(
|
||||
return self._simple_select_one_onecol(
|
||||
table="monthly_active_users",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
},
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="timestamp",
|
||||
allow_none=True,
|
||||
desc="user_last_seen_monthly_active",
|
||||
))
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def populate_monthly_active_users(self, user_id):
|
||||
|
@ -10,7 +10,7 @@ class OpenIdStore(SQLBaseStore):
|
||||
"ts_valid_until_ms": ts_valid_until_ms,
|
||||
"user_id": user_id,
|
||||
},
|
||||
desc="insert_open_id_token"
|
||||
desc="insert_open_id_token",
|
||||
)
|
||||
|
||||
def get_user_id_for_open_id_token(self, token, ts_now_ms):
|
||||
@ -27,6 +27,5 @@ class OpenIdStore(SQLBaseStore):
|
||||
return None
|
||||
else:
|
||||
return rows[0][0]
|
||||
return self.runInteraction(
|
||||
"get_user_id_for_token", get_user_id_for_token_txn
|
||||
)
|
||||
|
||||
return self.runInteraction("get_user_id_for_token", get_user_id_for_token_txn)
|
||||
|
@ -143,10 +143,9 @@ def _setup_new_database(cur, database_engine):
|
||||
|
||||
cur.execute(
|
||||
database_engine.convert_param_style(
|
||||
"INSERT INTO schema_version (version, upgraded)"
|
||||
" VALUES (?,?)"
|
||||
"INSERT INTO schema_version (version, upgraded)" " VALUES (?,?)"
|
||||
),
|
||||
(max_current_ver, False,)
|
||||
(max_current_ver, False),
|
||||
)
|
||||
|
||||
_upgrade_existing_database(
|
||||
@ -160,8 +159,15 @@ def _setup_new_database(cur, database_engine):
|
||||
)
|
||||
|
||||
|
||||
def _upgrade_existing_database(cur, current_version, applied_delta_files,
|
||||
upgraded, database_engine, config, is_empty=False):
|
||||
def _upgrade_existing_database(
|
||||
cur,
|
||||
current_version,
|
||||
applied_delta_files,
|
||||
upgraded,
|
||||
database_engine,
|
||||
config,
|
||||
is_empty=False,
|
||||
):
|
||||
"""Upgrades an existing database.
|
||||
|
||||
Delta files can either be SQL stored in *.sql files, or python modules
|
||||
@ -209,8 +215,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
|
||||
|
||||
if current_version > SCHEMA_VERSION:
|
||||
raise ValueError(
|
||||
"Cannot use this database as it is too " +
|
||||
"new for the server to understand"
|
||||
"Cannot use this database as it is too "
|
||||
+ "new for the server to understand"
|
||||
)
|
||||
|
||||
start_ver = current_version
|
||||
@ -239,20 +245,14 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
|
||||
if relative_path in applied_delta_files:
|
||||
continue
|
||||
|
||||
absolute_path = os.path.join(
|
||||
dir_path, "schema", "delta", relative_path,
|
||||
)
|
||||
absolute_path = os.path.join(dir_path, "schema", "delta", relative_path)
|
||||
root_name, ext = os.path.splitext(file_name)
|
||||
if ext == ".py":
|
||||
# This is a python upgrade module. We need to import into some
|
||||
# package and then execute its `run_upgrade` function.
|
||||
module_name = "synapse.storage.v%d_%s" % (
|
||||
v, root_name
|
||||
)
|
||||
module_name = "synapse.storage.v%d_%s" % (v, root_name)
|
||||
with open(absolute_path) as python_file:
|
||||
module = imp.load_source(
|
||||
module_name, absolute_path, python_file
|
||||
)
|
||||
module = imp.load_source(module_name, absolute_path, python_file)
|
||||
logger.info("Running script %s", relative_path)
|
||||
module.run_create(cur, database_engine)
|
||||
if not is_empty:
|
||||
@ -269,8 +269,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
|
||||
else:
|
||||
# Not a valid delta file.
|
||||
logger.warn(
|
||||
"Found directory entry that did not end in .py or"
|
||||
" .sql: %s",
|
||||
"Found directory entry that did not end in .py or" " .sql: %s",
|
||||
relative_path,
|
||||
)
|
||||
continue
|
||||
@ -278,19 +277,17 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
|
||||
# Mark as done.
|
||||
cur.execute(
|
||||
database_engine.convert_param_style(
|
||||
"INSERT INTO applied_schema_deltas (version, file)"
|
||||
" VALUES (?,?)",
|
||||
"INSERT INTO applied_schema_deltas (version, file)" " VALUES (?,?)"
|
||||
),
|
||||
(v, relative_path)
|
||||
(v, relative_path),
|
||||
)
|
||||
|
||||
cur.execute("DELETE FROM schema_version")
|
||||
cur.execute(
|
||||
database_engine.convert_param_style(
|
||||
"INSERT INTO schema_version (version, upgraded)"
|
||||
" VALUES (?,?)",
|
||||
"INSERT INTO schema_version (version, upgraded)" " VALUES (?,?)"
|
||||
),
|
||||
(v, True)
|
||||
(v, True),
|
||||
)
|
||||
|
||||
|
||||
@ -308,7 +305,7 @@ def _apply_module_schemas(txn, database_engine, config):
|
||||
continue
|
||||
modname = ".".join((mod.__module__, mod.__name__))
|
||||
_apply_module_schema_files(
|
||||
txn, database_engine, modname, mod.get_db_schema_files(),
|
||||
txn, database_engine, modname, mod.get_db_schema_files()
|
||||
)
|
||||
|
||||
|
||||
@ -326,7 +323,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
|
||||
database_engine.convert_param_style(
|
||||
"SELECT file FROM applied_module_schemas WHERE module_name = ?"
|
||||
),
|
||||
(modname,)
|
||||
(modname,),
|
||||
)
|
||||
applied_deltas = set(d for d, in cur)
|
||||
for (name, stream) in names_and_streams:
|
||||
@ -336,7 +333,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
|
||||
root_name, ext = os.path.splitext(name)
|
||||
if ext != '.sql':
|
||||
raise PrepareDatabaseException(
|
||||
"only .sql files are currently supported for module schemas",
|
||||
"only .sql files are currently supported for module schemas"
|
||||
)
|
||||
|
||||
logger.info("applying schema %s for %s", name, modname)
|
||||
@ -346,10 +343,9 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
|
||||
# Mark as done.
|
||||
cur.execute(
|
||||
database_engine.convert_param_style(
|
||||
"INSERT INTO applied_module_schemas (module_name, file)"
|
||||
" VALUES (?,?)",
|
||||
"INSERT INTO applied_module_schemas (module_name, file)" " VALUES (?,?)"
|
||||
),
|
||||
(modname, name)
|
||||
(modname, name),
|
||||
)
|
||||
|
||||
|
||||
@ -386,10 +382,7 @@ def get_statements(f):
|
||||
statements = line.split(";")
|
||||
|
||||
# We must prepend statement_buffer to the first statement
|
||||
first_statement = "%s %s" % (
|
||||
statement_buffer.strip(),
|
||||
statements[0].strip()
|
||||
)
|
||||
first_statement = "%s %s" % (statement_buffer.strip(), statements[0].strip())
|
||||
statements[0] = first_statement
|
||||
|
||||
# Every entry, except the last, is a full statement
|
||||
@ -409,9 +402,7 @@ def executescript(txn, schema_path):
|
||||
|
||||
def _get_or_create_schema_state(txn, database_engine):
|
||||
# Bluntly try creating the schema_version tables.
|
||||
schema_path = os.path.join(
|
||||
dir_path, "schema", "schema_version.sql",
|
||||
)
|
||||
schema_path = os.path.join(dir_path, "schema", "schema_version.sql")
|
||||
executescript(txn, schema_path)
|
||||
|
||||
txn.execute("SELECT version, upgraded FROM schema_version")
|
||||
@ -424,7 +415,7 @@ def _get_or_create_schema_state(txn, database_engine):
|
||||
database_engine.convert_param_style(
|
||||
"SELECT file FROM applied_schema_deltas WHERE version >= ?"
|
||||
),
|
||||
(current_version,)
|
||||
(current_version,),
|
||||
)
|
||||
applied_deltas = [d for d, in txn]
|
||||
return current_version, applied_deltas, upgraded
|
||||
|
@ -24,10 +24,20 @@ from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cache
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
|
||||
class UserPresenceState(namedtuple("UserPresenceState",
|
||||
("user_id", "state", "last_active_ts",
|
||||
"last_federation_update_ts", "last_user_sync_ts",
|
||||
"status_msg", "currently_active"))):
|
||||
class UserPresenceState(
|
||||
namedtuple(
|
||||
"UserPresenceState",
|
||||
(
|
||||
"user_id",
|
||||
"state",
|
||||
"last_active_ts",
|
||||
"last_federation_update_ts",
|
||||
"last_user_sync_ts",
|
||||
"status_msg",
|
||||
"currently_active",
|
||||
),
|
||||
)
|
||||
):
|
||||
"""Represents the current presence state of the user.
|
||||
|
||||
user_id (str)
|
||||
@ -75,22 +85,21 @@ class PresenceStore(SQLBaseStore):
|
||||
with stream_ordering_manager as stream_orderings:
|
||||
yield self.runInteraction(
|
||||
"update_presence",
|
||||
self._update_presence_txn, stream_orderings, presence_states,
|
||||
self._update_presence_txn,
|
||||
stream_orderings,
|
||||
presence_states,
|
||||
)
|
||||
|
||||
defer.returnValue((
|
||||
stream_orderings[-1], self._presence_id_gen.get_current_token()
|
||||
))
|
||||
defer.returnValue(
|
||||
(stream_orderings[-1], self._presence_id_gen.get_current_token())
|
||||
)
|
||||
|
||||
def _update_presence_txn(self, txn, stream_orderings, presence_states):
|
||||
for stream_id, state in zip(stream_orderings, presence_states):
|
||||
txn.call_after(
|
||||
self.presence_stream_cache.entity_has_changed,
|
||||
state.user_id, stream_id,
|
||||
)
|
||||
txn.call_after(
|
||||
self._get_presence_for_user.invalidate, (state.user_id,)
|
||||
self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
|
||||
)
|
||||
txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
|
||||
|
||||
# Actually insert new rows
|
||||
self._simple_insert_many_txn(
|
||||
@ -113,18 +122,13 @@ class PresenceStore(SQLBaseStore):
|
||||
|
||||
# Delete old rows to stop database from getting really big
|
||||
sql = (
|
||||
"DELETE FROM presence_stream WHERE"
|
||||
" stream_id < ?"
|
||||
" AND user_id IN (%s)"
|
||||
"DELETE FROM presence_stream WHERE" " stream_id < ?" " AND user_id IN (%s)"
|
||||
)
|
||||
|
||||
for states in batch_iter(presence_states, 50):
|
||||
args = [stream_id]
|
||||
args.extend(s.user_id for s in states)
|
||||
txn.execute(
|
||||
sql % (",".join("?" for _ in states),),
|
||||
args
|
||||
)
|
||||
txn.execute(sql % (",".join("?" for _ in states),), args)
|
||||
|
||||
def get_all_presence_updates(self, last_id, current_id):
|
||||
if last_id == current_id:
|
||||
@ -149,8 +153,12 @@ class PresenceStore(SQLBaseStore):
|
||||
def _get_presence_for_user(self, user_id):
|
||||
raise NotImplementedError()
|
||||
|
||||
@cachedList(cached_method_name="_get_presence_for_user", list_name="user_ids",
|
||||
num_args=1, inlineCallbacks=True)
|
||||
@cachedList(
|
||||
cached_method_name="_get_presence_for_user",
|
||||
list_name="user_ids",
|
||||
num_args=1,
|
||||
inlineCallbacks=True,
|
||||
)
|
||||
def get_presence_for_users(self, user_ids):
|
||||
rows = yield self._simple_select_many_batch(
|
||||
table="presence_stream",
|
||||
@ -180,8 +188,10 @@ class PresenceStore(SQLBaseStore):
|
||||
def allow_presence_visible(self, observed_localpart, observer_userid):
|
||||
return self._simple_insert(
|
||||
table="presence_allow_inbound",
|
||||
values={"observed_user_id": observed_localpart,
|
||||
"observer_user_id": observer_userid},
|
||||
values={
|
||||
"observed_user_id": observed_localpart,
|
||||
"observer_user_id": observer_userid,
|
||||
},
|
||||
desc="allow_presence_visible",
|
||||
or_ignore=True,
|
||||
)
|
||||
@ -189,17 +199,21 @@ class PresenceStore(SQLBaseStore):
|
||||
def disallow_presence_visible(self, observed_localpart, observer_userid):
|
||||
return self._simple_delete_one(
|
||||
table="presence_allow_inbound",
|
||||
keyvalues={"observed_user_id": observed_localpart,
|
||||
"observer_user_id": observer_userid},
|
||||
keyvalues={
|
||||
"observed_user_id": observed_localpart,
|
||||
"observer_user_id": observer_userid,
|
||||
},
|
||||
desc="disallow_presence_visible",
|
||||
)
|
||||
|
||||
def add_presence_list_pending(self, observer_localpart, observed_userid):
|
||||
return self._simple_insert(
|
||||
table="presence_list",
|
||||
values={"user_id": observer_localpart,
|
||||
"observed_user_id": observed_userid,
|
||||
"accepted": False},
|
||||
values={
|
||||
"user_id": observer_localpart,
|
||||
"observed_user_id": observed_userid,
|
||||
"accepted": False,
|
||||
},
|
||||
desc="add_presence_list_pending",
|
||||
)
|
||||
|
||||
@ -210,7 +224,7 @@ class PresenceStore(SQLBaseStore):
|
||||
table="presence_list",
|
||||
keyvalues={
|
||||
"user_id": observer_localpart,
|
||||
"observed_user_id": observed_userid
|
||||
"observed_user_id": observed_userid,
|
||||
},
|
||||
updatevalues={"accepted": True},
|
||||
)
|
||||
@ -225,7 +239,7 @@ class PresenceStore(SQLBaseStore):
|
||||
return result
|
||||
|
||||
return self.runInteraction(
|
||||
"set_presence_list_accepted", update_presence_list_txn,
|
||||
"set_presence_list_accepted", update_presence_list_txn
|
||||
)
|
||||
|
||||
def get_presence_list(self, observer_localpart, accepted=None):
|
||||
@ -261,16 +275,16 @@ class PresenceStore(SQLBaseStore):
|
||||
desc="get_presence_list_accepted",
|
||||
)
|
||||
|
||||
defer.returnValue([
|
||||
"@%s:%s" % (u, self.hs.hostname,) for u in user_localparts
|
||||
])
|
||||
defer.returnValue(["@%s:%s" % (u, self.hs.hostname) for u in user_localparts])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def del_presence_list(self, observer_localpart, observed_userid):
|
||||
yield self._simple_delete_one(
|
||||
table="presence_list",
|
||||
keyvalues={"user_id": observer_localpart,
|
||||
"observed_user_id": observed_userid},
|
||||
keyvalues={
|
||||
"user_id": observer_localpart,
|
||||
"observed_user_id": observed_userid,
|
||||
},
|
||||
desc="del_presence_list",
|
||||
)
|
||||
self.get_presence_list_accepted.invalidate((observer_localpart,))
|
||||
|
@ -41,8 +41,7 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||
|
||||
defer.returnValue(
|
||||
ProfileInfo(
|
||||
avatar_url=profile['avatar_url'],
|
||||
display_name=profile['displayname'],
|
||||
avatar_url=profile['avatar_url'], display_name=profile['displayname']
|
||||
)
|
||||
)
|
||||
|
||||
@ -66,16 +65,14 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||
return self._simple_select_one(
|
||||
table="remote_profile_cache",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("displayname", "avatar_url",),
|
||||
retcols=("displayname", "avatar_url"),
|
||||
allow_none=True,
|
||||
desc="get_from_remote_profile_cache",
|
||||
)
|
||||
|
||||
def create_profile(self, user_localpart):
|
||||
return self._simple_insert(
|
||||
table="profiles",
|
||||
values={"user_id": user_localpart},
|
||||
desc="create_profile",
|
||||
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
|
||||
)
|
||||
|
||||
def set_profile_displayname(self, user_localpart, new_displayname):
|
||||
@ -141,6 +138,7 @@ class ProfileStore(ProfileWorkerStore):
|
||||
def get_remote_profile_cache_entries_that_expire(self, last_checked):
|
||||
"""Get all users who haven't been checked since `last_checked`
|
||||
"""
|
||||
|
||||
def _get_remote_profile_cache_entries_that_expire_txn(txn):
|
||||
sql = """
|
||||
SELECT user_id, displayname, avatar_url
|
||||
|
@ -57,11 +57,13 @@ def _load_rules(rawrules, enabled_map):
|
||||
return rules
|
||||
|
||||
|
||||
class PushRulesWorkerStore(ApplicationServiceWorkerStore,
|
||||
ReceiptsWorkerStore,
|
||||
PusherWorkerStore,
|
||||
RoomMemberWorkerStore,
|
||||
SQLBaseStore):
|
||||
class PushRulesWorkerStore(
|
||||
ApplicationServiceWorkerStore,
|
||||
ReceiptsWorkerStore,
|
||||
PusherWorkerStore,
|
||||
RoomMemberWorkerStore,
|
||||
SQLBaseStore,
|
||||
):
|
||||
"""This is an abstract base class where subclasses must implement
|
||||
`get_max_push_rules_stream_id` which can be called in the initializer.
|
||||
"""
|
||||
@ -74,14 +76,16 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
|
||||
super(PushRulesWorkerStore, self).__init__(db_conn, hs)
|
||||
|
||||
push_rules_prefill, push_rules_id = self._get_cache_dict(
|
||||
db_conn, "push_rules_stream",
|
||||
db_conn,
|
||||
"push_rules_stream",
|
||||
entity_column="user_id",
|
||||
stream_column="stream_id",
|
||||
max_value=self.get_max_push_rules_stream_id(),
|
||||
)
|
||||
|
||||
self.push_rules_stream_cache = StreamChangeCache(
|
||||
"PushRulesStreamChangeCache", push_rules_id,
|
||||
"PushRulesStreamChangeCache",
|
||||
push_rules_id,
|
||||
prefilled_cache=push_rules_prefill,
|
||||
)
|
||||
|
||||
@ -98,19 +102,19 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
|
||||
def get_push_rules_for_user(self, user_id):
|
||||
rows = yield self._simple_select_list(
|
||||
table="push_rules",
|
||||
keyvalues={
|
||||
"user_name": user_id,
|
||||
},
|
||||
keyvalues={"user_name": user_id},
|
||||
retcols=(
|
||||
"user_name", "rule_id", "priority_class", "priority",
|
||||
"conditions", "actions",
|
||||
"user_name",
|
||||
"rule_id",
|
||||
"priority_class",
|
||||
"priority",
|
||||
"conditions",
|
||||
"actions",
|
||||
),
|
||||
desc="get_push_rules_enabled_for_user",
|
||||
)
|
||||
|
||||
rows.sort(
|
||||
key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
|
||||
)
|
||||
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
|
||||
|
||||
enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
|
||||
|
||||
@ -122,22 +126,19 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
|
||||
def get_push_rules_enabled_for_user(self, user_id):
|
||||
results = yield self._simple_select_list(
|
||||
table="push_rules_enable",
|
||||
keyvalues={
|
||||
'user_name': user_id
|
||||
},
|
||||
retcols=(
|
||||
"user_name", "rule_id", "enabled",
|
||||
),
|
||||
keyvalues={'user_name': user_id},
|
||||
retcols=("user_name", "rule_id", "enabled"),
|
||||
desc="get_push_rules_enabled_for_user",
|
||||
)
|
||||
defer.returnValue({
|
||||
r['rule_id']: False if r['enabled'] == 0 else True for r in results
|
||||
})
|
||||
defer.returnValue(
|
||||
{r['rule_id']: False if r['enabled'] == 0 else True for r in results}
|
||||
)
|
||||
|
||||
def have_push_rules_changed_for_user(self, user_id, last_id):
|
||||
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
|
||||
return defer.succeed(False)
|
||||
else:
|
||||
|
||||
def have_push_rules_changed_txn(txn):
|
||||
sql = (
|
||||
"SELECT COUNT(stream_id) FROM push_rules_stream"
|
||||
@ -146,20 +147,22 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
|
||||
txn.execute(sql, (user_id, last_id))
|
||||
count, = txn.fetchone()
|
||||
return bool(count)
|
||||
|
||||
return self.runInteraction(
|
||||
"have_push_rules_changed", have_push_rules_changed_txn
|
||||
)
|
||||
|
||||
@cachedList(cached_method_name="get_push_rules_for_user",
|
||||
list_name="user_ids", num_args=1, inlineCallbacks=True)
|
||||
@cachedList(
|
||||
cached_method_name="get_push_rules_for_user",
|
||||
list_name="user_ids",
|
||||
num_args=1,
|
||||
inlineCallbacks=True,
|
||||
)
|
||||
def bulk_get_push_rules(self, user_ids):
|
||||
if not user_ids:
|
||||
defer.returnValue({})
|
||||
|
||||
results = {
|
||||
user_id: []
|
||||
for user_id in user_ids
|
||||
}
|
||||
results = {user_id: [] for user_id in user_ids}
|
||||
|
||||
rows = yield self._simple_select_many_batch(
|
||||
table="push_rules",
|
||||
@ -169,9 +172,7 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
|
||||
desc="bulk_get_push_rules",
|
||||
)
|
||||
|
||||
rows.sort(
|
||||
key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
|
||||
)
|
||||
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
|
||||
|
||||
for row in rows:
|
||||
results.setdefault(row['user_name'], []).append(row)
|
||||
@ -179,16 +180,12 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
|
||||
enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
|
||||
|
||||
for user_id, rules in results.items():
|
||||
results[user_id] = _load_rules(
|
||||
rules, enabled_map_by_user.get(user_id, {})
|
||||
)
|
||||
results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
|
||||
|
||||
defer.returnValue(results)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def move_push_rule_from_room_to_room(
|
||||
self, new_room_id, user_id, rule,
|
||||
):
|
||||
def move_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
|
||||
"""Move a single push rule from one room to another for a specific user.
|
||||
|
||||
Args:
|
||||
@ -219,7 +216,7 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def move_push_rules_from_room_to_room_for_user(
|
||||
self, old_room_id, new_room_id, user_id,
|
||||
self, old_room_id, new_room_id, user_id
|
||||
):
|
||||
"""Move all of the push rules from one room to another for a specific
|
||||
user.
|
||||
@ -236,11 +233,11 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
|
||||
# delete them from the old room
|
||||
for rule in user_push_rules:
|
||||
conditions = rule.get("conditions", [])
|
||||
if any((c.get("key") == "room_id" and
|
||||
c.get("pattern") == old_room_id) for c in conditions):
|
||||
self.move_push_rule_from_room_to_room(
|
||||
new_room_id, user_id, rule,
|
||||
)
|
||||
if any(
|
||||
(c.get("key") == "room_id" and c.get("pattern") == old_room_id)
|
||||
for c in conditions
|
||||
):
|
||||
self.move_push_rule_from_room_to_room(new_room_id, user_id, rule)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def bulk_get_push_rules_for_room(self, event, context):
|
||||
@ -259,8 +256,9 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
|
||||
defer.returnValue(result)
|
||||
|
||||
@cachedInlineCallbacks(num_args=2, cache_context=True)
|
||||
def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids,
|
||||
cache_context, event=None):
|
||||
def _bulk_get_push_rules_for_room(
|
||||
self, room_id, state_group, current_state_ids, cache_context, event=None
|
||||
):
|
||||
# We don't use `state_group`, its there so that we can cache based
|
||||
# on it. However, its important that its never None, since two current_state's
|
||||
# with a state_group of None are likely to be different.
|
||||
@ -273,7 +271,9 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
|
||||
# sent a read receipt into the room.
|
||||
|
||||
users_in_room = yield self._get_joined_users_from_context(
|
||||
room_id, state_group, current_state_ids,
|
||||
room_id,
|
||||
state_group,
|
||||
current_state_ids,
|
||||
on_invalidate=cache_context.invalidate,
|
||||
event=event,
|
||||
)
|
||||
@ -282,7 +282,8 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
|
||||
# up the `get_if_users_have_pushers` cache with AS entries that we
|
||||
# know don't have pushers, nor even read receipts.
|
||||
local_users_in_room = set(
|
||||
u for u in users_in_room
|
||||
u
|
||||
for u in users_in_room
|
||||
if self.hs.is_mine_id(u)
|
||||
and not self.get_if_app_services_interested_in_user(u)
|
||||
)
|
||||
@ -290,15 +291,14 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
|
||||
# users in the room who have pushers need to get push rules run because
|
||||
# that's how their pushers work
|
||||
if_users_with_pushers = yield self.get_if_users_have_pushers(
|
||||
local_users_in_room,
|
||||
on_invalidate=cache_context.invalidate,
|
||||
local_users_in_room, on_invalidate=cache_context.invalidate
|
||||
)
|
||||
user_ids = set(
|
||||
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
|
||||
)
|
||||
|
||||
users_with_receipts = yield self.get_users_with_read_receipts_in_room(
|
||||
room_id, on_invalidate=cache_context.invalidate,
|
||||
room_id, on_invalidate=cache_context.invalidate
|
||||
)
|
||||
|
||||
# any users with pushers must be ours: they have pushers
|
||||
@ -307,29 +307,30 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
|
||||
user_ids.add(uid)
|
||||
|
||||
rules_by_user = yield self.bulk_get_push_rules(
|
||||
user_ids, on_invalidate=cache_context.invalidate,
|
||||
user_ids, on_invalidate=cache_context.invalidate
|
||||
)
|
||||
|
||||
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
|
||||
|
||||
defer.returnValue(rules_by_user)
|
||||
|
||||
@cachedList(cached_method_name="get_push_rules_enabled_for_user",
|
||||
list_name="user_ids", num_args=1, inlineCallbacks=True)
|
||||
@cachedList(
|
||||
cached_method_name="get_push_rules_enabled_for_user",
|
||||
list_name="user_ids",
|
||||
num_args=1,
|
||||
inlineCallbacks=True,
|
||||
)
|
||||
def bulk_get_push_rules_enabled(self, user_ids):
|
||||
if not user_ids:
|
||||
defer.returnValue({})
|
||||
|
||||
results = {
|
||||
user_id: {}
|
||||
for user_id in user_ids
|
||||
}
|
||||
results = {user_id: {} for user_id in user_ids}
|
||||
|
||||
rows = yield self._simple_select_many_batch(
|
||||
table="push_rules_enable",
|
||||
column="user_name",
|
||||
iterable=user_ids,
|
||||
retcols=("user_name", "rule_id", "enabled",),
|
||||
retcols=("user_name", "rule_id", "enabled"),
|
||||
desc="bulk_get_push_rules_enabled",
|
||||
)
|
||||
for row in rows:
|
||||
@ -341,8 +342,14 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
|
||||
class PushRuleStore(PushRulesWorkerStore):
|
||||
@defer.inlineCallbacks
|
||||
def add_push_rule(
|
||||
self, user_id, rule_id, priority_class, conditions, actions,
|
||||
before=None, after=None
|
||||
self,
|
||||
user_id,
|
||||
rule_id,
|
||||
priority_class,
|
||||
conditions,
|
||||
actions,
|
||||
before=None,
|
||||
after=None,
|
||||
):
|
||||
conditions_json = json.dumps(conditions)
|
||||
actions_json = json.dumps(actions)
|
||||
@ -352,20 +359,41 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
yield self.runInteraction(
|
||||
"_add_push_rule_relative_txn",
|
||||
self._add_push_rule_relative_txn,
|
||||
stream_id, event_stream_ordering, user_id, rule_id, priority_class,
|
||||
conditions_json, actions_json, before, after,
|
||||
stream_id,
|
||||
event_stream_ordering,
|
||||
user_id,
|
||||
rule_id,
|
||||
priority_class,
|
||||
conditions_json,
|
||||
actions_json,
|
||||
before,
|
||||
after,
|
||||
)
|
||||
else:
|
||||
yield self.runInteraction(
|
||||
"_add_push_rule_highest_priority_txn",
|
||||
self._add_push_rule_highest_priority_txn,
|
||||
stream_id, event_stream_ordering, user_id, rule_id, priority_class,
|
||||
conditions_json, actions_json,
|
||||
stream_id,
|
||||
event_stream_ordering,
|
||||
user_id,
|
||||
rule_id,
|
||||
priority_class,
|
||||
conditions_json,
|
||||
actions_json,
|
||||
)
|
||||
|
||||
def _add_push_rule_relative_txn(
|
||||
self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
|
||||
conditions_json, actions_json, before, after
|
||||
self,
|
||||
txn,
|
||||
stream_id,
|
||||
event_stream_ordering,
|
||||
user_id,
|
||||
rule_id,
|
||||
priority_class,
|
||||
conditions_json,
|
||||
actions_json,
|
||||
before,
|
||||
after,
|
||||
):
|
||||
# Lock the table since otherwise we'll have annoying races between the
|
||||
# SELECT here and the UPSERT below.
|
||||
@ -376,10 +404,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
res = self._simple_select_one_txn(
|
||||
txn,
|
||||
table="push_rules",
|
||||
keyvalues={
|
||||
"user_name": user_id,
|
||||
"rule_id": relative_to_rule,
|
||||
},
|
||||
keyvalues={"user_name": user_id, "rule_id": relative_to_rule},
|
||||
retcols=["priority_class", "priority"],
|
||||
allow_none=True,
|
||||
)
|
||||
@ -416,13 +441,27 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
txn.execute(sql, (user_id, priority_class, new_rule_priority))
|
||||
|
||||
self._upsert_push_rule_txn(
|
||||
txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
|
||||
new_rule_priority, conditions_json, actions_json,
|
||||
txn,
|
||||
stream_id,
|
||||
event_stream_ordering,
|
||||
user_id,
|
||||
rule_id,
|
||||
priority_class,
|
||||
new_rule_priority,
|
||||
conditions_json,
|
||||
actions_json,
|
||||
)
|
||||
|
||||
def _add_push_rule_highest_priority_txn(
|
||||
self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
|
||||
conditions_json, actions_json
|
||||
self,
|
||||
txn,
|
||||
stream_id,
|
||||
event_stream_ordering,
|
||||
user_id,
|
||||
rule_id,
|
||||
priority_class,
|
||||
conditions_json,
|
||||
actions_json,
|
||||
):
|
||||
# Lock the table since otherwise we'll have annoying races between the
|
||||
# SELECT here and the UPSERT below.
|
||||
@ -443,13 +482,28 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
|
||||
self._upsert_push_rule_txn(
|
||||
txn,
|
||||
stream_id, event_stream_ordering, user_id, rule_id, priority_class, new_prio,
|
||||
conditions_json, actions_json,
|
||||
stream_id,
|
||||
event_stream_ordering,
|
||||
user_id,
|
||||
rule_id,
|
||||
priority_class,
|
||||
new_prio,
|
||||
conditions_json,
|
||||
actions_json,
|
||||
)
|
||||
|
||||
def _upsert_push_rule_txn(
|
||||
self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
|
||||
priority, conditions_json, actions_json, update_stream=True
|
||||
self,
|
||||
txn,
|
||||
stream_id,
|
||||
event_stream_ordering,
|
||||
user_id,
|
||||
rule_id,
|
||||
priority_class,
|
||||
priority,
|
||||
conditions_json,
|
||||
actions_json,
|
||||
update_stream=True,
|
||||
):
|
||||
"""Specialised version of _simple_upsert_txn that picks a push_rule_id
|
||||
using the _push_rule_id_gen if it needs to insert the rule. It assumes
|
||||
@ -461,10 +515,10 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
" WHERE user_name = ? AND rule_id = ?"
|
||||
)
|
||||
|
||||
txn.execute(sql, (
|
||||
priority_class, priority, conditions_json, actions_json,
|
||||
user_id, rule_id,
|
||||
))
|
||||
txn.execute(
|
||||
sql,
|
||||
(priority_class, priority, conditions_json, actions_json, user_id, rule_id),
|
||||
)
|
||||
|
||||
if txn.rowcount == 0:
|
||||
# We didn't update a row with the given rule_id so insert one
|
||||
@ -486,14 +540,18 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
|
||||
if update_stream:
|
||||
self._insert_push_rules_update_txn(
|
||||
txn, stream_id, event_stream_ordering, user_id, rule_id,
|
||||
txn,
|
||||
stream_id,
|
||||
event_stream_ordering,
|
||||
user_id,
|
||||
rule_id,
|
||||
op="ADD",
|
||||
data={
|
||||
"priority_class": priority_class,
|
||||
"priority": priority,
|
||||
"conditions": conditions_json,
|
||||
"actions": actions_json,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -507,22 +565,23 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
user_id (str): The matrix ID of the push rule owner
|
||||
rule_id (str): The rule_id of the rule to be deleted
|
||||
"""
|
||||
|
||||
def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
|
||||
self._simple_delete_one_txn(
|
||||
txn,
|
||||
"push_rules",
|
||||
{'user_name': user_id, 'rule_id': rule_id},
|
||||
txn, "push_rules", {'user_name': user_id, 'rule_id': rule_id}
|
||||
)
|
||||
|
||||
self._insert_push_rules_update_txn(
|
||||
txn, stream_id, event_stream_ordering, user_id, rule_id,
|
||||
op="DELETE"
|
||||
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
|
||||
)
|
||||
|
||||
with self._push_rules_stream_id_gen.get_next() as ids:
|
||||
stream_id, event_stream_ordering = ids
|
||||
yield self.runInteraction(
|
||||
"delete_push_rule", delete_push_rule_txn, stream_id, event_stream_ordering
|
||||
"delete_push_rule",
|
||||
delete_push_rule_txn,
|
||||
stream_id,
|
||||
event_stream_ordering,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -532,7 +591,11 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
yield self.runInteraction(
|
||||
"_set_push_rule_enabled_txn",
|
||||
self._set_push_rule_enabled_txn,
|
||||
stream_id, event_stream_ordering, user_id, rule_id, enabled
|
||||
stream_id,
|
||||
event_stream_ordering,
|
||||
user_id,
|
||||
rule_id,
|
||||
enabled,
|
||||
)
|
||||
|
||||
def _set_push_rule_enabled_txn(
|
||||
@ -548,8 +611,12 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
)
|
||||
|
||||
self._insert_push_rules_update_txn(
|
||||
txn, stream_id, event_stream_ordering, user_id, rule_id,
|
||||
op="ENABLE" if enabled else "DISABLE"
|
||||
txn,
|
||||
stream_id,
|
||||
event_stream_ordering,
|
||||
user_id,
|
||||
rule_id,
|
||||
op="ENABLE" if enabled else "DISABLE",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -563,9 +630,16 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
priority_class = -1
|
||||
priority = 1
|
||||
self._upsert_push_rule_txn(
|
||||
txn, stream_id, event_stream_ordering, user_id, rule_id,
|
||||
priority_class, priority, "[]", actions_json,
|
||||
update_stream=False
|
||||
txn,
|
||||
stream_id,
|
||||
event_stream_ordering,
|
||||
user_id,
|
||||
rule_id,
|
||||
priority_class,
|
||||
priority,
|
||||
"[]",
|
||||
actions_json,
|
||||
update_stream=False,
|
||||
)
|
||||
else:
|
||||
self._simple_update_one_txn(
|
||||
@ -576,15 +650,22 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
)
|
||||
|
||||
self._insert_push_rules_update_txn(
|
||||
txn, stream_id, event_stream_ordering, user_id, rule_id,
|
||||
op="ACTIONS", data={"actions": actions_json}
|
||||
txn,
|
||||
stream_id,
|
||||
event_stream_ordering,
|
||||
user_id,
|
||||
rule_id,
|
||||
op="ACTIONS",
|
||||
data={"actions": actions_json},
|
||||
)
|
||||
|
||||
with self._push_rules_stream_id_gen.get_next() as ids:
|
||||
stream_id, event_stream_ordering = ids
|
||||
yield self.runInteraction(
|
||||
"set_push_rule_actions", set_push_rule_actions_txn,
|
||||
stream_id, event_stream_ordering
|
||||
"set_push_rule_actions",
|
||||
set_push_rule_actions_txn,
|
||||
stream_id,
|
||||
event_stream_ordering,
|
||||
)
|
||||
|
||||
def _insert_push_rules_update_txn(
|
||||
@ -602,12 +683,8 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
|
||||
self._simple_insert_txn(txn, "push_rules_stream", values=values)
|
||||
|
||||
txn.call_after(
|
||||
self.get_push_rules_for_user.invalidate, (user_id,)
|
||||
)
|
||||
txn.call_after(
|
||||
self.get_push_rules_enabled_for_user.invalidate, (user_id,)
|
||||
)
|
||||
txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
|
||||
txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
|
||||
txn.call_after(
|
||||
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
|
||||
)
|
||||
@ -627,6 +704,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
)
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
return txn.fetchall()
|
||||
|
||||
return self.runInteraction(
|
||||
"get_all_push_rule_updates", get_all_push_rule_updates_txn
|
||||
)
|
||||
|
@ -47,7 +47,9 @@ class PusherWorkerStore(SQLBaseStore):
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
"Invalid JSON in data for pusher %d: %s, %s",
|
||||
r['id'], dataJson, e.args[0],
|
||||
r['id'],
|
||||
dataJson,
|
||||
e.args[0],
|
||||
)
|
||||
pass
|
||||
|
||||
@ -64,20 +66,16 @@ class PusherWorkerStore(SQLBaseStore):
|
||||
defer.returnValue(ret is not None)
|
||||
|
||||
def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
|
||||
return self.get_pushers_by({
|
||||
"app_id": app_id,
|
||||
"pushkey": pushkey,
|
||||
})
|
||||
return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
|
||||
|
||||
def get_pushers_by_user_id(self, user_id):
|
||||
return self.get_pushers_by({
|
||||
"user_name": user_id,
|
||||
})
|
||||
return self.get_pushers_by({"user_name": user_id})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_pushers_by(self, keyvalues):
|
||||
ret = yield self._simple_select_list(
|
||||
"pushers", keyvalues,
|
||||
"pushers",
|
||||
keyvalues,
|
||||
[
|
||||
"id",
|
||||
"user_name",
|
||||
@ -94,7 +92,8 @@ class PusherWorkerStore(SQLBaseStore):
|
||||
"last_stream_ordering",
|
||||
"last_success",
|
||||
"failing_since",
|
||||
], desc="get_pushers_by"
|
||||
],
|
||||
desc="get_pushers_by",
|
||||
)
|
||||
defer.returnValue(self._decode_pushers_rows(ret))
|
||||
|
||||
@ -135,6 +134,7 @@ class PusherWorkerStore(SQLBaseStore):
|
||||
deleted = txn.fetchall()
|
||||
|
||||
return (updated, deleted)
|
||||
|
||||
return self.runInteraction(
|
||||
"get_all_updated_pushers", get_all_updated_pushers_txn
|
||||
)
|
||||
@ -177,6 +177,7 @@ class PusherWorkerStore(SQLBaseStore):
|
||||
results.sort() # Sort so that they're ordered by stream id
|
||||
|
||||
return results
|
||||
|
||||
return self.runInteraction(
|
||||
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
|
||||
)
|
||||
@ -186,15 +187,19 @@ class PusherWorkerStore(SQLBaseStore):
|
||||
# This only exists for the cachedList decorator
|
||||
raise NotImplementedError()
|
||||
|
||||
@cachedList(cached_method_name="get_if_user_has_pusher",
|
||||
list_name="user_ids", num_args=1, inlineCallbacks=True)
|
||||
@cachedList(
|
||||
cached_method_name="get_if_user_has_pusher",
|
||||
list_name="user_ids",
|
||||
num_args=1,
|
||||
inlineCallbacks=True,
|
||||
)
|
||||
def get_if_users_have_pushers(self, user_ids):
|
||||
rows = yield self._simple_select_many_batch(
|
||||
table='pushers',
|
||||
column='user_name',
|
||||
iterable=user_ids,
|
||||
retcols=['user_name'],
|
||||
desc='get_if_users_have_pushers'
|
||||
desc='get_if_users_have_pushers',
|
||||
)
|
||||
|
||||
result = {user_id: False for user_id in user_ids}
|
||||
@ -208,20 +213,27 @@ class PusherStore(PusherWorkerStore):
|
||||
return self._pushers_id_gen.get_current_token()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_pusher(self, user_id, access_token, kind, app_id,
|
||||
app_display_name, device_display_name,
|
||||
pushkey, pushkey_ts, lang, data, last_stream_ordering,
|
||||
profile_tag=""):
|
||||
def add_pusher(
|
||||
self,
|
||||
user_id,
|
||||
access_token,
|
||||
kind,
|
||||
app_id,
|
||||
app_display_name,
|
||||
device_display_name,
|
||||
pushkey,
|
||||
pushkey_ts,
|
||||
lang,
|
||||
data,
|
||||
last_stream_ordering,
|
||||
profile_tag="",
|
||||
):
|
||||
with self._pushers_id_gen.get_next() as stream_id:
|
||||
# no need to lock because `pushers` has a unique key on
|
||||
# (app_id, pushkey, user_name) so _simple_upsert will retry
|
||||
yield self._simple_upsert(
|
||||
table="pushers",
|
||||
keyvalues={
|
||||
"app_id": app_id,
|
||||
"pushkey": pushkey,
|
||||
"user_name": user_id,
|
||||
},
|
||||
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
|
||||
values={
|
||||
"access_token": access_token,
|
||||
"kind": kind,
|
||||
@ -247,7 +259,8 @@ class PusherStore(PusherWorkerStore):
|
||||
yield self.runInteraction(
|
||||
"add_pusher",
|
||||
self._invalidate_cache_and_stream,
|
||||
self.get_if_user_has_pusher, (user_id,)
|
||||
self.get_if_user_has_pusher,
|
||||
(user_id,),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -260,7 +273,7 @@ class PusherStore(PusherWorkerStore):
|
||||
self._simple_delete_one_txn(
|
||||
txn,
|
||||
"pushers",
|
||||
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id}
|
||||
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
|
||||
)
|
||||
|
||||
# it's possible for us to end up with duplicate rows for
|
||||
@ -278,13 +291,12 @@ class PusherStore(PusherWorkerStore):
|
||||
)
|
||||
|
||||
with self._pushers_id_gen.get_next() as stream_id:
|
||||
yield self.runInteraction(
|
||||
"delete_pusher", delete_pusher_txn, stream_id
|
||||
)
|
||||
yield self.runInteraction("delete_pusher", delete_pusher_txn, stream_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_pusher_last_stream_ordering(self, app_id, pushkey, user_id,
|
||||
last_stream_ordering):
|
||||
def update_pusher_last_stream_ordering(
|
||||
self, app_id, pushkey, user_id, last_stream_ordering
|
||||
):
|
||||
yield self._simple_update_one(
|
||||
"pushers",
|
||||
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
|
||||
@ -293,23 +305,21 @@ class PusherStore(PusherWorkerStore):
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_pusher_last_stream_ordering_and_success(self, app_id, pushkey,
|
||||
user_id,
|
||||
last_stream_ordering,
|
||||
last_success):
|
||||
def update_pusher_last_stream_ordering_and_success(
|
||||
self, app_id, pushkey, user_id, last_stream_ordering, last_success
|
||||
):
|
||||
yield self._simple_update_one(
|
||||
"pushers",
|
||||
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
|
||||
{
|
||||
'last_stream_ordering': last_stream_ordering,
|
||||
'last_success': last_success
|
||||
'last_success': last_success,
|
||||
},
|
||||
desc="update_pusher_last_stream_ordering_and_success",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_pusher_failing_since(self, app_id, pushkey, user_id,
|
||||
failing_since):
|
||||
def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
|
||||
yield self._simple_update_one(
|
||||
"pushers",
|
||||
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
|
||||
@ -323,14 +333,14 @@ class PusherStore(PusherWorkerStore):
|
||||
"pusher_throttle",
|
||||
{"pusher": pusher_id},
|
||||
["room_id", "last_sent_ts", "throttle_ms"],
|
||||
desc="get_throttle_params_by_room"
|
||||
desc="get_throttle_params_by_room",
|
||||
)
|
||||
|
||||
params_by_room = {}
|
||||
for row in res:
|
||||
params_by_room[row["room_id"]] = {
|
||||
"last_sent_ts": row["last_sent_ts"],
|
||||
"throttle_ms": row["throttle_ms"]
|
||||
"throttle_ms": row["throttle_ms"],
|
||||
}
|
||||
|
||||
defer.returnValue(params_by_room)
|
||||
|
@ -64,10 +64,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
def get_receipts_for_room(self, room_id, receipt_type):
|
||||
return self._simple_select_list(
|
||||
table="receipts_linearized",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
"receipt_type": receipt_type,
|
||||
},
|
||||
keyvalues={"room_id": room_id, "receipt_type": receipt_type},
|
||||
retcols=("user_id", "event_id"),
|
||||
desc="get_receipts_for_room",
|
||||
)
|
||||
@ -79,7 +76,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
"receipt_type": receipt_type,
|
||||
"user_id": user_id
|
||||
"user_id": user_id,
|
||||
},
|
||||
retcol="event_id",
|
||||
desc="get_own_receipt_for_user",
|
||||
@ -90,10 +87,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
def get_receipts_for_user(self, user_id, receipt_type):
|
||||
rows = yield self._simple_select_list(
|
||||
table="receipts_linearized",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"receipt_type": receipt_type,
|
||||
},
|
||||
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
|
||||
retcols=("room_id", "event_id"),
|
||||
desc="get_receipts_for_user",
|
||||
)
|
||||
@ -114,16 +108,18 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
)
|
||||
txn.execute(sql, (user_id,))
|
||||
return txn.fetchall()
|
||||
rows = yield self.runInteraction(
|
||||
"get_receipts_for_user_with_orderings", f
|
||||
|
||||
rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f)
|
||||
defer.returnValue(
|
||||
{
|
||||
row[0]: {
|
||||
"event_id": row[1],
|
||||
"topological_ordering": row[2],
|
||||
"stream_ordering": row[3],
|
||||
}
|
||||
for row in rows
|
||||
}
|
||||
)
|
||||
defer.returnValue({
|
||||
row[0]: {
|
||||
"event_id": row[1],
|
||||
"topological_ordering": row[2],
|
||||
"stream_ordering": row[3],
|
||||
} for row in rows
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
|
||||
@ -177,6 +173,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
|
||||
"""See get_linearized_receipts_for_room
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
if from_key:
|
||||
sql = (
|
||||
@ -184,48 +181,40 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
" room_id = ? AND stream_id > ? AND stream_id <= ?"
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
sql,
|
||||
(room_id, from_key, to_key)
|
||||
)
|
||||
txn.execute(sql, (room_id, from_key, to_key))
|
||||
else:
|
||||
sql = (
|
||||
"SELECT * FROM receipts_linearized WHERE"
|
||||
" room_id = ? AND stream_id <= ?"
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
sql,
|
||||
(room_id, to_key)
|
||||
)
|
||||
txn.execute(sql, (room_id, to_key))
|
||||
|
||||
rows = self.cursor_to_dict(txn)
|
||||
|
||||
return rows
|
||||
|
||||
rows = yield self.runInteraction(
|
||||
"get_linearized_receipts_for_room", f
|
||||
)
|
||||
rows = yield self.runInteraction("get_linearized_receipts_for_room", f)
|
||||
|
||||
if not rows:
|
||||
defer.returnValue([])
|
||||
|
||||
content = {}
|
||||
for row in rows:
|
||||
content.setdefault(
|
||||
row["event_id"], {}
|
||||
).setdefault(
|
||||
row["receipt_type"], {}
|
||||
)[row["user_id"]] = json.loads(row["data"])
|
||||
content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
|
||||
row["user_id"]
|
||||
] = json.loads(row["data"])
|
||||
|
||||
defer.returnValue([{
|
||||
"type": "m.receipt",
|
||||
"room_id": room_id,
|
||||
"content": content,
|
||||
}])
|
||||
defer.returnValue(
|
||||
[{"type": "m.receipt", "room_id": room_id, "content": content}]
|
||||
)
|
||||
|
||||
@cachedList(cached_method_name="_get_linearized_receipts_for_room",
|
||||
list_name="room_ids", num_args=3, inlineCallbacks=True)
|
||||
@cachedList(
|
||||
cached_method_name="_get_linearized_receipts_for_room",
|
||||
list_name="room_ids",
|
||||
num_args=3,
|
||||
inlineCallbacks=True,
|
||||
)
|
||||
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
|
||||
if not room_ids:
|
||||
defer.returnValue({})
|
||||
@ -235,9 +224,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
sql = (
|
||||
"SELECT * FROM receipts_linearized WHERE"
|
||||
" room_id IN (%s) AND stream_id > ? AND stream_id <= ?"
|
||||
) % (
|
||||
",".join(["?"] * len(room_ids))
|
||||
)
|
||||
) % (",".join(["?"] * len(room_ids)))
|
||||
args = list(room_ids)
|
||||
args.extend([from_key, to_key])
|
||||
|
||||
@ -246,9 +233,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
sql = (
|
||||
"SELECT * FROM receipts_linearized WHERE"
|
||||
" room_id IN (%s) AND stream_id <= ?"
|
||||
) % (
|
||||
",".join(["?"] * len(room_ids))
|
||||
)
|
||||
) % (",".join(["?"] * len(room_ids)))
|
||||
|
||||
args = list(room_ids)
|
||||
args.append(to_key)
|
||||
@ -257,19 +242,16 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
|
||||
return self.cursor_to_dict(txn)
|
||||
|
||||
txn_results = yield self.runInteraction(
|
||||
"_get_linearized_receipts_for_rooms", f
|
||||
)
|
||||
txn_results = yield self.runInteraction("_get_linearized_receipts_for_rooms", f)
|
||||
|
||||
results = {}
|
||||
for row in txn_results:
|
||||
# We want a single event per room, since we want to batch the
|
||||
# receipts by room, event and type.
|
||||
room_event = results.setdefault(row["room_id"], {
|
||||
"type": "m.receipt",
|
||||
"room_id": row["room_id"],
|
||||
"content": {},
|
||||
})
|
||||
room_event = results.setdefault(
|
||||
row["room_id"],
|
||||
{"type": "m.receipt", "room_id": row["room_id"], "content": {}},
|
||||
)
|
||||
|
||||
# The content is of the form:
|
||||
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
|
||||
@ -301,21 +283,21 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
args.append(limit)
|
||||
txn.execute(sql, args)
|
||||
|
||||
return (
|
||||
r[0:5] + (json.loads(r[5]), ) for r in txn
|
||||
)
|
||||
return (r[0:5] + (json.loads(r[5]),) for r in txn)
|
||||
|
||||
return self.runInteraction(
|
||||
"get_all_updated_receipts", get_all_updated_receipts_txn
|
||||
)
|
||||
|
||||
def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
|
||||
user_id):
|
||||
def _invalidate_get_users_with_receipts_in_room(
|
||||
self, room_id, receipt_type, user_id
|
||||
):
|
||||
if receipt_type != "m.read":
|
||||
return
|
||||
|
||||
# Returns either an ObservableDeferred or the raw result
|
||||
res = self.get_users_with_read_receipts_in_room.cache.get(
|
||||
room_id, None, update_metrics=False,
|
||||
room_id, None, update_metrics=False
|
||||
)
|
||||
|
||||
# first handle the Deferred case
|
||||
@ -346,8 +328,9 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
||||
def get_max_receipt_stream_id(self):
|
||||
return self._receipts_id_gen.get_current_token()
|
||||
|
||||
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
|
||||
user_id, event_id, data, stream_id):
|
||||
def insert_linearized_receipt_txn(
|
||||
self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
|
||||
):
|
||||
"""Inserts a read-receipt into the database if it's newer than the current RR
|
||||
|
||||
Returns: int|None
|
||||
@ -360,7 +343,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
||||
table="events",
|
||||
retcols=["stream_ordering", "received_ts"],
|
||||
keyvalues={"event_id": event_id},
|
||||
allow_none=True
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
stream_ordering = int(res["stream_ordering"]) if res else None
|
||||
@ -381,31 +364,31 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
||||
logger.debug(
|
||||
"Ignoring new receipt for %s in favour of existing "
|
||||
"one for later event %s",
|
||||
event_id, eid,
|
||||
event_id,
|
||||
eid,
|
||||
)
|
||||
return None
|
||||
|
||||
txn.call_after(
|
||||
self.get_receipts_for_room.invalidate, (room_id, receipt_type)
|
||||
)
|
||||
txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
|
||||
txn.call_after(
|
||||
self._invalidate_get_users_with_receipts_in_room,
|
||||
room_id, receipt_type, user_id,
|
||||
)
|
||||
txn.call_after(
|
||||
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
|
||||
room_id,
|
||||
receipt_type,
|
||||
user_id,
|
||||
)
|
||||
txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
|
||||
# FIXME: This shouldn't invalidate the whole cache
|
||||
txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,))
|
||||
txn.call_after(
|
||||
self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
|
||||
)
|
||||
|
||||
txn.call_after(
|
||||
self._receipts_stream_cache.entity_has_changed,
|
||||
room_id, stream_id
|
||||
self._receipts_stream_cache.entity_has_changed, room_id, stream_id
|
||||
)
|
||||
|
||||
txn.call_after(
|
||||
self.get_last_receipt_event_id_for_user.invalidate,
|
||||
(user_id, room_id, receipt_type)
|
||||
(user_id, room_id, receipt_type),
|
||||
)
|
||||
|
||||
self._simple_delete_txn(
|
||||
@ -415,7 +398,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
||||
"room_id": room_id,
|
||||
"receipt_type": receipt_type,
|
||||
"user_id": user_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
self._simple_insert_txn(
|
||||
@ -428,15 +411,12 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
||||
"user_id": user_id,
|
||||
"event_id": event_id,
|
||||
"data": json.dumps(data),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if receipt_type == "m.read" and stream_ordering is not None:
|
||||
self._remove_old_push_actions_before_txn(
|
||||
txn,
|
||||
room_id=room_id,
|
||||
user_id=user_id,
|
||||
stream_ordering=stream_ordering,
|
||||
txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
|
||||
)
|
||||
|
||||
return rx_ts
|
||||
@ -479,7 +459,10 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
||||
event_ts = yield self.runInteraction(
|
||||
"insert_linearized_receipt",
|
||||
self.insert_linearized_receipt_txn,
|
||||
room_id, receipt_type, user_id, linearized_event_id,
|
||||
room_id,
|
||||
receipt_type,
|
||||
user_id,
|
||||
linearized_event_id,
|
||||
data,
|
||||
stream_id=stream_id,
|
||||
)
|
||||
@ -490,39 +473,43 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
||||
now = self._clock.time_msec()
|
||||
logger.debug(
|
||||
"RR for event %s in %s (%i ms old)",
|
||||
linearized_event_id, room_id, now - event_ts,
|
||||
linearized_event_id,
|
||||
room_id,
|
||||
now - event_ts,
|
||||
)
|
||||
|
||||
yield self.insert_graph_receipt(
|
||||
room_id, receipt_type, user_id, event_ids, data
|
||||
)
|
||||
yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
|
||||
|
||||
max_persisted_id = self._receipts_id_gen.get_current_token()
|
||||
|
||||
defer.returnValue((stream_id, max_persisted_id))
|
||||
|
||||
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids,
|
||||
data):
|
||||
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
|
||||
return self.runInteraction(
|
||||
"insert_graph_receipt",
|
||||
self.insert_graph_receipt_txn,
|
||||
room_id, receipt_type, user_id, event_ids, data
|
||||
room_id,
|
||||
receipt_type,
|
||||
user_id,
|
||||
event_ids,
|
||||
data,
|
||||
)
|
||||
|
||||
def insert_graph_receipt_txn(self, txn, room_id, receipt_type,
|
||||
user_id, event_ids, data):
|
||||
txn.call_after(
|
||||
self.get_receipts_for_room.invalidate, (room_id, receipt_type)
|
||||
)
|
||||
def insert_graph_receipt_txn(
|
||||
self, txn, room_id, receipt_type, user_id, event_ids, data
|
||||
):
|
||||
txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
|
||||
txn.call_after(
|
||||
self._invalidate_get_users_with_receipts_in_room,
|
||||
room_id, receipt_type, user_id,
|
||||
)
|
||||
txn.call_after(
|
||||
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
|
||||
room_id,
|
||||
receipt_type,
|
||||
user_id,
|
||||
)
|
||||
txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
|
||||
# FIXME: This shouldn't invalidate the whole cache
|
||||
txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,))
|
||||
txn.call_after(
|
||||
self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
|
||||
)
|
||||
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
@ -531,7 +518,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
||||
"room_id": room_id,
|
||||
"receipt_type": receipt_type,
|
||||
"user_id": user_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
@ -542,5 +529,5 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
||||
"user_id": user_id,
|
||||
"event_ids": json.dumps(event_ids),
|
||||
"data": json.dumps(data),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
@ -37,13 +37,15 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||
def get_user_by_id(self, user_id):
|
||||
return self._simple_select_one(
|
||||
table="users",
|
||||
keyvalues={
|
||||
"name": user_id,
|
||||
},
|
||||
keyvalues={"name": user_id},
|
||||
retcols=[
|
||||
"name", "password_hash", "is_guest",
|
||||
"consent_version", "consent_server_notice_sent",
|
||||
"appservice_id", "creation_ts",
|
||||
"name",
|
||||
"password_hash",
|
||||
"is_guest",
|
||||
"consent_version",
|
||||
"consent_server_notice_sent",
|
||||
"appservice_id",
|
||||
"creation_ts",
|
||||
],
|
||||
allow_none=True,
|
||||
desc="get_user_by_id",
|
||||
@ -81,9 +83,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||
including the keys `name`, `is_guest`, `device_id`, `token_id`.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"get_user_by_access_token",
|
||||
self._query_for_auth,
|
||||
token
|
||||
"get_user_by_access_token", self._query_for_auth, token
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -143,10 +143,10 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||
"""Gets users that match user_id case insensitively.
|
||||
Returns a mapping of user_id -> password_hash.
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
sql = (
|
||||
"SELECT name, password_hash FROM users"
|
||||
" WHERE lower(name) = lower(?)"
|
||||
"SELECT name, password_hash FROM users" " WHERE lower(name) = lower(?)"
|
||||
)
|
||||
txn.execute(sql, (user_id,))
|
||||
return dict(txn)
|
||||
@ -156,6 +156,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||
@defer.inlineCallbacks
|
||||
def count_all_users(self):
|
||||
"""Counts all users registered on the homeserver."""
|
||||
|
||||
def _count_users(txn):
|
||||
txn.execute("SELECT COUNT(*) AS users FROM users")
|
||||
rows = self.cursor_to_dict(txn)
|
||||
@ -173,6 +174,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||
3) bridged users
|
||||
who registered on the homeserver in the past 24 hours
|
||||
"""
|
||||
|
||||
def _count_daily_user_type(txn):
|
||||
yesterday = int(self._clock.time()) - (60 * 60 * 24)
|
||||
|
||||
@ -193,15 +195,18 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||
for row in txn:
|
||||
results[row[0]] = row[1]
|
||||
return results
|
||||
|
||||
return self.runInteraction("count_daily_user_type", _count_daily_user_type)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def count_nonbridged_users(self):
|
||||
def _count_users(txn):
|
||||
txn.execute("""
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT COALESCE(COUNT(*), 0) FROM users
|
||||
WHERE appservice_id IS NULL
|
||||
""")
|
||||
"""
|
||||
)
|
||||
count, = txn.fetchone()
|
||||
return count
|
||||
|
||||
@ -220,6 +225,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||
avoid the case of ID 10000000 being pre-allocated, so us wasting the
|
||||
first (and shortest) many generated user IDs.
|
||||
"""
|
||||
|
||||
def _find_next_generated_user_id(txn):
|
||||
txn.execute("SELECT name FROM users")
|
||||
|
||||
@ -227,7 +233,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||
|
||||
found = set()
|
||||
|
||||
for user_id, in txn:
|
||||
for (user_id,) in txn:
|
||||
match = regex.search(user_id)
|
||||
if match:
|
||||
found.add(int(match.group(1)))
|
||||
@ -235,20 +241,22 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||
if i not in found:
|
||||
return i
|
||||
|
||||
defer.returnValue((yield self.runInteraction(
|
||||
"find_next_generated_user_id",
|
||||
_find_next_generated_user_id
|
||||
)))
|
||||
defer.returnValue(
|
||||
(
|
||||
yield self.runInteraction(
|
||||
"find_next_generated_user_id", _find_next_generated_user_id
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_3pid_guest_access_token(self, medium, address):
|
||||
ret = yield self._simple_select_one(
|
||||
"threepid_guest_access_tokens",
|
||||
{
|
||||
"medium": medium,
|
||||
"address": address
|
||||
},
|
||||
["guest_access_token"], True, 'get_3pid_guest_access_token'
|
||||
{"medium": medium, "address": address},
|
||||
["guest_access_token"],
|
||||
True,
|
||||
'get_3pid_guest_access_token',
|
||||
)
|
||||
if ret:
|
||||
defer.returnValue(ret["guest_access_token"])
|
||||
@ -266,8 +274,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||
Deferred[str|None]: user id or None if no user id/threepid mapping exists
|
||||
"""
|
||||
user_id = yield self.runInteraction(
|
||||
"get_user_id_by_threepid", self.get_user_id_by_threepid_txn,
|
||||
medium, address
|
||||
"get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
|
||||
)
|
||||
defer.returnValue(user_id)
|
||||
|
||||
@ -285,11 +292,9 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||
ret = self._simple_select_one_txn(
|
||||
txn,
|
||||
"user_threepids",
|
||||
{
|
||||
"medium": medium,
|
||||
"address": address
|
||||
},
|
||||
['user_id'], True
|
||||
{"medium": medium, "address": address},
|
||||
['user_id'],
|
||||
True,
|
||||
)
|
||||
if ret:
|
||||
return ret['user_id']
|
||||
@ -297,41 +302,33 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
|
||||
yield self._simple_upsert("user_threepids", {
|
||||
"medium": medium,
|
||||
"address": address,
|
||||
}, {
|
||||
"user_id": user_id,
|
||||
"validated_at": validated_at,
|
||||
"added_at": added_at,
|
||||
})
|
||||
yield self._simple_upsert(
|
||||
"user_threepids",
|
||||
{"medium": medium, "address": address},
|
||||
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_get_threepids(self, user_id):
|
||||
ret = yield self._simple_select_list(
|
||||
"user_threepids", {
|
||||
"user_id": user_id
|
||||
},
|
||||
"user_threepids",
|
||||
{"user_id": user_id},
|
||||
['medium', 'address', 'validated_at', 'added_at'],
|
||||
'user_get_threepids'
|
||||
'user_get_threepids',
|
||||
)
|
||||
defer.returnValue(ret)
|
||||
|
||||
def user_delete_threepid(self, user_id, medium, address):
|
||||
return self._simple_delete(
|
||||
"user_threepids",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"medium": medium,
|
||||
"address": address,
|
||||
},
|
||||
keyvalues={"user_id": user_id, "medium": medium, "address": address},
|
||||
desc="user_delete_threepids",
|
||||
)
|
||||
|
||||
|
||||
class RegistrationStore(RegistrationWorkerStore,
|
||||
background_updates.BackgroundUpdateStore):
|
||||
|
||||
class RegistrationStore(
|
||||
RegistrationWorkerStore, background_updates.BackgroundUpdateStore
|
||||
):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(RegistrationStore, self).__init__(db_conn, hs)
|
||||
|
||||
@ -372,18 +369,22 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||
|
||||
yield self._simple_insert(
|
||||
"access_tokens",
|
||||
{
|
||||
"id": next_id,
|
||||
"user_id": user_id,
|
||||
"token": token,
|
||||
"device_id": device_id,
|
||||
},
|
||||
{"id": next_id, "user_id": user_id, "token": token, "device_id": device_id},
|
||||
desc="add_access_token_to_user",
|
||||
)
|
||||
|
||||
def register(self, user_id, token=None, password_hash=None,
|
||||
was_guest=False, make_guest=False, appservice_id=None,
|
||||
create_profile_with_displayname=None, admin=False, user_type=None):
|
||||
def register(
|
||||
self,
|
||||
user_id,
|
||||
token=None,
|
||||
password_hash=None,
|
||||
was_guest=False,
|
||||
make_guest=False,
|
||||
appservice_id=None,
|
||||
create_profile_with_displayname=None,
|
||||
admin=False,
|
||||
user_type=None,
|
||||
):
|
||||
"""Attempts to register an account.
|
||||
|
||||
Args:
|
||||
@ -417,7 +418,7 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||
appservice_id,
|
||||
create_profile_with_displayname,
|
||||
admin,
|
||||
user_type
|
||||
user_type,
|
||||
)
|
||||
|
||||
def _register(
|
||||
@ -447,10 +448,7 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||
self._simple_select_one_txn(
|
||||
txn,
|
||||
"users",
|
||||
keyvalues={
|
||||
"name": user_id,
|
||||
"is_guest": 1,
|
||||
},
|
||||
keyvalues={"name": user_id, "is_guest": 1},
|
||||
retcols=("name",),
|
||||
allow_none=False,
|
||||
)
|
||||
@ -458,10 +456,7 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||
self._simple_update_one_txn(
|
||||
txn,
|
||||
"users",
|
||||
keyvalues={
|
||||
"name": user_id,
|
||||
"is_guest": 1,
|
||||
},
|
||||
keyvalues={"name": user_id, "is_guest": 1},
|
||||
updatevalues={
|
||||
"password_hash": password_hash,
|
||||
"upgrade_ts": now,
|
||||
@ -469,7 +464,7 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||
"appservice_id": appservice_id,
|
||||
"admin": 1 if admin else 0,
|
||||
"user_type": user_type,
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
self._simple_insert_txn(
|
||||
@ -483,20 +478,17 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||
"appservice_id": appservice_id,
|
||||
"admin": 1 if admin else 0,
|
||||
"user_type": user_type,
|
||||
}
|
||||
},
|
||||
)
|
||||
except self.database_engine.module.IntegrityError:
|
||||
raise StoreError(
|
||||
400, "User ID already taken.", errcode=Codes.USER_IN_USE
|
||||
)
|
||||
raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
|
||||
|
||||
if token:
|
||||
# it's possible for this to get a conflict, but only for a single user
|
||||
# since tokens are namespaced based on their user ID
|
||||
txn.execute(
|
||||
"INSERT INTO access_tokens(id, user_id, token)"
|
||||
" VALUES (?,?,?)",
|
||||
(next_id, user_id, token,)
|
||||
"INSERT INTO access_tokens(id, user_id, token)" " VALUES (?,?,?)",
|
||||
(next_id, user_id, token),
|
||||
)
|
||||
|
||||
if create_profile_with_displayname:
|
||||
@ -507,12 +499,10 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||
# while everything else uses the full mxid.
|
||||
txn.execute(
|
||||
"INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
|
||||
(user_id_obj.localpart, create_profile_with_displayname)
|
||||
(user_id_obj.localpart, create_profile_with_displayname),
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_user_by_id, (user_id,)
|
||||
)
|
||||
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
||||
txn.call_after(self.is_guest.invalidate, (user_id,))
|
||||
|
||||
def user_set_password_hash(self, user_id, password_hash):
|
||||
@ -521,22 +511,14 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||
removes most of the entries subsequently anyway so it would be
|
||||
pointless. Use flush_user separately.
|
||||
"""
|
||||
|
||||
def user_set_password_hash_txn(txn):
|
||||
self._simple_update_one_txn(
|
||||
txn,
|
||||
'users', {
|
||||
'name': user_id
|
||||
},
|
||||
{
|
||||
'password_hash': password_hash
|
||||
}
|
||||
txn, 'users', {'name': user_id}, {'password_hash': password_hash}
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_user_by_id, (user_id,)
|
||||
)
|
||||
return self.runInteraction(
|
||||
"user_set_password_hash", user_set_password_hash_txn
|
||||
)
|
||||
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
||||
|
||||
return self.runInteraction("user_set_password_hash", user_set_password_hash_txn)
|
||||
|
||||
def user_set_consent_version(self, user_id, consent_version):
|
||||
"""Updates the user table to record privacy policy consent
|
||||
@ -549,16 +531,16 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||
Raises:
|
||||
StoreError(404) if user not found
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
self._simple_update_one_txn(
|
||||
txn,
|
||||
table='users',
|
||||
keyvalues={'name': user_id, },
|
||||
updatevalues={'consent_version': consent_version, },
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_user_by_id, (user_id,)
|
||||
keyvalues={'name': user_id},
|
||||
updatevalues={'consent_version': consent_version},
|
||||
)
|
||||
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
||||
|
||||
return self.runInteraction("user_set_consent_version", f)
|
||||
|
||||
def user_set_consent_server_notice_sent(self, user_id, consent_version):
|
||||
@ -573,20 +555,19 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||
Raises:
|
||||
StoreError(404) if user not found
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
self._simple_update_one_txn(
|
||||
txn,
|
||||
table='users',
|
||||
keyvalues={'name': user_id, },
|
||||
updatevalues={'consent_server_notice_sent': consent_version, },
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_user_by_id, (user_id,)
|
||||
keyvalues={'name': user_id},
|
||||
updatevalues={'consent_server_notice_sent': consent_version},
|
||||
)
|
||||
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
||||
|
||||
return self.runInteraction("user_set_consent_server_notice_sent", f)
|
||||
|
||||
def user_delete_access_tokens(self, user_id, except_token_id=None,
|
||||
device_id=None):
|
||||
def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
|
||||
"""
|
||||
Invalidate access tokens belonging to a user
|
||||
|
||||
@ -601,10 +582,9 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||
defer.Deferred[list[str, int, str|None, int]]: a list of
|
||||
(token, token id, device id) for each of the deleted tokens
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
keyvalues = {
|
||||
"user_id": user_id,
|
||||
}
|
||||
keyvalues = {"user_id": user_id}
|
||||
if device_id is not None:
|
||||
keyvalues["device_id"] = device_id
|
||||
|
||||
@ -616,8 +596,9 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||
values.append(except_token_id)
|
||||
|
||||
txn.execute(
|
||||
"SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause,
|
||||
values
|
||||
"SELECT token, id, device_id FROM access_tokens WHERE %s"
|
||||
% where_clause,
|
||||
values,
|
||||
)
|
||||
tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
|
||||
|
||||
@ -626,25 +607,16 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||
txn, self.get_user_by_access_token, (token,)
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
"DELETE FROM access_tokens WHERE %s" % where_clause,
|
||||
values
|
||||
)
|
||||
txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
|
||||
|
||||
return tokens_and_devices
|
||||
|
||||
return self.runInteraction(
|
||||
"user_delete_access_tokens", f,
|
||||
)
|
||||
return self.runInteraction("user_delete_access_tokens", f)
|
||||
|
||||
def delete_access_token(self, access_token):
|
||||
def f(txn):
|
||||
self._simple_delete_one_txn(
|
||||
txn,
|
||||
table="access_tokens",
|
||||
keyvalues={
|
||||
"token": access_token
|
||||
},
|
||||
txn, table="access_tokens", keyvalues={"token": access_token}
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
@ -667,7 +639,7 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def save_or_get_3pid_guest_access_token(
|
||||
self, medium, address, access_token, inviter_user_id
|
||||
self, medium, address, access_token, inviter_user_id
|
||||
):
|
||||
"""
|
||||
Gets the 3pid's guest access token if exists, else saves access_token.
|
||||
@ -683,12 +655,13 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||
deferred str: Whichever access token is persisted at the end
|
||||
of this function call.
|
||||
"""
|
||||
|
||||
def insert(txn):
|
||||
txn.execute(
|
||||
"INSERT INTO threepid_guest_access_tokens "
|
||||
"(medium, address, guest_access_token, first_inviter) "
|
||||
"VALUES (?, ?, ?, ?)",
|
||||
(medium, address, access_token, inviter_user_id)
|
||||
(medium, address, access_token, inviter_user_id),
|
||||
)
|
||||
|
||||
try:
|
||||
@ -705,9 +678,7 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||
"""
|
||||
return self._simple_insert(
|
||||
"users_pending_deactivation",
|
||||
values={
|
||||
"user_id": user_id,
|
||||
},
|
||||
values={"user_id": user_id},
|
||||
desc="add_user_pending_deactivation",
|
||||
)
|
||||
|
||||
@ -720,9 +691,7 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||
# the table, so somehow duplicate entries have ended up in it.
|
||||
return self._simple_delete(
|
||||
"users_pending_deactivation",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
},
|
||||
keyvalues={"user_id": user_id},
|
||||
desc="del_user_pending_deactivation",
|
||||
)
|
||||
|
||||
|
@ -36,9 +36,7 @@ class RejectionsStore(SQLBaseStore):
|
||||
return self._simple_select_one_onecol(
|
||||
table="rejections",
|
||||
retcol="reason",
|
||||
keyvalues={
|
||||
"event_id": event_id,
|
||||
},
|
||||
keyvalues={"event_id": event_id},
|
||||
allow_none=True,
|
||||
desc="get_rejection_reason",
|
||||
)
|
||||
|
@ -30,13 +30,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
OpsLevel = collections.namedtuple(
|
||||
"OpsLevel",
|
||||
("ban_level", "kick_level", "redact_level",)
|
||||
"OpsLevel", ("ban_level", "kick_level", "redact_level")
|
||||
)
|
||||
|
||||
RatelimitOverride = collections.namedtuple(
|
||||
"RatelimitOverride",
|
||||
("messages_per_second", "burst_count",)
|
||||
"RatelimitOverride", ("messages_per_second", "burst_count")
|
||||
)
|
||||
|
||||
|
||||
@ -60,9 +58,7 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
def get_public_room_ids(self):
|
||||
return self._simple_select_onecol(
|
||||
table="rooms",
|
||||
keyvalues={
|
||||
"is_public": True,
|
||||
},
|
||||
keyvalues={"is_public": True},
|
||||
retcol="room_id",
|
||||
desc="get_public_room_ids",
|
||||
)
|
||||
@ -79,11 +75,11 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
return self.runInteraction(
|
||||
"get_public_room_ids_at_stream_id",
|
||||
self.get_public_room_ids_at_stream_id_txn,
|
||||
stream_id, network_tuple=network_tuple
|
||||
stream_id,
|
||||
network_tuple=network_tuple,
|
||||
)
|
||||
|
||||
def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
|
||||
network_tuple):
|
||||
def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, network_tuple):
|
||||
return {
|
||||
rm
|
||||
for rm, vis in self.get_published_at_stream_id_txn(
|
||||
@ -96,7 +92,7 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
if network_tuple:
|
||||
# We want to get from a particular list. No aggregation required.
|
||||
|
||||
sql = ("""
|
||||
sql = """
|
||||
SELECT room_id, visibility FROM public_room_list_stream
|
||||
INNER JOIN (
|
||||
SELECT room_id, max(stream_id) AS stream_id
|
||||
@ -104,25 +100,22 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
WHERE stream_id <= ? %s
|
||||
GROUP BY room_id
|
||||
) grouped USING (room_id, stream_id)
|
||||
""")
|
||||
"""
|
||||
|
||||
if network_tuple.appservice_id is not None:
|
||||
txn.execute(
|
||||
sql % ("AND appservice_id = ? AND network_id = ?",),
|
||||
(stream_id, network_tuple.appservice_id, network_tuple.network_id,)
|
||||
(stream_id, network_tuple.appservice_id, network_tuple.network_id),
|
||||
)
|
||||
else:
|
||||
txn.execute(
|
||||
sql % ("AND appservice_id IS NULL",),
|
||||
(stream_id,)
|
||||
)
|
||||
txn.execute(sql % ("AND appservice_id IS NULL",), (stream_id,))
|
||||
return dict(txn)
|
||||
else:
|
||||
# We want to get from all lists, so we need to aggregate the results
|
||||
|
||||
logger.info("Executing full list")
|
||||
|
||||
sql = ("""
|
||||
sql = """
|
||||
SELECT room_id, visibility
|
||||
FROM public_room_list_stream
|
||||
INNER JOIN (
|
||||
@ -133,12 +126,9 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
WHERE stream_id <= ?
|
||||
GROUP BY room_id, appservice_id, network_id
|
||||
) grouped USING (room_id, stream_id)
|
||||
""")
|
||||
"""
|
||||
|
||||
txn.execute(
|
||||
sql,
|
||||
(stream_id,)
|
||||
)
|
||||
txn.execute(sql, (stream_id,))
|
||||
|
||||
results = {}
|
||||
# A room is visible if its visible on any list.
|
||||
@ -147,8 +137,7 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
|
||||
return results
|
||||
|
||||
def get_public_room_changes(self, prev_stream_id, new_stream_id,
|
||||
network_tuple):
|
||||
def get_public_room_changes(self, prev_stream_id, new_stream_id, network_tuple):
|
||||
def get_public_room_changes_txn(txn):
|
||||
then_rooms = self.get_public_room_ids_at_stream_id_txn(
|
||||
txn, prev_stream_id, network_tuple
|
||||
@ -158,9 +147,7 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
txn, new_stream_id, network_tuple
|
||||
)
|
||||
|
||||
now_rooms_visible = set(
|
||||
rm for rm, vis in now_rooms_dict.items() if vis
|
||||
)
|
||||
now_rooms_visible = set(rm for rm, vis in now_rooms_dict.items() if vis)
|
||||
now_rooms_not_visible = set(
|
||||
rm for rm, vis in now_rooms_dict.items() if not vis
|
||||
)
|
||||
@ -178,9 +165,7 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
def is_room_blocked(self, room_id):
|
||||
return self._simple_select_one_onecol(
|
||||
table="blocked_rooms",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
},
|
||||
keyvalues={"room_id": room_id},
|
||||
retcol="1",
|
||||
allow_none=True,
|
||||
desc="is_room_blocked",
|
||||
@ -208,16 +193,17 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
if row:
|
||||
defer.returnValue(RatelimitOverride(
|
||||
messages_per_second=row["messages_per_second"],
|
||||
burst_count=row["burst_count"],
|
||||
))
|
||||
defer.returnValue(
|
||||
RatelimitOverride(
|
||||
messages_per_second=row["messages_per_second"],
|
||||
burst_count=row["burst_count"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
defer.returnValue(None)
|
||||
|
||||
|
||||
class RoomStore(RoomWorkerStore, SearchStore):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def store_room(self, room_id, room_creator_user_id, is_public):
|
||||
"""Stores a room.
|
||||
@ -231,6 +217,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
|
||||
StoreError if the room could not be stored.
|
||||
"""
|
||||
try:
|
||||
|
||||
def store_room_txn(txn, next_id):
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
@ -249,13 +236,11 @@ class RoomStore(RoomWorkerStore, SearchStore):
|
||||
"stream_id": next_id,
|
||||
"room_id": room_id,
|
||||
"visibility": is_public,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
with self._public_room_id_gen.get_next() as next_id:
|
||||
yield self.runInteraction(
|
||||
"store_room_txn",
|
||||
store_room_txn, next_id,
|
||||
)
|
||||
yield self.runInteraction("store_room_txn", store_room_txn, next_id)
|
||||
except Exception as e:
|
||||
logger.error("store_room with room_id=%s failed: %s", room_id, e)
|
||||
raise StoreError(500, "Problem creating room.")
|
||||
@ -297,19 +282,19 @@ class RoomStore(RoomWorkerStore, SearchStore):
|
||||
"visibility": is_public,
|
||||
"appservice_id": None,
|
||||
"network_id": None,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
with self._public_room_id_gen.get_next() as next_id:
|
||||
yield self.runInteraction(
|
||||
"set_room_is_public",
|
||||
set_room_is_public_txn, next_id,
|
||||
"set_room_is_public", set_room_is_public_txn, next_id
|
||||
)
|
||||
self.hs.get_notifier().on_new_replication_data()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_room_is_public_appservice(self, room_id, appservice_id, network_id,
|
||||
is_public):
|
||||
def set_room_is_public_appservice(
|
||||
self, room_id, appservice_id, network_id, is_public
|
||||
):
|
||||
"""Edit the appservice/network specific public room list.
|
||||
|
||||
Each appservice can have a number of published room lists associated
|
||||
@ -324,6 +309,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
|
||||
is_public (bool): Whether to publish or unpublish the room from the
|
||||
list.
|
||||
"""
|
||||
|
||||
def set_room_is_public_appservice_txn(txn, next_id):
|
||||
if is_public:
|
||||
try:
|
||||
@ -333,7 +319,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
|
||||
values={
|
||||
"appservice_id": appservice_id,
|
||||
"network_id": network_id,
|
||||
"room_id": room_id
|
||||
"room_id": room_id,
|
||||
},
|
||||
)
|
||||
except self.database_engine.module.IntegrityError:
|
||||
@ -346,7 +332,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
|
||||
keyvalues={
|
||||
"appservice_id": appservice_id,
|
||||
"network_id": network_id,
|
||||
"room_id": room_id
|
||||
"room_id": room_id,
|
||||
},
|
||||
)
|
||||
|
||||
@ -377,13 +363,14 @@ class RoomStore(RoomWorkerStore, SearchStore):
|
||||
"visibility": is_public,
|
||||
"appservice_id": appservice_id,
|
||||
"network_id": network_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
with self._public_room_id_gen.get_next() as next_id:
|
||||
yield self.runInteraction(
|
||||
"set_room_is_public_appservice",
|
||||
set_room_is_public_appservice_txn, next_id,
|
||||
set_room_is_public_appservice_txn,
|
||||
next_id,
|
||||
)
|
||||
self.hs.get_notifier().on_new_replication_data()
|
||||
|
||||
@ -397,9 +384,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
|
||||
row = txn.fetchone()
|
||||
return row[0] or 0
|
||||
|
||||
return self.runInteraction(
|
||||
"get_rooms", f
|
||||
)
|
||||
return self.runInteraction("get_rooms", f)
|
||||
|
||||
def _store_room_topic_txn(self, txn, event):
|
||||
if hasattr(event, "content") and "topic" in event.content:
|
||||
@ -414,7 +399,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
|
||||
)
|
||||
|
||||
self.store_event_search_txn(
|
||||
txn, event, "content.topic", event.content["topic"],
|
||||
txn, event, "content.topic", event.content["topic"]
|
||||
)
|
||||
|
||||
def _store_room_name_txn(self, txn, event):
|
||||
@ -426,17 +411,17 @@ class RoomStore(RoomWorkerStore, SearchStore):
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"name": event.content["name"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
self.store_event_search_txn(
|
||||
txn, event, "content.name", event.content["name"],
|
||||
txn, event, "content.name", event.content["name"]
|
||||
)
|
||||
|
||||
def _store_room_message_txn(self, txn, event):
|
||||
if hasattr(event, "content") and "body" in event.content:
|
||||
self.store_event_search_txn(
|
||||
txn, event, "content.body", event.content["body"],
|
||||
txn, event, "content.body", event.content["body"]
|
||||
)
|
||||
|
||||
def _store_history_visibility_txn(self, txn, event):
|
||||
@ -452,14 +437,11 @@ class RoomStore(RoomWorkerStore, SearchStore):
|
||||
" (event_id, room_id, %(key)s)"
|
||||
" VALUES (?, ?, ?)" % {"key": key}
|
||||
)
|
||||
txn.execute(sql, (
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
event.content[key]
|
||||
))
|
||||
txn.execute(sql, (event.event_id, event.room_id, event.content[key]))
|
||||
|
||||
def add_event_report(self, room_id, event_id, user_id, reason, content,
|
||||
received_ts):
|
||||
def add_event_report(
|
||||
self, room_id, event_id, user_id, reason, content, received_ts
|
||||
):
|
||||
next_id = self._event_reports_id_gen.get_next()
|
||||
return self._simple_insert(
|
||||
table="event_reports",
|
||||
@ -472,7 +454,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
|
||||
"reason": reason,
|
||||
"content": json.dumps(content),
|
||||
},
|
||||
desc="add_event_report"
|
||||
desc="add_event_report",
|
||||
)
|
||||
|
||||
def get_current_public_room_stream_id(self):
|
||||
@ -480,23 +462,21 @@ class RoomStore(RoomWorkerStore, SearchStore):
|
||||
|
||||
def get_all_new_public_rooms(self, prev_id, current_id, limit):
|
||||
def get_all_new_public_rooms(txn):
|
||||
sql = ("""
|
||||
sql = """
|
||||
SELECT stream_id, room_id, visibility, appservice_id, network_id
|
||||
FROM public_room_list_stream
|
||||
WHERE stream_id > ? AND stream_id <= ?
|
||||
ORDER BY stream_id ASC
|
||||
LIMIT ?
|
||||
""")
|
||||
"""
|
||||
|
||||
txn.execute(sql, (prev_id, current_id, limit,))
|
||||
txn.execute(sql, (prev_id, current_id, limit))
|
||||
return txn.fetchall()
|
||||
|
||||
if prev_id == current_id:
|
||||
return defer.succeed([])
|
||||
|
||||
return self.runInteraction(
|
||||
"get_all_new_public_rooms", get_all_new_public_rooms
|
||||
)
|
||||
return self.runInteraction("get_all_new_public_rooms", get_all_new_public_rooms)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def block_room(self, room_id, user_id):
|
||||
@ -511,19 +491,16 @@ class RoomStore(RoomWorkerStore, SearchStore):
|
||||
"""
|
||||
yield self._simple_upsert(
|
||||
table="blocked_rooms",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
},
|
||||
keyvalues={"room_id": room_id},
|
||||
values={},
|
||||
insertion_values={
|
||||
"user_id": user_id,
|
||||
},
|
||||
insertion_values={"user_id": user_id},
|
||||
desc="block_room",
|
||||
)
|
||||
yield self.runInteraction(
|
||||
"block_room_invalidation",
|
||||
self._invalidate_cache_and_stream,
|
||||
self.is_room_blocked, (room_id,),
|
||||
self.is_room_blocked,
|
||||
(room_id,),
|
||||
)
|
||||
|
||||
def get_media_mxcs_in_room(self, room_id):
|
||||
@ -536,6 +513,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
|
||||
The local and remote media as a lists of tuples where the key is
|
||||
the hostname and the value is the media ID.
|
||||
"""
|
||||
|
||||
def _get_media_mxcs_in_room_txn(txn):
|
||||
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
|
||||
local_media_mxcs = []
|
||||
@ -548,23 +526,28 @@ class RoomStore(RoomWorkerStore, SearchStore):
|
||||
remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
|
||||
|
||||
return local_media_mxcs, remote_media_mxcs
|
||||
|
||||
return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
|
||||
|
||||
def quarantine_media_ids_in_room(self, room_id, quarantined_by):
|
||||
"""For a room loops through all events with media and quarantines
|
||||
the associated media
|
||||
"""
|
||||
|
||||
def _quarantine_media_in_room_txn(txn):
|
||||
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
|
||||
total_media_quarantined = 0
|
||||
|
||||
# Now update all the tables to set the quarantined_by flag
|
||||
|
||||
txn.executemany("""
|
||||
txn.executemany(
|
||||
"""
|
||||
UPDATE local_media_repository
|
||||
SET quarantined_by = ?
|
||||
WHERE media_id = ?
|
||||
""", ((quarantined_by, media_id) for media_id in local_mxcs))
|
||||
""",
|
||||
((quarantined_by, media_id) for media_id in local_mxcs),
|
||||
)
|
||||
|
||||
txn.executemany(
|
||||
"""
|
||||
@ -575,7 +558,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
|
||||
(
|
||||
(quarantined_by, origin, media_id)
|
||||
for origin, media_id in remote_mxcs
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
total_media_quarantined += len(local_mxcs)
|
||||
@ -584,8 +567,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
|
||||
return total_media_quarantined
|
||||
|
||||
return self.runInteraction(
|
||||
"quarantine_media_in_room",
|
||||
_quarantine_media_in_room_txn,
|
||||
"quarantine_media_in_room", _quarantine_media_in_room_txn
|
||||
)
|
||||
|
||||
def _get_media_mxcs_in_room_txn(self, txn, room_id):
|
||||
|
@ -30,10 +30,10 @@ from .background_updates import BackgroundUpdateStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SearchEntry = namedtuple('SearchEntry', [
|
||||
'key', 'value', 'event_id', 'room_id', 'stream_ordering',
|
||||
'origin_server_ts',
|
||||
])
|
||||
SearchEntry = namedtuple(
|
||||
'SearchEntry',
|
||||
['key', 'value', 'event_id', 'room_id', 'stream_ordering', 'origin_server_ts'],
|
||||
)
|
||||
|
||||
|
||||
class SearchStore(BackgroundUpdateStore):
|
||||
@ -53,8 +53,7 @@ class SearchStore(BackgroundUpdateStore):
|
||||
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
|
||||
)
|
||||
self.register_background_update_handler(
|
||||
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
|
||||
self._background_reindex_search_order
|
||||
self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order
|
||||
)
|
||||
|
||||
# we used to have a background update to turn the GIN index into a
|
||||
@ -62,13 +61,10 @@ class SearchStore(BackgroundUpdateStore):
|
||||
# a GIN index. However, it's possible that some people might still have
|
||||
# the background update queued, so we register a handler to clear the
|
||||
# background update.
|
||||
self.register_noop_background_update(
|
||||
self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME,
|
||||
)
|
||||
self.register_noop_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME)
|
||||
|
||||
self.register_background_update_handler(
|
||||
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME,
|
||||
self._background_reindex_gin_search
|
||||
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -138,21 +134,23 @@ class SearchStore(BackgroundUpdateStore):
|
||||
# then skip over it
|
||||
continue
|
||||
|
||||
event_search_rows.append(SearchEntry(
|
||||
key=key,
|
||||
value=value,
|
||||
event_id=event_id,
|
||||
room_id=room_id,
|
||||
stream_ordering=stream_ordering,
|
||||
origin_server_ts=origin_server_ts,
|
||||
))
|
||||
event_search_rows.append(
|
||||
SearchEntry(
|
||||
key=key,
|
||||
value=value,
|
||||
event_id=event_id,
|
||||
room_id=room_id,
|
||||
stream_ordering=stream_ordering,
|
||||
origin_server_ts=origin_server_ts,
|
||||
)
|
||||
)
|
||||
|
||||
self.store_search_entries_txn(txn, event_search_rows)
|
||||
|
||||
progress = {
|
||||
"target_min_stream_id_inclusive": target_min_stream_id,
|
||||
"max_stream_id_exclusive": min_stream_id,
|
||||
"rows_inserted": rows_inserted + len(event_search_rows)
|
||||
"rows_inserted": rows_inserted + len(event_search_rows),
|
||||
}
|
||||
|
||||
self._background_update_progress_txn(
|
||||
@ -191,6 +189,7 @@ class SearchStore(BackgroundUpdateStore):
|
||||
# doesn't support CREATE INDEX IF EXISTS so we just catch the
|
||||
# exception and ignore it.
|
||||
import psycopg2
|
||||
|
||||
try:
|
||||
c.execute(
|
||||
"CREATE INDEX CONCURRENTLY event_search_fts_idx"
|
||||
@ -198,14 +197,11 @@ class SearchStore(BackgroundUpdateStore):
|
||||
)
|
||||
except psycopg2.ProgrammingError as e:
|
||||
logger.warn(
|
||||
"Ignoring error %r when trying to switch from GIST to GIN",
|
||||
e
|
||||
"Ignoring error %r when trying to switch from GIST to GIN", e
|
||||
)
|
||||
|
||||
# we should now be able to delete the GIST index.
|
||||
c.execute(
|
||||
"DROP INDEX IF EXISTS event_search_fts_idx_gist"
|
||||
)
|
||||
c.execute("DROP INDEX IF EXISTS event_search_fts_idx_gist")
|
||||
finally:
|
||||
conn.set_session(autocommit=False)
|
||||
|
||||
@ -223,6 +219,7 @@ class SearchStore(BackgroundUpdateStore):
|
||||
have_added_index = progress['have_added_indexes']
|
||||
|
||||
if not have_added_index:
|
||||
|
||||
def create_index(conn):
|
||||
conn.rollback()
|
||||
conn.set_session(autocommit=True)
|
||||
@ -248,7 +245,8 @@ class SearchStore(BackgroundUpdateStore):
|
||||
yield self.runInteraction(
|
||||
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
|
||||
self._background_update_progress_txn,
|
||||
self.EVENT_SEARCH_ORDER_UPDATE_NAME, pg,
|
||||
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
|
||||
pg,
|
||||
)
|
||||
|
||||
def reindex_search_txn(txn):
|
||||
@ -302,14 +300,16 @@ class SearchStore(BackgroundUpdateStore):
|
||||
"""
|
||||
self.store_search_entries_txn(
|
||||
txn,
|
||||
(SearchEntry(
|
||||
key=key,
|
||||
value=value,
|
||||
event_id=event.event_id,
|
||||
room_id=event.room_id,
|
||||
stream_ordering=event.internal_metadata.stream_ordering,
|
||||
origin_server_ts=event.origin_server_ts,
|
||||
),),
|
||||
(
|
||||
SearchEntry(
|
||||
key=key,
|
||||
value=value,
|
||||
event_id=event.event_id,
|
||||
room_id=event.room_id,
|
||||
stream_ordering=event.internal_metadata.stream_ordering,
|
||||
origin_server_ts=event.origin_server_ts,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def store_search_entries_txn(self, txn, entries):
|
||||
@ -329,10 +329,17 @@ class SearchStore(BackgroundUpdateStore):
|
||||
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
|
||||
)
|
||||
|
||||
args = ((
|
||||
entry.event_id, entry.room_id, entry.key, entry.value,
|
||||
entry.stream_ordering, entry.origin_server_ts,
|
||||
) for entry in entries)
|
||||
args = (
|
||||
(
|
||||
entry.event_id,
|
||||
entry.room_id,
|
||||
entry.key,
|
||||
entry.value,
|
||||
entry.stream_ordering,
|
||||
entry.origin_server_ts,
|
||||
)
|
||||
for entry in entries
|
||||
)
|
||||
|
||||
# inserts to a GIN index are normally batched up into a pending
|
||||
# list, and then all committed together once the list gets to a
|
||||
@ -363,9 +370,10 @@ class SearchStore(BackgroundUpdateStore):
|
||||
"INSERT INTO event_search (event_id, room_id, key, value)"
|
||||
" VALUES (?,?,?,?)"
|
||||
)
|
||||
args = ((
|
||||
entry.event_id, entry.room_id, entry.key, entry.value,
|
||||
) for entry in entries)
|
||||
args = (
|
||||
(entry.event_id, entry.room_id, entry.key, entry.value)
|
||||
for entry in entries
|
||||
)
|
||||
|
||||
txn.executemany(sql, args)
|
||||
else:
|
||||
@ -394,9 +402,7 @@ class SearchStore(BackgroundUpdateStore):
|
||||
# Make sure we don't explode because the person is in too many rooms.
|
||||
# We filter the results below regardless.
|
||||
if len(room_ids) < 500:
|
||||
clauses.append(
|
||||
"room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
|
||||
)
|
||||
clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),))
|
||||
args.extend(room_ids)
|
||||
|
||||
local_clauses = []
|
||||
@ -404,9 +410,7 @@ class SearchStore(BackgroundUpdateStore):
|
||||
local_clauses.append("key = ?")
|
||||
args.append(key)
|
||||
|
||||
clauses.append(
|
||||
"(%s)" % (" OR ".join(local_clauses),)
|
||||
)
|
||||
clauses.append("(%s)" % (" OR ".join(local_clauses),))
|
||||
|
||||
count_args = args
|
||||
count_clauses = clauses
|
||||
@ -452,18 +456,13 @@ class SearchStore(BackgroundUpdateStore):
|
||||
# entire table from the database.
|
||||
sql += " ORDER BY rank DESC LIMIT 500"
|
||||
|
||||
results = yield self._execute(
|
||||
"search_msgs", self.cursor_to_dict, sql, *args
|
||||
)
|
||||
results = yield self._execute("search_msgs", self.cursor_to_dict, sql, *args)
|
||||
|
||||
results = list(filter(lambda row: row["room_id"] in room_ids, results))
|
||||
|
||||
events = yield self._get_events([r["event_id"] for r in results])
|
||||
|
||||
event_map = {
|
||||
ev.event_id: ev
|
||||
for ev in events
|
||||
}
|
||||
event_map = {ev.event_id: ev for ev in events}
|
||||
|
||||
highlights = None
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
@ -477,18 +476,17 @@ class SearchStore(BackgroundUpdateStore):
|
||||
|
||||
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
|
||||
|
||||
defer.returnValue({
|
||||
"results": [
|
||||
{
|
||||
"event": event_map[r["event_id"]],
|
||||
"rank": r["rank"],
|
||||
}
|
||||
for r in results
|
||||
if r["event_id"] in event_map
|
||||
],
|
||||
"highlights": highlights,
|
||||
"count": count,
|
||||
})
|
||||
defer.returnValue(
|
||||
{
|
||||
"results": [
|
||||
{"event": event_map[r["event_id"]], "rank": r["rank"]}
|
||||
for r in results
|
||||
if r["event_id"] in event_map
|
||||
],
|
||||
"highlights": highlights,
|
||||
"count": count,
|
||||
}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None):
|
||||
@ -513,9 +511,7 @@ class SearchStore(BackgroundUpdateStore):
|
||||
# Make sure we don't explode because the person is in too many rooms.
|
||||
# We filter the results below regardless.
|
||||
if len(room_ids) < 500:
|
||||
clauses.append(
|
||||
"room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
|
||||
)
|
||||
clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),))
|
||||
args.extend(room_ids)
|
||||
|
||||
local_clauses = []
|
||||
@ -523,9 +519,7 @@ class SearchStore(BackgroundUpdateStore):
|
||||
local_clauses.append("key = ?")
|
||||
args.append(key)
|
||||
|
||||
clauses.append(
|
||||
"(%s)" % (" OR ".join(local_clauses),)
|
||||
)
|
||||
clauses.append("(%s)" % (" OR ".join(local_clauses),))
|
||||
|
||||
# take copies of the current args and clauses lists, before adding
|
||||
# pagination clauses to main query.
|
||||
@ -607,18 +601,13 @@ class SearchStore(BackgroundUpdateStore):
|
||||
|
||||
args.append(limit)
|
||||
|
||||
results = yield self._execute(
|
||||
"search_rooms", self.cursor_to_dict, sql, *args
|
||||
)
|
||||
results = yield self._execute("search_rooms", self.cursor_to_dict, sql, *args)
|
||||
|
||||
results = list(filter(lambda row: row["room_id"] in room_ids, results))
|
||||
|
||||
events = yield self._get_events([r["event_id"] for r in results])
|
||||
|
||||
event_map = {
|
||||
ev.event_id: ev
|
||||
for ev in events
|
||||
}
|
||||
event_map = {ev.event_id: ev for ev in events}
|
||||
|
||||
highlights = None
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
@ -632,21 +621,22 @@ class SearchStore(BackgroundUpdateStore):
|
||||
|
||||
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
|
||||
|
||||
defer.returnValue({
|
||||
"results": [
|
||||
{
|
||||
"event": event_map[r["event_id"]],
|
||||
"rank": r["rank"],
|
||||
"pagination_token": "%s,%s" % (
|
||||
r["origin_server_ts"], r["stream_ordering"]
|
||||
),
|
||||
}
|
||||
for r in results
|
||||
if r["event_id"] in event_map
|
||||
],
|
||||
"highlights": highlights,
|
||||
"count": count,
|
||||
})
|
||||
defer.returnValue(
|
||||
{
|
||||
"results": [
|
||||
{
|
||||
"event": event_map[r["event_id"]],
|
||||
"rank": r["rank"],
|
||||
"pagination_token": "%s,%s"
|
||||
% (r["origin_server_ts"], r["stream_ordering"]),
|
||||
}
|
||||
for r in results
|
||||
if r["event_id"] in event_map
|
||||
],
|
||||
"highlights": highlights,
|
||||
"count": count,
|
||||
}
|
||||
)
|
||||
|
||||
def _find_highlights_in_postgres(self, search_query, events):
|
||||
"""Given a list of events and a search term, return a list of words
|
||||
@ -662,6 +652,7 @@ class SearchStore(BackgroundUpdateStore):
|
||||
Returns:
|
||||
deferred : A set of strings.
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
highlight_words = set()
|
||||
for event in events:
|
||||
@ -689,13 +680,15 @@ class SearchStore(BackgroundUpdateStore):
|
||||
stop_sel += ">"
|
||||
|
||||
query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % (
|
||||
_to_postgres_options({
|
||||
"StartSel": start_sel,
|
||||
"StopSel": stop_sel,
|
||||
"MaxFragments": "50",
|
||||
})
|
||||
_to_postgres_options(
|
||||
{
|
||||
"StartSel": start_sel,
|
||||
"StopSel": stop_sel,
|
||||
"MaxFragments": "50",
|
||||
}
|
||||
)
|
||||
)
|
||||
txn.execute(query, (value, search_query,))
|
||||
txn.execute(query, (value, search_query))
|
||||
headline, = txn.fetchall()[0]
|
||||
|
||||
# Now we need to pick the possible highlights out of the haedline
|
||||
@ -714,9 +707,7 @@ class SearchStore(BackgroundUpdateStore):
|
||||
|
||||
|
||||
def _to_postgres_options(options_dict):
|
||||
return "'%s'" % (
|
||||
",".join("%s=%s" % (k, v) for k, v in options_dict.items()),
|
||||
)
|
||||
return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
|
||||
|
||||
|
||||
def _parse_query(database_engine, search_term):
|
||||
|
@ -39,8 +39,9 @@ class SignatureWorkerStore(SQLBaseStore):
|
||||
# to use its cache
|
||||
raise NotImplementedError()
|
||||
|
||||
@cachedList(cached_method_name="get_event_reference_hash",
|
||||
list_name="event_ids", num_args=1)
|
||||
@cachedList(
|
||||
cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1
|
||||
)
|
||||
def get_event_reference_hashes(self, event_ids):
|
||||
def f(txn):
|
||||
return {
|
||||
@ -48,21 +49,13 @@ class SignatureWorkerStore(SQLBaseStore):
|
||||
for event_id in event_ids
|
||||
}
|
||||
|
||||
return self.runInteraction(
|
||||
"get_event_reference_hashes",
|
||||
f
|
||||
)
|
||||
return self.runInteraction("get_event_reference_hashes", f)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_event_hashes(self, event_ids):
|
||||
hashes = yield self.get_event_reference_hashes(
|
||||
event_ids
|
||||
)
|
||||
hashes = yield self.get_event_reference_hashes(event_ids)
|
||||
hashes = {
|
||||
e_id: {
|
||||
k: encode_base64(v) for k, v in h.items()
|
||||
if k == "sha256"
|
||||
}
|
||||
e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
|
||||
for e_id, h in hashes.items()
|
||||
}
|
||||
|
||||
@ -81,7 +74,7 @@ class SignatureWorkerStore(SQLBaseStore):
|
||||
" FROM event_reference_hashes"
|
||||
" WHERE event_id = ?"
|
||||
)
|
||||
txn.execute(query, (event_id, ))
|
||||
txn.execute(query, (event_id,))
|
||||
return {k: v for k, v in txn}
|
||||
|
||||
|
||||
@ -98,14 +91,12 @@ class SignatureStore(SignatureWorkerStore):
|
||||
vals = []
|
||||
for event in events:
|
||||
ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
|
||||
vals.append({
|
||||
"event_id": event.event_id,
|
||||
"algorithm": ref_alg,
|
||||
"hash": db_binary_type(ref_hash_bytes),
|
||||
})
|
||||
vals.append(
|
||||
{
|
||||
"event_id": event.event_id,
|
||||
"algorithm": ref_alg,
|
||||
"hash": db_binary_type(ref_hash_bytes),
|
||||
}
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="event_reference_hashes",
|
||||
values=vals,
|
||||
)
|
||||
self._simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
|
||||
|
@ -40,10 +40,13 @@ logger = logging.getLogger(__name__)
|
||||
MAX_STATE_DELTA_HOPS = 100
|
||||
|
||||
|
||||
class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))):
|
||||
class _GetStateGroupDelta(
|
||||
namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
|
||||
):
|
||||
"""Return type of get_state_group_delta that implements __len__, which lets
|
||||
us use the itrable flag when caching
|
||||
"""
|
||||
|
||||
__slots__ = []
|
||||
|
||||
def __len__(self):
|
||||
@ -70,10 +73,7 @@ class StateFilter(object):
|
||||
# If `include_others` is set we canonicalise the filter by removing
|
||||
# wildcards from the types dictionary
|
||||
if self.include_others:
|
||||
self.types = {
|
||||
k: v for k, v in iteritems(self.types)
|
||||
if v is not None
|
||||
}
|
||||
self.types = {k: v for k, v in iteritems(self.types) if v is not None}
|
||||
|
||||
@staticmethod
|
||||
def all():
|
||||
@ -130,10 +130,7 @@ class StateFilter(object):
|
||||
Returns:
|
||||
StateFilter
|
||||
"""
|
||||
return StateFilter(
|
||||
types={EventTypes.Member: set(members)},
|
||||
include_others=True,
|
||||
)
|
||||
return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
|
||||
|
||||
def return_expanded(self):
|
||||
"""Creates a new StateFilter where type wild cards have been removed
|
||||
@ -243,9 +240,7 @@ class StateFilter(object):
|
||||
if where_clause:
|
||||
where_clause += " OR "
|
||||
|
||||
where_clause += "type NOT IN (%s)" % (
|
||||
",".join(["?"] * len(self.types)),
|
||||
)
|
||||
where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),)
|
||||
where_args.extend(self.types)
|
||||
|
||||
return where_clause, where_args
|
||||
@ -305,12 +300,8 @@ class StateFilter(object):
|
||||
bool
|
||||
"""
|
||||
|
||||
return (
|
||||
self.include_others
|
||||
or any(
|
||||
state_keys is None
|
||||
for state_keys in itervalues(self.types)
|
||||
)
|
||||
return self.include_others or any(
|
||||
state_keys is None for state_keys in itervalues(self.types)
|
||||
)
|
||||
|
||||
def concrete_types(self):
|
||||
@ -406,11 +397,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
self._state_group_cache = DictionaryCache(
|
||||
"*stateGroupCache*",
|
||||
# TODO: this hasn't been tuned yet
|
||||
50000 * get_cache_factor_for("stateGroupCache")
|
||||
50000 * get_cache_factor_for("stateGroupCache"),
|
||||
)
|
||||
self._state_group_members_cache = DictionaryCache(
|
||||
"*stateGroupMembersCache*",
|
||||
500000 * get_cache_factor_for("stateGroupMembersCache")
|
||||
500000 * get_cache_factor_for("stateGroupMembersCache"),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -488,22 +479,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
Returns:
|
||||
deferred: dict of (type, state_key) -> event_id
|
||||
"""
|
||||
|
||||
def _get_current_state_ids_txn(txn):
|
||||
txn.execute(
|
||||
"""SELECT type, state_key, event_id FROM current_state_events
|
||||
WHERE room_id = ?
|
||||
""",
|
||||
(room_id,)
|
||||
(room_id,),
|
||||
)
|
||||
|
||||
return {
|
||||
(intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
|
||||
}
|
||||
|
||||
return self.runInteraction(
|
||||
"get_current_state_ids",
|
||||
_get_current_state_ids_txn,
|
||||
)
|
||||
return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn)
|
||||
|
||||
# FIXME: how should this be cached?
|
||||
def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
|
||||
@ -544,8 +533,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
return results
|
||||
|
||||
return self.runInteraction(
|
||||
"get_filtered_current_state_ids",
|
||||
_get_filtered_current_state_ids_txn,
|
||||
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -559,9 +547,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
Deferred[str|None]: The canonical alias, if any
|
||||
"""
|
||||
|
||||
state = yield self.get_filtered_current_state_ids(room_id, StateFilter.from_types(
|
||||
[(EventTypes.CanonicalAlias, "")]
|
||||
))
|
||||
state = yield self.get_filtered_current_state_ids(
|
||||
room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
|
||||
)
|
||||
|
||||
event_id = state.get((EventTypes.CanonicalAlias, ""))
|
||||
if not event_id:
|
||||
@ -581,13 +569,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
Returns:
|
||||
(prev_group, delta_ids), where both may be None.
|
||||
"""
|
||||
|
||||
def _get_state_group_delta_txn(txn):
|
||||
prev_group = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="state_group_edges",
|
||||
keyvalues={
|
||||
"state_group": state_group,
|
||||
},
|
||||
keyvalues={"state_group": state_group},
|
||||
retcol="prev_state_group",
|
||||
allow_none=True,
|
||||
)
|
||||
@ -598,20 +585,16 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
delta_ids = self._simple_select_list_txn(
|
||||
txn,
|
||||
table="state_groups_state",
|
||||
keyvalues={
|
||||
"state_group": state_group,
|
||||
},
|
||||
retcols=("type", "state_key", "event_id",)
|
||||
keyvalues={"state_group": state_group},
|
||||
retcols=("type", "state_key", "event_id"),
|
||||
)
|
||||
|
||||
return _GetStateGroupDelta(prev_group, {
|
||||
(row["type"], row["state_key"]): row["event_id"]
|
||||
for row in delta_ids
|
||||
})
|
||||
return self.runInteraction(
|
||||
"get_state_group_delta",
|
||||
_get_state_group_delta_txn,
|
||||
)
|
||||
return _GetStateGroupDelta(
|
||||
prev_group,
|
||||
{(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
|
||||
)
|
||||
|
||||
return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_groups_ids(self, _room_id, event_ids):
|
||||
@ -628,9 +611,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
if not event_ids:
|
||||
defer.returnValue({})
|
||||
|
||||
event_to_groups = yield self._get_state_group_for_events(
|
||||
event_ids,
|
||||
)
|
||||
event_to_groups = yield self._get_state_group_for_events(event_ids)
|
||||
|
||||
groups = set(itervalues(event_to_groups))
|
||||
group_to_state = yield self._get_state_for_groups(groups)
|
||||
@ -666,19 +647,23 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
|
||||
state_event_map = yield self.get_events(
|
||||
[
|
||||
ev_id for group_ids in itervalues(group_to_ids)
|
||||
ev_id
|
||||
for group_ids in itervalues(group_to_ids)
|
||||
for ev_id in itervalues(group_ids)
|
||||
],
|
||||
get_prev_content=False
|
||||
get_prev_content=False,
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
group: [
|
||||
state_event_map[v] for v in itervalues(event_id_map)
|
||||
if v in state_event_map
|
||||
]
|
||||
for group, event_id_map in iteritems(group_to_ids)
|
||||
})
|
||||
defer.returnValue(
|
||||
{
|
||||
group: [
|
||||
state_event_map[v]
|
||||
for v in itervalues(event_id_map)
|
||||
if v in state_event_map
|
||||
]
|
||||
for group, event_id_map in iteritems(group_to_ids)
|
||||
}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_state_groups_from_groups(self, groups, state_filter):
|
||||
@ -695,18 +680,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
"""
|
||||
results = {}
|
||||
|
||||
chunks = [groups[i:i + 100] for i in range(0, len(groups), 100)]
|
||||
chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
|
||||
for chunk in chunks:
|
||||
res = yield self.runInteraction(
|
||||
"_get_state_groups_from_groups",
|
||||
self._get_state_groups_from_groups_txn, chunk, state_filter,
|
||||
self._get_state_groups_from_groups_txn,
|
||||
chunk,
|
||||
state_filter,
|
||||
)
|
||||
results.update(res)
|
||||
|
||||
defer.returnValue(results)
|
||||
|
||||
def _get_state_groups_from_groups_txn(
|
||||
self, txn, groups, state_filter=StateFilter.all(),
|
||||
self, txn, groups, state_filter=StateFilter.all()
|
||||
):
|
||||
results = {group: {} for group in groups}
|
||||
|
||||
@ -776,7 +763,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
txn.execute(
|
||||
"SELECT type, state_key, event_id FROM state_groups_state"
|
||||
" WHERE state_group = ? " + where_clause,
|
||||
args
|
||||
args,
|
||||
)
|
||||
results[group].update(
|
||||
((typ, state_key), event_id)
|
||||
@ -791,8 +778,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
# wildcards (i.e. Nones) in which case we have to do an exhaustive
|
||||
# search
|
||||
if (
|
||||
max_entries_returned is not None and
|
||||
len(results[group]) == max_entries_returned
|
||||
max_entries_returned is not None
|
||||
and len(results[group]) == max_entries_returned
|
||||
):
|
||||
break
|
||||
|
||||
@ -819,16 +806,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
Returns:
|
||||
deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
|
||||
"""
|
||||
event_to_groups = yield self._get_state_group_for_events(
|
||||
event_ids,
|
||||
)
|
||||
event_to_groups = yield self._get_state_group_for_events(event_ids)
|
||||
|
||||
groups = set(itervalues(event_to_groups))
|
||||
group_to_state = yield self._get_state_for_groups(groups, state_filter)
|
||||
|
||||
state_event_map = yield self.get_events(
|
||||
[ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
|
||||
get_prev_content=False
|
||||
get_prev_content=False,
|
||||
)
|
||||
|
||||
event_to_state = {
|
||||
@ -856,9 +841,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
Returns:
|
||||
A deferred dict from event_id -> (type, state_key) -> event_id
|
||||
"""
|
||||
event_to_groups = yield self._get_state_group_for_events(
|
||||
event_ids,
|
||||
)
|
||||
event_to_groups = yield self._get_state_group_for_events(event_ids)
|
||||
|
||||
groups = set(itervalues(event_to_groups))
|
||||
group_to_state = yield self._get_state_for_groups(groups, state_filter)
|
||||
@ -906,16 +889,18 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
def _get_state_group_for_event(self, event_id):
|
||||
return self._simple_select_one_onecol(
|
||||
table="event_to_state_groups",
|
||||
keyvalues={
|
||||
"event_id": event_id,
|
||||
},
|
||||
keyvalues={"event_id": event_id},
|
||||
retcol="state_group",
|
||||
allow_none=True,
|
||||
desc="_get_state_group_for_event",
|
||||
)
|
||||
|
||||
@cachedList(cached_method_name="_get_state_group_for_event",
|
||||
list_name="event_ids", num_args=1, inlineCallbacks=True)
|
||||
@cachedList(
|
||||
cached_method_name="_get_state_group_for_event",
|
||||
list_name="event_ids",
|
||||
num_args=1,
|
||||
inlineCallbacks=True,
|
||||
)
|
||||
def _get_state_group_for_events(self, event_ids):
|
||||
"""Returns mapping event_id -> state_group
|
||||
"""
|
||||
@ -924,7 +909,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
column="event_id",
|
||||
iterable=event_ids,
|
||||
keyvalues={},
|
||||
retcols=("event_id", "state_group",),
|
||||
retcols=("event_id", "state_group"),
|
||||
desc="_get_state_group_for_events",
|
||||
)
|
||||
|
||||
@ -989,15 +974,13 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
# Now we look them up in the member and non-member caches
|
||||
non_member_state, incomplete_groups_nm, = (
|
||||
yield self._get_state_for_groups_using_cache(
|
||||
groups, self._state_group_cache,
|
||||
state_filter=non_member_filter,
|
||||
groups, self._state_group_cache, state_filter=non_member_filter
|
||||
)
|
||||
)
|
||||
|
||||
member_state, incomplete_groups_m, = (
|
||||
yield self._get_state_for_groups_using_cache(
|
||||
groups, self._state_group_members_cache,
|
||||
state_filter=member_filter,
|
||||
groups, self._state_group_members_cache, state_filter=member_filter
|
||||
)
|
||||
)
|
||||
|
||||
@ -1019,8 +1002,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
db_state_filter = state_filter.return_expanded()
|
||||
|
||||
group_to_state_dict = yield self._get_state_groups_from_groups(
|
||||
list(incomplete_groups),
|
||||
state_filter=db_state_filter,
|
||||
list(incomplete_groups), state_filter=db_state_filter
|
||||
)
|
||||
|
||||
# Now lets update the caches
|
||||
@ -1040,9 +1022,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
|
||||
defer.returnValue(state)
|
||||
|
||||
def _get_state_for_groups_using_cache(
|
||||
self, groups, cache, state_filter,
|
||||
):
|
||||
def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
|
||||
"""Gets the state at each of a list of state groups, optionally
|
||||
filtering by type/state_key, querying from a specific cache.
|
||||
|
||||
@ -1074,8 +1054,13 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
|
||||
return results, incomplete_groups
|
||||
|
||||
def _insert_into_cache(self, group_to_state_dict, state_filter,
|
||||
cache_seq_num_members, cache_seq_num_non_members):
|
||||
def _insert_into_cache(
|
||||
self,
|
||||
group_to_state_dict,
|
||||
state_filter,
|
||||
cache_seq_num_members,
|
||||
cache_seq_num_non_members,
|
||||
):
|
||||
"""Inserts results from querying the database into the relevant cache.
|
||||
|
||||
Args:
|
||||
@ -1132,8 +1117,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
fetched_keys=non_member_types,
|
||||
)
|
||||
|
||||
def store_state_group(self, event_id, room_id, prev_group, delta_ids,
|
||||
current_state_ids):
|
||||
def store_state_group(
|
||||
self, event_id, room_id, prev_group, delta_ids, current_state_ids
|
||||
):
|
||||
"""Store a new set of state, returning a newly assigned state group.
|
||||
|
||||
Args:
|
||||
@ -1149,6 +1135,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
Returns:
|
||||
Deferred[int]: The state group ID
|
||||
"""
|
||||
|
||||
def _store_state_group_txn(txn):
|
||||
if current_state_ids is None:
|
||||
# AFAIK, this can never happen
|
||||
@ -1159,11 +1146,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="state_groups",
|
||||
values={
|
||||
"id": state_group,
|
||||
"room_id": room_id,
|
||||
"event_id": event_id,
|
||||
},
|
||||
values={"id": state_group, "room_id": room_id, "event_id": event_id},
|
||||
)
|
||||
|
||||
# We persist as a delta if we can, while also ensuring the chain
|
||||
@ -1182,17 +1165,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
% (prev_group,)
|
||||
)
|
||||
|
||||
potential_hops = self._count_state_group_hops_txn(
|
||||
txn, prev_group
|
||||
)
|
||||
potential_hops = self._count_state_group_hops_txn(txn, prev_group)
|
||||
if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="state_group_edges",
|
||||
values={
|
||||
"state_group": state_group,
|
||||
"prev_state_group": prev_group,
|
||||
},
|
||||
values={"state_group": state_group, "prev_state_group": prev_group},
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
@ -1264,7 +1242,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
This is used to ensure the delta chains don't get too long.
|
||||
"""
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
sql = ("""
|
||||
sql = """
|
||||
WITH RECURSIVE state(state_group) AS (
|
||||
VALUES(?::bigint)
|
||||
UNION ALL
|
||||
@ -1272,7 +1250,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
WHERE s.state_group = e.state_group
|
||||
)
|
||||
SELECT count(*) FROM state;
|
||||
""")
|
||||
"""
|
||||
|
||||
txn.execute(sql, (state_group,))
|
||||
row = txn.fetchone()
|
||||
@ -1331,8 +1309,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
|
||||
self._background_deduplicate_state,
|
||||
)
|
||||
self.register_background_update_handler(
|
||||
self.STATE_GROUP_INDEX_UPDATE_NAME,
|
||||
self._background_index_state,
|
||||
self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
|
||||
)
|
||||
self.register_background_index_update(
|
||||
self.CURRENT_STATE_INDEX_UPDATE_NAME,
|
||||
@ -1366,18 +1343,14 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
|
||||
txn,
|
||||
table="event_to_state_groups",
|
||||
values=[
|
||||
{
|
||||
"state_group": state_group_id,
|
||||
"event_id": event_id,
|
||||
}
|
||||
{"state_group": state_group_id, "event_id": event_id}
|
||||
for event_id, state_group_id in iteritems(state_groups)
|
||||
],
|
||||
)
|
||||
|
||||
for event_id, state_group_id in iteritems(state_groups):
|
||||
txn.call_after(
|
||||
self._get_state_group_for_event.prefill,
|
||||
(event_id,), state_group_id
|
||||
self._get_state_group_for_event.prefill, (event_id,), state_group_id
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -1395,7 +1368,8 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
|
||||
|
||||
if max_group is None:
|
||||
rows = yield self._execute(
|
||||
"_background_deduplicate_state", None,
|
||||
"_background_deduplicate_state",
|
||||
None,
|
||||
"SELECT coalesce(max(id), 0) FROM state_groups",
|
||||
)
|
||||
max_group = rows[0][0]
|
||||
@ -1408,7 +1382,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
|
||||
" WHERE ? < id AND id <= ?"
|
||||
" ORDER BY id ASC"
|
||||
" LIMIT 1",
|
||||
(new_last_state_group, max_group,)
|
||||
(new_last_state_group, max_group),
|
||||
)
|
||||
row = txn.fetchone()
|
||||
if row:
|
||||
@ -1420,7 +1394,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
|
||||
txn.execute(
|
||||
"SELECT state_group FROM state_group_edges"
|
||||
" WHERE state_group = ?",
|
||||
(state_group,)
|
||||
(state_group,),
|
||||
)
|
||||
|
||||
# If we reach a point where we've already started inserting
|
||||
@ -1431,27 +1405,25 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
|
||||
txn.execute(
|
||||
"SELECT coalesce(max(id), 0) FROM state_groups"
|
||||
" WHERE id < ? AND room_id = ?",
|
||||
(state_group, room_id,)
|
||||
(state_group, room_id),
|
||||
)
|
||||
prev_group, = txn.fetchone()
|
||||
new_last_state_group = state_group
|
||||
|
||||
if prev_group:
|
||||
potential_hops = self._count_state_group_hops_txn(
|
||||
txn, prev_group
|
||||
)
|
||||
potential_hops = self._count_state_group_hops_txn(txn, prev_group)
|
||||
if potential_hops >= MAX_STATE_DELTA_HOPS:
|
||||
# We want to ensure chains are at most this long,#
|
||||
# otherwise read performance degrades.
|
||||
continue
|
||||
|
||||
prev_state = self._get_state_groups_from_groups_txn(
|
||||
txn, [prev_group],
|
||||
txn, [prev_group]
|
||||
)
|
||||
prev_state = prev_state[prev_group]
|
||||
|
||||
curr_state = self._get_state_groups_from_groups_txn(
|
||||
txn, [state_group],
|
||||
txn, [state_group]
|
||||
)
|
||||
curr_state = curr_state[state_group]
|
||||
|
||||
@ -1460,16 +1432,15 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
|
||||
# of keys
|
||||
|
||||
delta_state = {
|
||||
key: value for key, value in iteritems(curr_state)
|
||||
key: value
|
||||
for key, value in iteritems(curr_state)
|
||||
if prev_state.get(key, None) != value
|
||||
}
|
||||
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="state_group_edges",
|
||||
keyvalues={
|
||||
"state_group": state_group,
|
||||
}
|
||||
keyvalues={"state_group": state_group},
|
||||
)
|
||||
|
||||
self._simple_insert_txn(
|
||||
@ -1478,15 +1449,13 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
|
||||
values={
|
||||
"state_group": state_group,
|
||||
"prev_state_group": prev_group,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="state_groups_state",
|
||||
keyvalues={
|
||||
"state_group": state_group,
|
||||
}
|
||||
keyvalues={"state_group": state_group},
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
@ -1521,7 +1490,9 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
|
||||
)
|
||||
|
||||
if finished:
|
||||
yield self._end_background_update(self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME)
|
||||
yield self._end_background_update(
|
||||
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
|
||||
)
|
||||
|
||||
defer.returnValue(result * BATCH_SIZE_SCALE_FACTOR)
|
||||
|
||||
@ -1538,9 +1509,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
|
||||
"CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
|
||||
" ON state_groups_state(state_group, type, state_key)"
|
||||
)
|
||||
txn.execute(
|
||||
"DROP INDEX IF EXISTS state_groups_state_id"
|
||||
)
|
||||
txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
|
||||
finally:
|
||||
conn.set_session(autocommit=False)
|
||||
else:
|
||||
@ -1549,9 +1518,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
|
||||
"CREATE INDEX state_groups_state_type_idx"
|
||||
" ON state_groups_state(state_group, type, state_key)"
|
||||
)
|
||||
txn.execute(
|
||||
"DROP INDEX IF EXISTS state_groups_state_id"
|
||||
)
|
||||
txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
|
||||
|
||||
yield self.runWithConnection(reindex_txn)
|
||||
|
||||
|
@ -21,10 +21,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StateDeltasStore(SQLBaseStore):
|
||||
|
||||
def get_current_state_deltas(self, prev_stream_id):
|
||||
prev_stream_id = int(prev_stream_id)
|
||||
if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id):
|
||||
if not self._curr_state_delta_stream_cache.has_any_entity_changed(
|
||||
prev_stream_id
|
||||
):
|
||||
return []
|
||||
|
||||
def get_current_state_deltas_txn(txn):
|
||||
@ -58,7 +59,7 @@ class StateDeltasStore(SQLBaseStore):
|
||||
WHERE ? < stream_id AND stream_id <= ?
|
||||
ORDER BY stream_id ASC
|
||||
"""
|
||||
txn.execute(sql, (prev_stream_id, max_stream_id,))
|
||||
txn.execute(sql, (prev_stream_id, max_stream_id))
|
||||
return self.cursor_to_dict(txn)
|
||||
|
||||
return self.runInteraction(
|
||||
|
@ -59,9 +59,9 @@ _TOPOLOGICAL_TOKEN = "topological"
|
||||
|
||||
|
||||
# Used as return values for pagination APIs
|
||||
_EventDictReturn = namedtuple("_EventDictReturn", (
|
||||
"event_id", "topological_ordering", "stream_ordering",
|
||||
))
|
||||
_EventDictReturn = namedtuple(
|
||||
"_EventDictReturn", ("event_id", "topological_ordering", "stream_ordering")
|
||||
)
|
||||
|
||||
|
||||
def lower_bound(token, engine, inclusive=False):
|
||||
@ -74,13 +74,20 @@ def lower_bound(token, engine, inclusive=False):
|
||||
# as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
|
||||
# use the later form when running against postgres.
|
||||
return "((%d,%d) <%s (%s,%s))" % (
|
||||
token.topological, token.stream, inclusive,
|
||||
"topological_ordering", "stream_ordering",
|
||||
token.topological,
|
||||
token.stream,
|
||||
inclusive,
|
||||
"topological_ordering",
|
||||
"stream_ordering",
|
||||
)
|
||||
return "(%d < %s OR (%d = %s AND %d <%s %s))" % (
|
||||
token.topological, "topological_ordering",
|
||||
token.topological, "topological_ordering",
|
||||
token.stream, inclusive, "stream_ordering",
|
||||
token.topological,
|
||||
"topological_ordering",
|
||||
token.topological,
|
||||
"topological_ordering",
|
||||
token.stream,
|
||||
inclusive,
|
||||
"stream_ordering",
|
||||
)
|
||||
|
||||
|
||||
@ -94,13 +101,20 @@ def upper_bound(token, engine, inclusive=True):
|
||||
# as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we
|
||||
# use the later form when running against postgres.
|
||||
return "((%d,%d) >%s (%s,%s))" % (
|
||||
token.topological, token.stream, inclusive,
|
||||
"topological_ordering", "stream_ordering",
|
||||
token.topological,
|
||||
token.stream,
|
||||
inclusive,
|
||||
"topological_ordering",
|
||||
"stream_ordering",
|
||||
)
|
||||
return "(%d > %s OR (%d = %s AND %d >%s %s))" % (
|
||||
token.topological, "topological_ordering",
|
||||
token.topological, "topological_ordering",
|
||||
token.stream, inclusive, "stream_ordering",
|
||||
token.topological,
|
||||
"topological_ordering",
|
||||
token.topological,
|
||||
"topological_ordering",
|
||||
token.stream,
|
||||
inclusive,
|
||||
"stream_ordering",
|
||||
)
|
||||
|
||||
|
||||
@ -116,9 +130,7 @@ def filter_to_clause(event_filter):
|
||||
args = []
|
||||
|
||||
if event_filter.types:
|
||||
clauses.append(
|
||||
"(%s)" % " OR ".join("type = ?" for _ in event_filter.types)
|
||||
)
|
||||
clauses.append("(%s)" % " OR ".join("type = ?" for _ in event_filter.types))
|
||||
args.extend(event_filter.types)
|
||||
|
||||
for typ in event_filter.not_types:
|
||||
@ -126,9 +138,7 @@ def filter_to_clause(event_filter):
|
||||
args.append(typ)
|
||||
|
||||
if event_filter.senders:
|
||||
clauses.append(
|
||||
"(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders)
|
||||
)
|
||||
clauses.append("(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders))
|
||||
args.extend(event_filter.senders)
|
||||
|
||||
for sender in event_filter.not_senders:
|
||||
@ -136,9 +146,7 @@ def filter_to_clause(event_filter):
|
||||
args.append(sender)
|
||||
|
||||
if event_filter.rooms:
|
||||
clauses.append(
|
||||
"(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms)
|
||||
)
|
||||
clauses.append("(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms))
|
||||
args.extend(event_filter.rooms)
|
||||
|
||||
for room_id in event_filter.not_rooms:
|
||||
@ -165,17 +173,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
|
||||
events_max = self.get_room_max_stream_ordering()
|
||||
event_cache_prefill, min_event_val = self._get_cache_dict(
|
||||
db_conn, "events",
|
||||
db_conn,
|
||||
"events",
|
||||
entity_column="room_id",
|
||||
stream_column="stream_ordering",
|
||||
max_value=events_max,
|
||||
)
|
||||
self._events_stream_cache = StreamChangeCache(
|
||||
"EventsRoomStreamChangeCache", min_event_val,
|
||||
"EventsRoomStreamChangeCache",
|
||||
min_event_val,
|
||||
prefilled_cache=event_cache_prefill,
|
||||
)
|
||||
self._membership_stream_cache = StreamChangeCache(
|
||||
"MembershipStreamChangeCache", events_max,
|
||||
"MembershipStreamChangeCache", events_max
|
||||
)
|
||||
|
||||
self._stream_order_on_start = self.get_room_max_stream_ordering()
|
||||
@ -189,8 +199,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
raise NotImplementedError()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0,
|
||||
order='DESC'):
|
||||
def get_room_events_stream_for_rooms(
|
||||
self, room_ids, from_key, to_key, limit=0, order='DESC'
|
||||
):
|
||||
"""Get new room events in stream ordering since `from_key`.
|
||||
|
||||
Args:
|
||||
@ -221,14 +232,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
|
||||
results = {}
|
||||
room_ids = list(room_ids)
|
||||
for rm_ids in (room_ids[i:i + 20] for i in range(0, len(room_ids), 20)):
|
||||
res = yield make_deferred_yieldable(defer.gatherResults([
|
||||
run_in_background(
|
||||
self.get_room_events_stream_for_room,
|
||||
room_id, from_key, to_key, limit, order=order,
|
||||
for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)):
|
||||
res = yield make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
run_in_background(
|
||||
self.get_room_events_stream_for_room,
|
||||
room_id,
|
||||
from_key,
|
||||
to_key,
|
||||
limit,
|
||||
order=order,
|
||||
)
|
||||
for room_id in rm_ids
|
||||
],
|
||||
consumeErrors=True,
|
||||
)
|
||||
for room_id in rm_ids
|
||||
], consumeErrors=True))
|
||||
)
|
||||
results.update(dict(zip(rm_ids, res)))
|
||||
|
||||
defer.returnValue(results)
|
||||
@ -243,13 +263,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
"""
|
||||
from_key = RoomStreamToken.parse_stream_token(from_key).stream
|
||||
return set(
|
||||
room_id for room_id in room_ids
|
||||
room_id
|
||||
for room_id in room_ids
|
||||
if self._events_stream_cache.has_entity_changed(room_id, from_key)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0,
|
||||
order='DESC'):
|
||||
def get_room_events_stream_for_room(
|
||||
self, room_id, from_key, to_key, limit=0, order='DESC'
|
||||
):
|
||||
|
||||
"""Get new room events in stream ordering since `from_key`.
|
||||
|
||||
@ -297,10 +319,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
|
||||
rows = yield self.runInteraction("get_room_events_stream_for_room", f)
|
||||
|
||||
ret = yield self._get_events(
|
||||
[r.event_id for r in rows],
|
||||
get_prev_content=True
|
||||
)
|
||||
ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True)
|
||||
|
||||
self._set_before_and_after(ret, rows, topo_order=from_id is None)
|
||||
|
||||
@ -340,7 +359,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
" AND e.stream_ordering > ? AND e.stream_ordering <= ?"
|
||||
" ORDER BY e.stream_ordering ASC"
|
||||
)
|
||||
txn.execute(sql, (user_id, from_id, to_id,))
|
||||
txn.execute(sql, (user_id, from_id, to_id))
|
||||
|
||||
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
|
||||
|
||||
@ -348,10 +367,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
|
||||
rows = yield self.runInteraction("get_membership_changes_for_user", f)
|
||||
|
||||
ret = yield self._get_events(
|
||||
[r.event_id for r in rows],
|
||||
get_prev_content=True
|
||||
)
|
||||
ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True)
|
||||
|
||||
self._set_before_and_after(ret, rows, topo_order=False)
|
||||
|
||||
@ -374,13 +390,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
"""
|
||||
|
||||
rows, token = yield self.get_recent_event_ids_for_room(
|
||||
room_id, limit, end_token,
|
||||
room_id, limit, end_token
|
||||
)
|
||||
|
||||
logger.debug("stream before")
|
||||
events = yield self._get_events(
|
||||
[r.event_id for r in rows],
|
||||
get_prev_content=True
|
||||
[r.event_id for r in rows], get_prev_content=True
|
||||
)
|
||||
logger.debug("stream after")
|
||||
|
||||
@ -410,8 +425,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
end_token = RoomStreamToken.parse(end_token)
|
||||
|
||||
rows, token = yield self.runInteraction(
|
||||
"get_recent_event_ids_for_room", self._paginate_room_events_txn,
|
||||
room_id, from_token=end_token, limit=limit,
|
||||
"get_recent_event_ids_for_room",
|
||||
self._paginate_room_events_txn,
|
||||
room_id,
|
||||
from_token=end_token,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# We want to return the results in ascending order.
|
||||
@ -430,6 +448,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
Deferred[(int, int, str)]:
|
||||
(stream ordering, topological ordering, event_id)
|
||||
"""
|
||||
|
||||
def _f(txn):
|
||||
sql = (
|
||||
"SELECT stream_ordering, topological_ordering, event_id"
|
||||
@ -439,12 +458,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
" ORDER BY stream_ordering"
|
||||
" LIMIT 1"
|
||||
)
|
||||
txn.execute(sql, (room_id, stream_ordering, ))
|
||||
txn.execute(sql, (room_id, stream_ordering))
|
||||
return txn.fetchone()
|
||||
|
||||
return self.runInteraction(
|
||||
"get_room_event_after_stream_ordering", _f,
|
||||
)
|
||||
return self.runInteraction("get_room_event_after_stream_ordering", _f)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_events_max_id(self, room_id=None):
|
||||
@ -459,8 +476,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
defer.returnValue("s%d" % (token,))
|
||||
else:
|
||||
topo = yield self.runInteraction(
|
||||
"_get_max_topological_txn", self._get_max_topological_txn,
|
||||
room_id,
|
||||
"_get_max_topological_txn", self._get_max_topological_txn, room_id
|
||||
)
|
||||
defer.returnValue("t%d-%d" % (topo, token))
|
||||
|
||||
@ -474,9 +490,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
A deferred "s%d" stream token.
|
||||
"""
|
||||
return self._simple_select_one_onecol(
|
||||
table="events",
|
||||
keyvalues={"event_id": event_id},
|
||||
retcol="stream_ordering",
|
||||
table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
|
||||
).addCallback(lambda row: "s%d" % (row,))
|
||||
|
||||
def get_topological_token_for_event(self, event_id):
|
||||
@ -493,8 +507,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
keyvalues={"event_id": event_id},
|
||||
retcols=("stream_ordering", "topological_ordering"),
|
||||
desc="get_topological_token_for_event",
|
||||
).addCallback(lambda row: "t%d-%d" % (
|
||||
row["topological_ordering"], row["stream_ordering"],)
|
||||
).addCallback(
|
||||
lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
|
||||
)
|
||||
|
||||
def get_max_topological_token(self, room_id, stream_key):
|
||||
@ -503,17 +517,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
" WHERE room_id = ? AND stream_ordering < ?"
|
||||
)
|
||||
return self._execute(
|
||||
"get_max_topological_token", None,
|
||||
sql, room_id, stream_key,
|
||||
).addCallback(
|
||||
lambda r: r[0][0] if r else 0
|
||||
)
|
||||
"get_max_topological_token", None, sql, room_id, stream_key
|
||||
).addCallback(lambda r: r[0][0] if r else 0)
|
||||
|
||||
def _get_max_topological_txn(self, txn, room_id):
|
||||
txn.execute(
|
||||
"SELECT MAX(topological_ordering) FROM events"
|
||||
" WHERE room_id = ?",
|
||||
(room_id,)
|
||||
"SELECT MAX(topological_ordering) FROM events" " WHERE room_id = ?",
|
||||
(room_id,),
|
||||
)
|
||||
|
||||
rows = txn.fetchall()
|
||||
@ -540,14 +550,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
internal = event.internal_metadata
|
||||
internal.before = str(RoomStreamToken(topo, stream - 1))
|
||||
internal.after = str(RoomStreamToken(topo, stream))
|
||||
internal.order = (
|
||||
int(topo) if topo else 0,
|
||||
int(stream),
|
||||
)
|
||||
internal.order = (int(topo) if topo else 0, int(stream))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_events_around(
|
||||
self, room_id, event_id, before_limit, after_limit, event_filter=None,
|
||||
self, room_id, event_id, before_limit, after_limit, event_filter=None
|
||||
):
|
||||
"""Retrieve events and pagination tokens around a given event in a
|
||||
room.
|
||||
@ -564,29 +571,34 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
"""
|
||||
|
||||
results = yield self.runInteraction(
|
||||
"get_events_around", self._get_events_around_txn,
|
||||
room_id, event_id, before_limit, after_limit, event_filter,
|
||||
"get_events_around",
|
||||
self._get_events_around_txn,
|
||||
room_id,
|
||||
event_id,
|
||||
before_limit,
|
||||
after_limit,
|
||||
event_filter,
|
||||
)
|
||||
|
||||
events_before = yield self._get_events(
|
||||
[e for e in results["before"]["event_ids"]],
|
||||
get_prev_content=True
|
||||
[e for e in results["before"]["event_ids"]], get_prev_content=True
|
||||
)
|
||||
|
||||
events_after = yield self._get_events(
|
||||
[e for e in results["after"]["event_ids"]],
|
||||
get_prev_content=True
|
||||
[e for e in results["after"]["event_ids"]], get_prev_content=True
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"events_before": events_before,
|
||||
"events_after": events_after,
|
||||
"start": results["before"]["token"],
|
||||
"end": results["after"]["token"],
|
||||
})
|
||||
defer.returnValue(
|
||||
{
|
||||
"events_before": events_before,
|
||||
"events_after": events_after,
|
||||
"start": results["before"]["token"],
|
||||
"end": results["after"]["token"],
|
||||
}
|
||||
)
|
||||
|
||||
def _get_events_around_txn(
|
||||
self, txn, room_id, event_id, before_limit, after_limit, event_filter,
|
||||
self, txn, room_id, event_id, before_limit, after_limit, event_filter
|
||||
):
|
||||
"""Retrieves event_ids and pagination tokens around a given event in a
|
||||
room.
|
||||
@ -605,46 +617,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
results = self._simple_select_one_txn(
|
||||
txn,
|
||||
"events",
|
||||
keyvalues={
|
||||
"event_id": event_id,
|
||||
"room_id": room_id,
|
||||
},
|
||||
keyvalues={"event_id": event_id, "room_id": room_id},
|
||||
retcols=["stream_ordering", "topological_ordering"],
|
||||
)
|
||||
|
||||
# Paginating backwards includes the event at the token, but paginating
|
||||
# forward doesn't.
|
||||
before_token = RoomStreamToken(
|
||||
results["topological_ordering"] - 1,
|
||||
results["stream_ordering"],
|
||||
results["topological_ordering"] - 1, results["stream_ordering"]
|
||||
)
|
||||
|
||||
after_token = RoomStreamToken(
|
||||
results["topological_ordering"],
|
||||
results["stream_ordering"],
|
||||
results["topological_ordering"], results["stream_ordering"]
|
||||
)
|
||||
|
||||
rows, start_token = self._paginate_room_events_txn(
|
||||
txn, room_id, before_token, direction='b', limit=before_limit,
|
||||
txn,
|
||||
room_id,
|
||||
before_token,
|
||||
direction='b',
|
||||
limit=before_limit,
|
||||
event_filter=event_filter,
|
||||
)
|
||||
events_before = [r.event_id for r in rows]
|
||||
|
||||
rows, end_token = self._paginate_room_events_txn(
|
||||
txn, room_id, after_token, direction='f', limit=after_limit,
|
||||
txn,
|
||||
room_id,
|
||||
after_token,
|
||||
direction='f',
|
||||
limit=after_limit,
|
||||
event_filter=event_filter,
|
||||
)
|
||||
events_after = [r.event_id for r in rows]
|
||||
|
||||
return {
|
||||
"before": {
|
||||
"event_ids": events_before,
|
||||
"token": start_token,
|
||||
},
|
||||
"after": {
|
||||
"event_ids": events_after,
|
||||
"token": end_token,
|
||||
},
|
||||
"before": {"event_ids": events_before, "token": start_token},
|
||||
"after": {"event_ids": events_after, "token": end_token},
|
||||
}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -685,7 +694,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
return upper_bound, [row[1] for row in rows]
|
||||
|
||||
upper_bound, event_ids = yield self.runInteraction(
|
||||
"get_all_new_events_stream", get_all_new_events_stream_txn,
|
||||
"get_all_new_events_stream", get_all_new_events_stream_txn
|
||||
)
|
||||
|
||||
events = yield self._get_events(event_ids)
|
||||
@ -697,7 +706,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
table="federation_stream_position",
|
||||
retcol="stream_id",
|
||||
keyvalues={"type": typ},
|
||||
desc="get_federation_out_pos"
|
||||
desc="get_federation_out_pos",
|
||||
)
|
||||
|
||||
def update_federation_out_pos(self, typ, stream_id):
|
||||
@ -711,8 +720,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
def has_room_changed_since(self, room_id, stream_id):
|
||||
return self._events_stream_cache.has_entity_changed(room_id, stream_id)
|
||||
|
||||
def _paginate_room_events_txn(self, txn, room_id, from_token, to_token=None,
|
||||
direction='b', limit=-1, event_filter=None):
|
||||
def _paginate_room_events_txn(
|
||||
self,
|
||||
txn,
|
||||
room_id,
|
||||
from_token,
|
||||
to_token=None,
|
||||
direction='b',
|
||||
limit=-1,
|
||||
event_filter=None,
|
||||
):
|
||||
"""Returns list of events before or after a given token.
|
||||
|
||||
Args:
|
||||
@ -741,22 +758,20 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
args = [False, room_id]
|
||||
if direction == 'b':
|
||||
order = "DESC"
|
||||
bounds = upper_bound(
|
||||
from_token, self.database_engine
|
||||
)
|
||||
bounds = upper_bound(from_token, self.database_engine)
|
||||
if to_token:
|
||||
bounds = "%s AND %s" % (bounds, lower_bound(
|
||||
to_token, self.database_engine
|
||||
))
|
||||
bounds = "%s AND %s" % (
|
||||
bounds,
|
||||
lower_bound(to_token, self.database_engine),
|
||||
)
|
||||
else:
|
||||
order = "ASC"
|
||||
bounds = lower_bound(
|
||||
from_token, self.database_engine
|
||||
)
|
||||
bounds = lower_bound(from_token, self.database_engine)
|
||||
if to_token:
|
||||
bounds = "%s AND %s" % (bounds, upper_bound(
|
||||
to_token, self.database_engine
|
||||
))
|
||||
bounds = "%s AND %s" % (
|
||||
bounds,
|
||||
upper_bound(to_token, self.database_engine),
|
||||
)
|
||||
|
||||
filter_clause, filter_args = filter_to_clause(event_filter)
|
||||
|
||||
@ -772,10 +787,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
" WHERE outlier = ? AND room_id = ? AND %(bounds)s"
|
||||
" ORDER BY topological_ordering %(order)s,"
|
||||
" stream_ordering %(order)s LIMIT ?"
|
||||
) % {
|
||||
"bounds": bounds,
|
||||
"order": order,
|
||||
}
|
||||
) % {"bounds": bounds, "order": order}
|
||||
|
||||
txn.execute(sql, args)
|
||||
|
||||
@ -796,11 +808,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
# TODO (erikj): We should work out what to do here instead.
|
||||
next_token = to_token if to_token else from_token
|
||||
|
||||
return rows, str(next_token),
|
||||
return rows, str(next_token)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def paginate_room_events(self, room_id, from_key, to_key=None,
|
||||
direction='b', limit=-1, event_filter=None):
|
||||
def paginate_room_events(
|
||||
self, room_id, from_key, to_key=None, direction='b', limit=-1, event_filter=None
|
||||
):
|
||||
"""Returns list of events before or after a given token.
|
||||
|
||||
Args:
|
||||
@ -826,13 +839,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
to_key = RoomStreamToken.parse(to_key)
|
||||
|
||||
rows, token = yield self.runInteraction(
|
||||
"paginate_room_events", self._paginate_room_events_txn,
|
||||
room_id, from_key, to_key, direction, limit, event_filter,
|
||||
"paginate_room_events",
|
||||
self._paginate_room_events_txn,
|
||||
room_id,
|
||||
from_key,
|
||||
to_key,
|
||||
direction,
|
||||
limit,
|
||||
event_filter,
|
||||
)
|
||||
|
||||
events = yield self._get_events(
|
||||
[r.event_id for r in rows],
|
||||
get_prev_content=True
|
||||
[r.event_id for r in rows], get_prev_content=True
|
||||
)
|
||||
|
||||
self._set_before_and_after(events, rows)
|
||||
|
@ -84,9 +84,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||
|
||||
def get_tag_content(txn, tag_ids):
|
||||
sql = (
|
||||
"SELECT tag, content"
|
||||
" FROM room_tags"
|
||||
" WHERE user_id=? AND room_id=?"
|
||||
"SELECT tag, content" " FROM room_tags" " WHERE user_id=? AND room_id=?"
|
||||
)
|
||||
results = []
|
||||
for stream_id, user_id, room_id in tag_ids:
|
||||
@ -105,7 +103,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||
tags = yield self.runInteraction(
|
||||
"get_all_updated_tag_content",
|
||||
get_tag_content,
|
||||
tag_ids[i:i + batch_size],
|
||||
tag_ids[i : i + batch_size],
|
||||
)
|
||||
results.extend(tags)
|
||||
|
||||
@ -123,6 +121,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||
A deferred dict mapping from room_id strings to lists of tag
|
||||
strings for all the rooms that changed since the stream_id token.
|
||||
"""
|
||||
|
||||
def get_updated_tags_txn(txn):
|
||||
sql = (
|
||||
"SELECT room_id from room_tags_revisions"
|
||||
@ -138,9 +137,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||
if not changed:
|
||||
defer.returnValue({})
|
||||
|
||||
room_ids = yield self.runInteraction(
|
||||
"get_updated_tags", get_updated_tags_txn
|
||||
)
|
||||
room_ids = yield self.runInteraction("get_updated_tags", get_updated_tags_txn)
|
||||
|
||||
results = {}
|
||||
if room_ids:
|
||||
@ -163,9 +160,9 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||
keyvalues={"user_id": user_id, "room_id": room_id},
|
||||
retcols=("tag", "content"),
|
||||
desc="get_tags_for_room",
|
||||
).addCallback(lambda rows: {
|
||||
row["tag"]: json.loads(row["content"]) for row in rows
|
||||
})
|
||||
).addCallback(
|
||||
lambda rows: {row["tag"]: json.loads(row["content"]) for row in rows}
|
||||
)
|
||||
|
||||
|
||||
class TagsStore(TagsWorkerStore):
|
||||
@ -186,14 +183,8 @@ class TagsStore(TagsWorkerStore):
|
||||
self._simple_upsert_txn(
|
||||
txn,
|
||||
table="room_tags",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"room_id": room_id,
|
||||
"tag": tag,
|
||||
},
|
||||
values={
|
||||
"content": content_json,
|
||||
}
|
||||
keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag},
|
||||
values={"content": content_json},
|
||||
)
|
||||
self._update_revision_txn(txn, user_id, room_id, next_id)
|
||||
|
||||
@ -211,6 +202,7 @@ class TagsStore(TagsWorkerStore):
|
||||
Returns:
|
||||
A deferred that completes once the tag has been removed
|
||||
"""
|
||||
|
||||
def remove_tag_txn(txn, next_id):
|
||||
sql = (
|
||||
"DELETE FROM room_tags "
|
||||
@ -238,8 +230,7 @@ class TagsStore(TagsWorkerStore):
|
||||
"""
|
||||
|
||||
txn.call_after(
|
||||
self._account_data_stream_cache.entity_has_changed,
|
||||
user_id, next_id
|
||||
self._account_data_stream_cache.entity_has_changed, user_id, next_id
|
||||
)
|
||||
|
||||
update_max_id_sql = (
|
||||
|
@ -38,16 +38,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_TransactionRow = namedtuple(
|
||||
"_TransactionRow", (
|
||||
"id", "transaction_id", "destination", "ts", "response_code",
|
||||
"response_json",
|
||||
)
|
||||
"_TransactionRow",
|
||||
("id", "transaction_id", "destination", "ts", "response_code", "response_json"),
|
||||
)
|
||||
|
||||
_UpdateTransactionRow = namedtuple(
|
||||
"_TransactionRow", (
|
||||
"response_code", "response_json",
|
||||
)
|
||||
"_TransactionRow", ("response_code", "response_json")
|
||||
)
|
||||
|
||||
SENTINEL = object()
|
||||
@ -84,19 +80,22 @@ class TransactionStore(SQLBaseStore):
|
||||
|
||||
return self.runInteraction(
|
||||
"get_received_txn_response",
|
||||
self._get_received_txn_response, transaction_id, origin
|
||||
self._get_received_txn_response,
|
||||
transaction_id,
|
||||
origin,
|
||||
)
|
||||
|
||||
def _get_received_txn_response(self, txn, transaction_id, origin):
|
||||
result = self._simple_select_one_txn(
|
||||
txn,
|
||||
table="received_transactions",
|
||||
keyvalues={
|
||||
"transaction_id": transaction_id,
|
||||
"origin": origin,
|
||||
},
|
||||
keyvalues={"transaction_id": transaction_id, "origin": origin},
|
||||
retcols=(
|
||||
"transaction_id", "origin", "ts", "response_code", "response_json",
|
||||
"transaction_id",
|
||||
"origin",
|
||||
"ts",
|
||||
"response_code",
|
||||
"response_json",
|
||||
"has_been_referenced",
|
||||
),
|
||||
allow_none=True,
|
||||
@ -108,8 +107,7 @@ class TransactionStore(SQLBaseStore):
|
||||
else:
|
||||
return None
|
||||
|
||||
def set_received_txn_response(self, transaction_id, origin, code,
|
||||
response_dict):
|
||||
def set_received_txn_response(self, transaction_id, origin, code, response_dict):
|
||||
"""Persist the response we returened for an incoming transaction, and
|
||||
should return for subsequent transactions with the same transaction_id
|
||||
and origin.
|
||||
@ -135,8 +133,7 @@ class TransactionStore(SQLBaseStore):
|
||||
desc="set_received_txn_response",
|
||||
)
|
||||
|
||||
def prep_send_transaction(self, transaction_id, destination,
|
||||
origin_server_ts):
|
||||
def prep_send_transaction(self, transaction_id, destination, origin_server_ts):
|
||||
"""Persists an outgoing transaction and calculates the values for the
|
||||
previous transaction id list.
|
||||
|
||||
@ -182,7 +179,9 @@ class TransactionStore(SQLBaseStore):
|
||||
|
||||
result = yield self.runInteraction(
|
||||
"get_destination_retry_timings",
|
||||
self._get_destination_retry_timings, destination)
|
||||
self._get_destination_retry_timings,
|
||||
destination,
|
||||
)
|
||||
|
||||
# We don't hugely care about race conditions between getting and
|
||||
# invalidating the cache, since we time out fairly quickly anyway.
|
||||
@ -193,9 +192,7 @@ class TransactionStore(SQLBaseStore):
|
||||
result = self._simple_select_one_txn(
|
||||
txn,
|
||||
table="destinations",
|
||||
keyvalues={
|
||||
"destination": destination,
|
||||
},
|
||||
keyvalues={"destination": destination},
|
||||
retcols=("destination", "retry_last_ts", "retry_interval"),
|
||||
allow_none=True,
|
||||
)
|
||||
@ -205,8 +202,7 @@ class TransactionStore(SQLBaseStore):
|
||||
else:
|
||||
return None
|
||||
|
||||
def set_destination_retry_timings(self, destination,
|
||||
retry_last_ts, retry_interval):
|
||||
def set_destination_retry_timings(self, destination, retry_last_ts, retry_interval):
|
||||
"""Sets the current retry timings for a given destination.
|
||||
Both timings should be zero if retrying is no longer occuring.
|
||||
|
||||
@ -225,8 +221,9 @@ class TransactionStore(SQLBaseStore):
|
||||
retry_interval,
|
||||
)
|
||||
|
||||
def _set_destination_retry_timings(self, txn, destination,
|
||||
retry_last_ts, retry_interval):
|
||||
def _set_destination_retry_timings(
|
||||
self, txn, destination, retry_last_ts, retry_interval
|
||||
):
|
||||
self.database_engine.lock_table(txn, "destinations")
|
||||
|
||||
# We need to be careful here as the data may have changed from under us
|
||||
@ -235,9 +232,7 @@ class TransactionStore(SQLBaseStore):
|
||||
prev_row = self._simple_select_one_txn(
|
||||
txn,
|
||||
table="destinations",
|
||||
keyvalues={
|
||||
"destination": destination,
|
||||
},
|
||||
keyvalues={"destination": destination},
|
||||
retcols=("retry_last_ts", "retry_interval"),
|
||||
allow_none=True,
|
||||
)
|
||||
@ -250,15 +245,13 @@ class TransactionStore(SQLBaseStore):
|
||||
"destination": destination,
|
||||
"retry_last_ts": retry_last_ts,
|
||||
"retry_interval": retry_interval,
|
||||
}
|
||||
},
|
||||
)
|
||||
elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
|
||||
self._simple_update_one_txn(
|
||||
txn,
|
||||
"destinations",
|
||||
keyvalues={
|
||||
"destination": destination,
|
||||
},
|
||||
keyvalues={"destination": destination},
|
||||
updatevalues={
|
||||
"retry_last_ts": retry_last_ts,
|
||||
"retry_interval": retry_interval,
|
||||
@ -273,8 +266,7 @@ class TransactionStore(SQLBaseStore):
|
||||
"""
|
||||
|
||||
return self.runInteraction(
|
||||
"get_destinations_needing_retry",
|
||||
self._get_destinations_needing_retry
|
||||
"get_destinations_needing_retry", self._get_destinations_needing_retry
|
||||
)
|
||||
|
||||
def _get_destinations_needing_retry(self, txn):
|
||||
@ -288,7 +280,7 @@ class TransactionStore(SQLBaseStore):
|
||||
|
||||
def _start_cleanup_transactions(self):
|
||||
return run_as_background_process(
|
||||
"cleanup_transactions", self._cleanup_transactions,
|
||||
"cleanup_transactions", self._cleanup_transactions
|
||||
)
|
||||
|
||||
def _cleanup_transactions(self):
|
||||
|
@ -40,9 +40,7 @@ class UserErasureWorkerStore(SQLBaseStore):
|
||||
).addCallback(operator.truth)
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="is_user_erased",
|
||||
list_name="user_ids",
|
||||
inlineCallbacks=True,
|
||||
cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True
|
||||
)
|
||||
def are_users_erased(self, user_ids):
|
||||
"""
|
||||
@ -61,16 +59,13 @@ class UserErasureWorkerStore(SQLBaseStore):
|
||||
|
||||
def _get_erased_users(txn):
|
||||
txn.execute(
|
||||
"SELECT user_id FROM erased_users WHERE user_id IN (%s)" % (
|
||||
",".join("?" * len(user_ids))
|
||||
),
|
||||
"SELECT user_id FROM erased_users WHERE user_id IN (%s)"
|
||||
% (",".join("?" * len(user_ids))),
|
||||
user_ids,
|
||||
)
|
||||
return set(r[0] for r in txn)
|
||||
|
||||
erased_users = yield self.runInteraction(
|
||||
"are_users_erased", _get_erased_users,
|
||||
)
|
||||
erased_users = yield self.runInteraction("are_users_erased", _get_erased_users)
|
||||
res = dict((u, u in erased_users) for u in user_ids)
|
||||
defer.returnValue(res)
|
||||
|
||||
@ -82,22 +77,16 @@ class UserErasureStore(UserErasureWorkerStore):
|
||||
Args:
|
||||
user_id (str): full user_id to be erased
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
# first check if they are already in the list
|
||||
txn.execute(
|
||||
"SELECT 1 FROM erased_users WHERE user_id = ?",
|
||||
(user_id, )
|
||||
)
|
||||
txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,))
|
||||
if txn.fetchone():
|
||||
return
|
||||
|
||||
# they are not already there: do the insert.
|
||||
txn.execute(
|
||||
"INSERT INTO erased_users (user_id) VALUES (?)",
|
||||
(user_id, )
|
||||
)
|
||||
txn.execute("INSERT INTO erased_users (user_id) VALUES (?)", (user_id,))
|
||||
|
||||
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.is_user_erased, (user_id,)
|
||||
)
|
||||
return self.runInteraction("mark_user_erased", f)
|
||||
|
@ -43,9 +43,9 @@ def _load_current_id(db_conn, table, column, step=1):
|
||||
"""
|
||||
cur = db_conn.cursor()
|
||||
if step == 1:
|
||||
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
|
||||
cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
|
||||
else:
|
||||
cur.execute("SELECT MIN(%s) FROM %s" % (column, table,))
|
||||
cur.execute("SELECT MIN(%s) FROM %s" % (column, table))
|
||||
val, = cur.fetchone()
|
||||
cur.close()
|
||||
current_id = int(val) if val else step
|
||||
@ -77,6 +77,7 @@ class StreamIdGenerator(object):
|
||||
with stream_id_gen.get_next() as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
|
||||
def __init__(self, db_conn, table, column, extra_tables=[], step=1):
|
||||
assert step != 0
|
||||
self._lock = threading.Lock()
|
||||
@ -84,8 +85,7 @@ class StreamIdGenerator(object):
|
||||
self._current = _load_current_id(db_conn, table, column, step)
|
||||
for table, column in extra_tables:
|
||||
self._current = (max if step > 0 else min)(
|
||||
self._current,
|
||||
_load_current_id(db_conn, table, column, step)
|
||||
self._current, _load_current_id(db_conn, table, column, step)
|
||||
)
|
||||
self._unfinished_ids = deque()
|
||||
|
||||
@ -121,7 +121,7 @@ class StreamIdGenerator(object):
|
||||
next_ids = range(
|
||||
self._current + self._step,
|
||||
self._current + self._step * (n + 1),
|
||||
self._step
|
||||
self._step,
|
||||
)
|
||||
self._current += n * self._step
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user