mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-10-01 11:49:51 -04:00
Merge pull request #534 from matrix-org/erikj/setup
Add a Homeserver.setup method
This commit is contained in:
commit
167d1df699
@ -254,6 +254,18 @@ class SynapseHomeServer(HomeServer):
|
|||||||
except IncorrectDatabaseSetup as e:
|
except IncorrectDatabaseSetup as e:
|
||||||
quit_with_error(e.message)
|
quit_with_error(e.message)
|
||||||
|
|
||||||
|
def get_db_conn(self):
|
||||||
|
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||||
|
# not be passed to the database engine.
|
||||||
|
db_params = {
|
||||||
|
k: v for k, v in self.db_config.get("args", {}).items()
|
||||||
|
if not k.startswith("cp_")
|
||||||
|
}
|
||||||
|
db_conn = self.database_engine.module.connect(**db_params)
|
||||||
|
|
||||||
|
self.database_engine.on_new_connection(db_conn)
|
||||||
|
return db_conn
|
||||||
|
|
||||||
|
|
||||||
def quit_with_error(error_string):
|
def quit_with_error(error_string):
|
||||||
message_lines = error_string.split("\n")
|
message_lines = error_string.split("\n")
|
||||||
@ -390,13 +402,7 @@ def setup(config_options):
|
|||||||
logger.info("Preparing database: %s...", config.database_config['name'])
|
logger.info("Preparing database: %s...", config.database_config['name'])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db_conn = database_engine.module.connect(
|
db_conn = hs.get_db_conn()
|
||||||
**{
|
|
||||||
k: v for k, v in config.database_config.get("args", {}).items()
|
|
||||||
if not k.startswith("cp_")
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
database_engine.prepare_database(db_conn)
|
database_engine.prepare_database(db_conn)
|
||||||
hs.run_startup_checks(db_conn, database_engine)
|
hs.run_startup_checks(db_conn, database_engine)
|
||||||
|
|
||||||
@ -411,14 +417,18 @@ def setup(config_options):
|
|||||||
|
|
||||||
logger.info("Database prepared in %s.", config.database_config['name'])
|
logger.info("Database prepared in %s.", config.database_config['name'])
|
||||||
|
|
||||||
|
hs.setup()
|
||||||
hs.start_listening()
|
hs.start_listening()
|
||||||
|
|
||||||
|
def start():
|
||||||
hs.get_pusherpool().start()
|
hs.get_pusherpool().start()
|
||||||
hs.get_state_handler().start_caching()
|
hs.get_state_handler().start_caching()
|
||||||
hs.get_datastore().start_profiling()
|
hs.get_datastore().start_profiling()
|
||||||
hs.get_datastore().start_doing_background_updates()
|
hs.get_datastore().start_doing_background_updates()
|
||||||
hs.get_replication_layer().start_get_pdu_cache()
|
hs.get_replication_layer().start_get_pdu_cache()
|
||||||
|
|
||||||
|
reactor.callWhenRunning(start)
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,6 +40,11 @@ from synapse.api.filtering import Filtering
|
|||||||
|
|
||||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class HomeServer(object):
|
class HomeServer(object):
|
||||||
"""A basic homeserver object without lazy component builders.
|
"""A basic homeserver object without lazy component builders.
|
||||||
@ -102,10 +107,19 @@ class HomeServer(object):
|
|||||||
self.hostname = hostname
|
self.hostname = hostname
|
||||||
self._building = {}
|
self._building = {}
|
||||||
|
|
||||||
|
self.clock = Clock()
|
||||||
|
self.distributor = Distributor()
|
||||||
|
self.ratelimiter = Ratelimiter()
|
||||||
|
|
||||||
# Other kwargs are explicit dependencies
|
# Other kwargs are explicit dependencies
|
||||||
for depname in kwargs:
|
for depname in kwargs:
|
||||||
setattr(self, depname, kwargs[depname])
|
setattr(self, depname, kwargs[depname])
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
logger.info("Setting up.")
|
||||||
|
self.datastore = DataStore(self.get_db_conn(), self)
|
||||||
|
logger.info("Finished setting up.")
|
||||||
|
|
||||||
def get_ip_from_request(self, request):
|
def get_ip_from_request(self, request):
|
||||||
# X-Forwarded-For is handled by our custom request type.
|
# X-Forwarded-For is handled by our custom request type.
|
||||||
return request.getClientIP()
|
return request.getClientIP()
|
||||||
@ -116,15 +130,9 @@ class HomeServer(object):
|
|||||||
def is_mine_id(self, string):
|
def is_mine_id(self, string):
|
||||||
return string.split(":", 1)[1] == self.hostname
|
return string.split(":", 1)[1] == self.hostname
|
||||||
|
|
||||||
def build_clock(self):
|
|
||||||
return Clock()
|
|
||||||
|
|
||||||
def build_replication_layer(self):
|
def build_replication_layer(self):
|
||||||
return initialize_http_replication(self)
|
return initialize_http_replication(self)
|
||||||
|
|
||||||
def build_datastore(self):
|
|
||||||
return DataStore(self)
|
|
||||||
|
|
||||||
def build_handlers(self):
|
def build_handlers(self):
|
||||||
return Handlers(self)
|
return Handlers(self)
|
||||||
|
|
||||||
@ -135,10 +143,9 @@ class HomeServer(object):
|
|||||||
return Auth(self)
|
return Auth(self)
|
||||||
|
|
||||||
def build_http_client_context_factory(self):
|
def build_http_client_context_factory(self):
|
||||||
config = self.get_config()
|
|
||||||
return (
|
return (
|
||||||
InsecureInterceptableContextFactory()
|
InsecureInterceptableContextFactory()
|
||||||
if config.use_insecure_ssl_client_just_for_testing_do_not_use
|
if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
|
||||||
else BrowserLikePolicyForHTTPS()
|
else BrowserLikePolicyForHTTPS()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -157,15 +164,9 @@ class HomeServer(object):
|
|||||||
def build_state_handler(self):
|
def build_state_handler(self):
|
||||||
return StateHandler(self)
|
return StateHandler(self)
|
||||||
|
|
||||||
def build_distributor(self):
|
|
||||||
return Distributor()
|
|
||||||
|
|
||||||
def build_event_sources(self):
|
def build_event_sources(self):
|
||||||
return EventSources(self)
|
return EventSources(self)
|
||||||
|
|
||||||
def build_ratelimiter(self):
|
|
||||||
return Ratelimiter()
|
|
||||||
|
|
||||||
def build_keyring(self):
|
def build_keyring(self):
|
||||||
return Keyring(self)
|
return Keyring(self)
|
||||||
|
|
||||||
|
@ -46,6 +46,9 @@ from .tags import TagsStore
|
|||||||
from .account_data import AccountDataStore
|
from .account_data import AccountDataStore
|
||||||
|
|
||||||
|
|
||||||
|
from util.id_generators import IdGenerator, StreamIdGenerator
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
@ -79,18 +82,43 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||||||
EventPushActionsStore
|
EventPushActionsStore
|
||||||
):
|
):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(DataStore, self).__init__(hs)
|
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
self.min_token_deferred = self._get_min_token()
|
cur = db_conn.cursor()
|
||||||
self.min_token = None
|
try:
|
||||||
|
cur.execute("SELECT MIN(stream_ordering) FROM events",)
|
||||||
|
rows = cur.fetchall()
|
||||||
|
self.min_stream_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1
|
||||||
|
self.min_stream_token = min(self.min_stream_token, -1)
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
|
||||||
self.client_ip_last_seen = Cache(
|
self.client_ip_last_seen = Cache(
|
||||||
name="client_ip_last_seen",
|
name="client_ip_last_seen",
|
||||||
keylen=4,
|
keylen=4,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._stream_id_gen = StreamIdGenerator(
|
||||||
|
db_conn, "events", "stream_ordering"
|
||||||
|
)
|
||||||
|
self._receipts_id_gen = StreamIdGenerator(
|
||||||
|
db_conn, "receipts_linearized", "stream_id"
|
||||||
|
)
|
||||||
|
self._account_data_id_gen = StreamIdGenerator(
|
||||||
|
db_conn, "account_data_max_stream_id", "stream_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
|
||||||
|
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
|
||||||
|
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
|
||||||
|
self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
|
||||||
|
self._pushers_id_gen = IdGenerator("pushers", "id", self)
|
||||||
|
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
|
||||||
|
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
|
||||||
|
|
||||||
|
super(DataStore, self).__init__(hs)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def insert_client_ip(self, user, access_token, ip, user_agent):
|
def insert_client_ip(self, user, access_token, ip, user_agent):
|
||||||
now = int(self._clock.time_msec())
|
now = int(self._clock.time_msec())
|
||||||
|
@ -15,13 +15,11 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.util.logutils import log_function
|
|
||||||
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
|
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
|
||||||
from synapse.util.caches.dictionary_cache import DictionaryCache
|
from synapse.util.caches.dictionary_cache import DictionaryCache
|
||||||
from synapse.util.caches.descriptors import Cache
|
from synapse.util.caches.descriptors import Cache
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
from util.id_generators import IdGenerator, StreamIdGenerator
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
@ -175,16 +173,6 @@ class SQLBaseStore(object):
|
|||||||
|
|
||||||
self.database_engine = hs.database_engine
|
self.database_engine = hs.database_engine
|
||||||
|
|
||||||
self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
|
|
||||||
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
|
|
||||||
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
|
|
||||||
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
|
|
||||||
self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
|
|
||||||
self._pushers_id_gen = IdGenerator("pushers", "id", self)
|
|
||||||
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
|
|
||||||
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
|
|
||||||
self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
|
|
||||||
|
|
||||||
def start_profiling(self):
|
def start_profiling(self):
|
||||||
self._previous_loop_ts = self._clock.time_msec()
|
self._previous_loop_ts = self._clock.time_msec()
|
||||||
|
|
||||||
@ -345,7 +333,8 @@ class SQLBaseStore(object):
|
|||||||
|
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
def cursor_to_dict(self, cursor):
|
@staticmethod
|
||||||
|
def cursor_to_dict(cursor):
|
||||||
"""Converts a SQL cursor into an list of dicts.
|
"""Converts a SQL cursor into an list of dicts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -402,8 +391,8 @@ class SQLBaseStore(object):
|
|||||||
if not or_ignore:
|
if not or_ignore:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@log_function
|
@staticmethod
|
||||||
def _simple_insert_txn(self, txn, table, values):
|
def _simple_insert_txn(txn, table, values):
|
||||||
keys, vals = zip(*values.items())
|
keys, vals = zip(*values.items())
|
||||||
|
|
||||||
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
|
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
|
||||||
@ -414,7 +403,8 @@ class SQLBaseStore(object):
|
|||||||
|
|
||||||
txn.execute(sql, vals)
|
txn.execute(sql, vals)
|
||||||
|
|
||||||
def _simple_insert_many_txn(self, txn, table, values):
|
@staticmethod
|
||||||
|
def _simple_insert_many_txn(txn, table, values):
|
||||||
if not values:
|
if not values:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -537,9 +527,10 @@ class SQLBaseStore(object):
|
|||||||
table, keyvalues, retcol, allow_none=allow_none,
|
table, keyvalues, retcol, allow_none=allow_none,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol,
|
@classmethod
|
||||||
|
def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol,
|
||||||
allow_none=False):
|
allow_none=False):
|
||||||
ret = self._simple_select_onecol_txn(
|
ret = cls._simple_select_onecol_txn(
|
||||||
txn,
|
txn,
|
||||||
table=table,
|
table=table,
|
||||||
keyvalues=keyvalues,
|
keyvalues=keyvalues,
|
||||||
@ -554,7 +545,8 @@ class SQLBaseStore(object):
|
|||||||
else:
|
else:
|
||||||
raise StoreError(404, "No row found")
|
raise StoreError(404, "No row found")
|
||||||
|
|
||||||
def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
|
@staticmethod
|
||||||
|
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
|
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
|
||||||
) % {
|
) % {
|
||||||
@ -603,7 +595,8 @@ class SQLBaseStore(object):
|
|||||||
table, keyvalues, retcols
|
table, keyvalues, retcols
|
||||||
)
|
)
|
||||||
|
|
||||||
def _simple_select_list_txn(self, txn, table, keyvalues, retcols):
|
@classmethod
|
||||||
|
def _simple_select_list_txn(cls, txn, table, keyvalues, retcols):
|
||||||
"""Executes a SELECT query on the named table, which may return zero or
|
"""Executes a SELECT query on the named table, which may return zero or
|
||||||
more rows, returning the result as a list of dicts.
|
more rows, returning the result as a list of dicts.
|
||||||
|
|
||||||
@ -627,7 +620,7 @@ class SQLBaseStore(object):
|
|||||||
)
|
)
|
||||||
txn.execute(sql)
|
txn.execute(sql)
|
||||||
|
|
||||||
return self.cursor_to_dict(txn)
|
return cls.cursor_to_dict(txn)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _simple_select_many_batch(self, table, column, iterable, retcols,
|
def _simple_select_many_batch(self, table, column, iterable, retcols,
|
||||||
@ -662,7 +655,8 @@ class SQLBaseStore(object):
|
|||||||
|
|
||||||
defer.returnValue(results)
|
defer.returnValue(results)
|
||||||
|
|
||||||
def _simple_select_many_txn(self, txn, table, column, iterable, keyvalues, retcols):
|
@classmethod
|
||||||
|
def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
|
||||||
"""Executes a SELECT query on the named table, which may return zero or
|
"""Executes a SELECT query on the named table, which may return zero or
|
||||||
more rows, returning the result as a list of dicts.
|
more rows, returning the result as a list of dicts.
|
||||||
|
|
||||||
@ -699,7 +693,7 @@ class SQLBaseStore(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
txn.execute(sql, values)
|
txn.execute(sql, values)
|
||||||
return self.cursor_to_dict(txn)
|
return cls.cursor_to_dict(txn)
|
||||||
|
|
||||||
def _simple_update_one(self, table, keyvalues, updatevalues,
|
def _simple_update_one(self, table, keyvalues, updatevalues,
|
||||||
desc="_simple_update_one"):
|
desc="_simple_update_one"):
|
||||||
@ -726,7 +720,8 @@ class SQLBaseStore(object):
|
|||||||
table, keyvalues, updatevalues,
|
table, keyvalues, updatevalues,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues):
|
@staticmethod
|
||||||
|
def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
|
||||||
update_sql = "UPDATE %s SET %s WHERE %s" % (
|
update_sql = "UPDATE %s SET %s WHERE %s" % (
|
||||||
table,
|
table,
|
||||||
", ".join("%s = ?" % (k,) for k in updatevalues),
|
", ".join("%s = ?" % (k,) for k in updatevalues),
|
||||||
@ -743,7 +738,8 @@ class SQLBaseStore(object):
|
|||||||
if txn.rowcount > 1:
|
if txn.rowcount > 1:
|
||||||
raise StoreError(500, "More than one row matched")
|
raise StoreError(500, "More than one row matched")
|
||||||
|
|
||||||
def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
|
@staticmethod
|
||||||
|
def _simple_select_one_txn(txn, table, keyvalues, retcols,
|
||||||
allow_none=False):
|
allow_none=False):
|
||||||
select_sql = "SELECT %s FROM %s WHERE %s" % (
|
select_sql = "SELECT %s FROM %s WHERE %s" % (
|
||||||
", ".join(retcols),
|
", ".join(retcols),
|
||||||
@ -784,7 +780,8 @@ class SQLBaseStore(object):
|
|||||||
raise StoreError(500, "more than one row matched")
|
raise StoreError(500, "more than one row matched")
|
||||||
return self.runInteraction(desc, func)
|
return self.runInteraction(desc, func)
|
||||||
|
|
||||||
def _simple_delete_txn(self, txn, table, keyvalues):
|
@staticmethod
|
||||||
|
def _simple_delete_txn(txn, table, keyvalues):
|
||||||
sql = "DELETE FROM %s WHERE %s" % (
|
sql = "DELETE FROM %s WHERE %s" % (
|
||||||
table,
|
table,
|
||||||
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
|
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
|
||||||
|
@ -66,11 +66,9 @@ class EventsStore(SQLBaseStore):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if backfilled:
|
if backfilled:
|
||||||
if not self.min_token_deferred.called:
|
start = self.min_stream_token - 1
|
||||||
yield self.min_token_deferred
|
self.min_stream_token -= len(events_and_contexts) + 1
|
||||||
start = self.min_token - 1
|
stream_orderings = range(start, self.min_stream_token, -1)
|
||||||
self.min_token -= len(events_and_contexts) + 1
|
|
||||||
stream_orderings = range(start, self.min_token, -1)
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def stream_ordering_manager():
|
def stream_ordering_manager():
|
||||||
@ -107,10 +105,8 @@ class EventsStore(SQLBaseStore):
|
|||||||
is_new_state=True, current_state=None):
|
is_new_state=True, current_state=None):
|
||||||
stream_ordering = None
|
stream_ordering = None
|
||||||
if backfilled:
|
if backfilled:
|
||||||
if not self.min_token_deferred.called:
|
self.min_stream_token -= 1
|
||||||
yield self.min_token_deferred
|
stream_ordering = self.min_stream_token
|
||||||
self.min_token -= 1
|
|
||||||
stream_ordering = self.min_token
|
|
||||||
|
|
||||||
if stream_ordering is None:
|
if stream_ordering is None:
|
||||||
stream_ordering_manager = yield self._stream_id_gen.get_next(self)
|
stream_ordering_manager = yield self._stream_id_gen.get_next(self)
|
||||||
|
@ -31,7 +31,9 @@ class ReceiptsStore(SQLBaseStore):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(ReceiptsStore, self).__init__(hs)
|
super(ReceiptsStore, self).__init__(hs)
|
||||||
|
|
||||||
self._receipts_stream_cache = _RoomStreamChangeCache()
|
self._receipts_stream_cache = _RoomStreamChangeCache(
|
||||||
|
self._receipts_id_gen.get_max_token(None)
|
||||||
|
)
|
||||||
|
|
||||||
@cached(num_args=2)
|
@cached(num_args=2)
|
||||||
def get_receipts_for_room(self, room_id, receipt_type):
|
def get_receipts_for_room(self, room_id, receipt_type):
|
||||||
@ -377,11 +379,11 @@ class _RoomStreamChangeCache(object):
|
|||||||
may have changed since that key. If the key is too old then the cache
|
may have changed since that key. If the key is too old then the cache
|
||||||
will simply return all rooms.
|
will simply return all rooms.
|
||||||
"""
|
"""
|
||||||
def __init__(self, size_of_cache=10000):
|
def __init__(self, current_key, size_of_cache=10000):
|
||||||
self._size_of_cache = size_of_cache
|
self._size_of_cache = size_of_cache
|
||||||
self._room_to_key = {}
|
self._room_to_key = {}
|
||||||
self._cache = sorteddict()
|
self._cache = sorteddict()
|
||||||
self._earliest_key = None
|
self._earliest_key = current_key
|
||||||
self.name = "ReceiptsRoomChangeCache"
|
self.name = "ReceiptsRoomChangeCache"
|
||||||
caches_by_name[self.name] = self._cache
|
caches_by_name[self.name] = self._cache
|
||||||
|
|
||||||
|
@ -444,19 +444,6 @@ class StreamStore(SQLBaseStore):
|
|||||||
rows = txn.fetchall()
|
rows = txn.fetchall()
|
||||||
return rows[0][0] if rows else 0
|
return rows[0][0] if rows else 0
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _get_min_token(self):
|
|
||||||
row = yield self._execute(
|
|
||||||
"_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
|
|
||||||
self.min_token = min(self.min_token, -1)
|
|
||||||
|
|
||||||
logger.debug("min_token is: %s", self.min_token)
|
|
||||||
|
|
||||||
defer.returnValue(self.min_token)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _set_before_and_after(events, rows):
|
def _set_before_and_after(events, rows):
|
||||||
for event, row in zip(events, rows):
|
for event, row in zip(events, rows):
|
||||||
|
@ -16,7 +16,6 @@
|
|||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from .util.id_generators import StreamIdGenerator
|
|
||||||
|
|
||||||
import ujson as json
|
import ujson as json
|
||||||
import logging
|
import logging
|
||||||
@ -25,12 +24,6 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class TagsStore(SQLBaseStore):
|
class TagsStore(SQLBaseStore):
|
||||||
def __init__(self, hs):
|
|
||||||
super(TagsStore, self).__init__(hs)
|
|
||||||
|
|
||||||
self._account_data_id_gen = StreamIdGenerator(
|
|
||||||
"account_data_max_stream_id", "stream_id"
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_max_account_data_stream_id(self):
|
def get_max_account_data_stream_id(self):
|
||||||
"""Get the current max stream id for the private user data stream
|
"""Get the current max stream id for the private user data stream
|
||||||
|
@ -72,28 +72,24 @@ class StreamIdGenerator(object):
|
|||||||
with stream_id_gen.get_next_txn(txn) as stream_id:
|
with stream_id_gen.get_next_txn(txn) as stream_id:
|
||||||
# ... persist event ...
|
# ... persist event ...
|
||||||
"""
|
"""
|
||||||
def __init__(self, table, column):
|
def __init__(self, db_conn, table, column):
|
||||||
self.table = table
|
self.table = table
|
||||||
self.column = column
|
self.column = column
|
||||||
|
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
self._current_max = None
|
cur = db_conn.cursor()
|
||||||
|
self._current_max = self._get_or_compute_current_max(cur)
|
||||||
|
cur.close()
|
||||||
|
|
||||||
self._unfinished_ids = deque()
|
self._unfinished_ids = deque()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_next(self, store):
|
def get_next(self, store):
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
with yield stream_id_gen.get_next as stream_id:
|
with yield stream_id_gen.get_next as stream_id:
|
||||||
# ... persist event ...
|
# ... persist event ...
|
||||||
"""
|
"""
|
||||||
if not self._current_max:
|
|
||||||
yield store.runInteraction(
|
|
||||||
"_compute_current_max",
|
|
||||||
self._get_or_compute_current_max,
|
|
||||||
)
|
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._current_max += 1
|
self._current_max += 1
|
||||||
next_id = self._current_max
|
next_id = self._current_max
|
||||||
@ -108,21 +104,14 @@ class StreamIdGenerator(object):
|
|||||||
with self._lock:
|
with self._lock:
|
||||||
self._unfinished_ids.remove(next_id)
|
self._unfinished_ids.remove(next_id)
|
||||||
|
|
||||||
defer.returnValue(manager())
|
return manager()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_next_mult(self, store, n):
|
def get_next_mult(self, store, n):
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
with yield stream_id_gen.get_next(store, n) as stream_ids:
|
with yield stream_id_gen.get_next(store, n) as stream_ids:
|
||||||
# ... persist events ...
|
# ... persist events ...
|
||||||
"""
|
"""
|
||||||
if not self._current_max:
|
|
||||||
yield store.runInteraction(
|
|
||||||
"_compute_current_max",
|
|
||||||
self._get_or_compute_current_max,
|
|
||||||
)
|
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
next_ids = range(self._current_max + 1, self._current_max + n + 1)
|
next_ids = range(self._current_max + 1, self._current_max + n + 1)
|
||||||
self._current_max += n
|
self._current_max += n
|
||||||
@ -139,24 +128,17 @@ class StreamIdGenerator(object):
|
|||||||
for next_id in next_ids:
|
for next_id in next_ids:
|
||||||
self._unfinished_ids.remove(next_id)
|
self._unfinished_ids.remove(next_id)
|
||||||
|
|
||||||
defer.returnValue(manager())
|
return manager()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_max_token(self, store):
|
def get_max_token(self, store):
|
||||||
"""Returns the maximum stream id such that all stream ids less than or
|
"""Returns the maximum stream id such that all stream ids less than or
|
||||||
equal to it have been successfully persisted.
|
equal to it have been successfully persisted.
|
||||||
"""
|
"""
|
||||||
if not self._current_max:
|
|
||||||
yield store.runInteraction(
|
|
||||||
"_compute_current_max",
|
|
||||||
self._get_or_compute_current_max,
|
|
||||||
)
|
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self._unfinished_ids:
|
if self._unfinished_ids:
|
||||||
defer.returnValue(self._unfinished_ids[0] - 1)
|
return self._unfinished_ids[0] - 1
|
||||||
|
|
||||||
defer.returnValue(self._current_max)
|
return self._current_max
|
||||||
|
|
||||||
def _get_or_compute_current_max(self, txn):
|
def _get_or_compute_current_max(self, txn):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
@ -439,7 +439,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
|||||||
f2 = self._write_config(suffix="2")
|
f2 = self._write_config(suffix="2")
|
||||||
|
|
||||||
config = Mock(app_service_config_files=[f1, f2])
|
config = Mock(app_service_config_files=[f1, f2])
|
||||||
hs = yield setup_test_homeserver(config=config)
|
hs = yield setup_test_homeserver(config=config, datastore=Mock())
|
||||||
|
|
||||||
ApplicationServiceStore(hs)
|
ApplicationServiceStore(hs)
|
||||||
|
|
||||||
@ -449,7 +449,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
|||||||
f2 = self._write_config(id="id", suffix="2")
|
f2 = self._write_config(id="id", suffix="2")
|
||||||
|
|
||||||
config = Mock(app_service_config_files=[f1, f2])
|
config = Mock(app_service_config_files=[f1, f2])
|
||||||
hs = yield setup_test_homeserver(config=config)
|
hs = yield setup_test_homeserver(config=config, datastore=Mock())
|
||||||
|
|
||||||
with self.assertRaises(ConfigError) as cm:
|
with self.assertRaises(ConfigError) as cm:
|
||||||
ApplicationServiceStore(hs)
|
ApplicationServiceStore(hs)
|
||||||
@ -465,7 +465,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
|||||||
f2 = self._write_config(as_token="as_token", suffix="2")
|
f2 = self._write_config(as_token="as_token", suffix="2")
|
||||||
|
|
||||||
config = Mock(app_service_config_files=[f1, f2])
|
config = Mock(app_service_config_files=[f1, f2])
|
||||||
hs = yield setup_test_homeserver(config=config)
|
hs = yield setup_test_homeserver(config=config, datastore=Mock())
|
||||||
|
|
||||||
with self.assertRaises(ConfigError) as cm:
|
with self.assertRaises(ConfigError) as cm:
|
||||||
ApplicationServiceStore(hs)
|
ApplicationServiceStore(hs)
|
||||||
|
@ -18,7 +18,6 @@ from tests import unittest
|
|||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.storage.registration import RegistrationStore
|
|
||||||
from synapse.util import stringutils
|
from synapse.util import stringutils
|
||||||
|
|
||||||
from tests.utils import setup_test_homeserver
|
from tests.utils import setup_test_homeserver
|
||||||
@ -31,7 +30,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
|||||||
hs = yield setup_test_homeserver()
|
hs = yield setup_test_homeserver()
|
||||||
self.db_pool = hs.get_db_pool()
|
self.db_pool = hs.get_db_pool()
|
||||||
|
|
||||||
self.store = RegistrationStore(hs)
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
self.user_id = "@my-user:test"
|
self.user_id = "@my-user:test"
|
||||||
self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz",
|
self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz",
|
||||||
|
@ -60,8 +60,10 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
|||||||
name, db_pool=db_pool, config=config,
|
name, db_pool=db_pool, config=config,
|
||||||
version_string="Synapse/tests",
|
version_string="Synapse/tests",
|
||||||
database_engine=create_engine("sqlite3"),
|
database_engine=create_engine("sqlite3"),
|
||||||
|
get_db_conn=db_pool.get_db_conn,
|
||||||
**kargs
|
**kargs
|
||||||
)
|
)
|
||||||
|
hs.setup()
|
||||||
else:
|
else:
|
||||||
hs = HomeServer(
|
hs = HomeServer(
|
||||||
name, db_pool=None, datastore=datastore, config=config,
|
name, db_pool=None, datastore=datastore, config=config,
|
||||||
@ -280,6 +282,12 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
|
|||||||
lambda conn: prepare_database(conn, engine)
|
lambda conn: prepare_database(conn, engine)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_db_conn(self):
|
||||||
|
conn = self.connect()
|
||||||
|
engine = create_engine("sqlite3")
|
||||||
|
prepare_database(conn, engine)
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
class MemoryDataStore(object):
|
class MemoryDataStore(object):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user