Fix daedlock

This commit is contained in:
Erik Johnston 2015-05-15 10:54:04 +01:00
parent 1d566edb81
commit a2c4f3f150
7 changed files with 118 additions and 80 deletions

View File

@ -222,7 +222,7 @@ class FederationClient(FederationBase):
for p in transaction_data["pdus"] for p in transaction_data["pdus"]
] ]
if pdu_list: if pdu_list and pdu_list[0]:
pdu = pdu_list[0] pdu = pdu_list[0]
# Check signatures are correct. # Check signatures are correct.
@ -255,7 +255,7 @@ class FederationClient(FederationBase):
) )
continue continue
if self._get_pdu_cache is not None: if self._get_pdu_cache is not None and pdu:
self._get_pdu_cache[event_id] = pdu self._get_pdu_cache[event_id] = pdu
defer.returnValue(pdu) defer.returnValue(pdu)
@ -475,6 +475,9 @@ class FederationClient(FederationBase):
limit (int): Maximum number of events to return. limit (int): Maximum number of events to return.
min_depth (int): Minimum depth of events tor return. min_depth (int): Minimum depth of events tor return.
""" """
logger.debug("get_missing_events: latest_events: %r", latest_events)
logger.debug("get_missing_events: earliest_events_ids: %r", earliest_events_ids)
try: try:
content = yield self.transport_layer.get_missing_events( content = yield self.transport_layer.get_missing_events(
destination=destination, destination=destination,
@ -485,6 +488,8 @@ class FederationClient(FederationBase):
min_depth=min_depth, min_depth=min_depth,
) )
logger.debug("get_missing_events: Got content: %r", content)
events = [ events = [
self.event_from_pdu_json(e) self.event_from_pdu_json(e)
for e in content.get("events", []) for e in content.get("events", [])
@ -494,6 +499,8 @@ class FederationClient(FederationBase):
destination, events, outlier=False destination, events, outlier=False
) )
logger.debug("get_missing_events: signed_events: %r", signed_events)
have_gotten_all_from_destination = True have_gotten_all_from_destination = True
except HttpResponseException as e: except HttpResponseException as e:
if not e.code == 400: if not e.code == 400:
@ -518,6 +525,8 @@ class FederationClient(FederationBase):
# Are we missing any? # Are we missing any?
seen_events = set(earliest_events_ids) seen_events = set(earliest_events_ids)
logger.debug("get_missing_events: signed_events2: %r", signed_events)
seen_events.update(e.event_id for e in signed_events) seen_events.update(e.event_id for e in signed_events)
missing_events = {} missing_events = {}
@ -561,7 +570,7 @@ class FederationClient(FederationBase):
res = yield defer.DeferredList(deferreds, consumeErrors=True) res = yield defer.DeferredList(deferreds, consumeErrors=True)
for (result, val), (e_id, _) in zip(res, ordered_missing): for (result, val), (e_id, _) in zip(res, ordered_missing):
if result: if result and val:
signed_events.append(val) signed_events.append(val)
else: else:
failed_to_fetch.add(e_id) failed_to_fetch.add(e_id)

View File

@ -415,6 +415,8 @@ class FederationServer(FederationBase):
pdu.internal_metadata.outlier = True pdu.internal_metadata.outlier = True
elif min_depth and pdu.depth > min_depth: elif min_depth and pdu.depth > min_depth:
if get_missing and prevs - seen: if get_missing and prevs - seen:
logger.debug("We're missing: %r", prevs-seen)
latest = yield self.store.get_latest_event_ids_in_room( latest = yield self.store.get_latest_event_ids_in_room(
pdu.room_id pdu.room_id
) )

View File

@ -303,18 +303,27 @@ class MessageHandler(BaseHandler):
if event.membership != Membership.JOIN: if event.membership != Membership.JOIN:
return return
try: try:
(messages, token), current_state = yield defer.gatherResults( # (messages, token), current_state = yield defer.gatherResults(
[ # [
self.store.get_recent_events_for_room( # self.store.get_recent_events_for_room(
# event.room_id,
# limit=limit,
# end_token=now_token.room_key,
# ),
# self.state_handler.get_current_state(
# event.room_id
# ),
# ]
# ).addErrback(unwrapFirstError)
messages, token = yield self.store.get_recent_events_for_room(
event.room_id, event.room_id,
limit=limit, limit=limit,
end_token=now_token.room_key, end_token=now_token.room_key,
), )
self.state_handler.get_current_state( current_state = yield self.state_handler.get_current_state(
event.room_id event.room_id
), )
]
).addErrback(unwrapFirstError)
start_token = now_token.copy_and_replace("room_key", token[0]) start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1]) end_token = now_token.copy_and_replace("room_key", token[1])

View File

@ -301,10 +301,12 @@ class SQLBaseStore(object):
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
max_entries=hs.config.event_cache_size) max_entries=hs.config.event_cache_size)
self._event_fetch_lock = threading.Condition() self._event_fetch_lock = threading.Lock()
self._event_fetch_list = [] self._event_fetch_list = []
self._event_fetch_ongoing = 0 self._event_fetch_ongoing = 0
self._pending_ds = []
self.database_engine = hs.database_engine self.database_engine = hs.database_engine
self._stream_id_gen = StreamIdGenerator() self._stream_id_gen = StreamIdGenerator()
@ -344,8 +346,7 @@ class SQLBaseStore(object):
self._clock.looping_call(loop, 10000) self._clock.looping_call(loop, 10000)
@contextlib.contextmanager def _new_transaction(self, conn, desc, after_callbacks, func, *args, **kwargs):
def _new_transaction(self, conn, desc, after_callbacks):
start = time.time() * 1000 start = time.time() * 1000
txn_id = self._TXN_ID txn_id = self._TXN_ID
@ -366,6 +367,9 @@ class SQLBaseStore(object):
txn = LoggingTransaction( txn = LoggingTransaction(
txn, name, self.database_engine, after_callbacks txn, name, self.database_engine, after_callbacks
) )
r = func(txn, *args, **kwargs)
conn.commit()
return r
except self.database_engine.module.OperationalError as e: except self.database_engine.module.OperationalError as e:
# This can happen if the database disappears mid # This can happen if the database disappears mid
# transaction. # transaction.
@ -398,17 +402,6 @@ class SQLBaseStore(object):
) )
continue continue
raise raise
try:
yield txn
conn.commit()
return
except:
try:
conn.rollback()
except:
pass
raise
except Exception as e: except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", name, e) logger.debug("[TXN FAIL] {%s} %s", name, e)
raise raise
@ -440,8 +433,9 @@ class SQLBaseStore(object):
conn.reconnect() conn.reconnect()
current_context.copy_to(context) current_context.copy_to(context)
with self._new_transaction(conn, desc, after_callbacks) as txn: return self._new_transaction(
return func(txn, *args, **kwargs) conn, desc, after_callbacks, func, *args, **kwargs
)
result = yield preserve_context_over_fn( result = yield preserve_context_over_fn(
self._db_pool.runWithConnection, self._db_pool.runWithConnection,

View File

@ -420,12 +420,14 @@ class EventsStore(SQLBaseStore):
]) ])
if not txn: if not txn:
logger.debug("enqueue before")
missing_events = yield self._enqueue_events( missing_events = yield self._enqueue_events(
missing_events_ids, missing_events_ids,
check_redacted=check_redacted, check_redacted=check_redacted,
get_prev_content=get_prev_content, get_prev_content=get_prev_content,
allow_rejected=allow_rejected, allow_rejected=allow_rejected,
) )
logger.debug("enqueue after")
else: else:
missing_events = self._fetch_events_txn( missing_events = self._fetch_events_txn(
txn, txn,
@ -498,41 +500,39 @@ class EventsStore(SQLBaseStore):
allow_rejected=allow_rejected, allow_rejected=allow_rejected,
)) ))
@defer.inlineCallbacks def _do_fetch(self, conn):
def _enqueue_events(self, events, check_redacted=True,
get_prev_content=False, allow_rejected=False):
if not events:
defer.returnValue({})
def do_fetch(conn):
event_list = [] event_list = []
while True:
try: try:
while True:
logger.debug("do_fetch getting lock")
with self._event_fetch_lock: with self._event_fetch_lock:
i = 0 logger.debug("do_fetch go lock: %r", self._event_fetch_list)
while not self._event_fetch_list:
self._event_fetch_ongoing -= 1
return
event_list = self._event_fetch_list event_list = self._event_fetch_list
self._event_fetch_list = [] self._event_fetch_list = []
if not event_list:
self._event_fetch_ongoing -= 1
return
event_id_lists = zip(*event_list)[0] event_id_lists = zip(*event_list)[0]
event_ids = [ event_ids = [
item for sublist in event_id_lists for item in sublist item for sublist in event_id_lists for item in sublist
] ]
with self._new_transaction(conn, "do_fetch", []) as txn: rows = self._new_transaction(
rows = self._fetch_event_rows(txn, event_ids) conn, "do_fetch", [], self._fetch_event_rows, event_ids
)
row_dict = { row_dict = {
r["event_id"]: r r["event_id"]: r
for r in rows for r in rows
} }
for ids, d in event_list: logger.debug("do_fetch got events: %r", row_dict.keys())
def fire():
def fire(evs):
for ids, d in evs:
if not d.called: if not d.called:
try:
d.callback( d.callback(
[ [
row_dict[i] row_dict[i]
@ -540,32 +540,51 @@ class EventsStore(SQLBaseStore):
if i in row_dict if i in row_dict
] ]
) )
reactor.callFromThread(fire) except:
logger.exception("Failed to callback")
reactor.callFromThread(fire, event_list)
except Exception as e: except Exception as e:
logger.exception("do_fetch") logger.exception("do_fetch")
for _, d in event_list:
if not d.called:
reactor.callFromThread(d.errback, e)
with self._event_fetch_lock: def fire(evs):
self._event_fetch_ongoing -= 1 for _, d in evs:
return if not d.called:
d.errback(e)
if event_list:
reactor.callFromThread(fire, event_list)
@defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True,
get_prev_content=False, allow_rejected=False):
if not events:
defer.returnValue({})
events_d = defer.Deferred() events_d = defer.Deferred()
try:
logger.debug("enqueueueueue getting lock")
with self._event_fetch_lock: with self._event_fetch_lock:
logger.debug("enqueue go lock")
self._event_fetch_list.append( self._event_fetch_list.append(
(events, events_d) (events, events_d)
) )
self._event_fetch_lock.notify_all()
# if self._event_fetch_ongoing < 5:
self._event_fetch_ongoing += 1 self._event_fetch_ongoing += 1
self.runWithConnection( self.runWithConnection(
do_fetch self._do_fetch
) )
except Exception as e:
if not events_d.called:
events_d.errback(e)
logger.debug("events_d before")
try:
rows = yield events_d rows = yield events_d
except:
logger.exception("events_d")
logger.debug("events_d after")
res = yield defer.gatherResults( res = yield defer.gatherResults(
[ [
@ -580,6 +599,7 @@ class EventsStore(SQLBaseStore):
], ],
consumeErrors=True consumeErrors=True
) )
logger.debug("gatherResults after")
defer.returnValue({ defer.returnValue({
e.event_id: e e.event_id: e
@ -639,7 +659,8 @@ class EventsStore(SQLBaseStore):
rejected_reason=row["rejects"], rejected_reason=row["rejects"],
) )
for row in rows for row in rows
] ],
consumeErrors=True,
) )
defer.returnValue({ defer.returnValue({

View File

@ -357,10 +357,12 @@ class StreamStore(SQLBaseStore):
"get_recent_events_for_room", get_recent_events_for_room_txn "get_recent_events_for_room", get_recent_events_for_room_txn
) )
logger.debug("stream before")
events = yield self._get_events( events = yield self._get_events(
[r["event_id"] for r in rows], [r["event_id"] for r in rows],
get_prev_content=True get_prev_content=True
) )
logger.debug("stream after")
self._set_before_and_after(events, rows) self._set_before_and_after(events, rows)

View File

@ -33,8 +33,9 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.db_pool = Mock(spec=["runInteraction"]) self.db_pool = Mock(spec=["runInteraction"])
self.mock_txn = Mock() self.mock_txn = Mock()
self.mock_conn = Mock(spec_set=["cursor"]) self.mock_conn = Mock(spec_set=["cursor", "rollback", "commit"])
self.mock_conn.cursor.return_value = self.mock_txn self.mock_conn.cursor.return_value = self.mock_txn
self.mock_conn.rollback.return_value = None
# Our fake runInteraction just runs synchronously inline # Our fake runInteraction just runs synchronously inline
def runInteraction(func, *args, **kwargs): def runInteraction(func, *args, **kwargs):