diff --git a/rnsh/retry.py b/rnsh/retry.py index cf91a81..fa03415 100644 --- a/rnsh/retry.py +++ b/rnsh/retry.py @@ -64,10 +64,10 @@ class RetryStatus: self.completed = True self.timeout_callback(self.tag, self.tries) - def retry(self): + def retry(self) -> any: self.tries = self.tries + 1 self.try_time = time.time() - self.retry_callback(self.tag, self.tries) + return self.retry_callback(self.tag, self.tries) class RetryThread(AbstractContextManager): @@ -123,7 +123,9 @@ class RetryThread(AbstractContextManager): prune.append(retry) elif retry.ready: self._log.debug(f"retrying {retry.tag}, try {retry.tries + 1}/{retry.try_limit}") - retry.retry() + should_continue = retry.retry() + if not should_continue: + self.complete(retry.tag) except Exception as e: self._log.error(f"error processing retry id {retry.tag}: {e}") prune.append(retry) @@ -145,10 +147,13 @@ class RetryThread(AbstractContextManager): return next(filter(lambda s: s.tag == tag, self._statuses), None) is not None def begin(self, try_limit: int, wait_delay: float, try_callback: Callable[[any, int], any], - timeout_callback: Callable[[any, int], None], tag: int = None) -> any: + timeout_callback: Callable[[any, int], None]) -> any: self._log.debug(f"running first try") - tag = try_callback(tag, 1) + tag = try_callback(None, 1) self._log.debug(f"first try got id {tag}") + if not tag: + self._log.debug(f"callback returned None/False/0, considering complete.") + return None with self._lock: if tag is None: tag = self._get_next_tag() @@ -160,6 +165,7 @@ class RetryThread(AbstractContextManager): retry_callback=try_callback, timeout_callback=timeout_callback)) self._log.debug(f"added retry timer for {tag}") + return tag def complete(self, tag: any): assert tag is not None diff --git a/tests/test_retry.py b/tests/test_retry.py index fd9fe9c..24f5485 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -17,6 +17,7 @@ class State(AbstractContextManager): self.callbacks = 0 self.timed_out = False self.tag = str(uuid.uuid4()) + self.results = [self.tag, self.tag, self.tag] self.got_tag = None assert self.retry_thread.is_alive() @@ -30,7 +31,7 @@ class State(AbstractContextManager): self.tries = tries self.got_tag = tag self.callbacks += 1 - return self.tag + return self.results[tries - 1] def timeout(self, tag, tries): self.tries = tries @@ -47,11 +48,11 @@ class State(AbstractContextManager): def test_retry_timeout(): with State(0.1) as state: - state.retry_thread.begin(try_limit=3, - wait_delay=state.delay, - try_callback=state.retry, - timeout_callback=state.timeout) - + return_tag = state.retry_thread.begin(try_limit=3, + wait_delay=state.delay, + try_callback=state.retry, + timeout_callback=state.timeout) + assert return_tag == state.tag assert state.tries == 1 assert state.callbacks == 1 assert state.got_tag is None @@ -81,15 +82,62 @@ def test_retry_timeout(): assert state.tries == 3 -def test_retry_complete(): +def test_retry_immediate_complete(): with State(0.01) as state: - state.retry_thread.begin(try_limit=3, - wait_delay=state.delay, - try_callback=state.retry, - timeout_callback=state.timeout) - + state.results[0] = False + return_tag = state.retry_thread.begin(try_limit=3, + wait_delay=state.delay, + try_callback=state.retry, + timeout_callback=state.timeout) + assert not return_tag + assert state.callbacks == 1 + assert not state.got_tag + assert not state.timed_out + time.sleep(state.delay * 3) assert state.tries == 1 assert state.callbacks == 1 + assert not state.got_tag + assert not state.timed_out + + +def test_retry_return_complete(): + with State(0.01) as state: + state.results[1] = False + return_tag = state.retry_thread.begin(try_limit=3, + wait_delay=state.delay, + try_callback=state.retry, + timeout_callback=state.timeout) + assert return_tag == state.tag + assert state.callbacks == 1 + assert state.got_tag is None + assert not state.timed_out + time.sleep(state.delay / 2.0) + time.sleep(state.delay) + assert state.tries == 2 + assert state.callbacks == 2 + assert state.got_tag == state.tag + assert not state.timed_out + + time.sleep(state.delay) + assert state.tries == 2 + assert state.callbacks == 2 + assert state.got_tag == state.tag + assert not state.timed_out + + # check no more callbacks + time.sleep(state.delay * 3.0) + assert state.callbacks == 2 + assert state.tries == 2 + + +def test_retry_set_complete(): + with State(0.01) as state: + return_tag = state.retry_thread.begin(try_limit=3, + wait_delay=state.delay, + try_callback=state.retry, + timeout_callback=state.timeout) + assert return_tag == state.tag + assert state.callbacks == 1 assert state.got_tag is None assert not state.timed_out time.sleep(state.delay / 2.0)