Run black on the rest of the storage module (#4996)

This commit is contained in:
Amber Brown 2019-04-03 20:07:29 +11:00 committed by Richard van der Hoff
parent 3039d61baf
commit 7efd1d87c2
42 changed files with 2129 additions and 2453 deletions

1
changelog.d/4996.misc Normal file
View File

@ -0,0 +1 @@
Run `black` on the remainder of `synapse/storage/`.

View File

@ -61,10 +61,18 @@ from .util.id_generators import ChainedIdGenerator, IdGenerator, StreamIdGenerat
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DataStore(RoomMemberStore, RoomStore, class DataStore(
RegistrationStore, StreamStore, ProfileStore, RoomMemberStore,
PresenceStore, TransactionStore, RoomStore,
DirectoryStore, KeyStore, StateStore, SignatureStore, RegistrationStore,
StreamStore,
ProfileStore,
PresenceStore,
TransactionStore,
DirectoryStore,
KeyStore,
StateStore,
SignatureStore,
ApplicationServiceStore, ApplicationServiceStore,
EventsStore, EventsStore,
EventFederationStore, EventFederationStore,
@ -90,19 +98,23 @@ class DataStore(RoomMemberStore, RoomStore,
UserErasureStore, UserErasureStore,
MonthlyActiveUsersStore, MonthlyActiveUsersStore,
): ):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
self.hs = hs self.hs = hs
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.database_engine = hs.database_engine self.database_engine = hs.database_engine
self._stream_id_gen = StreamIdGenerator( self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering", db_conn,
extra_tables=[("local_invites", "stream_id")] "events",
"stream_ordering",
extra_tables=[("local_invites", "stream_id")],
) )
self._backfill_id_gen = StreamIdGenerator( self._backfill_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering", step=-1, db_conn,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")] "events",
"stream_ordering",
step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
) )
self._presence_id_gen = StreamIdGenerator( self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id" db_conn, "presence_stream", "stream_id"
@ -114,7 +126,7 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "public_room_list_stream", "stream_id" db_conn, "public_room_list_stream", "stream_id"
) )
self._device_list_id_gen = StreamIdGenerator( self._device_list_id_gen = StreamIdGenerator(
db_conn, "device_lists_stream", "stream_id", db_conn, "device_lists_stream", "stream_id"
) )
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
@ -125,16 +137,15 @@ class DataStore(RoomMemberStore, RoomStore,
self._stream_id_gen, db_conn, "push_rules_stream", "stream_id" self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
) )
self._pushers_id_gen = StreamIdGenerator( self._pushers_id_gen = StreamIdGenerator(
db_conn, "pushers", "id", db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
extra_tables=[("deleted_pushers", "stream_id")],
) )
self._group_updates_id_gen = StreamIdGenerator( self._group_updates_id_gen = StreamIdGenerator(
db_conn, "local_group_updates", "stream_id", db_conn, "local_group_updates", "stream_id"
) )
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = StreamIdGenerator( self._cache_id_gen = StreamIdGenerator(
db_conn, "cache_invalidation_stream", "stream_id", db_conn, "cache_invalidation_stream", "stream_id"
) )
else: else:
self._cache_id_gen = None self._cache_id_gen = None
@ -142,72 +153,82 @@ class DataStore(RoomMemberStore, RoomStore,
self._presence_on_startup = self._get_active_presence(db_conn) self._presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self._get_cache_dict( presence_cache_prefill, min_presence_val = self._get_cache_dict(
db_conn, "presence_stream", db_conn,
"presence_stream",
entity_column="user_id", entity_column="user_id",
stream_column="stream_id", stream_column="stream_id",
max_value=self._presence_id_gen.get_current_token(), max_value=self._presence_id_gen.get_current_token(),
) )
self.presence_stream_cache = StreamChangeCache( self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", min_presence_val, "PresenceStreamChangeCache",
prefilled_cache=presence_cache_prefill min_presence_val,
prefilled_cache=presence_cache_prefill,
) )
max_device_inbox_id = self._device_inbox_id_gen.get_current_token() max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
device_inbox_prefill, min_device_inbox_id = self._get_cache_dict( device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
db_conn, "device_inbox", db_conn,
"device_inbox",
entity_column="user_id", entity_column="user_id",
stream_column="stream_id", stream_column="stream_id",
max_value=max_device_inbox_id, max_value=max_device_inbox_id,
limit=1000, limit=1000,
) )
self._device_inbox_stream_cache = StreamChangeCache( self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache", min_device_inbox_id, "DeviceInboxStreamChangeCache",
min_device_inbox_id,
prefilled_cache=device_inbox_prefill, prefilled_cache=device_inbox_prefill,
) )
# The federation outbox and the local device inbox uses the same # The federation outbox and the local device inbox uses the same
# stream_id generator. # stream_id generator.
device_outbox_prefill, min_device_outbox_id = self._get_cache_dict( device_outbox_prefill, min_device_outbox_id = self._get_cache_dict(
db_conn, "device_federation_outbox", db_conn,
"device_federation_outbox",
entity_column="destination", entity_column="destination",
stream_column="stream_id", stream_column="stream_id",
max_value=max_device_inbox_id, max_value=max_device_inbox_id,
limit=1000, limit=1000,
) )
self._device_federation_outbox_stream_cache = StreamChangeCache( self._device_federation_outbox_stream_cache = StreamChangeCache(
"DeviceFederationOutboxStreamChangeCache", min_device_outbox_id, "DeviceFederationOutboxStreamChangeCache",
min_device_outbox_id,
prefilled_cache=device_outbox_prefill, prefilled_cache=device_outbox_prefill,
) )
device_list_max = self._device_list_id_gen.get_current_token() device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache( self._device_list_stream_cache = StreamChangeCache(
"DeviceListStreamChangeCache", device_list_max, "DeviceListStreamChangeCache", device_list_max
) )
self._device_list_federation_stream_cache = StreamChangeCache( self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache", device_list_max, "DeviceListFederationStreamChangeCache", device_list_max
) )
events_max = self._stream_id_gen.get_current_token() events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
db_conn, "current_state_delta_stream", db_conn,
"current_state_delta_stream",
entity_column="room_id", entity_column="room_id",
stream_column="stream_id", stream_column="stream_id",
max_value=events_max, # As we share the stream id with events token max_value=events_max, # As we share the stream id with events token
limit=1000, limit=1000,
) )
self._curr_state_delta_stream_cache = StreamChangeCache( self._curr_state_delta_stream_cache = StreamChangeCache(
"_curr_state_delta_stream_cache", min_curr_state_delta_id, "_curr_state_delta_stream_cache",
min_curr_state_delta_id,
prefilled_cache=curr_state_delta_prefill, prefilled_cache=curr_state_delta_prefill,
) )
_group_updates_prefill, min_group_updates_id = self._get_cache_dict( _group_updates_prefill, min_group_updates_id = self._get_cache_dict(
db_conn, "local_group_updates", db_conn,
"local_group_updates",
entity_column="user_id", entity_column="user_id",
stream_column="stream_id", stream_column="stream_id",
max_value=self._group_updates_id_gen.get_current_token(), max_value=self._group_updates_id_gen.get_current_token(),
limit=1000, limit=1000,
) )
self._group_updates_stream_cache = StreamChangeCache( self._group_updates_stream_cache = StreamChangeCache(
"_group_updates_stream_cache", min_group_updates_id, "_group_updates_stream_cache",
min_group_updates_id,
prefilled_cache=_group_updates_prefill, prefilled_cache=_group_updates_prefill,
) )
@ -250,6 +271,7 @@ class DataStore(RoomMemberStore, RoomStore,
""" """
Counts the number of users who used this homeserver in the last 24 hours. Counts the number of users who used this homeserver in the last 24 hours.
""" """
def _count_users(txn): def _count_users(txn):
yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
@ -277,6 +299,7 @@ class DataStore(RoomMemberStore, RoomStore,
Returns counts globaly for a given user as well as breaking Returns counts globaly for a given user as well as breaking
by platform by platform
""" """
def _count_r30_users(txn): def _count_r30_users(txn):
thirty_days_in_secs = 86400 * 30 thirty_days_in_secs = 86400 * 30
now = int(self._clock.time()) now = int(self._clock.time())
@ -313,8 +336,7 @@ class DataStore(RoomMemberStore, RoomStore,
""" """
results = {} results = {}
txn.execute(sql, (thirty_days_ago_in_secs, txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
thirty_days_ago_in_secs))
for row in txn: for row in txn:
if row[0] == 'unknown': if row[0] == 'unknown':
@ -341,8 +363,7 @@ class DataStore(RoomMemberStore, RoomStore,
) u ) u
""" """
txn.execute(sql, (thirty_days_ago_in_secs, txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
thirty_days_ago_in_secs))
count, = txn.fetchone() count, = txn.fetchone()
results['all'] = count results['all'] = count
@ -356,15 +377,14 @@ class DataStore(RoomMemberStore, RoomStore,
Returns millisecond unixtime for start of UTC day. Returns millisecond unixtime for start of UTC day.
""" """
now = time.gmtime() now = time.gmtime()
today_start = calendar.timegm(( today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0,
))
return today_start * 1000 return today_start * 1000
def generate_user_daily_visits(self): def generate_user_daily_visits(self):
""" """
Generates daily visit data for use in cohort/ retention analysis Generates daily visit data for use in cohort/ retention analysis
""" """
def _generate_user_daily_visits(txn): def _generate_user_daily_visits(txn):
logger.info("Calling _generate_user_daily_visits") logger.info("Calling _generate_user_daily_visits")
today_start = self._get_start_of_day() today_start = self._get_start_of_day()
@ -395,25 +415,29 @@ class DataStore(RoomMemberStore, RoomStore,
# often to minimise this case. # often to minimise this case.
if today_start > self._last_user_visit_update: if today_start > self._last_user_visit_update:
yesterday_start = today_start - a_day_in_milliseconds yesterday_start = today_start - a_day_in_milliseconds
txn.execute(sql, ( txn.execute(
yesterday_start, yesterday_start, sql,
self._last_user_visit_update, today_start (
)) yesterday_start,
yesterday_start,
self._last_user_visit_update,
today_start,
),
)
self._last_user_visit_update = today_start self._last_user_visit_update = today_start
txn.execute(sql, ( txn.execute(
today_start, today_start, sql, (today_start, today_start, self._last_user_visit_update, now)
self._last_user_visit_update, )
now
))
# Update _last_user_visit_update to now. The reason to do this # Update _last_user_visit_update to now. The reason to do this
# rather just clamping to the beginning of the day is to limit # rather just clamping to the beginning of the day is to limit
# the size of the join - meaning that the query can be run more # the size of the join - meaning that the query can be run more
# frequently # frequently
self._last_user_visit_update = now self._last_user_visit_update = now
return self.runInteraction("generate_user_daily_visits", return self.runInteraction(
_generate_user_daily_visits) "generate_user_daily_visits", _generate_user_daily_visits
)
def get_users(self): def get_users(self):
"""Function to reterive a list of users in users table. """Function to reterive a list of users in users table.
@ -425,12 +449,7 @@ class DataStore(RoomMemberStore, RoomStore,
return self._simple_select_list( return self._simple_select_list(
table="users", table="users",
keyvalues={}, keyvalues={},
retcols=[ retcols=["name", "password_hash", "is_guest", "admin"],
"name",
"password_hash",
"is_guest",
"admin"
],
desc="get_users", desc="get_users",
) )
@ -451,20 +470,9 @@ class DataStore(RoomMemberStore, RoomStore,
i_limit = (int)(limit) i_limit = (int)(limit)
return self.get_user_list_paginate( return self.get_user_list_paginate(
table="users", table="users",
keyvalues={ keyvalues={"is_guest": is_guest},
"is_guest": is_guest pagevalues=[order, i_limit, i_start],
}, retcols=["name", "password_hash", "is_guest", "admin"],
pagevalues=[
order,
i_limit,
i_start
],
retcols=[
"name",
"password_hash",
"is_guest",
"admin"
],
desc="get_users_paginate", desc="get_users_paginate",
) )
@ -482,12 +490,7 @@ class DataStore(RoomMemberStore, RoomStore,
table="users", table="users",
term=term, term=term,
col="name", col="name",
retcols=[ retcols=["name", "password_hash", "is_guest", "admin"],
"name",
"password_hash",
"is_guest",
"admin"
],
desc="search_users", desc="search_users",
) )

View File

@ -76,12 +76,18 @@ class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object """An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute() passed to the constructor. Adds logging and metrics to the .execute()
method.""" method."""
__slots__ = [ __slots__ = [
"txn", "name", "database_engine", "after_callbacks", "exception_callbacks", "txn",
"name",
"database_engine",
"after_callbacks",
"exception_callbacks",
] ]
def __init__(self, txn, name, database_engine, after_callbacks, def __init__(
exception_callbacks): self, txn, name, database_engine, after_callbacks, exception_callbacks
):
object.__setattr__(self, "txn", txn) object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name) object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine) object.__setattr__(self, "database_engine", database_engine)
@ -110,6 +116,7 @@ class LoggingTransaction(object):
def execute_batch(self, sql, args): def execute_batch(self, sql, args):
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch from psycopg2.extras import execute_batch
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
else: else:
for val in args: for val in args:
@ -134,10 +141,7 @@ class LoggingTransaction(object):
sql = self.database_engine.convert_param_style(sql) sql = self.database_engine.convert_param_style(sql)
if args: if args:
try: try:
sql_logger.debug( sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
"[SQL values] {%s} %r",
self.name, args[0]
)
except Exception: except Exception:
# Don't let logging failures stop SQL from working # Don't let logging failures stop SQL from working
pass pass
@ -145,9 +149,7 @@ class LoggingTransaction(object):
start = time.time() start = time.time()
try: try:
return func( return func(sql, *args)
sql, *args
)
except Exception as e: except Exception as e:
logger.debug("[SQL FAIL] {%s} %s", self.name, e) logger.debug("[SQL FAIL] {%s} %s", self.name, e)
raise raise
@ -176,11 +178,9 @@ class PerformanceCounters(object):
counters = [] counters = []
for name, (count, cum_time) in iteritems(self.current_counters): for name, (count, cum_time) in iteritems(self.current_counters):
prev_count, prev_time = self.previous_counters.get(name, (0, 0)) prev_count, prev_time = self.previous_counters.get(name, (0, 0))
counters.append(( counters.append(
(cum_time - prev_time) / interval_duration, ((cum_time - prev_time) / interval_duration, count - prev_count, name)
count - prev_count, )
name
))
self.previous_counters = dict(self.current_counters) self.previous_counters = dict(self.current_counters)
@ -212,8 +212,9 @@ class SQLBaseStore(object):
self._txn_perf_counters = PerformanceCounters() self._txn_perf_counters = PerformanceCounters()
self._get_event_counters = PerformanceCounters() self._get_event_counters = PerformanceCounters()
self._get_event_cache = Cache("*getEvent*", keylen=3, self._get_event_cache = Cache(
max_entries=hs.config.event_cache_size) "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
)
self._event_fetch_lock = threading.Condition() self._event_fetch_lock = threading.Condition()
self._event_fetch_list = [] self._event_fetch_list = []
@ -239,7 +240,7 @@ class SQLBaseStore(object):
0.0, 0.0,
run_as_background_process, run_as_background_process,
"upsert_safety_check", "upsert_safety_check",
self._check_safe_to_upsert self._check_safe_to_upsert,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -271,7 +272,7 @@ class SQLBaseStore(object):
15.0, 15.0,
run_as_background_process, run_as_background_process,
"upsert_safety_check", "upsert_safety_check",
self._check_safe_to_upsert self._check_safe_to_upsert,
) )
def start_profiling(self): def start_profiling(self):
@ -298,13 +299,16 @@ class SQLBaseStore(object):
perf_logger.info( perf_logger.info(
"Total database time: %.3f%% {%s} {%s}", "Total database time: %.3f%% {%s} {%s}",
ratio * 100, top_three_counters, top_3_event_counters ratio * 100,
top_three_counters,
top_3_event_counters,
) )
self._clock.looping_call(loop, 10000) self._clock.looping_call(loop, 10000)
def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks, def _new_transaction(
func, *args, **kwargs): self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
):
start = time.time() start = time.time()
txn_id = self._TXN_ID txn_id = self._TXN_ID
@ -312,7 +316,7 @@ class SQLBaseStore(object):
# growing really large. # growing really large.
self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID) self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
name = "%s-%x" % (desc, txn_id, ) name = "%s-%x" % (desc, txn_id)
transaction_logger.debug("[TXN START] {%s}", name) transaction_logger.debug("[TXN START] {%s}", name)
@ -323,7 +327,10 @@ class SQLBaseStore(object):
try: try:
txn = conn.cursor() txn = conn.cursor()
txn = LoggingTransaction( txn = LoggingTransaction(
txn, name, self.database_engine, after_callbacks, txn,
name,
self.database_engine,
after_callbacks,
exception_callbacks, exception_callbacks,
) )
r = func(txn, *args, **kwargs) r = func(txn, *args, **kwargs)
@ -334,7 +341,10 @@ class SQLBaseStore(object):
# transaction. # transaction.
logger.warning( logger.warning(
"[TXN OPERROR] {%s} %s %d/%d", "[TXN OPERROR] {%s} %s %d/%d",
name, exception_to_unicode(e), i, N name,
exception_to_unicode(e),
i,
N,
) )
if i < N: if i < N:
i += 1 i += 1
@ -342,8 +352,7 @@ class SQLBaseStore(object):
conn.rollback() conn.rollback()
except self.database_engine.module.Error as e1: except self.database_engine.module.Error as e1:
logger.warning( logger.warning(
"[TXN EROLL] {%s} %s", "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
name, exception_to_unicode(e1),
) )
continue continue
raise raise
@ -357,7 +366,8 @@ class SQLBaseStore(object):
except self.database_engine.module.Error as e1: except self.database_engine.module.Error as e1:
logger.warning( logger.warning(
"[TXN EROLL] {%s} %s", "[TXN EROLL] {%s} %s",
name, exception_to_unicode(e1), name,
exception_to_unicode(e1),
) )
continue continue
raise raise
@ -396,16 +406,17 @@ class SQLBaseStore(object):
exception_callbacks = [] exception_callbacks = []
if LoggingContext.current_context() == LoggingContext.sentinel: if LoggingContext.current_context() == LoggingContext.sentinel:
logger.warn( logger.warn("Starting db txn '%s' from sentinel context", desc)
"Starting db txn '%s' from sentinel context",
desc,
)
try: try:
result = yield self.runWithConnection( result = yield self.runWithConnection(
self._new_transaction, self._new_transaction,
desc, after_callbacks, exception_callbacks, func, desc,
*args, **kwargs after_callbacks,
exception_callbacks,
func,
*args,
**kwargs
) )
for after_callback, after_args, after_kwargs in after_callbacks: for after_callback, after_args, after_kwargs in after_callbacks:
@ -434,7 +445,7 @@ class SQLBaseStore(object):
parent_context = LoggingContext.current_context() parent_context = LoggingContext.current_context()
if parent_context == LoggingContext.sentinel: if parent_context == LoggingContext.sentinel:
logger.warn( logger.warn(
"Starting db connection from sentinel context: metrics will be lost", "Starting db connection from sentinel context: metrics will be lost"
) )
parent_context = None parent_context = None
@ -453,9 +464,7 @@ class SQLBaseStore(object):
return func(conn, *args, **kwargs) return func(conn, *args, **kwargs)
with PreserveLoggingContext(): with PreserveLoggingContext():
result = yield self._db_pool.runWithConnection( result = yield self._db_pool.runWithConnection(inner_func, *args, **kwargs)
inner_func, *args, **kwargs
)
defer.returnValue(result) defer.returnValue(result)
@ -469,9 +478,7 @@ class SQLBaseStore(object):
A list of dicts where the key is the column header. A list of dicts where the key is the column header.
""" """
col_headers = list(intern(str(column[0])) for column in cursor.description) col_headers = list(intern(str(column[0])) for column in cursor.description)
results = list( results = list(dict(zip(col_headers, row)) for row in cursor)
dict(zip(col_headers, row)) for row in cursor
)
return results return results
def _execute(self, desc, decoder, query, *args): def _execute(self, desc, decoder, query, *args):
@ -485,6 +492,7 @@ class SQLBaseStore(object):
Returns: Returns:
The result of decoder(results) The result of decoder(results)
""" """
def interaction(txn): def interaction(txn):
txn.execute(query, args) txn.execute(query, args)
if decoder: if decoder:
@ -498,8 +506,7 @@ class SQLBaseStore(object):
# no complex WHERE clauses, just a dict of values for columns. # no complex WHERE clauses, just a dict of values for columns.
@defer.inlineCallbacks @defer.inlineCallbacks
def _simple_insert(self, table, values, or_ignore=False, def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"):
desc="_simple_insert"):
"""Executes an INSERT query on the named table. """Executes an INSERT query on the named table.
Args: Args:
@ -511,10 +518,7 @@ class SQLBaseStore(object):
`or_ignore` is True `or_ignore` is True
""" """
try: try:
yield self.runInteraction( yield self.runInteraction(desc, self._simple_insert_txn, table, values)
desc,
self._simple_insert_txn, table, values,
)
except self.database_engine.module.IntegrityError: except self.database_engine.module.IntegrityError:
# We have to do or_ignore flag at this layer, since we can't reuse # We have to do or_ignore flag at this layer, since we can't reuse
# a cursor after we receive an error from the db. # a cursor after we receive an error from the db.
@ -530,15 +534,13 @@ class SQLBaseStore(object):
sql = "INSERT INTO %s (%s) VALUES(%s)" % ( sql = "INSERT INTO %s (%s) VALUES(%s)" % (
table, table,
", ".join(k for k in keys), ", ".join(k for k in keys),
", ".join("?" for _ in keys) ", ".join("?" for _ in keys),
) )
txn.execute(sql, vals) txn.execute(sql, vals)
def _simple_insert_many(self, table, values, desc): def _simple_insert_many(self, table, values, desc):
return self.runInteraction( return self.runInteraction(desc, self._simple_insert_many_txn, table, values)
desc, self._simple_insert_many_txn, table, values
)
@staticmethod @staticmethod
def _simple_insert_many_txn(txn, table, values): def _simple_insert_many_txn(txn, table, values):
@ -553,24 +555,18 @@ class SQLBaseStore(object):
# #
# The sort is to ensure that we don't rely on dictionary iteration # The sort is to ensure that we don't rely on dictionary iteration
# order. # order.
keys, vals = zip(*[ keys, vals = zip(
zip( *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
*(sorted(i.items(), key=lambda kv: kv[0]))
) )
for i in values
if i
])
for k in keys: for k in keys:
if k != keys[0]: if k != keys[0]:
raise RuntimeError( raise RuntimeError("All items must have the same keys")
"All items must have the same keys"
)
sql = "INSERT INTO %s (%s) VALUES(%s)" % ( sql = "INSERT INTO %s (%s) VALUES(%s)" % (
table, table,
", ".join(k for k in keys[0]), ", ".join(k for k in keys[0]),
", ".join("?" for _ in keys[0]) ", ".join("?" for _ in keys[0]),
) )
txn.executemany(sql, vals) txn.executemany(sql, vals)
@ -583,7 +579,7 @@ class SQLBaseStore(object):
values, values,
insertion_values={}, insertion_values={},
desc="_simple_upsert", desc="_simple_upsert",
lock=True lock=True,
): ):
""" """
@ -635,13 +631,7 @@ class SQLBaseStore(object):
) )
def _simple_upsert_txn( def _simple_upsert_txn(
self, self, txn, table, keyvalues, values, insertion_values={}, lock=True
txn,
table,
keyvalues,
values,
insertion_values={},
lock=True,
): ):
""" """
Pick the UPSERT method which works best on the platform. Either the Pick the UPSERT method which works best on the platform. Either the
@ -665,11 +655,7 @@ class SQLBaseStore(object):
and table not in self._unsafe_to_upsert_tables and table not in self._unsafe_to_upsert_tables
): ):
return self._simple_upsert_txn_native_upsert( return self._simple_upsert_txn_native_upsert(
txn, txn, table, keyvalues, values, insertion_values=insertion_values
table,
keyvalues,
values,
insertion_values=insertion_values,
) )
else: else:
return self._simple_upsert_txn_emulated( return self._simple_upsert_txn_emulated(
@ -714,7 +700,7 @@ class SQLBaseStore(object):
# SELECT instead to see if it exists. # SELECT instead to see if it exists.
sql = "SELECT 1 FROM %s WHERE %s" % ( sql = "SELECT 1 FROM %s WHERE %s" % (
table, table,
" AND ".join(_getwhere(k) for k in keyvalues) " AND ".join(_getwhere(k) for k in keyvalues),
) )
sqlargs = list(keyvalues.values()) sqlargs = list(keyvalues.values())
txn.execute(sql, sqlargs) txn.execute(sql, sqlargs)
@ -726,7 +712,7 @@ class SQLBaseStore(object):
sql = "UPDATE %s SET %s WHERE %s" % ( sql = "UPDATE %s SET %s WHERE %s" % (
table, table,
", ".join("%s = ?" % (k,) for k in values), ", ".join("%s = ?" % (k,) for k in values),
" AND ".join(_getwhere(k) for k in keyvalues) " AND ".join(_getwhere(k) for k in keyvalues),
) )
sqlargs = list(values.values()) + list(keyvalues.values()) sqlargs = list(values.values()) + list(keyvalues.values())
@ -773,19 +759,14 @@ class SQLBaseStore(object):
latter = "NOTHING" latter = "NOTHING"
else: else:
allvalues.update(values) allvalues.update(values)
latter = ( latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
"UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
)
sql = ( sql = ("INSERT INTO %s (%s) VALUES (%s) " "ON CONFLICT (%s) DO %s") % (
"INSERT INTO %s (%s) VALUES (%s) "
"ON CONFLICT (%s) DO %s"
) % (
table, table,
", ".join(k for k in allvalues), ", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues), ", ".join("?" for _ in allvalues),
", ".join(k for k in keyvalues), ", ".join(k for k in keyvalues),
latter latter,
) )
txn.execute(sql, list(allvalues.values())) txn.execute(sql, list(allvalues.values()))
@ -870,8 +851,8 @@ class SQLBaseStore(object):
latter = "NOTHING" latter = "NOTHING"
value_values = [() for x in range(len(key_values))] value_values = [() for x in range(len(key_values))]
else: else:
latter = ( latter = "UPDATE SET " + ", ".join(
"UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in value_names) k + "=EXCLUDED." + k for k in value_names
) )
sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % ( sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
@ -889,8 +870,9 @@ class SQLBaseStore(object):
return txn.execute_batch(sql, args) return txn.execute_batch(sql, args)
def _simple_select_one(self, table, keyvalues, retcols, def _simple_select_one(
allow_none=False, desc="_simple_select_one"): self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"
):
"""Executes a SELECT query on the named table, which is expected to """Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it. return a single row, returning multiple columns from it.
@ -903,14 +885,17 @@ class SQLBaseStore(object):
statement returns no rows statement returns no rows
""" """
return self.runInteraction( return self.runInteraction(
desc, desc, self._simple_select_one_txn, table, keyvalues, retcols, allow_none
self._simple_select_one_txn,
table, keyvalues, retcols, allow_none,
) )
def _simple_select_one_onecol(self, table, keyvalues, retcol, def _simple_select_one_onecol(
self,
table,
keyvalues,
retcol,
allow_none=False, allow_none=False,
desc="_simple_select_one_onecol"): desc="_simple_select_one_onecol",
):
"""Executes a SELECT query on the named table, which is expected to """Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it. return a single row, returning a single column from it.
@ -922,17 +907,18 @@ class SQLBaseStore(object):
return self.runInteraction( return self.runInteraction(
desc, desc,
self._simple_select_one_onecol_txn, self._simple_select_one_onecol_txn,
table, keyvalues, retcol, allow_none=allow_none, table,
keyvalues,
retcol,
allow_none=allow_none,
) )
@classmethod @classmethod
def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol, def _simple_select_one_onecol_txn(
allow_none=False): cls, txn, table, keyvalues, retcol, allow_none=False
):
ret = cls._simple_select_onecol_txn( ret = cls._simple_select_onecol_txn(
txn, txn, table=table, keyvalues=keyvalues, retcol=retcol
table=table,
keyvalues=keyvalues,
retcol=retcol,
) )
if ret: if ret:
@ -945,12 +931,7 @@ class SQLBaseStore(object):
@staticmethod @staticmethod
def _simple_select_onecol_txn(txn, table, keyvalues, retcol): def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
sql = ( sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
"SELECT %(retcol)s FROM %(table)s"
) % {
"retcol": retcol,
"table": table,
}
if keyvalues: if keyvalues:
sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
@ -960,8 +941,9 @@ class SQLBaseStore(object):
return [r[0] for r in txn] return [r[0] for r in txn]
def _simple_select_onecol(self, table, keyvalues, retcol, def _simple_select_onecol(
desc="_simple_select_onecol"): self, table, keyvalues, retcol, desc="_simple_select_onecol"
):
"""Executes a SELECT query on the named table, which returns a list """Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows. comprising of the values of the named column from the selected rows.
@ -974,13 +956,12 @@ class SQLBaseStore(object):
Deferred: Results in a list Deferred: Results in a list
""" """
return self.runInteraction( return self.runInteraction(
desc, desc, self._simple_select_onecol_txn, table, keyvalues, retcol
self._simple_select_onecol_txn,
table, keyvalues, retcol
) )
def _simple_select_list(self, table, keyvalues, retcols, def _simple_select_list(
desc="_simple_select_list"): self, table, keyvalues, retcols, desc="_simple_select_list"
):
"""Executes a SELECT query on the named table, which may return zero or """Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts. more rows, returning the result as a list of dicts.
@ -994,9 +975,7 @@ class SQLBaseStore(object):
defer.Deferred: resolves to list[dict[str, Any]] defer.Deferred: resolves to list[dict[str, Any]]
""" """
return self.runInteraction( return self.runInteraction(
desc, desc, self._simple_select_list_txn, table, keyvalues, retcols
self._simple_select_list_txn,
table, keyvalues, retcols
) )
@classmethod @classmethod
@ -1016,22 +995,26 @@ class SQLBaseStore(object):
sql = "SELECT %s FROM %s WHERE %s" % ( sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols), ", ".join(retcols),
table, table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues) " AND ".join("%s = ?" % (k,) for k in keyvalues),
) )
txn.execute(sql, list(keyvalues.values())) txn.execute(sql, list(keyvalues.values()))
else: else:
sql = "SELECT %s FROM %s" % ( sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
", ".join(retcols),
table
)
txn.execute(sql) txn.execute(sql)
return cls.cursor_to_dict(txn) return cls.cursor_to_dict(txn)
@defer.inlineCallbacks @defer.inlineCallbacks
def _simple_select_many_batch(self, table, column, iterable, retcols, def _simple_select_many_batch(
keyvalues={}, desc="_simple_select_many_batch", self,
batch_size=100): table,
column,
iterable,
retcols,
keyvalues={},
desc="_simple_select_many_batch",
batch_size=100,
):
"""Executes a SELECT query on the named table, which may return zero or """Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts. more rows, returning the result as a list of dicts.
@ -1053,14 +1036,17 @@ class SQLBaseStore(object):
it_list = list(iterable) it_list = list(iterable)
chunks = [ chunks = [
it_list[i:i + batch_size] it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
for i in range(0, len(it_list), batch_size)
] ]
for chunk in chunks: for chunk in chunks:
rows = yield self.runInteraction( rows = yield self.runInteraction(
desc, desc,
self._simple_select_many_txn, self._simple_select_many_txn,
table, column, chunk, keyvalues, retcols table,
column,
chunk,
keyvalues,
retcols,
) )
results.extend(rows) results.extend(rows)
@ -1089,9 +1075,7 @@ class SQLBaseStore(object):
clauses = [] clauses = []
values = [] values = []
clauses.append( clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
"%s IN (%s)" % (column, ",".join("?" for _ in iterable))
)
values.extend(iterable) values.extend(iterable)
for key, value in iteritems(keyvalues): for key, value in iteritems(keyvalues):
@ -1099,19 +1083,14 @@ class SQLBaseStore(object):
values.append(value) values.append(value)
if clauses: if clauses:
sql = "%s WHERE %s" % ( sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
sql,
" AND ".join(clauses),
)
txn.execute(sql, values) txn.execute(sql, values)
return cls.cursor_to_dict(txn) return cls.cursor_to_dict(txn)
def _simple_update(self, table, keyvalues, updatevalues, desc): def _simple_update(self, table, keyvalues, updatevalues, desc):
return self.runInteraction( return self.runInteraction(
desc, desc, self._simple_update_txn, table, keyvalues, updatevalues
self._simple_update_txn,
table, keyvalues, updatevalues,
) )
@staticmethod @staticmethod
@ -1127,15 +1106,13 @@ class SQLBaseStore(object):
where, where,
) )
txn.execute( txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values()))
update_sql,
list(updatevalues.values()) + list(keyvalues.values())
)
return txn.rowcount return txn.rowcount
def _simple_update_one(self, table, keyvalues, updatevalues, def _simple_update_one(
desc="_simple_update_one"): self, table, keyvalues, updatevalues, desc="_simple_update_one"
):
"""Executes an UPDATE query on the named table, setting new values for """Executes an UPDATE query on the named table, setting new values for
columns in a row matching the key values. columns in a row matching the key values.
@ -1154,9 +1131,7 @@ class SQLBaseStore(object):
the update column in the 'keyvalues' dict as well. the update column in the 'keyvalues' dict as well.
""" """
return self.runInteraction( return self.runInteraction(
desc, desc, self._simple_update_one_txn, table, keyvalues, updatevalues
self._simple_update_one_txn,
table, keyvalues, updatevalues,
) )
@classmethod @classmethod
@ -1169,12 +1144,11 @@ class SQLBaseStore(object):
raise StoreError(500, "More than one row matched (%s)" % (table,)) raise StoreError(500, "More than one row matched (%s)" % (table,))
@staticmethod @staticmethod
def _simple_select_one_txn(txn, table, keyvalues, retcols, def _simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
allow_none=False):
select_sql = "SELECT %s FROM %s WHERE %s" % ( select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols), ", ".join(retcols),
table, table,
" AND ".join("%s = ?" % (k,) for k in keyvalues) " AND ".join("%s = ?" % (k,) for k in keyvalues),
) )
txn.execute(select_sql, list(keyvalues.values())) txn.execute(select_sql, list(keyvalues.values()))
@ -1197,9 +1171,7 @@ class SQLBaseStore(object):
table : string giving the table name table : string giving the table name
keyvalues : dict of column names and values to select the row with keyvalues : dict of column names and values to select the row with
""" """
return self.runInteraction( return self.runInteraction(desc, self._simple_delete_one_txn, table, keyvalues)
desc, self._simple_delete_one_txn, table, keyvalues
)
@staticmethod @staticmethod
def _simple_delete_one_txn(txn, table, keyvalues): def _simple_delete_one_txn(txn, table, keyvalues):
@ -1212,7 +1184,7 @@ class SQLBaseStore(object):
""" """
sql = "DELETE FROM %s WHERE %s" % ( sql = "DELETE FROM %s WHERE %s" % (
table, table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues) " AND ".join("%s = ?" % (k,) for k in keyvalues),
) )
txn.execute(sql, list(keyvalues.values())) txn.execute(sql, list(keyvalues.values()))
@ -1222,15 +1194,13 @@ class SQLBaseStore(object):
raise StoreError(500, "More than one row matched (%s)" % (table,)) raise StoreError(500, "More than one row matched (%s)" % (table,))
def _simple_delete(self, table, keyvalues, desc): def _simple_delete(self, table, keyvalues, desc):
return self.runInteraction( return self.runInteraction(desc, self._simple_delete_txn, table, keyvalues)
desc, self._simple_delete_txn, table, keyvalues
)
@staticmethod @staticmethod
def _simple_delete_txn(txn, table, keyvalues): def _simple_delete_txn(txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % ( sql = "DELETE FROM %s WHERE %s" % (
table, table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues) " AND ".join("%s = ?" % (k,) for k in keyvalues),
) )
return txn.execute(sql, list(keyvalues.values())) return txn.execute(sql, list(keyvalues.values()))
@ -1260,9 +1230,7 @@ class SQLBaseStore(object):
clauses = [] clauses = []
values = [] values = []
clauses.append( clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
"%s IN (%s)" % (column, ",".join("?" for _ in iterable))
)
values.extend(iterable) values.extend(iterable)
for key, value in iteritems(keyvalues): for key, value in iteritems(keyvalues):
@ -1270,14 +1238,12 @@ class SQLBaseStore(object):
values.append(value) values.append(value)
if clauses: if clauses:
sql = "%s WHERE %s" % ( sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
sql,
" AND ".join(clauses),
)
return txn.execute(sql, values) return txn.execute(sql, values)
def _get_cache_dict(self, db_conn, table, entity_column, stream_column, def _get_cache_dict(
max_value, limit=100000): self, db_conn, table, entity_column, stream_column, max_value, limit=100000
):
# Fetch a mapping of room_id -> max stream position for "recent" rooms. # Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will # It doesn't really matter how many we get, the StreamChangeCache will
# do the right thing to ensure it respects the max size of cache. # do the right thing to ensure it respects the max size of cache.
@ -1297,10 +1263,7 @@ class SQLBaseStore(object):
txn = db_conn.cursor() txn = db_conn.cursor()
txn.execute(sql, (int(max_value),)) txn.execute(sql, (int(max_value),))
cache = { cache = {row[0]: int(row[1]) for row in txn}
row[0]: int(row[1])
for row in txn
}
txn.close() txn.close()
@ -1342,9 +1305,7 @@ class SQLBaseStore(object):
# be safe. # be safe.
for chunk in batch_iter(members_changed, 50): for chunk in batch_iter(members_changed, 50):
keys = itertools.chain([room_id], chunk) keys = itertools.chain([room_id], chunk)
self._send_invalidation_to_replication( self._send_invalidation_to_replication(txn, _CURRENT_STATE_CACHE_NAME, keys)
txn, _CURRENT_STATE_CACHE_NAME, keys,
)
def _invalidate_state_caches(self, room_id, members_changed): def _invalidate_state_caches(self, room_id, members_changed):
"""Invalidates caches that are based on the current state, but does """Invalidates caches that are based on the current state, but does
@ -1356,22 +1317,12 @@ class SQLBaseStore(object):
changed changed
""" """
for host in set(get_domain_from_id(u) for u in members_changed): for host in set(get_domain_from_id(u) for u in members_changed):
self._attempt_to_invalidate_cache( self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
"is_host_joined", (room_id, host,), self._attempt_to_invalidate_cache("was_host_joined", (room_id, host))
)
self._attempt_to_invalidate_cache(
"was_host_joined", (room_id, host,),
)
self._attempt_to_invalidate_cache( self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
"get_users_in_room", (room_id,), self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
) self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))
self._attempt_to_invalidate_cache(
"get_room_summary", (room_id,),
)
self._attempt_to_invalidate_cache(
"get_current_state_ids", (room_id,),
)
def _attempt_to_invalidate_cache(self, cache_name, key): def _attempt_to_invalidate_cache(self, cache_name, key):
"""Attempts to invalidate the cache of the given name, ignoring if the """Attempts to invalidate the cache of the given name, ignoring if the
@ -1419,7 +1370,7 @@ class SQLBaseStore(object):
"cache_func": cache_name, "cache_func": cache_name,
"keys": list(keys), "keys": list(keys),
"invalidation_ts": self.clock.time_msec(), "invalidation_ts": self.clock.time_msec(),
} },
) )
def get_all_updated_caches(self, last_id, current_id, limit): def get_all_updated_caches(self, last_id, current_id, limit):
@ -1435,11 +1386,10 @@ class SQLBaseStore(object):
" FROM cache_invalidation_stream" " FROM cache_invalidation_stream"
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
) )
txn.execute(sql, (last_id, limit,)) txn.execute(sql, (last_id, limit))
return txn.fetchall() return txn.fetchall()
return self.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn)
)
def get_cache_stream_token(self): def get_cache_stream_token(self):
if self._cache_id_gen: if self._cache_id_gen:
@ -1447,8 +1397,9 @@ class SQLBaseStore(object):
else: else:
return 0 return 0
def _simple_select_list_paginate(self, table, keyvalues, pagevalues, retcols, def _simple_select_list_paginate(
desc="_simple_select_list_paginate"): self, table, keyvalues, pagevalues, retcols, desc="_simple_select_list_paginate"
):
"""Executes a SELECT query on the named table with start and limit, """Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit, of row numbers, which may return zero or number of rows from start to limit,
returning the result as a list of dicts. returning the result as a list of dicts.
@ -1468,11 +1419,16 @@ class SQLBaseStore(object):
return self.runInteraction( return self.runInteraction(
desc, desc,
self._simple_select_list_paginate_txn, self._simple_select_list_paginate_txn,
table, keyvalues, pagevalues, retcols table,
keyvalues,
pagevalues,
retcols,
) )
@classmethod @classmethod
def _simple_select_list_paginate_txn(cls, txn, table, keyvalues, pagevalues, retcols): def _simple_select_list_paginate_txn(
cls, txn, table, keyvalues, pagevalues, retcols
):
"""Executes a SELECT query on the named table with start and limit, """Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit, of row numbers, which may return zero or number of rows from start to limit,
returning the result as a list of dicts. returning the result as a list of dicts.
@ -1497,22 +1453,23 @@ class SQLBaseStore(object):
", ".join(retcols), ", ".join(retcols),
table, table,
" AND ".join("%s = ?" % (k,) for k in keyvalues), " AND ".join("%s = ?" % (k,) for k in keyvalues),
" ? ASC LIMIT ? OFFSET ?" " ? ASC LIMIT ? OFFSET ?",
) )
txn.execute(sql, list(keyvalues.values()) + list(pagevalues)) txn.execute(sql, list(keyvalues.values()) + list(pagevalues))
else: else:
sql = "SELECT %s FROM %s ORDER BY %s" % ( sql = "SELECT %s FROM %s ORDER BY %s" % (
", ".join(retcols), ", ".join(retcols),
table, table,
" ? ASC LIMIT ? OFFSET ?" " ? ASC LIMIT ? OFFSET ?",
) )
txn.execute(sql, pagevalues) txn.execute(sql, pagevalues)
return cls.cursor_to_dict(txn) return cls.cursor_to_dict(txn)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_list_paginate(self, table, keyvalues, pagevalues, retcols, def get_user_list_paginate(
desc="get_user_list_paginate"): self, table, keyvalues, pagevalues, retcols, desc="get_user_list_paginate"
):
"""Get a list of users from start row to a limit number of rows. This will """Get a list of users from start row to a limit number of rows. This will
return a json object with users and total number of users in users list. return a json object with users and total number of users in users list.
@ -1532,16 +1489,13 @@ class SQLBaseStore(object):
users = yield self.runInteraction( users = yield self.runInteraction(
desc, desc,
self._simple_select_list_paginate_txn, self._simple_select_list_paginate_txn,
table, keyvalues, pagevalues, retcols table,
keyvalues,
pagevalues,
retcols,
) )
count = yield self.runInteraction( count = yield self.runInteraction(desc, self.get_user_count_txn)
desc, retval = {"users": users, "total": count}
self.get_user_count_txn
)
retval = {
"users": users,
"total": count
}
defer.returnValue(retval) defer.returnValue(retval)
def get_user_count_txn(self, txn): def get_user_count_txn(self, txn):
@ -1556,8 +1510,9 @@ class SQLBaseStore(object):
txn.execute(sql_count) txn.execute(sql_count)
return txn.fetchone()[0] return txn.fetchone()[0]
def _simple_search_list(self, table, term, col, retcols, def _simple_search_list(
desc="_simple_search_list"): self, table, term, col, retcols, desc="_simple_search_list"
):
"""Executes a SELECT query on the named table, which may return zero or """Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts. more rows, returning the result as a list of dicts.
@ -1572,9 +1527,7 @@ class SQLBaseStore(object):
""" """
return self.runInteraction( return self.runInteraction(
desc, desc, self._simple_search_list_txn, table, term, col, retcols
self._simple_search_list_txn,
table, term, col, retcols
) )
@classmethod @classmethod
@ -1593,11 +1546,7 @@ class SQLBaseStore(object):
defer.Deferred: resolves to list[dict[str, Any]] or None defer.Deferred: resolves to list[dict[str, Any]] or None
""" """
if term: if term:
sql = "SELECT %s FROM %s WHERE %s LIKE ?" % ( sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
", ".join(retcols),
table,
col
)
termvalues = ["%%" + term + "%%"] termvalues = ["%%" + term + "%%"]
txn.execute(sql, termvalues) txn.execute(sql, termvalues)
else: else:
@ -1618,6 +1567,7 @@ class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying """ This exception is used to rollback a transaction without implying
something went wrong. something went wrong.
""" """
pass pass

View File

@ -41,7 +41,7 @@ class AccountDataWorkerStore(SQLBaseStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
account_max = self.get_max_account_data_stream_id() account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache( self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max, "AccountDataAndTagsChangeCache", account_max
) )
super(AccountDataWorkerStore, self).__init__(db_conn, hs) super(AccountDataWorkerStore, self).__init__(db_conn, hs)
@ -68,8 +68,10 @@ class AccountDataWorkerStore(SQLBaseStore):
def get_account_data_for_user_txn(txn): def get_account_data_for_user_txn(txn):
rows = self._simple_select_list_txn( rows = self._simple_select_list_txn(
txn, "account_data", {"user_id": user_id}, txn,
["account_data_type", "content"] "account_data",
{"user_id": user_id},
["account_data_type", "content"],
) )
global_account_data = { global_account_data = {
@ -77,8 +79,10 @@ class AccountDataWorkerStore(SQLBaseStore):
} }
rows = self._simple_select_list_txn( rows = self._simple_select_list_txn(
txn, "room_account_data", {"user_id": user_id}, txn,
["room_id", "account_data_type", "content"] "room_account_data",
{"user_id": user_id},
["room_id", "account_data_type", "content"],
) )
by_room = {} by_room = {}
@ -100,10 +104,7 @@ class AccountDataWorkerStore(SQLBaseStore):
""" """
result = yield self._simple_select_one_onecol( result = yield self._simple_select_one_onecol(
table="account_data", table="account_data",
keyvalues={ keyvalues={"user_id": user_id, "account_data_type": data_type},
"user_id": user_id,
"account_data_type": data_type,
},
retcol="content", retcol="content",
desc="get_global_account_data_by_type_for_user", desc="get_global_account_data_by_type_for_user",
allow_none=True, allow_none=True,
@ -124,10 +125,13 @@ class AccountDataWorkerStore(SQLBaseStore):
Returns: Returns:
A deferred dict of the room account_data A deferred dict of the room account_data
""" """
def get_account_data_for_room_txn(txn): def get_account_data_for_room_txn(txn):
rows = self._simple_select_list_txn( rows = self._simple_select_list_txn(
txn, "room_account_data", {"user_id": user_id, "room_id": room_id}, txn,
["account_data_type", "content"] "room_account_data",
{"user_id": user_id, "room_id": room_id},
["account_data_type", "content"],
) )
return { return {
@ -150,6 +154,7 @@ class AccountDataWorkerStore(SQLBaseStore):
A deferred of the room account_data for that type, or None if A deferred of the room account_data for that type, or None if
there isn't any set. there isn't any set.
""" """
def get_account_data_for_room_and_type_txn(txn): def get_account_data_for_room_and_type_txn(txn):
content_json = self._simple_select_one_onecol_txn( content_json = self._simple_select_one_onecol_txn(
txn, txn,
@ -160,18 +165,18 @@ class AccountDataWorkerStore(SQLBaseStore):
"account_data_type": account_data_type, "account_data_type": account_data_type,
}, },
retcol="content", retcol="content",
allow_none=True allow_none=True,
) )
return json.loads(content_json) if content_json else None return json.loads(content_json) if content_json else None
return self.runInteraction( return self.runInteraction(
"get_account_data_for_room_and_type", "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
get_account_data_for_room_and_type_txn,
) )
def get_all_updated_account_data(self, last_global_id, last_room_id, def get_all_updated_account_data(
current_id, limit): self, last_global_id, last_room_id, current_id, limit
):
"""Get all the client account_data that has changed on the server """Get all the client account_data that has changed on the server
Args: Args:
last_global_id(int): The position to fetch from for top level data last_global_id(int): The position to fetch from for top level data
@ -201,6 +206,7 @@ class AccountDataWorkerStore(SQLBaseStore):
txn.execute(sql, (last_room_id, current_id, limit)) txn.execute(sql, (last_room_id, current_id, limit))
room_results = txn.fetchall() room_results = txn.fetchall()
return (global_results, room_results) return (global_results, room_results)
return self.runInteraction( return self.runInteraction(
"get_all_updated_account_data_txn", get_updated_account_data_txn "get_all_updated_account_data_txn", get_updated_account_data_txn
) )
@ -224,9 +230,7 @@ class AccountDataWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, stream_id)) txn.execute(sql, (user_id, stream_id))
global_account_data = { global_account_data = {row[0]: json.loads(row[1]) for row in txn}
row[0]: json.loads(row[1]) for row in txn
}
sql = ( sql = (
"SELECT room_id, account_data_type, content FROM room_account_data" "SELECT room_id, account_data_type, content FROM room_account_data"
@ -255,7 +259,8 @@ class AccountDataWorkerStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
ignored_account_data = yield self.get_global_account_data_by_type_for_user( ignored_account_data = yield self.get_global_account_data_by_type_for_user(
"m.ignored_user_list", ignorer_user_id, "m.ignored_user_list",
ignorer_user_id,
on_invalidate=cache_context.invalidate, on_invalidate=cache_context.invalidate,
) )
if not ignored_account_data: if not ignored_account_data:
@ -307,10 +312,7 @@ class AccountDataStore(AccountDataWorkerStore):
"room_id": room_id, "room_id": room_id,
"account_data_type": account_data_type, "account_data_type": account_data_type,
}, },
values={ values={"stream_id": next_id, "content": content_json},
"stream_id": next_id,
"content": content_json,
},
lock=False, lock=False,
) )
@ -324,9 +326,9 @@ class AccountDataStore(AccountDataWorkerStore):
self._account_data_stream_cache.entity_has_changed(user_id, next_id) self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,)) self.get_account_data_for_user.invalidate((user_id,))
self.get_account_data_for_room.invalidate((user_id, room_id,)) self.get_account_data_for_room.invalidate((user_id, room_id))
self.get_account_data_for_room_and_type.prefill( self.get_account_data_for_room_and_type.prefill(
(user_id, room_id, account_data_type,), content, (user_id, room_id, account_data_type), content
) )
result = self._account_data_id_gen.get_current_token() result = self._account_data_id_gen.get_current_token()
@ -351,14 +353,8 @@ class AccountDataStore(AccountDataWorkerStore):
yield self._simple_upsert( yield self._simple_upsert(
desc="add_user_account_data", desc="add_user_account_data",
table="account_data", table="account_data",
keyvalues={ keyvalues={"user_id": user_id, "account_data_type": account_data_type},
"user_id": user_id, values={"stream_id": next_id, "content": content_json},
"account_data_type": account_data_type,
},
values={
"stream_id": next_id,
"content": content_json,
},
lock=False, lock=False,
) )
@ -370,12 +366,10 @@ class AccountDataStore(AccountDataWorkerStore):
# transaction. # transaction.
yield self._update_max_stream_id(next_id) yield self._update_max_stream_id(next_id)
self._account_data_stream_cache.entity_has_changed( self._account_data_stream_cache.entity_has_changed(user_id, next_id)
user_id, next_id,
)
self.get_account_data_for_user.invalidate((user_id,)) self.get_account_data_for_user.invalidate((user_id,))
self.get_global_account_data_by_type_for_user.invalidate( self.get_global_account_data_by_type_for_user.invalidate(
(account_data_type, user_id,) (account_data_type, user_id)
) )
result = self._account_data_id_gen.get_current_token() result = self._account_data_id_gen.get_current_token()
@ -387,6 +381,7 @@ class AccountDataStore(AccountDataWorkerStore):
Args: Args:
next_id(int): The the revision to advance to. next_id(int): The the revision to advance to.
""" """
def _update(txn): def _update(txn):
update_max_id_sql = ( update_max_id_sql = (
"UPDATE account_data_max_stream_id" "UPDATE account_data_max_stream_id"
@ -394,7 +389,5 @@ class AccountDataStore(AccountDataWorkerStore):
" WHERE stream_id < ?" " WHERE stream_id < ?"
) )
txn.execute(update_max_id_sql, (next_id, next_id)) txn.execute(update_max_id_sql, (next_id, next_id))
return self.runInteraction(
"update_account_data_max_stream_id", return self.runInteraction("update_account_data_max_stream_id", _update)
_update,
)

View File

@ -51,8 +51,7 @@ def _make_exclusive_regex(services_cache):
class ApplicationServiceWorkerStore(SQLBaseStore): class ApplicationServiceWorkerStore(SQLBaseStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
self.services_cache = load_appservices( self.services_cache = load_appservices(
hs.hostname, hs.hostname, hs.config.app_service_config_files
hs.config.app_service_config_files
) )
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
@ -122,8 +121,9 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
pass pass
class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore, class ApplicationServiceTransactionWorkerStore(
EventsWorkerStore): ApplicationServiceWorkerStore, EventsWorkerStore
):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_appservices_by_state(self, state): def get_appservices_by_state(self, state):
"""Get a list of application services based on their state. """Get a list of application services based on their state.
@ -135,9 +135,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
may be empty. may be empty.
""" """
results = yield self._simple_select_list( results = yield self._simple_select_list(
"application_services_state", "application_services_state", dict(state=state), ["as_id"]
dict(state=state),
["as_id"]
) )
# NB: This assumes this class is linked with ApplicationServiceStore # NB: This assumes this class is linked with ApplicationServiceStore
as_list = self.get_app_services() as_list = self.get_app_services()
@ -180,9 +178,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
A Deferred which resolves when the state was set successfully. A Deferred which resolves when the state was set successfully.
""" """
return self._simple_upsert( return self._simple_upsert(
"application_services_state", "application_services_state", dict(as_id=service.id), dict(state=state)
dict(as_id=service.id),
dict(state=state)
) )
def create_appservice_txn(self, service, events): def create_appservice_txn(self, service, events):
@ -195,6 +191,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
Returns: Returns:
AppServiceTransaction: A new transaction. AppServiceTransaction: A new transaction.
""" """
def _create_appservice_txn(txn): def _create_appservice_txn(txn):
# work out new txn id (highest txn id for this service += 1) # work out new txn id (highest txn id for this service += 1)
# The highest id may be the last one sent (in which case it is last_txn) # The highest id may be the last one sent (in which case it is last_txn)
@ -204,7 +201,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
txn.execute( txn.execute(
"SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?", "SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
(service.id,) (service.id,),
) )
highest_txn_id = txn.fetchone()[0] highest_txn_id = txn.fetchone()[0]
if highest_txn_id is None: if highest_txn_id is None:
@ -217,16 +214,11 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
txn.execute( txn.execute(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) " "INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
"VALUES(?,?,?)", "VALUES(?,?,?)",
(service.id, new_txn_id, event_ids) (service.id, new_txn_id, event_ids),
)
return AppServiceTransaction(
service=service, id=new_txn_id, events=events
) )
return AppServiceTransaction(service=service, id=new_txn_id, events=events)
return self.runInteraction( return self.runInteraction("create_appservice_txn", _create_appservice_txn)
"create_appservice_txn",
_create_appservice_txn,
)
def complete_appservice_txn(self, txn_id, service): def complete_appservice_txn(self, txn_id, service):
"""Completes an application service transaction. """Completes an application service transaction.
@ -252,26 +244,26 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
"appservice: Completing a transaction which has an ID > 1 from " "appservice: Completing a transaction which has an ID > 1 from "
"the last ID sent to this AS. We've either dropped events or " "the last ID sent to this AS. We've either dropped events or "
"sent it to the AS out of order. FIX ME. last_txn=%s " "sent it to the AS out of order. FIX ME. last_txn=%s "
"completing_txn=%s service_id=%s", last_txn_id, txn_id, "completing_txn=%s service_id=%s",
service.id last_txn_id,
txn_id,
service.id,
) )
# Set current txn_id for AS to 'txn_id' # Set current txn_id for AS to 'txn_id'
self._simple_upsert_txn( self._simple_upsert_txn(
txn, "application_services_state", dict(as_id=service.id), txn,
dict(last_txn=txn_id) "application_services_state",
dict(as_id=service.id),
dict(last_txn=txn_id),
) )
# Delete txn # Delete txn
self._simple_delete_txn( self._simple_delete_txn(
txn, "application_services_txns", txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id)
dict(txn_id=txn_id, as_id=service.id)
) )
return self.runInteraction( return self.runInteraction("complete_appservice_txn", _complete_appservice_txn)
"complete_appservice_txn",
_complete_appservice_txn,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_oldest_unsent_txn(self, service): def get_oldest_unsent_txn(self, service):
@ -284,13 +276,14 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
A Deferred which resolves to an AppServiceTransaction or A Deferred which resolves to an AppServiceTransaction or
None. None.
""" """
def _get_oldest_unsent_txn(txn): def _get_oldest_unsent_txn(txn):
# Monotonically increasing txn ids, so just select the smallest # Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent) # one in the txns table (we delete them when they are sent)
txn.execute( txn.execute(
"SELECT * FROM application_services_txns WHERE as_id=?" "SELECT * FROM application_services_txns WHERE as_id=?"
" ORDER BY txn_id ASC LIMIT 1", " ORDER BY txn_id ASC LIMIT 1",
(service.id,) (service.id,),
) )
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
if not rows: if not rows:
@ -301,8 +294,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
return entry return entry
entry = yield self.runInteraction( entry = yield self.runInteraction(
"get_oldest_unsent_appservice_txn", "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
_get_oldest_unsent_txn,
) )
if not entry: if not entry:
@ -312,14 +304,14 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
events = yield self._get_events(event_ids) events = yield self._get_events(event_ids)
defer.returnValue(AppServiceTransaction( defer.returnValue(
service=service, id=entry["txn_id"], events=events AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
)) )
def _get_last_txn(self, txn, service_id): def _get_last_txn(self, txn, service_id):
txn.execute( txn.execute(
"SELECT last_txn FROM application_services_state WHERE as_id=?", "SELECT last_txn FROM application_services_state WHERE as_id=?",
(service_id,) (service_id,),
) )
last_txn_id = txn.fetchone() last_txn_id = txn.fetchone()
if last_txn_id is None or last_txn_id[0] is None: # no row exists if last_txn_id is None or last_txn_id[0] is None: # no row exists
@ -332,6 +324,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
txn.execute( txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
) )
return self.runInteraction( return self.runInteraction(
"set_appservice_last_pos", set_appservice_last_pos_txn "set_appservice_last_pos", set_appservice_last_pos_txn
) )
@ -362,7 +355,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
return upper_bound, [row[1] for row in rows] return upper_bound, [row[1] for row in rows]
upper_bound, event_ids = yield self.runInteraction( upper_bound, event_ids = yield self.runInteraction(
"get_new_events_for_appservice", get_new_events_for_appservice_txn, "get_new_events_for_appservice", get_new_events_for_appservice_txn
) )
events = yield self._get_events(event_ids) events = yield self._get_events(event_ids)

View File

@ -94,16 +94,13 @@ class BackgroundUpdateStore(SQLBaseStore):
self._all_done = False self._all_done = False
def start_doing_background_updates(self): def start_doing_background_updates(self):
run_as_background_process( run_as_background_process("background_updates", self._run_background_updates)
"background_updates", self._run_background_updates,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _run_background_updates(self): def _run_background_updates(self):
logger.info("Starting background schema updates") logger.info("Starting background schema updates")
while True: while True:
yield self.hs.get_clock().sleep( yield self.hs.get_clock().sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)
try: try:
result = yield self.do_next_background_update( result = yield self.do_next_background_update(
@ -187,8 +184,7 @@ class BackgroundUpdateStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_background_update(self, update_name, desired_duration_ms): def _do_background_update(self, update_name, desired_duration_ms):
logger.info("Starting update batch on background update '%s'", logger.info("Starting update batch on background update '%s'", update_name)
update_name)
update_handler = self._background_update_handlers[update_name] update_handler = self._background_update_handlers[update_name]
@ -210,7 +206,7 @@ class BackgroundUpdateStore(SQLBaseStore):
progress_json = yield self._simple_select_one_onecol( progress_json = yield self._simple_select_one_onecol(
"background_updates", "background_updates",
keyvalues={"update_name": update_name}, keyvalues={"update_name": update_name},
retcol="progress_json" retcol="progress_json",
) )
progress = json.loads(progress_json) progress = json.loads(progress_json)
@ -224,7 +220,9 @@ class BackgroundUpdateStore(SQLBaseStore):
logger.info( logger.info(
"Updating %r. Updated %r items in %rms." "Updating %r. Updated %r items in %rms."
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)", " (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
update_name, items_updated, duration_ms, update_name,
items_updated,
duration_ms,
performance.total_items_per_ms(), performance.total_items_per_ms(),
performance.average_items_per_ms(), performance.average_items_per_ms(),
performance.total_item_count, performance.total_item_count,
@ -264,6 +262,7 @@ class BackgroundUpdateStore(SQLBaseStore):
Args: Args:
update_name (str): Name of update update_name (str): Name of update
""" """
@defer.inlineCallbacks @defer.inlineCallbacks
def noop_update(progress, batch_size): def noop_update(progress, batch_size):
yield self._end_background_update(update_name) yield self._end_background_update(update_name)
@ -271,10 +270,16 @@ class BackgroundUpdateStore(SQLBaseStore):
self.register_background_update_handler(update_name, noop_update) self.register_background_update_handler(update_name, noop_update)
def register_background_index_update(self, update_name, index_name, def register_background_index_update(
table, columns, where_clause=None, self,
update_name,
index_name,
table,
columns,
where_clause=None,
unique=False, unique=False,
psql_only=False): psql_only=False,
):
"""Helper for store classes to do a background index addition """Helper for store classes to do a background index addition
To use: To use:
@ -320,7 +325,7 @@ class BackgroundUpdateStore(SQLBaseStore):
"name": index_name, "name": index_name,
"table": table, "table": table,
"columns": ", ".join(columns), "columns": ", ".join(columns),
"where_clause": "WHERE " + where_clause if where_clause else "" "where_clause": "WHERE " + where_clause if where_clause else "",
} }
logger.debug("[SQL] %s", sql) logger.debug("[SQL] %s", sql)
c.execute(sql) c.execute(sql)
@ -387,7 +392,7 @@ class BackgroundUpdateStore(SQLBaseStore):
return self._simple_insert( return self._simple_insert(
"background_updates", "background_updates",
{"update_name": update_name, "progress_json": progress_json} {"update_name": update_name, "progress_json": progress_json},
) )
def _end_background_update(self, update_name): def _end_background_update(self, update_name):

View File

@ -37,9 +37,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
self.client_ip_last_seen = Cache( self.client_ip_last_seen = Cache(
name="client_ip_last_seen", name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
keylen=4,
max_entries=50000 * CACHE_SIZE_FACTOR,
) )
super(ClientIpStore, self).__init__(db_conn, hs) super(ClientIpStore, self).__init__(db_conn, hs)
@ -66,13 +64,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
) )
self.register_background_update_handler( self.register_background_update_handler(
"user_ips_analyze", "user_ips_analyze", self._analyze_user_ip
self._analyze_user_ip,
) )
self.register_background_update_handler( self.register_background_update_handler(
"user_ips_remove_dupes", "user_ips_remove_dupes", self._remove_user_ip_dupes
self._remove_user_ip_dupes,
) )
# Register a unique index # Register a unique index
@ -86,8 +82,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
# Drop the old non-unique index # Drop the old non-unique index
self.register_background_update_handler( self.register_background_update_handler(
"user_ips_drop_nonunique_index", "user_ips_drop_nonunique_index", self._remove_user_ip_nonunique
self._remove_user_ip_nonunique,
) )
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
@ -104,9 +99,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
def _remove_user_ip_nonunique(self, progress, batch_size): def _remove_user_ip_nonunique(self, progress, batch_size):
def f(conn): def f(conn):
txn = conn.cursor() txn = conn.cursor()
txn.execute( txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
"DROP INDEX IF EXISTS user_ips_user_ip"
)
txn.close() txn.close()
yield self.runWithConnection(f) yield self.runWithConnection(f)
@ -124,9 +117,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
def user_ips_analyze(txn): def user_ips_analyze(txn):
txn.execute("ANALYZE user_ips") txn.execute("ANALYZE user_ips")
yield self.runInteraction( yield self.runInteraction("user_ips_analyze", user_ips_analyze)
"user_ips_analyze", user_ips_analyze
)
yield self._end_background_update("user_ips_analyze") yield self._end_background_update("user_ips_analyze")
@ -151,7 +142,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
LIMIT 1 LIMIT 1
OFFSET ? OFFSET ?
""", """,
(begin_last_seen, batch_size) (begin_last_seen, batch_size),
) )
row = txn.fetchone() row = txn.fetchone()
if row: if row:
@ -169,7 +160,8 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
logger.info( logger.info(
"Scanning for duplicate 'user_ips' rows in range: %s <= last_seen < %s", "Scanning for duplicate 'user_ips' rows in range: %s <= last_seen < %s",
begin_last_seen, end_last_seen, begin_last_seen,
end_last_seen,
) )
def remove(txn): def remove(txn):
@ -207,8 +199,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
INNER JOIN user_ips USING (user_id, access_token, ip) INNER JOIN user_ips USING (user_id, access_token, ip)
GROUP BY user_id, access_token, ip GROUP BY user_id, access_token, ip
HAVING count(*) > 1 HAVING count(*) > 1
""".format(clause), """.format(
args clause
),
args,
) )
res = txn.fetchall() res = txn.fetchall()
@ -254,7 +248,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
DELETE FROM user_ips DELETE FROM user_ips
WHERE user_id = ? AND access_token = ? AND ip = ? AND last_seen < ? WHERE user_id = ? AND access_token = ? AND ip = ? AND last_seen < ?
""", """,
(user_id, access_token, ip, last_seen) (user_id, access_token, ip, last_seen),
) )
if txn.rowcount == count - 1: if txn.rowcount == count - 1:
# We deleted all but one of the duplicate rows, i.e. there # We deleted all but one of the duplicate rows, i.e. there
@ -263,7 +257,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
continue continue
elif txn.rowcount >= count: elif txn.rowcount >= count:
raise Exception( raise Exception(
"We deleted more duplicate rows from 'user_ips' than expected", "We deleted more duplicate rows from 'user_ips' than expected"
) )
# The previous step didn't delete enough rows, so we fallback to # The previous step didn't delete enough rows, so we fallback to
@ -275,7 +269,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
DELETE FROM user_ips DELETE FROM user_ips
WHERE user_id = ? AND access_token = ? AND ip = ? WHERE user_id = ? AND access_token = ? AND ip = ?
""", """,
(user_id, access_token, ip) (user_id, access_token, ip),
) )
# Add in one to be the last_seen # Add in one to be the last_seen
@ -285,7 +279,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
(user_id, access_token, ip, device_id, user_agent, last_seen) (user_id, access_token, ip, device_id, user_agent, last_seen)
VALUES (?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?)
""", """,
(user_id, access_token, ip, device_id, user_agent, last_seen) (user_id, access_token, ip, device_id, user_agent, last_seen),
) )
self._background_update_progress_txn( self._background_update_progress_txn(
@ -300,8 +294,9 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
defer.returnValue(batch_size) defer.returnValue(batch_size)
@defer.inlineCallbacks @defer.inlineCallbacks
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id, def insert_client_ip(
now=None): self, user_id, access_token, ip, user_agent, device_id, now=None
):
if not now: if not now:
now = int(self._clock.time_msec()) now = int(self._clock.time_msec())
key = (user_id, access_token, ip) key = (user_id, access_token, ip)
@ -329,13 +324,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
to_update = self._batch_row_update to_update = self._batch_row_update
self._batch_row_update = {} self._batch_row_update = {}
return self.runInteraction( return self.runInteraction(
"_update_client_ips_batch", self._update_client_ips_batch_txn, "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
to_update,
) )
return run_as_background_process( return run_as_background_process("update_client_ips", update)
"update_client_ips", update,
)
def _update_client_ips_batch_txn(self, txn, to_update): def _update_client_ips_batch_txn(self, txn, to_update):
if "user_ips" in self._unsafe_to_upsert_tables or ( if "user_ips" in self._unsafe_to_upsert_tables or (
@ -383,7 +375,8 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
res = yield self.runInteraction( res = yield self.runInteraction(
"get_last_client_ip_by_device", "get_last_client_ip_by_device",
self._get_last_client_ip_by_device_txn, self._get_last_client_ip_by_device_txn,
user_id, device_id, user_id,
device_id,
retcols=( retcols=(
"user_id", "user_id",
"access_token", "access_token",
@ -428,9 +421,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips " "SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips "
"WHERE %(where)s " "WHERE %(where)s "
"GROUP BY user_id, device_id" "GROUP BY user_id, device_id"
) % { ) % {"where": " OR ".join(where_clauses)}
"where": " OR ".join(where_clauses),
}
sql = ( sql = (
"SELECT %(retcols)s FROM user_ips " "SELECT %(retcols)s FROM user_ips "
@ -462,9 +453,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
rows = yield self._simple_select_list( rows = yield self._simple_select_list(
table="user_ips", table="user_ips",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcols=[ retcols=["access_token", "ip", "user_agent", "last_seen"],
"access_token", "ip", "user_agent", "last_seen"
],
desc="get_user_ip_and_agents", desc="get_user_ip_and_agents",
) )
@ -472,7 +461,8 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"])) ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
for row in rows for row in rows
) )
defer.returnValue(list( defer.returnValue(
list(
{ {
"access_token": access_token, "access_token": access_token,
"ip": ip, "ip": ip,
@ -480,4 +470,5 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"last_seen": last_seen, "last_seen": last_seen,
} }
for (access_token, ip), (user_agent, last_seen) in iteritems(results) for (access_token, ip), (user_agent, last_seen) in iteritems(results)
)) )
)

View File

@ -57,9 +57,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
" ORDER BY stream_id ASC" " ORDER BY stream_id ASC"
" LIMIT ?" " LIMIT ?"
) )
txn.execute(sql, ( txn.execute(
user_id, device_id, last_stream_id, current_stream_id, limit sql, (user_id, device_id, last_stream_id, current_stream_id, limit)
)) )
messages = [] messages = []
for row in txn: for row in txn:
stream_pos = row[0] stream_pos = row[0]
@ -69,7 +69,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
return (messages, stream_pos) return (messages, stream_pos)
return self.runInteraction( return self.runInteraction(
"get_new_messages_for_device", get_new_messages_for_device_txn, "get_new_messages_for_device", get_new_messages_for_device_txn
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -146,9 +146,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
" ORDER BY stream_id ASC" " ORDER BY stream_id ASC"
" LIMIT ?" " LIMIT ?"
) )
txn.execute(sql, ( txn.execute(sql, (destination, last_stream_id, current_stream_id, limit))
destination, last_stream_id, current_stream_id, limit
))
messages = [] messages = []
for row in txn: for row in txn:
stream_pos = row[0] stream_pos = row[0]
@ -172,6 +170,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
Returns: Returns:
A deferred that resolves when the messages have been deleted. A deferred that resolves when the messages have been deleted.
""" """
def delete_messages_for_remote_destination_txn(txn): def delete_messages_for_remote_destination_txn(txn):
sql = ( sql = (
"DELETE FROM device_federation_outbox" "DELETE FROM device_federation_outbox"
@ -181,8 +180,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
txn.execute(sql, (destination, up_to_stream_id)) txn.execute(sql, (destination, up_to_stream_id))
return self.runInteraction( return self.runInteraction(
"delete_device_msgs_for_remote", "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
delete_messages_for_remote_destination_txn
) )
@ -200,8 +198,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
) )
self.register_background_update_handler( self.register_background_update_handler(
self.DEVICE_INBOX_STREAM_ID, self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
self._background_drop_index_device_inbox,
) )
# Map of (user_id, device_id) to the last stream_id that has been # Map of (user_id, device_id) to the last stream_id that has been
@ -214,8 +211,9 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def add_messages_to_device_inbox(self, local_messages_by_user_then_device, def add_messages_to_device_inbox(
remote_messages_by_destination): self, local_messages_by_user_then_device, remote_messages_by_destination
):
"""Used to send messages from this server. """Used to send messages from this server.
Args: Args:
@ -252,15 +250,10 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
with self._device_inbox_id_gen.get_next() as stream_id: with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec() now_ms = self.clock.time_msec()
yield self.runInteraction( yield self.runInteraction(
"add_messages_to_device_inbox", "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
add_messages_txn,
now_ms,
stream_id,
) )
for user_id in local_messages_by_user_then_device.keys(): for user_id in local_messages_by_user_then_device.keys():
self._device_inbox_stream_cache.entity_has_changed( self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id)
user_id, stream_id
)
for destination in remote_messages_by_destination.keys(): for destination in remote_messages_by_destination.keys():
self._device_federation_outbox_stream_cache.entity_has_changed( self._device_federation_outbox_stream_cache.entity_has_changed(
destination, stream_id destination, stream_id
@ -277,7 +270,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
# origin. This can happen if the origin doesn't receive our # origin. This can happen if the origin doesn't receive our
# acknowledgement from the first time we received the message. # acknowledgement from the first time we received the message.
already_inserted = self._simple_select_one_txn( already_inserted = self._simple_select_one_txn(
txn, table="device_federation_inbox", txn,
table="device_federation_inbox",
keyvalues={"origin": origin, "message_id": message_id}, keyvalues={"origin": origin, "message_id": message_id},
retcols=("message_id",), retcols=("message_id",),
allow_none=True, allow_none=True,
@ -288,7 +282,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
# Add an entry for this message_id so that we know we've processed # Add an entry for this message_id so that we know we've processed
# it. # it.
self._simple_insert_txn( self._simple_insert_txn(
txn, table="device_federation_inbox", txn,
table="device_federation_inbox",
values={ values={
"origin": origin, "origin": origin,
"message_id": message_id, "message_id": message_id,
@ -311,19 +306,14 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
stream_id, stream_id,
) )
for user_id in local_messages_by_user_then_device.keys(): for user_id in local_messages_by_user_then_device.keys():
self._device_inbox_stream_cache.entity_has_changed( self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id)
user_id, stream_id
)
defer.returnValue(stream_id) defer.returnValue(stream_id)
def _add_messages_to_local_device_inbox_txn(self, txn, stream_id, def _add_messages_to_local_device_inbox_txn(
messages_by_user_then_device): self, txn, stream_id, messages_by_user_then_device
sql = ( ):
"UPDATE device_max_stream_id" sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?"
" SET stream_id = ?"
" WHERE stream_id < ?"
)
txn.execute(sql, (stream_id, stream_id)) txn.execute(sql, (stream_id, stream_id))
local_by_user_then_device = {} local_by_user_then_device = {}
@ -332,10 +322,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
devices = list(messages_by_device.keys()) devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*": if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids. # Handle wildcard device_ids.
sql = ( sql = "SELECT device_id FROM devices" " WHERE user_id = ?"
"SELECT device_id FROM devices"
" WHERE user_id = ?"
)
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
message_json = json.dumps(messages_by_device["*"]) message_json = json.dumps(messages_by_device["*"])
for row in txn: for row in txn:
@ -428,9 +415,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
def _background_drop_index_device_inbox(self, progress, batch_size): def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn): def reindex_txn(conn):
txn = conn.cursor() txn = conn.cursor()
txn.execute( txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
"DROP INDEX IF EXISTS device_inbox_stream_id"
)
txn.close() txn.close()
yield self.runWithConnection(reindex_txn) yield self.runWithConnection(reindex_txn)

View File

@ -67,7 +67,7 @@ class DeviceWorkerStore(SQLBaseStore):
table="devices", table="devices",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcols=("user_id", "device_id", "display_name"), retcols=("user_id", "device_id", "display_name"),
desc="get_devices_by_user" desc="get_devices_by_user",
) )
defer.returnValue({d["device_id"]: d for d in devices}) defer.returnValue({d["device_id"]: d for d in devices})
@ -87,21 +87,23 @@ class DeviceWorkerStore(SQLBaseStore):
return (now_stream_id, []) return (now_stream_id, [])
return self.runInteraction( return self.runInteraction(
"get_devices_by_remote", self._get_devices_by_remote_txn, "get_devices_by_remote",
destination, from_stream_id, now_stream_id, self._get_devices_by_remote_txn,
destination,
from_stream_id,
now_stream_id,
) )
def _get_devices_by_remote_txn(self, txn, destination, from_stream_id, def _get_devices_by_remote_txn(
now_stream_id): self, txn, destination, from_stream_id, now_stream_id
):
sql = """ sql = """
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
GROUP BY user_id, device_id GROUP BY user_id, device_id
LIMIT 20 LIMIT 20
""" """
txn.execute( txn.execute(sql, (destination, from_stream_id, now_stream_id, False))
sql, (destination, from_stream_id, now_stream_id, False)
)
# maps (user_id, device_id) -> stream_id # maps (user_id, device_id) -> stream_id
query_map = {(r[0], r[1]): r[2] for r in txn} query_map = {(r[0], r[1]): r[2] for r in txn}
@ -112,7 +114,10 @@ class DeviceWorkerStore(SQLBaseStore):
now_stream_id = max(stream_id for stream_id in itervalues(query_map)) now_stream_id = max(stream_id for stream_id in itervalues(query_map))
devices = self._get_e2e_device_keys_txn( devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True txn,
query_map.keys(),
include_all_devices=True,
include_deleted_devices=True,
) )
prev_sent_id_sql = """ prev_sent_id_sql = """
@ -157,8 +162,10 @@ class DeviceWorkerStore(SQLBaseStore):
"""Mark that updates have successfully been sent to the destination. """Mark that updates have successfully been sent to the destination.
""" """
return self.runInteraction( return self.runInteraction(
"mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, "mark_as_sent_devices_by_remote",
destination, stream_id, self._mark_as_sent_devices_by_remote_txn,
destination,
stream_id,
) )
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
@ -173,7 +180,7 @@ class DeviceWorkerStore(SQLBaseStore):
WHERE destination = ? AND o.stream_id <= ? WHERE destination = ? AND o.stream_id <= ?
GROUP BY user_id GROUP BY user_id
""" """
txn.execute(sql, (destination, stream_id,)) txn.execute(sql, (destination, stream_id))
rows = txn.fetchall() rows = txn.fetchall()
sql = """ sql = """
@ -181,16 +188,14 @@ class DeviceWorkerStore(SQLBaseStore):
SET stream_id = ? SET stream_id = ?
WHERE destination = ? AND user_id = ? WHERE destination = ? AND user_id = ?
""" """
txn.executemany( txn.executemany(sql, ((row[1], destination, row[0]) for row in rows if row[2]))
sql, ((row[1], destination, row[0],) for row in rows if row[2])
)
sql = """ sql = """
INSERT INTO device_lists_outbound_last_success INSERT INTO device_lists_outbound_last_success
(destination, user_id, stream_id) VALUES (?, ?, ?) (destination, user_id, stream_id) VALUES (?, ?, ?)
""" """
txn.executemany( txn.executemany(
sql, ((destination, row[0], row[1],) for row in rows if not row[2]) sql, ((destination, row[0], row[1]) for row in rows if not row[2])
) )
# Delete all sent outbound pokes # Delete all sent outbound pokes
@ -198,7 +203,7 @@ class DeviceWorkerStore(SQLBaseStore):
DELETE FROM device_lists_outbound_pokes DELETE FROM device_lists_outbound_pokes
WHERE destination = ? AND stream_id <= ? WHERE destination = ? AND stream_id <= ?
""" """
txn.execute(sql, (destination, stream_id,)) txn.execute(sql, (destination, stream_id))
def get_device_stream_token(self): def get_device_stream_token(self):
return self._device_list_id_gen.get_current_token() return self._device_list_id_gen.get_current_token()
@ -240,10 +245,7 @@ class DeviceWorkerStore(SQLBaseStore):
def _get_cached_user_device(self, user_id, device_id): def _get_cached_user_device(self, user_id, device_id):
content = yield self._simple_select_one_onecol( content = yield self._simple_select_one_onecol(
table="device_lists_remote_cache", table="device_lists_remote_cache",
keyvalues={ keyvalues={"user_id": user_id, "device_id": device_id},
"user_id": user_id,
"device_id": device_id,
},
retcol="content", retcol="content",
desc="_get_cached_user_device", desc="_get_cached_user_device",
) )
@ -253,16 +255,13 @@ class DeviceWorkerStore(SQLBaseStore):
def _get_cached_devices_for_user(self, user_id): def _get_cached_devices_for_user(self, user_id):
devices = yield self._simple_select_list( devices = yield self._simple_select_list(
table="device_lists_remote_cache", table="device_lists_remote_cache",
keyvalues={ keyvalues={"user_id": user_id},
"user_id": user_id,
},
retcols=("device_id", "content"), retcols=("device_id", "content"),
desc="_get_cached_devices_for_user", desc="_get_cached_devices_for_user",
) )
defer.returnValue({ defer.returnValue(
device["device_id"]: db_to_json(device["content"]) {device["device_id"]: db_to_json(device["content"]) for device in devices}
for device in devices )
})
def get_devices_with_keys_by_user(self, user_id): def get_devices_with_keys_by_user(self, user_id):
"""Get all devices (with any device keys) for a user """Get all devices (with any device keys) for a user
@ -272,7 +271,8 @@ class DeviceWorkerStore(SQLBaseStore):
""" """
return self.runInteraction( return self.runInteraction(
"get_devices_with_keys_by_user", "get_devices_with_keys_by_user",
self._get_devices_with_keys_by_user_txn, user_id, self._get_devices_with_keys_by_user_txn,
user_id,
) )
def _get_devices_with_keys_by_user_txn(self, txn, user_id): def _get_devices_with_keys_by_user_txn(self, txn, user_id):
@ -286,9 +286,7 @@ class DeviceWorkerStore(SQLBaseStore):
user_devices = devices[user_id] user_devices = devices[user_id]
results = [] results = []
for device_id, device in iteritems(user_devices): for device_id, device in iteritems(user_devices):
result = { result = {"device_id": device_id}
"device_id": device_id,
}
key_json = device.get("key_json", None) key_json = device.get("key_json", None)
if key_json: if key_json:
@ -315,7 +313,9 @@ class DeviceWorkerStore(SQLBaseStore):
sql = """ sql = """
SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ? SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
""" """
rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) rows = yield self._execute(
"get_user_whose_devices_changed", None, sql, from_key
)
defer.returnValue(set(row[0] for row in rows)) defer.returnValue(set(row[0] for row in rows))
def get_all_device_list_changes_for_remotes(self, from_key, to_key): def get_all_device_list_changes_for_remotes(self, from_key, to_key):
@ -333,8 +333,7 @@ class DeviceWorkerStore(SQLBaseStore):
GROUP BY user_id, destination GROUP BY user_id, destination
""" """
return self._execute( return self._execute(
"get_all_device_list_changes_for_remotes", None, "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
sql, from_key, to_key
) )
@cached(max_entries=10000) @cached(max_entries=10000)
@ -350,21 +349,22 @@ class DeviceWorkerStore(SQLBaseStore):
allow_none=True, allow_none=True,
) )
@cachedList(cached_method_name="get_device_list_last_stream_id_for_remote", @cachedList(
list_name="user_ids", inlineCallbacks=True) cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids",
inlineCallbacks=True,
)
def get_device_list_last_stream_id_for_remotes(self, user_ids): def get_device_list_last_stream_id_for_remotes(self, user_ids):
rows = yield self._simple_select_many_batch( rows = yield self._simple_select_many_batch(
table="device_lists_remote_extremeties", table="device_lists_remote_extremeties",
column="user_id", column="user_id",
iterable=user_ids, iterable=user_ids,
retcols=("user_id", "stream_id",), retcols=("user_id", "stream_id"),
desc="get_device_list_last_stream_id_for_remotes", desc="get_device_list_last_stream_id_for_remotes",
) )
results = {user_id: None for user_id in user_ids} results = {user_id: None for user_id in user_ids}
results.update({ results.update({row["user_id"]: row["stream_id"] for row in rows})
row["user_id"]: row["stream_id"] for row in rows
})
defer.returnValue(results) defer.returnValue(results)
@ -376,14 +376,10 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
# Map of (user_id, device_id) -> bool. If there is an entry that implies # Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists. # the device exists.
self.device_id_exists_cache = Cache( self.device_id_exists_cache = Cache(
name="device_id_exists", name="device_id_exists", keylen=2, max_entries=10000
keylen=2,
max_entries=10000,
) )
self._clock.looping_call( self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
)
self.register_background_index_update( self.register_background_index_update(
"device_lists_stream_idx", "device_lists_stream_idx",
@ -417,8 +413,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def store_device(self, user_id, device_id, def store_device(self, user_id, device_id, initial_device_display_name):
initial_device_display_name):
"""Ensure the given device is known; add it to the store if not """Ensure the given device is known; add it to the store if not
Args: Args:
@ -440,7 +435,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
values={ values={
"user_id": user_id, "user_id": user_id,
"device_id": device_id, "device_id": device_id,
"display_name": initial_device_display_name "display_name": initial_device_display_name,
}, },
desc="store_device", desc="store_device",
or_ignore=True, or_ignore=True,
@ -448,12 +443,17 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
self.device_id_exists_cache.prefill(key, True) self.device_id_exists_cache.prefill(key, True)
defer.returnValue(inserted) defer.returnValue(inserted)
except Exception as e: except Exception as e:
logger.error("store_device with device_id=%s(%r) user_id=%s(%r)" logger.error(
"store_device with device_id=%s(%r) user_id=%s(%r)"
" display_name=%s(%r) failed: %s", " display_name=%s(%r) failed: %s",
type(device_id).__name__, device_id, type(device_id).__name__,
type(user_id).__name__, user_id, device_id,
type(user_id).__name__,
user_id,
type(initial_device_display_name).__name__, type(initial_device_display_name).__name__,
initial_device_display_name, e) initial_device_display_name,
e,
)
raise StoreError(500, "Problem storing device.") raise StoreError(500, "Problem storing device.")
@defer.inlineCallbacks @defer.inlineCallbacks
@ -525,15 +525,14 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
""" """
yield self._simple_delete( yield self._simple_delete(
table="device_lists_remote_extremeties", table="device_lists_remote_extremeties",
keyvalues={ keyvalues={"user_id": user_id},
"user_id": user_id,
},
desc="mark_remote_user_device_list_as_unsubscribed", desc="mark_remote_user_device_list_as_unsubscribed",
) )
self.get_device_list_last_stream_id_for_remote.invalidate((user_id,)) self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
def update_remote_device_list_cache_entry(self, user_id, device_id, content, def update_remote_device_list_cache_entry(
stream_id): self, user_id, device_id, content, stream_id
):
"""Updates a single device in the cache of a remote user's devicelist. """Updates a single device in the cache of a remote user's devicelist.
Note: assumes that we are the only thread that can be updating this user's Note: assumes that we are the only thread that can be updating this user's
@ -551,42 +550,35 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
return self.runInteraction( return self.runInteraction(
"update_remote_device_list_cache_entry", "update_remote_device_list_cache_entry",
self._update_remote_device_list_cache_entry_txn, self._update_remote_device_list_cache_entry_txn,
user_id, device_id, content, stream_id, user_id,
device_id,
content,
stream_id,
) )
def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id, def _update_remote_device_list_cache_entry_txn(
content, stream_id): self, txn, user_id, device_id, content, stream_id
):
if content.get("deleted"): if content.get("deleted"):
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
table="device_lists_remote_cache", table="device_lists_remote_cache",
keyvalues={ keyvalues={"user_id": user_id, "device_id": device_id},
"user_id": user_id,
"device_id": device_id,
},
) )
txn.call_after( txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id))
self.device_id_exists_cache.invalidate, (user_id, device_id,)
)
else: else:
self._simple_upsert_txn( self._simple_upsert_txn(
txn, txn,
table="device_lists_remote_cache", table="device_lists_remote_cache",
keyvalues={ keyvalues={"user_id": user_id, "device_id": device_id},
"user_id": user_id, values={"content": json.dumps(content)},
"device_id": device_id,
},
values={
"content": json.dumps(content),
},
# we don't need to lock, because we assume we are the only thread # we don't need to lock, because we assume we are the only thread
# updating this user's devices. # updating this user's devices.
lock=False, lock=False,
) )
txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,)) txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id))
txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
txn.call_after( txn.call_after(
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
@ -595,13 +587,8 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
self._simple_upsert_txn( self._simple_upsert_txn(
txn, txn,
table="device_lists_remote_extremeties", table="device_lists_remote_extremeties",
keyvalues={ keyvalues={"user_id": user_id},
"user_id": user_id, values={"stream_id": stream_id},
},
values={
"stream_id": stream_id,
},
# again, we can assume we are the only thread updating this user's # again, we can assume we are the only thread updating this user's
# extremity. # extremity.
lock=False, lock=False,
@ -624,17 +611,14 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
return self.runInteraction( return self.runInteraction(
"update_remote_device_list_cache", "update_remote_device_list_cache",
self._update_remote_device_list_cache_txn, self._update_remote_device_list_cache_txn,
user_id, devices, stream_id, user_id,
devices,
stream_id,
) )
def _update_remote_device_list_cache_txn(self, txn, user_id, devices, def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id):
stream_id):
self._simple_delete_txn( self._simple_delete_txn(
txn, txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
},
) )
self._simple_insert_many_txn( self._simple_insert_many_txn(
@ -647,7 +631,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
"content": json.dumps(content), "content": json.dumps(content),
} }
for content in devices for content in devices
] ],
) )
txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
@ -659,13 +643,8 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
self._simple_upsert_txn( self._simple_upsert_txn(
txn, txn,
table="device_lists_remote_extremeties", table="device_lists_remote_extremeties",
keyvalues={ keyvalues={"user_id": user_id},
"user_id": user_id, values={"stream_id": stream_id},
},
values={
"stream_id": stream_id,
},
# we don't need to lock, because we can assume we are the only thread # we don't need to lock, because we can assume we are the only thread
# updating this user's extremity. # updating this user's extremity.
lock=False, lock=False,
@ -678,8 +657,12 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
""" """
with self._device_list_id_gen.get_next() as stream_id: with self._device_list_id_gen.get_next() as stream_id:
yield self.runInteraction( yield self.runInteraction(
"add_device_change_to_streams", self._add_device_change_txn, "add_device_change_to_streams",
user_id, device_ids, hosts, stream_id, self._add_device_change_txn,
user_id,
device_ids,
hosts,
stream_id,
) )
defer.returnValue(stream_id) defer.returnValue(stream_id)
@ -687,13 +670,13 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
now = self._clock.time_msec() now = self._clock.time_msec()
txn.call_after( txn.call_after(
self._device_list_stream_cache.entity_has_changed, self._device_list_stream_cache.entity_has_changed, user_id, stream_id
user_id, stream_id,
) )
for host in hosts: for host in hosts:
txn.call_after( txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed, self._device_list_federation_stream_cache.entity_has_changed,
host, stream_id, host,
stream_id,
) )
# Delete older entries in the table, as we really only care about # Delete older entries in the table, as we really only care about
@ -703,20 +686,16 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
DELETE FROM device_lists_stream DELETE FROM device_lists_stream
WHERE user_id = ? AND device_id = ? AND stream_id < ? WHERE user_id = ? AND device_id = ? AND stream_id < ?
""", """,
[(user_id, device_id, stream_id) for device_id in device_ids] [(user_id, device_id, stream_id) for device_id in device_ids],
) )
self._simple_insert_many_txn( self._simple_insert_many_txn(
txn, txn,
table="device_lists_stream", table="device_lists_stream",
values=[ values=[
{ {"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
"stream_id": stream_id,
"user_id": user_id,
"device_id": device_id,
}
for device_id in device_ids for device_id in device_ids
] ],
) )
self._simple_insert_many_txn( self._simple_insert_many_txn(
@ -733,7 +712,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
} }
for destination in hosts for destination in hosts
for device_id in device_ids for device_id in device_ids
] ],
) )
def _prune_old_outbound_device_pokes(self): def _prune_old_outbound_device_pokes(self):
@ -764,11 +743,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
""" """
txn.executemany( txn.executemany(
delete_sql, delete_sql, ((yesterday, row[0], row[1], row[2]) for row in rows)
(
(yesterday, row[0], row[1], row[2])
for row in rows
)
) )
# Since we've deleted unsent deltas, we need to remove the entry # Since we've deleted unsent deltas, we need to remove the entry
@ -792,12 +767,8 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
def f(conn): def f(conn):
txn = conn.cursor() txn = conn.cursor()
txn.execute( txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
"DROP INDEX IF EXISTS device_lists_remote_cache_id" txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
)
txn.execute(
"DROP INDEX IF EXISTS device_lists_remote_extremeties_id"
)
txn.close() txn.close()
yield self.runWithConnection(f) yield self.runWithConnection(f)

View File

@ -22,10 +22,7 @@ from synapse.util.caches.descriptors import cached
from ._base import SQLBaseStore from ._base import SQLBaseStore
RoomAliasMapping = namedtuple( RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
"RoomAliasMapping",
("room_id", "room_alias", "servers",)
)
class DirectoryWorkerStore(SQLBaseStore): class DirectoryWorkerStore(SQLBaseStore):
@ -63,16 +60,12 @@ class DirectoryWorkerStore(SQLBaseStore):
defer.returnValue(None) defer.returnValue(None)
return return
defer.returnValue( defer.returnValue(RoomAliasMapping(room_id, room_alias.to_string(), servers))
RoomAliasMapping(room_id, room_alias.to_string(), servers)
)
def get_room_alias_creator(self, room_alias): def get_room_alias_creator(self, room_alias):
return self._simple_select_one_onecol( return self._simple_select_one_onecol(
table="room_aliases", table="room_aliases",
keyvalues={ keyvalues={"room_alias": room_alias},
"room_alias": room_alias,
},
retcol="creator", retcol="creator",
desc="get_room_alias_creator", desc="get_room_alias_creator",
) )
@ -101,6 +94,7 @@ class DirectoryStore(DirectoryWorkerStore):
Returns: Returns:
Deferred Deferred
""" """
def alias_txn(txn): def alias_txn(txn):
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
@ -115,10 +109,10 @@ class DirectoryStore(DirectoryWorkerStore):
self._simple_insert_many_txn( self._simple_insert_many_txn(
txn, txn,
table="room_alias_servers", table="room_alias_servers",
values=[{ values=[
"room_alias": room_alias.to_string(), {"room_alias": room_alias.to_string(), "server": server}
"server": server, for server in servers
} for server in servers], ],
) )
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
@ -126,9 +120,7 @@ class DirectoryStore(DirectoryWorkerStore):
) )
try: try:
ret = yield self.runInteraction( ret = yield self.runInteraction("create_room_alias_association", alias_txn)
"create_room_alias_association", alias_txn
)
except self.database_engine.module.IntegrityError: except self.database_engine.module.IntegrityError:
raise SynapseError( raise SynapseError(
409, "Room alias %s already exists" % room_alias.to_string() 409, "Room alias %s already exists" % room_alias.to_string()
@ -138,9 +130,7 @@ class DirectoryStore(DirectoryWorkerStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_room_alias(self, room_alias): def delete_room_alias(self, room_alias):
room_id = yield self.runInteraction( room_id = yield self.runInteraction(
"delete_room_alias", "delete_room_alias", self._delete_room_alias_txn, room_alias
self._delete_room_alias_txn,
room_alias,
) )
defer.returnValue(room_id) defer.returnValue(room_id)
@ -148,7 +138,7 @@ class DirectoryStore(DirectoryWorkerStore):
def _delete_room_alias_txn(self, txn, room_alias): def _delete_room_alias_txn(self, txn, room_alias):
txn.execute( txn.execute(
"SELECT room_id FROM room_aliases WHERE room_alias = ?", "SELECT room_id FROM room_aliases WHERE room_alias = ?",
(room_alias.to_string(),) (room_alias.to_string(),),
) )
res = txn.fetchone() res = txn.fetchone()
@ -158,31 +148,29 @@ class DirectoryStore(DirectoryWorkerStore):
return None return None
txn.execute( txn.execute(
"DELETE FROM room_aliases WHERE room_alias = ?", "DELETE FROM room_aliases WHERE room_alias = ?", (room_alias.to_string(),)
(room_alias.to_string(),)
) )
txn.execute( txn.execute(
"DELETE FROM room_alias_servers WHERE room_alias = ?", "DELETE FROM room_alias_servers WHERE room_alias = ?",
(room_alias.to_string(),) (room_alias.to_string(),),
) )
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(txn, self.get_aliases_for_room, (room_id,))
txn, self.get_aliases_for_room, (room_id,)
)
return room_id return room_id
def update_aliases_for_room(self, old_room_id, new_room_id, creator): def update_aliases_for_room(self, old_room_id, new_room_id, creator):
def _update_aliases_for_room_txn(txn): def _update_aliases_for_room_txn(txn):
sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?" sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?"
txn.execute(sql, (new_room_id, creator, old_room_id,)) txn.execute(sql, (new_room_id, creator, old_room_id))
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self.get_aliases_for_room, (old_room_id,) txn, self.get_aliases_for_room, (old_room_id,)
) )
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self.get_aliases_for_room, (new_room_id,) txn, self.get_aliases_for_room, (new_room_id,)
) )
return self.runInteraction( return self.runInteraction(
"_update_aliases_for_room_txn", _update_aliases_for_room_txn "_update_aliases_for_room_txn", _update_aliases_for_room_txn
) )

View File

@ -23,7 +23,6 @@ from ._base import SQLBaseStore
class EndToEndRoomKeyStore(SQLBaseStore): class EndToEndRoomKeyStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_e2e_room_key(self, user_id, version, room_id, session_id): def get_e2e_room_key(self, user_id, version, room_id, session_id):
"""Get the encrypted E2E room key for a given session from a given """Get the encrypted E2E room key for a given session from a given
@ -97,9 +96,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_e2e_room_keys( def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
self, user_id, version, room_id=None, session_id=None
):
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given """Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session. room, or a given session.
@ -123,10 +120,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
except ValueError: except ValueError:
defer.returnValue({'rooms': {}}) defer.returnValue({'rooms': {}})
keyvalues = { keyvalues = {"user_id": user_id, "version": version}
"user_id": user_id,
"version": version,
}
if room_id: if room_id:
keyvalues['room_id'] = room_id keyvalues['room_id'] = room_id
if session_id: if session_id:
@ -160,9 +154,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
defer.returnValue(sessions) defer.returnValue(sessions)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_e2e_room_keys( def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
self, user_id, version, room_id=None, session_id=None
):
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given """Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session. room or a given session.
@ -180,19 +172,14 @@ class EndToEndRoomKeyStore(SQLBaseStore):
A deferred of the deletion transaction A deferred of the deletion transaction
""" """
keyvalues = { keyvalues = {"user_id": user_id, "version": int(version)}
"user_id": user_id,
"version": int(version),
}
if room_id: if room_id:
keyvalues['room_id'] = room_id keyvalues['room_id'] = room_id
if session_id: if session_id:
keyvalues['session_id'] = session_id keyvalues['session_id'] = session_id
yield self._simple_delete( yield self._simple_delete(
table="e2e_room_keys", table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
keyvalues=keyvalues,
desc="delete_e2e_room_keys",
) )
@staticmethod @staticmethod
@ -200,7 +187,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
txn.execute( txn.execute(
"SELECT MAX(version) FROM e2e_room_keys_versions " "SELECT MAX(version) FROM e2e_room_keys_versions "
"WHERE user_id=? AND deleted=0", "WHERE user_id=? AND deleted=0",
(user_id,) (user_id,),
) )
row = txn.fetchone() row = txn.fetchone()
if not row: if not row:
@ -238,24 +225,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
result = self._simple_select_one_txn( result = self._simple_select_one_txn(
txn, txn,
table="e2e_room_keys_versions", table="e2e_room_keys_versions",
keyvalues={ keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
"user_id": user_id, retcols=("version", "algorithm", "auth_data"),
"version": this_version,
"deleted": 0,
},
retcols=(
"version",
"algorithm",
"auth_data",
),
) )
result["auth_data"] = json.loads(result["auth_data"]) result["auth_data"] = json.loads(result["auth_data"])
result["version"] = str(result["version"]) result["version"] = str(result["version"])
return result return result
return self.runInteraction( return self.runInteraction(
"get_e2e_room_keys_version_info", "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
_get_e2e_room_keys_version_info_txn
) )
def create_e2e_room_keys_version(self, user_id, info): def create_e2e_room_keys_version(self, user_id, info):
@ -273,7 +251,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
def _create_e2e_room_keys_version_txn(txn): def _create_e2e_room_keys_version_txn(txn):
txn.execute( txn.execute(
"SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?", "SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?",
(user_id,) (user_id,),
) )
current_version = txn.fetchone()[0] current_version = txn.fetchone()[0]
if current_version is None: if current_version is None:
@ -309,14 +287,9 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return self._simple_update( return self._simple_update(
table="e2e_room_keys_versions", table="e2e_room_keys_versions",
keyvalues={ keyvalues={"user_id": user_id, "version": version},
"user_id": user_id, updatevalues={"auth_data": json.dumps(info["auth_data"])},
"version": version, desc="update_e2e_room_keys_version",
},
updatevalues={
"auth_data": json.dumps(info["auth_data"]),
},
desc="update_e2e_room_keys_version"
) )
def delete_e2e_room_keys_version(self, user_id, version=None): def delete_e2e_room_keys_version(self, user_id, version=None):
@ -341,16 +314,10 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return self._simple_update_one_txn( return self._simple_update_one_txn(
txn, txn,
table="e2e_room_keys_versions", table="e2e_room_keys_versions",
keyvalues={ keyvalues={"user_id": user_id, "version": this_version},
"user_id": user_id, updatevalues={"deleted": 1},
"version": this_version,
},
updatevalues={
"deleted": 1,
}
) )
return self.runInteraction( return self.runInteraction(
"delete_e2e_room_keys_version", "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn
_delete_e2e_room_keys_version_txn
) )

View File

@ -26,8 +26,7 @@ from ._base import SQLBaseStore, db_to_json
class EndToEndKeyWorkerStore(SQLBaseStore): class EndToEndKeyWorkerStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_e2e_device_keys( def get_e2e_device_keys(
self, query_list, include_all_devices=False, self, query_list, include_all_devices=False, include_deleted_devices=False
include_deleted_devices=False,
): ):
"""Fetch a list of device keys. """Fetch a list of device keys.
Args: Args:
@ -45,8 +44,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
defer.returnValue({}) defer.returnValue({})
results = yield self.runInteraction( results = yield self.runInteraction(
"get_e2e_device_keys", self._get_e2e_device_keys_txn, "get_e2e_device_keys",
query_list, include_all_devices, include_deleted_devices, self._get_e2e_device_keys_txn,
query_list,
include_all_devices,
include_deleted_devices,
) )
for user_id, device_keys in iteritems(results): for user_id, device_keys in iteritems(results):
@ -56,8 +58,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
def _get_e2e_device_keys_txn( def _get_e2e_device_keys_txn(
self, txn, query_list, include_all_devices=False, self, txn, query_list, include_all_devices=False, include_deleted_devices=False
include_deleted_devices=False,
): ):
query_clauses = [] query_clauses = []
query_params = [] query_params = []
@ -87,7 +88,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
" WHERE %s" " WHERE %s"
) % ( ) % (
"LEFT" if include_all_devices else "INNER", "LEFT" if include_all_devices else "INNER",
" OR ".join("(" + q + ")" for q in query_clauses) " OR ".join("(" + q + ")" for q in query_clauses),
) )
txn.execute(sql, query_params) txn.execute(sql, query_params)
@ -124,17 +125,14 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
table="e2e_one_time_keys_json", table="e2e_one_time_keys_json",
column="key_id", column="key_id",
iterable=key_ids, iterable=key_ids,
retcols=("algorithm", "key_id", "key_json",), retcols=("algorithm", "key_id", "key_json"),
keyvalues={ keyvalues={"user_id": user_id, "device_id": device_id},
"user_id": user_id,
"device_id": device_id,
},
desc="add_e2e_one_time_keys_check", desc="add_e2e_one_time_keys_check",
) )
defer.returnValue({ defer.returnValue(
(row["algorithm"], row["key_id"]): row["key_json"] for row in rows {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
}) )
@defer.inlineCallbacks @defer.inlineCallbacks
def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys): def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
@ -155,7 +153,8 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
# `add_e2e_one_time_keys` then they'll conflict and we will only # `add_e2e_one_time_keys` then they'll conflict and we will only
# insert one set. # insert one set.
self._simple_insert_many_txn( self._simple_insert_many_txn(
txn, table="e2e_one_time_keys_json", txn,
table="e2e_one_time_keys_json",
values=[ values=[
{ {
"user_id": user_id, "user_id": user_id,
@ -169,8 +168,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
], ],
) )
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id,) txn, self.count_e2e_one_time_keys, (user_id, device_id)
) )
yield self.runInteraction( yield self.runInteraction(
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
) )
@ -181,6 +181,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
Returns: Returns:
Dict mapping from algorithm to number of keys for that algorithm. Dict mapping from algorithm to number of keys for that algorithm.
""" """
def _count_e2e_one_time_keys(txn): def _count_e2e_one_time_keys(txn):
sql = ( sql = (
"SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json" "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
@ -192,9 +193,8 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for algorithm, key_count in txn: for algorithm, key_count in txn:
result[algorithm] = key_count result[algorithm] = key_count
return result return result
return self.runInteraction(
"count_e2e_one_time_keys", _count_e2e_one_time_keys return self.runInteraction("count_e2e_one_time_keys", _count_e2e_one_time_keys)
)
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
@ -202,14 +202,12 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"""Stores device keys for a device. Returns whether there was a change """Stores device keys for a device. Returns whether there was a change
or the keys were already in the database. or the keys were already in the database.
""" """
def _set_e2e_device_keys_txn(txn): def _set_e2e_device_keys_txn(txn):
old_key_json = self._simple_select_one_onecol_txn( old_key_json = self._simple_select_one_onecol_txn(
txn, txn,
table="e2e_device_keys_json", table="e2e_device_keys_json",
keyvalues={ keyvalues={"user_id": user_id, "device_id": device_id},
"user_id": user_id,
"device_id": device_id,
},
retcol="key_json", retcol="key_json",
allow_none=True, allow_none=True,
) )
@ -224,24 +222,17 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
self._simple_upsert_txn( self._simple_upsert_txn(
txn, txn,
table="e2e_device_keys_json", table="e2e_device_keys_json",
keyvalues={ keyvalues={"user_id": user_id, "device_id": device_id},
"user_id": user_id, values={"ts_added_ms": time_now, "key_json": new_key_json},
"device_id": device_id,
},
values={
"ts_added_ms": time_now,
"key_json": new_key_json,
}
) )
return True return True
return self.runInteraction( return self.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn)
"set_e2e_device_keys", _set_e2e_device_keys_txn
)
def claim_e2e_one_time_keys(self, query_list): def claim_e2e_one_time_keys(self, query_list):
"""Take a list of one time keys out of the database""" """Take a list of one time keys out of the database"""
def _claim_e2e_one_time_keys(txn): def _claim_e2e_one_time_keys(txn):
sql = ( sql = (
"SELECT key_id, key_json FROM e2e_one_time_keys_json" "SELECT key_id, key_json FROM e2e_one_time_keys_json"
@ -265,12 +256,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
for user_id, device_id, algorithm, key_id in delete: for user_id, device_id, algorithm, key_id in delete:
txn.execute(sql, (user_id, device_id, algorithm, key_id)) txn.execute(sql, (user_id, device_id, algorithm, key_id))
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id,) txn, self.count_e2e_one_time_keys, (user_id, device_id)
) )
return result return result
return self.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys return self.runInteraction("claim_e2e_one_time_keys", _claim_e2e_one_time_keys)
)
def delete_e2e_keys_by_device(self, user_id, device_id): def delete_e2e_keys_by_device(self, user_id, device_id):
def delete_e2e_keys_by_device_txn(txn): def delete_e2e_keys_by_device_txn(txn):
@ -285,8 +275,9 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
keyvalues={"user_id": user_id, "device_id": device_id}, keyvalues={"user_id": user_id, "device_id": device_id},
) )
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id,) txn, self.count_e2e_one_time_keys, (user_id, device_id)
) )
return self.runInteraction( return self.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
) )

View File

@ -20,10 +20,7 @@ from ._base import IncorrectDatabaseSetup
from .postgres import PostgresEngine from .postgres import PostgresEngine
from .sqlite import Sqlite3Engine from .sqlite import Sqlite3Engine
SUPPORTED_MODULE = { SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine}
"sqlite3": Sqlite3Engine,
"psycopg2": PostgresEngine,
}
def create_engine(database_config): def create_engine(database_config):
@ -32,15 +29,12 @@ def create_engine(database_config):
if engine_class: if engine_class:
# pypy requires psycopg2cffi rather than psycopg2 # pypy requires psycopg2cffi rather than psycopg2
if (name == "psycopg2" and if name == "psycopg2" and platform.python_implementation() == "PyPy":
platform.python_implementation() == "PyPy"):
name = "psycopg2cffi" name = "psycopg2cffi"
module = importlib.import_module(name) module = importlib.import_module(name)
return engine_class(module, database_config) return engine_class(module, database_config)
raise RuntimeError( raise RuntimeError("Unsupported database engine '%s'" % (name,))
"Unsupported database engine '%s'" % (name,)
)
__all__ = ["create_engine", "IncorrectDatabaseSetup"] __all__ = ["create_engine", "IncorrectDatabaseSetup"]

View File

@ -31,8 +31,7 @@ class PostgresEngine(object):
if rows and rows[0][0] != "UTF8": if rows and rows[0][0] != "UTF8":
raise IncorrectDatabaseSetup( raise IncorrectDatabaseSetup(
"Database has incorrect encoding: '%s' instead of 'UTF8'\n" "Database has incorrect encoding: '%s' instead of 'UTF8'\n"
"See docs/postgres.rst for more information." "See docs/postgres.rst for more information." % (rows[0][0],)
% (rows[0][0],)
) )
def convert_param_style(self, sql): def convert_param_style(self, sql):
@ -103,12 +102,6 @@ class PostgresEngine(object):
# https://www.postgresql.org/docs/current/libpq-status.html#LIBPQ-PQSERVERVERSION # https://www.postgresql.org/docs/current/libpq-status.html#LIBPQ-PQSERVERVERSION
if numver >= 100000: if numver >= 100000:
return "%i.%i" % ( return "%i.%i" % (numver / 10000, numver % 10000)
numver / 10000, numver % 10000,
)
else: else:
return "%i.%i.%i" % ( return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100)
numver / 10000,
(numver % 10000) / 100,
numver % 100,
)

View File

@ -82,6 +82,7 @@ class Sqlite3Engine(object):
# Following functions taken from: https://github.com/coleifer/peewee # Following functions taken from: https://github.com/coleifer/peewee
def _parse_match_info(buf): def _parse_match_info(buf):
bufsize = len(buf) bufsize = len(buf)
return [struct.unpack('@I', buf[i : i + 4])[0] for i in range(0, bufsize, 4)] return [struct.unpack('@I', buf[i : i + 4])[0] for i in range(0, bufsize, 4)]

View File

@ -32,8 +32,7 @@ from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
SQLBaseStore):
def get_auth_chain(self, event_ids, include_given=False): def get_auth_chain(self, event_ids, include_given=False):
"""Get auth events for given event_ids. The events *must* be state events. """Get auth events for given event_ids. The events *must* be state events.
@ -45,7 +44,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
list of events list of events
""" """
return self.get_auth_chain_ids( return self.get_auth_chain_ids(
event_ids, include_given=include_given, event_ids, include_given=include_given
).addCallback(self._get_events) ).addCallback(self._get_events)
def get_auth_chain_ids(self, event_ids, include_given=False): def get_auth_chain_ids(self, event_ids, include_given=False):
@ -59,9 +58,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
list of event_ids list of event_ids
""" """
return self.runInteraction( return self.runInteraction(
"get_auth_chain_ids", "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given
self._get_auth_chain_ids_txn,
event_ids, include_given
) )
def _get_auth_chain_ids_txn(self, txn, event_ids, include_given): def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
@ -70,23 +67,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
else: else:
results = set() results = set()
base_sql = ( base_sql = "SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
"SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
)
front = set(event_ids) front = set(event_ids)
while front: while front:
new_front = set() new_front = set()
front_list = list(front) front_list = list(front)
chunks = [ chunks = [front_list[x : x + 100] for x in range(0, len(front), 100)]
front_list[x:x + 100]
for x in range(0, len(front), 100)
]
for chunk in chunks: for chunk in chunks:
txn.execute( txn.execute(base_sql % (",".join(["?"] * len(chunk)),), chunk)
base_sql % (",".join(["?"] * len(chunk)),),
chunk
)
new_front.update([r[0] for r in txn]) new_front.update([r[0] for r in txn])
new_front -= results new_front -= results
@ -98,9 +87,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
def get_oldest_events_in_room(self, room_id): def get_oldest_events_in_room(self, room_id):
return self.runInteraction( return self.runInteraction(
"get_oldest_events_in_room", "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
self._get_oldest_events_in_room_txn,
room_id,
) )
def get_oldest_events_with_depth_in_room(self, room_id): def get_oldest_events_with_depth_in_room(self, room_id):
@ -121,7 +108,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
" GROUP BY b.event_id" " GROUP BY b.event_id"
) )
txn.execute(sql, (room_id, False,)) txn.execute(sql, (room_id, False))
return dict(txn) return dict(txn)
@ -152,9 +139,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
return self._simple_select_onecol_txn( return self._simple_select_onecol_txn(
txn, txn,
table="event_backward_extremities", table="event_backward_extremities",
keyvalues={ keyvalues={"room_id": room_id},
"room_id": room_id,
},
retcol="event_id", retcol="event_id",
) )
@ -209,9 +194,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
def get_latest_event_ids_in_room(self, room_id): def get_latest_event_ids_in_room(self, room_id):
return self._simple_select_onecol( return self._simple_select_onecol(
table="event_forward_extremities", table="event_forward_extremities",
keyvalues={ keyvalues={"room_id": room_id},
"room_id": room_id,
},
retcol="event_id", retcol="event_id",
desc="get_latest_event_ids_in_room", desc="get_latest_event_ids_in_room",
) )
@ -231,8 +214,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
for event_id, depth in txn.fetchall(): for event_id, depth in txn.fetchall():
hashes = self._get_event_reference_hashes_txn(txn, event_id) hashes = self._get_event_reference_hashes_txn(txn, event_id)
prev_hashes = { prev_hashes = {
k: encode_base64(v) for k, v in hashes.items() k: encode_base64(v) for k, v in hashes.items() if k == "sha256"
if k == "sha256"
} }
results.append((event_id, prev_hashes, depth)) results.append((event_id, prev_hashes, depth))
@ -242,9 +224,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
""" For hte given room, get the minimum depth we have seen for it. """ For hte given room, get the minimum depth we have seen for it.
""" """
return self.runInteraction( return self.runInteraction(
"get_min_depth", "get_min_depth", self._get_min_depth_interaction, room_id
self._get_min_depth_interaction,
room_id,
) )
def _get_min_depth_interaction(self, txn, room_id): def _get_min_depth_interaction(self, txn, room_id):
@ -300,7 +280,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
if stream_ordering <= self.stream_ordering_month_ago: if stream_ordering <= self.stream_ordering_month_ago:
raise StoreError(400, "stream_ordering too old") raise StoreError(400, "stream_ordering too old")
sql = (""" sql = """
SELECT event_id FROM stream_ordering_to_exterm SELECT event_id FROM stream_ordering_to_exterm
INNER JOIN ( INNER JOIN (
SELECT room_id, MAX(stream_ordering) AS stream_ordering SELECT room_id, MAX(stream_ordering) AS stream_ordering
@ -308,15 +288,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
WHERE stream_ordering <= ? GROUP BY room_id WHERE stream_ordering <= ? GROUP BY room_id
) AS rms USING (room_id, stream_ordering) ) AS rms USING (room_id, stream_ordering)
WHERE room_id = ? WHERE room_id = ?
""") """
def get_forward_extremeties_for_room_txn(txn): def get_forward_extremeties_for_room_txn(txn):
txn.execute(sql, (stream_ordering, room_id)) txn.execute(sql, (stream_ordering, room_id))
return [event_id for event_id, in txn] return [event_id for event_id, in txn]
return self.runInteraction( return self.runInteraction(
"get_forward_extremeties_for_room", "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
get_forward_extremeties_for_room_txn
) )
def get_backfill_events(self, room_id, event_list, limit): def get_backfill_events(self, room_id, event_list, limit):
@ -329,19 +308,21 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
event_list (list) event_list (list)
limit (int) limit (int)
""" """
return self.runInteraction( return (
self.runInteraction(
"get_backfill_events", "get_backfill_events",
self._get_backfill_events, room_id, event_list, limit self._get_backfill_events,
).addCallback( room_id,
self._get_events event_list,
).addCallback( limit,
lambda l: sorted(l, key=lambda e: -e.depth) )
.addCallback(self._get_events)
.addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
) )
def _get_backfill_events(self, txn, room_id, event_list, limit): def _get_backfill_events(self, txn, room_id, event_list, limit):
logger.debug( logger.debug(
"_get_backfill_events: %s, %s, %s", "_get_backfill_events: %s, %s, %s", room_id, repr(event_list), limit
room_id, repr(event_list), limit
) )
event_results = set() event_results = set()
@ -364,10 +345,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
depth = self._simple_select_one_onecol_txn( depth = self._simple_select_one_onecol_txn(
txn, txn,
table="events", table="events",
keyvalues={ keyvalues={"event_id": event_id, "room_id": room_id},
"event_id": event_id,
"room_id": room_id,
},
retcol="depth", retcol="depth",
allow_none=True, allow_none=True,
) )
@ -386,10 +364,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
event_results.add(event_id) event_results.add(event_id)
txn.execute( txn.execute(query, (event_id, False, limit - len(event_results)))
query,
(event_id, False, limit - len(event_results))
)
for row in txn: for row in txn:
if row[1] not in event_results: if row[1] not in event_results:
@ -398,18 +373,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
return event_results return event_results
@defer.inlineCallbacks @defer.inlineCallbacks
def get_missing_events(self, room_id, earliest_events, latest_events, def get_missing_events(self, room_id, earliest_events, latest_events, limit):
limit):
ids = yield self.runInteraction( ids = yield self.runInteraction(
"get_missing_events", "get_missing_events",
self._get_missing_events, self._get_missing_events,
room_id, earliest_events, latest_events, limit, room_id,
earliest_events,
latest_events,
limit,
) )
events = yield self._get_events(ids) events = yield self._get_events(ids)
defer.returnValue(events) defer.returnValue(events)
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
limit):
seen_events = set(earliest_events) seen_events = set(earliest_events)
front = set(latest_events) - seen_events front = set(latest_events) - seen_events
@ -425,8 +401,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
new_front = set() new_front = set()
for event_id in front: for event_id in front:
txn.execute( txn.execute(
query, query, (room_id, event_id, False, limit - len(event_results))
(room_id, event_id, False, limit - len(event_results))
) )
new_results = set(t[0] for t in txn) - seen_events new_results = set(t[0] for t in txn) - seen_events
@ -457,12 +432,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
column="prev_event_id", column="prev_event_id",
iterable=event_ids, iterable=event_ids,
retcols=("event_id",), retcols=("event_id",),
desc="get_successor_events" desc="get_successor_events",
) )
defer.returnValue([ defer.returnValue([row["event_id"] for row in rows])
row["event_id"] for row in rows
])
class EventFederationStore(EventFederationWorkerStore): class EventFederationStore(EventFederationWorkerStore):
@ -481,12 +454,11 @@ class EventFederationStore(EventFederationWorkerStore):
super(EventFederationStore, self).__init__(db_conn, hs) super(EventFederationStore, self).__init__(db_conn, hs)
self.register_background_update_handler( self.register_background_update_handler(
self.EVENT_AUTH_STATE_ONLY, self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
self._background_delete_non_state_event_auth,
) )
hs.get_clock().looping_call( hs.get_clock().looping_call(
self._delete_old_forward_extrem_cache, 60 * 60 * 1000, self._delete_old_forward_extrem_cache, 60 * 60 * 1000
) )
def _update_min_depth_for_room_txn(self, txn, room_id, depth): def _update_min_depth_for_room_txn(self, txn, room_id, depth):
@ -498,12 +470,8 @@ class EventFederationStore(EventFederationWorkerStore):
self._simple_upsert_txn( self._simple_upsert_txn(
txn, txn,
table="room_depth", table="room_depth",
keyvalues={ keyvalues={"room_id": room_id},
"room_id": room_id, values={"min_depth": depth},
},
values={
"min_depth": depth,
},
) )
def _handle_mult_prev_events(self, txn, events): def _handle_mult_prev_events(self, txn, events):
@ -553,11 +521,15 @@ class EventFederationStore(EventFederationWorkerStore):
" )" " )"
) )
txn.executemany(query, [ txn.executemany(
query,
[
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
for ev in events for e_id in ev.prev_event_ids() for ev in events
for e_id in ev.prev_event_ids()
if not ev.internal_metadata.is_outlier() if not ev.internal_metadata.is_outlier()
]) ],
)
query = ( query = (
"DELETE FROM event_backward_extremities" "DELETE FROM event_backward_extremities"
@ -566,16 +538,17 @@ class EventFederationStore(EventFederationWorkerStore):
txn.executemany( txn.executemany(
query, query,
[ [
(ev.event_id, ev.room_id) for ev in events (ev.event_id, ev.room_id)
for ev in events
if not ev.internal_metadata.is_outlier() if not ev.internal_metadata.is_outlier()
] ],
) )
def _delete_old_forward_extrem_cache(self): def _delete_old_forward_extrem_cache(self):
def _delete_old_forward_extrem_cache_txn(txn): def _delete_old_forward_extrem_cache_txn(txn):
# Delete entries older than a month, while making sure we don't delete # Delete entries older than a month, while making sure we don't delete
# the only entries for a room. # the only entries for a room.
sql = (""" sql = """
DELETE FROM stream_ordering_to_exterm DELETE FROM stream_ordering_to_exterm
WHERE WHERE
room_id IN ( room_id IN (
@ -583,11 +556,11 @@ class EventFederationStore(EventFederationWorkerStore):
FROM stream_ordering_to_exterm FROM stream_ordering_to_exterm
WHERE stream_ordering > ? WHERE stream_ordering > ?
) AND stream_ordering < ? ) AND stream_ordering < ?
""") """
txn.execute( txn.execute(
sql, sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
(self.stream_ordering_month_ago, self.stream_ordering_month_ago,)
) )
return run_as_background_process( return run_as_background_process(
"delete_old_forward_extrem_cache", "delete_old_forward_extrem_cache",
self.runInteraction, self.runInteraction,
@ -597,9 +570,7 @@ class EventFederationStore(EventFederationWorkerStore):
def clean_room_for_join(self, room_id): def clean_room_for_join(self, room_id):
return self.runInteraction( return self.runInteraction(
"clean_room_for_join", "clean_room_for_join", self._clean_room_for_join_txn, room_id
self._clean_room_for_join_txn,
room_id,
) )
def _clean_room_for_join_txn(self, txn, room_id): def _clean_room_for_join_txn(self, txn, room_id):
@ -635,7 +606,7 @@ class EventFederationStore(EventFederationWorkerStore):
) )
""" """
txn.execute(sql, (min_stream_id, max_stream_id,)) txn.execute(sql, (min_stream_id, max_stream_id))
new_progress = { new_progress = {
"target_min_stream_id_inclusive": target_min_stream_id, "target_min_stream_id_inclusive": target_min_stream_id,

View File

@ -31,7 +31,9 @@ logger = logging.getLogger(__name__)
DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}] DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}]
DEFAULT_HIGHLIGHT_ACTION = [ DEFAULT_HIGHLIGHT_ACTION = [
"notify", {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"} "notify",
{"set_tweak": "sound", "value": "default"},
{"set_tweak": "highlight"},
] ]
@ -96,20 +98,21 @@ class EventPushActionsWorkerStore(SQLBaseStore):
ret = yield self.runInteraction( ret = yield self.runInteraction(
"get_unread_event_push_actions_by_room", "get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn, self._get_unread_counts_by_receipt_txn,
room_id, user_id, last_read_event_id room_id,
user_id,
last_read_event_id,
) )
defer.returnValue(ret) defer.returnValue(ret)
def _get_unread_counts_by_receipt_txn(self, txn, room_id, user_id, def _get_unread_counts_by_receipt_txn(
last_read_event_id): self, txn, room_id, user_id, last_read_event_id
):
sql = ( sql = (
"SELECT stream_ordering" "SELECT stream_ordering"
" FROM events" " FROM events"
" WHERE room_id = ? AND event_id = ?" " WHERE room_id = ? AND event_id = ?"
) )
txn.execute( txn.execute(sql, (room_id, last_read_event_id))
sql, (room_id, last_read_event_id)
)
results = txn.fetchall() results = txn.fetchall()
if len(results) == 0: if len(results) == 0:
return {"notify_count": 0, "highlight_count": 0} return {"notify_count": 0, "highlight_count": 0}
@ -138,10 +141,13 @@ class EventPushActionsWorkerStore(SQLBaseStore):
row = txn.fetchone() row = txn.fetchone()
notify_count = row[0] if row else 0 notify_count = row[0] if row else 0
txn.execute(""" txn.execute(
"""
SELECT notif_count FROM event_push_summary SELECT notif_count FROM event_push_summary
WHERE room_id = ? AND user_id = ? AND stream_ordering > ? WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
""", (room_id, user_id, stream_ordering,)) """,
(room_id, user_id, stream_ordering),
)
rows = txn.fetchall() rows = txn.fetchall()
if rows: if rows:
notify_count += rows[0][0] notify_count += rows[0][0]
@ -161,10 +167,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
row = txn.fetchone() row = txn.fetchone()
highlight_count = row[0] if row else 0 highlight_count = row[0] if row else 0
return { return {"notify_count": notify_count, "highlight_count": highlight_count}
"notify_count": notify_count,
"highlight_count": highlight_count,
}
@defer.inlineCallbacks @defer.inlineCallbacks
def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering): def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering):
@ -175,6 +178,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
) )
txn.execute(sql, (min_stream_ordering, max_stream_ordering)) txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn] return [r[0] for r in txn]
ret = yield self.runInteraction("get_push_action_users_in_range", f) ret = yield self.runInteraction("get_push_action_users_in_range", f)
defer.returnValue(ret) defer.returnValue(ret)
@ -223,12 +227,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.stream_ordering <= ?" " AND ep.stream_ordering <= ?"
" ORDER BY ep.stream_ordering ASC LIMIT ?" " ORDER BY ep.stream_ordering ASC LIMIT ?"
) )
args = [ args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
user_id, user_id,
min_stream_ordering, max_stream_ordering, limit,
]
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
after_read_receipt = yield self.runInteraction( after_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
) )
@ -253,12 +255,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.stream_ordering <= ?" " AND ep.stream_ordering <= ?"
" ORDER BY ep.stream_ordering ASC LIMIT ?" " ORDER BY ep.stream_ordering ASC LIMIT ?"
) )
args = [ args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
user_id, user_id,
min_stream_ordering, max_stream_ordering, limit,
]
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
no_read_receipt = yield self.runInteraction( no_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
) )
@ -269,7 +269,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"room_id": row[1], "room_id": row[1],
"stream_ordering": row[2], "stream_ordering": row[2],
"actions": _deserialize_action(row[3], row[4]), "actions": _deserialize_action(row[3], row[4]),
} for row in after_read_receipt + no_read_receipt }
for row in after_read_receipt + no_read_receipt
] ]
# Now sort it so it's ordered correctly, since currently it will # Now sort it so it's ordered correctly, since currently it will
@ -326,12 +327,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.stream_ordering <= ?" " AND ep.stream_ordering <= ?"
" ORDER BY ep.stream_ordering DESC LIMIT ?" " ORDER BY ep.stream_ordering DESC LIMIT ?"
) )
args = [ args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
user_id, user_id,
min_stream_ordering, max_stream_ordering, limit,
]
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
after_read_receipt = yield self.runInteraction( after_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
) )
@ -356,12 +355,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.stream_ordering <= ?" " AND ep.stream_ordering <= ?"
" ORDER BY ep.stream_ordering DESC LIMIT ?" " ORDER BY ep.stream_ordering DESC LIMIT ?"
) )
args = [ args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
user_id, user_id,
min_stream_ordering, max_stream_ordering, limit,
]
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
no_read_receipt = yield self.runInteraction( no_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
) )
@ -374,7 +371,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"stream_ordering": row[2], "stream_ordering": row[2],
"actions": _deserialize_action(row[3], row[4]), "actions": _deserialize_action(row[3], row[4]),
"received_ts": row[5], "received_ts": row[5],
} for row in after_read_receipt + no_read_receipt }
for row in after_read_receipt + no_read_receipt
] ]
# Now sort it so it's ordered correctly, since currently it will # Now sort it so it's ordered correctly, since currently it will
@ -408,7 +406,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
LIMIT 1 LIMIT 1
""" """
txn.execute(sql, (user_id, min_stream_ordering,)) txn.execute(sql, (user_id, min_stream_ordering))
return bool(txn.fetchone()) return bool(txn.fetchone())
return self.runInteraction( return self.runInteraction(
@ -454,10 +452,13 @@ class EventPushActionsWorkerStore(SQLBaseStore):
VALUES (?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?)
""" """
txn.executemany(sql, ( txn.executemany(
sql,
(
_gen_entry(user_id, actions) _gen_entry(user_id, actions)
for user_id, actions in iteritems(user_id_actions) for user_id, actions in iteritems(user_id_actions)
)) ),
)
return self.runInteraction( return self.runInteraction(
"add_push_actions_to_staging", _add_push_actions_to_staging_txn "add_push_actions_to_staging", _add_push_actions_to_staging_txn
@ -475,9 +476,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
try: try:
res = yield self._simple_delete( res = yield self._simple_delete(
table="event_push_actions_staging", table="event_push_actions_staging",
keyvalues={ keyvalues={"event_id": event_id},
"event_id": event_id,
},
desc="remove_push_actions_from_staging", desc="remove_push_actions_from_staging",
) )
defer.returnValue(res) defer.returnValue(res)
@ -486,7 +485,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# another exception here really isn't helpful - there's nothing # another exception here really isn't helpful - there's nothing
# the caller can do about it. Just log the exception and move on. # the caller can do about it. Just log the exception and move on.
logger.exception( logger.exception(
"Error removing push actions after event persistence failure", "Error removing push actions after event persistence failure"
) )
def _find_stream_orderings_for_times(self): def _find_stream_orderings_for_times(self):
@ -503,16 +502,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000 txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
) )
logger.info( logger.info(
"Found stream ordering 1 month ago: it's %d", "Found stream ordering 1 month ago: it's %d", self.stream_ordering_month_ago
self.stream_ordering_month_ago
) )
logger.info("Searching for stream ordering 1 day ago") logger.info("Searching for stream ordering 1 day ago")
self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn( self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn(
txn, self._clock.time_msec() - 24 * 60 * 60 * 1000 txn, self._clock.time_msec() - 24 * 60 * 60 * 1000
) )
logger.info( logger.info(
"Found stream ordering 1 day ago: it's %d", "Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago
self.stream_ordering_day_ago
) )
def find_first_stream_ordering_after_ts(self, ts): def find_first_stream_ordering_after_ts(self, ts):
@ -631,16 +628,17 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
index_name="event_push_actions_highlights_index", index_name="event_push_actions_highlights_index",
table="event_push_actions", table="event_push_actions",
columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
where_clause="highlight=1" where_clause="highlight=1",
) )
self._doing_notif_rotation = False self._doing_notif_rotation = False
self._rotate_notif_loop = self._clock.looping_call( self._rotate_notif_loop = self._clock.looping_call(
self._start_rotate_notifs, 30 * 60 * 1000, self._start_rotate_notifs, 30 * 60 * 1000
) )
def _set_push_actions_for_event_and_users_txn(self, txn, events_and_contexts, def _set_push_actions_for_event_and_users_txn(
all_events_and_contexts): self, txn, events_and_contexts, all_events_and_contexts
):
"""Handles moving push actions from staging table to main """Handles moving push actions from staging table to main
event_push_actions table for all events in `events_and_contexts`. event_push_actions table for all events in `events_and_contexts`.
@ -667,43 +665,44 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
""" """
if events_and_contexts: if events_and_contexts:
txn.executemany(sql, ( txn.executemany(
sql,
( (
event.room_id, event.internal_metadata.stream_ordering, (
event.depth, event.event_id, event.room_id,
event.internal_metadata.stream_ordering,
event.depth,
event.event_id,
) )
for event, _ in events_and_contexts for event, _ in events_and_contexts
)) ),
)
for event, _ in events_and_contexts: for event, _ in events_and_contexts:
user_ids = self._simple_select_onecol_txn( user_ids = self._simple_select_onecol_txn(
txn, txn,
table="event_push_actions_staging", table="event_push_actions_staging",
keyvalues={ keyvalues={"event_id": event.event_id},
"event_id": event.event_id,
},
retcol="user_id", retcol="user_id",
) )
for uid in user_ids: for uid in user_ids:
txn.call_after( txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many, self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
(event.room_id, uid,) (event.room_id, uid),
) )
# Now we delete the staging area for *all* events that were being # Now we delete the staging area for *all* events that were being
# persisted. # persisted.
txn.executemany( txn.executemany(
"DELETE FROM event_push_actions_staging WHERE event_id = ?", "DELETE FROM event_push_actions_staging WHERE event_id = ?",
( ((event.event_id,) for event, _ in all_events_and_contexts),
(event.event_id,)
for event, _ in all_events_and_contexts
)
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_push_actions_for_user(self, user_id, before=None, limit=50, def get_push_actions_for_user(
only_highlight=False): self, user_id, before=None, limit=50, only_highlight=False
):
def f(txn): def f(txn):
before_clause = "" before_clause = ""
if before: if before:
@ -727,15 +726,12 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
" WHERE epa.event_id = e.event_id" " WHERE epa.event_id = e.event_id"
" AND epa.user_id = ? %s" " AND epa.user_id = ? %s"
" ORDER BY epa.stream_ordering DESC" " ORDER BY epa.stream_ordering DESC"
" LIMIT ?" " LIMIT ?" % (before_clause,)
% (before_clause,)
) )
txn.execute(sql, args) txn.execute(sql, args)
return self.cursor_to_dict(txn) return self.cursor_to_dict(txn)
push_actions = yield self.runInteraction( push_actions = yield self.runInteraction("get_push_actions_for_user", f)
"get_push_actions_for_user", f
)
for pa in push_actions: for pa in push_actions:
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
defer.returnValue(push_actions) defer.returnValue(push_actions)
@ -753,6 +749,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
) )
txn.execute(sql, (stream_ordering,)) txn.execute(sql, (stream_ordering,))
return txn.fetchone() return txn.fetchone()
result = yield self.runInteraction("get_time_of_last_push_action_before", f) result = yield self.runInteraction("get_time_of_last_push_action_before", f)
defer.returnValue(result[0] if result else None) defer.returnValue(result[0] if result else None)
@ -761,24 +758,24 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
def f(txn): def f(txn):
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
return txn.fetchone() return txn.fetchone()
result = yield self.runInteraction(
"get_latest_push_action_stream_ordering", f result = yield self.runInteraction("get_latest_push_action_stream_ordering", f)
)
defer.returnValue(result[0] or 0) defer.returnValue(result[0] or 0)
def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
# Sad that we have to blow away the cache for the whole room here # Sad that we have to blow away the cache for the whole room here
txn.call_after( txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many, self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
(room_id,) (room_id,),
) )
txn.execute( txn.execute(
"DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?", "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
(room_id, event_id) (room_id, event_id),
) )
def _remove_old_push_actions_before_txn(self, txn, room_id, user_id, def _remove_old_push_actions_before_txn(
stream_ordering): self, txn, room_id, user_id, stream_ordering
):
""" """
Purges old push actions for a user and room before a given Purges old push actions for a user and room before a given
stream_ordering. stream_ordering.
@ -795,7 +792,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
""" """
txn.call_after( txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many, self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
(room_id, user_id, ) (room_id, user_id),
) )
# We need to join on the events table to get the received_ts for # We need to join on the events table to get the received_ts for
@ -811,13 +808,16 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
" WHERE user_id = ? AND room_id = ? AND " " WHERE user_id = ? AND room_id = ? AND "
" stream_ordering <= ?" " stream_ordering <= ?"
" AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)", " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
(user_id, room_id, stream_ordering, self.stream_ordering_month_ago) (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
) )
txn.execute(""" txn.execute(
"""
DELETE FROM event_push_summary DELETE FROM event_push_summary
WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
""", (room_id, user_id, stream_ordering)) """,
(room_id, user_id, stream_ordering),
)
def _start_rotate_notifs(self): def _start_rotate_notifs(self):
return run_as_background_process("rotate_notifs", self._rotate_notifs) return run_as_background_process("rotate_notifs", self._rotate_notifs)
@ -833,8 +833,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
logger.info("Rotating notifications") logger.info("Rotating notifications")
caught_up = yield self.runInteraction( caught_up = yield self.runInteraction(
"_rotate_notifs", "_rotate_notifs", self._rotate_notifs_txn
self._rotate_notifs_txn
) )
if caught_up: if caught_up:
break break
@ -856,11 +855,14 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
# We don't to try and rotate millions of rows at once, so we cap the # We don't to try and rotate millions of rows at once, so we cap the
# maximum stream ordering we'll rotate before. # maximum stream ordering we'll rotate before.
txn.execute(""" txn.execute(
"""
SELECT stream_ordering FROM event_push_actions SELECT stream_ordering FROM event_push_actions
WHERE stream_ordering > ? WHERE stream_ordering > ?
ORDER BY stream_ordering ASC LIMIT 1 OFFSET ? ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
""", (old_rotate_stream_ordering, self._rotate_count)) """,
(old_rotate_stream_ordering, self._rotate_count),
)
stream_row = txn.fetchone() stream_row = txn.fetchone()
if stream_row: if stream_row:
offset_stream_ordering, = stream_row offset_stream_ordering, = stream_row
@ -904,7 +906,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
LEFT JOIN event_push_summary AS old USING (user_id, room_id) LEFT JOIN event_push_summary AS old USING (user_id, room_id)
""" """
txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering,)) txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering))
rows = txn.fetchall() rows = txn.fetchall()
logger.info("Rotating notifications, handling %d rows", len(rows)) logger.info("Rotating notifications, handling %d rows", len(rows))
@ -922,8 +924,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
"notif_count": row[2], "notif_count": row[2],
"stream_ordering": row[3], "stream_ordering": row[3],
} }
for row in rows if row[4] is None for row in rows
] if row[4] is None
],
) )
txn.executemany( txn.executemany(
@ -931,20 +934,20 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
UPDATE event_push_summary SET notif_count = ?, stream_ordering = ? UPDATE event_push_summary SET notif_count = ?, stream_ordering = ?
WHERE user_id = ? AND room_id = ? WHERE user_id = ? AND room_id = ?
""", """,
((row[2], row[3], row[0], row[1],) for row in rows if row[4] is not None) ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None),
) )
txn.execute( txn.execute(
"DELETE FROM event_push_actions" "DELETE FROM event_push_actions"
" WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0", " WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0",
(old_rotate_stream_ordering, rotate_to_stream_ordering,) (old_rotate_stream_ordering, rotate_to_stream_ordering),
) )
logger.info("Rotating notifications, deleted %s push actions", txn.rowcount) logger.info("Rotating notifications, deleted %s push actions", txn.rowcount)
txn.execute( txn.execute(
"UPDATE event_push_summary_stream_ordering SET stream_ordering = ?", "UPDATE event_push_summary_stream_ordering SET stream_ordering = ?",
(rotate_to_stream_ordering,) (rotate_to_stream_ordering,),
) )

View File

@ -71,17 +71,21 @@ class EventsWorkerStore(SQLBaseStore):
""" """
return self._simple_select_one_onecol( return self._simple_select_one_onecol(
table="events", table="events",
keyvalues={ keyvalues={"event_id": event_id},
"event_id": event_id,
},
retcol="received_ts", retcol="received_ts",
desc="get_received_ts", desc="get_received_ts",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True, def get_event(
get_prev_content=False, allow_rejected=False, self,
allow_none=False, check_room_id=None): event_id,
check_redacted=True,
get_prev_content=False,
allow_rejected=False,
allow_none=False,
check_room_id=None,
):
"""Get an event from the database by event_id. """Get an event from the database by event_id.
Args: Args:
@ -118,8 +122,13 @@ class EventsWorkerStore(SQLBaseStore):
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_events(self, event_ids, check_redacted=True, def get_events(
get_prev_content=False, allow_rejected=False): self,
event_ids,
check_redacted=True,
get_prev_content=False,
allow_rejected=False,
):
"""Get events from the database """Get events from the database
Args: Args:
@ -143,8 +152,13 @@ class EventsWorkerStore(SQLBaseStore):
defer.returnValue({e.event_id: e for e in events}) defer.returnValue({e.event_id: e for e in events})
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_events(self, event_ids, check_redacted=True, def _get_events(
get_prev_content=False, allow_rejected=False): self,
event_ids,
check_redacted=True,
get_prev_content=False,
allow_rejected=False,
):
if not event_ids: if not event_ids:
defer.returnValue([]) defer.returnValue([])
@ -152,8 +166,7 @@ class EventsWorkerStore(SQLBaseStore):
event_ids = set(event_ids) event_ids = set(event_ids)
event_entry_map = self._get_events_from_cache( event_entry_map = self._get_events_from_cache(
event_ids, event_ids, allow_rejected=allow_rejected
allow_rejected=allow_rejected,
) )
missing_events_ids = [e for e in event_ids if e not in event_entry_map] missing_events_ids = [e for e in event_ids if e not in event_entry_map]
@ -169,8 +182,7 @@ class EventsWorkerStore(SQLBaseStore):
# #
# _enqueue_events is a bit of a rubbish name but naming is hard. # _enqueue_events is a bit of a rubbish name but naming is hard.
missing_events = yield self._enqueue_events( missing_events = yield self._enqueue_events(
missing_events_ids, missing_events_ids, allow_rejected=allow_rejected
allow_rejected=allow_rejected,
) )
event_entry_map.update(missing_events) event_entry_map.update(missing_events)
@ -214,7 +226,10 @@ class EventsWorkerStore(SQLBaseStore):
) )
expected_domain = get_domain_from_id(entry.event.sender) expected_domain = get_domain_from_id(entry.event.sender)
if orig_sender and get_domain_from_id(orig_sender) == expected_domain: if (
orig_sender
and get_domain_from_id(orig_sender) == expected_domain
):
# This redaction event is allowed. Mark as not needing a # This redaction event is allowed. Mark as not needing a
# recheck. # recheck.
entry.event.internal_metadata.recheck_redaction = False entry.event.internal_metadata.recheck_redaction = False
@ -267,8 +282,7 @@ class EventsWorkerStore(SQLBaseStore):
for event_id in events: for event_id in events:
ret = self._get_event_cache.get( ret = self._get_event_cache.get(
(event_id,), None, (event_id,), None, update_metrics=update_metrics
update_metrics=update_metrics,
) )
if not ret: if not ret:
continue continue
@ -318,19 +332,13 @@ class EventsWorkerStore(SQLBaseStore):
with Measure(self._clock, "_fetch_event_list"): with Measure(self._clock, "_fetch_event_list"):
try: try:
event_id_lists = list(zip(*event_list))[0] event_id_lists = list(zip(*event_list))[0]
event_ids = [ event_ids = [item for sublist in event_id_lists for item in sublist]
item for sublist in event_id_lists for item in sublist
]
rows = self._new_transaction( rows = self._new_transaction(
conn, "do_fetch", [], [], conn, "do_fetch", [], [], self._fetch_event_rows, event_ids
self._fetch_event_rows, event_ids,
) )
row_dict = { row_dict = {r["event_id"]: r for r in rows}
r["event_id"]: r
for r in rows
}
# We only want to resolve deferreds from the main thread # We only want to resolve deferreds from the main thread
def fire(lst, res): def fire(lst, res):
@ -338,13 +346,10 @@ class EventsWorkerStore(SQLBaseStore):
if not d.called: if not d.called:
try: try:
with PreserveLoggingContext(): with PreserveLoggingContext():
d.callback([ d.callback([res[i] for i in ids if i in res])
res[i]
for i in ids
if i in res
])
except Exception: except Exception:
logger.exception("Failed to callback") logger.exception("Failed to callback")
with PreserveLoggingContext(): with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, row_dict) self.hs.get_reactor().callFromThread(fire, event_list, row_dict)
except Exception as e: except Exception as e:
@ -371,9 +376,7 @@ class EventsWorkerStore(SQLBaseStore):
events_d = defer.Deferred() events_d = defer.Deferred()
with self._event_fetch_lock: with self._event_fetch_lock:
self._event_fetch_list.append( self._event_fetch_list.append((events, events_d))
(events, events_d)
)
self._event_fetch_lock.notify() self._event_fetch_lock.notify()
@ -385,9 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
if should_start: if should_start:
run_as_background_process( run_as_background_process(
"fetch_events", "fetch_events", self.runWithConnection, self._do_fetch
self.runWithConnection,
self._do_fetch,
) )
logger.debug("Loading %d events", len(events)) logger.debug("Loading %d events", len(events))
@ -398,23 +399,24 @@ class EventsWorkerStore(SQLBaseStore):
if not allow_rejected: if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]] rows[:] = [r for r in rows if not r["rejects"]]
res = yield make_deferred_yieldable(defer.gatherResults( res = yield make_deferred_yieldable(
defer.gatherResults(
[ [
run_in_background( run_in_background(
self._get_event_from_row, self._get_event_from_row,
row["internal_metadata"], row["json"], row["redacts"], row["internal_metadata"],
row["json"],
row["redacts"],
rejected_reason=row["rejects"], rejected_reason=row["rejects"],
format_version=row["format_version"], format_version=row["format_version"],
) )
for row in rows for row in rows
], ],
consumeErrors=True consumeErrors=True,
)) )
)
defer.returnValue({ defer.returnValue({e.event.event_id: e for e in res if e})
e.event.event_id: e
for e in res if e
})
def _fetch_event_rows(self, txn, events): def _fetch_event_rows(self, txn, events):
rows = [] rows = []
@ -444,8 +446,9 @@ class EventsWorkerStore(SQLBaseStore):
return rows return rows
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_event_from_row(self, internal_metadata, js, redacted, def _get_event_from_row(
format_version, rejected_reason=None): self, internal_metadata, js, redacted, format_version, rejected_reason=None
):
with Measure(self._clock, "_get_event_from_row"): with Measure(self._clock, "_get_event_from_row"):
d = json.loads(js) d = json.loads(js)
internal_metadata = json.loads(internal_metadata) internal_metadata = json.loads(internal_metadata)
@ -484,9 +487,7 @@ class EventsWorkerStore(SQLBaseStore):
# Get the redaction event. # Get the redaction event.
because = yield self.get_event( because = yield self.get_event(
redaction_id, redaction_id, check_redacted=False, allow_none=True
check_redacted=False,
allow_none=True,
) )
if because: if because:
@ -508,8 +509,7 @@ class EventsWorkerStore(SQLBaseStore):
redacted_event = None redacted_event = None
cache_entry = _EventCacheEntry( cache_entry = _EventCacheEntry(
event=original_ev, event=original_ev, redacted_event=redacted_event
redacted_event=redacted_event,
) )
self._get_event_cache.prefill((original_ev.event_id,), cache_entry) self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
@ -545,9 +545,8 @@ class EventsWorkerStore(SQLBaseStore):
results = set() results = set()
def have_seen_events_txn(txn, chunk): def have_seen_events_txn(txn, chunk):
sql = ( sql = "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" % (
"SELECT event_id FROM events as e WHERE e.event_id IN (%s)" ",".join("?" * len(chunk)),
% (",".join("?" * len(chunk)), )
) )
txn.execute(sql, chunk) txn.execute(sql, chunk)
for (event_id,) in txn: for (event_id,) in txn:
@ -555,13 +554,8 @@ class EventsWorkerStore(SQLBaseStore):
# break the input up into chunks of 100 # break the input up into chunks of 100
input_iterator = iter(event_ids) input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
[]): yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk)
yield self.runInteraction(
"have_seen_events",
have_seen_events_txn,
chunk,
)
defer.returnValue(results) defer.returnValue(results)
def get_seen_events_with_rejections(self, event_ids): def get_seen_events_with_rejections(self, event_ids):

View File

@ -35,10 +35,7 @@ class FilteringStore(SQLBaseStore):
def_json = yield self._simple_select_one_onecol( def_json = yield self._simple_select_one_onecol(
table="user_filters", table="user_filters",
keyvalues={ keyvalues={"user_id": user_localpart, "filter_id": filter_id},
"user_id": user_localpart,
"filter_id": filter_id,
},
retcol="filter_json", retcol="filter_json",
allow_none=False, allow_none=False,
desc="get_user_filter", desc="get_user_filter",
@ -61,10 +58,7 @@ class FilteringStore(SQLBaseStore):
if filter_id_response is not None: if filter_id_response is not None:
return filter_id_response[0] return filter_id_response[0]
sql = ( sql = "SELECT MAX(filter_id) FROM user_filters " "WHERE user_id = ?"
"SELECT MAX(filter_id) FROM user_filters "
"WHERE user_id = ?"
)
txn.execute(sql, (user_localpart,)) txn.execute(sql, (user_localpart,))
max_id = txn.fetchone()[0] max_id = txn.fetchone()[0]
if max_id is None: if max_id is None:

View File

@ -38,24 +38,22 @@ class GroupServerStore(SQLBaseStore):
""" """
return self._simple_update_one( return self._simple_update_one(
table="groups", table="groups",
keyvalues={ keyvalues={"group_id": group_id},
"group_id": group_id, updatevalues={"join_policy": join_policy},
},
updatevalues={
"join_policy": join_policy,
},
desc="set_group_join_policy", desc="set_group_join_policy",
) )
def get_group(self, group_id): def get_group(self, group_id):
return self._simple_select_one( return self._simple_select_one(
table="groups", table="groups",
keyvalues={ keyvalues={"group_id": group_id},
"group_id": group_id,
},
retcols=( retcols=(
"name", "short_description", "long_description", "name",
"avatar_url", "is_public", "join_policy", "short_description",
"long_description",
"avatar_url",
"is_public",
"join_policy",
), ),
allow_none=True, allow_none=True,
desc="get_group", desc="get_group",
@ -64,16 +62,14 @@ class GroupServerStore(SQLBaseStore):
def get_users_in_group(self, group_id, include_private=False): def get_users_in_group(self, group_id, include_private=False):
# TODO: Pagination # TODO: Pagination
keyvalues = { keyvalues = {"group_id": group_id}
"group_id": group_id,
}
if not include_private: if not include_private:
keyvalues["is_public"] = True keyvalues["is_public"] = True
return self._simple_select_list( return self._simple_select_list(
table="group_users", table="group_users",
keyvalues=keyvalues, keyvalues=keyvalues,
retcols=("user_id", "is_public", "is_admin",), retcols=("user_id", "is_public", "is_admin"),
desc="get_users_in_group", desc="get_users_in_group",
) )
@ -82,9 +78,7 @@ class GroupServerStore(SQLBaseStore):
return self._simple_select_onecol( return self._simple_select_onecol(
table="group_invites", table="group_invites",
keyvalues={ keyvalues={"group_id": group_id},
"group_id": group_id,
},
retcol="user_id", retcol="user_id",
desc="get_invited_users_in_group", desc="get_invited_users_in_group",
) )
@ -92,16 +86,14 @@ class GroupServerStore(SQLBaseStore):
def get_rooms_in_group(self, group_id, include_private=False): def get_rooms_in_group(self, group_id, include_private=False):
# TODO: Pagination # TODO: Pagination
keyvalues = { keyvalues = {"group_id": group_id}
"group_id": group_id,
}
if not include_private: if not include_private:
keyvalues["is_public"] = True keyvalues["is_public"] = True
return self._simple_select_list( return self._simple_select_list(
table="group_rooms", table="group_rooms",
keyvalues=keyvalues, keyvalues=keyvalues,
retcols=("room_id", "is_public",), retcols=("room_id", "is_public"),
desc="get_rooms_in_group", desc="get_rooms_in_group",
) )
@ -110,10 +102,9 @@ class GroupServerStore(SQLBaseStore):
Returns ([rooms], [categories]) Returns ([rooms], [categories])
""" """
def _get_rooms_for_summary_txn(txn): def _get_rooms_for_summary_txn(txn):
keyvalues = { keyvalues = {"group_id": group_id}
"group_id": group_id,
}
if not include_private: if not include_private:
keyvalues["is_public"] = True keyvalues["is_public"] = True
@ -162,18 +153,23 @@ class GroupServerStore(SQLBaseStore):
} }
return rooms, categories return rooms, categories
return self.runInteraction(
"get_rooms_for_summary", _get_rooms_for_summary_txn return self.runInteraction("get_rooms_for_summary", _get_rooms_for_summary_txn)
)
def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): def add_room_to_summary(self, group_id, room_id, category_id, order, is_public):
return self.runInteraction( return self.runInteraction(
"add_room_to_summary", self._add_room_to_summary_txn, "add_room_to_summary",
group_id, room_id, category_id, order, is_public, self._add_room_to_summary_txn,
group_id,
room_id,
category_id,
order,
is_public,
) )
def _add_room_to_summary_txn(self, txn, group_id, room_id, category_id, order, def _add_room_to_summary_txn(
is_public): self, txn, group_id, room_id, category_id, order, is_public
):
"""Add (or update) room's entry in summary. """Add (or update) room's entry in summary.
Args: Args:
@ -188,10 +184,7 @@ class GroupServerStore(SQLBaseStore):
room_in_group = self._simple_select_one_onecol_txn( room_in_group = self._simple_select_one_onecol_txn(
txn, txn,
table="group_rooms", table="group_rooms",
keyvalues={ keyvalues={"group_id": group_id, "room_id": room_id},
"group_id": group_id,
"room_id": room_id,
},
retcol="room_id", retcol="room_id",
allow_none=True, allow_none=True,
) )
@ -204,10 +197,7 @@ class GroupServerStore(SQLBaseStore):
cat_exists = self._simple_select_one_onecol_txn( cat_exists = self._simple_select_one_onecol_txn(
txn, txn,
table="group_room_categories", table="group_room_categories",
keyvalues={ keyvalues={"group_id": group_id, "category_id": category_id},
"group_id": group_id,
"category_id": category_id,
},
retcol="group_id", retcol="group_id",
allow_none=True, allow_none=True,
) )
@ -218,22 +208,22 @@ class GroupServerStore(SQLBaseStore):
cat_exists = self._simple_select_one_onecol_txn( cat_exists = self._simple_select_one_onecol_txn(
txn, txn,
table="group_summary_room_categories", table="group_summary_room_categories",
keyvalues={ keyvalues={"group_id": group_id, "category_id": category_id},
"group_id": group_id,
"category_id": category_id,
},
retcol="group_id", retcol="group_id",
allow_none=True, allow_none=True,
) )
if not cat_exists: if not cat_exists:
# If not, add it with an order larger than all others # If not, add it with an order larger than all others
txn.execute(""" txn.execute(
"""
INSERT INTO group_summary_room_categories INSERT INTO group_summary_room_categories
(group_id, category_id, cat_order) (group_id, category_id, cat_order)
SELECT ?, ?, COALESCE(MAX(cat_order), 0) + 1 SELECT ?, ?, COALESCE(MAX(cat_order), 0) + 1
FROM group_summary_room_categories FROM group_summary_room_categories
WHERE group_id = ? AND category_id = ? WHERE group_id = ? AND category_id = ?
""", (group_id, category_id, group_id, category_id)) """,
(group_id, category_id, group_id, category_id),
)
existing = self._simple_select_one_txn( existing = self._simple_select_one_txn(
txn, txn,
@ -243,7 +233,7 @@ class GroupServerStore(SQLBaseStore):
"room_id": room_id, "room_id": room_id,
"category_id": category_id, "category_id": category_id,
}, },
retcols=("room_order", "is_public",), retcols=("room_order", "is_public"),
allow_none=True, allow_none=True,
) )
@ -253,13 +243,13 @@ class GroupServerStore(SQLBaseStore):
UPDATE group_summary_rooms SET room_order = room_order + 1 UPDATE group_summary_rooms SET room_order = room_order + 1
WHERE group_id = ? AND category_id = ? AND room_order >= ? WHERE group_id = ? AND category_id = ? AND room_order >= ?
""" """
txn.execute(sql, (group_id, category_id, order,)) txn.execute(sql, (group_id, category_id, order))
elif not existing: elif not existing:
sql = """ sql = """
SELECT COALESCE(MAX(room_order), 0) + 1 FROM group_summary_rooms SELECT COALESCE(MAX(room_order), 0) + 1 FROM group_summary_rooms
WHERE group_id = ? AND category_id = ? WHERE group_id = ? AND category_id = ?
""" """
txn.execute(sql, (group_id, category_id,)) txn.execute(sql, (group_id, category_id))
order, = txn.fetchone() order, = txn.fetchone()
if existing: if existing:
@ -312,29 +302,26 @@ class GroupServerStore(SQLBaseStore):
def get_group_categories(self, group_id): def get_group_categories(self, group_id):
rows = yield self._simple_select_list( rows = yield self._simple_select_list(
table="group_room_categories", table="group_room_categories",
keyvalues={ keyvalues={"group_id": group_id},
"group_id": group_id,
},
retcols=("category_id", "is_public", "profile"), retcols=("category_id", "is_public", "profile"),
desc="get_group_categories", desc="get_group_categories",
) )
defer.returnValue({ defer.returnValue(
{
row["category_id"]: { row["category_id"]: {
"is_public": row["is_public"], "is_public": row["is_public"],
"profile": json.loads(row["profile"]), "profile": json.loads(row["profile"]),
} }
for row in rows for row in rows
}) }
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_group_category(self, group_id, category_id): def get_group_category(self, group_id, category_id):
category = yield self._simple_select_one( category = yield self._simple_select_one(
table="group_room_categories", table="group_room_categories",
keyvalues={ keyvalues={"group_id": group_id, "category_id": category_id},
"group_id": group_id,
"category_id": category_id,
},
retcols=("is_public", "profile"), retcols=("is_public", "profile"),
desc="get_group_category", desc="get_group_category",
) )
@ -361,10 +348,7 @@ class GroupServerStore(SQLBaseStore):
return self._simple_upsert( return self._simple_upsert(
table="group_room_categories", table="group_room_categories",
keyvalues={ keyvalues={"group_id": group_id, "category_id": category_id},
"group_id": group_id,
"category_id": category_id,
},
values=update_values, values=update_values,
insertion_values=insertion_values, insertion_values=insertion_values,
desc="upsert_group_category", desc="upsert_group_category",
@ -373,10 +357,7 @@ class GroupServerStore(SQLBaseStore):
def remove_group_category(self, group_id, category_id): def remove_group_category(self, group_id, category_id):
return self._simple_delete( return self._simple_delete(
table="group_room_categories", table="group_room_categories",
keyvalues={ keyvalues={"group_id": group_id, "category_id": category_id},
"group_id": group_id,
"category_id": category_id,
},
desc="remove_group_category", desc="remove_group_category",
) )
@ -384,29 +365,26 @@ class GroupServerStore(SQLBaseStore):
def get_group_roles(self, group_id): def get_group_roles(self, group_id):
rows = yield self._simple_select_list( rows = yield self._simple_select_list(
table="group_roles", table="group_roles",
keyvalues={ keyvalues={"group_id": group_id},
"group_id": group_id,
},
retcols=("role_id", "is_public", "profile"), retcols=("role_id", "is_public", "profile"),
desc="get_group_roles", desc="get_group_roles",
) )
defer.returnValue({ defer.returnValue(
{
row["role_id"]: { row["role_id"]: {
"is_public": row["is_public"], "is_public": row["is_public"],
"profile": json.loads(row["profile"]), "profile": json.loads(row["profile"]),
} }
for row in rows for row in rows
}) }
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_group_role(self, group_id, role_id): def get_group_role(self, group_id, role_id):
role = yield self._simple_select_one( role = yield self._simple_select_one(
table="group_roles", table="group_roles",
keyvalues={ keyvalues={"group_id": group_id, "role_id": role_id},
"group_id": group_id,
"role_id": role_id,
},
retcols=("is_public", "profile"), retcols=("is_public", "profile"),
desc="get_group_role", desc="get_group_role",
) )
@ -433,10 +411,7 @@ class GroupServerStore(SQLBaseStore):
return self._simple_upsert( return self._simple_upsert(
table="group_roles", table="group_roles",
keyvalues={ keyvalues={"group_id": group_id, "role_id": role_id},
"group_id": group_id,
"role_id": role_id,
},
values=update_values, values=update_values,
insertion_values=insertion_values, insertion_values=insertion_values,
desc="upsert_group_role", desc="upsert_group_role",
@ -445,21 +420,24 @@ class GroupServerStore(SQLBaseStore):
def remove_group_role(self, group_id, role_id): def remove_group_role(self, group_id, role_id):
return self._simple_delete( return self._simple_delete(
table="group_roles", table="group_roles",
keyvalues={ keyvalues={"group_id": group_id, "role_id": role_id},
"group_id": group_id,
"role_id": role_id,
},
desc="remove_group_role", desc="remove_group_role",
) )
def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): def add_user_to_summary(self, group_id, user_id, role_id, order, is_public):
return self.runInteraction( return self.runInteraction(
"add_user_to_summary", self._add_user_to_summary_txn, "add_user_to_summary",
group_id, user_id, role_id, order, is_public, self._add_user_to_summary_txn,
group_id,
user_id,
role_id,
order,
is_public,
) )
def _add_user_to_summary_txn(self, txn, group_id, user_id, role_id, order, def _add_user_to_summary_txn(
is_public): self, txn, group_id, user_id, role_id, order, is_public
):
"""Add (or update) user's entry in summary. """Add (or update) user's entry in summary.
Args: Args:
@ -474,10 +452,7 @@ class GroupServerStore(SQLBaseStore):
user_in_group = self._simple_select_one_onecol_txn( user_in_group = self._simple_select_one_onecol_txn(
txn, txn,
table="group_users", table="group_users",
keyvalues={ keyvalues={"group_id": group_id, "user_id": user_id},
"group_id": group_id,
"user_id": user_id,
},
retcol="user_id", retcol="user_id",
allow_none=True, allow_none=True,
) )
@ -490,10 +465,7 @@ class GroupServerStore(SQLBaseStore):
role_exists = self._simple_select_one_onecol_txn( role_exists = self._simple_select_one_onecol_txn(
txn, txn,
table="group_roles", table="group_roles",
keyvalues={ keyvalues={"group_id": group_id, "role_id": role_id},
"group_id": group_id,
"role_id": role_id,
},
retcol="group_id", retcol="group_id",
allow_none=True, allow_none=True,
) )
@ -504,32 +476,28 @@ class GroupServerStore(SQLBaseStore):
role_exists = self._simple_select_one_onecol_txn( role_exists = self._simple_select_one_onecol_txn(
txn, txn,
table="group_summary_roles", table="group_summary_roles",
keyvalues={ keyvalues={"group_id": group_id, "role_id": role_id},
"group_id": group_id,
"role_id": role_id,
},
retcol="group_id", retcol="group_id",
allow_none=True, allow_none=True,
) )
if not role_exists: if not role_exists:
# If not, add it with an order larger than all others # If not, add it with an order larger than all others
txn.execute(""" txn.execute(
"""
INSERT INTO group_summary_roles INSERT INTO group_summary_roles
(group_id, role_id, role_order) (group_id, role_id, role_order)
SELECT ?, ?, COALESCE(MAX(role_order), 0) + 1 SELECT ?, ?, COALESCE(MAX(role_order), 0) + 1
FROM group_summary_roles FROM group_summary_roles
WHERE group_id = ? AND role_id = ? WHERE group_id = ? AND role_id = ?
""", (group_id, role_id, group_id, role_id)) """,
(group_id, role_id, group_id, role_id),
)
existing = self._simple_select_one_txn( existing = self._simple_select_one_txn(
txn, txn,
table="group_summary_users", table="group_summary_users",
keyvalues={ keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id},
"group_id": group_id, retcols=("user_order", "is_public"),
"user_id": user_id,
"role_id": role_id,
},
retcols=("user_order", "is_public",),
allow_none=True, allow_none=True,
) )
@ -539,13 +507,13 @@ class GroupServerStore(SQLBaseStore):
UPDATE group_summary_users SET user_order = user_order + 1 UPDATE group_summary_users SET user_order = user_order + 1
WHERE group_id = ? AND role_id = ? AND user_order >= ? WHERE group_id = ? AND role_id = ? AND user_order >= ?
""" """
txn.execute(sql, (group_id, role_id, order,)) txn.execute(sql, (group_id, role_id, order))
elif not existing: elif not existing:
sql = """ sql = """
SELECT COALESCE(MAX(user_order), 0) + 1 FROM group_summary_users SELECT COALESCE(MAX(user_order), 0) + 1 FROM group_summary_users
WHERE group_id = ? AND role_id = ? WHERE group_id = ? AND role_id = ?
""" """
txn.execute(sql, (group_id, role_id,)) txn.execute(sql, (group_id, role_id))
order, = txn.fetchone() order, = txn.fetchone()
if existing: if existing:
@ -586,11 +554,7 @@ class GroupServerStore(SQLBaseStore):
return self._simple_delete( return self._simple_delete(
table="group_summary_users", table="group_summary_users",
keyvalues={ keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id},
"group_id": group_id,
"role_id": role_id,
"user_id": user_id,
},
desc="remove_user_from_summary", desc="remove_user_from_summary",
) )
@ -599,10 +563,9 @@ class GroupServerStore(SQLBaseStore):
Returns ([users], [roles]) Returns ([users], [roles])
""" """
def _get_users_for_summary_txn(txn): def _get_users_for_summary_txn(txn):
keyvalues = { keyvalues = {"group_id": group_id}
"group_id": group_id,
}
if not include_private: if not include_private:
keyvalues["is_public"] = True keyvalues["is_public"] = True
@ -651,6 +614,7 @@ class GroupServerStore(SQLBaseStore):
} }
return users, roles return users, roles
return self.runInteraction( return self.runInteraction(
"get_users_for_summary_by_role", _get_users_for_summary_txn "get_users_for_summary_by_role", _get_users_for_summary_txn
) )
@ -658,10 +622,7 @@ class GroupServerStore(SQLBaseStore):
def is_user_in_group(self, user_id, group_id): def is_user_in_group(self, user_id, group_id):
return self._simple_select_one_onecol( return self._simple_select_one_onecol(
table="group_users", table="group_users",
keyvalues={ keyvalues={"group_id": group_id, "user_id": user_id},
"group_id": group_id,
"user_id": user_id,
},
retcol="user_id", retcol="user_id",
allow_none=True, allow_none=True,
desc="is_user_in_group", desc="is_user_in_group",
@ -670,10 +631,7 @@ class GroupServerStore(SQLBaseStore):
def is_user_admin_in_group(self, group_id, user_id): def is_user_admin_in_group(self, group_id, user_id):
return self._simple_select_one_onecol( return self._simple_select_one_onecol(
table="group_users", table="group_users",
keyvalues={ keyvalues={"group_id": group_id, "user_id": user_id},
"group_id": group_id,
"user_id": user_id,
},
retcol="is_admin", retcol="is_admin",
allow_none=True, allow_none=True,
desc="is_user_admin_in_group", desc="is_user_admin_in_group",
@ -684,10 +642,7 @@ class GroupServerStore(SQLBaseStore):
""" """
return self._simple_insert( return self._simple_insert(
table="group_invites", table="group_invites",
values={ values={"group_id": group_id, "user_id": user_id},
"group_id": group_id,
"user_id": user_id,
},
desc="add_group_invite", desc="add_group_invite",
) )
@ -696,10 +651,7 @@ class GroupServerStore(SQLBaseStore):
""" """
return self._simple_select_one_onecol( return self._simple_select_one_onecol(
table="group_invites", table="group_invites",
keyvalues={ keyvalues={"group_id": group_id, "user_id": user_id},
"group_id": group_id,
"user_id": user_id,
},
retcol="user_id", retcol="user_id",
desc="is_user_invited_to_local_group", desc="is_user_invited_to_local_group",
allow_none=True, allow_none=True,
@ -718,14 +670,12 @@ class GroupServerStore(SQLBaseStore):
Returns an empty dict if the user is not join/invite/etc Returns an empty dict if the user is not join/invite/etc
""" """
def _get_users_membership_in_group_txn(txn): def _get_users_membership_in_group_txn(txn):
row = self._simple_select_one_txn( row = self._simple_select_one_txn(
txn, txn,
table="group_users", table="group_users",
keyvalues={ keyvalues={"group_id": group_id, "user_id": user_id},
"group_id": group_id,
"user_id": user_id,
},
retcols=("is_admin", "is_public"), retcols=("is_admin", "is_public"),
allow_none=True, allow_none=True,
) )
@ -740,27 +690,29 @@ class GroupServerStore(SQLBaseStore):
row = self._simple_select_one_onecol_txn( row = self._simple_select_one_onecol_txn(
txn, txn,
table="group_invites", table="group_invites",
keyvalues={ keyvalues={"group_id": group_id, "user_id": user_id},
"group_id": group_id,
"user_id": user_id,
},
retcol="user_id", retcol="user_id",
allow_none=True, allow_none=True,
) )
if row: if row:
return { return {"membership": "invite"}
"membership": "invite",
}
return {} return {}
return self.runInteraction( return self.runInteraction(
"get_users_membership_info_in_group", _get_users_membership_in_group_txn, "get_users_membership_info_in_group", _get_users_membership_in_group_txn
) )
def add_user_to_group(self, group_id, user_id, is_admin=False, is_public=True, def add_user_to_group(
local_attestation=None, remote_attestation=None): self,
group_id,
user_id,
is_admin=False,
is_public=True,
local_attestation=None,
remote_attestation=None,
):
"""Add a user to the group server. """Add a user to the group server.
Args: Args:
@ -774,6 +726,7 @@ class GroupServerStore(SQLBaseStore):
remote_attestation (dict): The attestation given to GS by remote remote_attestation (dict): The attestation given to GS by remote
server. Optional if the user and group are on the same server server. Optional if the user and group are on the same server
""" """
def _add_user_to_group_txn(txn): def _add_user_to_group_txn(txn):
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
@ -789,10 +742,7 @@ class GroupServerStore(SQLBaseStore):
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
table="group_invites", table="group_invites",
keyvalues={ keyvalues={"group_id": group_id, "user_id": user_id},
"group_id": group_id,
"user_id": user_id,
},
) )
if local_attestation: if local_attestation:
@ -817,75 +767,52 @@ class GroupServerStore(SQLBaseStore):
}, },
) )
return self.runInteraction( return self.runInteraction("add_user_to_group", _add_user_to_group_txn)
"add_user_to_group", _add_user_to_group_txn
)
def remove_user_from_group(self, group_id, user_id): def remove_user_from_group(self, group_id, user_id):
def _remove_user_from_group_txn(txn): def _remove_user_from_group_txn(txn):
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
table="group_users", table="group_users",
keyvalues={ keyvalues={"group_id": group_id, "user_id": user_id},
"group_id": group_id,
"user_id": user_id,
},
) )
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
table="group_invites", table="group_invites",
keyvalues={ keyvalues={"group_id": group_id, "user_id": user_id},
"group_id": group_id,
"user_id": user_id,
},
) )
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
table="group_attestations_renewals", table="group_attestations_renewals",
keyvalues={ keyvalues={"group_id": group_id, "user_id": user_id},
"group_id": group_id,
"user_id": user_id,
},
) )
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
table="group_attestations_remote", table="group_attestations_remote",
keyvalues={ keyvalues={"group_id": group_id, "user_id": user_id},
"group_id": group_id,
"user_id": user_id,
},
) )
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
table="group_summary_users", table="group_summary_users",
keyvalues={ keyvalues={"group_id": group_id, "user_id": user_id},
"group_id": group_id, )
"user_id": user_id,
}, return self.runInteraction(
"remove_user_from_group", _remove_user_from_group_txn
) )
return self.runInteraction("remove_user_from_group", _remove_user_from_group_txn)
def add_room_to_group(self, group_id, room_id, is_public): def add_room_to_group(self, group_id, room_id, is_public):
return self._simple_insert( return self._simple_insert(
table="group_rooms", table="group_rooms",
values={ values={"group_id": group_id, "room_id": room_id, "is_public": is_public},
"group_id": group_id,
"room_id": room_id,
"is_public": is_public,
},
desc="add_room_to_group", desc="add_room_to_group",
) )
def update_room_in_group_visibility(self, group_id, room_id, is_public): def update_room_in_group_visibility(self, group_id, room_id, is_public):
return self._simple_update( return self._simple_update(
table="group_rooms", table="group_rooms",
keyvalues={ keyvalues={"group_id": group_id, "room_id": room_id},
"group_id": group_id, updatevalues={"is_public": is_public},
"room_id": room_id,
},
updatevalues={
"is_public": is_public,
},
desc="update_room_in_group_visibility", desc="update_room_in_group_visibility",
) )
@ -894,22 +821,17 @@ class GroupServerStore(SQLBaseStore):
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
table="group_rooms", table="group_rooms",
keyvalues={ keyvalues={"group_id": group_id, "room_id": room_id},
"group_id": group_id,
"room_id": room_id,
},
) )
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
table="group_summary_rooms", table="group_summary_rooms",
keyvalues={ keyvalues={"group_id": group_id, "room_id": room_id},
"group_id": group_id,
"room_id": room_id,
},
) )
return self.runInteraction( return self.runInteraction(
"remove_room_from_group", _remove_room_from_group_txn, "remove_room_from_group", _remove_room_from_group_txn
) )
def get_publicised_groups_for_user(self, user_id): def get_publicised_groups_for_user(self, user_id):
@ -917,11 +839,7 @@ class GroupServerStore(SQLBaseStore):
""" """
return self._simple_select_onecol( return self._simple_select_onecol(
table="local_group_membership", table="local_group_membership",
keyvalues={ keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
"user_id": user_id,
"membership": "join",
"is_publicised": True,
},
retcol="group_id", retcol="group_id",
desc="get_publicised_groups_for_user", desc="get_publicised_groups_for_user",
) )
@ -931,19 +849,19 @@ class GroupServerStore(SQLBaseStore):
""" """
return self._simple_update_one( return self._simple_update_one(
table="local_group_membership", table="local_group_membership",
keyvalues={ keyvalues={"group_id": group_id, "user_id": user_id},
"group_id": group_id, updatevalues={"is_publicised": publicise},
"user_id": user_id, desc="update_group_publicity",
},
updatevalues={
"is_publicised": publicise,
},
desc="update_group_publicity"
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def register_user_group_membership(self, group_id, user_id, membership, def register_user_group_membership(
is_admin=False, content={}, self,
group_id,
user_id,
membership,
is_admin=False,
content={},
local_attestation=None, local_attestation=None,
remote_attestation=None, remote_attestation=None,
is_publicised=False, is_publicised=False,
@ -962,15 +880,13 @@ class GroupServerStore(SQLBaseStore):
remote_attestation (dict): If remote group then store the remote remote_attestation (dict): If remote group then store the remote
attestation from the group, else None. attestation from the group, else None.
""" """
def _register_user_group_membership_txn(txn, next_id): def _register_user_group_membership_txn(txn, next_id):
# TODO: Upsert? # TODO: Upsert?
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
table="local_group_membership", table="local_group_membership",
keyvalues={ keyvalues={"group_id": group_id, "user_id": user_id},
"group_id": group_id,
"user_id": user_id,
},
) )
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
@ -993,8 +909,10 @@ class GroupServerStore(SQLBaseStore):
"group_id": group_id, "group_id": group_id,
"user_id": user_id, "user_id": user_id,
"type": "membership", "type": "membership",
"content": json.dumps({"membership": membership, "content": content}), "content": json.dumps(
} {"membership": membership, "content": content}
),
},
) )
self._group_updates_stream_cache.entity_has_changed(user_id, next_id) self._group_updates_stream_cache.entity_has_changed(user_id, next_id)
@ -1009,7 +927,7 @@ class GroupServerStore(SQLBaseStore):
"group_id": group_id, "group_id": group_id,
"user_id": user_id, "user_id": user_id,
"valid_until_ms": local_attestation["valid_until_ms"], "valid_until_ms": local_attestation["valid_until_ms"],
} },
) )
if remote_attestation: if remote_attestation:
self._simple_insert_txn( self._simple_insert_txn(
@ -1020,24 +938,18 @@ class GroupServerStore(SQLBaseStore):
"user_id": user_id, "user_id": user_id,
"valid_until_ms": remote_attestation["valid_until_ms"], "valid_until_ms": remote_attestation["valid_until_ms"],
"attestation_json": json.dumps(remote_attestation), "attestation_json": json.dumps(remote_attestation),
} },
) )
else: else:
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
table="group_attestations_renewals", table="group_attestations_renewals",
keyvalues={ keyvalues={"group_id": group_id, "user_id": user_id},
"group_id": group_id,
"user_id": user_id,
},
) )
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
table="group_attestations_remote", table="group_attestations_remote",
keyvalues={ keyvalues={"group_id": group_id, "user_id": user_id},
"group_id": group_id,
"user_id": user_id,
},
) )
return next_id return next_id
@ -1045,13 +957,15 @@ class GroupServerStore(SQLBaseStore):
with self._group_updates_id_gen.get_next() as next_id: with self._group_updates_id_gen.get_next() as next_id:
res = yield self.runInteraction( res = yield self.runInteraction(
"register_user_group_membership", "register_user_group_membership",
_register_user_group_membership_txn, next_id, _register_user_group_membership_txn,
next_id,
) )
defer.returnValue(res) defer.returnValue(res)
@defer.inlineCallbacks @defer.inlineCallbacks
def create_group(self, group_id, user_id, name, avatar_url, short_description, def create_group(
long_description,): self, group_id, user_id, name, avatar_url, short_description, long_description
):
yield self._simple_insert( yield self._simple_insert(
table="groups", table="groups",
values={ values={
@ -1066,12 +980,10 @@ class GroupServerStore(SQLBaseStore):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def update_group_profile(self, group_id, profile,): def update_group_profile(self, group_id, profile):
yield self._simple_update_one( yield self._simple_update_one(
table="groups", table="groups",
keyvalues={ keyvalues={"group_id": group_id},
"group_id": group_id,
},
updatevalues=profile, updatevalues=profile,
desc="update_group_profile", desc="update_group_profile",
) )
@ -1079,6 +991,7 @@ class GroupServerStore(SQLBaseStore):
def get_attestations_need_renewals(self, valid_until_ms): def get_attestations_need_renewals(self, valid_until_ms):
"""Get all attestations that need to be renewed until givent time """Get all attestations that need to be renewed until givent time
""" """
def _get_attestations_need_renewals_txn(txn): def _get_attestations_need_renewals_txn(txn):
sql = """ sql = """
SELECT group_id, user_id FROM group_attestations_renewals SELECT group_id, user_id FROM group_attestations_renewals
@ -1086,6 +999,7 @@ class GroupServerStore(SQLBaseStore):
""" """
txn.execute(sql, (valid_until_ms,)) txn.execute(sql, (valid_until_ms,))
return self.cursor_to_dict(txn) return self.cursor_to_dict(txn)
return self.runInteraction( return self.runInteraction(
"get_attestations_need_renewals", _get_attestations_need_renewals_txn "get_attestations_need_renewals", _get_attestations_need_renewals_txn
) )
@ -1095,13 +1009,8 @@ class GroupServerStore(SQLBaseStore):
""" """
return self._simple_update_one( return self._simple_update_one(
table="group_attestations_renewals", table="group_attestations_renewals",
keyvalues={ keyvalues={"group_id": group_id, "user_id": user_id},
"group_id": group_id, updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
"user_id": user_id,
},
updatevalues={
"valid_until_ms": attestation["valid_until_ms"],
},
desc="update_attestation_renewal", desc="update_attestation_renewal",
) )
@ -1110,13 +1019,10 @@ class GroupServerStore(SQLBaseStore):
""" """
return self._simple_update_one( return self._simple_update_one(
table="group_attestations_remote", table="group_attestations_remote",
keyvalues={ keyvalues={"group_id": group_id, "user_id": user_id},
"group_id": group_id,
"user_id": user_id,
},
updatevalues={ updatevalues={
"valid_until_ms": attestation["valid_until_ms"], "valid_until_ms": attestation["valid_until_ms"],
"attestation_json": json.dumps(attestation) "attestation_json": json.dumps(attestation),
}, },
desc="update_remote_attestion", desc="update_remote_attestion",
) )
@ -1132,10 +1038,7 @@ class GroupServerStore(SQLBaseStore):
""" """
return self._simple_delete( return self._simple_delete(
table="group_attestations_renewals", table="group_attestations_renewals",
keyvalues={ keyvalues={"group_id": group_id, "user_id": user_id},
"group_id": group_id,
"user_id": user_id,
},
desc="remove_attestation_renewal", desc="remove_attestation_renewal",
) )
@ -1146,10 +1049,7 @@ class GroupServerStore(SQLBaseStore):
""" """
row = yield self._simple_select_one( row = yield self._simple_select_one(
table="group_attestations_remote", table="group_attestations_remote",
keyvalues={ keyvalues={"group_id": group_id, "user_id": user_id},
"group_id": group_id,
"user_id": user_id,
},
retcols=("valid_until_ms", "attestation_json"), retcols=("valid_until_ms", "attestation_json"),
desc="get_remote_attestation", desc="get_remote_attestation",
allow_none=True, allow_none=True,
@ -1164,10 +1064,7 @@ class GroupServerStore(SQLBaseStore):
def get_joined_groups(self, user_id): def get_joined_groups(self, user_id):
return self._simple_select_onecol( return self._simple_select_onecol(
table="local_group_membership", table="local_group_membership",
keyvalues={ keyvalues={"user_id": user_id, "membership": "join"},
"user_id": user_id,
"membership": "join",
},
retcol="group_id", retcol="group_id",
desc="get_joined_groups", desc="get_joined_groups",
) )
@ -1181,7 +1078,7 @@ class GroupServerStore(SQLBaseStore):
WHERE user_id = ? AND membership != 'leave' WHERE user_id = ? AND membership != 'leave'
AND stream_id <= ? AND stream_id <= ?
""" """
txn.execute(sql, (user_id, now_token,)) txn.execute(sql, (user_id, now_token))
return [ return [
{ {
"group_id": row[0], "group_id": row[0],
@ -1191,14 +1088,15 @@ class GroupServerStore(SQLBaseStore):
} }
for row in txn for row in txn
] ]
return self.runInteraction( return self.runInteraction(
"get_all_groups_for_user", _get_all_groups_for_user_txn, "get_all_groups_for_user", _get_all_groups_for_user_txn
) )
def get_groups_changes_for_user(self, user_id, from_token, to_token): def get_groups_changes_for_user(self, user_id, from_token, to_token):
from_token = int(from_token) from_token = int(from_token)
has_changed = self._group_updates_stream_cache.has_entity_changed( has_changed = self._group_updates_stream_cache.has_entity_changed(
user_id, from_token, user_id, from_token
) )
if not has_changed: if not has_changed:
return [] return []
@ -1210,21 +1108,25 @@ class GroupServerStore(SQLBaseStore):
INNER JOIN local_group_membership USING (group_id, user_id) INNER JOIN local_group_membership USING (group_id, user_id)
WHERE user_id = ? AND ? < stream_id AND stream_id <= ? WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
""" """
txn.execute(sql, (user_id, from_token, to_token,)) txn.execute(sql, (user_id, from_token, to_token))
return [{ return [
{
"group_id": group_id, "group_id": group_id,
"membership": membership, "membership": membership,
"type": gtype, "type": gtype,
"content": json.loads(content_json), "content": json.loads(content_json),
} for group_id, membership, gtype, content_json in txn] }
for group_id, membership, gtype, content_json in txn
]
return self.runInteraction( return self.runInteraction(
"get_groups_changes_for_user", _get_groups_changes_for_user_txn, "get_groups_changes_for_user", _get_groups_changes_for_user_txn
) )
def get_all_groups_changes(self, from_token, to_token, limit): def get_all_groups_changes(self, from_token, to_token, limit):
from_token = int(from_token) from_token = int(from_token)
has_changed = self._group_updates_stream_cache.has_any_entity_changed( has_changed = self._group_updates_stream_cache.has_any_entity_changed(
from_token, from_token
) )
if not has_changed: if not has_changed:
return [] return []
@ -1236,16 +1138,14 @@ class GroupServerStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ? WHERE ? < stream_id AND stream_id <= ?
LIMIT ? LIMIT ?
""" """
txn.execute(sql, (from_token, to_token, limit,)) txn.execute(sql, (from_token, to_token, limit))
return [( return [
stream_id, (stream_id, group_id, user_id, gtype, json.loads(content_json))
group_id, for stream_id, group_id, user_id, gtype, content_json in txn
user_id, ]
gtype,
json.loads(content_json),
) for stream_id, group_id, user_id, gtype, content_json in txn]
return self.runInteraction( return self.runInteraction(
"get_all_groups_changes", _get_all_groups_changes_txn, "get_all_groups_changes", _get_all_groups_changes_txn
) )
def get_group_stream_token(self): def get_group_stream_token(self):

View File

@ -56,12 +56,13 @@ class KeyStore(SQLBaseStore):
desc="get_server_certificate", desc="get_server_certificate",
) )
tls_certificate = OpenSSL.crypto.load_certificate( tls_certificate = OpenSSL.crypto.load_certificate(
OpenSSL.crypto.FILETYPE_ASN1, tls_certificate_bytes, OpenSSL.crypto.FILETYPE_ASN1, tls_certificate_bytes
) )
defer.returnValue(tls_certificate) defer.returnValue(tls_certificate)
def store_server_certificate(self, server_name, from_server, time_now_ms, def store_server_certificate(
tls_certificate): self, server_name, from_server, time_now_ms, tls_certificate
):
"""Stores the TLS X.509 certificate for the given server """Stores the TLS X.509 certificate for the given server
Args: Args:
server_name (str): The name of the server. server_name (str): The name of the server.
@ -75,10 +76,7 @@ class KeyStore(SQLBaseStore):
fingerprint = hashlib.sha256(tls_certificate_bytes).hexdigest() fingerprint = hashlib.sha256(tls_certificate_bytes).hexdigest()
return self._simple_upsert( return self._simple_upsert(
table="server_tls_certificates", table="server_tls_certificates",
keyvalues={ keyvalues={"server_name": server_name, "fingerprint": fingerprint},
"server_name": server_name,
"fingerprint": fingerprint,
},
values={ values={
"from_server": from_server, "from_server": from_server,
"ts_added_ms": time_now_ms, "ts_added_ms": time_now_ms,
@ -91,19 +89,14 @@ class KeyStore(SQLBaseStore):
def _get_server_verify_key(self, server_name, key_id): def _get_server_verify_key(self, server_name, key_id):
verify_key_bytes = yield self._simple_select_one_onecol( verify_key_bytes = yield self._simple_select_one_onecol(
table="server_signature_keys", table="server_signature_keys",
keyvalues={ keyvalues={"server_name": server_name, "key_id": key_id},
"server_name": server_name,
"key_id": key_id,
},
retcol="verify_key", retcol="verify_key",
desc="_get_server_verify_key", desc="_get_server_verify_key",
allow_none=True, allow_none=True,
) )
if verify_key_bytes: if verify_key_bytes:
defer.returnValue(decode_verify_key_bytes( defer.returnValue(decode_verify_key_bytes(key_id, bytes(verify_key_bytes)))
key_id, bytes(verify_key_bytes)
))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_keys(self, server_name, key_ids): def get_server_verify_keys(self, server_name, key_ids):
@ -123,8 +116,9 @@ class KeyStore(SQLBaseStore):
keys[key_id] = key keys[key_id] = key
defer.returnValue(keys) defer.returnValue(keys)
def store_server_verify_key(self, server_name, from_server, time_now_ms, def store_server_verify_key(
verify_key): self, server_name, from_server, time_now_ms, verify_key
):
"""Stores a NACL verification key for the given server. """Stores a NACL verification key for the given server.
Args: Args:
server_name (str): The name of the server. server_name (str): The name of the server.
@ -139,10 +133,7 @@ class KeyStore(SQLBaseStore):
self._simple_upsert_txn( self._simple_upsert_txn(
txn, txn,
table="server_signature_keys", table="server_signature_keys",
keyvalues={ keyvalues={"server_name": server_name, "key_id": key_id},
"server_name": server_name,
"key_id": key_id,
},
values={ values={
"from_server": from_server, "from_server": from_server,
"ts_added_ms": time_now_ms, "ts_added_ms": time_now_ms,
@ -150,14 +141,14 @@ class KeyStore(SQLBaseStore):
}, },
) )
txn.call_after( txn.call_after(
self._get_server_verify_key.invalidate, self._get_server_verify_key.invalidate, (server_name, key_id)
(server_name, key_id)
) )
return self.runInteraction("store_server_verify_key", _txn) return self.runInteraction("store_server_verify_key", _txn)
def store_server_keys_json(self, server_name, key_id, from_server, def store_server_keys_json(
ts_now_ms, ts_expires_ms, key_json_bytes): self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
):
"""Stores the JSON bytes for a set of keys from a server """Stores the JSON bytes for a set of keys from a server
The JSON should be signed by the originating server, the intermediate The JSON should be signed by the originating server, the intermediate
server, and by this server. Updates the value for the server, and by this server. Updates the value for the
@ -200,6 +191,7 @@ class KeyStore(SQLBaseStore):
Dict mapping (server_name, key_id, source) triplets to dicts with Dict mapping (server_name, key_id, source) triplets to dicts with
"ts_valid_until_ms" and "key_json" keys. "ts_valid_until_ms" and "key_json" keys.
""" """
def _get_server_keys_json_txn(txn): def _get_server_keys_json_txn(txn):
results = {} results = {}
for server_name, key_id, from_server in server_keys: for server_name, key_id, from_server in server_keys:
@ -222,6 +214,5 @@ class KeyStore(SQLBaseStore):
) )
results[(server_name, key_id, from_server)] = rows results[(server_name, key_id, from_server)] = rows
return results return results
return self.runInteraction(
"get_server_keys_json", _get_server_keys_json_txn return self.runInteraction("get_server_keys_json", _get_server_keys_json_txn)
)

View File

@ -38,15 +38,27 @@ class MediaRepositoryStore(BackgroundUpdateStore):
"local_media_repository", "local_media_repository",
{"media_id": media_id}, {"media_id": media_id},
( (
"media_type", "media_length", "upload_name", "created_ts", "media_type",
"quarantined_by", "url_cache", "media_length",
"upload_name",
"created_ts",
"quarantined_by",
"url_cache",
), ),
allow_none=True, allow_none=True,
desc="get_local_media", desc="get_local_media",
) )
def store_local_media(self, media_id, media_type, time_now_ms, upload_name, def store_local_media(
media_length, user_id, url_cache=None): self,
media_id,
media_type,
time_now_ms,
upload_name,
media_length,
user_id,
url_cache=None,
):
return self._simple_insert( return self._simple_insert(
"local_media_repository", "local_media_repository",
{ {
@ -66,6 +78,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
Returns: Returns:
None if the URL isn't cached. None if the URL isn't cached.
""" """
def get_url_cache_txn(txn): def get_url_cache_txn(txn):
# get the most recently cached result (relative to the given ts) # get the most recently cached result (relative to the given ts)
sql = ( sql = (
@ -92,16 +105,25 @@ class MediaRepositoryStore(BackgroundUpdateStore):
if not row: if not row:
return None return None
return dict(zip(( return dict(
'response_code', 'etag', 'expires_ts', 'og', 'media_id', 'download_ts' zip(
), row)) (
'response_code',
return self.runInteraction( 'etag',
"get_url_cache", get_url_cache_txn 'expires_ts',
'og',
'media_id',
'download_ts',
),
row,
)
) )
def store_url_cache(self, url, response_code, etag, expires_ts, og, media_id, return self.runInteraction("get_url_cache", get_url_cache_txn)
download_ts):
def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts
):
return self._simple_insert( return self._simple_insert(
"local_media_repository_url_cache", "local_media_repository_url_cache",
{ {
@ -121,15 +143,24 @@ class MediaRepositoryStore(BackgroundUpdateStore):
"local_media_repository_thumbnails", "local_media_repository_thumbnails",
{"media_id": media_id}, {"media_id": media_id},
( (
"thumbnail_width", "thumbnail_height", "thumbnail_method", "thumbnail_width",
"thumbnail_type", "thumbnail_length", "thumbnail_height",
"thumbnail_method",
"thumbnail_type",
"thumbnail_length",
), ),
desc="get_local_media_thumbnails", desc="get_local_media_thumbnails",
) )
def store_local_thumbnail(self, media_id, thumbnail_width, def store_local_thumbnail(
thumbnail_height, thumbnail_type, self,
thumbnail_method, thumbnail_length): media_id,
thumbnail_width,
thumbnail_height,
thumbnail_type,
thumbnail_method,
thumbnail_length,
):
return self._simple_insert( return self._simple_insert(
"local_media_repository_thumbnails", "local_media_repository_thumbnails",
{ {
@ -148,16 +179,27 @@ class MediaRepositoryStore(BackgroundUpdateStore):
"remote_media_cache", "remote_media_cache",
{"media_origin": origin, "media_id": media_id}, {"media_origin": origin, "media_id": media_id},
( (
"media_type", "media_length", "upload_name", "created_ts", "media_type",
"filesystem_id", "quarantined_by", "media_length",
"upload_name",
"created_ts",
"filesystem_id",
"quarantined_by",
), ),
allow_none=True, allow_none=True,
desc="get_cached_remote_media", desc="get_cached_remote_media",
) )
def store_cached_remote_media(self, origin, media_id, media_type, def store_cached_remote_media(
media_length, time_now_ms, upload_name, self,
filesystem_id): origin,
media_id,
media_type,
media_length,
time_now_ms,
upload_name,
filesystem_id,
):
return self._simple_insert( return self._simple_insert(
"remote_media_cache", "remote_media_cache",
{ {
@ -181,26 +223,27 @@ class MediaRepositoryStore(BackgroundUpdateStore):
remote_media (iterable[(str, str)]): Set of (server_name, media_id) remote_media (iterable[(str, str)]): Set of (server_name, media_id)
time_ms: Current time in milliseconds time_ms: Current time in milliseconds
""" """
def update_cache_txn(txn): def update_cache_txn(txn):
sql = ( sql = (
"UPDATE remote_media_cache SET last_access_ts = ?" "UPDATE remote_media_cache SET last_access_ts = ?"
" WHERE media_origin = ? AND media_id = ?" " WHERE media_origin = ? AND media_id = ?"
) )
txn.executemany(sql, ( txn.executemany(
sql,
(
(time_ms, media_origin, media_id) (time_ms, media_origin, media_id)
for media_origin, media_id in remote_media for media_origin, media_id in remote_media
)) ),
)
sql = ( sql = (
"UPDATE local_media_repository SET last_access_ts = ?" "UPDATE local_media_repository SET last_access_ts = ?"
" WHERE media_id = ?" " WHERE media_id = ?"
) )
txn.executemany(sql, ( txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
(time_ms, media_id)
for media_id in local_media
))
return self.runInteraction("update_cached_last_access_time", update_cache_txn) return self.runInteraction("update_cached_last_access_time", update_cache_txn)
@ -209,16 +252,27 @@ class MediaRepositoryStore(BackgroundUpdateStore):
"remote_media_cache_thumbnails", "remote_media_cache_thumbnails",
{"media_origin": origin, "media_id": media_id}, {"media_origin": origin, "media_id": media_id},
( (
"thumbnail_width", "thumbnail_height", "thumbnail_method", "thumbnail_width",
"thumbnail_type", "thumbnail_length", "filesystem_id", "thumbnail_height",
"thumbnail_method",
"thumbnail_type",
"thumbnail_length",
"filesystem_id",
), ),
desc="get_remote_media_thumbnails", desc="get_remote_media_thumbnails",
) )
def store_remote_media_thumbnail(self, origin, media_id, filesystem_id, def store_remote_media_thumbnail(
thumbnail_width, thumbnail_height, self,
thumbnail_type, thumbnail_method, origin,
thumbnail_length): media_id,
filesystem_id,
thumbnail_width,
thumbnail_height,
thumbnail_type,
thumbnail_method,
thumbnail_length,
):
return self._simple_insert( return self._simple_insert(
"remote_media_cache_thumbnails", "remote_media_cache_thumbnails",
{ {
@ -250,17 +304,14 @@ class MediaRepositoryStore(BackgroundUpdateStore):
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
"remote_media_cache", "remote_media_cache",
keyvalues={ keyvalues={"media_origin": media_origin, "media_id": media_id},
"media_origin": media_origin, "media_id": media_id
},
) )
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
"remote_media_cache_thumbnails", "remote_media_cache_thumbnails",
keyvalues={ keyvalues={"media_origin": media_origin, "media_id": media_id},
"media_origin": media_origin, "media_id": media_id
},
) )
return self.runInteraction("delete_remote_media", delete_remote_media_txn) return self.runInteraction("delete_remote_media", delete_remote_media_txn)
def get_expired_url_cache(self, now_ts): def get_expired_url_cache(self, now_ts):
@ -281,10 +332,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
if len(media_ids) == 0: if len(media_ids) == 0:
return return
sql = ( sql = "DELETE FROM local_media_repository_url_cache" " WHERE media_id = ?"
"DELETE FROM local_media_repository_url_cache"
" WHERE media_id = ?"
)
def _delete_url_cache_txn(txn): def _delete_url_cache_txn(txn):
txn.executemany(sql, [(media_id,) for media_id in media_ids]) txn.executemany(sql, [(media_id,) for media_id in media_ids])
@ -304,7 +352,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
return [row[0] for row in txn] return [row[0] for row in txn]
return self.runInteraction( return self.runInteraction(
"get_url_cache_media_before", _get_url_cache_media_before_txn, "get_url_cache_media_before", _get_url_cache_media_before_txn
) )
def delete_url_cache_media(self, media_ids): def delete_url_cache_media(self, media_ids):
@ -312,20 +360,14 @@ class MediaRepositoryStore(BackgroundUpdateStore):
return return
def _delete_url_cache_media_txn(txn): def _delete_url_cache_media_txn(txn):
sql = ( sql = "DELETE FROM local_media_repository" " WHERE media_id = ?"
"DELETE FROM local_media_repository"
" WHERE media_id = ?"
)
txn.executemany(sql, [(media_id,) for media_id in media_ids]) txn.executemany(sql, [(media_id,) for media_id in media_ids])
sql = ( sql = "DELETE FROM local_media_repository_thumbnails" " WHERE media_id = ?"
"DELETE FROM local_media_repository_thumbnails"
" WHERE media_id = ?"
)
txn.executemany(sql, [(media_id,) for media_id in media_ids]) txn.executemany(sql, [(media_id,) for media_id in media_ids])
return self.runInteraction( return self.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn, "delete_url_cache_media", _delete_url_cache_media_txn
) )

View File

@ -35,7 +35,10 @@ class MonthlyActiveUsersStore(SQLBaseStore):
self.reserved_users = () self.reserved_users = ()
# Do not add more reserved users than the total allowable number # Do not add more reserved users than the total allowable number
self._new_transaction( self._new_transaction(
dbconn, "initialise_mau_threepids", [], [], dbconn,
"initialise_mau_threepids",
[],
[],
self._initialise_reserved_users, self._initialise_reserved_users,
hs.config.mau_limits_reserved_threepids[: self.hs.config.max_mau_value], hs.config.mau_limits_reserved_threepids[: self.hs.config.max_mau_value],
) )
@ -51,10 +54,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
reserved_user_list = [] reserved_user_list = []
for tp in threepids: for tp in threepids:
user_id = self.get_user_id_by_threepid_txn( user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
txn,
tp["medium"], tp["address"]
)
if user_id: if user_id:
is_support = self.is_support_user_txn(txn, user_id) is_support = self.is_support_user_txn(txn, user_id)
@ -62,9 +62,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
self.upsert_monthly_active_user_txn(txn, user_id) self.upsert_monthly_active_user_txn(txn, user_id)
reserved_user_list.append(user_id) reserved_user_list.append(user_id)
else: else:
logger.warning( logger.warning("mau limit reserved threepid %s not found in db" % tp)
"mau limit reserved threepid %s not found in db" % tp
)
self.reserved_users = tuple(reserved_user_list) self.reserved_users = tuple(reserved_user_list)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -75,12 +73,11 @@ class MonthlyActiveUsersStore(SQLBaseStore):
Returns: Returns:
Deferred[] Deferred[]
""" """
def _reap_users(txn): def _reap_users(txn):
# Purge stale users # Purge stale users
thirty_days_ago = ( thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
)
query_args = [thirty_days_ago] query_args = [thirty_days_ago]
base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?" base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?"
@ -158,6 +155,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
txn.execute(sql) txn.execute(sql)
count, = txn.fetchone() count, = txn.fetchone()
return count return count
return self.runInteraction("count_users", _count_users) return self.runInteraction("count_users", _count_users)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -198,14 +196,11 @@ class MonthlyActiveUsersStore(SQLBaseStore):
return return
yield self.runInteraction( yield self.runInteraction(
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
user_id
) )
user_in_mau = self.user_last_seen_monthly_active.cache.get( user_in_mau = self.user_last_seen_monthly_active.cache.get(
(user_id,), (user_id,), None, update_metrics=False
None,
update_metrics=False
) )
if user_in_mau is None: if user_in_mau is None:
self.get_monthly_active_count.invalidate(()) self.get_monthly_active_count.invalidate(())
@ -247,12 +242,8 @@ class MonthlyActiveUsersStore(SQLBaseStore):
is_insert = self._simple_upsert_txn( is_insert = self._simple_upsert_txn(
txn, txn,
table="monthly_active_users", table="monthly_active_users",
keyvalues={ keyvalues={"user_id": user_id},
"user_id": user_id, values={"timestamp": int(self._clock.time_msec())},
},
values={
"timestamp": int(self._clock.time_msec()),
},
) )
return is_insert return is_insert
@ -268,15 +259,13 @@ class MonthlyActiveUsersStore(SQLBaseStore):
""" """
return(self._simple_select_one_onecol( return self._simple_select_one_onecol(
table="monthly_active_users", table="monthly_active_users",
keyvalues={ keyvalues={"user_id": user_id},
"user_id": user_id,
},
retcol="timestamp", retcol="timestamp",
allow_none=True, allow_none=True,
desc="user_last_seen_monthly_active", desc="user_last_seen_monthly_active",
)) )
@defer.inlineCallbacks @defer.inlineCallbacks
def populate_monthly_active_users(self, user_id): def populate_monthly_active_users(self, user_id):

View File

@ -10,7 +10,7 @@ class OpenIdStore(SQLBaseStore):
"ts_valid_until_ms": ts_valid_until_ms, "ts_valid_until_ms": ts_valid_until_ms,
"user_id": user_id, "user_id": user_id,
}, },
desc="insert_open_id_token" desc="insert_open_id_token",
) )
def get_user_id_for_open_id_token(self, token, ts_now_ms): def get_user_id_for_open_id_token(self, token, ts_now_ms):
@ -27,6 +27,5 @@ class OpenIdStore(SQLBaseStore):
return None return None
else: else:
return rows[0][0] return rows[0][0]
return self.runInteraction(
"get_user_id_for_token", get_user_id_for_token_txn return self.runInteraction("get_user_id_for_token", get_user_id_for_token_txn)
)

View File

@ -143,10 +143,9 @@ def _setup_new_database(cur, database_engine):
cur.execute( cur.execute(
database_engine.convert_param_style( database_engine.convert_param_style(
"INSERT INTO schema_version (version, upgraded)" "INSERT INTO schema_version (version, upgraded)" " VALUES (?,?)"
" VALUES (?,?)"
), ),
(max_current_ver, False,) (max_current_ver, False),
) )
_upgrade_existing_database( _upgrade_existing_database(
@ -160,8 +159,15 @@ def _setup_new_database(cur, database_engine):
) )
def _upgrade_existing_database(cur, current_version, applied_delta_files, def _upgrade_existing_database(
upgraded, database_engine, config, is_empty=False): cur,
current_version,
applied_delta_files,
upgraded,
database_engine,
config,
is_empty=False,
):
"""Upgrades an existing database. """Upgrades an existing database.
Delta files can either be SQL stored in *.sql files, or python modules Delta files can either be SQL stored in *.sql files, or python modules
@ -209,8 +215,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
if current_version > SCHEMA_VERSION: if current_version > SCHEMA_VERSION:
raise ValueError( raise ValueError(
"Cannot use this database as it is too " + "Cannot use this database as it is too "
"new for the server to understand" + "new for the server to understand"
) )
start_ver = current_version start_ver = current_version
@ -239,20 +245,14 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
if relative_path in applied_delta_files: if relative_path in applied_delta_files:
continue continue
absolute_path = os.path.join( absolute_path = os.path.join(dir_path, "schema", "delta", relative_path)
dir_path, "schema", "delta", relative_path,
)
root_name, ext = os.path.splitext(file_name) root_name, ext = os.path.splitext(file_name)
if ext == ".py": if ext == ".py":
# This is a python upgrade module. We need to import into some # This is a python upgrade module. We need to import into some
# package and then execute its `run_upgrade` function. # package and then execute its `run_upgrade` function.
module_name = "synapse.storage.v%d_%s" % ( module_name = "synapse.storage.v%d_%s" % (v, root_name)
v, root_name
)
with open(absolute_path) as python_file: with open(absolute_path) as python_file:
module = imp.load_source( module = imp.load_source(module_name, absolute_path, python_file)
module_name, absolute_path, python_file
)
logger.info("Running script %s", relative_path) logger.info("Running script %s", relative_path)
module.run_create(cur, database_engine) module.run_create(cur, database_engine)
if not is_empty: if not is_empty:
@ -269,8 +269,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
else: else:
# Not a valid delta file. # Not a valid delta file.
logger.warn( logger.warn(
"Found directory entry that did not end in .py or" "Found directory entry that did not end in .py or" " .sql: %s",
" .sql: %s",
relative_path, relative_path,
) )
continue continue
@ -278,19 +277,17 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
# Mark as done. # Mark as done.
cur.execute( cur.execute(
database_engine.convert_param_style( database_engine.convert_param_style(
"INSERT INTO applied_schema_deltas (version, file)" "INSERT INTO applied_schema_deltas (version, file)" " VALUES (?,?)"
" VALUES (?,?)",
), ),
(v, relative_path) (v, relative_path),
) )
cur.execute("DELETE FROM schema_version") cur.execute("DELETE FROM schema_version")
cur.execute( cur.execute(
database_engine.convert_param_style( database_engine.convert_param_style(
"INSERT INTO schema_version (version, upgraded)" "INSERT INTO schema_version (version, upgraded)" " VALUES (?,?)"
" VALUES (?,?)",
), ),
(v, True) (v, True),
) )
@ -308,7 +305,7 @@ def _apply_module_schemas(txn, database_engine, config):
continue continue
modname = ".".join((mod.__module__, mod.__name__)) modname = ".".join((mod.__module__, mod.__name__))
_apply_module_schema_files( _apply_module_schema_files(
txn, database_engine, modname, mod.get_db_schema_files(), txn, database_engine, modname, mod.get_db_schema_files()
) )
@ -326,7 +323,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
database_engine.convert_param_style( database_engine.convert_param_style(
"SELECT file FROM applied_module_schemas WHERE module_name = ?" "SELECT file FROM applied_module_schemas WHERE module_name = ?"
), ),
(modname,) (modname,),
) )
applied_deltas = set(d for d, in cur) applied_deltas = set(d for d, in cur)
for (name, stream) in names_and_streams: for (name, stream) in names_and_streams:
@ -336,7 +333,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
root_name, ext = os.path.splitext(name) root_name, ext = os.path.splitext(name)
if ext != '.sql': if ext != '.sql':
raise PrepareDatabaseException( raise PrepareDatabaseException(
"only .sql files are currently supported for module schemas", "only .sql files are currently supported for module schemas"
) )
logger.info("applying schema %s for %s", name, modname) logger.info("applying schema %s for %s", name, modname)
@ -346,10 +343,9 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
# Mark as done. # Mark as done.
cur.execute( cur.execute(
database_engine.convert_param_style( database_engine.convert_param_style(
"INSERT INTO applied_module_schemas (module_name, file)" "INSERT INTO applied_module_schemas (module_name, file)" " VALUES (?,?)"
" VALUES (?,?)",
), ),
(modname, name) (modname, name),
) )
@ -386,10 +382,7 @@ def get_statements(f):
statements = line.split(";") statements = line.split(";")
# We must prepend statement_buffer to the first statement # We must prepend statement_buffer to the first statement
first_statement = "%s %s" % ( first_statement = "%s %s" % (statement_buffer.strip(), statements[0].strip())
statement_buffer.strip(),
statements[0].strip()
)
statements[0] = first_statement statements[0] = first_statement
# Every entry, except the last, is a full statement # Every entry, except the last, is a full statement
@ -409,9 +402,7 @@ def executescript(txn, schema_path):
def _get_or_create_schema_state(txn, database_engine): def _get_or_create_schema_state(txn, database_engine):
# Bluntly try creating the schema_version tables. # Bluntly try creating the schema_version tables.
schema_path = os.path.join( schema_path = os.path.join(dir_path, "schema", "schema_version.sql")
dir_path, "schema", "schema_version.sql",
)
executescript(txn, schema_path) executescript(txn, schema_path)
txn.execute("SELECT version, upgraded FROM schema_version") txn.execute("SELECT version, upgraded FROM schema_version")
@ -424,7 +415,7 @@ def _get_or_create_schema_state(txn, database_engine):
database_engine.convert_param_style( database_engine.convert_param_style(
"SELECT file FROM applied_schema_deltas WHERE version >= ?" "SELECT file FROM applied_schema_deltas WHERE version >= ?"
), ),
(current_version,) (current_version,),
) )
applied_deltas = [d for d, in txn] applied_deltas = [d for d, in txn]
return current_version, applied_deltas, upgraded return current_version, applied_deltas, upgraded

View File

@ -24,10 +24,20 @@ from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cache
from ._base import SQLBaseStore from ._base import SQLBaseStore
class UserPresenceState(namedtuple("UserPresenceState", class UserPresenceState(
("user_id", "state", "last_active_ts", namedtuple(
"last_federation_update_ts", "last_user_sync_ts", "UserPresenceState",
"status_msg", "currently_active"))): (
"user_id",
"state",
"last_active_ts",
"last_federation_update_ts",
"last_user_sync_ts",
"status_msg",
"currently_active",
),
)
):
"""Represents the current presence state of the user. """Represents the current presence state of the user.
user_id (str) user_id (str)
@ -75,22 +85,21 @@ class PresenceStore(SQLBaseStore):
with stream_ordering_manager as stream_orderings: with stream_ordering_manager as stream_orderings:
yield self.runInteraction( yield self.runInteraction(
"update_presence", "update_presence",
self._update_presence_txn, stream_orderings, presence_states, self._update_presence_txn,
stream_orderings,
presence_states,
) )
defer.returnValue(( defer.returnValue(
stream_orderings[-1], self._presence_id_gen.get_current_token() (stream_orderings[-1], self._presence_id_gen.get_current_token())
)) )
def _update_presence_txn(self, txn, stream_orderings, presence_states): def _update_presence_txn(self, txn, stream_orderings, presence_states):
for stream_id, state in zip(stream_orderings, presence_states): for stream_id, state in zip(stream_orderings, presence_states):
txn.call_after( txn.call_after(
self.presence_stream_cache.entity_has_changed, self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
state.user_id, stream_id,
)
txn.call_after(
self._get_presence_for_user.invalidate, (state.user_id,)
) )
txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
# Actually insert new rows # Actually insert new rows
self._simple_insert_many_txn( self._simple_insert_many_txn(
@ -113,18 +122,13 @@ class PresenceStore(SQLBaseStore):
# Delete old rows to stop database from getting really big # Delete old rows to stop database from getting really big
sql = ( sql = (
"DELETE FROM presence_stream WHERE" "DELETE FROM presence_stream WHERE" " stream_id < ?" " AND user_id IN (%s)"
" stream_id < ?"
" AND user_id IN (%s)"
) )
for states in batch_iter(presence_states, 50): for states in batch_iter(presence_states, 50):
args = [stream_id] args = [stream_id]
args.extend(s.user_id for s in states) args.extend(s.user_id for s in states)
txn.execute( txn.execute(sql % (",".join("?" for _ in states),), args)
sql % (",".join("?" for _ in states),),
args
)
def get_all_presence_updates(self, last_id, current_id): def get_all_presence_updates(self, last_id, current_id):
if last_id == current_id: if last_id == current_id:
@ -149,8 +153,12 @@ class PresenceStore(SQLBaseStore):
def _get_presence_for_user(self, user_id): def _get_presence_for_user(self, user_id):
raise NotImplementedError() raise NotImplementedError()
@cachedList(cached_method_name="_get_presence_for_user", list_name="user_ids", @cachedList(
num_args=1, inlineCallbacks=True) cached_method_name="_get_presence_for_user",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
)
def get_presence_for_users(self, user_ids): def get_presence_for_users(self, user_ids):
rows = yield self._simple_select_many_batch( rows = yield self._simple_select_many_batch(
table="presence_stream", table="presence_stream",
@ -180,8 +188,10 @@ class PresenceStore(SQLBaseStore):
def allow_presence_visible(self, observed_localpart, observer_userid): def allow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_insert( return self._simple_insert(
table="presence_allow_inbound", table="presence_allow_inbound",
values={"observed_user_id": observed_localpart, values={
"observer_user_id": observer_userid}, "observed_user_id": observed_localpart,
"observer_user_id": observer_userid,
},
desc="allow_presence_visible", desc="allow_presence_visible",
or_ignore=True, or_ignore=True,
) )
@ -189,17 +199,21 @@ class PresenceStore(SQLBaseStore):
def disallow_presence_visible(self, observed_localpart, observer_userid): def disallow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_delete_one( return self._simple_delete_one(
table="presence_allow_inbound", table="presence_allow_inbound",
keyvalues={"observed_user_id": observed_localpart, keyvalues={
"observer_user_id": observer_userid}, "observed_user_id": observed_localpart,
"observer_user_id": observer_userid,
},
desc="disallow_presence_visible", desc="disallow_presence_visible",
) )
def add_presence_list_pending(self, observer_localpart, observed_userid): def add_presence_list_pending(self, observer_localpart, observed_userid):
return self._simple_insert( return self._simple_insert(
table="presence_list", table="presence_list",
values={"user_id": observer_localpart, values={
"user_id": observer_localpart,
"observed_user_id": observed_userid, "observed_user_id": observed_userid,
"accepted": False}, "accepted": False,
},
desc="add_presence_list_pending", desc="add_presence_list_pending",
) )
@ -210,7 +224,7 @@ class PresenceStore(SQLBaseStore):
table="presence_list", table="presence_list",
keyvalues={ keyvalues={
"user_id": observer_localpart, "user_id": observer_localpart,
"observed_user_id": observed_userid "observed_user_id": observed_userid,
}, },
updatevalues={"accepted": True}, updatevalues={"accepted": True},
) )
@ -225,7 +239,7 @@ class PresenceStore(SQLBaseStore):
return result return result
return self.runInteraction( return self.runInteraction(
"set_presence_list_accepted", update_presence_list_txn, "set_presence_list_accepted", update_presence_list_txn
) )
def get_presence_list(self, observer_localpart, accepted=None): def get_presence_list(self, observer_localpart, accepted=None):
@ -261,16 +275,16 @@ class PresenceStore(SQLBaseStore):
desc="get_presence_list_accepted", desc="get_presence_list_accepted",
) )
defer.returnValue([ defer.returnValue(["@%s:%s" % (u, self.hs.hostname) for u in user_localparts])
"@%s:%s" % (u, self.hs.hostname,) for u in user_localparts
])
@defer.inlineCallbacks @defer.inlineCallbacks
def del_presence_list(self, observer_localpart, observed_userid): def del_presence_list(self, observer_localpart, observed_userid):
yield self._simple_delete_one( yield self._simple_delete_one(
table="presence_list", table="presence_list",
keyvalues={"user_id": observer_localpart, keyvalues={
"observed_user_id": observed_userid}, "user_id": observer_localpart,
"observed_user_id": observed_userid,
},
desc="del_presence_list", desc="del_presence_list",
) )
self.get_presence_list_accepted.invalidate((observer_localpart,)) self.get_presence_list_accepted.invalidate((observer_localpart,))

View File

@ -41,8 +41,7 @@ class ProfileWorkerStore(SQLBaseStore):
defer.returnValue( defer.returnValue(
ProfileInfo( ProfileInfo(
avatar_url=profile['avatar_url'], avatar_url=profile['avatar_url'], display_name=profile['displayname']
display_name=profile['displayname'],
) )
) )
@ -66,16 +65,14 @@ class ProfileWorkerStore(SQLBaseStore):
return self._simple_select_one( return self._simple_select_one(
table="remote_profile_cache", table="remote_profile_cache",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcols=("displayname", "avatar_url",), retcols=("displayname", "avatar_url"),
allow_none=True, allow_none=True,
desc="get_from_remote_profile_cache", desc="get_from_remote_profile_cache",
) )
def create_profile(self, user_localpart): def create_profile(self, user_localpart):
return self._simple_insert( return self._simple_insert(
table="profiles", table="profiles", values={"user_id": user_localpart}, desc="create_profile"
values={"user_id": user_localpart},
desc="create_profile",
) )
def set_profile_displayname(self, user_localpart, new_displayname): def set_profile_displayname(self, user_localpart, new_displayname):
@ -141,6 +138,7 @@ class ProfileStore(ProfileWorkerStore):
def get_remote_profile_cache_entries_that_expire(self, last_checked): def get_remote_profile_cache_entries_that_expire(self, last_checked):
"""Get all users who haven't been checked since `last_checked` """Get all users who haven't been checked since `last_checked`
""" """
def _get_remote_profile_cache_entries_that_expire_txn(txn): def _get_remote_profile_cache_entries_that_expire_txn(txn):
sql = """ sql = """
SELECT user_id, displayname, avatar_url SELECT user_id, displayname, avatar_url

View File

@ -57,11 +57,13 @@ def _load_rules(rawrules, enabled_map):
return rules return rules
class PushRulesWorkerStore(ApplicationServiceWorkerStore, class PushRulesWorkerStore(
ApplicationServiceWorkerStore,
ReceiptsWorkerStore, ReceiptsWorkerStore,
PusherWorkerStore, PusherWorkerStore,
RoomMemberWorkerStore, RoomMemberWorkerStore,
SQLBaseStore): SQLBaseStore,
):
"""This is an abstract base class where subclasses must implement """This is an abstract base class where subclasses must implement
`get_max_push_rules_stream_id` which can be called in the initializer. `get_max_push_rules_stream_id` which can be called in the initializer.
""" """
@ -74,14 +76,16 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
super(PushRulesWorkerStore, self).__init__(db_conn, hs) super(PushRulesWorkerStore, self).__init__(db_conn, hs)
push_rules_prefill, push_rules_id = self._get_cache_dict( push_rules_prefill, push_rules_id = self._get_cache_dict(
db_conn, "push_rules_stream", db_conn,
"push_rules_stream",
entity_column="user_id", entity_column="user_id",
stream_column="stream_id", stream_column="stream_id",
max_value=self.get_max_push_rules_stream_id(), max_value=self.get_max_push_rules_stream_id(),
) )
self.push_rules_stream_cache = StreamChangeCache( self.push_rules_stream_cache = StreamChangeCache(
"PushRulesStreamChangeCache", push_rules_id, "PushRulesStreamChangeCache",
push_rules_id,
prefilled_cache=push_rules_prefill, prefilled_cache=push_rules_prefill,
) )
@ -98,19 +102,19 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
def get_push_rules_for_user(self, user_id): def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list( rows = yield self._simple_select_list(
table="push_rules", table="push_rules",
keyvalues={ keyvalues={"user_name": user_id},
"user_name": user_id,
},
retcols=( retcols=(
"user_name", "rule_id", "priority_class", "priority", "user_name",
"conditions", "actions", "rule_id",
"priority_class",
"priority",
"conditions",
"actions",
), ),
desc="get_push_rules_enabled_for_user", desc="get_push_rules_enabled_for_user",
) )
rows.sort( rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
)
enabled_map = yield self.get_push_rules_enabled_for_user(user_id) enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
@ -122,22 +126,19 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
def get_push_rules_enabled_for_user(self, user_id): def get_push_rules_enabled_for_user(self, user_id):
results = yield self._simple_select_list( results = yield self._simple_select_list(
table="push_rules_enable", table="push_rules_enable",
keyvalues={ keyvalues={'user_name': user_id},
'user_name': user_id retcols=("user_name", "rule_id", "enabled"),
},
retcols=(
"user_name", "rule_id", "enabled",
),
desc="get_push_rules_enabled_for_user", desc="get_push_rules_enabled_for_user",
) )
defer.returnValue({ defer.returnValue(
r['rule_id']: False if r['enabled'] == 0 else True for r in results {r['rule_id']: False if r['enabled'] == 0 else True for r in results}
}) )
def have_push_rules_changed_for_user(self, user_id, last_id): def have_push_rules_changed_for_user(self, user_id, last_id):
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
return defer.succeed(False) return defer.succeed(False)
else: else:
def have_push_rules_changed_txn(txn): def have_push_rules_changed_txn(txn):
sql = ( sql = (
"SELECT COUNT(stream_id) FROM push_rules_stream" "SELECT COUNT(stream_id) FROM push_rules_stream"
@ -146,20 +147,22 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
txn.execute(sql, (user_id, last_id)) txn.execute(sql, (user_id, last_id))
count, = txn.fetchone() count, = txn.fetchone()
return bool(count) return bool(count)
return self.runInteraction( return self.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn "have_push_rules_changed", have_push_rules_changed_txn
) )
@cachedList(cached_method_name="get_push_rules_for_user", @cachedList(
list_name="user_ids", num_args=1, inlineCallbacks=True) cached_method_name="get_push_rules_for_user",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
)
def bulk_get_push_rules(self, user_ids): def bulk_get_push_rules(self, user_ids):
if not user_ids: if not user_ids:
defer.returnValue({}) defer.returnValue({})
results = { results = {user_id: [] for user_id in user_ids}
user_id: []
for user_id in user_ids
}
rows = yield self._simple_select_many_batch( rows = yield self._simple_select_many_batch(
table="push_rules", table="push_rules",
@ -169,9 +172,7 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
desc="bulk_get_push_rules", desc="bulk_get_push_rules",
) )
rows.sort( rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
)
for row in rows: for row in rows:
results.setdefault(row['user_name'], []).append(row) results.setdefault(row['user_name'], []).append(row)
@ -179,16 +180,12 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
for user_id, rules in results.items(): for user_id, rules in results.items():
results[user_id] = _load_rules( results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
rules, enabled_map_by_user.get(user_id, {})
)
defer.returnValue(results) defer.returnValue(results)
@defer.inlineCallbacks @defer.inlineCallbacks
def move_push_rule_from_room_to_room( def move_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
self, new_room_id, user_id, rule,
):
"""Move a single push rule from one room to another for a specific user. """Move a single push rule from one room to another for a specific user.
Args: Args:
@ -219,7 +216,7 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
@defer.inlineCallbacks @defer.inlineCallbacks
def move_push_rules_from_room_to_room_for_user( def move_push_rules_from_room_to_room_for_user(
self, old_room_id, new_room_id, user_id, self, old_room_id, new_room_id, user_id
): ):
"""Move all of the push rules from one room to another for a specific """Move all of the push rules from one room to another for a specific
user. user.
@ -236,11 +233,11 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
# delete them from the old room # delete them from the old room
for rule in user_push_rules: for rule in user_push_rules:
conditions = rule.get("conditions", []) conditions = rule.get("conditions", [])
if any((c.get("key") == "room_id" and if any(
c.get("pattern") == old_room_id) for c in conditions): (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
self.move_push_rule_from_room_to_room( for c in conditions
new_room_id, user_id, rule, ):
) self.move_push_rule_from_room_to_room(new_room_id, user_id, rule)
@defer.inlineCallbacks @defer.inlineCallbacks
def bulk_get_push_rules_for_room(self, event, context): def bulk_get_push_rules_for_room(self, event, context):
@ -259,8 +256,9 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
defer.returnValue(result) defer.returnValue(result)
@cachedInlineCallbacks(num_args=2, cache_context=True) @cachedInlineCallbacks(num_args=2, cache_context=True)
def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids, def _bulk_get_push_rules_for_room(
cache_context, event=None): self, room_id, state_group, current_state_ids, cache_context, event=None
):
# We don't use `state_group`, its there so that we can cache based # We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's # on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different. # with a state_group of None are likely to be different.
@ -273,7 +271,9 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
# sent a read receipt into the room. # sent a read receipt into the room.
users_in_room = yield self._get_joined_users_from_context( users_in_room = yield self._get_joined_users_from_context(
room_id, state_group, current_state_ids, room_id,
state_group,
current_state_ids,
on_invalidate=cache_context.invalidate, on_invalidate=cache_context.invalidate,
event=event, event=event,
) )
@ -282,7 +282,8 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
# up the `get_if_users_have_pushers` cache with AS entries that we # up the `get_if_users_have_pushers` cache with AS entries that we
# know don't have pushers, nor even read receipts. # know don't have pushers, nor even read receipts.
local_users_in_room = set( local_users_in_room = set(
u for u in users_in_room u
for u in users_in_room
if self.hs.is_mine_id(u) if self.hs.is_mine_id(u)
and not self.get_if_app_services_interested_in_user(u) and not self.get_if_app_services_interested_in_user(u)
) )
@ -290,15 +291,14 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
# users in the room who have pushers need to get push rules run because # users in the room who have pushers need to get push rules run because
# that's how their pushers work # that's how their pushers work
if_users_with_pushers = yield self.get_if_users_have_pushers( if_users_with_pushers = yield self.get_if_users_have_pushers(
local_users_in_room, local_users_in_room, on_invalidate=cache_context.invalidate
on_invalidate=cache_context.invalidate,
) )
user_ids = set( user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
) )
users_with_receipts = yield self.get_users_with_read_receipts_in_room( users_with_receipts = yield self.get_users_with_read_receipts_in_room(
room_id, on_invalidate=cache_context.invalidate, room_id, on_invalidate=cache_context.invalidate
) )
# any users with pushers must be ours: they have pushers # any users with pushers must be ours: they have pushers
@ -307,29 +307,30 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
user_ids.add(uid) user_ids.add(uid)
rules_by_user = yield self.bulk_get_push_rules( rules_by_user = yield self.bulk_get_push_rules(
user_ids, on_invalidate=cache_context.invalidate, user_ids, on_invalidate=cache_context.invalidate
) )
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
defer.returnValue(rules_by_user) defer.returnValue(rules_by_user)
@cachedList(cached_method_name="get_push_rules_enabled_for_user", @cachedList(
list_name="user_ids", num_args=1, inlineCallbacks=True) cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
)
def bulk_get_push_rules_enabled(self, user_ids): def bulk_get_push_rules_enabled(self, user_ids):
if not user_ids: if not user_ids:
defer.returnValue({}) defer.returnValue({})
results = { results = {user_id: {} for user_id in user_ids}
user_id: {}
for user_id in user_ids
}
rows = yield self._simple_select_many_batch( rows = yield self._simple_select_many_batch(
table="push_rules_enable", table="push_rules_enable",
column="user_name", column="user_name",
iterable=user_ids, iterable=user_ids,
retcols=("user_name", "rule_id", "enabled",), retcols=("user_name", "rule_id", "enabled"),
desc="bulk_get_push_rules_enabled", desc="bulk_get_push_rules_enabled",
) )
for row in rows: for row in rows:
@ -341,8 +342,14 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
class PushRuleStore(PushRulesWorkerStore): class PushRuleStore(PushRulesWorkerStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def add_push_rule( def add_push_rule(
self, user_id, rule_id, priority_class, conditions, actions, self,
before=None, after=None user_id,
rule_id,
priority_class,
conditions,
actions,
before=None,
after=None,
): ):
conditions_json = json.dumps(conditions) conditions_json = json.dumps(conditions)
actions_json = json.dumps(actions) actions_json = json.dumps(actions)
@ -352,20 +359,41 @@ class PushRuleStore(PushRulesWorkerStore):
yield self.runInteraction( yield self.runInteraction(
"_add_push_rule_relative_txn", "_add_push_rule_relative_txn",
self._add_push_rule_relative_txn, self._add_push_rule_relative_txn,
stream_id, event_stream_ordering, user_id, rule_id, priority_class, stream_id,
conditions_json, actions_json, before, after, event_stream_ordering,
user_id,
rule_id,
priority_class,
conditions_json,
actions_json,
before,
after,
) )
else: else:
yield self.runInteraction( yield self.runInteraction(
"_add_push_rule_highest_priority_txn", "_add_push_rule_highest_priority_txn",
self._add_push_rule_highest_priority_txn, self._add_push_rule_highest_priority_txn,
stream_id, event_stream_ordering, user_id, rule_id, priority_class, stream_id,
conditions_json, actions_json, event_stream_ordering,
user_id,
rule_id,
priority_class,
conditions_json,
actions_json,
) )
def _add_push_rule_relative_txn( def _add_push_rule_relative_txn(
self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class, self,
conditions_json, actions_json, before, after txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
priority_class,
conditions_json,
actions_json,
before,
after,
): ):
# Lock the table since otherwise we'll have annoying races between the # Lock the table since otherwise we'll have annoying races between the
# SELECT here and the UPSERT below. # SELECT here and the UPSERT below.
@ -376,10 +404,7 @@ class PushRuleStore(PushRulesWorkerStore):
res = self._simple_select_one_txn( res = self._simple_select_one_txn(
txn, txn,
table="push_rules", table="push_rules",
keyvalues={ keyvalues={"user_name": user_id, "rule_id": relative_to_rule},
"user_name": user_id,
"rule_id": relative_to_rule,
},
retcols=["priority_class", "priority"], retcols=["priority_class", "priority"],
allow_none=True, allow_none=True,
) )
@ -416,13 +441,27 @@ class PushRuleStore(PushRulesWorkerStore):
txn.execute(sql, (user_id, priority_class, new_rule_priority)) txn.execute(sql, (user_id, priority_class, new_rule_priority))
self._upsert_push_rule_txn( self._upsert_push_rule_txn(
txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class, txn,
new_rule_priority, conditions_json, actions_json, stream_id,
event_stream_ordering,
user_id,
rule_id,
priority_class,
new_rule_priority,
conditions_json,
actions_json,
) )
def _add_push_rule_highest_priority_txn( def _add_push_rule_highest_priority_txn(
self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class, self,
conditions_json, actions_json txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
priority_class,
conditions_json,
actions_json,
): ):
# Lock the table since otherwise we'll have annoying races between the # Lock the table since otherwise we'll have annoying races between the
# SELECT here and the UPSERT below. # SELECT here and the UPSERT below.
@ -443,13 +482,28 @@ class PushRuleStore(PushRulesWorkerStore):
self._upsert_push_rule_txn( self._upsert_push_rule_txn(
txn, txn,
stream_id, event_stream_ordering, user_id, rule_id, priority_class, new_prio, stream_id,
conditions_json, actions_json, event_stream_ordering,
user_id,
rule_id,
priority_class,
new_prio,
conditions_json,
actions_json,
) )
def _upsert_push_rule_txn( def _upsert_push_rule_txn(
self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class, self,
priority, conditions_json, actions_json, update_stream=True txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
priority_class,
priority,
conditions_json,
actions_json,
update_stream=True,
): ):
"""Specialised version of _simple_upsert_txn that picks a push_rule_id """Specialised version of _simple_upsert_txn that picks a push_rule_id
using the _push_rule_id_gen if it needs to insert the rule. It assumes using the _push_rule_id_gen if it needs to insert the rule. It assumes
@ -461,10 +515,10 @@ class PushRuleStore(PushRulesWorkerStore):
" WHERE user_name = ? AND rule_id = ?" " WHERE user_name = ? AND rule_id = ?"
) )
txn.execute(sql, ( txn.execute(
priority_class, priority, conditions_json, actions_json, sql,
user_id, rule_id, (priority_class, priority, conditions_json, actions_json, user_id, rule_id),
)) )
if txn.rowcount == 0: if txn.rowcount == 0:
# We didn't update a row with the given rule_id so insert one # We didn't update a row with the given rule_id so insert one
@ -486,14 +540,18 @@ class PushRuleStore(PushRulesWorkerStore):
if update_stream: if update_stream:
self._insert_push_rules_update_txn( self._insert_push_rules_update_txn(
txn, stream_id, event_stream_ordering, user_id, rule_id, txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
op="ADD", op="ADD",
data={ data={
"priority_class": priority_class, "priority_class": priority_class,
"priority": priority, "priority": priority,
"conditions": conditions_json, "conditions": conditions_json,
"actions": actions_json, "actions": actions_json,
} },
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -507,22 +565,23 @@ class PushRuleStore(PushRulesWorkerStore):
user_id (str): The matrix ID of the push rule owner user_id (str): The matrix ID of the push rule owner
rule_id (str): The rule_id of the rule to be deleted rule_id (str): The rule_id of the rule to be deleted
""" """
def delete_push_rule_txn(txn, stream_id, event_stream_ordering): def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
self._simple_delete_one_txn( self._simple_delete_one_txn(
txn, txn, "push_rules", {'user_name': user_id, 'rule_id': rule_id}
"push_rules",
{'user_name': user_id, 'rule_id': rule_id},
) )
self._insert_push_rules_update_txn( self._insert_push_rules_update_txn(
txn, stream_id, event_stream_ordering, user_id, rule_id, txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
op="DELETE"
) )
with self._push_rules_stream_id_gen.get_next() as ids: with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids stream_id, event_stream_ordering = ids
yield self.runInteraction( yield self.runInteraction(
"delete_push_rule", delete_push_rule_txn, stream_id, event_stream_ordering "delete_push_rule",
delete_push_rule_txn,
stream_id,
event_stream_ordering,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -532,7 +591,11 @@ class PushRuleStore(PushRulesWorkerStore):
yield self.runInteraction( yield self.runInteraction(
"_set_push_rule_enabled_txn", "_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn, self._set_push_rule_enabled_txn,
stream_id, event_stream_ordering, user_id, rule_id, enabled stream_id,
event_stream_ordering,
user_id,
rule_id,
enabled,
) )
def _set_push_rule_enabled_txn( def _set_push_rule_enabled_txn(
@ -548,8 +611,12 @@ class PushRuleStore(PushRulesWorkerStore):
) )
self._insert_push_rules_update_txn( self._insert_push_rules_update_txn(
txn, stream_id, event_stream_ordering, user_id, rule_id, txn,
op="ENABLE" if enabled else "DISABLE" stream_id,
event_stream_ordering,
user_id,
rule_id,
op="ENABLE" if enabled else "DISABLE",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -563,9 +630,16 @@ class PushRuleStore(PushRulesWorkerStore):
priority_class = -1 priority_class = -1
priority = 1 priority = 1
self._upsert_push_rule_txn( self._upsert_push_rule_txn(
txn, stream_id, event_stream_ordering, user_id, rule_id, txn,
priority_class, priority, "[]", actions_json, stream_id,
update_stream=False event_stream_ordering,
user_id,
rule_id,
priority_class,
priority,
"[]",
actions_json,
update_stream=False,
) )
else: else:
self._simple_update_one_txn( self._simple_update_one_txn(
@ -576,15 +650,22 @@ class PushRuleStore(PushRulesWorkerStore):
) )
self._insert_push_rules_update_txn( self._insert_push_rules_update_txn(
txn, stream_id, event_stream_ordering, user_id, rule_id, txn,
op="ACTIONS", data={"actions": actions_json} stream_id,
event_stream_ordering,
user_id,
rule_id,
op="ACTIONS",
data={"actions": actions_json},
) )
with self._push_rules_stream_id_gen.get_next() as ids: with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids stream_id, event_stream_ordering = ids
yield self.runInteraction( yield self.runInteraction(
"set_push_rule_actions", set_push_rule_actions_txn, "set_push_rule_actions",
stream_id, event_stream_ordering set_push_rule_actions_txn,
stream_id,
event_stream_ordering,
) )
def _insert_push_rules_update_txn( def _insert_push_rules_update_txn(
@ -602,12 +683,8 @@ class PushRuleStore(PushRulesWorkerStore):
self._simple_insert_txn(txn, "push_rules_stream", values=values) self._simple_insert_txn(txn, "push_rules_stream", values=values)
txn.call_after( txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
self.get_push_rules_for_user.invalidate, (user_id,) txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
)
txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, (user_id,)
)
txn.call_after( txn.call_after(
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
) )
@ -627,6 +704,7 @@ class PushRuleStore(PushRulesWorkerStore):
) )
txn.execute(sql, (last_id, current_id, limit)) txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall() return txn.fetchall()
return self.runInteraction( return self.runInteraction(
"get_all_push_rule_updates", get_all_push_rule_updates_txn "get_all_push_rule_updates", get_all_push_rule_updates_txn
) )

View File

@ -47,7 +47,9 @@ class PusherWorkerStore(SQLBaseStore):
except Exception as e: except Exception as e:
logger.warn( logger.warn(
"Invalid JSON in data for pusher %d: %s, %s", "Invalid JSON in data for pusher %d: %s, %s",
r['id'], dataJson, e.args[0], r['id'],
dataJson,
e.args[0],
) )
pass pass
@ -64,20 +66,16 @@ class PusherWorkerStore(SQLBaseStore):
defer.returnValue(ret is not None) defer.returnValue(ret is not None)
def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey): def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
return self.get_pushers_by({ return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
"app_id": app_id,
"pushkey": pushkey,
})
def get_pushers_by_user_id(self, user_id): def get_pushers_by_user_id(self, user_id):
return self.get_pushers_by({ return self.get_pushers_by({"user_name": user_id})
"user_name": user_id,
})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_pushers_by(self, keyvalues): def get_pushers_by(self, keyvalues):
ret = yield self._simple_select_list( ret = yield self._simple_select_list(
"pushers", keyvalues, "pushers",
keyvalues,
[ [
"id", "id",
"user_name", "user_name",
@ -94,7 +92,8 @@ class PusherWorkerStore(SQLBaseStore):
"last_stream_ordering", "last_stream_ordering",
"last_success", "last_success",
"failing_since", "failing_since",
], desc="get_pushers_by" ],
desc="get_pushers_by",
) )
defer.returnValue(self._decode_pushers_rows(ret)) defer.returnValue(self._decode_pushers_rows(ret))
@ -135,6 +134,7 @@ class PusherWorkerStore(SQLBaseStore):
deleted = txn.fetchall() deleted = txn.fetchall()
return (updated, deleted) return (updated, deleted)
return self.runInteraction( return self.runInteraction(
"get_all_updated_pushers", get_all_updated_pushers_txn "get_all_updated_pushers", get_all_updated_pushers_txn
) )
@ -177,6 +177,7 @@ class PusherWorkerStore(SQLBaseStore):
results.sort() # Sort so that they're ordered by stream id results.sort() # Sort so that they're ordered by stream id
return results return results
return self.runInteraction( return self.runInteraction(
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
) )
@ -186,15 +187,19 @@ class PusherWorkerStore(SQLBaseStore):
# This only exists for the cachedList decorator # This only exists for the cachedList decorator
raise NotImplementedError() raise NotImplementedError()
@cachedList(cached_method_name="get_if_user_has_pusher", @cachedList(
list_name="user_ids", num_args=1, inlineCallbacks=True) cached_method_name="get_if_user_has_pusher",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
)
def get_if_users_have_pushers(self, user_ids): def get_if_users_have_pushers(self, user_ids):
rows = yield self._simple_select_many_batch( rows = yield self._simple_select_many_batch(
table='pushers', table='pushers',
column='user_name', column='user_name',
iterable=user_ids, iterable=user_ids,
retcols=['user_name'], retcols=['user_name'],
desc='get_if_users_have_pushers' desc='get_if_users_have_pushers',
) )
result = {user_id: False for user_id in user_ids} result = {user_id: False for user_id in user_ids}
@ -208,20 +213,27 @@ class PusherStore(PusherWorkerStore):
return self._pushers_id_gen.get_current_token() return self._pushers_id_gen.get_current_token()
@defer.inlineCallbacks @defer.inlineCallbacks
def add_pusher(self, user_id, access_token, kind, app_id, def add_pusher(
app_display_name, device_display_name, self,
pushkey, pushkey_ts, lang, data, last_stream_ordering, user_id,
profile_tag=""): access_token,
kind,
app_id,
app_display_name,
device_display_name,
pushkey,
pushkey_ts,
lang,
data,
last_stream_ordering,
profile_tag="",
):
with self._pushers_id_gen.get_next() as stream_id: with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on # no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so _simple_upsert will retry # (app_id, pushkey, user_name) so _simple_upsert will retry
yield self._simple_upsert( yield self._simple_upsert(
table="pushers", table="pushers",
keyvalues={ keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
"app_id": app_id,
"pushkey": pushkey,
"user_name": user_id,
},
values={ values={
"access_token": access_token, "access_token": access_token,
"kind": kind, "kind": kind,
@ -247,7 +259,8 @@ class PusherStore(PusherWorkerStore):
yield self.runInteraction( yield self.runInteraction(
"add_pusher", "add_pusher",
self._invalidate_cache_and_stream, self._invalidate_cache_and_stream,
self.get_if_user_has_pusher, (user_id,) self.get_if_user_has_pusher,
(user_id,),
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -260,7 +273,7 @@ class PusherStore(PusherWorkerStore):
self._simple_delete_one_txn( self._simple_delete_one_txn(
txn, txn,
"pushers", "pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id} {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
) )
# it's possible for us to end up with duplicate rows for # it's possible for us to end up with duplicate rows for
@ -278,13 +291,12 @@ class PusherStore(PusherWorkerStore):
) )
with self._pushers_id_gen.get_next() as stream_id: with self._pushers_id_gen.get_next() as stream_id:
yield self.runInteraction( yield self.runInteraction("delete_pusher", delete_pusher_txn, stream_id)
"delete_pusher", delete_pusher_txn, stream_id
)
@defer.inlineCallbacks @defer.inlineCallbacks
def update_pusher_last_stream_ordering(self, app_id, pushkey, user_id, def update_pusher_last_stream_ordering(
last_stream_ordering): self, app_id, pushkey, user_id, last_stream_ordering
):
yield self._simple_update_one( yield self._simple_update_one(
"pushers", "pushers",
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id}, {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
@ -293,23 +305,21 @@ class PusherStore(PusherWorkerStore):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def update_pusher_last_stream_ordering_and_success(self, app_id, pushkey, def update_pusher_last_stream_ordering_and_success(
user_id, self, app_id, pushkey, user_id, last_stream_ordering, last_success
last_stream_ordering, ):
last_success):
yield self._simple_update_one( yield self._simple_update_one(
"pushers", "pushers",
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id}, {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
{ {
'last_stream_ordering': last_stream_ordering, 'last_stream_ordering': last_stream_ordering,
'last_success': last_success 'last_success': last_success,
}, },
desc="update_pusher_last_stream_ordering_and_success", desc="update_pusher_last_stream_ordering_and_success",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def update_pusher_failing_since(self, app_id, pushkey, user_id, def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
failing_since):
yield self._simple_update_one( yield self._simple_update_one(
"pushers", "pushers",
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id}, {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
@ -323,14 +333,14 @@ class PusherStore(PusherWorkerStore):
"pusher_throttle", "pusher_throttle",
{"pusher": pusher_id}, {"pusher": pusher_id},
["room_id", "last_sent_ts", "throttle_ms"], ["room_id", "last_sent_ts", "throttle_ms"],
desc="get_throttle_params_by_room" desc="get_throttle_params_by_room",
) )
params_by_room = {} params_by_room = {}
for row in res: for row in res:
params_by_room[row["room_id"]] = { params_by_room[row["room_id"]] = {
"last_sent_ts": row["last_sent_ts"], "last_sent_ts": row["last_sent_ts"],
"throttle_ms": row["throttle_ms"] "throttle_ms": row["throttle_ms"],
} }
defer.returnValue(params_by_room) defer.returnValue(params_by_room)

View File

@ -64,10 +64,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
def get_receipts_for_room(self, room_id, receipt_type): def get_receipts_for_room(self, room_id, receipt_type):
return self._simple_select_list( return self._simple_select_list(
table="receipts_linearized", table="receipts_linearized",
keyvalues={ keyvalues={"room_id": room_id, "receipt_type": receipt_type},
"room_id": room_id,
"receipt_type": receipt_type,
},
retcols=("user_id", "event_id"), retcols=("user_id", "event_id"),
desc="get_receipts_for_room", desc="get_receipts_for_room",
) )
@ -79,7 +76,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
keyvalues={ keyvalues={
"room_id": room_id, "room_id": room_id,
"receipt_type": receipt_type, "receipt_type": receipt_type,
"user_id": user_id "user_id": user_id,
}, },
retcol="event_id", retcol="event_id",
desc="get_own_receipt_for_user", desc="get_own_receipt_for_user",
@ -90,10 +87,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
def get_receipts_for_user(self, user_id, receipt_type): def get_receipts_for_user(self, user_id, receipt_type):
rows = yield self._simple_select_list( rows = yield self._simple_select_list(
table="receipts_linearized", table="receipts_linearized",
keyvalues={ keyvalues={"user_id": user_id, "receipt_type": receipt_type},
"user_id": user_id,
"receipt_type": receipt_type,
},
retcols=("room_id", "event_id"), retcols=("room_id", "event_id"),
desc="get_receipts_for_user", desc="get_receipts_for_user",
) )
@ -114,16 +108,18 @@ class ReceiptsWorkerStore(SQLBaseStore):
) )
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
return txn.fetchall() return txn.fetchall()
rows = yield self.runInteraction(
"get_receipts_for_user_with_orderings", f rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f)
) defer.returnValue(
defer.returnValue({ {
row[0]: { row[0]: {
"event_id": row[1], "event_id": row[1],
"topological_ordering": row[2], "topological_ordering": row[2],
"stream_ordering": row[3], "stream_ordering": row[3],
} for row in rows }
}) for row in rows
}
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
@ -177,6 +173,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""See get_linearized_receipts_for_room """See get_linearized_receipts_for_room
""" """
def f(txn): def f(txn):
if from_key: if from_key:
sql = ( sql = (
@ -184,48 +181,40 @@ class ReceiptsWorkerStore(SQLBaseStore):
" room_id = ? AND stream_id > ? AND stream_id <= ?" " room_id = ? AND stream_id > ? AND stream_id <= ?"
) )
txn.execute( txn.execute(sql, (room_id, from_key, to_key))
sql,
(room_id, from_key, to_key)
)
else: else:
sql = ( sql = (
"SELECT * FROM receipts_linearized WHERE" "SELECT * FROM receipts_linearized WHERE"
" room_id = ? AND stream_id <= ?" " room_id = ? AND stream_id <= ?"
) )
txn.execute( txn.execute(sql, (room_id, to_key))
sql,
(room_id, to_key)
)
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
return rows return rows
rows = yield self.runInteraction( rows = yield self.runInteraction("get_linearized_receipts_for_room", f)
"get_linearized_receipts_for_room", f
)
if not rows: if not rows:
defer.returnValue([]) defer.returnValue([])
content = {} content = {}
for row in rows: for row in rows:
content.setdefault( content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
row["event_id"], {} row["user_id"]
).setdefault( ] = json.loads(row["data"])
row["receipt_type"], {}
)[row["user_id"]] = json.loads(row["data"])
defer.returnValue([{ defer.returnValue(
"type": "m.receipt", [{"type": "m.receipt", "room_id": room_id, "content": content}]
"room_id": room_id, )
"content": content,
}])
@cachedList(cached_method_name="_get_linearized_receipts_for_room", @cachedList(
list_name="room_ids", num_args=3, inlineCallbacks=True) cached_method_name="_get_linearized_receipts_for_room",
list_name="room_ids",
num_args=3,
inlineCallbacks=True,
)
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids: if not room_ids:
defer.returnValue({}) defer.returnValue({})
@ -235,9 +224,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
sql = ( sql = (
"SELECT * FROM receipts_linearized WHERE" "SELECT * FROM receipts_linearized WHERE"
" room_id IN (%s) AND stream_id > ? AND stream_id <= ?" " room_id IN (%s) AND stream_id > ? AND stream_id <= ?"
) % ( ) % (",".join(["?"] * len(room_ids)))
",".join(["?"] * len(room_ids))
)
args = list(room_ids) args = list(room_ids)
args.extend([from_key, to_key]) args.extend([from_key, to_key])
@ -246,9 +233,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
sql = ( sql = (
"SELECT * FROM receipts_linearized WHERE" "SELECT * FROM receipts_linearized WHERE"
" room_id IN (%s) AND stream_id <= ?" " room_id IN (%s) AND stream_id <= ?"
) % ( ) % (",".join(["?"] * len(room_ids)))
",".join(["?"] * len(room_ids))
)
args = list(room_ids) args = list(room_ids)
args.append(to_key) args.append(to_key)
@ -257,19 +242,16 @@ class ReceiptsWorkerStore(SQLBaseStore):
return self.cursor_to_dict(txn) return self.cursor_to_dict(txn)
txn_results = yield self.runInteraction( txn_results = yield self.runInteraction("_get_linearized_receipts_for_rooms", f)
"_get_linearized_receipts_for_rooms", f
)
results = {} results = {}
for row in txn_results: for row in txn_results:
# We want a single event per room, since we want to batch the # We want a single event per room, since we want to batch the
# receipts by room, event and type. # receipts by room, event and type.
room_event = results.setdefault(row["room_id"], { room_event = results.setdefault(
"type": "m.receipt", row["room_id"],
"room_id": row["room_id"], {"type": "m.receipt", "room_id": row["room_id"], "content": {}},
"content": {}, )
})
# The content is of the form: # The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
@ -301,21 +283,21 @@ class ReceiptsWorkerStore(SQLBaseStore):
args.append(limit) args.append(limit)
txn.execute(sql, args) txn.execute(sql, args)
return ( return (r[0:5] + (json.loads(r[5]),) for r in txn)
r[0:5] + (json.loads(r[5]), ) for r in txn
)
return self.runInteraction( return self.runInteraction(
"get_all_updated_receipts", get_all_updated_receipts_txn "get_all_updated_receipts", get_all_updated_receipts_txn
) )
def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type, def _invalidate_get_users_with_receipts_in_room(
user_id): self, room_id, receipt_type, user_id
):
if receipt_type != "m.read": if receipt_type != "m.read":
return return
# Returns either an ObservableDeferred or the raw result # Returns either an ObservableDeferred or the raw result
res = self.get_users_with_read_receipts_in_room.cache.get( res = self.get_users_with_read_receipts_in_room.cache.get(
room_id, None, update_metrics=False, room_id, None, update_metrics=False
) )
# first handle the Deferred case # first handle the Deferred case
@ -346,8 +328,9 @@ class ReceiptsStore(ReceiptsWorkerStore):
def get_max_receipt_stream_id(self): def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token() return self._receipts_id_gen.get_current_token()
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, def insert_linearized_receipt_txn(
user_id, event_id, data, stream_id): self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
):
"""Inserts a read-receipt into the database if it's newer than the current RR """Inserts a read-receipt into the database if it's newer than the current RR
Returns: int|None Returns: int|None
@ -360,7 +343,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
table="events", table="events",
retcols=["stream_ordering", "received_ts"], retcols=["stream_ordering", "received_ts"],
keyvalues={"event_id": event_id}, keyvalues={"event_id": event_id},
allow_none=True allow_none=True,
) )
stream_ordering = int(res["stream_ordering"]) if res else None stream_ordering = int(res["stream_ordering"]) if res else None
@ -381,31 +364,31 @@ class ReceiptsStore(ReceiptsWorkerStore):
logger.debug( logger.debug(
"Ignoring new receipt for %s in favour of existing " "Ignoring new receipt for %s in favour of existing "
"one for later event %s", "one for later event %s",
event_id, eid, event_id,
eid,
) )
return None return None
txn.call_after( txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
self.get_receipts_for_room.invalidate, (room_id, receipt_type)
)
txn.call_after( txn.call_after(
self._invalidate_get_users_with_receipts_in_room, self._invalidate_get_users_with_receipts_in_room,
room_id, receipt_type, user_id, room_id,
) receipt_type,
txn.call_after( user_id,
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
) )
txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
# FIXME: This shouldn't invalidate the whole cache # FIXME: This shouldn't invalidate the whole cache
txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,)) txn.call_after(
self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
)
txn.call_after( txn.call_after(
self._receipts_stream_cache.entity_has_changed, self._receipts_stream_cache.entity_has_changed, room_id, stream_id
room_id, stream_id
) )
txn.call_after( txn.call_after(
self.get_last_receipt_event_id_for_user.invalidate, self.get_last_receipt_event_id_for_user.invalidate,
(user_id, room_id, receipt_type) (user_id, room_id, receipt_type),
) )
self._simple_delete_txn( self._simple_delete_txn(
@ -415,7 +398,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"room_id": room_id, "room_id": room_id,
"receipt_type": receipt_type, "receipt_type": receipt_type,
"user_id": user_id, "user_id": user_id,
} },
) )
self._simple_insert_txn( self._simple_insert_txn(
@ -428,15 +411,12 @@ class ReceiptsStore(ReceiptsWorkerStore):
"user_id": user_id, "user_id": user_id,
"event_id": event_id, "event_id": event_id,
"data": json.dumps(data), "data": json.dumps(data),
} },
) )
if receipt_type == "m.read" and stream_ordering is not None: if receipt_type == "m.read" and stream_ordering is not None:
self._remove_old_push_actions_before_txn( self._remove_old_push_actions_before_txn(
txn, txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
room_id=room_id,
user_id=user_id,
stream_ordering=stream_ordering,
) )
return rx_ts return rx_ts
@ -479,7 +459,10 @@ class ReceiptsStore(ReceiptsWorkerStore):
event_ts = yield self.runInteraction( event_ts = yield self.runInteraction(
"insert_linearized_receipt", "insert_linearized_receipt",
self.insert_linearized_receipt_txn, self.insert_linearized_receipt_txn,
room_id, receipt_type, user_id, linearized_event_id, room_id,
receipt_type,
user_id,
linearized_event_id,
data, data,
stream_id=stream_id, stream_id=stream_id,
) )
@ -490,39 +473,43 @@ class ReceiptsStore(ReceiptsWorkerStore):
now = self._clock.time_msec() now = self._clock.time_msec()
logger.debug( logger.debug(
"RR for event %s in %s (%i ms old)", "RR for event %s in %s (%i ms old)",
linearized_event_id, room_id, now - event_ts, linearized_event_id,
room_id,
now - event_ts,
) )
yield self.insert_graph_receipt( yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
room_id, receipt_type, user_id, event_ids, data
)
max_persisted_id = self._receipts_id_gen.get_current_token() max_persisted_id = self._receipts_id_gen.get_current_token()
defer.returnValue((stream_id, max_persisted_id)) defer.returnValue((stream_id, max_persisted_id))
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
data):
return self.runInteraction( return self.runInteraction(
"insert_graph_receipt", "insert_graph_receipt",
self.insert_graph_receipt_txn, self.insert_graph_receipt_txn,
room_id, receipt_type, user_id, event_ids, data room_id,
receipt_type,
user_id,
event_ids,
data,
) )
def insert_graph_receipt_txn(self, txn, room_id, receipt_type, def insert_graph_receipt_txn(
user_id, event_ids, data): self, txn, room_id, receipt_type, user_id, event_ids, data
txn.call_after( ):
self.get_receipts_for_room.invalidate, (room_id, receipt_type) txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
)
txn.call_after( txn.call_after(
self._invalidate_get_users_with_receipts_in_room, self._invalidate_get_users_with_receipts_in_room,
room_id, receipt_type, user_id, room_id,
) receipt_type,
txn.call_after( user_id,
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
) )
txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
# FIXME: This shouldn't invalidate the whole cache # FIXME: This shouldn't invalidate the whole cache
txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,)) txn.call_after(
self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
)
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
@ -531,7 +518,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"room_id": room_id, "room_id": room_id,
"receipt_type": receipt_type, "receipt_type": receipt_type,
"user_id": user_id, "user_id": user_id,
} },
) )
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
@ -542,5 +529,5 @@ class ReceiptsStore(ReceiptsWorkerStore):
"user_id": user_id, "user_id": user_id,
"event_ids": json.dumps(event_ids), "event_ids": json.dumps(event_ids),
"data": json.dumps(data), "data": json.dumps(data),
} },
) )

View File

@ -37,13 +37,15 @@ class RegistrationWorkerStore(SQLBaseStore):
def get_user_by_id(self, user_id): def get_user_by_id(self, user_id):
return self._simple_select_one( return self._simple_select_one(
table="users", table="users",
keyvalues={ keyvalues={"name": user_id},
"name": user_id,
},
retcols=[ retcols=[
"name", "password_hash", "is_guest", "name",
"consent_version", "consent_server_notice_sent", "password_hash",
"appservice_id", "creation_ts", "is_guest",
"consent_version",
"consent_server_notice_sent",
"appservice_id",
"creation_ts",
], ],
allow_none=True, allow_none=True,
desc="get_user_by_id", desc="get_user_by_id",
@ -81,9 +83,7 @@ class RegistrationWorkerStore(SQLBaseStore):
including the keys `name`, `is_guest`, `device_id`, `token_id`. including the keys `name`, `is_guest`, `device_id`, `token_id`.
""" """
return self.runInteraction( return self.runInteraction(
"get_user_by_access_token", "get_user_by_access_token", self._query_for_auth, token
self._query_for_auth,
token
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -143,10 +143,10 @@ class RegistrationWorkerStore(SQLBaseStore):
"""Gets users that match user_id case insensitively. """Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash. Returns a mapping of user_id -> password_hash.
""" """
def f(txn): def f(txn):
sql = ( sql = (
"SELECT name, password_hash FROM users" "SELECT name, password_hash FROM users" " WHERE lower(name) = lower(?)"
" WHERE lower(name) = lower(?)"
) )
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
return dict(txn) return dict(txn)
@ -156,6 +156,7 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def count_all_users(self): def count_all_users(self):
"""Counts all users registered on the homeserver.""" """Counts all users registered on the homeserver."""
def _count_users(txn): def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users") txn.execute("SELECT COUNT(*) AS users FROM users")
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
@ -173,6 +174,7 @@ class RegistrationWorkerStore(SQLBaseStore):
3) bridged users 3) bridged users
who registered on the homeserver in the past 24 hours who registered on the homeserver in the past 24 hours
""" """
def _count_daily_user_type(txn): def _count_daily_user_type(txn):
yesterday = int(self._clock.time()) - (60 * 60 * 24) yesterday = int(self._clock.time()) - (60 * 60 * 24)
@ -193,15 +195,18 @@ class RegistrationWorkerStore(SQLBaseStore):
for row in txn: for row in txn:
results[row[0]] = row[1] results[row[0]] = row[1]
return results return results
return self.runInteraction("count_daily_user_type", _count_daily_user_type) return self.runInteraction("count_daily_user_type", _count_daily_user_type)
@defer.inlineCallbacks @defer.inlineCallbacks
def count_nonbridged_users(self): def count_nonbridged_users(self):
def _count_users(txn): def _count_users(txn):
txn.execute(""" txn.execute(
"""
SELECT COALESCE(COUNT(*), 0) FROM users SELECT COALESCE(COUNT(*), 0) FROM users
WHERE appservice_id IS NULL WHERE appservice_id IS NULL
""") """
)
count, = txn.fetchone() count, = txn.fetchone()
return count return count
@ -220,6 +225,7 @@ class RegistrationWorkerStore(SQLBaseStore):
avoid the case of ID 10000000 being pre-allocated, so us wasting the avoid the case of ID 10000000 being pre-allocated, so us wasting the
first (and shortest) many generated user IDs. first (and shortest) many generated user IDs.
""" """
def _find_next_generated_user_id(txn): def _find_next_generated_user_id(txn):
txn.execute("SELECT name FROM users") txn.execute("SELECT name FROM users")
@ -227,7 +233,7 @@ class RegistrationWorkerStore(SQLBaseStore):
found = set() found = set()
for user_id, in txn: for (user_id,) in txn:
match = regex.search(user_id) match = regex.search(user_id)
if match: if match:
found.add(int(match.group(1))) found.add(int(match.group(1)))
@ -235,20 +241,22 @@ class RegistrationWorkerStore(SQLBaseStore):
if i not in found: if i not in found:
return i return i
defer.returnValue((yield self.runInteraction( defer.returnValue(
"find_next_generated_user_id", (
_find_next_generated_user_id yield self.runInteraction(
))) "find_next_generated_user_id", _find_next_generated_user_id
)
)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_3pid_guest_access_token(self, medium, address): def get_3pid_guest_access_token(self, medium, address):
ret = yield self._simple_select_one( ret = yield self._simple_select_one(
"threepid_guest_access_tokens", "threepid_guest_access_tokens",
{ {"medium": medium, "address": address},
"medium": medium, ["guest_access_token"],
"address": address True,
}, 'get_3pid_guest_access_token',
["guest_access_token"], True, 'get_3pid_guest_access_token'
) )
if ret: if ret:
defer.returnValue(ret["guest_access_token"]) defer.returnValue(ret["guest_access_token"])
@ -266,8 +274,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Deferred[str|None]: user id or None if no user id/threepid mapping exists Deferred[str|None]: user id or None if no user id/threepid mapping exists
""" """
user_id = yield self.runInteraction( user_id = yield self.runInteraction(
"get_user_id_by_threepid", self.get_user_id_by_threepid_txn, "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
medium, address
) )
defer.returnValue(user_id) defer.returnValue(user_id)
@ -285,11 +292,9 @@ class RegistrationWorkerStore(SQLBaseStore):
ret = self._simple_select_one_txn( ret = self._simple_select_one_txn(
txn, txn,
"user_threepids", "user_threepids",
{ {"medium": medium, "address": address},
"medium": medium, ['user_id'],
"address": address True,
},
['user_id'], True
) )
if ret: if ret:
return ret['user_id'] return ret['user_id']
@ -297,41 +302,33 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at): def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
yield self._simple_upsert("user_threepids", { yield self._simple_upsert(
"medium": medium, "user_threepids",
"address": address, {"medium": medium, "address": address},
}, { {"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
"user_id": user_id, )
"validated_at": validated_at,
"added_at": added_at,
})
@defer.inlineCallbacks @defer.inlineCallbacks
def user_get_threepids(self, user_id): def user_get_threepids(self, user_id):
ret = yield self._simple_select_list( ret = yield self._simple_select_list(
"user_threepids", { "user_threepids",
"user_id": user_id {"user_id": user_id},
},
['medium', 'address', 'validated_at', 'added_at'], ['medium', 'address', 'validated_at', 'added_at'],
'user_get_threepids' 'user_get_threepids',
) )
defer.returnValue(ret) defer.returnValue(ret)
def user_delete_threepid(self, user_id, medium, address): def user_delete_threepid(self, user_id, medium, address):
return self._simple_delete( return self._simple_delete(
"user_threepids", "user_threepids",
keyvalues={ keyvalues={"user_id": user_id, "medium": medium, "address": address},
"user_id": user_id,
"medium": medium,
"address": address,
},
desc="user_delete_threepids", desc="user_delete_threepids",
) )
class RegistrationStore(RegistrationWorkerStore, class RegistrationStore(
background_updates.BackgroundUpdateStore): RegistrationWorkerStore, background_updates.BackgroundUpdateStore
):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(RegistrationStore, self).__init__(db_conn, hs) super(RegistrationStore, self).__init__(db_conn, hs)
@ -372,18 +369,22 @@ class RegistrationStore(RegistrationWorkerStore,
yield self._simple_insert( yield self._simple_insert(
"access_tokens", "access_tokens",
{ {"id": next_id, "user_id": user_id, "token": token, "device_id": device_id},
"id": next_id,
"user_id": user_id,
"token": token,
"device_id": device_id,
},
desc="add_access_token_to_user", desc="add_access_token_to_user",
) )
def register(self, user_id, token=None, password_hash=None, def register(
was_guest=False, make_guest=False, appservice_id=None, self,
create_profile_with_displayname=None, admin=False, user_type=None): user_id,
token=None,
password_hash=None,
was_guest=False,
make_guest=False,
appservice_id=None,
create_profile_with_displayname=None,
admin=False,
user_type=None,
):
"""Attempts to register an account. """Attempts to register an account.
Args: Args:
@ -417,7 +418,7 @@ class RegistrationStore(RegistrationWorkerStore,
appservice_id, appservice_id,
create_profile_with_displayname, create_profile_with_displayname,
admin, admin,
user_type user_type,
) )
def _register( def _register(
@ -447,10 +448,7 @@ class RegistrationStore(RegistrationWorkerStore,
self._simple_select_one_txn( self._simple_select_one_txn(
txn, txn,
"users", "users",
keyvalues={ keyvalues={"name": user_id, "is_guest": 1},
"name": user_id,
"is_guest": 1,
},
retcols=("name",), retcols=("name",),
allow_none=False, allow_none=False,
) )
@ -458,10 +456,7 @@ class RegistrationStore(RegistrationWorkerStore,
self._simple_update_one_txn( self._simple_update_one_txn(
txn, txn,
"users", "users",
keyvalues={ keyvalues={"name": user_id, "is_guest": 1},
"name": user_id,
"is_guest": 1,
},
updatevalues={ updatevalues={
"password_hash": password_hash, "password_hash": password_hash,
"upgrade_ts": now, "upgrade_ts": now,
@ -469,7 +464,7 @@ class RegistrationStore(RegistrationWorkerStore,
"appservice_id": appservice_id, "appservice_id": appservice_id,
"admin": 1 if admin else 0, "admin": 1 if admin else 0,
"user_type": user_type, "user_type": user_type,
} },
) )
else: else:
self._simple_insert_txn( self._simple_insert_txn(
@ -483,20 +478,17 @@ class RegistrationStore(RegistrationWorkerStore,
"appservice_id": appservice_id, "appservice_id": appservice_id,
"admin": 1 if admin else 0, "admin": 1 if admin else 0,
"user_type": user_type, "user_type": user_type,
} },
) )
except self.database_engine.module.IntegrityError: except self.database_engine.module.IntegrityError:
raise StoreError( raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
400, "User ID already taken.", errcode=Codes.USER_IN_USE
)
if token: if token:
# it's possible for this to get a conflict, but only for a single user # it's possible for this to get a conflict, but only for a single user
# since tokens are namespaced based on their user ID # since tokens are namespaced based on their user ID
txn.execute( txn.execute(
"INSERT INTO access_tokens(id, user_id, token)" "INSERT INTO access_tokens(id, user_id, token)" " VALUES (?,?,?)",
" VALUES (?,?,?)", (next_id, user_id, token),
(next_id, user_id, token,)
) )
if create_profile_with_displayname: if create_profile_with_displayname:
@ -507,12 +499,10 @@ class RegistrationStore(RegistrationWorkerStore,
# while everything else uses the full mxid. # while everything else uses the full mxid.
txn.execute( txn.execute(
"INSERT INTO profiles(user_id, displayname) VALUES (?,?)", "INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
(user_id_obj.localpart, create_profile_with_displayname) (user_id_obj.localpart, create_profile_with_displayname),
) )
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn, self.get_user_by_id, (user_id,)
)
txn.call_after(self.is_guest.invalidate, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,))
def user_set_password_hash(self, user_id, password_hash): def user_set_password_hash(self, user_id, password_hash):
@ -521,22 +511,14 @@ class RegistrationStore(RegistrationWorkerStore,
removes most of the entries subsequently anyway so it would be removes most of the entries subsequently anyway so it would be
pointless. Use flush_user separately. pointless. Use flush_user separately.
""" """
def user_set_password_hash_txn(txn): def user_set_password_hash_txn(txn):
self._simple_update_one_txn( self._simple_update_one_txn(
txn, txn, 'users', {'name': user_id}, {'password_hash': password_hash}
'users', {
'name': user_id
},
{
'password_hash': password_hash
}
)
self._invalidate_cache_and_stream(
txn, self.get_user_by_id, (user_id,)
)
return self.runInteraction(
"user_set_password_hash", user_set_password_hash_txn
) )
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.runInteraction("user_set_password_hash", user_set_password_hash_txn)
def user_set_consent_version(self, user_id, consent_version): def user_set_consent_version(self, user_id, consent_version):
"""Updates the user table to record privacy policy consent """Updates the user table to record privacy policy consent
@ -549,16 +531,16 @@ class RegistrationStore(RegistrationWorkerStore,
Raises: Raises:
StoreError(404) if user not found StoreError(404) if user not found
""" """
def f(txn): def f(txn):
self._simple_update_one_txn( self._simple_update_one_txn(
txn, txn,
table='users', table='users',
keyvalues={'name': user_id, }, keyvalues={'name': user_id},
updatevalues={'consent_version': consent_version, }, updatevalues={'consent_version': consent_version},
)
self._invalidate_cache_and_stream(
txn, self.get_user_by_id, (user_id,)
) )
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.runInteraction("user_set_consent_version", f) return self.runInteraction("user_set_consent_version", f)
def user_set_consent_server_notice_sent(self, user_id, consent_version): def user_set_consent_server_notice_sent(self, user_id, consent_version):
@ -573,20 +555,19 @@ class RegistrationStore(RegistrationWorkerStore,
Raises: Raises:
StoreError(404) if user not found StoreError(404) if user not found
""" """
def f(txn): def f(txn):
self._simple_update_one_txn( self._simple_update_one_txn(
txn, txn,
table='users', table='users',
keyvalues={'name': user_id, }, keyvalues={'name': user_id},
updatevalues={'consent_server_notice_sent': consent_version, }, updatevalues={'consent_server_notice_sent': consent_version},
)
self._invalidate_cache_and_stream(
txn, self.get_user_by_id, (user_id,)
) )
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.runInteraction("user_set_consent_server_notice_sent", f) return self.runInteraction("user_set_consent_server_notice_sent", f)
def user_delete_access_tokens(self, user_id, except_token_id=None, def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
device_id=None):
""" """
Invalidate access tokens belonging to a user Invalidate access tokens belonging to a user
@ -601,10 +582,9 @@ class RegistrationStore(RegistrationWorkerStore,
defer.Deferred[list[str, int, str|None, int]]: a list of defer.Deferred[list[str, int, str|None, int]]: a list of
(token, token id, device id) for each of the deleted tokens (token, token id, device id) for each of the deleted tokens
""" """
def f(txn): def f(txn):
keyvalues = { keyvalues = {"user_id": user_id}
"user_id": user_id,
}
if device_id is not None: if device_id is not None:
keyvalues["device_id"] = device_id keyvalues["device_id"] = device_id
@ -616,8 +596,9 @@ class RegistrationStore(RegistrationWorkerStore,
values.append(except_token_id) values.append(except_token_id)
txn.execute( txn.execute(
"SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause, "SELECT token, id, device_id FROM access_tokens WHERE %s"
values % where_clause,
values,
) )
tokens_and_devices = [(r[0], r[1], r[2]) for r in txn] tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
@ -626,25 +607,16 @@ class RegistrationStore(RegistrationWorkerStore,
txn, self.get_user_by_access_token, (token,) txn, self.get_user_by_access_token, (token,)
) )
txn.execute( txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
"DELETE FROM access_tokens WHERE %s" % where_clause,
values
)
return tokens_and_devices return tokens_and_devices
return self.runInteraction( return self.runInteraction("user_delete_access_tokens", f)
"user_delete_access_tokens", f,
)
def delete_access_token(self, access_token): def delete_access_token(self, access_token):
def f(txn): def f(txn):
self._simple_delete_one_txn( self._simple_delete_one_txn(
txn, txn, table="access_tokens", keyvalues={"token": access_token}
table="access_tokens",
keyvalues={
"token": access_token
},
) )
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
@ -683,12 +655,13 @@ class RegistrationStore(RegistrationWorkerStore,
deferred str: Whichever access token is persisted at the end deferred str: Whichever access token is persisted at the end
of this function call. of this function call.
""" """
def insert(txn): def insert(txn):
txn.execute( txn.execute(
"INSERT INTO threepid_guest_access_tokens " "INSERT INTO threepid_guest_access_tokens "
"(medium, address, guest_access_token, first_inviter) " "(medium, address, guest_access_token, first_inviter) "
"VALUES (?, ?, ?, ?)", "VALUES (?, ?, ?, ?)",
(medium, address, access_token, inviter_user_id) (medium, address, access_token, inviter_user_id),
) )
try: try:
@ -705,9 +678,7 @@ class RegistrationStore(RegistrationWorkerStore,
""" """
return self._simple_insert( return self._simple_insert(
"users_pending_deactivation", "users_pending_deactivation",
values={ values={"user_id": user_id},
"user_id": user_id,
},
desc="add_user_pending_deactivation", desc="add_user_pending_deactivation",
) )
@ -720,9 +691,7 @@ class RegistrationStore(RegistrationWorkerStore,
# the table, so somehow duplicate entries have ended up in it. # the table, so somehow duplicate entries have ended up in it.
return self._simple_delete( return self._simple_delete(
"users_pending_deactivation", "users_pending_deactivation",
keyvalues={ keyvalues={"user_id": user_id},
"user_id": user_id,
},
desc="del_user_pending_deactivation", desc="del_user_pending_deactivation",
) )

View File

@ -36,9 +36,7 @@ class RejectionsStore(SQLBaseStore):
return self._simple_select_one_onecol( return self._simple_select_one_onecol(
table="rejections", table="rejections",
retcol="reason", retcol="reason",
keyvalues={ keyvalues={"event_id": event_id},
"event_id": event_id,
},
allow_none=True, allow_none=True,
desc="get_rejection_reason", desc="get_rejection_reason",
) )

View File

@ -30,13 +30,11 @@ logger = logging.getLogger(__name__)
OpsLevel = collections.namedtuple( OpsLevel = collections.namedtuple(
"OpsLevel", "OpsLevel", ("ban_level", "kick_level", "redact_level")
("ban_level", "kick_level", "redact_level",)
) )
RatelimitOverride = collections.namedtuple( RatelimitOverride = collections.namedtuple(
"RatelimitOverride", "RatelimitOverride", ("messages_per_second", "burst_count")
("messages_per_second", "burst_count",)
) )
@ -60,9 +58,7 @@ class RoomWorkerStore(SQLBaseStore):
def get_public_room_ids(self): def get_public_room_ids(self):
return self._simple_select_onecol( return self._simple_select_onecol(
table="rooms", table="rooms",
keyvalues={ keyvalues={"is_public": True},
"is_public": True,
},
retcol="room_id", retcol="room_id",
desc="get_public_room_ids", desc="get_public_room_ids",
) )
@ -79,11 +75,11 @@ class RoomWorkerStore(SQLBaseStore):
return self.runInteraction( return self.runInteraction(
"get_public_room_ids_at_stream_id", "get_public_room_ids_at_stream_id",
self.get_public_room_ids_at_stream_id_txn, self.get_public_room_ids_at_stream_id_txn,
stream_id, network_tuple=network_tuple stream_id,
network_tuple=network_tuple,
) )
def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, network_tuple):
network_tuple):
return { return {
rm rm
for rm, vis in self.get_published_at_stream_id_txn( for rm, vis in self.get_published_at_stream_id_txn(
@ -96,7 +92,7 @@ class RoomWorkerStore(SQLBaseStore):
if network_tuple: if network_tuple:
# We want to get from a particular list. No aggregation required. # We want to get from a particular list. No aggregation required.
sql = (""" sql = """
SELECT room_id, visibility FROM public_room_list_stream SELECT room_id, visibility FROM public_room_list_stream
INNER JOIN ( INNER JOIN (
SELECT room_id, max(stream_id) AS stream_id SELECT room_id, max(stream_id) AS stream_id
@ -104,25 +100,22 @@ class RoomWorkerStore(SQLBaseStore):
WHERE stream_id <= ? %s WHERE stream_id <= ? %s
GROUP BY room_id GROUP BY room_id
) grouped USING (room_id, stream_id) ) grouped USING (room_id, stream_id)
""") """
if network_tuple.appservice_id is not None: if network_tuple.appservice_id is not None:
txn.execute( txn.execute(
sql % ("AND appservice_id = ? AND network_id = ?",), sql % ("AND appservice_id = ? AND network_id = ?",),
(stream_id, network_tuple.appservice_id, network_tuple.network_id,) (stream_id, network_tuple.appservice_id, network_tuple.network_id),
) )
else: else:
txn.execute( txn.execute(sql % ("AND appservice_id IS NULL",), (stream_id,))
sql % ("AND appservice_id IS NULL",),
(stream_id,)
)
return dict(txn) return dict(txn)
else: else:
# We want to get from all lists, so we need to aggregate the results # We want to get from all lists, so we need to aggregate the results
logger.info("Executing full list") logger.info("Executing full list")
sql = (""" sql = """
SELECT room_id, visibility SELECT room_id, visibility
FROM public_room_list_stream FROM public_room_list_stream
INNER JOIN ( INNER JOIN (
@ -133,12 +126,9 @@ class RoomWorkerStore(SQLBaseStore):
WHERE stream_id <= ? WHERE stream_id <= ?
GROUP BY room_id, appservice_id, network_id GROUP BY room_id, appservice_id, network_id
) grouped USING (room_id, stream_id) ) grouped USING (room_id, stream_id)
""") """
txn.execute( txn.execute(sql, (stream_id,))
sql,
(stream_id,)
)
results = {} results = {}
# A room is visible if its visible on any list. # A room is visible if its visible on any list.
@ -147,8 +137,7 @@ class RoomWorkerStore(SQLBaseStore):
return results return results
def get_public_room_changes(self, prev_stream_id, new_stream_id, def get_public_room_changes(self, prev_stream_id, new_stream_id, network_tuple):
network_tuple):
def get_public_room_changes_txn(txn): def get_public_room_changes_txn(txn):
then_rooms = self.get_public_room_ids_at_stream_id_txn( then_rooms = self.get_public_room_ids_at_stream_id_txn(
txn, prev_stream_id, network_tuple txn, prev_stream_id, network_tuple
@ -158,9 +147,7 @@ class RoomWorkerStore(SQLBaseStore):
txn, new_stream_id, network_tuple txn, new_stream_id, network_tuple
) )
now_rooms_visible = set( now_rooms_visible = set(rm for rm, vis in now_rooms_dict.items() if vis)
rm for rm, vis in now_rooms_dict.items() if vis
)
now_rooms_not_visible = set( now_rooms_not_visible = set(
rm for rm, vis in now_rooms_dict.items() if not vis rm for rm, vis in now_rooms_dict.items() if not vis
) )
@ -178,9 +165,7 @@ class RoomWorkerStore(SQLBaseStore):
def is_room_blocked(self, room_id): def is_room_blocked(self, room_id):
return self._simple_select_one_onecol( return self._simple_select_one_onecol(
table="blocked_rooms", table="blocked_rooms",
keyvalues={ keyvalues={"room_id": room_id},
"room_id": room_id,
},
retcol="1", retcol="1",
allow_none=True, allow_none=True,
desc="is_room_blocked", desc="is_room_blocked",
@ -208,16 +193,17 @@ class RoomWorkerStore(SQLBaseStore):
) )
if row: if row:
defer.returnValue(RatelimitOverride( defer.returnValue(
RatelimitOverride(
messages_per_second=row["messages_per_second"], messages_per_second=row["messages_per_second"],
burst_count=row["burst_count"], burst_count=row["burst_count"],
)) )
)
else: else:
defer.returnValue(None) defer.returnValue(None)
class RoomStore(RoomWorkerStore, SearchStore): class RoomStore(RoomWorkerStore, SearchStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def store_room(self, room_id, room_creator_user_id, is_public): def store_room(self, room_id, room_creator_user_id, is_public):
"""Stores a room. """Stores a room.
@ -231,6 +217,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
StoreError if the room could not be stored. StoreError if the room could not be stored.
""" """
try: try:
def store_room_txn(txn, next_id): def store_room_txn(txn, next_id):
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
@ -249,13 +236,11 @@ class RoomStore(RoomWorkerStore, SearchStore):
"stream_id": next_id, "stream_id": next_id,
"room_id": room_id, "room_id": room_id,
"visibility": is_public, "visibility": is_public,
} },
) )
with self._public_room_id_gen.get_next() as next_id: with self._public_room_id_gen.get_next() as next_id:
yield self.runInteraction( yield self.runInteraction("store_room_txn", store_room_txn, next_id)
"store_room_txn",
store_room_txn, next_id,
)
except Exception as e: except Exception as e:
logger.error("store_room with room_id=%s failed: %s", room_id, e) logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.") raise StoreError(500, "Problem creating room.")
@ -297,19 +282,19 @@ class RoomStore(RoomWorkerStore, SearchStore):
"visibility": is_public, "visibility": is_public,
"appservice_id": None, "appservice_id": None,
"network_id": None, "network_id": None,
} },
) )
with self._public_room_id_gen.get_next() as next_id: with self._public_room_id_gen.get_next() as next_id:
yield self.runInteraction( yield self.runInteraction(
"set_room_is_public", "set_room_is_public", set_room_is_public_txn, next_id
set_room_is_public_txn, next_id,
) )
self.hs.get_notifier().on_new_replication_data() self.hs.get_notifier().on_new_replication_data()
@defer.inlineCallbacks @defer.inlineCallbacks
def set_room_is_public_appservice(self, room_id, appservice_id, network_id, def set_room_is_public_appservice(
is_public): self, room_id, appservice_id, network_id, is_public
):
"""Edit the appservice/network specific public room list. """Edit the appservice/network specific public room list.
Each appservice can have a number of published room lists associated Each appservice can have a number of published room lists associated
@ -324,6 +309,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
is_public (bool): Whether to publish or unpublish the room from the is_public (bool): Whether to publish or unpublish the room from the
list. list.
""" """
def set_room_is_public_appservice_txn(txn, next_id): def set_room_is_public_appservice_txn(txn, next_id):
if is_public: if is_public:
try: try:
@ -333,7 +319,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
values={ values={
"appservice_id": appservice_id, "appservice_id": appservice_id,
"network_id": network_id, "network_id": network_id,
"room_id": room_id "room_id": room_id,
}, },
) )
except self.database_engine.module.IntegrityError: except self.database_engine.module.IntegrityError:
@ -346,7 +332,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
keyvalues={ keyvalues={
"appservice_id": appservice_id, "appservice_id": appservice_id,
"network_id": network_id, "network_id": network_id,
"room_id": room_id "room_id": room_id,
}, },
) )
@ -377,13 +363,14 @@ class RoomStore(RoomWorkerStore, SearchStore):
"visibility": is_public, "visibility": is_public,
"appservice_id": appservice_id, "appservice_id": appservice_id,
"network_id": network_id, "network_id": network_id,
} },
) )
with self._public_room_id_gen.get_next() as next_id: with self._public_room_id_gen.get_next() as next_id:
yield self.runInteraction( yield self.runInteraction(
"set_room_is_public_appservice", "set_room_is_public_appservice",
set_room_is_public_appservice_txn, next_id, set_room_is_public_appservice_txn,
next_id,
) )
self.hs.get_notifier().on_new_replication_data() self.hs.get_notifier().on_new_replication_data()
@ -397,9 +384,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
row = txn.fetchone() row = txn.fetchone()
return row[0] or 0 return row[0] or 0
return self.runInteraction( return self.runInteraction("get_rooms", f)
"get_rooms", f
)
def _store_room_topic_txn(self, txn, event): def _store_room_topic_txn(self, txn, event):
if hasattr(event, "content") and "topic" in event.content: if hasattr(event, "content") and "topic" in event.content:
@ -414,7 +399,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
) )
self.store_event_search_txn( self.store_event_search_txn(
txn, event, "content.topic", event.content["topic"], txn, event, "content.topic", event.content["topic"]
) )
def _store_room_name_txn(self, txn, event): def _store_room_name_txn(self, txn, event):
@ -426,17 +411,17 @@ class RoomStore(RoomWorkerStore, SearchStore):
"event_id": event.event_id, "event_id": event.event_id,
"room_id": event.room_id, "room_id": event.room_id,
"name": event.content["name"], "name": event.content["name"],
} },
) )
self.store_event_search_txn( self.store_event_search_txn(
txn, event, "content.name", event.content["name"], txn, event, "content.name", event.content["name"]
) )
def _store_room_message_txn(self, txn, event): def _store_room_message_txn(self, txn, event):
if hasattr(event, "content") and "body" in event.content: if hasattr(event, "content") and "body" in event.content:
self.store_event_search_txn( self.store_event_search_txn(
txn, event, "content.body", event.content["body"], txn, event, "content.body", event.content["body"]
) )
def _store_history_visibility_txn(self, txn, event): def _store_history_visibility_txn(self, txn, event):
@ -452,14 +437,11 @@ class RoomStore(RoomWorkerStore, SearchStore):
" (event_id, room_id, %(key)s)" " (event_id, room_id, %(key)s)"
" VALUES (?, ?, ?)" % {"key": key} " VALUES (?, ?, ?)" % {"key": key}
) )
txn.execute(sql, ( txn.execute(sql, (event.event_id, event.room_id, event.content[key]))
event.event_id,
event.room_id,
event.content[key]
))
def add_event_report(self, room_id, event_id, user_id, reason, content, def add_event_report(
received_ts): self, room_id, event_id, user_id, reason, content, received_ts
):
next_id = self._event_reports_id_gen.get_next() next_id = self._event_reports_id_gen.get_next()
return self._simple_insert( return self._simple_insert(
table="event_reports", table="event_reports",
@ -472,7 +454,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
"reason": reason, "reason": reason,
"content": json.dumps(content), "content": json.dumps(content),
}, },
desc="add_event_report" desc="add_event_report",
) )
def get_current_public_room_stream_id(self): def get_current_public_room_stream_id(self):
@ -480,23 +462,21 @@ class RoomStore(RoomWorkerStore, SearchStore):
def get_all_new_public_rooms(self, prev_id, current_id, limit): def get_all_new_public_rooms(self, prev_id, current_id, limit):
def get_all_new_public_rooms(txn): def get_all_new_public_rooms(txn):
sql = (""" sql = """
SELECT stream_id, room_id, visibility, appservice_id, network_id SELECT stream_id, room_id, visibility, appservice_id, network_id
FROM public_room_list_stream FROM public_room_list_stream
WHERE stream_id > ? AND stream_id <= ? WHERE stream_id > ? AND stream_id <= ?
ORDER BY stream_id ASC ORDER BY stream_id ASC
LIMIT ? LIMIT ?
""") """
txn.execute(sql, (prev_id, current_id, limit,)) txn.execute(sql, (prev_id, current_id, limit))
return txn.fetchall() return txn.fetchall()
if prev_id == current_id: if prev_id == current_id:
return defer.succeed([]) return defer.succeed([])
return self.runInteraction( return self.runInteraction("get_all_new_public_rooms", get_all_new_public_rooms)
"get_all_new_public_rooms", get_all_new_public_rooms
)
@defer.inlineCallbacks @defer.inlineCallbacks
def block_room(self, room_id, user_id): def block_room(self, room_id, user_id):
@ -511,19 +491,16 @@ class RoomStore(RoomWorkerStore, SearchStore):
""" """
yield self._simple_upsert( yield self._simple_upsert(
table="blocked_rooms", table="blocked_rooms",
keyvalues={ keyvalues={"room_id": room_id},
"room_id": room_id,
},
values={}, values={},
insertion_values={ insertion_values={"user_id": user_id},
"user_id": user_id,
},
desc="block_room", desc="block_room",
) )
yield self.runInteraction( yield self.runInteraction(
"block_room_invalidation", "block_room_invalidation",
self._invalidate_cache_and_stream, self._invalidate_cache_and_stream,
self.is_room_blocked, (room_id,), self.is_room_blocked,
(room_id,),
) )
def get_media_mxcs_in_room(self, room_id): def get_media_mxcs_in_room(self, room_id):
@ -536,6 +513,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
The local and remote media as a lists of tuples where the key is The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID. the hostname and the value is the media ID.
""" """
def _get_media_mxcs_in_room_txn(txn): def _get_media_mxcs_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
local_media_mxcs = [] local_media_mxcs = []
@ -548,23 +526,28 @@ class RoomStore(RoomWorkerStore, SearchStore):
remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id)) remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
return local_media_mxcs, remote_media_mxcs return local_media_mxcs, remote_media_mxcs
return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn) return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
def quarantine_media_ids_in_room(self, room_id, quarantined_by): def quarantine_media_ids_in_room(self, room_id, quarantined_by):
"""For a room loops through all events with media and quarantines """For a room loops through all events with media and quarantines
the associated media the associated media
""" """
def _quarantine_media_in_room_txn(txn): def _quarantine_media_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
total_media_quarantined = 0 total_media_quarantined = 0
# Now update all the tables to set the quarantined_by flag # Now update all the tables to set the quarantined_by flag
txn.executemany(""" txn.executemany(
"""
UPDATE local_media_repository UPDATE local_media_repository
SET quarantined_by = ? SET quarantined_by = ?
WHERE media_id = ? WHERE media_id = ?
""", ((quarantined_by, media_id) for media_id in local_mxcs)) """,
((quarantined_by, media_id) for media_id in local_mxcs),
)
txn.executemany( txn.executemany(
""" """
@ -575,7 +558,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
( (
(quarantined_by, origin, media_id) (quarantined_by, origin, media_id)
for origin, media_id in remote_mxcs for origin, media_id in remote_mxcs
) ),
) )
total_media_quarantined += len(local_mxcs) total_media_quarantined += len(local_mxcs)
@ -584,8 +567,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
return total_media_quarantined return total_media_quarantined
return self.runInteraction( return self.runInteraction(
"quarantine_media_in_room", "quarantine_media_in_room", _quarantine_media_in_room_txn
_quarantine_media_in_room_txn,
) )
def _get_media_mxcs_in_room_txn(self, txn, room_id): def _get_media_mxcs_in_room_txn(self, txn, room_id):

View File

@ -30,10 +30,10 @@ from .background_updates import BackgroundUpdateStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SearchEntry = namedtuple('SearchEntry', [ SearchEntry = namedtuple(
'key', 'value', 'event_id', 'room_id', 'stream_ordering', 'SearchEntry',
'origin_server_ts', ['key', 'value', 'event_id', 'room_id', 'stream_ordering', 'origin_server_ts'],
]) )
class SearchStore(BackgroundUpdateStore): class SearchStore(BackgroundUpdateStore):
@ -53,8 +53,7 @@ class SearchStore(BackgroundUpdateStore):
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
) )
self.register_background_update_handler( self.register_background_update_handler(
self.EVENT_SEARCH_ORDER_UPDATE_NAME, self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order
self._background_reindex_search_order
) )
# we used to have a background update to turn the GIN index into a # we used to have a background update to turn the GIN index into a
@ -62,13 +61,10 @@ class SearchStore(BackgroundUpdateStore):
# a GIN index. However, it's possible that some people might still have # a GIN index. However, it's possible that some people might still have
# the background update queued, so we register a handler to clear the # the background update queued, so we register a handler to clear the
# background update. # background update.
self.register_noop_background_update( self.register_noop_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME)
self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME,
)
self.register_background_update_handler( self.register_background_update_handler(
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search
self._background_reindex_gin_search
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -138,21 +134,23 @@ class SearchStore(BackgroundUpdateStore):
# then skip over it # then skip over it
continue continue
event_search_rows.append(SearchEntry( event_search_rows.append(
SearchEntry(
key=key, key=key,
value=value, value=value,
event_id=event_id, event_id=event_id,
room_id=room_id, room_id=room_id,
stream_ordering=stream_ordering, stream_ordering=stream_ordering,
origin_server_ts=origin_server_ts, origin_server_ts=origin_server_ts,
)) )
)
self.store_search_entries_txn(txn, event_search_rows) self.store_search_entries_txn(txn, event_search_rows)
progress = { progress = {
"target_min_stream_id_inclusive": target_min_stream_id, "target_min_stream_id_inclusive": target_min_stream_id,
"max_stream_id_exclusive": min_stream_id, "max_stream_id_exclusive": min_stream_id,
"rows_inserted": rows_inserted + len(event_search_rows) "rows_inserted": rows_inserted + len(event_search_rows),
} }
self._background_update_progress_txn( self._background_update_progress_txn(
@ -191,6 +189,7 @@ class SearchStore(BackgroundUpdateStore):
# doesn't support CREATE INDEX IF EXISTS so we just catch the # doesn't support CREATE INDEX IF EXISTS so we just catch the
# exception and ignore it. # exception and ignore it.
import psycopg2 import psycopg2
try: try:
c.execute( c.execute(
"CREATE INDEX CONCURRENTLY event_search_fts_idx" "CREATE INDEX CONCURRENTLY event_search_fts_idx"
@ -198,14 +197,11 @@ class SearchStore(BackgroundUpdateStore):
) )
except psycopg2.ProgrammingError as e: except psycopg2.ProgrammingError as e:
logger.warn( logger.warn(
"Ignoring error %r when trying to switch from GIST to GIN", "Ignoring error %r when trying to switch from GIST to GIN", e
e
) )
# we should now be able to delete the GIST index. # we should now be able to delete the GIST index.
c.execute( c.execute("DROP INDEX IF EXISTS event_search_fts_idx_gist")
"DROP INDEX IF EXISTS event_search_fts_idx_gist"
)
finally: finally:
conn.set_session(autocommit=False) conn.set_session(autocommit=False)
@ -223,6 +219,7 @@ class SearchStore(BackgroundUpdateStore):
have_added_index = progress['have_added_indexes'] have_added_index = progress['have_added_indexes']
if not have_added_index: if not have_added_index:
def create_index(conn): def create_index(conn):
conn.rollback() conn.rollback()
conn.set_session(autocommit=True) conn.set_session(autocommit=True)
@ -248,7 +245,8 @@ class SearchStore(BackgroundUpdateStore):
yield self.runInteraction( yield self.runInteraction(
self.EVENT_SEARCH_ORDER_UPDATE_NAME, self.EVENT_SEARCH_ORDER_UPDATE_NAME,
self._background_update_progress_txn, self._background_update_progress_txn,
self.EVENT_SEARCH_ORDER_UPDATE_NAME, pg, self.EVENT_SEARCH_ORDER_UPDATE_NAME,
pg,
) )
def reindex_search_txn(txn): def reindex_search_txn(txn):
@ -302,14 +300,16 @@ class SearchStore(BackgroundUpdateStore):
""" """
self.store_search_entries_txn( self.store_search_entries_txn(
txn, txn,
(SearchEntry( (
SearchEntry(
key=key, key=key,
value=value, value=value,
event_id=event.event_id, event_id=event.event_id,
room_id=event.room_id, room_id=event.room_id,
stream_ordering=event.internal_metadata.stream_ordering, stream_ordering=event.internal_metadata.stream_ordering,
origin_server_ts=event.origin_server_ts, origin_server_ts=event.origin_server_ts,
),), ),
),
) )
def store_search_entries_txn(self, txn, entries): def store_search_entries_txn(self, txn, entries):
@ -329,10 +329,17 @@ class SearchStore(BackgroundUpdateStore):
" VALUES (?,?,?,to_tsvector('english', ?),?,?)" " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
) )
args = (( args = (
entry.event_id, entry.room_id, entry.key, entry.value, (
entry.stream_ordering, entry.origin_server_ts, entry.event_id,
) for entry in entries) entry.room_id,
entry.key,
entry.value,
entry.stream_ordering,
entry.origin_server_ts,
)
for entry in entries
)
# inserts to a GIN index are normally batched up into a pending # inserts to a GIN index are normally batched up into a pending
# list, and then all committed together once the list gets to a # list, and then all committed together once the list gets to a
@ -363,9 +370,10 @@ class SearchStore(BackgroundUpdateStore):
"INSERT INTO event_search (event_id, room_id, key, value)" "INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)" " VALUES (?,?,?,?)"
) )
args = (( args = (
entry.event_id, entry.room_id, entry.key, entry.value, (entry.event_id, entry.room_id, entry.key, entry.value)
) for entry in entries) for entry in entries
)
txn.executemany(sql, args) txn.executemany(sql, args)
else: else:
@ -394,9 +402,7 @@ class SearchStore(BackgroundUpdateStore):
# Make sure we don't explode because the person is in too many rooms. # Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless. # We filter the results below regardless.
if len(room_ids) < 500: if len(room_ids) < 500:
clauses.append( clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),))
"room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
)
args.extend(room_ids) args.extend(room_ids)
local_clauses = [] local_clauses = []
@ -404,9 +410,7 @@ class SearchStore(BackgroundUpdateStore):
local_clauses.append("key = ?") local_clauses.append("key = ?")
args.append(key) args.append(key)
clauses.append( clauses.append("(%s)" % (" OR ".join(local_clauses),))
"(%s)" % (" OR ".join(local_clauses),)
)
count_args = args count_args = args
count_clauses = clauses count_clauses = clauses
@ -452,18 +456,13 @@ class SearchStore(BackgroundUpdateStore):
# entire table from the database. # entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500" sql += " ORDER BY rank DESC LIMIT 500"
results = yield self._execute( results = yield self._execute("search_msgs", self.cursor_to_dict, sql, *args)
"search_msgs", self.cursor_to_dict, sql, *args
)
results = list(filter(lambda row: row["room_id"] in room_ids, results)) results = list(filter(lambda row: row["room_id"] in room_ids, results))
events = yield self._get_events([r["event_id"] for r in results]) events = yield self._get_events([r["event_id"] for r in results])
event_map = { event_map = {ev.event_id: ev for ev in events}
ev.event_id: ev
for ev in events
}
highlights = None highlights = None
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
@ -477,18 +476,17 @@ class SearchStore(BackgroundUpdateStore):
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
defer.returnValue({ defer.returnValue(
"results": [
{ {
"event": event_map[r["event_id"]], "results": [
"rank": r["rank"], {"event": event_map[r["event_id"]], "rank": r["rank"]}
}
for r in results for r in results
if r["event_id"] in event_map if r["event_id"] in event_map
], ],
"highlights": highlights, "highlights": highlights,
"count": count, "count": count,
}) }
)
@defer.inlineCallbacks @defer.inlineCallbacks
def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None): def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None):
@ -513,9 +511,7 @@ class SearchStore(BackgroundUpdateStore):
# Make sure we don't explode because the person is in too many rooms. # Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless. # We filter the results below regardless.
if len(room_ids) < 500: if len(room_ids) < 500:
clauses.append( clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),))
"room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
)
args.extend(room_ids) args.extend(room_ids)
local_clauses = [] local_clauses = []
@ -523,9 +519,7 @@ class SearchStore(BackgroundUpdateStore):
local_clauses.append("key = ?") local_clauses.append("key = ?")
args.append(key) args.append(key)
clauses.append( clauses.append("(%s)" % (" OR ".join(local_clauses),))
"(%s)" % (" OR ".join(local_clauses),)
)
# take copies of the current args and clauses lists, before adding # take copies of the current args and clauses lists, before adding
# pagination clauses to main query. # pagination clauses to main query.
@ -607,18 +601,13 @@ class SearchStore(BackgroundUpdateStore):
args.append(limit) args.append(limit)
results = yield self._execute( results = yield self._execute("search_rooms", self.cursor_to_dict, sql, *args)
"search_rooms", self.cursor_to_dict, sql, *args
)
results = list(filter(lambda row: row["room_id"] in room_ids, results)) results = list(filter(lambda row: row["room_id"] in room_ids, results))
events = yield self._get_events([r["event_id"] for r in results]) events = yield self._get_events([r["event_id"] for r in results])
event_map = { event_map = {ev.event_id: ev for ev in events}
ev.event_id: ev
for ev in events
}
highlights = None highlights = None
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
@ -632,21 +621,22 @@ class SearchStore(BackgroundUpdateStore):
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
defer.returnValue({ defer.returnValue(
{
"results": [ "results": [
{ {
"event": event_map[r["event_id"]], "event": event_map[r["event_id"]],
"rank": r["rank"], "rank": r["rank"],
"pagination_token": "%s,%s" % ( "pagination_token": "%s,%s"
r["origin_server_ts"], r["stream_ordering"] % (r["origin_server_ts"], r["stream_ordering"]),
),
} }
for r in results for r in results
if r["event_id"] in event_map if r["event_id"] in event_map
], ],
"highlights": highlights, "highlights": highlights,
"count": count, "count": count,
}) }
)
def _find_highlights_in_postgres(self, search_query, events): def _find_highlights_in_postgres(self, search_query, events):
"""Given a list of events and a search term, return a list of words """Given a list of events and a search term, return a list of words
@ -662,6 +652,7 @@ class SearchStore(BackgroundUpdateStore):
Returns: Returns:
deferred : A set of strings. deferred : A set of strings.
""" """
def f(txn): def f(txn):
highlight_words = set() highlight_words = set()
for event in events: for event in events:
@ -689,13 +680,15 @@ class SearchStore(BackgroundUpdateStore):
stop_sel += ">" stop_sel += ">"
query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % ( query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % (
_to_postgres_options({ _to_postgres_options(
{
"StartSel": start_sel, "StartSel": start_sel,
"StopSel": stop_sel, "StopSel": stop_sel,
"MaxFragments": "50", "MaxFragments": "50",
}) }
) )
txn.execute(query, (value, search_query,)) )
txn.execute(query, (value, search_query))
headline, = txn.fetchall()[0] headline, = txn.fetchall()[0]
# Now we need to pick the possible highlights out of the haedline # Now we need to pick the possible highlights out of the haedline
@ -714,9 +707,7 @@ class SearchStore(BackgroundUpdateStore):
def _to_postgres_options(options_dict): def _to_postgres_options(options_dict):
return "'%s'" % ( return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
",".join("%s=%s" % (k, v) for k, v in options_dict.items()),
)
def _parse_query(database_engine, search_term): def _parse_query(database_engine, search_term):

View File

@ -39,8 +39,9 @@ class SignatureWorkerStore(SQLBaseStore):
# to use its cache # to use its cache
raise NotImplementedError() raise NotImplementedError()
@cachedList(cached_method_name="get_event_reference_hash", @cachedList(
list_name="event_ids", num_args=1) cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1
)
def get_event_reference_hashes(self, event_ids): def get_event_reference_hashes(self, event_ids):
def f(txn): def f(txn):
return { return {
@ -48,21 +49,13 @@ class SignatureWorkerStore(SQLBaseStore):
for event_id in event_ids for event_id in event_ids
} }
return self.runInteraction( return self.runInteraction("get_event_reference_hashes", f)
"get_event_reference_hashes",
f
)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_event_hashes(self, event_ids): def add_event_hashes(self, event_ids):
hashes = yield self.get_event_reference_hashes( hashes = yield self.get_event_reference_hashes(event_ids)
event_ids
)
hashes = { hashes = {
e_id: { e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
k: encode_base64(v) for k, v in h.items()
if k == "sha256"
}
for e_id, h in hashes.items() for e_id, h in hashes.items()
} }
@ -98,14 +91,12 @@ class SignatureStore(SignatureWorkerStore):
vals = [] vals = []
for event in events: for event in events:
ref_alg, ref_hash_bytes = compute_event_reference_hash(event) ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
vals.append({ vals.append(
{
"event_id": event.event_id, "event_id": event.event_id,
"algorithm": ref_alg, "algorithm": ref_alg,
"hash": db_binary_type(ref_hash_bytes), "hash": db_binary_type(ref_hash_bytes),
}) }
self._simple_insert_many_txn(
txn,
table="event_reference_hashes",
values=vals,
) )
self._simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)

View File

@ -40,10 +40,13 @@ logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100 MAX_STATE_DELTA_HOPS = 100
class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))): class _GetStateGroupDelta(
namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
):
"""Return type of get_state_group_delta that implements __len__, which lets """Return type of get_state_group_delta that implements __len__, which lets
us use the itrable flag when caching us use the itrable flag when caching
""" """
__slots__ = [] __slots__ = []
def __len__(self): def __len__(self):
@ -70,10 +73,7 @@ class StateFilter(object):
# If `include_others` is set we canonicalise the filter by removing # If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary # wildcards from the types dictionary
if self.include_others: if self.include_others:
self.types = { self.types = {k: v for k, v in iteritems(self.types) if v is not None}
k: v for k, v in iteritems(self.types)
if v is not None
}
@staticmethod @staticmethod
def all(): def all():
@ -130,10 +130,7 @@ class StateFilter(object):
Returns: Returns:
StateFilter StateFilter
""" """
return StateFilter( return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
types={EventTypes.Member: set(members)},
include_others=True,
)
def return_expanded(self): def return_expanded(self):
"""Creates a new StateFilter where type wild cards have been removed """Creates a new StateFilter where type wild cards have been removed
@ -243,9 +240,7 @@ class StateFilter(object):
if where_clause: if where_clause:
where_clause += " OR " where_clause += " OR "
where_clause += "type NOT IN (%s)" % ( where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),)
",".join(["?"] * len(self.types)),
)
where_args.extend(self.types) where_args.extend(self.types)
return where_clause, where_args return where_clause, where_args
@ -305,12 +300,8 @@ class StateFilter(object):
bool bool
""" """
return ( return self.include_others or any(
self.include_others state_keys is None for state_keys in itervalues(self.types)
or any(
state_keys is None
for state_keys in itervalues(self.types)
)
) )
def concrete_types(self): def concrete_types(self):
@ -406,11 +397,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
self._state_group_cache = DictionaryCache( self._state_group_cache = DictionaryCache(
"*stateGroupCache*", "*stateGroupCache*",
# TODO: this hasn't been tuned yet # TODO: this hasn't been tuned yet
50000 * get_cache_factor_for("stateGroupCache") 50000 * get_cache_factor_for("stateGroupCache"),
) )
self._state_group_members_cache = DictionaryCache( self._state_group_members_cache = DictionaryCache(
"*stateGroupMembersCache*", "*stateGroupMembersCache*",
500000 * get_cache_factor_for("stateGroupMembersCache") 500000 * get_cache_factor_for("stateGroupMembersCache"),
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -488,22 +479,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns: Returns:
deferred: dict of (type, state_key) -> event_id deferred: dict of (type, state_key) -> event_id
""" """
def _get_current_state_ids_txn(txn): def _get_current_state_ids_txn(txn):
txn.execute( txn.execute(
"""SELECT type, state_key, event_id FROM current_state_events """SELECT type, state_key, event_id FROM current_state_events
WHERE room_id = ? WHERE room_id = ?
""", """,
(room_id,) (room_id,),
) )
return { return {
(intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
} }
return self.runInteraction( return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn)
"get_current_state_ids",
_get_current_state_ids_txn,
)
# FIXME: how should this be cached? # FIXME: how should this be cached?
def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()): def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
@ -544,8 +533,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return results return results
return self.runInteraction( return self.runInteraction(
"get_filtered_current_state_ids", "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
_get_filtered_current_state_ids_txn,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -559,9 +547,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Deferred[str|None]: The canonical alias, if any Deferred[str|None]: The canonical alias, if any
""" """
state = yield self.get_filtered_current_state_ids(room_id, StateFilter.from_types( state = yield self.get_filtered_current_state_ids(
[(EventTypes.CanonicalAlias, "")] room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
)) )
event_id = state.get((EventTypes.CanonicalAlias, "")) event_id = state.get((EventTypes.CanonicalAlias, ""))
if not event_id: if not event_id:
@ -581,13 +569,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns: Returns:
(prev_group, delta_ids), where both may be None. (prev_group, delta_ids), where both may be None.
""" """
def _get_state_group_delta_txn(txn): def _get_state_group_delta_txn(txn):
prev_group = self._simple_select_one_onecol_txn( prev_group = self._simple_select_one_onecol_txn(
txn, txn,
table="state_group_edges", table="state_group_edges",
keyvalues={ keyvalues={"state_group": state_group},
"state_group": state_group,
},
retcol="prev_state_group", retcol="prev_state_group",
allow_none=True, allow_none=True,
) )
@ -598,21 +585,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
delta_ids = self._simple_select_list_txn( delta_ids = self._simple_select_list_txn(
txn, txn,
table="state_groups_state", table="state_groups_state",
keyvalues={ keyvalues={"state_group": state_group},
"state_group": state_group, retcols=("type", "state_key", "event_id"),
},
retcols=("type", "state_key", "event_id",)
) )
return _GetStateGroupDelta(prev_group, { return _GetStateGroupDelta(
(row["type"], row["state_key"]): row["event_id"] prev_group,
for row in delta_ids {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
})
return self.runInteraction(
"get_state_group_delta",
_get_state_group_delta_txn,
) )
return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_groups_ids(self, _room_id, event_ids): def get_state_groups_ids(self, _room_id, event_ids):
"""Get the event IDs of all the state for the state groups for the given events """Get the event IDs of all the state for the state groups for the given events
@ -628,9 +611,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
if not event_ids: if not event_ids:
defer.returnValue({}) defer.returnValue({})
event_to_groups = yield self._get_state_group_for_events( event_to_groups = yield self._get_state_group_for_events(event_ids)
event_ids,
)
groups = set(itervalues(event_to_groups)) groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups) group_to_state = yield self._get_state_for_groups(groups)
@ -666,19 +647,23 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
state_event_map = yield self.get_events( state_event_map = yield self.get_events(
[ [
ev_id for group_ids in itervalues(group_to_ids) ev_id
for group_ids in itervalues(group_to_ids)
for ev_id in itervalues(group_ids) for ev_id in itervalues(group_ids)
], ],
get_prev_content=False get_prev_content=False,
) )
defer.returnValue({ defer.returnValue(
{
group: [ group: [
state_event_map[v] for v in itervalues(event_id_map) state_event_map[v]
for v in itervalues(event_id_map)
if v in state_event_map if v in state_event_map
] ]
for group, event_id_map in iteritems(group_to_ids) for group, event_id_map in iteritems(group_to_ids)
}) }
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_state_groups_from_groups(self, groups, state_filter): def _get_state_groups_from_groups(self, groups, state_filter):
@ -699,14 +684,16 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
for chunk in chunks: for chunk in chunks:
res = yield self.runInteraction( res = yield self.runInteraction(
"_get_state_groups_from_groups", "_get_state_groups_from_groups",
self._get_state_groups_from_groups_txn, chunk, state_filter, self._get_state_groups_from_groups_txn,
chunk,
state_filter,
) )
results.update(res) results.update(res)
defer.returnValue(results) defer.returnValue(results)
def _get_state_groups_from_groups_txn( def _get_state_groups_from_groups_txn(
self, txn, groups, state_filter=StateFilter.all(), self, txn, groups, state_filter=StateFilter.all()
): ):
results = {group: {} for group in groups} results = {group: {} for group in groups}
@ -776,7 +763,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
txn.execute( txn.execute(
"SELECT type, state_key, event_id FROM state_groups_state" "SELECT type, state_key, event_id FROM state_groups_state"
" WHERE state_group = ? " + where_clause, " WHERE state_group = ? " + where_clause,
args args,
) )
results[group].update( results[group].update(
((typ, state_key), event_id) ((typ, state_key), event_id)
@ -791,8 +778,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# wildcards (i.e. Nones) in which case we have to do an exhaustive # wildcards (i.e. Nones) in which case we have to do an exhaustive
# search # search
if ( if (
max_entries_returned is not None and max_entries_returned is not None
len(results[group]) == max_entries_returned and len(results[group]) == max_entries_returned
): ):
break break
@ -819,16 +806,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns: Returns:
deferred: A dict of (event_id) -> (type, state_key) -> [state_events] deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
""" """
event_to_groups = yield self._get_state_group_for_events( event_to_groups = yield self._get_state_group_for_events(event_ids)
event_ids,
)
groups = set(itervalues(event_to_groups)) groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups, state_filter) group_to_state = yield self._get_state_for_groups(groups, state_filter)
state_event_map = yield self.get_events( state_event_map = yield self.get_events(
[ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
get_prev_content=False get_prev_content=False,
) )
event_to_state = { event_to_state = {
@ -856,9 +841,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns: Returns:
A deferred dict from event_id -> (type, state_key) -> event_id A deferred dict from event_id -> (type, state_key) -> event_id
""" """
event_to_groups = yield self._get_state_group_for_events( event_to_groups = yield self._get_state_group_for_events(event_ids)
event_ids,
)
groups = set(itervalues(event_to_groups)) groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups, state_filter) group_to_state = yield self._get_state_for_groups(groups, state_filter)
@ -906,16 +889,18 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
def _get_state_group_for_event(self, event_id): def _get_state_group_for_event(self, event_id):
return self._simple_select_one_onecol( return self._simple_select_one_onecol(
table="event_to_state_groups", table="event_to_state_groups",
keyvalues={ keyvalues={"event_id": event_id},
"event_id": event_id,
},
retcol="state_group", retcol="state_group",
allow_none=True, allow_none=True,
desc="_get_state_group_for_event", desc="_get_state_group_for_event",
) )
@cachedList(cached_method_name="_get_state_group_for_event", @cachedList(
list_name="event_ids", num_args=1, inlineCallbacks=True) cached_method_name="_get_state_group_for_event",
list_name="event_ids",
num_args=1,
inlineCallbacks=True,
)
def _get_state_group_for_events(self, event_ids): def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group """Returns mapping event_id -> state_group
""" """
@ -924,7 +909,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
column="event_id", column="event_id",
iterable=event_ids, iterable=event_ids,
keyvalues={}, keyvalues={},
retcols=("event_id", "state_group",), retcols=("event_id", "state_group"),
desc="_get_state_group_for_events", desc="_get_state_group_for_events",
) )
@ -989,15 +974,13 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# Now we look them up in the member and non-member caches # Now we look them up in the member and non-member caches
non_member_state, incomplete_groups_nm, = ( non_member_state, incomplete_groups_nm, = (
yield self._get_state_for_groups_using_cache( yield self._get_state_for_groups_using_cache(
groups, self._state_group_cache, groups, self._state_group_cache, state_filter=non_member_filter
state_filter=non_member_filter,
) )
) )
member_state, incomplete_groups_m, = ( member_state, incomplete_groups_m, = (
yield self._get_state_for_groups_using_cache( yield self._get_state_for_groups_using_cache(
groups, self._state_group_members_cache, groups, self._state_group_members_cache, state_filter=member_filter
state_filter=member_filter,
) )
) )
@ -1019,8 +1002,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
db_state_filter = state_filter.return_expanded() db_state_filter = state_filter.return_expanded()
group_to_state_dict = yield self._get_state_groups_from_groups( group_to_state_dict = yield self._get_state_groups_from_groups(
list(incomplete_groups), list(incomplete_groups), state_filter=db_state_filter
state_filter=db_state_filter,
) )
# Now lets update the caches # Now lets update the caches
@ -1040,9 +1022,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue(state) defer.returnValue(state)
def _get_state_for_groups_using_cache( def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
self, groups, cache, state_filter,
):
"""Gets the state at each of a list of state groups, optionally """Gets the state at each of a list of state groups, optionally
filtering by type/state_key, querying from a specific cache. filtering by type/state_key, querying from a specific cache.
@ -1074,8 +1054,13 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return results, incomplete_groups return results, incomplete_groups
def _insert_into_cache(self, group_to_state_dict, state_filter, def _insert_into_cache(
cache_seq_num_members, cache_seq_num_non_members): self,
group_to_state_dict,
state_filter,
cache_seq_num_members,
cache_seq_num_non_members,
):
"""Inserts results from querying the database into the relevant cache. """Inserts results from querying the database into the relevant cache.
Args: Args:
@ -1132,8 +1117,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
fetched_keys=non_member_types, fetched_keys=non_member_types,
) )
def store_state_group(self, event_id, room_id, prev_group, delta_ids, def store_state_group(
current_state_ids): self, event_id, room_id, prev_group, delta_ids, current_state_ids
):
"""Store a new set of state, returning a newly assigned state group. """Store a new set of state, returning a newly assigned state group.
Args: Args:
@ -1149,6 +1135,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns: Returns:
Deferred[int]: The state group ID Deferred[int]: The state group ID
""" """
def _store_state_group_txn(txn): def _store_state_group_txn(txn):
if current_state_ids is None: if current_state_ids is None:
# AFAIK, this can never happen # AFAIK, this can never happen
@ -1159,11 +1146,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
table="state_groups", table="state_groups",
values={ values={"id": state_group, "room_id": room_id, "event_id": event_id},
"id": state_group,
"room_id": room_id,
"event_id": event_id,
},
) )
# We persist as a delta if we can, while also ensuring the chain # We persist as a delta if we can, while also ensuring the chain
@ -1182,17 +1165,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
% (prev_group,) % (prev_group,)
) )
potential_hops = self._count_state_group_hops_txn( potential_hops = self._count_state_group_hops_txn(txn, prev_group)
txn, prev_group
)
if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
table="state_group_edges", table="state_group_edges",
values={ values={"state_group": state_group, "prev_state_group": prev_group},
"state_group": state_group,
"prev_state_group": prev_group,
},
) )
self._simple_insert_many_txn( self._simple_insert_many_txn(
@ -1264,7 +1242,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
This is used to ensure the delta chains don't get too long. This is used to ensure the delta chains don't get too long.
""" """
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
sql = (""" sql = """
WITH RECURSIVE state(state_group) AS ( WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint) VALUES(?::bigint)
UNION ALL UNION ALL
@ -1272,7 +1250,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
WHERE s.state_group = e.state_group WHERE s.state_group = e.state_group
) )
SELECT count(*) FROM state; SELECT count(*) FROM state;
""") """
txn.execute(sql, (state_group,)) txn.execute(sql, (state_group,))
row = txn.fetchone() row = txn.fetchone()
@ -1331,8 +1309,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
self._background_deduplicate_state, self._background_deduplicate_state,
) )
self.register_background_update_handler( self.register_background_update_handler(
self.STATE_GROUP_INDEX_UPDATE_NAME, self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
self._background_index_state,
) )
self.register_background_index_update( self.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME, self.CURRENT_STATE_INDEX_UPDATE_NAME,
@ -1366,18 +1343,14 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
txn, txn,
table="event_to_state_groups", table="event_to_state_groups",
values=[ values=[
{ {"state_group": state_group_id, "event_id": event_id}
"state_group": state_group_id,
"event_id": event_id,
}
for event_id, state_group_id in iteritems(state_groups) for event_id, state_group_id in iteritems(state_groups)
], ],
) )
for event_id, state_group_id in iteritems(state_groups): for event_id, state_group_id in iteritems(state_groups):
txn.call_after( txn.call_after(
self._get_state_group_for_event.prefill, self._get_state_group_for_event.prefill, (event_id,), state_group_id
(event_id,), state_group_id
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -1395,7 +1368,8 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
if max_group is None: if max_group is None:
rows = yield self._execute( rows = yield self._execute(
"_background_deduplicate_state", None, "_background_deduplicate_state",
None,
"SELECT coalesce(max(id), 0) FROM state_groups", "SELECT coalesce(max(id), 0) FROM state_groups",
) )
max_group = rows[0][0] max_group = rows[0][0]
@ -1408,7 +1382,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
" WHERE ? < id AND id <= ?" " WHERE ? < id AND id <= ?"
" ORDER BY id ASC" " ORDER BY id ASC"
" LIMIT 1", " LIMIT 1",
(new_last_state_group, max_group,) (new_last_state_group, max_group),
) )
row = txn.fetchone() row = txn.fetchone()
if row: if row:
@ -1420,7 +1394,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
txn.execute( txn.execute(
"SELECT state_group FROM state_group_edges" "SELECT state_group FROM state_group_edges"
" WHERE state_group = ?", " WHERE state_group = ?",
(state_group,) (state_group,),
) )
# If we reach a point where we've already started inserting # If we reach a point where we've already started inserting
@ -1431,27 +1405,25 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
txn.execute( txn.execute(
"SELECT coalesce(max(id), 0) FROM state_groups" "SELECT coalesce(max(id), 0) FROM state_groups"
" WHERE id < ? AND room_id = ?", " WHERE id < ? AND room_id = ?",
(state_group, room_id,) (state_group, room_id),
) )
prev_group, = txn.fetchone() prev_group, = txn.fetchone()
new_last_state_group = state_group new_last_state_group = state_group
if prev_group: if prev_group:
potential_hops = self._count_state_group_hops_txn( potential_hops = self._count_state_group_hops_txn(txn, prev_group)
txn, prev_group
)
if potential_hops >= MAX_STATE_DELTA_HOPS: if potential_hops >= MAX_STATE_DELTA_HOPS:
# We want to ensure chains are at most this long,# # We want to ensure chains are at most this long,#
# otherwise read performance degrades. # otherwise read performance degrades.
continue continue
prev_state = self._get_state_groups_from_groups_txn( prev_state = self._get_state_groups_from_groups_txn(
txn, [prev_group], txn, [prev_group]
) )
prev_state = prev_state[prev_group] prev_state = prev_state[prev_group]
curr_state = self._get_state_groups_from_groups_txn( curr_state = self._get_state_groups_from_groups_txn(
txn, [state_group], txn, [state_group]
) )
curr_state = curr_state[state_group] curr_state = curr_state[state_group]
@ -1460,16 +1432,15 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
# of keys # of keys
delta_state = { delta_state = {
key: value for key, value in iteritems(curr_state) key: value
for key, value in iteritems(curr_state)
if prev_state.get(key, None) != value if prev_state.get(key, None) != value
} }
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
table="state_group_edges", table="state_group_edges",
keyvalues={ keyvalues={"state_group": state_group},
"state_group": state_group,
}
) )
self._simple_insert_txn( self._simple_insert_txn(
@ -1478,15 +1449,13 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
values={ values={
"state_group": state_group, "state_group": state_group,
"prev_state_group": prev_group, "prev_state_group": prev_group,
} },
) )
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
table="state_groups_state", table="state_groups_state",
keyvalues={ keyvalues={"state_group": state_group},
"state_group": state_group,
}
) )
self._simple_insert_many_txn( self._simple_insert_many_txn(
@ -1521,7 +1490,9 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
) )
if finished: if finished:
yield self._end_background_update(self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME) yield self._end_background_update(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
)
defer.returnValue(result * BATCH_SIZE_SCALE_FACTOR) defer.returnValue(result * BATCH_SIZE_SCALE_FACTOR)
@ -1538,9 +1509,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
"CREATE INDEX CONCURRENTLY state_groups_state_type_idx" "CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
" ON state_groups_state(state_group, type, state_key)" " ON state_groups_state(state_group, type, state_key)"
) )
txn.execute( txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
"DROP INDEX IF EXISTS state_groups_state_id"
)
finally: finally:
conn.set_session(autocommit=False) conn.set_session(autocommit=False)
else: else:
@ -1549,9 +1518,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
"CREATE INDEX state_groups_state_type_idx" "CREATE INDEX state_groups_state_type_idx"
" ON state_groups_state(state_group, type, state_key)" " ON state_groups_state(state_group, type, state_key)"
) )
txn.execute( txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
"DROP INDEX IF EXISTS state_groups_state_id"
)
yield self.runWithConnection(reindex_txn) yield self.runWithConnection(reindex_txn)

View File

@ -21,10 +21,11 @@ logger = logging.getLogger(__name__)
class StateDeltasStore(SQLBaseStore): class StateDeltasStore(SQLBaseStore):
def get_current_state_deltas(self, prev_stream_id): def get_current_state_deltas(self, prev_stream_id):
prev_stream_id = int(prev_stream_id) prev_stream_id = int(prev_stream_id)
if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id): if not self._curr_state_delta_stream_cache.has_any_entity_changed(
prev_stream_id
):
return [] return []
def get_current_state_deltas_txn(txn): def get_current_state_deltas_txn(txn):
@ -58,7 +59,7 @@ class StateDeltasStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ? WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC ORDER BY stream_id ASC
""" """
txn.execute(sql, (prev_stream_id, max_stream_id,)) txn.execute(sql, (prev_stream_id, max_stream_id))
return self.cursor_to_dict(txn) return self.cursor_to_dict(txn)
return self.runInteraction( return self.runInteraction(

View File

@ -59,9 +59,9 @@ _TOPOLOGICAL_TOKEN = "topological"
# Used as return values for pagination APIs # Used as return values for pagination APIs
_EventDictReturn = namedtuple("_EventDictReturn", ( _EventDictReturn = namedtuple(
"event_id", "topological_ordering", "stream_ordering", "_EventDictReturn", ("event_id", "topological_ordering", "stream_ordering")
)) )
def lower_bound(token, engine, inclusive=False): def lower_bound(token, engine, inclusive=False):
@ -74,13 +74,20 @@ def lower_bound(token, engine, inclusive=False):
# as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we # as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
# use the later form when running against postgres. # use the later form when running against postgres.
return "((%d,%d) <%s (%s,%s))" % ( return "((%d,%d) <%s (%s,%s))" % (
token.topological, token.stream, inclusive, token.topological,
"topological_ordering", "stream_ordering", token.stream,
inclusive,
"topological_ordering",
"stream_ordering",
) )
return "(%d < %s OR (%d = %s AND %d <%s %s))" % ( return "(%d < %s OR (%d = %s AND %d <%s %s))" % (
token.topological, "topological_ordering", token.topological,
token.topological, "topological_ordering", "topological_ordering",
token.stream, inclusive, "stream_ordering", token.topological,
"topological_ordering",
token.stream,
inclusive,
"stream_ordering",
) )
@ -94,13 +101,20 @@ def upper_bound(token, engine, inclusive=True):
# as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we # as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we
# use the later form when running against postgres. # use the later form when running against postgres.
return "((%d,%d) >%s (%s,%s))" % ( return "((%d,%d) >%s (%s,%s))" % (
token.topological, token.stream, inclusive, token.topological,
"topological_ordering", "stream_ordering", token.stream,
inclusive,
"topological_ordering",
"stream_ordering",
) )
return "(%d > %s OR (%d = %s AND %d >%s %s))" % ( return "(%d > %s OR (%d = %s AND %d >%s %s))" % (
token.topological, "topological_ordering", token.topological,
token.topological, "topological_ordering", "topological_ordering",
token.stream, inclusive, "stream_ordering", token.topological,
"topological_ordering",
token.stream,
inclusive,
"stream_ordering",
) )
@ -116,9 +130,7 @@ def filter_to_clause(event_filter):
args = [] args = []
if event_filter.types: if event_filter.types:
clauses.append( clauses.append("(%s)" % " OR ".join("type = ?" for _ in event_filter.types))
"(%s)" % " OR ".join("type = ?" for _ in event_filter.types)
)
args.extend(event_filter.types) args.extend(event_filter.types)
for typ in event_filter.not_types: for typ in event_filter.not_types:
@ -126,9 +138,7 @@ def filter_to_clause(event_filter):
args.append(typ) args.append(typ)
if event_filter.senders: if event_filter.senders:
clauses.append( clauses.append("(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders))
"(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders)
)
args.extend(event_filter.senders) args.extend(event_filter.senders)
for sender in event_filter.not_senders: for sender in event_filter.not_senders:
@ -136,9 +146,7 @@ def filter_to_clause(event_filter):
args.append(sender) args.append(sender)
if event_filter.rooms: if event_filter.rooms:
clauses.append( clauses.append("(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms))
"(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms)
)
args.extend(event_filter.rooms) args.extend(event_filter.rooms)
for room_id in event_filter.not_rooms: for room_id in event_filter.not_rooms:
@ -165,17 +173,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
events_max = self.get_room_max_stream_ordering() events_max = self.get_room_max_stream_ordering()
event_cache_prefill, min_event_val = self._get_cache_dict( event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events", db_conn,
"events",
entity_column="room_id", entity_column="room_id",
stream_column="stream_ordering", stream_column="stream_ordering",
max_value=events_max, max_value=events_max,
) )
self._events_stream_cache = StreamChangeCache( self._events_stream_cache = StreamChangeCache(
"EventsRoomStreamChangeCache", min_event_val, "EventsRoomStreamChangeCache",
min_event_val,
prefilled_cache=event_cache_prefill, prefilled_cache=event_cache_prefill,
) )
self._membership_stream_cache = StreamChangeCache( self._membership_stream_cache = StreamChangeCache(
"MembershipStreamChangeCache", events_max, "MembershipStreamChangeCache", events_max
) )
self._stream_order_on_start = self.get_room_max_stream_ordering() self._stream_order_on_start = self.get_room_max_stream_ordering()
@ -189,8 +199,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
raise NotImplementedError() raise NotImplementedError()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0, def get_room_events_stream_for_rooms(
order='DESC'): self, room_ids, from_key, to_key, limit=0, order='DESC'
):
"""Get new room events in stream ordering since `from_key`. """Get new room events in stream ordering since `from_key`.
Args: Args:
@ -222,13 +233,22 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
results = {} results = {}
room_ids = list(room_ids) room_ids = list(room_ids)
for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)): for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)):
res = yield make_deferred_yieldable(defer.gatherResults([ res = yield make_deferred_yieldable(
defer.gatherResults(
[
run_in_background( run_in_background(
self.get_room_events_stream_for_room, self.get_room_events_stream_for_room,
room_id, from_key, to_key, limit, order=order, room_id,
from_key,
to_key,
limit,
order=order,
) )
for room_id in rm_ids for room_id in rm_ids
], consumeErrors=True)) ],
consumeErrors=True,
)
)
results.update(dict(zip(rm_ids, res))) results.update(dict(zip(rm_ids, res)))
defer.returnValue(results) defer.returnValue(results)
@ -243,13 +263,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
""" """
from_key = RoomStreamToken.parse_stream_token(from_key).stream from_key = RoomStreamToken.parse_stream_token(from_key).stream
return set( return set(
room_id for room_id in room_ids room_id
for room_id in room_ids
if self._events_stream_cache.has_entity_changed(room_id, from_key) if self._events_stream_cache.has_entity_changed(room_id, from_key)
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0, def get_room_events_stream_for_room(
order='DESC'): self, room_id, from_key, to_key, limit=0, order='DESC'
):
"""Get new room events in stream ordering since `from_key`. """Get new room events in stream ordering since `from_key`.
@ -297,10 +319,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = yield self.runInteraction("get_room_events_stream_for_room", f) rows = yield self.runInteraction("get_room_events_stream_for_room", f)
ret = yield self._get_events( ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True)
[r.event_id for r in rows],
get_prev_content=True
)
self._set_before_and_after(ret, rows, topo_order=from_id is None) self._set_before_and_after(ret, rows, topo_order=from_id is None)
@ -340,7 +359,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
" AND e.stream_ordering > ? AND e.stream_ordering <= ?" " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
" ORDER BY e.stream_ordering ASC" " ORDER BY e.stream_ordering ASC"
) )
txn.execute(sql, (user_id, from_id, to_id,)) txn.execute(sql, (user_id, from_id, to_id))
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
@ -348,10 +367,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = yield self.runInteraction("get_membership_changes_for_user", f) rows = yield self.runInteraction("get_membership_changes_for_user", f)
ret = yield self._get_events( ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True)
[r.event_id for r in rows],
get_prev_content=True
)
self._set_before_and_after(ret, rows, topo_order=False) self._set_before_and_after(ret, rows, topo_order=False)
@ -374,13 +390,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
""" """
rows, token = yield self.get_recent_event_ids_for_room( rows, token = yield self.get_recent_event_ids_for_room(
room_id, limit, end_token, room_id, limit, end_token
) )
logger.debug("stream before") logger.debug("stream before")
events = yield self._get_events( events = yield self._get_events(
[r.event_id for r in rows], [r.event_id for r in rows], get_prev_content=True
get_prev_content=True
) )
logger.debug("stream after") logger.debug("stream after")
@ -410,8 +425,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
end_token = RoomStreamToken.parse(end_token) end_token = RoomStreamToken.parse(end_token)
rows, token = yield self.runInteraction( rows, token = yield self.runInteraction(
"get_recent_event_ids_for_room", self._paginate_room_events_txn, "get_recent_event_ids_for_room",
room_id, from_token=end_token, limit=limit, self._paginate_room_events_txn,
room_id,
from_token=end_token,
limit=limit,
) )
# We want to return the results in ascending order. # We want to return the results in ascending order.
@ -430,6 +448,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Deferred[(int, int, str)]: Deferred[(int, int, str)]:
(stream ordering, topological ordering, event_id) (stream ordering, topological ordering, event_id)
""" """
def _f(txn): def _f(txn):
sql = ( sql = (
"SELECT stream_ordering, topological_ordering, event_id" "SELECT stream_ordering, topological_ordering, event_id"
@ -439,12 +458,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
" ORDER BY stream_ordering" " ORDER BY stream_ordering"
" LIMIT 1" " LIMIT 1"
) )
txn.execute(sql, (room_id, stream_ordering, )) txn.execute(sql, (room_id, stream_ordering))
return txn.fetchone() return txn.fetchone()
return self.runInteraction( return self.runInteraction("get_room_event_after_stream_ordering", _f)
"get_room_event_after_stream_ordering", _f,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_events_max_id(self, room_id=None): def get_room_events_max_id(self, room_id=None):
@ -459,8 +476,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue("s%d" % (token,)) defer.returnValue("s%d" % (token,))
else: else:
topo = yield self.runInteraction( topo = yield self.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn, "_get_max_topological_txn", self._get_max_topological_txn, room_id
room_id,
) )
defer.returnValue("t%d-%d" % (topo, token)) defer.returnValue("t%d-%d" % (topo, token))
@ -474,9 +490,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
A deferred "s%d" stream token. A deferred "s%d" stream token.
""" """
return self._simple_select_one_onecol( return self._simple_select_one_onecol(
table="events", table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
keyvalues={"event_id": event_id},
retcol="stream_ordering",
).addCallback(lambda row: "s%d" % (row,)) ).addCallback(lambda row: "s%d" % (row,))
def get_topological_token_for_event(self, event_id): def get_topological_token_for_event(self, event_id):
@ -493,8 +507,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
keyvalues={"event_id": event_id}, keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"), retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event", desc="get_topological_token_for_event",
).addCallback(lambda row: "t%d-%d" % ( ).addCallback(
row["topological_ordering"], row["stream_ordering"],) lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
) )
def get_max_topological_token(self, room_id, stream_key): def get_max_topological_token(self, room_id, stream_key):
@ -503,17 +517,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
" WHERE room_id = ? AND stream_ordering < ?" " WHERE room_id = ? AND stream_ordering < ?"
) )
return self._execute( return self._execute(
"get_max_topological_token", None, "get_max_topological_token", None, sql, room_id, stream_key
sql, room_id, stream_key, ).addCallback(lambda r: r[0][0] if r else 0)
).addCallback(
lambda r: r[0][0] if r else 0
)
def _get_max_topological_txn(self, txn, room_id): def _get_max_topological_txn(self, txn, room_id):
txn.execute( txn.execute(
"SELECT MAX(topological_ordering) FROM events" "SELECT MAX(topological_ordering) FROM events" " WHERE room_id = ?",
" WHERE room_id = ?", (room_id,),
(room_id,)
) )
rows = txn.fetchall() rows = txn.fetchall()
@ -540,14 +550,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
internal = event.internal_metadata internal = event.internal_metadata
internal.before = str(RoomStreamToken(topo, stream - 1)) internal.before = str(RoomStreamToken(topo, stream - 1))
internal.after = str(RoomStreamToken(topo, stream)) internal.after = str(RoomStreamToken(topo, stream))
internal.order = ( internal.order = (int(topo) if topo else 0, int(stream))
int(topo) if topo else 0,
int(stream),
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_events_around( def get_events_around(
self, room_id, event_id, before_limit, after_limit, event_filter=None, self, room_id, event_id, before_limit, after_limit, event_filter=None
): ):
"""Retrieve events and pagination tokens around a given event in a """Retrieve events and pagination tokens around a given event in a
room. room.
@ -564,29 +571,34 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
""" """
results = yield self.runInteraction( results = yield self.runInteraction(
"get_events_around", self._get_events_around_txn, "get_events_around",
room_id, event_id, before_limit, after_limit, event_filter, self._get_events_around_txn,
room_id,
event_id,
before_limit,
after_limit,
event_filter,
) )
events_before = yield self._get_events( events_before = yield self._get_events(
[e for e in results["before"]["event_ids"]], [e for e in results["before"]["event_ids"]], get_prev_content=True
get_prev_content=True
) )
events_after = yield self._get_events( events_after = yield self._get_events(
[e for e in results["after"]["event_ids"]], [e for e in results["after"]["event_ids"]], get_prev_content=True
get_prev_content=True
) )
defer.returnValue({ defer.returnValue(
{
"events_before": events_before, "events_before": events_before,
"events_after": events_after, "events_after": events_after,
"start": results["before"]["token"], "start": results["before"]["token"],
"end": results["after"]["token"], "end": results["after"]["token"],
}) }
)
def _get_events_around_txn( def _get_events_around_txn(
self, txn, room_id, event_id, before_limit, after_limit, event_filter, self, txn, room_id, event_id, before_limit, after_limit, event_filter
): ):
"""Retrieves event_ids and pagination tokens around a given event in a """Retrieves event_ids and pagination tokens around a given event in a
room. room.
@ -605,46 +617,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
results = self._simple_select_one_txn( results = self._simple_select_one_txn(
txn, txn,
"events", "events",
keyvalues={ keyvalues={"event_id": event_id, "room_id": room_id},
"event_id": event_id,
"room_id": room_id,
},
retcols=["stream_ordering", "topological_ordering"], retcols=["stream_ordering", "topological_ordering"],
) )
# Paginating backwards includes the event at the token, but paginating # Paginating backwards includes the event at the token, but paginating
# forward doesn't. # forward doesn't.
before_token = RoomStreamToken( before_token = RoomStreamToken(
results["topological_ordering"] - 1, results["topological_ordering"] - 1, results["stream_ordering"]
results["stream_ordering"],
) )
after_token = RoomStreamToken( after_token = RoomStreamToken(
results["topological_ordering"], results["topological_ordering"], results["stream_ordering"]
results["stream_ordering"],
) )
rows, start_token = self._paginate_room_events_txn( rows, start_token = self._paginate_room_events_txn(
txn, room_id, before_token, direction='b', limit=before_limit, txn,
room_id,
before_token,
direction='b',
limit=before_limit,
event_filter=event_filter, event_filter=event_filter,
) )
events_before = [r.event_id for r in rows] events_before = [r.event_id for r in rows]
rows, end_token = self._paginate_room_events_txn( rows, end_token = self._paginate_room_events_txn(
txn, room_id, after_token, direction='f', limit=after_limit, txn,
room_id,
after_token,
direction='f',
limit=after_limit,
event_filter=event_filter, event_filter=event_filter,
) )
events_after = [r.event_id for r in rows] events_after = [r.event_id for r in rows]
return { return {
"before": { "before": {"event_ids": events_before, "token": start_token},
"event_ids": events_before, "after": {"event_ids": events_after, "token": end_token},
"token": start_token,
},
"after": {
"event_ids": events_after,
"token": end_token,
},
} }
@defer.inlineCallbacks @defer.inlineCallbacks
@ -685,7 +694,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, [row[1] for row in rows] return upper_bound, [row[1] for row in rows]
upper_bound, event_ids = yield self.runInteraction( upper_bound, event_ids = yield self.runInteraction(
"get_all_new_events_stream", get_all_new_events_stream_txn, "get_all_new_events_stream", get_all_new_events_stream_txn
) )
events = yield self._get_events(event_ids) events = yield self._get_events(event_ids)
@ -697,7 +706,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
table="federation_stream_position", table="federation_stream_position",
retcol="stream_id", retcol="stream_id",
keyvalues={"type": typ}, keyvalues={"type": typ},
desc="get_federation_out_pos" desc="get_federation_out_pos",
) )
def update_federation_out_pos(self, typ, stream_id): def update_federation_out_pos(self, typ, stream_id):
@ -711,8 +720,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def has_room_changed_since(self, room_id, stream_id): def has_room_changed_since(self, room_id, stream_id):
return self._events_stream_cache.has_entity_changed(room_id, stream_id) return self._events_stream_cache.has_entity_changed(room_id, stream_id)
def _paginate_room_events_txn(self, txn, room_id, from_token, to_token=None, def _paginate_room_events_txn(
direction='b', limit=-1, event_filter=None): self,
txn,
room_id,
from_token,
to_token=None,
direction='b',
limit=-1,
event_filter=None,
):
"""Returns list of events before or after a given token. """Returns list of events before or after a given token.
Args: Args:
@ -741,22 +758,20 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
args = [False, room_id] args = [False, room_id]
if direction == 'b': if direction == 'b':
order = "DESC" order = "DESC"
bounds = upper_bound( bounds = upper_bound(from_token, self.database_engine)
from_token, self.database_engine
)
if to_token: if to_token:
bounds = "%s AND %s" % (bounds, lower_bound( bounds = "%s AND %s" % (
to_token, self.database_engine bounds,
)) lower_bound(to_token, self.database_engine),
)
else: else:
order = "ASC" order = "ASC"
bounds = lower_bound( bounds = lower_bound(from_token, self.database_engine)
from_token, self.database_engine
)
if to_token: if to_token:
bounds = "%s AND %s" % (bounds, upper_bound( bounds = "%s AND %s" % (
to_token, self.database_engine bounds,
)) upper_bound(to_token, self.database_engine),
)
filter_clause, filter_args = filter_to_clause(event_filter) filter_clause, filter_args = filter_to_clause(event_filter)
@ -772,10 +787,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
" WHERE outlier = ? AND room_id = ? AND %(bounds)s" " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
" ORDER BY topological_ordering %(order)s," " ORDER BY topological_ordering %(order)s,"
" stream_ordering %(order)s LIMIT ?" " stream_ordering %(order)s LIMIT ?"
) % { ) % {"bounds": bounds, "order": order}
"bounds": bounds,
"order": order,
}
txn.execute(sql, args) txn.execute(sql, args)
@ -796,11 +808,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# TODO (erikj): We should work out what to do here instead. # TODO (erikj): We should work out what to do here instead.
next_token = to_token if to_token else from_token next_token = to_token if to_token else from_token
return rows, str(next_token), return rows, str(next_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def paginate_room_events(self, room_id, from_key, to_key=None, def paginate_room_events(
direction='b', limit=-1, event_filter=None): self, room_id, from_key, to_key=None, direction='b', limit=-1, event_filter=None
):
"""Returns list of events before or after a given token. """Returns list of events before or after a given token.
Args: Args:
@ -826,13 +839,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
to_key = RoomStreamToken.parse(to_key) to_key = RoomStreamToken.parse(to_key)
rows, token = yield self.runInteraction( rows, token = yield self.runInteraction(
"paginate_room_events", self._paginate_room_events_txn, "paginate_room_events",
room_id, from_key, to_key, direction, limit, event_filter, self._paginate_room_events_txn,
room_id,
from_key,
to_key,
direction,
limit,
event_filter,
) )
events = yield self._get_events( events = yield self._get_events(
[r.event_id for r in rows], [r.event_id for r in rows], get_prev_content=True
get_prev_content=True
) )
self._set_before_and_after(events, rows) self._set_before_and_after(events, rows)

View File

@ -84,9 +84,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
def get_tag_content(txn, tag_ids): def get_tag_content(txn, tag_ids):
sql = ( sql = (
"SELECT tag, content" "SELECT tag, content" " FROM room_tags" " WHERE user_id=? AND room_id=?"
" FROM room_tags"
" WHERE user_id=? AND room_id=?"
) )
results = [] results = []
for stream_id, user_id, room_id in tag_ids: for stream_id, user_id, room_id in tag_ids:
@ -123,6 +121,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
A deferred dict mapping from room_id strings to lists of tag A deferred dict mapping from room_id strings to lists of tag
strings for all the rooms that changed since the stream_id token. strings for all the rooms that changed since the stream_id token.
""" """
def get_updated_tags_txn(txn): def get_updated_tags_txn(txn):
sql = ( sql = (
"SELECT room_id from room_tags_revisions" "SELECT room_id from room_tags_revisions"
@ -138,9 +137,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
if not changed: if not changed:
defer.returnValue({}) defer.returnValue({})
room_ids = yield self.runInteraction( room_ids = yield self.runInteraction("get_updated_tags", get_updated_tags_txn)
"get_updated_tags", get_updated_tags_txn
)
results = {} results = {}
if room_ids: if room_ids:
@ -163,9 +160,9 @@ class TagsWorkerStore(AccountDataWorkerStore):
keyvalues={"user_id": user_id, "room_id": room_id}, keyvalues={"user_id": user_id, "room_id": room_id},
retcols=("tag", "content"), retcols=("tag", "content"),
desc="get_tags_for_room", desc="get_tags_for_room",
).addCallback(lambda rows: { ).addCallback(
row["tag"]: json.loads(row["content"]) for row in rows lambda rows: {row["tag"]: json.loads(row["content"]) for row in rows}
}) )
class TagsStore(TagsWorkerStore): class TagsStore(TagsWorkerStore):
@ -186,14 +183,8 @@ class TagsStore(TagsWorkerStore):
self._simple_upsert_txn( self._simple_upsert_txn(
txn, txn,
table="room_tags", table="room_tags",
keyvalues={ keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag},
"user_id": user_id, values={"content": content_json},
"room_id": room_id,
"tag": tag,
},
values={
"content": content_json,
}
) )
self._update_revision_txn(txn, user_id, room_id, next_id) self._update_revision_txn(txn, user_id, room_id, next_id)
@ -211,6 +202,7 @@ class TagsStore(TagsWorkerStore):
Returns: Returns:
A deferred that completes once the tag has been removed A deferred that completes once the tag has been removed
""" """
def remove_tag_txn(txn, next_id): def remove_tag_txn(txn, next_id):
sql = ( sql = (
"DELETE FROM room_tags " "DELETE FROM room_tags "
@ -238,8 +230,7 @@ class TagsStore(TagsWorkerStore):
""" """
txn.call_after( txn.call_after(
self._account_data_stream_cache.entity_has_changed, self._account_data_stream_cache.entity_has_changed, user_id, next_id
user_id, next_id
) )
update_max_id_sql = ( update_max_id_sql = (

View File

@ -38,16 +38,12 @@ logger = logging.getLogger(__name__)
_TransactionRow = namedtuple( _TransactionRow = namedtuple(
"_TransactionRow", ( "_TransactionRow",
"id", "transaction_id", "destination", "ts", "response_code", ("id", "transaction_id", "destination", "ts", "response_code", "response_json"),
"response_json",
)
) )
_UpdateTransactionRow = namedtuple( _UpdateTransactionRow = namedtuple(
"_TransactionRow", ( "_TransactionRow", ("response_code", "response_json")
"response_code", "response_json",
)
) )
SENTINEL = object() SENTINEL = object()
@ -84,19 +80,22 @@ class TransactionStore(SQLBaseStore):
return self.runInteraction( return self.runInteraction(
"get_received_txn_response", "get_received_txn_response",
self._get_received_txn_response, transaction_id, origin self._get_received_txn_response,
transaction_id,
origin,
) )
def _get_received_txn_response(self, txn, transaction_id, origin): def _get_received_txn_response(self, txn, transaction_id, origin):
result = self._simple_select_one_txn( result = self._simple_select_one_txn(
txn, txn,
table="received_transactions", table="received_transactions",
keyvalues={ keyvalues={"transaction_id": transaction_id, "origin": origin},
"transaction_id": transaction_id,
"origin": origin,
},
retcols=( retcols=(
"transaction_id", "origin", "ts", "response_code", "response_json", "transaction_id",
"origin",
"ts",
"response_code",
"response_json",
"has_been_referenced", "has_been_referenced",
), ),
allow_none=True, allow_none=True,
@ -108,8 +107,7 @@ class TransactionStore(SQLBaseStore):
else: else:
return None return None
def set_received_txn_response(self, transaction_id, origin, code, def set_received_txn_response(self, transaction_id, origin, code, response_dict):
response_dict):
"""Persist the response we returened for an incoming transaction, and """Persist the response we returened for an incoming transaction, and
should return for subsequent transactions with the same transaction_id should return for subsequent transactions with the same transaction_id
and origin. and origin.
@ -135,8 +133,7 @@ class TransactionStore(SQLBaseStore):
desc="set_received_txn_response", desc="set_received_txn_response",
) )
def prep_send_transaction(self, transaction_id, destination, def prep_send_transaction(self, transaction_id, destination, origin_server_ts):
origin_server_ts):
"""Persists an outgoing transaction and calculates the values for the """Persists an outgoing transaction and calculates the values for the
previous transaction id list. previous transaction id list.
@ -182,7 +179,9 @@ class TransactionStore(SQLBaseStore):
result = yield self.runInteraction( result = yield self.runInteraction(
"get_destination_retry_timings", "get_destination_retry_timings",
self._get_destination_retry_timings, destination) self._get_destination_retry_timings,
destination,
)
# We don't hugely care about race conditions between getting and # We don't hugely care about race conditions between getting and
# invalidating the cache, since we time out fairly quickly anyway. # invalidating the cache, since we time out fairly quickly anyway.
@ -193,9 +192,7 @@ class TransactionStore(SQLBaseStore):
result = self._simple_select_one_txn( result = self._simple_select_one_txn(
txn, txn,
table="destinations", table="destinations",
keyvalues={ keyvalues={"destination": destination},
"destination": destination,
},
retcols=("destination", "retry_last_ts", "retry_interval"), retcols=("destination", "retry_last_ts", "retry_interval"),
allow_none=True, allow_none=True,
) )
@ -205,8 +202,7 @@ class TransactionStore(SQLBaseStore):
else: else:
return None return None
def set_destination_retry_timings(self, destination, def set_destination_retry_timings(self, destination, retry_last_ts, retry_interval):
retry_last_ts, retry_interval):
"""Sets the current retry timings for a given destination. """Sets the current retry timings for a given destination.
Both timings should be zero if retrying is no longer occuring. Both timings should be zero if retrying is no longer occuring.
@ -225,8 +221,9 @@ class TransactionStore(SQLBaseStore):
retry_interval, retry_interval,
) )
def _set_destination_retry_timings(self, txn, destination, def _set_destination_retry_timings(
retry_last_ts, retry_interval): self, txn, destination, retry_last_ts, retry_interval
):
self.database_engine.lock_table(txn, "destinations") self.database_engine.lock_table(txn, "destinations")
# We need to be careful here as the data may have changed from under us # We need to be careful here as the data may have changed from under us
@ -235,9 +232,7 @@ class TransactionStore(SQLBaseStore):
prev_row = self._simple_select_one_txn( prev_row = self._simple_select_one_txn(
txn, txn,
table="destinations", table="destinations",
keyvalues={ keyvalues={"destination": destination},
"destination": destination,
},
retcols=("retry_last_ts", "retry_interval"), retcols=("retry_last_ts", "retry_interval"),
allow_none=True, allow_none=True,
) )
@ -250,15 +245,13 @@ class TransactionStore(SQLBaseStore):
"destination": destination, "destination": destination,
"retry_last_ts": retry_last_ts, "retry_last_ts": retry_last_ts,
"retry_interval": retry_interval, "retry_interval": retry_interval,
} },
) )
elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval: elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
self._simple_update_one_txn( self._simple_update_one_txn(
txn, txn,
"destinations", "destinations",
keyvalues={ keyvalues={"destination": destination},
"destination": destination,
},
updatevalues={ updatevalues={
"retry_last_ts": retry_last_ts, "retry_last_ts": retry_last_ts,
"retry_interval": retry_interval, "retry_interval": retry_interval,
@ -273,8 +266,7 @@ class TransactionStore(SQLBaseStore):
""" """
return self.runInteraction( return self.runInteraction(
"get_destinations_needing_retry", "get_destinations_needing_retry", self._get_destinations_needing_retry
self._get_destinations_needing_retry
) )
def _get_destinations_needing_retry(self, txn): def _get_destinations_needing_retry(self, txn):
@ -288,7 +280,7 @@ class TransactionStore(SQLBaseStore):
def _start_cleanup_transactions(self): def _start_cleanup_transactions(self):
return run_as_background_process( return run_as_background_process(
"cleanup_transactions", self._cleanup_transactions, "cleanup_transactions", self._cleanup_transactions
) )
def _cleanup_transactions(self): def _cleanup_transactions(self):

View File

@ -40,9 +40,7 @@ class UserErasureWorkerStore(SQLBaseStore):
).addCallback(operator.truth) ).addCallback(operator.truth)
@cachedList( @cachedList(
cached_method_name="is_user_erased", cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True
list_name="user_ids",
inlineCallbacks=True,
) )
def are_users_erased(self, user_ids): def are_users_erased(self, user_ids):
""" """
@ -61,16 +59,13 @@ class UserErasureWorkerStore(SQLBaseStore):
def _get_erased_users(txn): def _get_erased_users(txn):
txn.execute( txn.execute(
"SELECT user_id FROM erased_users WHERE user_id IN (%s)" % ( "SELECT user_id FROM erased_users WHERE user_id IN (%s)"
",".join("?" * len(user_ids)) % (",".join("?" * len(user_ids))),
),
user_ids, user_ids,
) )
return set(r[0] for r in txn) return set(r[0] for r in txn)
erased_users = yield self.runInteraction( erased_users = yield self.runInteraction("are_users_erased", _get_erased_users)
"are_users_erased", _get_erased_users,
)
res = dict((u, u in erased_users) for u in user_ids) res = dict((u, u in erased_users) for u in user_ids)
defer.returnValue(res) defer.returnValue(res)
@ -82,22 +77,16 @@ class UserErasureStore(UserErasureWorkerStore):
Args: Args:
user_id (str): full user_id to be erased user_id (str): full user_id to be erased
""" """
def f(txn): def f(txn):
# first check if they are already in the list # first check if they are already in the list
txn.execute( txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,))
"SELECT 1 FROM erased_users WHERE user_id = ?",
(user_id, )
)
if txn.fetchone(): if txn.fetchone():
return return
# they are not already there: do the insert. # they are not already there: do the insert.
txn.execute( txn.execute("INSERT INTO erased_users (user_id) VALUES (?)", (user_id,))
"INSERT INTO erased_users (user_id) VALUES (?)",
(user_id, ) self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
)
self._invalidate_cache_and_stream(
txn, self.is_user_erased, (user_id,)
)
return self.runInteraction("mark_user_erased", f) return self.runInteraction("mark_user_erased", f)

View File

@ -43,9 +43,9 @@ def _load_current_id(db_conn, table, column, step=1):
""" """
cur = db_conn.cursor() cur = db_conn.cursor()
if step == 1: if step == 1:
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
else: else:
cur.execute("SELECT MIN(%s) FROM %s" % (column, table,)) cur.execute("SELECT MIN(%s) FROM %s" % (column, table))
val, = cur.fetchone() val, = cur.fetchone()
cur.close() cur.close()
current_id = int(val) if val else step current_id = int(val) if val else step
@ -77,6 +77,7 @@ class StreamIdGenerator(object):
with stream_id_gen.get_next() as stream_id: with stream_id_gen.get_next() as stream_id:
# ... persist event ... # ... persist event ...
""" """
def __init__(self, db_conn, table, column, extra_tables=[], step=1): def __init__(self, db_conn, table, column, extra_tables=[], step=1):
assert step != 0 assert step != 0
self._lock = threading.Lock() self._lock = threading.Lock()
@ -84,8 +85,7 @@ class StreamIdGenerator(object):
self._current = _load_current_id(db_conn, table, column, step) self._current = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables: for table, column in extra_tables:
self._current = (max if step > 0 else min)( self._current = (max if step > 0 else min)(
self._current, self._current, _load_current_id(db_conn, table, column, step)
_load_current_id(db_conn, table, column, step)
) )
self._unfinished_ids = deque() self._unfinished_ids = deque()
@ -121,7 +121,7 @@ class StreamIdGenerator(object):
next_ids = range( next_ids = range(
self._current + self._step, self._current + self._step,
self._current + self._step * (n + 1), self._current + self._step * (n + 1),
self._step self._step,
) )
self._current += n * self._step self._current += n * self._step