Performance and stability improvements.

This commit is contained in:
Aaron Heise 2023-02-12 22:29:16 -06:00
parent 7d711cde6d
commit 5e755acad4
6 changed files with 303 additions and 113 deletions

View File

@ -66,10 +66,10 @@ rnsh a5f72aefc2cb3cdba648f73f77c4e887
Usage:
rnsh [--config <configdir>] [-i <identityfile>] [-s <service_name>] [-l] -p
rnsh -l [--config <configfile>] [-i <identityfile>] [-s <service_name>]
[-v...] [-q...] [-b] (-n | -a <identity_hash> [-a <identity_hash>]...)
[-v... | -q...] [-b <period>] (-n | -a <identity_hash> [-a <identity_hash>] ...)
[--] <program> [<arg> ...]
rnsh [--config <configfile>] [-i <identityfile>] [-s <service_name>]
[-v...] [-q...] [-N] [-m] [-w <timeout>] <destination_hash>
[-v... | -q...] [-N] [-m] [-w <timeout>] <destination_hash>
rnsh -h
rnsh --version
@ -79,13 +79,20 @@ Options:
-s NAME --service NAME Listen on/connect to specific service name if not default
-p --print-identity Print identity information and exit
-l --listen Listen (server) mode
-b --no-announce Do not announce service
-b --announce PERIOD Announce on startup and every PERIOD seconds
Specify 0 for PERIOD to announce on startup only.
-a HASH --allowed HASH Specify identities allowed to connect
-n --no-auth Disable authentication
-N --no-id Disable identify on connect
-m --mirror Client returns with code of remote process
-w TIME --timeout TIME Specify client connect and request timeout in seconds
-v --verbose Increase verbosity
DEFAULT LEVEL
CRITICAL
Initiator -> ERROR
WARNING
Listener -> INFO
DEBUG
-q --quiet Increase quietness
--version Show version
-h --help Show this help

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "rnsh"
version = "0.0.2"
version = "0.0.3"
description = "Shell over Reticulum"
authors = ["acehoss <acehoss@acehoss.net>"]
license = "MIT"

View File

@ -22,6 +22,7 @@
import asyncio
import contextlib
import copy
import errno
import fcntl
import functools
@ -77,6 +78,7 @@ def tty_read(fd: int) -> bytes:
if fd_is_closed(fd):
return None
try:
run = True
result = bytearray()
while run and not fd_is_closed(fd):
@ -85,7 +87,7 @@ def tty_read(fd: int) -> bytes:
break
for f in ready:
try:
data = os.read(f, 512)
data = os.read(f, 4096)
except OSError as e:
if e.errno != errno.EIO and e.errno != errno.EWOULDBLOCK:
raise
@ -95,6 +97,8 @@ def tty_read(fd: int) -> bytes:
if data is not None and len(data) > 0:
result.extend(data)
return result
except Exception as ex:
module_logger.error("tty_read error: {ex}")
def fd_is_closed(fd: int) -> bool:
@ -169,7 +173,7 @@ class TTYRestorer(contextlib.AbstractContextManager):
ATTR_IDX_LFLAG = 4
ATTR_IDX_CC = 5
def __init__(self, fd: int):
def __init__(self, fd: int, suppress_logs=False):
"""
Saves termios attributes for a tty for later restoration.
@ -183,29 +187,37 @@ class TTYRestorer(contextlib.AbstractContextManager):
:param fd: file descriptor of tty
"""
self._log = module_logger.getChild(self.__class__.__name__)
self._fd = fd
self._tattr = None
with contextlib.suppress(termios.error):
termios.tcgetattr(self._fd)
self._suppress_logs = suppress_logs
self._tattr = self.current_attr()
if not self._tattr and not self._suppress_logs:
self._log.warning(f"Could not get attrs for fd {fd}")
def raw(self):
"""
Set raw mode on tty
"""
if not self._fd:
if self._fd is None:
return
with contextlib.suppress(termios.error):
tty.setraw(self._fd, termios.TCSANOW)
tty.setraw(self._fd, termios.TCSADRAIN)
def original_attr(self) -> [any]:
return copy.deepcopy(self._tattr)
def current_attr(self) -> [any]:
"""
Get the current termios attributes for the wrapped fd.
:return: attribute array
"""
if not self._fd:
if self._fd is None:
return None
return termios.tcgetattr(self._fd)
with contextlib.suppress(termios.error):
return copy.deepcopy(termios.tcgetattr(self._fd))
return None
def set_attr(self, attr: [any], when: int = termios.TCSANOW):
"""
@ -213,7 +225,7 @@ class TTYRestorer(contextlib.AbstractContextManager):
:param attr: attribute list to set
:param when: when attributes should be applied (termios.TCSANOW, termios.TCSADRAIN, termios.TCSAFLUSH)
"""
if not attr or not self._fd:
if not attr or self._fd is None:
return
with contextlib.suppress(termios.error):
@ -228,7 +240,64 @@ class TTYRestorer(contextlib.AbstractContextManager):
def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
__traceback: types.TracebackType) -> bool:
self.restore()
return __exc_type is not None and issubclass(__exc_type, termios.error)
return False #__exc_type is not None and issubclass(__exc_type, termios.error)
def _task_from_event(evt: asyncio.Event, loop: asyncio.AbstractEventLoop = None):
if not loop:
loop = asyncio.get_running_loop()
#TODO: this is hacky
async def wait():
while not evt.is_set():
await asyncio.sleep(0.1)
return True
return loop.create_task(wait())
class AggregateException(Exception):
def __init__(self, inner_exceptions: [Exception]):
super().__init__()
self.inner_exceptions = inner_exceptions
def __str__(self):
return "Multiple exceptions encountered: \n\n" + "\n\n".join(map(lambda e: str(e), self.inner_exceptions))
async def event_wait_any(evts: [asyncio.Event], timeout: float = None) -> (any, any):
tasks = list(map(lambda evt: (evt, _task_from_event(evt)), evts))
# try:
finished, unfinished = await asyncio.wait(map(lambda t: t[1], tasks),
timeout=timeout,
return_when=asyncio.FIRST_COMPLETED)
if len(unfinished) > 0:
for task in unfinished:
task.cancel()
await asyncio.wait(unfinished)
# exceptions = []
#
# for f in finished:
# ex = f.exception()
# if ex and not isinstance(ex, asyncio.CancelledError) and not isinstance(ex, TimeoutError):
# exceptions.append(ex)
#
# if len(exceptions) > 0:
# raise AggregateException(exceptions)
return next(map(lambda t: next(map(lambda tt: tt[0], tasks)), finished), None)
# finally:
# unfinished = []
# for task in map(lambda t: t[1], tasks):
# if task.done():
# if not task.cancelled():
# task.exception()
# else:
# task.cancel()
# unfinished.append(task)
# if len(unfinished) > 0:
# await asyncio.wait(unfinished)
async def event_wait(evt: asyncio.Event, timeout: float) -> bool:
@ -238,9 +307,7 @@ async def event_wait(evt: asyncio.Event, timeout: float) -> bool:
:param timeout: maximum number of seconds to wait.
:return: True if event was set, False if timeout expired
"""
# suppress TimeoutError because we'll return False in case of timeout
with contextlib.suppress(asyncio.TimeoutError):
await asyncio.wait_for(evt.wait(), timeout)
await event_wait_any([evt], timeout=timeout)
return evt.is_set()
@ -408,6 +475,8 @@ class CallbackSubprocess:
# self.log.debug("poll")
try:
pid, self._return_code = os.waitpid(self._pid, os.WNOHANG)
if self._return_code is not None:
self._return_code = self._return_code & 0xff
if self._return_code is not None and not process_exists(self._pid):
self._log.debug(f"polled return code {self._return_code}")
self._terminated_cb(self._return_code)

View File

@ -34,6 +34,7 @@ import sys
import termios
import threading
import time
import tty
from typing import Callable, TypeVar
import RNS
import rnsh.exception as exception
@ -58,20 +59,20 @@ _allow_all = False
_allowed_identity_hashes = []
_cmd: [str] = None
DATA_AVAIL_MSG = "data available"
_finished: asyncio.Event | None = None
_finished: asyncio.Event = None
_retry_timer: retry.RetryThread | None = None
_destination: RNS.Destination | None = None
_loop: asyncio.AbstractEventLoop | None = None
async def _check_finished(timeout: float = 0):
await process.event_wait(_finished, timeout=timeout)
return await process.event_wait(_finished, timeout=timeout)
def _sigint_handler(sig, frame):
global _finished
log = _get_logger("_sigint_handler")
log.debug("SIGINT")
log.debug(signal.Signals(sig).name)
if _finished is not None:
_finished.set()
else:
@ -107,8 +108,6 @@ def _print_identity(configdir, identitypath, service_name, include_destination:
exit(0)
# hack_goes_here
async def _listen(configdir, command, identitypath=None, service_name="default", verbosity=0, quietness=0,
allowed=None, disable_auth=None, announce_period=900):
global _identity, _allow_all, _allowed_identity_hashes, _reticulum, _cmd, _destination
@ -116,8 +115,8 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
_cmd = command
targetloglevel = RNS.LOG_INFO + verbosity - quietness
rnslogging.RnsHandler.set_global_log_level(__logging.INFO)
_reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
rnslogging.RnsHandler.set_log_level_with_rns_level(targetloglevel)
_prepare_identity(identitypath)
_destination = RNS.Destination(_identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, service_name)
@ -151,6 +150,7 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
_destination.register_request_handler(
path="data",
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
# response_generator=_listen_request,
allow=RNS.Destination.ALLOW_LIST,
allowed_list=_allowed_identity_hashes
)
@ -158,10 +158,12 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
_destination.register_request_handler(
path="data",
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
# response_generator=_listen_request,
allow=RNS.Destination.ALLOW_ALL,
)
await _check_finished()
if await _check_finished():
return
log.info("rnsh listening for commands on " + RNS.prettyhexrep(_destination.hash))
@ -171,23 +173,23 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
last = time.time()
try:
while True:
while not await _check_finished(1.0):
if announce_period and 0 < announce_period < time.time() - last:
last = time.time()
_destination.announce()
await _check_finished(1.0)
except KeyboardInterrupt:
finally:
log.warning("Shutting down")
for link in list(_destination.links):
with exception.permit(SystemExit):
with exception.permit(SystemExit, KeyboardInterrupt):
proc = Session.get_for_tag(link.link_id)
if proc is not None and proc.process.running:
proc.process.terminate()
await asyncio.sleep(1)
await asyncio.sleep(0)
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
for link in links_still_active:
if link.status != RNS.Link.CLOSED:
link.teardown()
await asyncio.sleep(0)
_PROTOCOL_MAGIC = 0xdeadbeef
@ -334,6 +336,7 @@ class Session:
@staticmethod
def default_request(stdin_fd: int | None) -> [any]:
global _tr
global _PROTOCOL_VERSION_0
request: list[any] = [
_PROTOCOL_VERSION_0, # 0 Protocol Version
@ -348,7 +351,7 @@ class Session:
if stdin_fd is not None:
request[Session.REQUEST_IDX_TERM] = os.environ.get("TERM", None)
request[Session.REQUEST_IDX_TIOS] = termios.tcgetattr(stdin_fd)
request[Session.REQUEST_IDX_TIOS] = _tr.original_attr() if _tr else termios.tcgetattr(stdin_fd)
request[Session.REQUEST_IDX_ROWS], \
request[Session.REQUEST_IDX_COLS], \
request[Session.REQUEST_IDX_HPIX], \
@ -372,6 +375,7 @@ class Session:
if term_state != self._term_state:
self._term_state = term_state
self._update_winsz()
# self.process.tcsetattr(termios.TCSANOW, self._term_state[0])
if stdin is not None and len(stdin) > 0:
stdin = base64.b64decode(stdin)
self.process.write(stdin)
@ -539,7 +543,8 @@ def _initiator_identified(link, identity):
def _listen_request(path, data, request_id, link_id, remote_identity, requested_at):
global _destination, _retry_timer, _loop
log = _get_logger("_listen_request")
log.debug(f"listen_execute {path} {request_id} {link_id} {remote_identity}, {requested_at}")
log.debug(
f"listen_execute {path} {RNS.prettyhexrep(request_id)} {RNS.prettyhexrep(link_id)} {remote_identity}, {requested_at}")
if not hasattr(data, "__len__") or len(data) < 1:
raise Exception("Request data invalid")
_retry_timer.complete(link_id)
@ -553,7 +558,8 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
if not remote_version == _PROTOCOL_VERSION_0:
response = Session.default_response()
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode("Listener<->initiator version mismatch\r\n".encode("utf-8"))
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode(
"Listener<->initiator version mismatch\r\n".encode("utf-8"))
response[Session.RESPONSE_IDX_RETCODE] = 255
response[Session.RESPONSE_IDX_RDYBYTE] = 0
return response
@ -570,7 +576,7 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
loop=_loop)
# leave significant headroom for metadata and encoding
result = session.process_request(data, link.MDU * 4 // 3)
result = session.process_request(data, link.MDU * 1 // 2)
return result
# return ProcessState.default_response()
except Exception as e:
@ -589,7 +595,8 @@ async def _spin(until: callable = None, timeout: float | None = None) -> bool:
timeout += time.time()
while (timeout is None or time.time() < timeout) and not until():
await _check_finished(0.01)
if await _check_finished(0.001):
raise asyncio.CancelledError()
if timeout is not None and time.time() > timeout:
return False
else:
@ -638,8 +645,8 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
if _reticulum is None:
targetloglevel = RNS.LOG_ERROR + verbosity - quietness
rnslogging.RnsHandler.set_global_log_level(__logging.ERROR)
_reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
rnslogging.RnsHandler.set_log_level_with_rns_level(targetloglevel)
if _identity is None:
_prepare_identity(identitypath)
@ -661,6 +668,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
)
if _link is None or _link.status == RNS.Link.PENDING:
log.debug("No link")
_link = RNS.Link(_destination)
_link.did_identify = False
@ -668,6 +676,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
if not await _spin(until=lambda: _link.status == RNS.Link.ACTIVE, timeout=timeout):
raise RemoteExecutionError("Could not establish link with " + RNS.prettyhexrep(destination_hash))
log.debug("Have link")
if not noid and not _link.did_identify:
_link.identify(_identity)
_link.did_identify = True
@ -680,6 +689,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
# TODO: Tune
timeout = timeout + _link.rtt * 4 + _remote_exec_grace
log.debug("Sending request")
request_receipt = _link.request(
path="data",
data=request,
@ -687,6 +697,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
)
timeout += 0.5
log.debug("Waiting for delivery")
await _spin(
until=lambda: _link.status == RNS.Link.CLOSED or (
request_receipt.status != RNS.RequestReceipt.FAILED and
@ -732,15 +743,14 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
except Exception as e:
raise RemoteExecutionError(f"Received invalid response") from e
_tr.raw()
if stdout is not None:
# log.debug(f"stdout: {stdout}")
_tr.raw()
log.debug(f"stdout: {stdout}")
os.write(sys.stdout.fileno(), stdout)
sys.stdout.flush()
sys.stderr.flush()
log.debug(f"{ready_bytes} bytes ready on server, return code {return_code}")
got_bytes = len(stdout) if stdout is not None else 0
log.debug(f"{got_bytes} chars received, {ready_bytes} bytes ready on server, return code {return_code}")
if ready_bytes > 0:
_new_data.set()
@ -761,11 +771,6 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
data_buffer = bytearray()
def sigint_handler():
log.debug("KeyboardInterrupt")
data_buffer.extend("\x03".encode("utf-8"))
_new_data.set()
def sigwinch_handler():
# log.debug("WindowChanged")
if _new_data is not None:
@ -773,7 +778,7 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
def stdin():
data = process.tty_read(sys.stdin.fileno())
# log.debug(f"stdin {data}")
log.debug(f"stdin {data}")
if data is not None:
data_buffer.extend(data)
_new_data.set()
@ -782,10 +787,12 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
await _check_finished()
loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler)
# leave a lot of overhead
mdu = 64
rtt = 5
first_loop = True
while True:
while not await _check_finished():
try:
log.debug("top of client loop")
stdin = data_buffer[:mdu]
@ -806,22 +813,27 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
if first_loop:
first_loop = False
mdu = _link.MDU * 4 // 3
loop.remove_signal_handler(signal.SIGINT)
loop.add_signal_handler(signal.SIGINT, sigint_handler)
mdu = _link.MDU // 2
_new_data.set()
if _link:
rtt = _link.rtt
if return_code is not None:
log.debug(f"received return code {return_code}, exiting")
with exception.permit(SystemExit):
with exception.permit(SystemExit, KeyboardInterrupt):
_link.teardown()
return return_code
except asyncio.CancelledError:
if _link and _link.status != RNS.Link.CLOSED:
_link.teardown()
return 0
except RemoteExecutionError as e:
print(e.msg)
return 255
await process.event_wait(_new_data, 5)
await process.event_wait_any([_new_data, _finished], timeout=min(max(rtt * 50, 5), 120))
_T = TypeVar("_T")
@ -835,6 +847,11 @@ def _split_array_at(arr: [_T], at: _T) -> ([_T], [_T]):
return arr, []
def _loop_set_signal(sig, loop):
loop.remove_signal_handler(sig)
loop.add_signal_handler(sig, functools.partial(_sigint_handler, sig, None))
async def _rnsh_cli_main():
global _tr, _finished, _loop
import docopt
@ -842,16 +859,16 @@ async def _rnsh_cli_main():
_loop = asyncio.get_running_loop()
rnslogging.set_main_loop(_loop)
_finished = asyncio.Event()
_loop.remove_signal_handler(signal.SIGINT)
_loop.add_signal_handler(signal.SIGINT, functools.partial(_sigint_handler, signal.SIGINT, None))
_loop_set_signal(signal.SIGINT, _loop)
_loop_set_signal(signal.SIGTERM, _loop)
usage = '''
Usage:
rnsh [--config <configdir>] [-i <identityfile>] [-s <service_name>] [-l] -p
rnsh -l [--config <configfile>] [-i <identityfile>] [-s <service_name>]
[-v...] [-q...] [-b <period>] (-n | -a <identity_hash> [-a <identity_hash>] ...)
[-v... | -q...] [-b <period>] (-n | -a <identity_hash> [-a <identity_hash>] ...)
[--] <program> [<arg> ...]
rnsh [--config <configfile>] [-i <identityfile>] [-s <service_name>]
[-v...] [-q...] [-N] [-m] [-w <timeout>] <destination_hash>
[-v... | -q...] [-N] [-m] [-w <timeout>] <destination_hash>
rnsh -h
rnsh --version
@ -869,6 +886,12 @@ Options:
-m --mirror Client returns with code of remote process
-w TIME --timeout TIME Specify client connect and request timeout in seconds
-v --verbose Increase verbosity
DEFAULT LEVEL
CRITICAL
Initiator -> ERROR
WARNING
Listener -> INFO
DEBUG
-q --quiet Increase quietness
--version Show version
-h --help Show this help
@ -930,7 +953,6 @@ Options:
)
if args_destination is not None and args_service_name is not None:
try:
return_code = await _initiate(
configdir=args_config,
identitypath=args_identity,
@ -942,24 +964,29 @@ Options:
timeout=args_timeout,
)
return return_code if args_mirror else 0
finally:
_tr.restore()
else:
print("")
print(args)
print("")
def _noop():
pass
# RNS.exit = _noop
def rnsh_cli():
global _tr, _retry_timer
with process.TTYRestorer(sys.stdin.fileno()) as _tr:
with retry.RetryThread() as _retry_timer:
with process.TTYRestorer(sys.stdin.fileno()) as _tr, retry.RetryThread() as _retry_timer:
return_code = asyncio.run(_rnsh_cli_main())
with exception.permit(SystemExit):
process.tty_unset_reader_callbacks(sys.stdin.fileno())
sys.exit(return_code or 255)
# RNS.Reticulum.exit_handler()
# time.sleep(0.5)
sys.exit(return_code if return_code is not None else 255)
if __name__ == "__main__":

View File

@ -60,10 +60,29 @@ class RnsHandler(Handler):
return RNS.LOG_DEBUG
return RNS.LOG_DEBUG
def get_logging_loglevel(rnsloglevel: int) -> int:
if rnsloglevel == RNS.LOG_CRITICAL:
return logging.CRITICAL
if rnsloglevel == RNS.LOG_ERROR:
return logging.ERROR
if rnsloglevel == RNS.LOG_WARNING:
return logging.WARNING
if rnsloglevel == RNS.LOG_NOTICE:
return logging.INFO
if rnsloglevel == RNS.LOG_INFO:
return logging.INFO
if rnsloglevel >= RNS.LOG_VERBOSE:
return RNS.LOG_DEBUG
return RNS.LOG_DEBUG
@classmethod
def set_global_log_level(cls, log_level: int):
logging.getLogger().setLevel(log_level)
RNS.loglevel = cls.get_rns_loglevel(log_level)
def set_log_level_with_rns_level(cls, rns_log_level: int):
logging.getLogger().setLevel(RnsHandler.get_logging_loglevel(rns_log_level))
RNS.loglevel = rns_log_level
def set_log_level_with_logging_level(cls, logging_log_level: int):
logging.getLogger().setLevel(logging_log_level)
RNS.loglevel = cls.get_rns_loglevel(logging_log_level)
def emit(self, record):
"""
@ -107,13 +126,16 @@ _rns_log_orig = RNS.log
def _rns_log(msg, level=3, _override_destination=False):
if RNS.loglevel < level:
return
if not RNS.compact_log_fmt:
msg = (" " * (7 - len(RNS.loglevelname(level)))) + msg
def _rns_log_inner():
nonlocal msg, level, _override_destination
try:
with process.TTYRestorer(sys.stdin.fileno()) as tr:
with process.TTYRestorer(sys.stdin.fileno(), suppress_logs=True) as tr:
attr = tr.current_attr()
if attr:
attr[process.TTYRestorer.ATTR_IDX_OFLAG] = attr[process.TTYRestorer.ATTR_IDX_OFLAG] | \
@ -123,12 +145,13 @@ def _rns_log(msg, level=3, _override_destination=False):
except ValueError:
_rns_log_orig(msg, level, _override_destination)
# TODO: figure out if forcing this to the main thread actually helps.
try:
if _loop and _loop.is_running():
_loop.call_soon_threadsafe(_rns_log_inner)
else:
_rns_log_inner()
except RuntimeError:
except:
_rns_log_inner()

View File

@ -9,6 +9,7 @@ import os
import threading
import types
import typing
import multiprocessing.pool
logging.getLogger().setLevel(logging.DEBUG)
@ -26,6 +27,7 @@ class State(contextlib.AbstractContextManager):
loop=self.loop,
stdout_callback=self._stdout_cb,
terminated_callback=self._terminated_cb)
def _stdout_cb(self, data):
with self._lock:
self._stdout.extend(data)
@ -74,3 +76,65 @@ async def test_echo():
decoded = data.decode("utf-8")
assert decoded == message.replace("\n", "\r\n") * 2
assert not state.process.running
@pytest.mark.skip_ci
@pytest.mark.asyncio
async def test_echo_live():
"""
Check for immediate echo
"""
loop = asyncio.get_running_loop()
with State(argv=["/bin/cat"],
loop=loop) as state:
state.start()
assert state.process is not None
assert state.process.running
message = "t"
state.process.write(message.encode("utf-8"))
await asyncio.sleep(0.1)
data = state.read()
state.process.write(rnsh.process.CTRL_C)
await asyncio.sleep(0.1)
assert len(data) > 0
decoded = data.decode("utf-8")
assert decoded == message
assert not state.process.running
@pytest.mark.asyncio
async def test_event_wait_any():
delay = 0.1
with multiprocessing.pool.ThreadPool() as pool:
loop = asyncio.get_running_loop()
evt1 = asyncio.Event()
evt2 = asyncio.Event()
def assert_between(min, max, val):
assert min <= val <= max
# test 1: both timeout
ts = time.time()
finished = await rnsh.process.event_wait_any([evt1, evt2], timeout=delay*2)
assert_between(delay*2, delay*2.1, time.time() - ts)
assert finished is None
assert not evt1.is_set()
assert not evt2.is_set()
#test 2: evt1 set, evt2 not set
hits = 0
def test2_bg():
nonlocal hits
hits += 1
time.sleep(delay)
evt1.set()
ts = time.time()
pool.apply_async(test2_bg)
finished = await rnsh.process.event_wait_any([evt1, evt2], timeout=delay * 2)
assert_between(delay * 0.5, delay * 1.5, time.time() - ts)
assert hits == 1
assert evt1.is_set()
assert not evt2.is_set()
assert finished == evt1