mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Merge branch 'erikj/executemany' of github.com:matrix-org/synapse into erikj/SYN-371
This commit is contained in:
commit
0c4ac271ca
@ -144,16 +144,17 @@ class Config(object):
|
||||
)
|
||||
config_args, remaining_args = config_parser.parse_known_args(argv)
|
||||
|
||||
if not config_args.config_path:
|
||||
config_parser.error(
|
||||
"Must supply a config file.\nA config file can be automatically"
|
||||
" generated using \"--generate-config -h SERVER_NAME"
|
||||
" -c CONFIG-FILE\""
|
||||
)
|
||||
|
||||
config_dir_path = os.path.dirname(config_args.config_path[0])
|
||||
config_dir_path = os.path.abspath(config_dir_path)
|
||||
if config_args.generate_config:
|
||||
if not config_args.config_path:
|
||||
config_parser.error(
|
||||
"Must supply a config file.\nA config file can be automatically"
|
||||
" generated using \"--generate-config -h SERVER_NAME"
|
||||
" -c CONFIG-FILE\""
|
||||
)
|
||||
|
||||
config_dir_path = os.path.dirname(config_args.config_path[0])
|
||||
config_dir_path = os.path.abspath(config_dir_path)
|
||||
|
||||
server_name = config_args.server_name
|
||||
if not server_name:
|
||||
print "Most specify a server_name to a generate config for."
|
||||
@ -196,6 +197,25 @@ class Config(object):
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
parents=[config_parser],
|
||||
description=description,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
|
||||
obj.invoke_all("add_arguments", parser)
|
||||
args = parser.parse_args(remaining_args)
|
||||
|
||||
if not config_args.config_path:
|
||||
config_parser.error(
|
||||
"Must supply a config file.\nA config file can be automatically"
|
||||
" generated using \"--generate-config -h SERVER_NAME"
|
||||
" -c CONFIG-FILE\""
|
||||
)
|
||||
|
||||
config_dir_path = os.path.dirname(config_args.config_path[0])
|
||||
config_dir_path = os.path.abspath(config_dir_path)
|
||||
|
||||
specified_config = {}
|
||||
for config_path in config_args.config_path:
|
||||
yaml_config = cls.read_config_file(config_path)
|
||||
@ -208,15 +228,6 @@ class Config(object):
|
||||
|
||||
obj.invoke_all("read_config", config)
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
parents=[config_parser],
|
||||
description=description,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
|
||||
obj.invoke_all("add_arguments", parser)
|
||||
args = parser.parse_args(remaining_args)
|
||||
|
||||
obj.invoke_all("read_arguments", args)
|
||||
|
||||
return obj
|
||||
|
@ -491,7 +491,7 @@ class FederationClient(FederationBase):
|
||||
]
|
||||
|
||||
signed_events = yield self._check_sigs_and_hash_and_fetch(
|
||||
destination, events, outlier=True
|
||||
destination, events, outlier=False
|
||||
)
|
||||
|
||||
have_gotten_all_from_destination = True
|
||||
|
@ -23,8 +23,6 @@ from twisted.internet import defer
|
||||
|
||||
from synapse.util.logutils import log_function
|
||||
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
@ -71,7 +69,7 @@ class TransactionActions(object):
|
||||
transaction.transaction_id,
|
||||
transaction.origin,
|
||||
code,
|
||||
encode_canonical_json(response)
|
||||
response,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -101,5 +99,5 @@ class TransactionActions(object):
|
||||
transaction.transaction_id,
|
||||
transaction.destination,
|
||||
response_code,
|
||||
encode_canonical_json(response_dict)
|
||||
response_dict,
|
||||
)
|
||||
|
@ -31,7 +31,9 @@ import functools
|
||||
import simplejson as json
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
|
||||
DEBUG_CACHES = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -68,9 +70,20 @@ class Cache(object):
|
||||
|
||||
self.name = name
|
||||
self.keylen = keylen
|
||||
|
||||
self.sequence = 0
|
||||
self.thread = None
|
||||
caches_by_name[name] = self.cache
|
||||
|
||||
def check_thread(self):
|
||||
expected_thread = self.thread
|
||||
if expected_thread is None:
|
||||
self.thread = threading.current_thread()
|
||||
else:
|
||||
if expected_thread is not threading.current_thread():
|
||||
raise ValueError(
|
||||
"Cache objects can only be accessed from the main thread"
|
||||
)
|
||||
|
||||
def get(self, *keyargs):
|
||||
if len(keyargs) != self.keylen:
|
||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||
@ -82,6 +95,13 @@ class Cache(object):
|
||||
cache_counter.inc_misses(self.name)
|
||||
raise KeyError()
|
||||
|
||||
def update(self, sequence, *args):
|
||||
self.check_thread()
|
||||
if self.sequence == sequence:
|
||||
# Only update the cache if the caches sequence number matches the
|
||||
# number that the cache had before the SELECT was started (SYN-369)
|
||||
self.prefill(*args)
|
||||
|
||||
def prefill(self, *args): # because I can't *keyargs, value
|
||||
keyargs = args[:-1]
|
||||
value = args[-1]
|
||||
@ -96,9 +116,12 @@ class Cache(object):
|
||||
self.cache[keyargs] = value
|
||||
|
||||
def invalidate(self, *keyargs):
|
||||
self.check_thread()
|
||||
if len(keyargs) != self.keylen:
|
||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||
|
||||
# Increment the sequence number so that any SELECT statements that
|
||||
# raced with the INSERT don't update the cache (SYN-369)
|
||||
self.sequence += 1
|
||||
self.cache.pop(keyargs, None)
|
||||
|
||||
|
||||
@ -128,11 +151,26 @@ def cached(max_entries=1000, num_args=1, lru=False):
|
||||
@defer.inlineCallbacks
|
||||
def wrapped(self, *keyargs):
|
||||
try:
|
||||
defer.returnValue(cache.get(*keyargs))
|
||||
cached_result = cache.get(*keyargs)
|
||||
if DEBUG_CACHES:
|
||||
actual_result = yield orig(self, *keyargs)
|
||||
if actual_result != cached_result:
|
||||
logger.error(
|
||||
"Stale cache entry %s%r: cached: %r, actual %r",
|
||||
orig.__name__, keyargs,
|
||||
cached_result, actual_result,
|
||||
)
|
||||
raise ValueError("Stale cache entry")
|
||||
defer.returnValue(cached_result)
|
||||
except KeyError:
|
||||
# Get the sequence number of the cache before reading from the
|
||||
# database so that we can tell if the cache is invalidated
|
||||
# while the SELECT is executing (SYN-369)
|
||||
sequence = cache.sequence
|
||||
|
||||
ret = yield orig(self, *keyargs)
|
||||
|
||||
cache.prefill(*keyargs + (ret,))
|
||||
cache.update(sequence, *keyargs + (ret,))
|
||||
|
||||
defer.returnValue(ret)
|
||||
|
||||
@ -147,12 +185,20 @@ class LoggingTransaction(object):
|
||||
"""An object that almost-transparently proxies for the 'txn' object
|
||||
passed to the constructor. Adds logging and metrics to the .execute()
|
||||
method."""
|
||||
__slots__ = ["txn", "name", "database_engine"]
|
||||
__slots__ = ["txn", "name", "database_engine", "after_callbacks"]
|
||||
|
||||
def __init__(self, txn, name, database_engine):
|
||||
def __init__(self, txn, name, database_engine, after_callbacks):
|
||||
object.__setattr__(self, "txn", txn)
|
||||
object.__setattr__(self, "name", name)
|
||||
object.__setattr__(self, "database_engine", database_engine)
|
||||
object.__setattr__(self, "after_callbacks", after_callbacks)
|
||||
|
||||
def call_after(self, callback, *args):
|
||||
"""Call the given callback on the main twisted thread after the
|
||||
transaction has finished. Used to invalidate the caches on the
|
||||
correct thread.
|
||||
"""
|
||||
self.after_callbacks.append((callback, args))
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.txn, name)
|
||||
@ -299,6 +345,8 @@ class SQLBaseStore(object):
|
||||
|
||||
start_time = time.time() * 1000
|
||||
|
||||
after_callbacks = []
|
||||
|
||||
def inner_func(conn, *args, **kwargs):
|
||||
with LoggingContext("runInteraction") as context:
|
||||
if self.database_engine.is_connection_closed(conn):
|
||||
@ -323,10 +371,10 @@ class SQLBaseStore(object):
|
||||
while True:
|
||||
try:
|
||||
txn = conn.cursor()
|
||||
return func(
|
||||
LoggingTransaction(txn, name, self.database_engine),
|
||||
*args, **kwargs
|
||||
txn = LoggingTransaction(
|
||||
txn, name, self.database_engine, after_callbacks
|
||||
)
|
||||
return func(txn, *args, **kwargs)
|
||||
except self.database_engine.module.OperationalError as e:
|
||||
# This can happen if the database disappears mid
|
||||
# transaction.
|
||||
@ -375,6 +423,8 @@ class SQLBaseStore(object):
|
||||
result = yield self._db_pool.runWithConnection(
|
||||
inner_func, *args, **kwargs
|
||||
)
|
||||
for after_callback, after_args in after_callbacks:
|
||||
after_callback(*after_args)
|
||||
defer.returnValue(result)
|
||||
|
||||
def cursor_to_dict(self, cursor):
|
||||
@ -453,6 +503,14 @@ class SQLBaseStore(object):
|
||||
if not values:
|
||||
return
|
||||
|
||||
# This is a *slight* abomination to get a list of tuples of key names
|
||||
# and a list of tuples of value names.
|
||||
#
|
||||
# i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
|
||||
# => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
|
||||
#
|
||||
# The sort is to ensure that we don't rely on dictionary iteration
|
||||
# order.
|
||||
keys, vals = zip(*[
|
||||
zip(
|
||||
*(sorted(i.items(), key=lambda kv: kv[0]))
|
||||
|
@ -332,7 +332,9 @@ class EventFederationStore(SQLBaseStore):
|
||||
)
|
||||
txn.execute(query)
|
||||
|
||||
self.get_latest_event_ids_in_room.invalidate(room_id)
|
||||
txn.call_after(
|
||||
self.get_latest_event_ids_in_room.invalidate, room_id
|
||||
)
|
||||
|
||||
def get_backfill_events(self, room_id, event_list, limit):
|
||||
"""Get a list of Events for a given topic that occurred before (and
|
||||
|
@ -93,7 +93,7 @@ class EventsStore(SQLBaseStore):
|
||||
current_state=None):
|
||||
|
||||
# Remove the any existing cache entries for the event_id
|
||||
self._invalidate_get_event_cache(event.event_id)
|
||||
txn.call_after(self._invalidate_get_event_cache, event.event_id)
|
||||
|
||||
if stream_ordering is None:
|
||||
with self._stream_id_gen.get_next_txn(txn) as stream_ordering:
|
||||
@ -113,19 +113,24 @@ class EventsStore(SQLBaseStore):
|
||||
keyvalues={"room_id": event.room_id},
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
"current_state_events",
|
||||
[
|
||||
for s in current_state:
|
||||
if s.type == EventTypes.Member:
|
||||
txn.call_after(
|
||||
self.get_rooms_for_user.invalidate, s.state_key
|
||||
)
|
||||
txn.call_after(
|
||||
self.get_joined_hosts_for_room.invalidate, s.room_id
|
||||
)
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"current_state_events",
|
||||
{
|
||||
"event_id": s.event_id,
|
||||
"room_id": s.room_id,
|
||||
"type": s.type,
|
||||
"state_key": s.state_key,
|
||||
}
|
||||
for s in current_state
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
outlier = event.internal_metadata.is_outlier()
|
||||
|
||||
@ -261,7 +266,9 @@ class EventsStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
if context.rejected:
|
||||
self._store_rejections_txn(txn, event.event_id, context.rejected)
|
||||
self._store_rejections_txn(
|
||||
txn, event.event_id, context.rejected
|
||||
)
|
||||
|
||||
for hash_alg, hash_base64 in event.hashes.items():
|
||||
hash_bytes = decode_base64(hash_base64)
|
||||
@ -273,7 +280,8 @@ class EventsStore(SQLBaseStore):
|
||||
for alg, hash_base64 in prev_hashes.items():
|
||||
hash_bytes = decode_base64(hash_base64)
|
||||
self._store_prev_event_hash_txn(
|
||||
txn, event.event_id, prev_event_id, alg, hash_bytes
|
||||
txn, event.event_id, prev_event_id, alg,
|
||||
hash_bytes
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
@ -340,9 +348,11 @@ class EventsStore(SQLBaseStore):
|
||||
}
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def _store_redaction(self, txn, event):
|
||||
# invalidate the cache for the redacted event
|
||||
self._invalidate_get_event_cache(event.redacts)
|
||||
txn.call_after(self._invalidate_get_event_cache, event.redacts)
|
||||
txn.execute(
|
||||
"INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
|
||||
(event.event_id, event.redacts)
|
||||
|
@ -64,8 +64,8 @@ class RoomMemberStore(SQLBaseStore):
|
||||
}
|
||||
)
|
||||
|
||||
self.get_rooms_for_user.invalidate(target_user_id)
|
||||
self.get_joined_hosts_for_room.invalidate(event.room_id)
|
||||
txn.call_after(self.get_rooms_for_user.invalidate, target_user_id)
|
||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
|
||||
|
||||
def get_room_member(self, user_id, room_id):
|
||||
"""Retrieve the current state of a room member.
|
||||
|
@ -17,6 +17,7 @@ from ._base import SQLBaseStore, cached
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -82,7 +83,7 @@ class TransactionStore(SQLBaseStore):
|
||||
"transaction_id": transaction_id,
|
||||
"origin": origin,
|
||||
"response_code": code,
|
||||
"response_json": response_dict,
|
||||
"response_json": buffer(encode_canonical_json(response_dict)),
|
||||
},
|
||||
or_ignore=True,
|
||||
desc="set_received_txn_response",
|
||||
@ -161,7 +162,8 @@ class TransactionStore(SQLBaseStore):
|
||||
return self.runInteraction(
|
||||
"delivered_txn",
|
||||
self._delivered_txn,
|
||||
transaction_id, destination, code, response_dict
|
||||
transaction_id, destination, code,
|
||||
buffer(encode_canonical_json(response_dict)),
|
||||
)
|
||||
|
||||
def _delivered_txn(self, txn, transaction_id, destination,
|
||||
|
Loading…
Reference in New Issue
Block a user