mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-01-04 08:20:49 -05:00
Merge pull request #613 from matrix-org/markjh/yield
Load the current id in the IdGenerator constructor
This commit is contained in:
commit
d50ca1b1ed
@ -115,13 +115,13 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||||||
db_conn, "presence_stream", "stream_id"
|
db_conn, "presence_stream", "stream_id"
|
||||||
)
|
)
|
||||||
|
|
||||||
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
|
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
|
||||||
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
|
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
|
||||||
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
|
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
||||||
self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
|
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
|
||||||
self._pushers_id_gen = IdGenerator("pushers", "id", self)
|
self._pushers_id_gen = IdGenerator(db_conn, "pushers", "id")
|
||||||
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
|
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
||||||
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
|
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
|
||||||
|
|
||||||
events_max = self._stream_id_gen.get_max_token()
|
events_max = self._stream_id_gen.get_max_token()
|
||||||
event_cache_prefill, min_event_val = self._get_cache_dict(
|
event_cache_prefill, min_event_val = self._get_cache_dict(
|
||||||
|
@ -163,12 +163,12 @@ class AccountDataStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
self._update_max_stream_id(txn, next_id)
|
self._update_max_stream_id(txn, next_id)
|
||||||
|
|
||||||
with (yield self._account_data_id_gen.get_next(self)) as next_id:
|
with self._account_data_id_gen.get_next() as next_id:
|
||||||
yield self.runInteraction(
|
yield self.runInteraction(
|
||||||
"add_room_account_data", add_account_data_txn, next_id
|
"add_room_account_data", add_account_data_txn, next_id
|
||||||
)
|
)
|
||||||
|
|
||||||
result = yield self._account_data_id_gen.get_max_token()
|
result = self._account_data_id_gen.get_max_token()
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -202,12 +202,12 @@ class AccountDataStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
self._update_max_stream_id(txn, next_id)
|
self._update_max_stream_id(txn, next_id)
|
||||||
|
|
||||||
with (yield self._account_data_id_gen.get_next(self)) as next_id:
|
with self._account_data_id_gen.get_next() as next_id:
|
||||||
yield self.runInteraction(
|
yield self.runInteraction(
|
||||||
"add_user_account_data", add_account_data_txn, next_id
|
"add_user_account_data", add_account_data_txn, next_id
|
||||||
)
|
)
|
||||||
|
|
||||||
result = yield self._account_data_id_gen.get_max_token()
|
result = self._account_data_id_gen.get_max_token()
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
def _update_max_stream_id(self, txn, next_id):
|
def _update_max_stream_id(self, txn, next_id):
|
||||||
|
@ -75,8 +75,8 @@ class EventsStore(SQLBaseStore):
|
|||||||
yield stream_orderings
|
yield stream_orderings
|
||||||
stream_ordering_manager = stream_ordering_manager()
|
stream_ordering_manager = stream_ordering_manager()
|
||||||
else:
|
else:
|
||||||
stream_ordering_manager = yield self._stream_id_gen.get_next_mult(
|
stream_ordering_manager = self._stream_id_gen.get_next_mult(
|
||||||
self, len(events_and_contexts)
|
len(events_and_contexts)
|
||||||
)
|
)
|
||||||
|
|
||||||
with stream_ordering_manager as stream_orderings:
|
with stream_ordering_manager as stream_orderings:
|
||||||
@ -109,7 +109,7 @@ class EventsStore(SQLBaseStore):
|
|||||||
stream_ordering = self.min_stream_token
|
stream_ordering = self.min_stream_token
|
||||||
|
|
||||||
if stream_ordering is None:
|
if stream_ordering is None:
|
||||||
stream_ordering_manager = yield self._stream_id_gen.get_next(self)
|
stream_ordering_manager = self._stream_id_gen.get_next()
|
||||||
else:
|
else:
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def stream_ordering_manager():
|
def stream_ordering_manager():
|
||||||
|
@ -58,8 +58,8 @@ class UserPresenceState(namedtuple("UserPresenceState",
|
|||||||
class PresenceStore(SQLBaseStore):
|
class PresenceStore(SQLBaseStore):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def update_presence(self, presence_states):
|
def update_presence(self, presence_states):
|
||||||
stream_ordering_manager = yield self._presence_id_gen.get_next_mult(
|
stream_ordering_manager = self._presence_id_gen.get_next_mult(
|
||||||
self, len(presence_states)
|
len(presence_states)
|
||||||
)
|
)
|
||||||
|
|
||||||
with stream_ordering_manager as stream_orderings:
|
with stream_ordering_manager as stream_orderings:
|
||||||
|
@ -226,7 +226,7 @@ class PushRuleStore(SQLBaseStore):
|
|||||||
|
|
||||||
if txn.rowcount == 0:
|
if txn.rowcount == 0:
|
||||||
# We didn't update a row with the given rule_id so insert one
|
# We didn't update a row with the given rule_id so insert one
|
||||||
push_rule_id = self._push_rule_id_gen.get_next_txn(txn)
|
push_rule_id = self._push_rule_id_gen.get_next()
|
||||||
|
|
||||||
self._simple_insert_txn(
|
self._simple_insert_txn(
|
||||||
txn,
|
txn,
|
||||||
@ -279,7 +279,7 @@ class PushRuleStore(SQLBaseStore):
|
|||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled):
|
def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled):
|
||||||
new_id = self._push_rules_enable_id_gen.get_next_txn(txn)
|
new_id = self._push_rules_enable_id_gen.get_next()
|
||||||
self._simple_upsert_txn(
|
self._simple_upsert_txn(
|
||||||
txn,
|
txn,
|
||||||
"push_rules_enable",
|
"push_rules_enable",
|
||||||
|
@ -84,7 +84,7 @@ class PusherStore(SQLBaseStore):
|
|||||||
app_display_name, device_display_name,
|
app_display_name, device_display_name,
|
||||||
pushkey, pushkey_ts, lang, data, profile_tag=""):
|
pushkey, pushkey_ts, lang, data, profile_tag=""):
|
||||||
try:
|
try:
|
||||||
next_id = yield self._pushers_id_gen.get_next()
|
next_id = self._pushers_id_gen.get_next()
|
||||||
yield self._simple_upsert(
|
yield self._simple_upsert(
|
||||||
"pushers",
|
"pushers",
|
||||||
dict(
|
dict(
|
||||||
|
@ -330,7 +330,7 @@ class ReceiptsStore(SQLBaseStore):
|
|||||||
"insert_receipt_conv", graph_to_linear
|
"insert_receipt_conv", graph_to_linear
|
||||||
)
|
)
|
||||||
|
|
||||||
stream_id_manager = yield self._receipts_id_gen.get_next(self)
|
stream_id_manager = self._receipts_id_gen.get_next()
|
||||||
with stream_id_manager as stream_id:
|
with stream_id_manager as stream_id:
|
||||||
have_persisted = yield self.runInteraction(
|
have_persisted = yield self.runInteraction(
|
||||||
"insert_linearized_receipt",
|
"insert_linearized_receipt",
|
||||||
@ -347,7 +347,7 @@ class ReceiptsStore(SQLBaseStore):
|
|||||||
room_id, receipt_type, user_id, event_ids, data
|
room_id, receipt_type, user_id, event_ids, data
|
||||||
)
|
)
|
||||||
|
|
||||||
max_persisted_id = yield self._stream_id_gen.get_max_token()
|
max_persisted_id = self._stream_id_gen.get_max_token()
|
||||||
|
|
||||||
defer.returnValue((stream_id, max_persisted_id))
|
defer.returnValue((stream_id, max_persisted_id))
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ class RegistrationStore(SQLBaseStore):
|
|||||||
Raises:
|
Raises:
|
||||||
StoreError if there was a problem adding this.
|
StoreError if there was a problem adding this.
|
||||||
"""
|
"""
|
||||||
next_id = yield self._access_tokens_id_gen.get_next()
|
next_id = self._access_tokens_id_gen.get_next()
|
||||||
|
|
||||||
yield self._simple_insert(
|
yield self._simple_insert(
|
||||||
"access_tokens",
|
"access_tokens",
|
||||||
@ -62,7 +62,7 @@ class RegistrationStore(SQLBaseStore):
|
|||||||
Raises:
|
Raises:
|
||||||
StoreError if there was a problem adding this.
|
StoreError if there was a problem adding this.
|
||||||
"""
|
"""
|
||||||
next_id = yield self._refresh_tokens_id_gen.get_next()
|
next_id = self._refresh_tokens_id_gen.get_next()
|
||||||
|
|
||||||
yield self._simple_insert(
|
yield self._simple_insert(
|
||||||
"refresh_tokens",
|
"refresh_tokens",
|
||||||
@ -99,7 +99,7 @@ class RegistrationStore(SQLBaseStore):
|
|||||||
def _register(self, txn, user_id, token, password_hash, was_guest, make_guest):
|
def _register(self, txn, user_id, token, password_hash, was_guest, make_guest):
|
||||||
now = int(self.clock.time())
|
now = int(self.clock.time())
|
||||||
|
|
||||||
next_id = self._access_tokens_id_gen.get_next_txn(txn)
|
next_id = self._access_tokens_id_gen.get_next()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if was_guest:
|
if was_guest:
|
||||||
|
@ -83,7 +83,7 @@ class StateStore(SQLBaseStore):
|
|||||||
if event.is_state():
|
if event.is_state():
|
||||||
state_events[(event.type, event.state_key)] = event
|
state_events[(event.type, event.state_key)] = event
|
||||||
|
|
||||||
state_group = self._state_groups_id_gen.get_next_txn(txn)
|
state_group = self._state_groups_id_gen.get_next()
|
||||||
self._simple_insert_txn(
|
self._simple_insert_txn(
|
||||||
txn,
|
txn,
|
||||||
table="state_groups",
|
table="state_groups",
|
||||||
|
@ -142,12 +142,12 @@ class TagsStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
self._update_revision_txn(txn, user_id, room_id, next_id)
|
self._update_revision_txn(txn, user_id, room_id, next_id)
|
||||||
|
|
||||||
with (yield self._account_data_id_gen.get_next(self)) as next_id:
|
with self._account_data_id_gen.get_next() as next_id:
|
||||||
yield self.runInteraction("add_tag", add_tag_txn, next_id)
|
yield self.runInteraction("add_tag", add_tag_txn, next_id)
|
||||||
|
|
||||||
self.get_tags_for_user.invalidate((user_id,))
|
self.get_tags_for_user.invalidate((user_id,))
|
||||||
|
|
||||||
result = yield self._account_data_id_gen.get_max_token()
|
result = self._account_data_id_gen.get_max_token()
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -164,12 +164,12 @@ class TagsStore(SQLBaseStore):
|
|||||||
txn.execute(sql, (user_id, room_id, tag))
|
txn.execute(sql, (user_id, room_id, tag))
|
||||||
self._update_revision_txn(txn, user_id, room_id, next_id)
|
self._update_revision_txn(txn, user_id, room_id, next_id)
|
||||||
|
|
||||||
with (yield self._account_data_id_gen.get_next(self)) as next_id:
|
with self._account_data_id_gen.get_next() as next_id:
|
||||||
yield self.runInteraction("remove_tag", remove_tag_txn, next_id)
|
yield self.runInteraction("remove_tag", remove_tag_txn, next_id)
|
||||||
|
|
||||||
self.get_tags_for_user.invalidate((user_id,))
|
self.get_tags_for_user.invalidate((user_id,))
|
||||||
|
|
||||||
result = yield self._account_data_id_gen.get_max_token()
|
result = self._account_data_id_gen.get_max_token()
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
def _update_revision_txn(self, txn, user_id, room_id, next_id):
|
def _update_revision_txn(self, txn, user_id, room_id, next_id):
|
||||||
|
@ -117,7 +117,7 @@ class TransactionStore(SQLBaseStore):
|
|||||||
def _prep_send_transaction(self, txn, transaction_id, destination,
|
def _prep_send_transaction(self, txn, transaction_id, destination,
|
||||||
origin_server_ts):
|
origin_server_ts):
|
||||||
|
|
||||||
next_id = self._transaction_id_gen.get_next_txn(txn)
|
next_id = self._transaction_id_gen.get_next()
|
||||||
|
|
||||||
# First we find out what the prev_txns should be.
|
# First we find out what the prev_txns should be.
|
||||||
# Since we know that we are only sending one transaction at a time,
|
# Since we know that we are only sending one transaction at a time,
|
||||||
|
@ -13,51 +13,30 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
import contextlib
|
import contextlib
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
|
|
||||||
class IdGenerator(object):
|
class IdGenerator(object):
|
||||||
def __init__(self, table, column, store):
|
def __init__(self, db_conn, table, column):
|
||||||
self.table = table
|
self.table = table
|
||||||
self.column = column
|
self.column = column
|
||||||
self.store = store
|
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
self._next_id = None
|
cur = db_conn.cursor()
|
||||||
|
self._next_id = self._load_next_id(cur)
|
||||||
|
cur.close()
|
||||||
|
|
||||||
|
def _load_next_id(self, txn):
|
||||||
|
txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table,))
|
||||||
|
val, = txn.fetchone()
|
||||||
|
return val + 1 if val else 1
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_next(self):
|
def get_next(self):
|
||||||
if self._next_id is None:
|
|
||||||
yield self.store.runInteraction(
|
|
||||||
"IdGenerator_%s" % (self.table,),
|
|
||||||
self.get_next_txn,
|
|
||||||
)
|
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
i = self._next_id
|
i = self._next_id
|
||||||
self._next_id += 1
|
self._next_id += 1
|
||||||
defer.returnValue(i)
|
return i
|
||||||
|
|
||||||
def get_next_txn(self, txn):
|
|
||||||
with self._lock:
|
|
||||||
if self._next_id:
|
|
||||||
i = self._next_id
|
|
||||||
self._next_id += 1
|
|
||||||
return i
|
|
||||||
else:
|
|
||||||
txn.execute(
|
|
||||||
"SELECT MAX(%s) FROM %s" % (self.column, self.table,)
|
|
||||||
)
|
|
||||||
|
|
||||||
val, = txn.fetchone()
|
|
||||||
cur = val or 0
|
|
||||||
cur += 1
|
|
||||||
self._next_id = cur + 1
|
|
||||||
|
|
||||||
return cur
|
|
||||||
|
|
||||||
|
|
||||||
class StreamIdGenerator(object):
|
class StreamIdGenerator(object):
|
||||||
@ -69,7 +48,7 @@ class StreamIdGenerator(object):
|
|||||||
persistence of events can complete out of order.
|
persistence of events can complete out of order.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
with stream_id_gen.get_next_txn(txn) as stream_id:
|
with stream_id_gen.get_next() as stream_id:
|
||||||
# ... persist event ...
|
# ... persist event ...
|
||||||
"""
|
"""
|
||||||
def __init__(self, db_conn, table, column):
|
def __init__(self, db_conn, table, column):
|
||||||
@ -79,15 +58,21 @@ class StreamIdGenerator(object):
|
|||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
cur = db_conn.cursor()
|
cur = db_conn.cursor()
|
||||||
self._current_max = self._get_or_compute_current_max(cur)
|
self._current_max = self._load_current_max(cur)
|
||||||
cur.close()
|
cur.close()
|
||||||
|
|
||||||
self._unfinished_ids = deque()
|
self._unfinished_ids = deque()
|
||||||
|
|
||||||
def get_next(self, store):
|
def _load_current_max(self, txn):
|
||||||
|
txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
|
||||||
|
rows = txn.fetchall()
|
||||||
|
val, = rows[0]
|
||||||
|
return int(val) if val else 1
|
||||||
|
|
||||||
|
def get_next(self):
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
with yield stream_id_gen.get_next as stream_id:
|
with stream_id_gen.get_next() as stream_id:
|
||||||
# ... persist event ...
|
# ... persist event ...
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
@ -106,10 +91,10 @@ class StreamIdGenerator(object):
|
|||||||
|
|
||||||
return manager()
|
return manager()
|
||||||
|
|
||||||
def get_next_mult(self, store, n):
|
def get_next_mult(self, n):
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
with yield stream_id_gen.get_next(store, n) as stream_ids:
|
with stream_id_gen.get_next(n) as stream_ids:
|
||||||
# ... persist events ...
|
# ... persist events ...
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
@ -139,13 +124,3 @@ class StreamIdGenerator(object):
|
|||||||
return self._unfinished_ids[0] - 1
|
return self._unfinished_ids[0] - 1
|
||||||
|
|
||||||
return self._current_max
|
return self._current_max
|
||||||
|
|
||||||
def _get_or_compute_current_max(self, txn):
|
|
||||||
with self._lock:
|
|
||||||
txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
|
|
||||||
rows = txn.fetchall()
|
|
||||||
val, = rows[0]
|
|
||||||
|
|
||||||
self._current_max = int(val) if val else 1
|
|
||||||
|
|
||||||
return self._current_max
|
|
||||||
|
Loading…
Reference in New Issue
Block a user