Performance and stability improvements.

This commit is contained in:
Aaron Heise 2023-02-12 22:29:16 -06:00
parent 7d711cde6d
commit 5e755acad4
6 changed files with 303 additions and 113 deletions

View File

@ -66,10 +66,10 @@ rnsh a5f72aefc2cb3cdba648f73f77c4e887
Usage: Usage:
rnsh [--config <configdir>] [-i <identityfile>] [-s <service_name>] [-l] -p rnsh [--config <configdir>] [-i <identityfile>] [-s <service_name>] [-l] -p
rnsh -l [--config <configfile>] [-i <identityfile>] [-s <service_name>] rnsh -l [--config <configfile>] [-i <identityfile>] [-s <service_name>]
[-v...] [-q...] [-b] (-n | -a <identity_hash> [-a <identity_hash>]...) [-v... | -q...] [-b <period>] (-n | -a <identity_hash> [-a <identity_hash>] ...)
[--] <program> [<arg>...] [--] <program> [<arg> ...]
rnsh [--config <configfile>] [-i <identityfile>] [-s <service_name>] rnsh [--config <configfile>] [-i <identityfile>] [-s <service_name>]
[-v...] [-q...] [-N] [-m] [-w <timeout>] <destination_hash> [-v... | -q...] [-N] [-m] [-w <timeout>] <destination_hash>
rnsh -h rnsh -h
rnsh --version rnsh --version
@ -79,13 +79,20 @@ Options:
-s NAME --service NAME Listen on/connect to specific service name if not default -s NAME --service NAME Listen on/connect to specific service name if not default
-p --print-identity Print identity information and exit -p --print-identity Print identity information and exit
-l --listen Listen (server) mode -l --listen Listen (server) mode
-b --no-announce Do not announce service -b --announce PERIOD Announce on startup and every PERIOD seconds
Specify 0 for PERIOD to announce on startup only.
-a HASH --allowed HASH Specify identities allowed to connect -a HASH --allowed HASH Specify identities allowed to connect
-n --no-auth Disable authentication -n --no-auth Disable authentication
-N --no-id Disable identify on connect -N --no-id Disable identify on connect
-m --mirror Client returns with code of remote process -m --mirror Client returns with code of remote process
-w TIME --timeout TIME Specify client connect and request timeout in seconds -w TIME --timeout TIME Specify client connect and request timeout in seconds
-v --verbose Increase verbosity -v --verbose Increase verbosity
DEFAULT LEVEL
CRITICAL
Initiator -> ERROR
WARNING
Listener -> INFO
DEBUG
-q --quiet Increase quietness -q --quiet Increase quietness
--version Show version --version Show version
-h --help Show this help -h --help Show this help

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "rnsh" name = "rnsh"
version = "0.0.2" version = "0.0.3"
description = "Shell over Reticulum" description = "Shell over Reticulum"
authors = ["acehoss <acehoss@acehoss.net>"] authors = ["acehoss <acehoss@acehoss.net>"]
license = "MIT" license = "MIT"

View File

@ -22,6 +22,7 @@
import asyncio import asyncio
import contextlib import contextlib
import copy
import errno import errno
import fcntl import fcntl
import functools import functools
@ -77,24 +78,27 @@ def tty_read(fd: int) -> bytes:
if fd_is_closed(fd): if fd_is_closed(fd):
return None return None
run = True try:
result = bytearray() run = True
while run and not fd_is_closed(fd): result = bytearray()
ready, _, _ = select.select([fd], [], [], 0) while run and not fd_is_closed(fd):
if len(ready) == 0: ready, _, _ = select.select([fd], [], [], 0)
break if len(ready) == 0:
for f in ready: break
try: for f in ready:
data = os.read(f, 512) try:
except OSError as e: data = os.read(f, 4096)
if e.errno != errno.EIO and e.errno != errno.EWOULDBLOCK: except OSError as e:
raise if e.errno != errno.EIO and e.errno != errno.EWOULDBLOCK:
else: raise
if not data: # EOF else:
run = False if not data: # EOF
if data is not None and len(data) > 0: run = False
result.extend(data) if data is not None and len(data) > 0:
return result result.extend(data)
return result
except Exception as ex:
module_logger.error("tty_read error: {ex}")
def fd_is_closed(fd: int) -> bool: def fd_is_closed(fd: int) -> bool:
@ -169,7 +173,7 @@ class TTYRestorer(contextlib.AbstractContextManager):
ATTR_IDX_LFLAG = 4 ATTR_IDX_LFLAG = 4
ATTR_IDX_CC = 5 ATTR_IDX_CC = 5
def __init__(self, fd: int): def __init__(self, fd: int, suppress_logs=False):
""" """
Saves termios attributes for a tty for later restoration. Saves termios attributes for a tty for later restoration.
@ -183,29 +187,37 @@ class TTYRestorer(contextlib.AbstractContextManager):
:param fd: file descriptor of tty :param fd: file descriptor of tty
""" """
self._log = module_logger.getChild(self.__class__.__name__)
self._fd = fd self._fd = fd
self._tattr = None self._tattr = None
with contextlib.suppress(termios.error): self._suppress_logs = suppress_logs
termios.tcgetattr(self._fd) self._tattr = self.current_attr()
if not self._tattr and not self._suppress_logs:
self._log.warning(f"Could not get attrs for fd {fd}")
def raw(self): def raw(self):
""" """
Set raw mode on tty Set raw mode on tty
""" """
if not self._fd: if self._fd is None:
return return
with contextlib.suppress(termios.error):
tty.setraw(self._fd, termios.TCSANOW)
tty.setraw(self._fd, termios.TCSADRAIN) def original_attr(self) -> [any]:
return copy.deepcopy(self._tattr)
def current_attr(self) -> [any]: def current_attr(self) -> [any]:
""" """
Get the current termios attributes for the wrapped fd. Get the current termios attributes for the wrapped fd.
:return: attribute array :return: attribute array
""" """
if not self._fd: if self._fd is None:
return None return None
return termios.tcgetattr(self._fd) with contextlib.suppress(termios.error):
return copy.deepcopy(termios.tcgetattr(self._fd))
return None
def set_attr(self, attr: [any], when: int = termios.TCSANOW): def set_attr(self, attr: [any], when: int = termios.TCSANOW):
""" """
@ -213,7 +225,7 @@ class TTYRestorer(contextlib.AbstractContextManager):
:param attr: attribute list to set :param attr: attribute list to set
:param when: when attributes should be applied (termios.TCSANOW, termios.TCSADRAIN, termios.TCSAFLUSH) :param when: when attributes should be applied (termios.TCSANOW, termios.TCSADRAIN, termios.TCSAFLUSH)
""" """
if not attr or not self._fd: if not attr or self._fd is None:
return return
with contextlib.suppress(termios.error): with contextlib.suppress(termios.error):
@ -228,7 +240,64 @@ class TTYRestorer(contextlib.AbstractContextManager):
def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException, def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
__traceback: types.TracebackType) -> bool: __traceback: types.TracebackType) -> bool:
self.restore() self.restore()
return __exc_type is not None and issubclass(__exc_type, termios.error) return False #__exc_type is not None and issubclass(__exc_type, termios.error)
def _task_from_event(evt: asyncio.Event, loop: asyncio.AbstractEventLoop = None):
if not loop:
loop = asyncio.get_running_loop()
#TODO: this is hacky
async def wait():
while not evt.is_set():
await asyncio.sleep(0.1)
return True
return loop.create_task(wait())
class AggregateException(Exception):
def __init__(self, inner_exceptions: [Exception]):
super().__init__()
self.inner_exceptions = inner_exceptions
def __str__(self):
return "Multiple exceptions encountered: \n\n" + "\n\n".join(map(lambda e: str(e), self.inner_exceptions))
async def event_wait_any(evts: [asyncio.Event], timeout: float = None) -> (any, any):
tasks = list(map(lambda evt: (evt, _task_from_event(evt)), evts))
# try:
finished, unfinished = await asyncio.wait(map(lambda t: t[1], tasks),
timeout=timeout,
return_when=asyncio.FIRST_COMPLETED)
if len(unfinished) > 0:
for task in unfinished:
task.cancel()
await asyncio.wait(unfinished)
# exceptions = []
#
# for f in finished:
# ex = f.exception()
# if ex and not isinstance(ex, asyncio.CancelledError) and not isinstance(ex, TimeoutError):
# exceptions.append(ex)
#
# if len(exceptions) > 0:
# raise AggregateException(exceptions)
return next(map(lambda t: next(map(lambda tt: tt[0], tasks)), finished), None)
# finally:
# unfinished = []
# for task in map(lambda t: t[1], tasks):
# if task.done():
# if not task.cancelled():
# task.exception()
# else:
# task.cancel()
# unfinished.append(task)
# if len(unfinished) > 0:
# await asyncio.wait(unfinished)
async def event_wait(evt: asyncio.Event, timeout: float) -> bool: async def event_wait(evt: asyncio.Event, timeout: float) -> bool:
@ -238,9 +307,7 @@ async def event_wait(evt: asyncio.Event, timeout: float) -> bool:
:param timeout: maximum number of seconds to wait. :param timeout: maximum number of seconds to wait.
:return: True if event was set, False if timeout expired :return: True if event was set, False if timeout expired
""" """
# suppress TimeoutError because we'll return False in case of timeout await event_wait_any([evt], timeout=timeout)
with contextlib.suppress(asyncio.TimeoutError):
await asyncio.wait_for(evt.wait(), timeout)
return evt.is_set() return evt.is_set()
@ -408,6 +475,8 @@ class CallbackSubprocess:
# self.log.debug("poll") # self.log.debug("poll")
try: try:
pid, self._return_code = os.waitpid(self._pid, os.WNOHANG) pid, self._return_code = os.waitpid(self._pid, os.WNOHANG)
if self._return_code is not None:
self._return_code = self._return_code & 0xff
if self._return_code is not None and not process_exists(self._pid): if self._return_code is not None and not process_exists(self._pid):
self._log.debug(f"polled return code {self._return_code}") self._log.debug(f"polled return code {self._return_code}")
self._terminated_cb(self._return_code) self._terminated_cb(self._return_code)

View File

@ -34,6 +34,7 @@ import sys
import termios import termios
import threading import threading
import time import time
import tty
from typing import Callable, TypeVar from typing import Callable, TypeVar
import RNS import RNS
import rnsh.exception as exception import rnsh.exception as exception
@ -58,20 +59,20 @@ _allow_all = False
_allowed_identity_hashes = [] _allowed_identity_hashes = []
_cmd: [str] = None _cmd: [str] = None
DATA_AVAIL_MSG = "data available" DATA_AVAIL_MSG = "data available"
_finished: asyncio.Event | None = None _finished: asyncio.Event = None
_retry_timer: retry.RetryThread | None = None _retry_timer: retry.RetryThread | None = None
_destination: RNS.Destination | None = None _destination: RNS.Destination | None = None
_loop: asyncio.AbstractEventLoop | None = None _loop: asyncio.AbstractEventLoop | None = None
async def _check_finished(timeout: float = 0): async def _check_finished(timeout: float = 0):
await process.event_wait(_finished, timeout=timeout) return await process.event_wait(_finished, timeout=timeout)
def _sigint_handler(sig, frame): def _sigint_handler(sig, frame):
global _finished global _finished
log = _get_logger("_sigint_handler") log = _get_logger("_sigint_handler")
log.debug("SIGINT") log.debug(signal.Signals(sig).name)
if _finished is not None: if _finished is not None:
_finished.set() _finished.set()
else: else:
@ -107,8 +108,6 @@ def _print_identity(configdir, identitypath, service_name, include_destination:
exit(0) exit(0)
# hack_goes_here
async def _listen(configdir, command, identitypath=None, service_name="default", verbosity=0, quietness=0, async def _listen(configdir, command, identitypath=None, service_name="default", verbosity=0, quietness=0,
allowed=None, disable_auth=None, announce_period=900): allowed=None, disable_auth=None, announce_period=900):
global _identity, _allow_all, _allowed_identity_hashes, _reticulum, _cmd, _destination global _identity, _allow_all, _allowed_identity_hashes, _reticulum, _cmd, _destination
@ -116,8 +115,8 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
_cmd = command _cmd = command
targetloglevel = RNS.LOG_INFO + verbosity - quietness targetloglevel = RNS.LOG_INFO + verbosity - quietness
rnslogging.RnsHandler.set_global_log_level(__logging.INFO)
_reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel) _reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
rnslogging.RnsHandler.set_log_level_with_rns_level(targetloglevel)
_prepare_identity(identitypath) _prepare_identity(identitypath)
_destination = RNS.Destination(_identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, service_name) _destination = RNS.Destination(_identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, service_name)
@ -151,6 +150,7 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
_destination.register_request_handler( _destination.register_request_handler(
path="data", path="data",
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()), response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
# response_generator=_listen_request,
allow=RNS.Destination.ALLOW_LIST, allow=RNS.Destination.ALLOW_LIST,
allowed_list=_allowed_identity_hashes allowed_list=_allowed_identity_hashes
) )
@ -158,10 +158,12 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
_destination.register_request_handler( _destination.register_request_handler(
path="data", path="data",
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()), response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
# response_generator=_listen_request,
allow=RNS.Destination.ALLOW_ALL, allow=RNS.Destination.ALLOW_ALL,
) )
await _check_finished() if await _check_finished():
return
log.info("rnsh listening for commands on " + RNS.prettyhexrep(_destination.hash)) log.info("rnsh listening for commands on " + RNS.prettyhexrep(_destination.hash))
@ -171,23 +173,23 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
last = time.time() last = time.time()
try: try:
while True: while not await _check_finished(1.0):
if announce_period and 0 < announce_period < time.time() - last: if announce_period and 0 < announce_period < time.time() - last:
last = time.time() last = time.time()
_destination.announce() _destination.announce()
await _check_finished(1.0) finally:
except KeyboardInterrupt:
log.warning("Shutting down") log.warning("Shutting down")
for link in list(_destination.links): for link in list(_destination.links):
with exception.permit(SystemExit): with exception.permit(SystemExit, KeyboardInterrupt):
proc = Session.get_for_tag(link.link_id) proc = Session.get_for_tag(link.link_id)
if proc is not None and proc.process.running: if proc is not None and proc.process.running:
proc.process.terminate() proc.process.terminate()
await asyncio.sleep(1) await asyncio.sleep(0)
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links)) links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
for link in links_still_active: for link in links_still_active:
if link.status != RNS.Link.CLOSED: if link.status != RNS.Link.CLOSED:
link.teardown() link.teardown()
await asyncio.sleep(0)
_PROTOCOL_MAGIC = 0xdeadbeef _PROTOCOL_MAGIC = 0xdeadbeef
@ -334,21 +336,22 @@ class Session:
@staticmethod @staticmethod
def default_request(stdin_fd: int | None) -> [any]: def default_request(stdin_fd: int | None) -> [any]:
global _tr
global _PROTOCOL_VERSION_0 global _PROTOCOL_VERSION_0
request: list[any] = [ request: list[any] = [
_PROTOCOL_VERSION_0, # 0 Protocol Version _PROTOCOL_VERSION_0, # 0 Protocol Version
None, # 1 Stdin None, # 1 Stdin
None, # 2 TERM variable None, # 2 TERM variable
None, # 3 termios attributes or something None, # 3 termios attributes or something
None, # 4 terminal rows None, # 4 terminal rows
None, # 5 terminal cols None, # 5 terminal cols
None, # 6 terminal horizontal pixels None, # 6 terminal horizontal pixels
None, # 7 terminal vertical pixels None, # 7 terminal vertical pixels
].copy() ].copy()
if stdin_fd is not None: if stdin_fd is not None:
request[Session.REQUEST_IDX_TERM] = os.environ.get("TERM", None) request[Session.REQUEST_IDX_TERM] = os.environ.get("TERM", None)
request[Session.REQUEST_IDX_TIOS] = termios.tcgetattr(stdin_fd) request[Session.REQUEST_IDX_TIOS] = _tr.original_attr() if _tr else termios.tcgetattr(stdin_fd)
request[Session.REQUEST_IDX_ROWS], \ request[Session.REQUEST_IDX_ROWS], \
request[Session.REQUEST_IDX_COLS], \ request[Session.REQUEST_IDX_COLS], \
request[Session.REQUEST_IDX_HPIX], \ request[Session.REQUEST_IDX_HPIX], \
@ -372,6 +375,7 @@ class Session:
if term_state != self._term_state: if term_state != self._term_state:
self._term_state = term_state self._term_state = term_state
self._update_winsz() self._update_winsz()
# self.process.tcsetattr(termios.TCSANOW, self._term_state[0])
if stdin is not None and len(stdin) > 0: if stdin is not None and len(stdin) > 0:
stdin = base64.b64decode(stdin) stdin = base64.b64decode(stdin)
self.process.write(stdin) self.process.write(stdin)
@ -397,11 +401,11 @@ class Session:
global _PROTOCOL_VERSION_0 global _PROTOCOL_VERSION_0
response: list[any] = [ response: list[any] = [
_PROTOCOL_VERSION_0, # 0: Protocol version _PROTOCOL_VERSION_0, # 0: Protocol version
False, # 1: Process running False, # 1: Process running
None, # 2: Return value None, # 2: Return value
0, # 3: Number of outstanding bytes 0, # 3: Number of outstanding bytes
None, # 4: Stdout/Stderr None, # 4: Stdout/Stderr
None, # 5: Timestamp None, # 5: Timestamp
].copy() ].copy()
response[Session.RESPONSE_IDX_TMSTAMP] = time.time() response[Session.RESPONSE_IDX_TMSTAMP] = time.time()
return response return response
@ -539,7 +543,8 @@ def _initiator_identified(link, identity):
def _listen_request(path, data, request_id, link_id, remote_identity, requested_at): def _listen_request(path, data, request_id, link_id, remote_identity, requested_at):
global _destination, _retry_timer, _loop global _destination, _retry_timer, _loop
log = _get_logger("_listen_request") log = _get_logger("_listen_request")
log.debug(f"listen_execute {path} {request_id} {link_id} {remote_identity}, {requested_at}") log.debug(
f"listen_execute {path} {RNS.prettyhexrep(request_id)} {RNS.prettyhexrep(link_id)} {remote_identity}, {requested_at}")
if not hasattr(data, "__len__") or len(data) < 1: if not hasattr(data, "__len__") or len(data) < 1:
raise Exception("Request data invalid") raise Exception("Request data invalid")
_retry_timer.complete(link_id) _retry_timer.complete(link_id)
@ -553,7 +558,8 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
if not remote_version == _PROTOCOL_VERSION_0: if not remote_version == _PROTOCOL_VERSION_0:
response = Session.default_response() response = Session.default_response()
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode("Listener<->initiator version mismatch\r\n".encode("utf-8")) response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode(
"Listener<->initiator version mismatch\r\n".encode("utf-8"))
response[Session.RESPONSE_IDX_RETCODE] = 255 response[Session.RESPONSE_IDX_RETCODE] = 255
response[Session.RESPONSE_IDX_RDYBYTE] = 0 response[Session.RESPONSE_IDX_RDYBYTE] = 0
return response return response
@ -565,12 +571,12 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
if session is None: if session is None:
log.debug(f"Process not found for link {link}") log.debug(f"Process not found for link {link}")
session = _listen_start_proc(link=link, session = _listen_start_proc(link=link,
term=term, term=term,
remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""), remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""),
loop=_loop) loop=_loop)
# leave significant headroom for metadata and encoding # leave significant headroom for metadata and encoding
result = session.process_request(data, link.MDU * 4 // 3) result = session.process_request(data, link.MDU * 1 // 2)
return result return result
# return ProcessState.default_response() # return ProcessState.default_response()
except Exception as e: except Exception as e:
@ -589,7 +595,8 @@ async def _spin(until: callable = None, timeout: float | None = None) -> bool:
timeout += time.time() timeout += time.time()
while (timeout is None or time.time() < timeout) and not until(): while (timeout is None or time.time() < timeout) and not until():
await _check_finished(0.01) if await _check_finished(0.001):
raise asyncio.CancelledError()
if timeout is not None and time.time() > timeout: if timeout is not None and time.time() > timeout:
return False return False
else: else:
@ -638,8 +645,8 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
if _reticulum is None: if _reticulum is None:
targetloglevel = RNS.LOG_ERROR + verbosity - quietness targetloglevel = RNS.LOG_ERROR + verbosity - quietness
rnslogging.RnsHandler.set_global_log_level(__logging.ERROR)
_reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel) _reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
rnslogging.RnsHandler.set_log_level_with_rns_level(targetloglevel)
if _identity is None: if _identity is None:
_prepare_identity(identitypath) _prepare_identity(identitypath)
@ -661,6 +668,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
) )
if _link is None or _link.status == RNS.Link.PENDING: if _link is None or _link.status == RNS.Link.PENDING:
log.debug("No link")
_link = RNS.Link(_destination) _link = RNS.Link(_destination)
_link.did_identify = False _link.did_identify = False
@ -668,6 +676,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
if not await _spin(until=lambda: _link.status == RNS.Link.ACTIVE, timeout=timeout): if not await _spin(until=lambda: _link.status == RNS.Link.ACTIVE, timeout=timeout):
raise RemoteExecutionError("Could not establish link with " + RNS.prettyhexrep(destination_hash)) raise RemoteExecutionError("Could not establish link with " + RNS.prettyhexrep(destination_hash))
log.debug("Have link")
if not noid and not _link.did_identify: if not noid and not _link.did_identify:
_link.identify(_identity) _link.identify(_identity)
_link.did_identify = True _link.did_identify = True
@ -680,6 +689,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
# TODO: Tune # TODO: Tune
timeout = timeout + _link.rtt * 4 + _remote_exec_grace timeout = timeout + _link.rtt * 4 + _remote_exec_grace
log.debug("Sending request")
request_receipt = _link.request( request_receipt = _link.request(
path="data", path="data",
data=request, data=request,
@ -687,6 +697,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
) )
timeout += 0.5 timeout += 0.5
log.debug("Waiting for delivery")
await _spin( await _spin(
until=lambda: _link.status == RNS.Link.CLOSED or ( until=lambda: _link.status == RNS.Link.CLOSED or (
request_receipt.status != RNS.RequestReceipt.FAILED and request_receipt.status != RNS.RequestReceipt.FAILED and
@ -732,15 +743,14 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
except Exception as e: except Exception as e:
raise RemoteExecutionError(f"Received invalid response") from e raise RemoteExecutionError(f"Received invalid response") from e
_tr.raw()
if stdout is not None: if stdout is not None:
# log.debug(f"stdout: {stdout}") _tr.raw()
log.debug(f"stdout: {stdout}")
os.write(sys.stdout.fileno(), stdout) os.write(sys.stdout.fileno(), stdout)
sys.stdout.flush()
sys.stdout.flush() got_bytes = len(stdout) if stdout is not None else 0
sys.stderr.flush() log.debug(f"{got_bytes} chars received, {ready_bytes} bytes ready on server, return code {return_code}")
log.debug(f"{ready_bytes} bytes ready on server, return code {return_code}")
if ready_bytes > 0: if ready_bytes > 0:
_new_data.set() _new_data.set()
@ -761,11 +771,6 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
data_buffer = bytearray() data_buffer = bytearray()
def sigint_handler():
log.debug("KeyboardInterrupt")
data_buffer.extend("\x03".encode("utf-8"))
_new_data.set()
def sigwinch_handler(): def sigwinch_handler():
# log.debug("WindowChanged") # log.debug("WindowChanged")
if _new_data is not None: if _new_data is not None:
@ -773,7 +778,7 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
def stdin(): def stdin():
data = process.tty_read(sys.stdin.fileno()) data = process.tty_read(sys.stdin.fileno())
# log.debug(f"stdin {data}") log.debug(f"stdin {data}")
if data is not None: if data is not None:
data_buffer.extend(data) data_buffer.extend(data)
_new_data.set() _new_data.set()
@ -782,10 +787,12 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
await _check_finished() await _check_finished()
loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler) loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler)
# leave a lot of overhead # leave a lot of overhead
mdu = 64 mdu = 64
rtt = 5
first_loop = True first_loop = True
while True: while not await _check_finished():
try: try:
log.debug("top of client loop") log.debug("top of client loop")
stdin = data_buffer[:mdu] stdin = data_buffer[:mdu]
@ -806,22 +813,27 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
if first_loop: if first_loop:
first_loop = False first_loop = False
mdu = _link.MDU * 4 // 3 mdu = _link.MDU // 2
loop.remove_signal_handler(signal.SIGINT)
loop.add_signal_handler(signal.SIGINT, sigint_handler)
_new_data.set() _new_data.set()
if _link:
rtt = _link.rtt
if return_code is not None: if return_code is not None:
log.debug(f"received return code {return_code}, exiting") log.debug(f"received return code {return_code}, exiting")
with exception.permit(SystemExit): with exception.permit(SystemExit, KeyboardInterrupt):
_link.teardown() _link.teardown()
return return_code return return_code
except asyncio.CancelledError:
if _link and _link.status != RNS.Link.CLOSED:
_link.teardown()
return 0
except RemoteExecutionError as e: except RemoteExecutionError as e:
print(e.msg) print(e.msg)
return 255 return 255
await process.event_wait(_new_data, 5) await process.event_wait_any([_new_data, _finished], timeout=min(max(rtt * 50, 5), 120))
_T = TypeVar("_T") _T = TypeVar("_T")
@ -835,6 +847,11 @@ def _split_array_at(arr: [_T], at: _T) -> ([_T], [_T]):
return arr, [] return arr, []
def _loop_set_signal(sig, loop):
loop.remove_signal_handler(sig)
loop.add_signal_handler(sig, functools.partial(_sigint_handler, sig, None))
async def _rnsh_cli_main(): async def _rnsh_cli_main():
global _tr, _finished, _loop global _tr, _finished, _loop
import docopt import docopt
@ -842,16 +859,16 @@ async def _rnsh_cli_main():
_loop = asyncio.get_running_loop() _loop = asyncio.get_running_loop()
rnslogging.set_main_loop(_loop) rnslogging.set_main_loop(_loop)
_finished = asyncio.Event() _finished = asyncio.Event()
_loop.remove_signal_handler(signal.SIGINT) _loop_set_signal(signal.SIGINT, _loop)
_loop.add_signal_handler(signal.SIGINT, functools.partial(_sigint_handler, signal.SIGINT, None)) _loop_set_signal(signal.SIGTERM, _loop)
usage = ''' usage = '''
Usage: Usage:
rnsh [--config <configdir>] [-i <identityfile>] [-s <service_name>] [-l] -p rnsh [--config <configdir>] [-i <identityfile>] [-s <service_name>] [-l] -p
rnsh -l [--config <configfile>] [-i <identityfile>] [-s <service_name>] rnsh -l [--config <configfile>] [-i <identityfile>] [-s <service_name>]
[-v...] [-q...] [-b <period>] (-n | -a <identity_hash> [-a <identity_hash>] ...) [-v... | -q...] [-b <period>] (-n | -a <identity_hash> [-a <identity_hash>] ...)
[--] <program> [<arg> ...] [--] <program> [<arg> ...]
rnsh [--config <configfile>] [-i <identityfile>] [-s <service_name>] rnsh [--config <configfile>] [-i <identityfile>] [-s <service_name>]
[-v...] [-q...] [-N] [-m] [-w <timeout>] <destination_hash> [-v... | -q...] [-N] [-m] [-w <timeout>] <destination_hash>
rnsh -h rnsh -h
rnsh --version rnsh --version
@ -869,6 +886,12 @@ Options:
-m --mirror Client returns with code of remote process -m --mirror Client returns with code of remote process
-w TIME --timeout TIME Specify client connect and request timeout in seconds -w TIME --timeout TIME Specify client connect and request timeout in seconds
-v --verbose Increase verbosity -v --verbose Increase verbosity
DEFAULT LEVEL
CRITICAL
Initiator -> ERROR
WARNING
Listener -> INFO
DEBUG
-q --quiet Increase quietness -q --quiet Increase quietness
--version Show version --version Show version
-h --help Show this help -h --help Show this help
@ -930,36 +953,40 @@ Options:
) )
if args_destination is not None and args_service_name is not None: if args_destination is not None and args_service_name is not None:
try: return_code = await _initiate(
return_code = await _initiate( configdir=args_config,
configdir=args_config, identitypath=args_identity,
identitypath=args_identity, verbosity=args_verbose,
verbosity=args_verbose, quietness=args_quiet,
quietness=args_quiet, noid=args_no_id,
noid=args_no_id, destination=args_destination,
destination=args_destination, service_name=args_service_name,
service_name=args_service_name, timeout=args_timeout,
timeout=args_timeout, )
) return return_code if args_mirror else 0
return return_code if args_mirror else 0
finally:
_tr.restore()
else: else:
print("") print("")
print(args) print(args)
print("") print("")
def _noop():
pass
# RNS.exit = _noop
def rnsh_cli(): def rnsh_cli():
global _tr, _retry_timer global _tr, _retry_timer
with process.TTYRestorer(sys.stdin.fileno()) as _tr: with process.TTYRestorer(sys.stdin.fileno()) as _tr, retry.RetryThread() as _retry_timer:
with retry.RetryThread() as _retry_timer: return_code = asyncio.run(_rnsh_cli_main())
return_code = asyncio.run(_rnsh_cli_main())
with exception.permit(SystemExit): with exception.permit(SystemExit):
process.tty_unset_reader_callbacks(sys.stdin.fileno()) process.tty_unset_reader_callbacks(sys.stdin.fileno())
sys.exit(return_code or 255) # RNS.Reticulum.exit_handler()
# time.sleep(0.5)
sys.exit(return_code if return_code is not None else 255)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -60,10 +60,29 @@ class RnsHandler(Handler):
return RNS.LOG_DEBUG return RNS.LOG_DEBUG
return RNS.LOG_DEBUG return RNS.LOG_DEBUG
def get_logging_loglevel(rnsloglevel: int) -> int:
if rnsloglevel == RNS.LOG_CRITICAL:
return logging.CRITICAL
if rnsloglevel == RNS.LOG_ERROR:
return logging.ERROR
if rnsloglevel == RNS.LOG_WARNING:
return logging.WARNING
if rnsloglevel == RNS.LOG_NOTICE:
return logging.INFO
if rnsloglevel == RNS.LOG_INFO:
return logging.INFO
if rnsloglevel >= RNS.LOG_VERBOSE:
return RNS.LOG_DEBUG
return RNS.LOG_DEBUG
@classmethod @classmethod
def set_global_log_level(cls, log_level: int): def set_log_level_with_rns_level(cls, rns_log_level: int):
logging.getLogger().setLevel(log_level) logging.getLogger().setLevel(RnsHandler.get_logging_loglevel(rns_log_level))
RNS.loglevel = cls.get_rns_loglevel(log_level) RNS.loglevel = rns_log_level
def set_log_level_with_logging_level(cls, logging_log_level: int):
logging.getLogger().setLevel(logging_log_level)
RNS.loglevel = cls.get_rns_loglevel(logging_log_level)
def emit(self, record): def emit(self, record):
""" """
@ -107,13 +126,16 @@ _rns_log_orig = RNS.log
def _rns_log(msg, level=3, _override_destination=False): def _rns_log(msg, level=3, _override_destination=False):
if RNS.loglevel < level:
return
if not RNS.compact_log_fmt: if not RNS.compact_log_fmt:
msg = (" " * (7 - len(RNS.loglevelname(level)))) + msg msg = (" " * (7 - len(RNS.loglevelname(level)))) + msg
def _rns_log_inner(): def _rns_log_inner():
nonlocal msg, level, _override_destination nonlocal msg, level, _override_destination
try: try:
with process.TTYRestorer(sys.stdin.fileno()) as tr: with process.TTYRestorer(sys.stdin.fileno(), suppress_logs=True) as tr:
attr = tr.current_attr() attr = tr.current_attr()
if attr: if attr:
attr[process.TTYRestorer.ATTR_IDX_OFLAG] = attr[process.TTYRestorer.ATTR_IDX_OFLAG] | \ attr[process.TTYRestorer.ATTR_IDX_OFLAG] = attr[process.TTYRestorer.ATTR_IDX_OFLAG] | \
@ -123,12 +145,13 @@ def _rns_log(msg, level=3, _override_destination=False):
except ValueError: except ValueError:
_rns_log_orig(msg, level, _override_destination) _rns_log_orig(msg, level, _override_destination)
# TODO: figure out if forcing this to the main thread actually helps.
try: try:
if _loop and _loop.is_running(): if _loop and _loop.is_running():
_loop.call_soon_threadsafe(_rns_log_inner) _loop.call_soon_threadsafe(_rns_log_inner)
else: else:
_rns_log_inner() _rns_log_inner()
except RuntimeError: except:
_rns_log_inner() _rns_log_inner()

View File

@ -9,6 +9,7 @@ import os
import threading import threading
import types import types
import typing import typing
import multiprocessing.pool
logging.getLogger().setLevel(logging.DEBUG) logging.getLogger().setLevel(logging.DEBUG)
@ -26,6 +27,7 @@ class State(contextlib.AbstractContextManager):
loop=self.loop, loop=self.loop,
stdout_callback=self._stdout_cb, stdout_callback=self._stdout_cb,
terminated_callback=self._terminated_cb) terminated_callback=self._terminated_cb)
def _stdout_cb(self, data): def _stdout_cb(self, data):
with self._lock: with self._lock:
self._stdout.extend(data) self._stdout.extend(data)
@ -74,3 +76,65 @@ async def test_echo():
decoded = data.decode("utf-8") decoded = data.decode("utf-8")
assert decoded == message.replace("\n", "\r\n") * 2 assert decoded == message.replace("\n", "\r\n") * 2
assert not state.process.running assert not state.process.running
@pytest.mark.skip_ci
@pytest.mark.asyncio
async def test_echo_live():
"""
Check for immediate echo
"""
loop = asyncio.get_running_loop()
with State(argv=["/bin/cat"],
loop=loop) as state:
state.start()
assert state.process is not None
assert state.process.running
message = "t"
state.process.write(message.encode("utf-8"))
await asyncio.sleep(0.1)
data = state.read()
state.process.write(rnsh.process.CTRL_C)
await asyncio.sleep(0.1)
assert len(data) > 0
decoded = data.decode("utf-8")
assert decoded == message
assert not state.process.running
@pytest.mark.asyncio
async def test_event_wait_any():
delay = 0.1
with multiprocessing.pool.ThreadPool() as pool:
loop = asyncio.get_running_loop()
evt1 = asyncio.Event()
evt2 = asyncio.Event()
def assert_between(min, max, val):
assert min <= val <= max
# test 1: both timeout
ts = time.time()
finished = await rnsh.process.event_wait_any([evt1, evt2], timeout=delay*2)
assert_between(delay*2, delay*2.1, time.time() - ts)
assert finished is None
assert not evt1.is_set()
assert not evt2.is_set()
#test 2: evt1 set, evt2 not set
hits = 0
def test2_bg():
nonlocal hits
hits += 1
time.sleep(delay)
evt1.set()
ts = time.time()
pool.apply_async(test2_bg)
finished = await rnsh.process.event_wait_any([evt1, evt2], timeout=delay * 2)
assert_between(delay * 0.5, delay * 1.5, time.time() - ts)
assert hits == 1
assert evt1.is_set()
assert not evt2.is_set()
assert finished == evt1