Merge pull request #613 from matrix-org/markjh/yield

Load the current id in the IdGenerator constructor
This commit is contained in:
Mark Haines 2016-03-01 14:54:29 +00:00
commit d50ca1b1ed
12 changed files with 52 additions and 77 deletions

View File

@ -115,13 +115,13 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "presence_stream", "stream_id" db_conn, "presence_stream", "stream_id"
) )
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
self._state_groups_id_gen = IdGenerator("state_groups", "id", self) self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self) self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
self._pushers_id_gen = IdGenerator("pushers", "id", self) self._pushers_id_gen = IdGenerator(db_conn, "pushers", "id")
self._push_rule_id_gen = IdGenerator("push_rules", "id", self) self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
events_max = self._stream_id_gen.get_max_token() events_max = self._stream_id_gen.get_max_token()
event_cache_prefill, min_event_val = self._get_cache_dict( event_cache_prefill, min_event_val = self._get_cache_dict(

View File

@ -163,12 +163,12 @@ class AccountDataStore(SQLBaseStore):
) )
self._update_max_stream_id(txn, next_id) self._update_max_stream_id(txn, next_id)
with (yield self._account_data_id_gen.get_next(self)) as next_id: with self._account_data_id_gen.get_next() as next_id:
yield self.runInteraction( yield self.runInteraction(
"add_room_account_data", add_account_data_txn, next_id "add_room_account_data", add_account_data_txn, next_id
) )
result = yield self._account_data_id_gen.get_max_token() result = self._account_data_id_gen.get_max_token()
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -202,12 +202,12 @@ class AccountDataStore(SQLBaseStore):
) )
self._update_max_stream_id(txn, next_id) self._update_max_stream_id(txn, next_id)
with (yield self._account_data_id_gen.get_next(self)) as next_id: with self._account_data_id_gen.get_next() as next_id:
yield self.runInteraction( yield self.runInteraction(
"add_user_account_data", add_account_data_txn, next_id "add_user_account_data", add_account_data_txn, next_id
) )
result = yield self._account_data_id_gen.get_max_token() result = self._account_data_id_gen.get_max_token()
defer.returnValue(result) defer.returnValue(result)
def _update_max_stream_id(self, txn, next_id): def _update_max_stream_id(self, txn, next_id):

View File

@ -75,8 +75,8 @@ class EventsStore(SQLBaseStore):
yield stream_orderings yield stream_orderings
stream_ordering_manager = stream_ordering_manager() stream_ordering_manager = stream_ordering_manager()
else: else:
stream_ordering_manager = yield self._stream_id_gen.get_next_mult( stream_ordering_manager = self._stream_id_gen.get_next_mult(
self, len(events_and_contexts) len(events_and_contexts)
) )
with stream_ordering_manager as stream_orderings: with stream_ordering_manager as stream_orderings:
@ -109,7 +109,7 @@ class EventsStore(SQLBaseStore):
stream_ordering = self.min_stream_token stream_ordering = self.min_stream_token
if stream_ordering is None: if stream_ordering is None:
stream_ordering_manager = yield self._stream_id_gen.get_next(self) stream_ordering_manager = self._stream_id_gen.get_next()
else: else:
@contextmanager @contextmanager
def stream_ordering_manager(): def stream_ordering_manager():

View File

@ -58,8 +58,8 @@ class UserPresenceState(namedtuple("UserPresenceState",
class PresenceStore(SQLBaseStore): class PresenceStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def update_presence(self, presence_states): def update_presence(self, presence_states):
stream_ordering_manager = yield self._presence_id_gen.get_next_mult( stream_ordering_manager = self._presence_id_gen.get_next_mult(
self, len(presence_states) len(presence_states)
) )
with stream_ordering_manager as stream_orderings: with stream_ordering_manager as stream_orderings:

View File

@ -226,7 +226,7 @@ class PushRuleStore(SQLBaseStore):
if txn.rowcount == 0: if txn.rowcount == 0:
# We didn't update a row with the given rule_id so insert one # We didn't update a row with the given rule_id so insert one
push_rule_id = self._push_rule_id_gen.get_next_txn(txn) push_rule_id = self._push_rule_id_gen.get_next()
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
@ -279,7 +279,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(ret) defer.returnValue(ret)
def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled): def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled):
new_id = self._push_rules_enable_id_gen.get_next_txn(txn) new_id = self._push_rules_enable_id_gen.get_next()
self._simple_upsert_txn( self._simple_upsert_txn(
txn, txn,
"push_rules_enable", "push_rules_enable",

View File

@ -84,7 +84,7 @@ class PusherStore(SQLBaseStore):
app_display_name, device_display_name, app_display_name, device_display_name,
pushkey, pushkey_ts, lang, data, profile_tag=""): pushkey, pushkey_ts, lang, data, profile_tag=""):
try: try:
next_id = yield self._pushers_id_gen.get_next() next_id = self._pushers_id_gen.get_next()
yield self._simple_upsert( yield self._simple_upsert(
"pushers", "pushers",
dict( dict(

View File

@ -330,7 +330,7 @@ class ReceiptsStore(SQLBaseStore):
"insert_receipt_conv", graph_to_linear "insert_receipt_conv", graph_to_linear
) )
stream_id_manager = yield self._receipts_id_gen.get_next(self) stream_id_manager = self._receipts_id_gen.get_next()
with stream_id_manager as stream_id: with stream_id_manager as stream_id:
have_persisted = yield self.runInteraction( have_persisted = yield self.runInteraction(
"insert_linearized_receipt", "insert_linearized_receipt",
@ -347,7 +347,7 @@ class ReceiptsStore(SQLBaseStore):
room_id, receipt_type, user_id, event_ids, data room_id, receipt_type, user_id, event_ids, data
) )
max_persisted_id = yield self._stream_id_gen.get_max_token() max_persisted_id = self._stream_id_gen.get_max_token()
defer.returnValue((stream_id, max_persisted_id)) defer.returnValue((stream_id, max_persisted_id))

View File

@ -40,7 +40,7 @@ class RegistrationStore(SQLBaseStore):
Raises: Raises:
StoreError if there was a problem adding this. StoreError if there was a problem adding this.
""" """
next_id = yield self._access_tokens_id_gen.get_next() next_id = self._access_tokens_id_gen.get_next()
yield self._simple_insert( yield self._simple_insert(
"access_tokens", "access_tokens",
@ -62,7 +62,7 @@ class RegistrationStore(SQLBaseStore):
Raises: Raises:
StoreError if there was a problem adding this. StoreError if there was a problem adding this.
""" """
next_id = yield self._refresh_tokens_id_gen.get_next() next_id = self._refresh_tokens_id_gen.get_next()
yield self._simple_insert( yield self._simple_insert(
"refresh_tokens", "refresh_tokens",
@ -99,7 +99,7 @@ class RegistrationStore(SQLBaseStore):
def _register(self, txn, user_id, token, password_hash, was_guest, make_guest): def _register(self, txn, user_id, token, password_hash, was_guest, make_guest):
now = int(self.clock.time()) now = int(self.clock.time())
next_id = self._access_tokens_id_gen.get_next_txn(txn) next_id = self._access_tokens_id_gen.get_next()
try: try:
if was_guest: if was_guest:

View File

@ -83,7 +83,7 @@ class StateStore(SQLBaseStore):
if event.is_state(): if event.is_state():
state_events[(event.type, event.state_key)] = event state_events[(event.type, event.state_key)] = event
state_group = self._state_groups_id_gen.get_next_txn(txn) state_group = self._state_groups_id_gen.get_next()
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
table="state_groups", table="state_groups",

View File

@ -142,12 +142,12 @@ class TagsStore(SQLBaseStore):
) )
self._update_revision_txn(txn, user_id, room_id, next_id) self._update_revision_txn(txn, user_id, room_id, next_id)
with (yield self._account_data_id_gen.get_next(self)) as next_id: with self._account_data_id_gen.get_next() as next_id:
yield self.runInteraction("add_tag", add_tag_txn, next_id) yield self.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,)) self.get_tags_for_user.invalidate((user_id,))
result = yield self._account_data_id_gen.get_max_token() result = self._account_data_id_gen.get_max_token()
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -164,12 +164,12 @@ class TagsStore(SQLBaseStore):
txn.execute(sql, (user_id, room_id, tag)) txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id) self._update_revision_txn(txn, user_id, room_id, next_id)
with (yield self._account_data_id_gen.get_next(self)) as next_id: with self._account_data_id_gen.get_next() as next_id:
yield self.runInteraction("remove_tag", remove_tag_txn, next_id) yield self.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,)) self.get_tags_for_user.invalidate((user_id,))
result = yield self._account_data_id_gen.get_max_token() result = self._account_data_id_gen.get_max_token()
defer.returnValue(result) defer.returnValue(result)
def _update_revision_txn(self, txn, user_id, room_id, next_id): def _update_revision_txn(self, txn, user_id, room_id, next_id):

View File

@ -117,7 +117,7 @@ class TransactionStore(SQLBaseStore):
def _prep_send_transaction(self, txn, transaction_id, destination, def _prep_send_transaction(self, txn, transaction_id, destination,
origin_server_ts): origin_server_ts):
next_id = self._transaction_id_gen.get_next_txn(txn) next_id = self._transaction_id_gen.get_next()
# First we find out what the prev_txns should be. # First we find out what the prev_txns should be.
# Since we know that we are only sending one transaction at a time, # Since we know that we are only sending one transaction at a time,

View File

@ -13,51 +13,30 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from collections import deque from collections import deque
import contextlib import contextlib
import threading import threading
class IdGenerator(object): class IdGenerator(object):
def __init__(self, table, column, store): def __init__(self, db_conn, table, column):
self.table = table self.table = table
self.column = column self.column = column
self.store = store
self._lock = threading.Lock() self._lock = threading.Lock()
self._next_id = None cur = db_conn.cursor()
self._next_id = self._load_next_id(cur)
cur.close()
def _load_next_id(self, txn):
txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table,))
val, = txn.fetchone()
return val + 1 if val else 1
@defer.inlineCallbacks
def get_next(self): def get_next(self):
if self._next_id is None:
yield self.store.runInteraction(
"IdGenerator_%s" % (self.table,),
self.get_next_txn,
)
with self._lock: with self._lock:
i = self._next_id
self._next_id += 1
defer.returnValue(i)
def get_next_txn(self, txn):
with self._lock:
if self._next_id:
i = self._next_id i = self._next_id
self._next_id += 1 self._next_id += 1
return i return i
else:
txn.execute(
"SELECT MAX(%s) FROM %s" % (self.column, self.table,)
)
val, = txn.fetchone()
cur = val or 0
cur += 1
self._next_id = cur + 1
return cur
class StreamIdGenerator(object): class StreamIdGenerator(object):
@ -69,7 +48,7 @@ class StreamIdGenerator(object):
persistence of events can complete out of order. persistence of events can complete out of order.
Usage: Usage:
with stream_id_gen.get_next_txn(txn) as stream_id: with stream_id_gen.get_next() as stream_id:
# ... persist event ... # ... persist event ...
""" """
def __init__(self, db_conn, table, column): def __init__(self, db_conn, table, column):
@ -79,15 +58,21 @@ class StreamIdGenerator(object):
self._lock = threading.Lock() self._lock = threading.Lock()
cur = db_conn.cursor() cur = db_conn.cursor()
self._current_max = self._get_or_compute_current_max(cur) self._current_max = self._load_current_max(cur)
cur.close() cur.close()
self._unfinished_ids = deque() self._unfinished_ids = deque()
def get_next(self, store): def _load_current_max(self, txn):
txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
rows = txn.fetchall()
val, = rows[0]
return int(val) if val else 1
def get_next(self):
""" """
Usage: Usage:
with yield stream_id_gen.get_next as stream_id: with stream_id_gen.get_next() as stream_id:
# ... persist event ... # ... persist event ...
""" """
with self._lock: with self._lock:
@ -106,10 +91,10 @@ class StreamIdGenerator(object):
return manager() return manager()
def get_next_mult(self, store, n): def get_next_mult(self, n):
""" """
Usage: Usage:
with yield stream_id_gen.get_next(store, n) as stream_ids: with stream_id_gen.get_next(n) as stream_ids:
# ... persist events ... # ... persist events ...
""" """
with self._lock: with self._lock:
@ -139,13 +124,3 @@ class StreamIdGenerator(object):
return self._unfinished_ids[0] - 1 return self._unfinished_ids[0] - 1
return self._current_max return self._current_max
def _get_or_compute_current_max(self, txn):
with self._lock:
txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
rows = txn.fetchall()
val, = rows[0]
self._current_max = int(val) if val else 1
return self._current_max