Use a stream id generator for backfilled ids

This commit is contained in:
Mark Haines 2016-04-01 13:29:05 +01:00
parent 03e406eefc
commit e36bfbab38
11 changed files with 69 additions and 61 deletions

View File

@ -88,15 +88,6 @@ class DataStore(RoomMemberStore, RoomStore,
self.hs = hs self.hs = hs
self.database_engine = hs.database_engine self.database_engine = hs.database_engine
cur = db_conn.cursor()
try:
cur.execute("SELECT MIN(stream_ordering) FROM events",)
rows = cur.fetchall()
self.min_stream_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1
self.min_stream_token = min(self.min_stream_token, -1)
finally:
cur.close()
self.client_ip_last_seen = Cache( self.client_ip_last_seen = Cache(
name="client_ip_last_seen", name="client_ip_last_seen",
keylen=4, keylen=4,
@ -105,6 +96,9 @@ class DataStore(RoomMemberStore, RoomStore,
self._stream_id_gen = StreamIdGenerator( self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering" db_conn, "events", "stream_ordering"
) )
self._backfill_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering", direction=-1
)
self._receipts_id_gen = StreamIdGenerator( self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id" db_conn, "receipts_linearized", "stream_id"
) )
@ -129,7 +123,7 @@ class DataStore(RoomMemberStore, RoomStore,
extra_tables=[("deleted_pushers", "stream_id")], extra_tables=[("deleted_pushers", "stream_id")],
) )
events_max = self._stream_id_gen.get_max_token() events_max = self._stream_id_gen.get_current_token()
event_cache_prefill, min_event_val = self._get_cache_dict( event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events", db_conn, "events",
entity_column="room_id", entity_column="room_id",
@ -145,7 +139,7 @@ class DataStore(RoomMemberStore, RoomStore,
"MembershipStreamChangeCache", events_max, "MembershipStreamChangeCache", events_max,
) )
account_max = self._account_data_id_gen.get_max_token() account_max = self._account_data_id_gen.get_current_token()
self._account_data_stream_cache = StreamChangeCache( self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max, "AccountDataAndTagsChangeCache", account_max,
) )
@ -156,7 +150,7 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "presence_stream", db_conn, "presence_stream",
entity_column="user_id", entity_column="user_id",
stream_column="stream_id", stream_column="stream_id",
max_value=self._presence_id_gen.get_max_token(), max_value=self._presence_id_gen.get_current_token(),
) )
self.presence_stream_cache = StreamChangeCache( self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", min_presence_val, "PresenceStreamChangeCache", min_presence_val,
@ -167,7 +161,7 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "push_rules_stream", db_conn, "push_rules_stream",
entity_column="user_id", entity_column="user_id",
stream_column="stream_id", stream_column="stream_id",
max_value=self._push_rules_stream_id_gen.get_max_token()[0], max_value=self._push_rules_stream_id_gen.get_current_token()[0],
) )
self.push_rules_stream_cache = StreamChangeCache( self.push_rules_stream_cache = StreamChangeCache(

View File

@ -200,7 +200,7 @@ class AccountDataStore(SQLBaseStore):
"add_room_account_data", add_account_data_txn, next_id "add_room_account_data", add_account_data_txn, next_id
) )
result = self._account_data_id_gen.get_max_token() result = self._account_data_id_gen.get_current_token()
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -239,7 +239,7 @@ class AccountDataStore(SQLBaseStore):
"add_user_account_data", add_account_data_txn, next_id "add_user_account_data", add_account_data_txn, next_id
) )
result = self._account_data_id_gen.get_max_token() result = self._account_data_id_gen.get_current_token()
defer.returnValue(result) defer.returnValue(result)
def _update_max_stream_id(self, txn, next_id): def _update_max_stream_id(self, txn, next_id):

View File

@ -24,7 +24,6 @@ from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from contextlib import contextmanager
from collections import namedtuple from collections import namedtuple
import logging import logging
@ -66,14 +65,9 @@ class EventsStore(SQLBaseStore):
return return
if backfilled: if backfilled:
start = self.min_stream_token - 1 stream_ordering_manager = self._backfill_id_gen.get_next_mult(
self.min_stream_token -= len(events_and_contexts) + 1 len(events_and_contexts)
stream_orderings = range(start, self.min_stream_token, -1) )
@contextmanager
def stream_ordering_manager():
yield stream_orderings
stream_ordering_manager = stream_ordering_manager()
else: else:
stream_ordering_manager = self._stream_id_gen.get_next_mult( stream_ordering_manager = self._stream_id_gen.get_next_mult(
len(events_and_contexts) len(events_and_contexts)
@ -130,7 +124,7 @@ class EventsStore(SQLBaseStore):
except _RollbackButIsFineException: except _RollbackButIsFineException:
pass pass
max_persisted_id = yield self._stream_id_gen.get_max_token() max_persisted_id = yield self._stream_id_gen.get_current_token()
defer.returnValue((stream_ordering, max_persisted_id)) defer.returnValue((stream_ordering, max_persisted_id))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -1117,10 +1111,7 @@ class EventsStore(SQLBaseStore):
def get_current_backfill_token(self): def get_current_backfill_token(self):
"""The current minimum token that backfilled events have reached""" """The current minimum token that backfilled events have reached"""
return -self._backfill_id_gen.get_current_token()
# TODO: Fix race with the persit_event txn by using one of the
# stream id managers
return -self.min_stream_token
def get_all_new_events(self, last_backfill_id, last_forward_id, def get_all_new_events(self, last_backfill_id, last_forward_id,
current_backfill_id, current_forward_id, limit): current_backfill_id, current_forward_id, limit):

View File

@ -68,7 +68,9 @@ class PresenceStore(SQLBaseStore):
self._update_presence_txn, stream_orderings, presence_states, self._update_presence_txn, stream_orderings, presence_states,
) )
defer.returnValue((stream_orderings[-1], self._presence_id_gen.get_max_token())) defer.returnValue((
stream_orderings[-1], self._presence_id_gen.get_current_token()
))
def _update_presence_txn(self, txn, stream_orderings, presence_states): def _update_presence_txn(self, txn, stream_orderings, presence_states):
for stream_id, state in zip(stream_orderings, presence_states): for stream_id, state in zip(stream_orderings, presence_states):
@ -155,7 +157,7 @@ class PresenceStore(SQLBaseStore):
defer.returnValue([UserPresenceState(**row) for row in rows]) defer.returnValue([UserPresenceState(**row) for row in rows])
def get_current_presence_token(self): def get_current_presence_token(self):
return self._presence_id_gen.get_max_token() return self._presence_id_gen.get_current_token()
def allow_presence_visible(self, observed_localpart, observer_userid): def allow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_insert( return self._simple_insert(

View File

@ -392,7 +392,7 @@ class PushRuleStore(SQLBaseStore):
"""Get the position of the push rules stream. """Get the position of the push rules stream.
Returns a pair of a stream id for the push_rules stream and the Returns a pair of a stream id for the push_rules stream and the
room stream ordering it corresponds to.""" room stream ordering it corresponds to."""
return self._push_rules_stream_id_gen.get_max_token() return self._push_rules_stream_id_gen.get_current_token()
def have_push_rules_changed_for_user(self, user_id, last_id): def have_push_rules_changed_for_user(self, user_id, last_id):
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):

View File

@ -78,7 +78,7 @@ class PusherStore(SQLBaseStore):
defer.returnValue(rows) defer.returnValue(rows)
def get_pushers_stream_token(self): def get_pushers_stream_token(self):
return self._pushers_id_gen.get_max_token() return self._pushers_id_gen.get_current_token()
def get_all_updated_pushers(self, last_id, current_id, limit): def get_all_updated_pushers(self, last_id, current_id, limit):
def get_all_updated_pushers_txn(txn): def get_all_updated_pushers_txn(txn):

View File

@ -31,7 +31,7 @@ class ReceiptsStore(SQLBaseStore):
super(ReceiptsStore, self).__init__(hs) super(ReceiptsStore, self).__init__(hs)
self._receipts_stream_cache = StreamChangeCache( self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token() "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
) )
@cached(num_args=2) @cached(num_args=2)
@ -221,7 +221,7 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
def get_max_receipt_stream_id(self): def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_max_token() return self._receipts_id_gen.get_current_token()
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id): user_id, event_id, data, stream_id):
@ -346,7 +346,7 @@ class ReceiptsStore(SQLBaseStore):
room_id, receipt_type, user_id, event_ids, data room_id, receipt_type, user_id, event_ids, data
) )
max_persisted_id = self._stream_id_gen.get_max_token() max_persisted_id = self._stream_id_gen.get_current_token()
defer.returnValue((stream_id, max_persisted_id)) defer.returnValue((stream_id, max_persisted_id))

View File

@ -458,4 +458,4 @@ class StateStore(SQLBaseStore):
) )
def get_state_stream_token(self): def get_state_stream_token(self):
return self._state_groups_id_gen.get_max_token() return self._state_groups_id_gen.get_current_token()

View File

@ -539,7 +539,7 @@ class StreamStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_events_max_id(self, direction='f'): def get_room_events_max_id(self, direction='f'):
token = yield self._stream_id_gen.get_max_token() token = yield self._stream_id_gen.get_current_token()
if direction != 'b': if direction != 'b':
defer.returnValue("s%d" % (token,)) defer.returnValue("s%d" % (token,))
else: else:

View File

@ -30,7 +30,7 @@ class TagsStore(SQLBaseStore):
Returns: Returns:
A deferred int. A deferred int.
""" """
return self._account_data_id_gen.get_max_token() return self._account_data_id_gen.get_current_token()
@cached() @cached()
def get_tags_for_user(self, user_id): def get_tags_for_user(self, user_id):
@ -200,7 +200,7 @@ class TagsStore(SQLBaseStore):
self.get_tags_for_user.invalidate((user_id,)) self.get_tags_for_user.invalidate((user_id,))
result = self._account_data_id_gen.get_max_token() result = self._account_data_id_gen.get_current_token()
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -222,7 +222,7 @@ class TagsStore(SQLBaseStore):
self.get_tags_for_user.invalidate((user_id,)) self.get_tags_for_user.invalidate((user_id,))
result = self._account_data_id_gen.get_max_token() result = self._account_data_id_gen.get_current_token()
defer.returnValue(result) defer.returnValue(result)
def _update_revision_txn(self, txn, user_id, room_id, next_id): def _update_revision_txn(self, txn, user_id, room_id, next_id):

View File

@ -21,7 +21,7 @@ import threading
class IdGenerator(object): class IdGenerator(object):
def __init__(self, db_conn, table, column): def __init__(self, db_conn, table, column):
self._lock = threading.Lock() self._lock = threading.Lock()
self._next_id = _load_max_id(db_conn, table, column) self._next_id = _load_current_id(db_conn, table, column)
def get_next(self): def get_next(self):
with self._lock: with self._lock:
@ -29,12 +29,16 @@ class IdGenerator(object):
return self._next_id return self._next_id
def _load_max_id(db_conn, table, column): def _load_current_id(db_conn, table, column, direction=1):
cur = db_conn.cursor() cur = db_conn.cursor()
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) if direction == 1:
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
else:
cur.execute("SELECT MIN(%s) FROM %s" % (column, table,))
val, = cur.fetchone() val, = cur.fetchone()
cur.close() cur.close()
return int(val) if val else 1 current_id = int(val) if val else direction
return (max if direction == 1 else min)(current_id, direction)
class StreamIdGenerator(object): class StreamIdGenerator(object):
@ -45,17 +49,30 @@ class StreamIdGenerator(object):
all ids less than or equal to it have completed. This handles the fact that all ids less than or equal to it have completed. This handles the fact that
persistence of events can complete out of order. persistence of events can complete out of order.
:param connection db_conn: A database connection to use to fetch the
initial value of the generator from.
:param str table: A database table to read the initial value of the id
generator from.
:param str column: The column of the database table to read the initial
value from the id generator from.
:param list extra_tables: List of pairs of database tables and columns to
use to source the initial value of the generator from. The value with
the largest magnitude is used.
:param int direction: which direction the stream ids grow in. +1 to grow
upwards, -1 to grow downwards.
Usage: Usage:
with stream_id_gen.get_next() as stream_id: with stream_id_gen.get_next() as stream_id:
# ... persist event ... # ... persist event ...
""" """
def __init__(self, db_conn, table, column, extra_tables=[]): def __init__(self, db_conn, table, column, extra_tables=[], direction=1):
self._lock = threading.Lock() self._lock = threading.Lock()
self._current_max = _load_max_id(db_conn, table, column) self._direction = direction
self._current = _load_current_id(db_conn, table, column, direction)
for table, column in extra_tables: for table, column in extra_tables:
self._current_max = max( self._current = (max if direction > 0 else min)(
self._current_max, self._current,
_load_max_id(db_conn, table, column) _load_current_id(db_conn, table, column, direction)
) )
self._unfinished_ids = deque() self._unfinished_ids = deque()
@ -66,8 +83,8 @@ class StreamIdGenerator(object):
# ... persist event ... # ... persist event ...
""" """
with self._lock: with self._lock:
self._current_max += 1 self._current += self._direction
next_id = self._current_max next_id = self._current
self._unfinished_ids.append(next_id) self._unfinished_ids.append(next_id)
@ -88,8 +105,12 @@ class StreamIdGenerator(object):
# ... persist events ... # ... persist events ...
""" """
with self._lock: with self._lock:
next_ids = range(self._current_max + 1, self._current_max + n + 1) next_ids = range(
self._current_max += n self._current + self._direction,
self._current + self._direction * (n + 1),
self._direction
)
self._current += n
for next_id in next_ids: for next_id in next_ids:
self._unfinished_ids.append(next_id) self._unfinished_ids.append(next_id)
@ -105,15 +126,15 @@ class StreamIdGenerator(object):
return manager() return manager()
def get_max_token(self): def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or """Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted. equal to it have been successfully persisted.
""" """
with self._lock: with self._lock:
if self._unfinished_ids: if self._unfinished_ids:
return self._unfinished_ids[0] - 1 return self._unfinished_ids[0] - self._direction
return self._current_max return self._current
class ChainedIdGenerator(object): class ChainedIdGenerator(object):
@ -125,7 +146,7 @@ class ChainedIdGenerator(object):
def __init__(self, chained_generator, db_conn, table, column): def __init__(self, chained_generator, db_conn, table, column):
self.chained_generator = chained_generator self.chained_generator = chained_generator
self._lock = threading.Lock() self._lock = threading.Lock()
self._current_max = _load_max_id(db_conn, table, column) self._current_max = _load_current_id(db_conn, table, column)
self._unfinished_ids = deque() self._unfinished_ids = deque()
def get_next(self): def get_next(self):
@ -137,7 +158,7 @@ class ChainedIdGenerator(object):
with self._lock: with self._lock:
self._current_max += 1 self._current_max += 1
next_id = self._current_max next_id = self._current_max
chained_id = self.chained_generator.get_max_token() chained_id = self.chained_generator.get_current_token()
self._unfinished_ids.append((next_id, chained_id)) self._unfinished_ids.append((next_id, chained_id))
@ -151,7 +172,7 @@ class ChainedIdGenerator(object):
return manager() return manager()
def get_max_token(self): def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or """Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted. equal to it have been successfully persisted.
""" """
@ -160,4 +181,4 @@ class ChainedIdGenerator(object):
stream_id, chained_id = self._unfinished_ids[0] stream_id, chained_id = self._unfinished_ids[0]
return (stream_id - 1, chained_id) return (stream_id - 1, chained_id)
return (self._current_max, self.chained_generator.get_max_token()) return (self._current_max, self.chained_generator.get_current_token())