Improve exception handling for concurrent execution (#12109)

* fix incorrect unwrapFirstError import

this was being imported from the wrong place

* Refactor `concurrently_execute` to use `yieldable_gather_results`

* Improve exception handling in `yieldable_gather_results`

Try to avoid swallowing so many stack traces.

* mark unwrapFirstError deprecated

* changelog
This commit is contained in:
Richard van der Hoff 2022-03-01 09:34:30 +00:00 committed by GitHub
parent 952efd0bca
commit 9d11fee8f2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 151 additions and 27 deletions

View file

@ -11,9 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import traceback
from twisted.internet import defer
from twisted.internet.defer import CancelledError, Deferred
from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
from twisted.internet.task import Clock
from twisted.python.failure import Failure
from synapse.logging.context import (
SENTINEL_CONTEXT,
@ -21,7 +24,11 @@ from synapse.logging.context import (
PreserveLoggingContext,
current_context,
)
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
from synapse.util.async_helpers import (
ObservableDeferred,
concurrently_execute,
timeout_deferred,
)
from tests.unittest import TestCase
@ -171,3 +178,107 @@ class TimeoutDeferredTest(TestCase):
)
self.failureResultOf(timing_out_d, defer.TimeoutError)
self.assertIs(current_context(), context_one)
class _TestException(Exception):
pass
class ConcurrentlyExecuteTest(TestCase):
def test_limits_runners(self):
"""If we have more tasks than runners, we should get the limit of runners"""
started = 0
waiters = []
processed = []
async def callback(v):
# when we first enter, bump the start count
nonlocal started
started += 1
# record the fact we got an item
processed.append(v)
# wait for the goahead before returning
d2 = Deferred()
waiters.append(d2)
await d2
# set it going
d2 = ensureDeferred(concurrently_execute(callback, [1, 2, 3, 4, 5], 3))
# check we got exactly 3 processes
self.assertEqual(started, 3)
self.assertEqual(len(waiters), 3)
# let one finish
waiters.pop().callback(0)
# ... which should start another
self.assertEqual(started, 4)
self.assertEqual(len(waiters), 3)
# we still shouldn't be done
self.assertNoResult(d2)
# finish the job
while waiters:
waiters.pop().callback(0)
# check everything got done
self.assertEqual(started, 5)
self.assertCountEqual(processed, [1, 2, 3, 4, 5])
self.successResultOf(d2)
def test_preserves_stacktraces(self):
"""Test that the stacktrace from an exception thrown in the callback is preserved"""
d1 = Deferred()
async def callback(v):
# alas, this doesn't work at all without an await here
await d1
raise _TestException("bah")
async def caller():
try:
await concurrently_execute(callback, [1], 2)
except _TestException as e:
tb = traceback.extract_tb(e.__traceback__)
# we expect to see "caller", "concurrently_execute" and "callback".
self.assertEqual(tb[0].name, "caller")
self.assertEqual(tb[1].name, "concurrently_execute")
self.assertEqual(tb[-1].name, "callback")
else:
self.fail("No exception thrown")
d2 = ensureDeferred(caller())
d1.callback(0)
self.successResultOf(d2)
def test_preserves_stacktraces_on_preformed_failure(self):
"""Test that the stacktrace on a Failure returned by the callback is preserved"""
d1 = Deferred()
f = Failure(_TestException("bah"))
async def callback(v):
# alas, this doesn't work at all without an await here
await d1
await defer.fail(f)
async def caller():
try:
await concurrently_execute(callback, [1], 2)
except _TestException as e:
tb = traceback.extract_tb(e.__traceback__)
# we expect to see "caller", "concurrently_execute", "callback",
# and some magic from inside ensureDeferred that happens when .fail
# is called.
self.assertEqual(tb[0].name, "caller")
self.assertEqual(tb[1].name, "concurrently_execute")
self.assertEqual(tb[-2].name, "callback")
else:
self.fail("No exception thrown")
d2 = ensureDeferred(caller())
d1.callback(0)
self.successResultOf(d2)