mirror of
				https://git.anonymousland.org/anonymousland/synapse-product.git
				synced 2025-10-31 16:48:53 -04:00 
			
		
		
		
	Merge branch 'develop' of github.com:matrix-org/synapse into develop
This commit is contained in:
		
						commit
						97a64f3ebe
					
				
					 31 changed files with 267 additions and 140 deletions
				
			
		
							
								
								
									
										1
									
								
								scripts/port_from_sqlite_to_postgres.py
									
										
									
									
									
										
										
										Normal file → Executable file
									
								
							
							
						
						
									
										1
									
								
								scripts/port_from_sqlite_to_postgres.py
									
										
									
									
									
										
										
										Normal file → Executable file
									
								
							|  | @ -1,3 +1,4 @@ | ||||||
|  | #!/usr/bin/env python | ||||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||||
| # Copyright 2015 OpenMarket Ltd | # Copyright 2015 OpenMarket Ltd | ||||||
| # | # | ||||||
|  |  | ||||||
							
								
								
									
										2
									
								
								scripts/upgrade_db_to_v0.6.0.py
									
										
									
									
									
										
										
										Normal file → Executable file
									
								
							
							
						
						
									
										2
									
								
								scripts/upgrade_db_to_v0.6.0.py
									
										
									
									
									
										
										
										Normal file → Executable file
									
								
							|  | @ -1,4 +1,4 @@ | ||||||
| 
 | #!/usr/bin/env python | ||||||
| from synapse.storage import SCHEMA_VERSION, read_schema | from synapse.storage import SCHEMA_VERSION, read_schema | ||||||
| from synapse.storage._base import SQLBaseStore | from synapse.storage._base import SQLBaseStore | ||||||
| from synapse.storage.signatures import SignatureStore | from synapse.storage.signatures import SignatureStore | ||||||
|  |  | ||||||
							
								
								
									
										3
									
								
								setup.py
									
										
									
									
									
								
							
							
						
						
									
										3
									
								
								setup.py
									
										
									
									
									
								
							|  | @ -14,6 +14,7 @@ | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| 
 | 
 | ||||||
|  | import glob | ||||||
| import os | import os | ||||||
| from setuptools import setup, find_packages | from setuptools import setup, find_packages | ||||||
| 
 | 
 | ||||||
|  | @ -55,5 +56,5 @@ setup( | ||||||
|     include_package_data=True, |     include_package_data=True, | ||||||
|     zip_safe=False, |     zip_safe=False, | ||||||
|     long_description=long_description, |     long_description=long_description, | ||||||
|     scripts=["synctl", "register_new_matrix_user"], |     scripts=["synctl"] + glob.glob("scripts/*"), | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | @ -496,11 +496,31 @@ class SynapseSite(Site): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def run(hs): | def run(hs): | ||||||
|  |     PROFILE_SYNAPSE = False | ||||||
|  |     if PROFILE_SYNAPSE: | ||||||
|  |         def profile(func): | ||||||
|  |             from cProfile import Profile | ||||||
|  |             from threading import current_thread | ||||||
|  | 
 | ||||||
|  |             def profiled(*args, **kargs): | ||||||
|  |                 profile = Profile() | ||||||
|  |                 profile.enable() | ||||||
|  |                 func(*args, **kargs) | ||||||
|  |                 profile.disable() | ||||||
|  |                 ident = current_thread().ident | ||||||
|  |                 profile.dump_stats("/tmp/%s.%s.%i.pstat" % ( | ||||||
|  |                     hs.hostname, func.__name__, ident | ||||||
|  |                 )) | ||||||
|  | 
 | ||||||
|  |             return profiled | ||||||
|  | 
 | ||||||
|  |         from twisted.python.threadpool import ThreadPool | ||||||
|  |         ThreadPool._worker = profile(ThreadPool._worker) | ||||||
|  |         reactor.run = profile(reactor.run) | ||||||
| 
 | 
 | ||||||
|     def in_thread(): |     def in_thread(): | ||||||
|         with LoggingContext("run"): |         with LoggingContext("run"): | ||||||
|             change_resource_limit(hs.config.soft_file_limit) |             change_resource_limit(hs.config.soft_file_limit) | ||||||
| 
 |  | ||||||
|             reactor.run() |             reactor.run() | ||||||
| 
 | 
 | ||||||
|     if hs.config.daemonize: |     if hs.config.daemonize: | ||||||
|  |  | ||||||
|  | @ -27,12 +27,7 @@ CONFIGFILE = "homeserver.yaml" | ||||||
| GREEN = "\x1b[1;32m" | GREEN = "\x1b[1;32m" | ||||||
| NORMAL = "\x1b[m" | NORMAL = "\x1b[m" | ||||||
| 
 | 
 | ||||||
| CONFIG = yaml.load(open(CONFIGFILE)) | if not os.path.exists(CONFIGFILE): | ||||||
| PIDFILE = CONFIG["pid_file"] |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def start(): |  | ||||||
|     if not os.path.exists(CONFIGFILE): |  | ||||||
|     sys.stderr.write( |     sys.stderr.write( | ||||||
|         "No config file found\n" |         "No config file found\n" | ||||||
|         "To generate a config file, run '%s -c %s --generate-config" |         "To generate a config file, run '%s -c %s --generate-config" | ||||||
|  | @ -41,6 +36,12 @@ def start(): | ||||||
|         ) |         ) | ||||||
|     ) |     ) | ||||||
|     sys.exit(1) |     sys.exit(1) | ||||||
|  | 
 | ||||||
|  | CONFIG = yaml.load(open(CONFIGFILE)) | ||||||
|  | PIDFILE = CONFIG["pid_file"] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def start(): | ||||||
|     print "Starting ...", |     print "Starting ...", | ||||||
|     args = SYNAPSE |     args = SYNAPSE | ||||||
|     args.extend(["--daemonize", "-c", CONFIGFILE]) |     args.extend(["--daemonize", "-c", CONFIGFILE]) | ||||||
|  |  | ||||||
|  | @ -144,6 +144,7 @@ class Config(object): | ||||||
|         ) |         ) | ||||||
|         config_args, remaining_args = config_parser.parse_known_args(argv) |         config_args, remaining_args = config_parser.parse_known_args(argv) | ||||||
| 
 | 
 | ||||||
|  |         if config_args.generate_config: | ||||||
|             if not config_args.config_path: |             if not config_args.config_path: | ||||||
|                 config_parser.error( |                 config_parser.error( | ||||||
|                     "Must supply a config file.\nA config file can be automatically" |                     "Must supply a config file.\nA config file can be automatically" | ||||||
|  | @ -153,7 +154,7 @@ class Config(object): | ||||||
| 
 | 
 | ||||||
|             config_dir_path = os.path.dirname(config_args.config_path[0]) |             config_dir_path = os.path.dirname(config_args.config_path[0]) | ||||||
|             config_dir_path = os.path.abspath(config_dir_path) |             config_dir_path = os.path.abspath(config_dir_path) | ||||||
|         if config_args.generate_config: | 
 | ||||||
|             server_name = config_args.server_name |             server_name = config_args.server_name | ||||||
|             if not server_name: |             if not server_name: | ||||||
|                 print "Must specify a server_name to a generate config for." |                 print "Must specify a server_name to a generate config for." | ||||||
|  | @ -196,6 +197,25 @@ class Config(object): | ||||||
|             ) |             ) | ||||||
|             sys.exit(0) |             sys.exit(0) | ||||||
| 
 | 
 | ||||||
|  |         parser = argparse.ArgumentParser( | ||||||
|  |             parents=[config_parser], | ||||||
|  |             description=description, | ||||||
|  |             formatter_class=argparse.RawDescriptionHelpFormatter, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         obj.invoke_all("add_arguments", parser) | ||||||
|  |         args = parser.parse_args(remaining_args) | ||||||
|  | 
 | ||||||
|  |         if not config_args.config_path: | ||||||
|  |             config_parser.error( | ||||||
|  |                 "Must supply a config file.\nA config file can be automatically" | ||||||
|  |                 " generated using \"--generate-config -h SERVER_NAME" | ||||||
|  |                 " -c CONFIG-FILE\"" | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         config_dir_path = os.path.dirname(config_args.config_path[0]) | ||||||
|  |         config_dir_path = os.path.abspath(config_dir_path) | ||||||
|  | 
 | ||||||
|         specified_config = {} |         specified_config = {} | ||||||
|         for config_path in config_args.config_path: |         for config_path in config_args.config_path: | ||||||
|             yaml_config = cls.read_config_file(config_path) |             yaml_config = cls.read_config_file(config_path) | ||||||
|  | @ -208,15 +228,6 @@ class Config(object): | ||||||
| 
 | 
 | ||||||
|         obj.invoke_all("read_config", config) |         obj.invoke_all("read_config", config) | ||||||
| 
 | 
 | ||||||
|         parser = argparse.ArgumentParser( |  | ||||||
|             parents=[config_parser], |  | ||||||
|             description=description, |  | ||||||
|             formatter_class=argparse.RawDescriptionHelpFormatter, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         obj.invoke_all("add_arguments", parser) |  | ||||||
|         args = parser.parse_args(remaining_args) |  | ||||||
| 
 |  | ||||||
|         obj.invoke_all("read_arguments", args) |         obj.invoke_all("read_arguments", args) | ||||||
| 
 | 
 | ||||||
|         return obj |         return obj | ||||||
|  |  | ||||||
|  | @ -491,7 +491,7 @@ class FederationClient(FederationBase): | ||||||
|             ] |             ] | ||||||
| 
 | 
 | ||||||
|             signed_events = yield self._check_sigs_and_hash_and_fetch( |             signed_events = yield self._check_sigs_and_hash_and_fetch( | ||||||
|                 destination, events, outlier=True |                 destination, events, outlier=False | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|             have_gotten_all_from_destination = True |             have_gotten_all_from_destination = True | ||||||
|  |  | ||||||
|  | @ -23,8 +23,6 @@ from twisted.internet import defer | ||||||
| 
 | 
 | ||||||
| from synapse.util.logutils import log_function | from synapse.util.logutils import log_function | ||||||
| 
 | 
 | ||||||
| from syutil.jsonutil import encode_canonical_json |  | ||||||
| 
 |  | ||||||
| import logging | import logging | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -71,7 +69,7 @@ class TransactionActions(object): | ||||||
|             transaction.transaction_id, |             transaction.transaction_id, | ||||||
|             transaction.origin, |             transaction.origin, | ||||||
|             code, |             code, | ||||||
|             encode_canonical_json(response) |             response, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|  | @ -101,5 +99,5 @@ class TransactionActions(object): | ||||||
|             transaction.transaction_id, |             transaction.transaction_id, | ||||||
|             transaction.destination, |             transaction.destination, | ||||||
|             response_code, |             response_code, | ||||||
|             encode_canonical_json(response_dict) |             response_dict, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  | @ -104,7 +104,6 @@ class TransactionQueue(object): | ||||||
|             return not destination.startswith("localhost") |             return not destination.startswith("localhost") | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     @log_function |  | ||||||
|     def enqueue_pdu(self, pdu, destinations, order): |     def enqueue_pdu(self, pdu, destinations, order): | ||||||
|         # We loop through all destinations to see whether we already have |         # We loop through all destinations to see whether we already have | ||||||
|         # a transaction in progress. If we do, stick it in the pending_pdus |         # a transaction in progress. If we do, stick it in the pending_pdus | ||||||
|  |  | ||||||
|  | @ -31,7 +31,9 @@ import functools | ||||||
| import simplejson as json | import simplejson as json | ||||||
| import sys | import sys | ||||||
| import time | import time | ||||||
|  | import threading | ||||||
| 
 | 
 | ||||||
|  | DEBUG_CACHES = False | ||||||
| 
 | 
 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
|  | @ -68,9 +70,20 @@ class Cache(object): | ||||||
| 
 | 
 | ||||||
|         self.name = name |         self.name = name | ||||||
|         self.keylen = keylen |         self.keylen = keylen | ||||||
| 
 |         self.sequence = 0 | ||||||
|  |         self.thread = None | ||||||
|         caches_by_name[name] = self.cache |         caches_by_name[name] = self.cache | ||||||
| 
 | 
 | ||||||
|  |     def check_thread(self): | ||||||
|  |         expected_thread = self.thread | ||||||
|  |         if expected_thread is None: | ||||||
|  |             self.thread = threading.current_thread() | ||||||
|  |         else: | ||||||
|  |             if expected_thread is not threading.current_thread(): | ||||||
|  |                 raise ValueError( | ||||||
|  |                     "Cache objects can only be accessed from the main thread" | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|     def get(self, *keyargs): |     def get(self, *keyargs): | ||||||
|         if len(keyargs) != self.keylen: |         if len(keyargs) != self.keylen: | ||||||
|             raise ValueError("Expected a key to have %d items", self.keylen) |             raise ValueError("Expected a key to have %d items", self.keylen) | ||||||
|  | @ -82,6 +95,13 @@ class Cache(object): | ||||||
|         cache_counter.inc_misses(self.name) |         cache_counter.inc_misses(self.name) | ||||||
|         raise KeyError() |         raise KeyError() | ||||||
| 
 | 
 | ||||||
|  |     def update(self, sequence, *args): | ||||||
|  |         self.check_thread() | ||||||
|  |         if self.sequence == sequence: | ||||||
|  |             # Only update the cache if the caches sequence number matches the | ||||||
|  |             # number that the cache had before the SELECT was started (SYN-369) | ||||||
|  |             self.prefill(*args) | ||||||
|  | 
 | ||||||
|     def prefill(self, *args):  # because I can't  *keyargs, value |     def prefill(self, *args):  # because I can't  *keyargs, value | ||||||
|         keyargs = args[:-1] |         keyargs = args[:-1] | ||||||
|         value = args[-1] |         value = args[-1] | ||||||
|  | @ -96,9 +116,12 @@ class Cache(object): | ||||||
|         self.cache[keyargs] = value |         self.cache[keyargs] = value | ||||||
| 
 | 
 | ||||||
|     def invalidate(self, *keyargs): |     def invalidate(self, *keyargs): | ||||||
|  |         self.check_thread() | ||||||
|         if len(keyargs) != self.keylen: |         if len(keyargs) != self.keylen: | ||||||
|             raise ValueError("Expected a key to have %d items", self.keylen) |             raise ValueError("Expected a key to have %d items", self.keylen) | ||||||
| 
 |         # Increment the sequence number so that any SELECT statements that | ||||||
|  |         # raced with the INSERT don't update the cache (SYN-369) | ||||||
|  |         self.sequence += 1 | ||||||
|         self.cache.pop(keyargs, None) |         self.cache.pop(keyargs, None) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -128,11 +151,26 @@ def cached(max_entries=1000, num_args=1, lru=False): | ||||||
|         @defer.inlineCallbacks |         @defer.inlineCallbacks | ||||||
|         def wrapped(self, *keyargs): |         def wrapped(self, *keyargs): | ||||||
|             try: |             try: | ||||||
|                 defer.returnValue(cache.get(*keyargs)) |                 cached_result = cache.get(*keyargs) | ||||||
|  |                 if DEBUG_CACHES: | ||||||
|  |                     actual_result = yield orig(self, *keyargs) | ||||||
|  |                     if actual_result != cached_result: | ||||||
|  |                         logger.error( | ||||||
|  |                             "Stale cache entry %s%r: cached: %r, actual %r", | ||||||
|  |                             orig.__name__, keyargs, | ||||||
|  |                             cached_result, actual_result, | ||||||
|  |                         ) | ||||||
|  |                         raise ValueError("Stale cache entry") | ||||||
|  |                 defer.returnValue(cached_result) | ||||||
|             except KeyError: |             except KeyError: | ||||||
|  |                 # Get the sequence number of the cache before reading from the | ||||||
|  |                 # database so that we can tell if the cache is invalidated | ||||||
|  |                 # while the SELECT is executing (SYN-369) | ||||||
|  |                 sequence = cache.sequence | ||||||
|  | 
 | ||||||
|                 ret = yield orig(self, *keyargs) |                 ret = yield orig(self, *keyargs) | ||||||
| 
 | 
 | ||||||
|                 cache.prefill(*keyargs + (ret,)) |                 cache.update(sequence, *keyargs + (ret,)) | ||||||
| 
 | 
 | ||||||
|                 defer.returnValue(ret) |                 defer.returnValue(ret) | ||||||
| 
 | 
 | ||||||
|  | @ -147,12 +185,20 @@ class LoggingTransaction(object): | ||||||
|     """An object that almost-transparently proxies for the 'txn' object |     """An object that almost-transparently proxies for the 'txn' object | ||||||
|     passed to the constructor. Adds logging and metrics to the .execute() |     passed to the constructor. Adds logging and metrics to the .execute() | ||||||
|     method.""" |     method.""" | ||||||
|     __slots__ = ["txn", "name", "database_engine"] |     __slots__ = ["txn", "name", "database_engine", "after_callbacks"] | ||||||
| 
 | 
 | ||||||
|     def __init__(self, txn, name, database_engine): |     def __init__(self, txn, name, database_engine, after_callbacks): | ||||||
|         object.__setattr__(self, "txn", txn) |         object.__setattr__(self, "txn", txn) | ||||||
|         object.__setattr__(self, "name", name) |         object.__setattr__(self, "name", name) | ||||||
|         object.__setattr__(self, "database_engine", database_engine) |         object.__setattr__(self, "database_engine", database_engine) | ||||||
|  |         object.__setattr__(self, "after_callbacks", after_callbacks) | ||||||
|  | 
 | ||||||
|  |     def call_after(self, callback, *args): | ||||||
|  |         """Call the given callback on the main twisted thread after the | ||||||
|  |         transaction has finished. Used to invalidate the caches on the | ||||||
|  |         correct thread. | ||||||
|  |         """ | ||||||
|  |         self.after_callbacks.append((callback, args)) | ||||||
| 
 | 
 | ||||||
|     def __getattr__(self, name): |     def __getattr__(self, name): | ||||||
|         return getattr(self.txn, name) |         return getattr(self.txn, name) | ||||||
|  | @ -160,22 +206,23 @@ class LoggingTransaction(object): | ||||||
|     def __setattr__(self, name, value): |     def __setattr__(self, name, value): | ||||||
|         setattr(self.txn, name, value) |         setattr(self.txn, name, value) | ||||||
| 
 | 
 | ||||||
|     def execute(self, sql, *args, **kwargs): |     def execute(self, sql, *args): | ||||||
|  |         self._do_execute(self.txn.execute, sql, *args) | ||||||
|  | 
 | ||||||
|  |     def executemany(self, sql, *args): | ||||||
|  |         self._do_execute(self.txn.executemany, sql, *args) | ||||||
|  | 
 | ||||||
|  |     def _do_execute(self, func, sql, *args): | ||||||
|         # TODO(paul): Maybe use 'info' and 'debug' for values? |         # TODO(paul): Maybe use 'info' and 'debug' for values? | ||||||
|         sql_logger.debug("[SQL] {%s} %s", self.name, sql) |         sql_logger.debug("[SQL] {%s} %s", self.name, sql) | ||||||
| 
 | 
 | ||||||
|         sql = self.database_engine.convert_param_style(sql) |         sql = self.database_engine.convert_param_style(sql) | ||||||
| 
 | 
 | ||||||
|         if args and args[0]: |         if args: | ||||||
|             args = list(args) |  | ||||||
|             args[0] = [ |  | ||||||
|                 self.database_engine.encode_parameter(a) for a in args[0] |  | ||||||
|             ] |  | ||||||
|             try: |             try: | ||||||
|                 sql_logger.debug( |                 sql_logger.debug( | ||||||
|                     "[SQL values] {%s} " + ", ".join(("<%r>",) * len(args[0])), |                     "[SQL values] {%s} %r", | ||||||
|                     self.name, |                     self.name, args[0] | ||||||
|                     *args[0] |  | ||||||
|                 ) |                 ) | ||||||
|             except: |             except: | ||||||
|                 # Don't let logging failures stop SQL from working |                 # Don't let logging failures stop SQL from working | ||||||
|  | @ -184,8 +231,8 @@ class LoggingTransaction(object): | ||||||
|         start = time.time() * 1000 |         start = time.time() * 1000 | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|             return self.txn.execute( |             return func( | ||||||
|                 sql, *args, **kwargs |                 sql, *args | ||||||
|             ) |             ) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             logger.debug("[SQL FAIL] {%s} %s", self.name, e) |             logger.debug("[SQL FAIL] {%s} %s", self.name, e) | ||||||
|  | @ -298,6 +345,8 @@ class SQLBaseStore(object): | ||||||
| 
 | 
 | ||||||
|         start_time = time.time() * 1000 |         start_time = time.time() * 1000 | ||||||
| 
 | 
 | ||||||
|  |         after_callbacks = [] | ||||||
|  | 
 | ||||||
|         def inner_func(conn, *args, **kwargs): |         def inner_func(conn, *args, **kwargs): | ||||||
|             with LoggingContext("runInteraction") as context: |             with LoggingContext("runInteraction") as context: | ||||||
|                 if self.database_engine.is_connection_closed(conn): |                 if self.database_engine.is_connection_closed(conn): | ||||||
|  | @ -322,10 +371,10 @@ class SQLBaseStore(object): | ||||||
|                     while True: |                     while True: | ||||||
|                         try: |                         try: | ||||||
|                             txn = conn.cursor() |                             txn = conn.cursor() | ||||||
|                             return func( |                             txn = LoggingTransaction( | ||||||
|                                 LoggingTransaction(txn, name, self.database_engine), |                                 txn, name, self.database_engine, after_callbacks | ||||||
|                                 *args, **kwargs |  | ||||||
|                             ) |                             ) | ||||||
|  |                             return func(txn, *args, **kwargs) | ||||||
|                         except self.database_engine.module.OperationalError as e: |                         except self.database_engine.module.OperationalError as e: | ||||||
|                             # This can happen if the database disappears mid |                             # This can happen if the database disappears mid | ||||||
|                             # transaction. |                             # transaction. | ||||||
|  | @ -374,6 +423,8 @@ class SQLBaseStore(object): | ||||||
|             result = yield self._db_pool.runWithConnection( |             result = yield self._db_pool.runWithConnection( | ||||||
|                 inner_func, *args, **kwargs |                 inner_func, *args, **kwargs | ||||||
|             ) |             ) | ||||||
|  |         for after_callback, after_args in after_callbacks: | ||||||
|  |             after_callback(*after_args) | ||||||
|         defer.returnValue(result) |         defer.returnValue(result) | ||||||
| 
 | 
 | ||||||
|     def cursor_to_dict(self, cursor): |     def cursor_to_dict(self, cursor): | ||||||
|  | @ -438,18 +489,49 @@ class SQLBaseStore(object): | ||||||
| 
 | 
 | ||||||
|     @log_function |     @log_function | ||||||
|     def _simple_insert_txn(self, txn, table, values): |     def _simple_insert_txn(self, txn, table, values): | ||||||
|  |         keys, vals = zip(*values.items()) | ||||||
|  | 
 | ||||||
|         sql = "INSERT INTO %s (%s) VALUES(%s)" % ( |         sql = "INSERT INTO %s (%s) VALUES(%s)" % ( | ||||||
|             table, |             table, | ||||||
|             ", ".join(k for k in values), |             ", ".join(k for k in keys), | ||||||
|             ", ".join("?" for k in values) |             ", ".join("?" for _ in keys) | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         logger.debug( |         txn.execute(sql, vals) | ||||||
|             "[SQL] %s Args=%s", | 
 | ||||||
|             sql, values.values(), |     def _simple_insert_many_txn(self, txn, table, values): | ||||||
|  |         if not values: | ||||||
|  |             return | ||||||
|  | 
 | ||||||
|  |         # This is a *slight* abomination to get a list of tuples of key names | ||||||
|  |         # and a list of tuples of value names. | ||||||
|  |         # | ||||||
|  |         # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}] | ||||||
|  |         #         => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)] | ||||||
|  |         # | ||||||
|  |         # The sort is to ensure that we don't rely on dictionary iteration | ||||||
|  |         # order. | ||||||
|  |         keys, vals = zip(*[ | ||||||
|  |             zip( | ||||||
|  |                 *(sorted(i.items(), key=lambda kv: kv[0])) | ||||||
|  |             ) | ||||||
|  |             for i in values | ||||||
|  |             if i | ||||||
|  |         ]) | ||||||
|  | 
 | ||||||
|  |         for k in keys: | ||||||
|  |             if k != keys[0]: | ||||||
|  |                 raise RuntimeError( | ||||||
|  |                     "All items must have the same keys" | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|         txn.execute(sql, values.values()) |         sql = "INSERT INTO %s (%s) VALUES(%s)" % ( | ||||||
|  |             table, | ||||||
|  |             ", ".join(k for k in keys[0]), | ||||||
|  |             ", ".join("?" for _ in keys[0]) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         txn.executemany(sql, vals) | ||||||
| 
 | 
 | ||||||
|     def _simple_upsert(self, table, keyvalues, values, |     def _simple_upsert(self, table, keyvalues, values, | ||||||
|                        insertion_values={}, desc="_simple_upsert", lock=True): |                        insertion_values={}, desc="_simple_upsert", lock=True): | ||||||
|  |  | ||||||
|  | @ -36,9 +36,6 @@ class PostgresEngine(object): | ||||||
|     def convert_param_style(self, sql): |     def convert_param_style(self, sql): | ||||||
|         return sql.replace("?", "%s") |         return sql.replace("?", "%s") | ||||||
| 
 | 
 | ||||||
|     def encode_parameter(self, param): |  | ||||||
|         return param |  | ||||||
| 
 |  | ||||||
|     def on_new_connection(self, db_conn): |     def on_new_connection(self, db_conn): | ||||||
|         db_conn.set_isolation_level( |         db_conn.set_isolation_level( | ||||||
|             self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ |             self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ | ||||||
|  |  | ||||||
|  | @ -26,9 +26,6 @@ class Sqlite3Engine(object): | ||||||
|     def convert_param_style(self, sql): |     def convert_param_style(self, sql): | ||||||
|         return sql |         return sql | ||||||
| 
 | 
 | ||||||
|     def encode_parameter(self, param): |  | ||||||
|         return param |  | ||||||
| 
 |  | ||||||
|     def on_new_connection(self, db_conn): |     def on_new_connection(self, db_conn): | ||||||
|         self.prepare_database(db_conn) |         self.prepare_database(db_conn) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -104,7 +104,7 @@ class EventFederationStore(SQLBaseStore): | ||||||
|                 "room_id": room_id, |                 "room_id": room_id, | ||||||
|             }, |             }, | ||||||
|             retcol="event_id", |             retcol="event_id", | ||||||
|             desc="get_latest_events_in_room", |             desc="get_latest_event_ids_in_room", | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def _get_latest_events_in_room(self, txn, room_id): |     def _get_latest_events_in_room(self, txn, room_id): | ||||||
|  | @ -262,17 +262,18 @@ class EventFederationStore(SQLBaseStore): | ||||||
|         For the given event, update the event edges table and forward and |         For the given event, update the event edges table and forward and | ||||||
|         backward extremities tables. |         backward extremities tables. | ||||||
|         """ |         """ | ||||||
|         for e_id, _ in prev_events: |         self._simple_insert_many_txn( | ||||||
|             # TODO (erikj): This could be done as a bulk insert |  | ||||||
|             self._simple_insert_txn( |  | ||||||
|             txn, |             txn, | ||||||
|             table="event_edges", |             table="event_edges", | ||||||
|                 values={ |             values=[ | ||||||
|  |                 { | ||||||
|                     "event_id": event_id, |                     "event_id": event_id, | ||||||
|                     "prev_event_id": e_id, |                     "prev_event_id": e_id, | ||||||
|                     "room_id": room_id, |                     "room_id": room_id, | ||||||
|                     "is_state": False, |                     "is_state": False, | ||||||
|                 }, |                 } | ||||||
|  |                 for e_id, _ in prev_events | ||||||
|  |             ], | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         # Update the extremities table if this is not an outlier. |         # Update the extremities table if this is not an outlier. | ||||||
|  | @ -307,15 +308,16 @@ class EventFederationStore(SQLBaseStore): | ||||||
| 
 | 
 | ||||||
|             # Insert all the prev_events as a backwards thing, they'll get |             # Insert all the prev_events as a backwards thing, they'll get | ||||||
|             # deleted in a second if they're incorrect anyway. |             # deleted in a second if they're incorrect anyway. | ||||||
|             for e_id, _ in prev_events: |             self._simple_insert_many_txn( | ||||||
|                 # TODO (erikj): This could be done as a bulk insert |  | ||||||
|                 self._simple_insert_txn( |  | ||||||
|                 txn, |                 txn, | ||||||
|                 table="event_backward_extremities", |                 table="event_backward_extremities", | ||||||
|                     values={ |                 values=[ | ||||||
|  |                     { | ||||||
|                         "event_id": e_id, |                         "event_id": e_id, | ||||||
|                         "room_id": room_id, |                         "room_id": room_id, | ||||||
|                     }, |                     } | ||||||
|  |                     for e_id, _ in prev_events | ||||||
|  |                 ], | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|             # Also delete from the backwards extremities table all ones that |             # Also delete from the backwards extremities table all ones that | ||||||
|  | @ -330,7 +332,9 @@ class EventFederationStore(SQLBaseStore): | ||||||
|             ) |             ) | ||||||
|             txn.execute(query) |             txn.execute(query) | ||||||
| 
 | 
 | ||||||
|             self.get_latest_event_ids_in_room.invalidate(room_id) |             txn.call_after( | ||||||
|  |                 self.get_latest_event_ids_in_room.invalidate, room_id | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|     def get_backfill_events(self, room_id, event_list, limit): |     def get_backfill_events(self, room_id, event_list, limit): | ||||||
|         """Get a list of Events for a given topic that occurred before (and |         """Get a list of Events for a given topic that occurred before (and | ||||||
|  |  | ||||||
|  | @ -93,7 +93,7 @@ class EventsStore(SQLBaseStore): | ||||||
|                            current_state=None): |                            current_state=None): | ||||||
| 
 | 
 | ||||||
|         # Remove the any existing cache entries for the event_id |         # Remove the any existing cache entries for the event_id | ||||||
|         self._invalidate_get_event_cache(event.event_id) |         txn.call_after(self._invalidate_get_event_cache, event.event_id) | ||||||
| 
 | 
 | ||||||
|         if stream_ordering is None: |         if stream_ordering is None: | ||||||
|             with self._stream_id_gen.get_next_txn(txn) as stream_ordering: |             with self._stream_id_gen.get_next_txn(txn) as stream_ordering: | ||||||
|  | @ -114,6 +114,13 @@ class EventsStore(SQLBaseStore): | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|             for s in current_state: |             for s in current_state: | ||||||
|  |                 if s.type == EventTypes.Member: | ||||||
|  |                     txn.call_after( | ||||||
|  |                         self.get_rooms_for_user.invalidate, s.state_key | ||||||
|  |                     ) | ||||||
|  |                     txn.call_after( | ||||||
|  |                         self.get_joined_hosts_for_room.invalidate, s.room_id | ||||||
|  |                     ) | ||||||
|                 self._simple_insert_txn( |                 self._simple_insert_txn( | ||||||
|                     txn, |                     txn, | ||||||
|                     "current_state_events", |                     "current_state_events", | ||||||
|  | @ -122,28 +129,6 @@ class EventsStore(SQLBaseStore): | ||||||
|                         "room_id": s.room_id, |                         "room_id": s.room_id, | ||||||
|                         "type": s.type, |                         "type": s.type, | ||||||
|                         "state_key": s.state_key, |                         "state_key": s.state_key, | ||||||
|                     }, |  | ||||||
|                 ) |  | ||||||
| 
 |  | ||||||
|         if event.is_state() and is_new_state: |  | ||||||
|             if not backfilled and not context.rejected: |  | ||||||
|                 self._simple_insert_txn( |  | ||||||
|                     txn, |  | ||||||
|                     table="state_forward_extremities", |  | ||||||
|                     values={ |  | ||||||
|                         "event_id": event.event_id, |  | ||||||
|                         "room_id": event.room_id, |  | ||||||
|                         "type": event.type, |  | ||||||
|                         "state_key": event.state_key, |  | ||||||
|                     }, |  | ||||||
|                 ) |  | ||||||
| 
 |  | ||||||
|                 for prev_state_id, _ in event.prev_state: |  | ||||||
|                     self._simple_delete_txn( |  | ||||||
|                         txn, |  | ||||||
|                         table="state_forward_extremities", |  | ||||||
|                         keyvalues={ |  | ||||||
|                             "event_id": prev_state_id, |  | ||||||
|                     } |                     } | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|  | @ -281,7 +266,9 @@ class EventsStore(SQLBaseStore): | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         if context.rejected: |         if context.rejected: | ||||||
|             self._store_rejections_txn(txn, event.event_id, context.rejected) |             self._store_rejections_txn( | ||||||
|  |                 txn, event.event_id, context.rejected | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|         for hash_alg, hash_base64 in event.hashes.items(): |         for hash_alg, hash_base64 in event.hashes.items(): | ||||||
|             hash_bytes = decode_base64(hash_base64) |             hash_bytes = decode_base64(hash_base64) | ||||||
|  | @ -293,18 +280,21 @@ class EventsStore(SQLBaseStore): | ||||||
|             for alg, hash_base64 in prev_hashes.items(): |             for alg, hash_base64 in prev_hashes.items(): | ||||||
|                 hash_bytes = decode_base64(hash_base64) |                 hash_bytes = decode_base64(hash_base64) | ||||||
|                 self._store_prev_event_hash_txn( |                 self._store_prev_event_hash_txn( | ||||||
|                     txn, event.event_id, prev_event_id, alg, hash_bytes |                     txn, event.event_id, prev_event_id, alg, | ||||||
|  |                     hash_bytes | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|         for auth_id, _ in event.auth_events: |         self._simple_insert_many_txn( | ||||||
|             self._simple_insert_txn( |  | ||||||
|             txn, |             txn, | ||||||
|             table="event_auth", |             table="event_auth", | ||||||
|                 values={ |             values=[ | ||||||
|  |                 { | ||||||
|                     "event_id": event.event_id, |                     "event_id": event.event_id, | ||||||
|                     "room_id": event.room_id, |                     "room_id": event.room_id, | ||||||
|                     "auth_id": auth_id, |                     "auth_id": auth_id, | ||||||
|                 }, |                 } | ||||||
|  |                 for auth_id, _ in event.auth_events | ||||||
|  |             ], | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event) |         (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event) | ||||||
|  | @ -330,16 +320,18 @@ class EventsStore(SQLBaseStore): | ||||||
|                 vals, |                 vals, | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|             for e_id, h in event.prev_state: |             self._simple_insert_many_txn( | ||||||
|                 self._simple_insert_txn( |  | ||||||
|                 txn, |                 txn, | ||||||
|                 table="event_edges", |                 table="event_edges", | ||||||
|                     values={ |                 values=[ | ||||||
|  |                     { | ||||||
|                         "event_id": event.event_id, |                         "event_id": event.event_id, | ||||||
|                         "prev_event_id": e_id, |                         "prev_event_id": e_id, | ||||||
|                         "room_id": event.room_id, |                         "room_id": event.room_id, | ||||||
|                         "is_state": True, |                         "is_state": True, | ||||||
|                     }, |                     } | ||||||
|  |                     for e_id, h in event.prev_state | ||||||
|  |                 ], | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|             if is_new_state and not context.rejected: |             if is_new_state and not context.rejected: | ||||||
|  | @ -356,9 +348,11 @@ class EventsStore(SQLBaseStore): | ||||||
|                     } |                     } | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|  |         return | ||||||
|  | 
 | ||||||
|     def _store_redaction(self, txn, event): |     def _store_redaction(self, txn, event): | ||||||
|         # invalidate the cache for the redacted event |         # invalidate the cache for the redacted event | ||||||
|         self._invalidate_get_event_cache(event.redacts) |         txn.call_after(self._invalidate_get_event_cache, event.redacts) | ||||||
|         txn.execute( |         txn.execute( | ||||||
|             "INSERT INTO redactions (event_id, redacts) VALUES (?,?)", |             "INSERT INTO redactions (event_id, redacts) VALUES (?,?)", | ||||||
|             (event.event_id, event.redacts) |             (event.event_id, event.redacts) | ||||||
|  |  | ||||||
|  | @ -64,8 +64,8 @@ class RoomMemberStore(SQLBaseStore): | ||||||
|             } |             } | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         self.get_rooms_for_user.invalidate(target_user_id) |         txn.call_after(self.get_rooms_for_user.invalidate, target_user_id) | ||||||
|         self.get_joined_hosts_for_room.invalidate(event.room_id) |         txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) | ||||||
| 
 | 
 | ||||||
|     def get_room_member(self, user_id, room_id): |     def get_room_member(self, user_id, room_id): | ||||||
|         """Retrieve the current state of a room member. |         """Retrieve the current state of a room member. | ||||||
|  |  | ||||||
							
								
								
									
										18
									
								
								synapse/storage/schema/delta/17/drop_indexes.sql
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								synapse/storage/schema/delta/17/drop_indexes.sql
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,18 @@ | ||||||
|  | /* Copyright 2015 OpenMarket Ltd | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  *    http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | DROP INDEX IF EXISTS sent_transaction_dest; | ||||||
|  | DROP INDEX IF EXISTS sent_transaction_sent; | ||||||
|  | DROP INDEX IF EXISTS user_ips_user; | ||||||
|  | @ -104,17 +104,19 @@ class StateStore(SQLBaseStore): | ||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|             for state in state_events.values(): |             self._simple_insert_many_txn( | ||||||
|                 self._simple_insert_txn( |  | ||||||
|                 txn, |                 txn, | ||||||
|                 table="state_groups_state", |                 table="state_groups_state", | ||||||
|                     values={ |                 values=[ | ||||||
|  |                     { | ||||||
|                         "state_group": state_group, |                         "state_group": state_group, | ||||||
|                         "room_id": state.room_id, |                         "room_id": state.room_id, | ||||||
|                         "type": state.type, |                         "type": state.type, | ||||||
|                         "state_key": state.state_key, |                         "state_key": state.state_key, | ||||||
|                         "event_id": state.event_id, |                         "event_id": state.event_id, | ||||||
|                     }, |                     } | ||||||
|  |                     for state in state_events.values() | ||||||
|  |                 ], | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         self._simple_insert_txn( |         self._simple_insert_txn( | ||||||
|  |  | ||||||
|  | @ -17,6 +17,7 @@ from ._base import SQLBaseStore, cached | ||||||
| 
 | 
 | ||||||
| from collections import namedtuple | from collections import namedtuple | ||||||
| 
 | 
 | ||||||
|  | from syutil.jsonutil import encode_canonical_json | ||||||
| import logging | import logging | ||||||
| 
 | 
 | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  | @ -82,7 +83,7 @@ class TransactionStore(SQLBaseStore): | ||||||
|                 "transaction_id": transaction_id, |                 "transaction_id": transaction_id, | ||||||
|                 "origin": origin, |                 "origin": origin, | ||||||
|                 "response_code": code, |                 "response_code": code, | ||||||
|                 "response_json": response_dict, |                 "response_json": buffer(encode_canonical_json(response_dict)), | ||||||
|             }, |             }, | ||||||
|             or_ignore=True, |             or_ignore=True, | ||||||
|             desc="set_received_txn_response", |             desc="set_received_txn_response", | ||||||
|  | @ -161,7 +162,8 @@ class TransactionStore(SQLBaseStore): | ||||||
|         return self.runInteraction( |         return self.runInteraction( | ||||||
|             "delivered_txn", |             "delivered_txn", | ||||||
|             self._delivered_txn, |             self._delivered_txn, | ||||||
|             transaction_id, destination, code, response_dict |             transaction_id, destination, code, | ||||||
|  |             buffer(encode_canonical_json(response_dict)), | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def _delivered_txn(self, txn, transaction_id, destination, |     def _delivered_txn(self, txn, transaction_id, destination, | ||||||
|  |  | ||||||
|  | @ -67,7 +67,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): | ||||||
| 
 | 
 | ||||||
|         self.mock_txn.execute.assert_called_with( |         self.mock_txn.execute.assert_called_with( | ||||||
|                 "INSERT INTO tablename (columname) VALUES(?)", |                 "INSERT INTO tablename (columname) VALUES(?)", | ||||||
|                 ["Value"] |                 ("Value",) | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|  | @ -82,7 +82,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): | ||||||
| 
 | 
 | ||||||
|         self.mock_txn.execute.assert_called_with( |         self.mock_txn.execute.assert_called_with( | ||||||
|                 "INSERT INTO tablename (colA, colB, colC) VALUES(?, ?, ?)", |                 "INSERT INTO tablename (colA, colB, colC) VALUES(?, ?, ?)", | ||||||
|                 [1, 2, 3] |                 (1, 2, 3,) | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 David Baker
						David Baker