SYN-377: Make sure that the StreamIdGenerator.get_next.__exit__ is called from the main thread after the transaction completes, not from database thread before the transaction completes.

This commit is contained in:
Mark Haines 2015-05-12 11:20:40 +01:00
parent d244fa9741
commit 5002056b16
2 changed files with 27 additions and 23 deletions

View File

@ -23,6 +23,7 @@ from synapse.crypto.event_signing import compute_event_reference_hash
from syutil.base64util import decode_base64
from syutil.jsonutil import encode_canonical_json
from contextlib import contextmanager
import logging
@ -41,7 +42,15 @@ class EventsStore(SQLBaseStore):
self.min_token -= 1
stream_ordering = self.min_token
if stream_ordering is None:
stream_ordering_manager = yield self._stream_id_gen.get_next(self)
else:
@contextmanager
def stream_ordering_manager():
yield stream_ordering
try:
with stream_ordering_manager as stream_ordering:
yield self.runInteraction(
"persist_event",
self._persist_event_txn,
@ -95,15 +104,6 @@ class EventsStore(SQLBaseStore):
# Remove the any existing cache entries for the event_id
txn.call_after(self._invalidate_get_event_cache, event.event_id)
if stream_ordering is None:
with self._stream_id_gen.get_next_txn(txn) as stream_ordering:
return self._persist_event_txn(
txn, event, context, backfilled,
stream_ordering=stream_ordering,
is_new_state=is_new_state,
current_state=current_state,
)
# We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table
if current_state:

View File

@ -78,14 +78,18 @@ class StreamIdGenerator(object):
self._current_max = None
self._unfinished_ids = deque()
def get_next_txn(self, txn):
@defer.inlineCallbacks
def get_next(self, store):
"""
Usage:
with stream_id_gen.get_next_txn(txn) as stream_id:
with yield stream_id_gen.get_next as stream_id:
# ... persist event ...
"""
if not self._current_max:
self._get_or_compute_current_max(txn)
yield store.runInteraction(
"_compute_current_max",
self._get_or_compute_current_max,
)
with self._lock:
self._current_max += 1
@ -101,7 +105,7 @@ class StreamIdGenerator(object):
with self._lock:
self._unfinished_ids.remove(next_id)
return manager()
defer.returnValue(manager())
@defer.inlineCallbacks
def get_max_token(self, store):