Merge branch 'develop' of github.com:matrix-org/synapse into develop

This commit is contained in:
David Baker 2015-05-07 09:33:42 +01:00
commit 97a64f3ebe
31 changed files with 267 additions and 140 deletions

1
scripts/port_from_sqlite_to_postgres.py Normal file → Executable file
View File

@ -1,3 +1,4 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015 OpenMarket Ltd
# #

2
scripts/upgrade_db_to_v0.6.0.py Normal file → Executable file
View File

@ -1,4 +1,4 @@
#!/usr/bin/env python
from synapse.storage import SCHEMA_VERSION, read_schema from synapse.storage import SCHEMA_VERSION, read_schema
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.signatures import SignatureStore from synapse.storage.signatures import SignatureStore

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import glob
import os import os
from setuptools import setup, find_packages from setuptools import setup, find_packages
@ -55,5 +56,5 @@ setup(
include_package_data=True, include_package_data=True,
zip_safe=False, zip_safe=False,
long_description=long_description, long_description=long_description,
scripts=["synctl", "register_new_matrix_user"], scripts=["synctl"] + glob.glob("scripts/*"),
) )

View File

@ -496,11 +496,31 @@ class SynapseSite(Site):
def run(hs): def run(hs):
PROFILE_SYNAPSE = False
if PROFILE_SYNAPSE:
def profile(func):
from cProfile import Profile
from threading import current_thread
def profiled(*args, **kargs):
profile = Profile()
profile.enable()
func(*args, **kargs)
profile.disable()
ident = current_thread().ident
profile.dump_stats("/tmp/%s.%s.%i.pstat" % (
hs.hostname, func.__name__, ident
))
return profiled
from twisted.python.threadpool import ThreadPool
ThreadPool._worker = profile(ThreadPool._worker)
reactor.run = profile(reactor.run)
def in_thread(): def in_thread():
with LoggingContext("run"): with LoggingContext("run"):
change_resource_limit(hs.config.soft_file_limit) change_resource_limit(hs.config.soft_file_limit)
reactor.run() reactor.run()
if hs.config.daemonize: if hs.config.daemonize:

View File

@ -27,20 +27,21 @@ CONFIGFILE = "homeserver.yaml"
GREEN = "\x1b[1;32m" GREEN = "\x1b[1;32m"
NORMAL = "\x1b[m" NORMAL = "\x1b[m"
if not os.path.exists(CONFIGFILE):
sys.stderr.write(
"No config file found\n"
"To generate a config file, run '%s -c %s --generate-config"
" --server-name=<server name>'\n" % (
" ".join(SYNAPSE), CONFIGFILE
)
)
sys.exit(1)
CONFIG = yaml.load(open(CONFIGFILE)) CONFIG = yaml.load(open(CONFIGFILE))
PIDFILE = CONFIG["pid_file"] PIDFILE = CONFIG["pid_file"]
def start(): def start():
if not os.path.exists(CONFIGFILE):
sys.stderr.write(
"No config file found\n"
"To generate a config file, run '%s -c %s --generate-config"
" --server-name=<server name>'\n" % (
" ".join(SYNAPSE), CONFIGFILE
)
)
sys.exit(1)
print "Starting ...", print "Starting ...",
args = SYNAPSE args = SYNAPSE
args.extend(["--daemonize", "-c", CONFIGFILE]) args.extend(["--daemonize", "-c", CONFIGFILE])

View File

@ -144,16 +144,17 @@ class Config(object):
) )
config_args, remaining_args = config_parser.parse_known_args(argv) config_args, remaining_args = config_parser.parse_known_args(argv)
if not config_args.config_path:
config_parser.error(
"Must supply a config file.\nA config file can be automatically"
" generated using \"--generate-config -h SERVER_NAME"
" -c CONFIG-FILE\""
)
config_dir_path = os.path.dirname(config_args.config_path[0])
config_dir_path = os.path.abspath(config_dir_path)
if config_args.generate_config: if config_args.generate_config:
if not config_args.config_path:
config_parser.error(
"Must supply a config file.\nA config file can be automatically"
" generated using \"--generate-config -h SERVER_NAME"
" -c CONFIG-FILE\""
)
config_dir_path = os.path.dirname(config_args.config_path[0])
config_dir_path = os.path.abspath(config_dir_path)
server_name = config_args.server_name server_name = config_args.server_name
if not server_name: if not server_name:
print "Must specify a server_name to a generate config for." print "Must specify a server_name to a generate config for."
@ -196,6 +197,25 @@ class Config(object):
) )
sys.exit(0) sys.exit(0)
parser = argparse.ArgumentParser(
parents=[config_parser],
description=description,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
obj.invoke_all("add_arguments", parser)
args = parser.parse_args(remaining_args)
if not config_args.config_path:
config_parser.error(
"Must supply a config file.\nA config file can be automatically"
" generated using \"--generate-config -h SERVER_NAME"
" -c CONFIG-FILE\""
)
config_dir_path = os.path.dirname(config_args.config_path[0])
config_dir_path = os.path.abspath(config_dir_path)
specified_config = {} specified_config = {}
for config_path in config_args.config_path: for config_path in config_args.config_path:
yaml_config = cls.read_config_file(config_path) yaml_config = cls.read_config_file(config_path)
@ -208,15 +228,6 @@ class Config(object):
obj.invoke_all("read_config", config) obj.invoke_all("read_config", config)
parser = argparse.ArgumentParser(
parents=[config_parser],
description=description,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
obj.invoke_all("add_arguments", parser)
args = parser.parse_args(remaining_args)
obj.invoke_all("read_arguments", args) obj.invoke_all("read_arguments", args)
return obj return obj

View File

@ -491,7 +491,7 @@ class FederationClient(FederationBase):
] ]
signed_events = yield self._check_sigs_and_hash_and_fetch( signed_events = yield self._check_sigs_and_hash_and_fetch(
destination, events, outlier=True destination, events, outlier=False
) )
have_gotten_all_from_destination = True have_gotten_all_from_destination = True

View File

@ -23,8 +23,6 @@ from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from syutil.jsonutil import encode_canonical_json
import logging import logging
@ -71,7 +69,7 @@ class TransactionActions(object):
transaction.transaction_id, transaction.transaction_id,
transaction.origin, transaction.origin,
code, code,
encode_canonical_json(response) response,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -101,5 +99,5 @@ class TransactionActions(object):
transaction.transaction_id, transaction.transaction_id,
transaction.destination, transaction.destination,
response_code, response_code,
encode_canonical_json(response_dict) response_dict,
) )

View File

@ -104,7 +104,6 @@ class TransactionQueue(object):
return not destination.startswith("localhost") return not destination.startswith("localhost")
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function
def enqueue_pdu(self, pdu, destinations, order): def enqueue_pdu(self, pdu, destinations, order):
# We loop through all destinations to see whether we already have # We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus # a transaction in progress. If we do, stick it in the pending_pdus

View File

@ -31,7 +31,9 @@ import functools
import simplejson as json import simplejson as json
import sys import sys
import time import time
import threading
DEBUG_CACHES = False
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -68,9 +70,20 @@ class Cache(object):
self.name = name self.name = name
self.keylen = keylen self.keylen = keylen
self.sequence = 0
self.thread = None
caches_by_name[name] = self.cache caches_by_name[name] = self.cache
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
else:
if expected_thread is not threading.current_thread():
raise ValueError(
"Cache objects can only be accessed from the main thread"
)
def get(self, *keyargs): def get(self, *keyargs):
if len(keyargs) != self.keylen: if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen) raise ValueError("Expected a key to have %d items", self.keylen)
@ -82,6 +95,13 @@ class Cache(object):
cache_counter.inc_misses(self.name) cache_counter.inc_misses(self.name)
raise KeyError() raise KeyError()
def update(self, sequence, *args):
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
self.prefill(*args)
def prefill(self, *args): # because I can't *keyargs, value def prefill(self, *args): # because I can't *keyargs, value
keyargs = args[:-1] keyargs = args[:-1]
value = args[-1] value = args[-1]
@ -96,9 +116,12 @@ class Cache(object):
self.cache[keyargs] = value self.cache[keyargs] = value
def invalidate(self, *keyargs): def invalidate(self, *keyargs):
self.check_thread()
if len(keyargs) != self.keylen: if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen) raise ValueError("Expected a key to have %d items", self.keylen)
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
self.cache.pop(keyargs, None) self.cache.pop(keyargs, None)
@ -128,11 +151,26 @@ def cached(max_entries=1000, num_args=1, lru=False):
@defer.inlineCallbacks @defer.inlineCallbacks
def wrapped(self, *keyargs): def wrapped(self, *keyargs):
try: try:
defer.returnValue(cache.get(*keyargs)) cached_result = cache.get(*keyargs)
if DEBUG_CACHES:
actual_result = yield orig(self, *keyargs)
if actual_result != cached_result:
logger.error(
"Stale cache entry %s%r: cached: %r, actual %r",
orig.__name__, keyargs,
cached_result, actual_result,
)
raise ValueError("Stale cache entry")
defer.returnValue(cached_result)
except KeyError: except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
sequence = cache.sequence
ret = yield orig(self, *keyargs) ret = yield orig(self, *keyargs)
cache.prefill(*keyargs + (ret,)) cache.update(sequence, *keyargs + (ret,))
defer.returnValue(ret) defer.returnValue(ret)
@ -147,12 +185,20 @@ class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object """An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute() passed to the constructor. Adds logging and metrics to the .execute()
method.""" method."""
__slots__ = ["txn", "name", "database_engine"] __slots__ = ["txn", "name", "database_engine", "after_callbacks"]
def __init__(self, txn, name, database_engine): def __init__(self, txn, name, database_engine, after_callbacks):
object.__setattr__(self, "txn", txn) object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name) object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine) object.__setattr__(self, "database_engine", database_engine)
object.__setattr__(self, "after_callbacks", after_callbacks)
def call_after(self, callback, *args):
"""Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the
correct thread.
"""
self.after_callbacks.append((callback, args))
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.txn, name) return getattr(self.txn, name)
@ -160,22 +206,23 @@ class LoggingTransaction(object):
def __setattr__(self, name, value): def __setattr__(self, name, value):
setattr(self.txn, name, value) setattr(self.txn, name, value)
def execute(self, sql, *args, **kwargs): def execute(self, sql, *args):
self._do_execute(self.txn.execute, sql, *args)
def executemany(self, sql, *args):
self._do_execute(self.txn.executemany, sql, *args)
def _do_execute(self, func, sql, *args):
# TODO(paul): Maybe use 'info' and 'debug' for values? # TODO(paul): Maybe use 'info' and 'debug' for values?
sql_logger.debug("[SQL] {%s} %s", self.name, sql) sql_logger.debug("[SQL] {%s} %s", self.name, sql)
sql = self.database_engine.convert_param_style(sql) sql = self.database_engine.convert_param_style(sql)
if args and args[0]: if args:
args = list(args)
args[0] = [
self.database_engine.encode_parameter(a) for a in args[0]
]
try: try:
sql_logger.debug( sql_logger.debug(
"[SQL values] {%s} " + ", ".join(("<%r>",) * len(args[0])), "[SQL values] {%s} %r",
self.name, self.name, args[0]
*args[0]
) )
except: except:
# Don't let logging failures stop SQL from working # Don't let logging failures stop SQL from working
@ -184,8 +231,8 @@ class LoggingTransaction(object):
start = time.time() * 1000 start = time.time() * 1000
try: try:
return self.txn.execute( return func(
sql, *args, **kwargs sql, *args
) )
except Exception as e: except Exception as e:
logger.debug("[SQL FAIL] {%s} %s", self.name, e) logger.debug("[SQL FAIL] {%s} %s", self.name, e)
@ -298,6 +345,8 @@ class SQLBaseStore(object):
start_time = time.time() * 1000 start_time = time.time() * 1000
after_callbacks = []
def inner_func(conn, *args, **kwargs): def inner_func(conn, *args, **kwargs):
with LoggingContext("runInteraction") as context: with LoggingContext("runInteraction") as context:
if self.database_engine.is_connection_closed(conn): if self.database_engine.is_connection_closed(conn):
@ -322,10 +371,10 @@ class SQLBaseStore(object):
while True: while True:
try: try:
txn = conn.cursor() txn = conn.cursor()
return func( txn = LoggingTransaction(
LoggingTransaction(txn, name, self.database_engine), txn, name, self.database_engine, after_callbacks
*args, **kwargs
) )
return func(txn, *args, **kwargs)
except self.database_engine.module.OperationalError as e: except self.database_engine.module.OperationalError as e:
# This can happen if the database disappears mid # This can happen if the database disappears mid
# transaction. # transaction.
@ -374,6 +423,8 @@ class SQLBaseStore(object):
result = yield self._db_pool.runWithConnection( result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs inner_func, *args, **kwargs
) )
for after_callback, after_args in after_callbacks:
after_callback(*after_args)
defer.returnValue(result) defer.returnValue(result)
def cursor_to_dict(self, cursor): def cursor_to_dict(self, cursor):
@ -438,18 +489,49 @@ class SQLBaseStore(object):
@log_function @log_function
def _simple_insert_txn(self, txn, table, values): def _simple_insert_txn(self, txn, table, values):
keys, vals = zip(*values.items())
sql = "INSERT INTO %s (%s) VALUES(%s)" % ( sql = "INSERT INTO %s (%s) VALUES(%s)" % (
table, table,
", ".join(k for k in values), ", ".join(k for k in keys),
", ".join("?" for k in values) ", ".join("?" for _ in keys)
) )
logger.debug( txn.execute(sql, vals)
"[SQL] %s Args=%s",
sql, values.values(), def _simple_insert_many_txn(self, txn, table, values):
if not values:
return
# This is a *slight* abomination to get a list of tuples of key names
# and a list of tuples of value names.
#
# i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
# => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
#
# The sort is to ensure that we don't rely on dictionary iteration
# order.
keys, vals = zip(*[
zip(
*(sorted(i.items(), key=lambda kv: kv[0]))
)
for i in values
if i
])
for k in keys:
if k != keys[0]:
raise RuntimeError(
"All items must have the same keys"
)
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
table,
", ".join(k for k in keys[0]),
", ".join("?" for _ in keys[0])
) )
txn.execute(sql, values.values()) txn.executemany(sql, vals)
def _simple_upsert(self, table, keyvalues, values, def _simple_upsert(self, table, keyvalues, values,
insertion_values={}, desc="_simple_upsert", lock=True): insertion_values={}, desc="_simple_upsert", lock=True):

View File

@ -36,9 +36,6 @@ class PostgresEngine(object):
def convert_param_style(self, sql): def convert_param_style(self, sql):
return sql.replace("?", "%s") return sql.replace("?", "%s")
def encode_parameter(self, param):
return param
def on_new_connection(self, db_conn): def on_new_connection(self, db_conn):
db_conn.set_isolation_level( db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ

View File

@ -26,9 +26,6 @@ class Sqlite3Engine(object):
def convert_param_style(self, sql): def convert_param_style(self, sql):
return sql return sql
def encode_parameter(self, param):
return param
def on_new_connection(self, db_conn): def on_new_connection(self, db_conn):
self.prepare_database(db_conn) self.prepare_database(db_conn)

View File

@ -104,7 +104,7 @@ class EventFederationStore(SQLBaseStore):
"room_id": room_id, "room_id": room_id,
}, },
retcol="event_id", retcol="event_id",
desc="get_latest_events_in_room", desc="get_latest_event_ids_in_room",
) )
def _get_latest_events_in_room(self, txn, room_id): def _get_latest_events_in_room(self, txn, room_id):
@ -262,18 +262,19 @@ class EventFederationStore(SQLBaseStore):
For the given event, update the event edges table and forward and For the given event, update the event edges table and forward and
backward extremities tables. backward extremities tables.
""" """
for e_id, _ in prev_events: self._simple_insert_many_txn(
# TODO (erikj): This could be done as a bulk insert txn,
self._simple_insert_txn( table="event_edges",
txn, values=[
table="event_edges", {
values={
"event_id": event_id, "event_id": event_id,
"prev_event_id": e_id, "prev_event_id": e_id,
"room_id": room_id, "room_id": room_id,
"is_state": False, "is_state": False,
}, }
) for e_id, _ in prev_events
],
)
# Update the extremities table if this is not an outlier. # Update the extremities table if this is not an outlier.
if not outlier: if not outlier:
@ -307,16 +308,17 @@ class EventFederationStore(SQLBaseStore):
# Insert all the prev_events as a backwards thing, they'll get # Insert all the prev_events as a backwards thing, they'll get
# deleted in a second if they're incorrect anyway. # deleted in a second if they're incorrect anyway.
for e_id, _ in prev_events: self._simple_insert_many_txn(
# TODO (erikj): This could be done as a bulk insert txn,
self._simple_insert_txn( table="event_backward_extremities",
txn, values=[
table="event_backward_extremities", {
values={
"event_id": e_id, "event_id": e_id,
"room_id": room_id, "room_id": room_id,
}, }
) for e_id, _ in prev_events
],
)
# Also delete from the backwards extremities table all ones that # Also delete from the backwards extremities table all ones that
# reference events that we have already seen # reference events that we have already seen
@ -330,7 +332,9 @@ class EventFederationStore(SQLBaseStore):
) )
txn.execute(query) txn.execute(query)
self.get_latest_event_ids_in_room.invalidate(room_id) txn.call_after(
self.get_latest_event_ids_in_room.invalidate, room_id
)
def get_backfill_events(self, room_id, event_list, limit): def get_backfill_events(self, room_id, event_list, limit):
"""Get a list of Events for a given topic that occurred before (and """Get a list of Events for a given topic that occurred before (and

View File

@ -93,7 +93,7 @@ class EventsStore(SQLBaseStore):
current_state=None): current_state=None):
# Remove the any existing cache entries for the event_id # Remove the any existing cache entries for the event_id
self._invalidate_get_event_cache(event.event_id) txn.call_after(self._invalidate_get_event_cache, event.event_id)
if stream_ordering is None: if stream_ordering is None:
with self._stream_id_gen.get_next_txn(txn) as stream_ordering: with self._stream_id_gen.get_next_txn(txn) as stream_ordering:
@ -114,6 +114,13 @@ class EventsStore(SQLBaseStore):
) )
for s in current_state: for s in current_state:
if s.type == EventTypes.Member:
txn.call_after(
self.get_rooms_for_user.invalidate, s.state_key
)
txn.call_after(
self.get_joined_hosts_for_room.invalidate, s.room_id
)
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
"current_state_events", "current_state_events",
@ -122,31 +129,9 @@ class EventsStore(SQLBaseStore):
"room_id": s.room_id, "room_id": s.room_id,
"type": s.type, "type": s.type,
"state_key": s.state_key, "state_key": s.state_key,
}, }
) )
if event.is_state() and is_new_state:
if not backfilled and not context.rejected:
self._simple_insert_txn(
txn,
table="state_forward_extremities",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
)
for prev_state_id, _ in event.prev_state:
self._simple_delete_txn(
txn,
table="state_forward_extremities",
keyvalues={
"event_id": prev_state_id,
}
)
outlier = event.internal_metadata.is_outlier() outlier = event.internal_metadata.is_outlier()
if not outlier: if not outlier:
@ -281,7 +266,9 @@ class EventsStore(SQLBaseStore):
) )
if context.rejected: if context.rejected:
self._store_rejections_txn(txn, event.event_id, context.rejected) self._store_rejections_txn(
txn, event.event_id, context.rejected
)
for hash_alg, hash_base64 in event.hashes.items(): for hash_alg, hash_base64 in event.hashes.items():
hash_bytes = decode_base64(hash_base64) hash_bytes = decode_base64(hash_base64)
@ -293,19 +280,22 @@ class EventsStore(SQLBaseStore):
for alg, hash_base64 in prev_hashes.items(): for alg, hash_base64 in prev_hashes.items():
hash_bytes = decode_base64(hash_base64) hash_bytes = decode_base64(hash_base64)
self._store_prev_event_hash_txn( self._store_prev_event_hash_txn(
txn, event.event_id, prev_event_id, alg, hash_bytes txn, event.event_id, prev_event_id, alg,
hash_bytes
) )
for auth_id, _ in event.auth_events: self._simple_insert_many_txn(
self._simple_insert_txn( txn,
txn, table="event_auth",
table="event_auth", values=[
values={ {
"event_id": event.event_id, "event_id": event.event_id,
"room_id": event.room_id, "room_id": event.room_id,
"auth_id": auth_id, "auth_id": auth_id,
}, }
) for auth_id, _ in event.auth_events
],
)
(ref_alg, ref_hash_bytes) = compute_event_reference_hash(event) (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
self._store_event_reference_hash_txn( self._store_event_reference_hash_txn(
@ -330,17 +320,19 @@ class EventsStore(SQLBaseStore):
vals, vals,
) )
for e_id, h in event.prev_state: self._simple_insert_many_txn(
self._simple_insert_txn( txn,
txn, table="event_edges",
table="event_edges", values=[
values={ {
"event_id": event.event_id, "event_id": event.event_id,
"prev_event_id": e_id, "prev_event_id": e_id,
"room_id": event.room_id, "room_id": event.room_id,
"is_state": True, "is_state": True,
}, }
) for e_id, h in event.prev_state
],
)
if is_new_state and not context.rejected: if is_new_state and not context.rejected:
self._simple_upsert_txn( self._simple_upsert_txn(
@ -356,9 +348,11 @@ class EventsStore(SQLBaseStore):
} }
) )
return
def _store_redaction(self, txn, event): def _store_redaction(self, txn, event):
# invalidate the cache for the redacted event # invalidate the cache for the redacted event
self._invalidate_get_event_cache(event.redacts) txn.call_after(self._invalidate_get_event_cache, event.redacts)
txn.execute( txn.execute(
"INSERT INTO redactions (event_id, redacts) VALUES (?,?)", "INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
(event.event_id, event.redacts) (event.event_id, event.redacts)

View File

@ -64,8 +64,8 @@ class RoomMemberStore(SQLBaseStore):
} }
) )
self.get_rooms_for_user.invalidate(target_user_id) txn.call_after(self.get_rooms_for_user.invalidate, target_user_id)
self.get_joined_hosts_for_room.invalidate(event.room_id) txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
def get_room_member(self, user_id, room_id): def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member. """Retrieve the current state of a room member.

View File

@ -0,0 +1,18 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
DROP INDEX IF EXISTS sent_transaction_dest;
DROP INDEX IF EXISTS sent_transaction_sent;
DROP INDEX IF EXISTS user_ips_user;

View File

@ -104,18 +104,20 @@ class StateStore(SQLBaseStore):
}, },
) )
for state in state_events.values(): self._simple_insert_many_txn(
self._simple_insert_txn( txn,
txn, table="state_groups_state",
table="state_groups_state", values=[
values={ {
"state_group": state_group, "state_group": state_group,
"room_id": state.room_id, "room_id": state.room_id,
"type": state.type, "type": state.type,
"state_key": state.state_key, "state_key": state.state_key,
"event_id": state.event_id, "event_id": state.event_id,
}, }
) for state in state_events.values()
],
)
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,

View File

@ -17,6 +17,7 @@ from ._base import SQLBaseStore, cached
from collections import namedtuple from collections import namedtuple
from syutil.jsonutil import encode_canonical_json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -82,7 +83,7 @@ class TransactionStore(SQLBaseStore):
"transaction_id": transaction_id, "transaction_id": transaction_id,
"origin": origin, "origin": origin,
"response_code": code, "response_code": code,
"response_json": response_dict, "response_json": buffer(encode_canonical_json(response_dict)),
}, },
or_ignore=True, or_ignore=True,
desc="set_received_txn_response", desc="set_received_txn_response",
@ -161,7 +162,8 @@ class TransactionStore(SQLBaseStore):
return self.runInteraction( return self.runInteraction(
"delivered_txn", "delivered_txn",
self._delivered_txn, self._delivered_txn,
transaction_id, destination, code, response_dict transaction_id, destination, code,
buffer(encode_canonical_json(response_dict)),
) )
def _delivered_txn(self, txn, transaction_id, destination, def _delivered_txn(self, txn, transaction_id, destination,

View File

@ -67,7 +67,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_with(
"INSERT INTO tablename (columname) VALUES(?)", "INSERT INTO tablename (columname) VALUES(?)",
["Value"] ("Value",)
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -82,7 +82,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_with(
"INSERT INTO tablename (colA, colB, colC) VALUES(?, ?, ?)", "INSERT INTO tablename (colA, colB, colC) VALUES(?, ?, ?)",
[1, 2, 3] (1, 2, 3,)
) )
@defer.inlineCallbacks @defer.inlineCallbacks