refactor thread_raise safety to use a context manager

This commit is contained in:
Noah Levitt 2017-04-24 19:51:51 -07:00
parent 7706bab8b8
commit 0953e6972e
2 changed files with 87 additions and 57 deletions

View File

@ -98,29 +98,67 @@ def behavior_script(url, template_parameters=None):
return script return script
return None return None
import threading
_thread_exception_gates = {}
_thread_exception_gates_lock = threading.Lock()
def thread_accept_exceptions(): def thread_accept_exceptions():
import threading '''
thread = threading.current_thread() Returns a context manager whose purpose is best explained with a snippet:
if hasattr(thread, 'thread_raise_lock'):
lock = thread.thread_raise_lock
else:
lock = threading.Lock()
with lock:
thread.thread_raise_lock = lock
thread.thread_raise_ok = True
def thread_block_exceptions(): # === thread1 ===
import threading
thread = threading.current_thread() # If thread2 calls `thread_raise(thread1, ...)` while do_something() is
if hasattr(thread, 'thread_raise_lock'): # executing, nothing will happen (no exception will be raised in
with thread.thread_raise_lock: # thread1).
thread.thread_raise_ok = False do_something()
try:
with thread_accept_exceptions():
# Now we're in the "runtime environment" (pep340) of the
# context manager. If thread2 calls `thread_raise(thread1,
# ...)` while do_something_else() is running, the exception
# will be raised here.
do_something_else()
# Here again if thread2 calls `thread_raise`, nothing happens.
do_yet_another_thing()
except:
handle_exception()
The context manager is reentrant, i.e. you can do this:
with thread_accept_exceptions():
with thread_accept_exceptions():
blah()
# `thread_raise` will still work here
toot()
'''
class ThreadExceptionGate:
def __init__(self):
self.lock = threading.Lock()
self.ok_to_raise = 0
def __enter__(self):
with self.lock:
self.ok_to_raise += 1
def __exit__(self, exc_type, exc_value, traceback):
with self.lock:
self.ok_to_raise -= 1
assert self.ok_to_raise >= 0
with _thread_exception_gates_lock:
if not threading.current_thread().ident in _thread_exception_gates:
_thread_exception_gates[
threading.current_thread().ident] = ThreadExceptionGate()
return _thread_exception_gates[threading.current_thread().ident]
def thread_raise(thread, exctype): def thread_raise(thread, exctype):
''' '''
If `thread` has declared itself willing to accept exceptions by calling Raises the exception `exctype` in the thread `thread`, if it is willing to
`thread_accept_exceptions`, raises the exception `exctype` in the thread accept exceptions (see `thread_accept_exceptions`).
`thread`.
Adapted from http://tomerfiliba.com/recipes/Thread2/ which explains: Adapted from http://tomerfiliba.com/recipes/Thread2/ which explains:
"The exception will be raised only when executing python bytecode. If your "The exception will be raised only when executing python bytecode. If your
@ -128,8 +166,9 @@ def thread_raise(thread, exctype):
raised only when execution returns to the python code." raised only when execution returns to the python code."
Returns: Returns:
True if exception was raised, False if the thread is not accepting True if exception was raised, False if `thread` is not accepting
exceptions or another thread is holding `thread.thread_raise_lock` exceptions, or another thread is in the middle of raising an exception
in `thread`
Raises: Raises:
threading.ThreadError if `thread` is not running threading.ThreadError if `thread` is not running
@ -143,29 +182,30 @@ def thread_raise(thread, exctype):
'cannot raise %s, only exception types can be raised (not ' 'cannot raise %s, only exception types can be raised (not '
'instances)' % exctype) 'instances)' % exctype)
if not hasattr(thread, 'thread_raise_lock'):
logging.warn(
'thread is not accepting exceptions (no member variable '
'"thread_raise_lock"), not raising %s in thread %s',
exctype, thread)
return False
got_lock = thread.thread_raise_lock.acquire(timeout=0.5)
if not got_lock:
logging.warn(
'could not get acquire "thread_raise_lock", not raising %s in '
'thread %s', exctype, thread)
return False
try:
if not thread.thread_raise_ok:
logging.warn(
'thread is not accepting exceptions (thread_raise_ok is '
'%s), not raising %s in thread %s',
thread.thread_raise_ok, exctype, thread)
return False
if not thread.is_alive(): if not thread.is_alive():
raise threading.ThreadError('thread %s is not running' % thread) raise threading.ThreadError('thread %s is not running' % thread)
gate = _thread_exception_gates.get(thread.ident)
if not gate:
logging.warn(
'thread is not accepting exceptions (gate not initialized), '
'not raising %s in thread %s', exctype, thread)
return False
got_lock = gate.lock.acquire(blocking=False)
if not got_lock:
logging.warn(
'could not get acquire thread exception gate lock, not '
'raising %s in thread %s', exctype, thread)
return False
try:
if not gate.ok_to_raise:
logging.warn(
'thread is not accepting exceptions (gate.ok_to_raise is '
'%s), not raising %s in thread %s',
gate.ok_to_raise, exctype, thread)
return False
logging.info('raising %s in thread %s', exctype, thread) logging.info('raising %s in thread %s', exctype, thread)
res = ctypes.pythonapi.PyThreadState_SetAsyncExc( res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_long(thread.ident), ctypes.py_object(exctype)) ctypes.c_long(thread.ident), ctypes.py_object(exctype))
@ -179,7 +219,7 @@ def thread_raise(thread, exctype):
raise SystemError('PyThreadState_SetAsyncExc failed') raise SystemError('PyThreadState_SetAsyncExc failed')
return True return True
finally: finally:
thread.thread_raise_lock.release() gate.lock.release()
def sleep(duration): def sleep(duration):
''' '''
@ -228,4 +268,5 @@ from brozzler.cli import suggest_default_chrome_exe
__all__ = ['Page', 'Site', 'BrozzlerWorker', 'is_permitted_by_robots', __all__ = ['Page', 'Site', 'BrozzlerWorker', 'is_permitted_by_robots',
'RethinkDbFrontier', 'Browser', 'BrowserPool', 'BrowsingException', 'RethinkDbFrontier', 'Browser', 'BrowserPool', 'BrowsingException',
'new_job', 'new_site', 'Job', 'new_job_file', 'InvalidJobConf'] 'new_job', 'new_site', 'Job', 'new_job_file', 'InvalidJobConf',
'sleep', 'thread_accept_exceptions', 'thread_raise']

View File

@ -197,18 +197,19 @@ def test_thread_raise():
thread_preamble_done = threading.Event() thread_preamble_done = threading.Event()
thread_caught_exception = None thread_caught_exception = None
def thread_target(accept_exceptions=False, block_exceptions=False): def thread_target(accept_exceptions=False):
if accept_exceptions:
brozzler.thread_accept_exceptions()
if block_exceptions:
brozzler.thread_block_exceptions()
thread_preamble_done.set()
try: try:
logging.info('waiting') if accept_exceptions:
with brozzler.thread_accept_exceptions():
thread_preamble_done.set()
logging.info('waiting (accepting exceptions)')
let_thread_finish.wait()
else:
thread_preamble_done.set()
logging.info('waiting (not accepting exceptions)')
let_thread_finish.wait() let_thread_finish.wait()
except Exception as e: except Exception as e:
logging.info('caught exception %s', e) logging.info('caught exception %s', repr(e))
nonlocal thread_caught_exception nonlocal thread_caught_exception
thread_caught_exception = e thread_caught_exception = e
finally: finally:
@ -219,7 +220,7 @@ def test_thread_raise():
# test that thread_raise does not raise exception in a thread that has not # test that thread_raise does not raise exception in a thread that has not
# called thread_accept_exceptions # called thread_accept_exceptions
thread_caught_exception = None thread_caught_exception = None
th = threading.Thread(target=lambda: thread_target()) th = threading.Thread(target=lambda: thread_target(accept_exceptions=False))
th.start() th.start()
thread_preamble_done.wait() thread_preamble_done.wait()
with pytest.raises(TypeError): with pytest.raises(TypeError):
@ -244,15 +245,3 @@ def test_thread_raise():
with pytest.raises(threading.ThreadError): # thread is not running with pytest.raises(threading.ThreadError): # thread is not running
brozzler.thread_raise(th, Exception) brozzler.thread_raise(th, Exception)
# test that thread_raise does not raise exception in a thread that has
# called thread_block_exceptions
thread_caught_exception = None
th = threading.Thread(target=lambda: thread_target(block_exceptions=True))
th.start()
thread_preamble_done.wait()
assert brozzler.thread_raise(th, Exception) is False
let_thread_finish.set()
th.join()
assert thread_caught_exception is None