mirror of
https://github.com/markqvist/rnsh.git
synced 2024-10-01 01:15:37 -04:00
Performance and stability improvements.
This commit is contained in:
parent
7d711cde6d
commit
5e755acad4
15
README.md
15
README.md
@ -66,10 +66,10 @@ rnsh a5f72aefc2cb3cdba648f73f77c4e887
|
||||
Usage:
|
||||
rnsh [--config <configdir>] [-i <identityfile>] [-s <service_name>] [-l] -p
|
||||
rnsh -l [--config <configfile>] [-i <identityfile>] [-s <service_name>]
|
||||
[-v...] [-q...] [-b] (-n | -a <identity_hash> [-a <identity_hash>]...)
|
||||
[--] <program> [<arg>...]
|
||||
[-v... | -q...] [-b <period>] (-n | -a <identity_hash> [-a <identity_hash>] ...)
|
||||
[--] <program> [<arg> ...]
|
||||
rnsh [--config <configfile>] [-i <identityfile>] [-s <service_name>]
|
||||
[-v...] [-q...] [-N] [-m] [-w <timeout>] <destination_hash>
|
||||
[-v... | -q...] [-N] [-m] [-w <timeout>] <destination_hash>
|
||||
rnsh -h
|
||||
rnsh --version
|
||||
|
||||
@ -79,13 +79,20 @@ Options:
|
||||
-s NAME --service NAME Listen on/connect to specific service name if not default
|
||||
-p --print-identity Print identity information and exit
|
||||
-l --listen Listen (server) mode
|
||||
-b --no-announce Do not announce service
|
||||
-b --announce PERIOD Announce on startup and every PERIOD seconds
|
||||
Specify 0 for PERIOD to announce on startup only.
|
||||
-a HASH --allowed HASH Specify identities allowed to connect
|
||||
-n --no-auth Disable authentication
|
||||
-N --no-id Disable identify on connect
|
||||
-m --mirror Client returns with code of remote process
|
||||
-w TIME --timeout TIME Specify client connect and request timeout in seconds
|
||||
-v --verbose Increase verbosity
|
||||
DEFAULT LEVEL
|
||||
CRITICAL
|
||||
Initiator -> ERROR
|
||||
WARNING
|
||||
Listener -> INFO
|
||||
DEBUG
|
||||
-q --quiet Increase quietness
|
||||
--version Show version
|
||||
-h --help Show this help
|
||||
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "rnsh"
|
||||
version = "0.0.2"
|
||||
version = "0.0.3"
|
||||
description = "Shell over Reticulum"
|
||||
authors = ["acehoss <acehoss@acehoss.net>"]
|
||||
license = "MIT"
|
||||
|
129
rnsh/process.py
129
rnsh/process.py
@ -22,6 +22,7 @@
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import copy
|
||||
import errno
|
||||
import fcntl
|
||||
import functools
|
||||
@ -77,24 +78,27 @@ def tty_read(fd: int) -> bytes:
|
||||
if fd_is_closed(fd):
|
||||
return None
|
||||
|
||||
run = True
|
||||
result = bytearray()
|
||||
while run and not fd_is_closed(fd):
|
||||
ready, _, _ = select.select([fd], [], [], 0)
|
||||
if len(ready) == 0:
|
||||
break
|
||||
for f in ready:
|
||||
try:
|
||||
data = os.read(f, 512)
|
||||
except OSError as e:
|
||||
if e.errno != errno.EIO and e.errno != errno.EWOULDBLOCK:
|
||||
raise
|
||||
else:
|
||||
if not data: # EOF
|
||||
run = False
|
||||
if data is not None and len(data) > 0:
|
||||
result.extend(data)
|
||||
return result
|
||||
try:
|
||||
run = True
|
||||
result = bytearray()
|
||||
while run and not fd_is_closed(fd):
|
||||
ready, _, _ = select.select([fd], [], [], 0)
|
||||
if len(ready) == 0:
|
||||
break
|
||||
for f in ready:
|
||||
try:
|
||||
data = os.read(f, 4096)
|
||||
except OSError as e:
|
||||
if e.errno != errno.EIO and e.errno != errno.EWOULDBLOCK:
|
||||
raise
|
||||
else:
|
||||
if not data: # EOF
|
||||
run = False
|
||||
if data is not None and len(data) > 0:
|
||||
result.extend(data)
|
||||
return result
|
||||
except Exception as ex:
|
||||
module_logger.error("tty_read error: {ex}")
|
||||
|
||||
|
||||
def fd_is_closed(fd: int) -> bool:
|
||||
@ -169,7 +173,7 @@ class TTYRestorer(contextlib.AbstractContextManager):
|
||||
ATTR_IDX_LFLAG = 4
|
||||
ATTR_IDX_CC = 5
|
||||
|
||||
def __init__(self, fd: int):
|
||||
def __init__(self, fd: int, suppress_logs=False):
|
||||
"""
|
||||
Saves termios attributes for a tty for later restoration.
|
||||
|
||||
@ -183,29 +187,37 @@ class TTYRestorer(contextlib.AbstractContextManager):
|
||||
|
||||
:param fd: file descriptor of tty
|
||||
"""
|
||||
self._log = module_logger.getChild(self.__class__.__name__)
|
||||
self._fd = fd
|
||||
self._tattr = None
|
||||
with contextlib.suppress(termios.error):
|
||||
termios.tcgetattr(self._fd)
|
||||
self._suppress_logs = suppress_logs
|
||||
self._tattr = self.current_attr()
|
||||
if not self._tattr and not self._suppress_logs:
|
||||
self._log.warning(f"Could not get attrs for fd {fd}")
|
||||
|
||||
def raw(self):
|
||||
"""
|
||||
Set raw mode on tty
|
||||
"""
|
||||
if not self._fd:
|
||||
if self._fd is None:
|
||||
return
|
||||
with contextlib.suppress(termios.error):
|
||||
tty.setraw(self._fd, termios.TCSANOW)
|
||||
|
||||
tty.setraw(self._fd, termios.TCSADRAIN)
|
||||
def original_attr(self) -> [any]:
|
||||
return copy.deepcopy(self._tattr)
|
||||
|
||||
def current_attr(self) -> [any]:
|
||||
"""
|
||||
Get the current termios attributes for the wrapped fd.
|
||||
:return: attribute array
|
||||
"""
|
||||
if not self._fd:
|
||||
if self._fd is None:
|
||||
return None
|
||||
|
||||
return termios.tcgetattr(self._fd)
|
||||
with contextlib.suppress(termios.error):
|
||||
return copy.deepcopy(termios.tcgetattr(self._fd))
|
||||
return None
|
||||
|
||||
def set_attr(self, attr: [any], when: int = termios.TCSANOW):
|
||||
"""
|
||||
@ -213,7 +225,7 @@ class TTYRestorer(contextlib.AbstractContextManager):
|
||||
:param attr: attribute list to set
|
||||
:param when: when attributes should be applied (termios.TCSANOW, termios.TCSADRAIN, termios.TCSAFLUSH)
|
||||
"""
|
||||
if not attr or not self._fd:
|
||||
if not attr or self._fd is None:
|
||||
return
|
||||
|
||||
with contextlib.suppress(termios.error):
|
||||
@ -228,7 +240,64 @@ class TTYRestorer(contextlib.AbstractContextManager):
|
||||
def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
|
||||
__traceback: types.TracebackType) -> bool:
|
||||
self.restore()
|
||||
return __exc_type is not None and issubclass(__exc_type, termios.error)
|
||||
return False #__exc_type is not None and issubclass(__exc_type, termios.error)
|
||||
|
||||
|
||||
def _task_from_event(evt: asyncio.Event, loop: asyncio.AbstractEventLoop = None):
|
||||
if not loop:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
#TODO: this is hacky
|
||||
async def wait():
|
||||
while not evt.is_set():
|
||||
await asyncio.sleep(0.1)
|
||||
return True
|
||||
|
||||
return loop.create_task(wait())
|
||||
|
||||
|
||||
class AggregateException(Exception):
|
||||
def __init__(self, inner_exceptions: [Exception]):
|
||||
super().__init__()
|
||||
self.inner_exceptions = inner_exceptions
|
||||
|
||||
def __str__(self):
|
||||
return "Multiple exceptions encountered: \n\n" + "\n\n".join(map(lambda e: str(e), self.inner_exceptions))
|
||||
|
||||
async def event_wait_any(evts: [asyncio.Event], timeout: float = None) -> (any, any):
|
||||
tasks = list(map(lambda evt: (evt, _task_from_event(evt)), evts))
|
||||
# try:
|
||||
finished, unfinished = await asyncio.wait(map(lambda t: t[1], tasks),
|
||||
timeout=timeout,
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
if len(unfinished) > 0:
|
||||
for task in unfinished:
|
||||
task.cancel()
|
||||
await asyncio.wait(unfinished)
|
||||
|
||||
# exceptions = []
|
||||
#
|
||||
# for f in finished:
|
||||
# ex = f.exception()
|
||||
# if ex and not isinstance(ex, asyncio.CancelledError) and not isinstance(ex, TimeoutError):
|
||||
# exceptions.append(ex)
|
||||
#
|
||||
# if len(exceptions) > 0:
|
||||
# raise AggregateException(exceptions)
|
||||
|
||||
return next(map(lambda t: next(map(lambda tt: tt[0], tasks)), finished), None)
|
||||
# finally:
|
||||
# unfinished = []
|
||||
# for task in map(lambda t: t[1], tasks):
|
||||
# if task.done():
|
||||
# if not task.cancelled():
|
||||
# task.exception()
|
||||
# else:
|
||||
# task.cancel()
|
||||
# unfinished.append(task)
|
||||
# if len(unfinished) > 0:
|
||||
# await asyncio.wait(unfinished)
|
||||
|
||||
|
||||
async def event_wait(evt: asyncio.Event, timeout: float) -> bool:
|
||||
@ -238,9 +307,7 @@ async def event_wait(evt: asyncio.Event, timeout: float) -> bool:
|
||||
:param timeout: maximum number of seconds to wait.
|
||||
:return: True if event was set, False if timeout expired
|
||||
"""
|
||||
# suppress TimeoutError because we'll return False in case of timeout
|
||||
with contextlib.suppress(asyncio.TimeoutError):
|
||||
await asyncio.wait_for(evt.wait(), timeout)
|
||||
await event_wait_any([evt], timeout=timeout)
|
||||
return evt.is_set()
|
||||
|
||||
|
||||
@ -408,6 +475,8 @@ class CallbackSubprocess:
|
||||
# self.log.debug("poll")
|
||||
try:
|
||||
pid, self._return_code = os.waitpid(self._pid, os.WNOHANG)
|
||||
if self._return_code is not None:
|
||||
self._return_code = self._return_code & 0xff
|
||||
if self._return_code is not None and not process_exists(self._pid):
|
||||
self._log.debug(f"polled return code {self._return_code}")
|
||||
self._terminated_cb(self._return_code)
|
||||
|
173
rnsh/rnsh.py
173
rnsh/rnsh.py
@ -34,6 +34,7 @@ import sys
|
||||
import termios
|
||||
import threading
|
||||
import time
|
||||
import tty
|
||||
from typing import Callable, TypeVar
|
||||
import RNS
|
||||
import rnsh.exception as exception
|
||||
@ -58,20 +59,20 @@ _allow_all = False
|
||||
_allowed_identity_hashes = []
|
||||
_cmd: [str] = None
|
||||
DATA_AVAIL_MSG = "data available"
|
||||
_finished: asyncio.Event | None = None
|
||||
_finished: asyncio.Event = None
|
||||
_retry_timer: retry.RetryThread | None = None
|
||||
_destination: RNS.Destination | None = None
|
||||
_loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
|
||||
async def _check_finished(timeout: float = 0):
|
||||
await process.event_wait(_finished, timeout=timeout)
|
||||
return await process.event_wait(_finished, timeout=timeout)
|
||||
|
||||
|
||||
def _sigint_handler(sig, frame):
|
||||
global _finished
|
||||
log = _get_logger("_sigint_handler")
|
||||
log.debug("SIGINT")
|
||||
log.debug(signal.Signals(sig).name)
|
||||
if _finished is not None:
|
||||
_finished.set()
|
||||
else:
|
||||
@ -107,8 +108,6 @@ def _print_identity(configdir, identitypath, service_name, include_destination:
|
||||
exit(0)
|
||||
|
||||
|
||||
# hack_goes_here
|
||||
|
||||
async def _listen(configdir, command, identitypath=None, service_name="default", verbosity=0, quietness=0,
|
||||
allowed=None, disable_auth=None, announce_period=900):
|
||||
global _identity, _allow_all, _allowed_identity_hashes, _reticulum, _cmd, _destination
|
||||
@ -116,8 +115,8 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
||||
_cmd = command
|
||||
|
||||
targetloglevel = RNS.LOG_INFO + verbosity - quietness
|
||||
rnslogging.RnsHandler.set_global_log_level(__logging.INFO)
|
||||
_reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
|
||||
rnslogging.RnsHandler.set_log_level_with_rns_level(targetloglevel)
|
||||
_prepare_identity(identitypath)
|
||||
_destination = RNS.Destination(_identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, service_name)
|
||||
|
||||
@ -151,6 +150,7 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
||||
_destination.register_request_handler(
|
||||
path="data",
|
||||
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
|
||||
# response_generator=_listen_request,
|
||||
allow=RNS.Destination.ALLOW_LIST,
|
||||
allowed_list=_allowed_identity_hashes
|
||||
)
|
||||
@ -158,10 +158,12 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
||||
_destination.register_request_handler(
|
||||
path="data",
|
||||
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
|
||||
# response_generator=_listen_request,
|
||||
allow=RNS.Destination.ALLOW_ALL,
|
||||
)
|
||||
|
||||
await _check_finished()
|
||||
if await _check_finished():
|
||||
return
|
||||
|
||||
log.info("rnsh listening for commands on " + RNS.prettyhexrep(_destination.hash))
|
||||
|
||||
@ -171,23 +173,23 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
||||
last = time.time()
|
||||
|
||||
try:
|
||||
while True:
|
||||
while not await _check_finished(1.0):
|
||||
if announce_period and 0 < announce_period < time.time() - last:
|
||||
last = time.time()
|
||||
_destination.announce()
|
||||
await _check_finished(1.0)
|
||||
except KeyboardInterrupt:
|
||||
finally:
|
||||
log.warning("Shutting down")
|
||||
for link in list(_destination.links):
|
||||
with exception.permit(SystemExit):
|
||||
with exception.permit(SystemExit, KeyboardInterrupt):
|
||||
proc = Session.get_for_tag(link.link_id)
|
||||
if proc is not None and proc.process.running:
|
||||
proc.process.terminate()
|
||||
await asyncio.sleep(1)
|
||||
await asyncio.sleep(0)
|
||||
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
|
||||
for link in links_still_active:
|
||||
if link.status != RNS.Link.CLOSED:
|
||||
link.teardown()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
_PROTOCOL_MAGIC = 0xdeadbeef
|
||||
@ -334,21 +336,22 @@ class Session:
|
||||
|
||||
@staticmethod
|
||||
def default_request(stdin_fd: int | None) -> [any]:
|
||||
global _tr
|
||||
global _PROTOCOL_VERSION_0
|
||||
request: list[any] = [
|
||||
_PROTOCOL_VERSION_0, # 0 Protocol Version
|
||||
None, # 1 Stdin
|
||||
None, # 2 TERM variable
|
||||
None, # 3 termios attributes or something
|
||||
None, # 4 terminal rows
|
||||
None, # 5 terminal cols
|
||||
None, # 6 terminal horizontal pixels
|
||||
None, # 7 terminal vertical pixels
|
||||
None, # 1 Stdin
|
||||
None, # 2 TERM variable
|
||||
None, # 3 termios attributes or something
|
||||
None, # 4 terminal rows
|
||||
None, # 5 terminal cols
|
||||
None, # 6 terminal horizontal pixels
|
||||
None, # 7 terminal vertical pixels
|
||||
].copy()
|
||||
|
||||
if stdin_fd is not None:
|
||||
request[Session.REQUEST_IDX_TERM] = os.environ.get("TERM", None)
|
||||
request[Session.REQUEST_IDX_TIOS] = termios.tcgetattr(stdin_fd)
|
||||
request[Session.REQUEST_IDX_TIOS] = _tr.original_attr() if _tr else termios.tcgetattr(stdin_fd)
|
||||
request[Session.REQUEST_IDX_ROWS], \
|
||||
request[Session.REQUEST_IDX_COLS], \
|
||||
request[Session.REQUEST_IDX_HPIX], \
|
||||
@ -372,6 +375,7 @@ class Session:
|
||||
if term_state != self._term_state:
|
||||
self._term_state = term_state
|
||||
self._update_winsz()
|
||||
# self.process.tcsetattr(termios.TCSANOW, self._term_state[0])
|
||||
if stdin is not None and len(stdin) > 0:
|
||||
stdin = base64.b64decode(stdin)
|
||||
self.process.write(stdin)
|
||||
@ -397,11 +401,11 @@ class Session:
|
||||
global _PROTOCOL_VERSION_0
|
||||
response: list[any] = [
|
||||
_PROTOCOL_VERSION_0, # 0: Protocol version
|
||||
False, # 1: Process running
|
||||
None, # 2: Return value
|
||||
0, # 3: Number of outstanding bytes
|
||||
None, # 4: Stdout/Stderr
|
||||
None, # 5: Timestamp
|
||||
False, # 1: Process running
|
||||
None, # 2: Return value
|
||||
0, # 3: Number of outstanding bytes
|
||||
None, # 4: Stdout/Stderr
|
||||
None, # 5: Timestamp
|
||||
].copy()
|
||||
response[Session.RESPONSE_IDX_TMSTAMP] = time.time()
|
||||
return response
|
||||
@ -539,7 +543,8 @@ def _initiator_identified(link, identity):
|
||||
def _listen_request(path, data, request_id, link_id, remote_identity, requested_at):
|
||||
global _destination, _retry_timer, _loop
|
||||
log = _get_logger("_listen_request")
|
||||
log.debug(f"listen_execute {path} {request_id} {link_id} {remote_identity}, {requested_at}")
|
||||
log.debug(
|
||||
f"listen_execute {path} {RNS.prettyhexrep(request_id)} {RNS.prettyhexrep(link_id)} {remote_identity}, {requested_at}")
|
||||
if not hasattr(data, "__len__") or len(data) < 1:
|
||||
raise Exception("Request data invalid")
|
||||
_retry_timer.complete(link_id)
|
||||
@ -553,7 +558,8 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
|
||||
|
||||
if not remote_version == _PROTOCOL_VERSION_0:
|
||||
response = Session.default_response()
|
||||
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode("Listener<->initiator version mismatch\r\n".encode("utf-8"))
|
||||
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode(
|
||||
"Listener<->initiator version mismatch\r\n".encode("utf-8"))
|
||||
response[Session.RESPONSE_IDX_RETCODE] = 255
|
||||
response[Session.RESPONSE_IDX_RDYBYTE] = 0
|
||||
return response
|
||||
@ -565,12 +571,12 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
|
||||
if session is None:
|
||||
log.debug(f"Process not found for link {link}")
|
||||
session = _listen_start_proc(link=link,
|
||||
term=term,
|
||||
remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""),
|
||||
loop=_loop)
|
||||
term=term,
|
||||
remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""),
|
||||
loop=_loop)
|
||||
|
||||
# leave significant headroom for metadata and encoding
|
||||
result = session.process_request(data, link.MDU * 4 // 3)
|
||||
result = session.process_request(data, link.MDU * 1 // 2)
|
||||
return result
|
||||
# return ProcessState.default_response()
|
||||
except Exception as e:
|
||||
@ -589,7 +595,8 @@ async def _spin(until: callable = None, timeout: float | None = None) -> bool:
|
||||
timeout += time.time()
|
||||
|
||||
while (timeout is None or time.time() < timeout) and not until():
|
||||
await _check_finished(0.01)
|
||||
if await _check_finished(0.001):
|
||||
raise asyncio.CancelledError()
|
||||
if timeout is not None and time.time() > timeout:
|
||||
return False
|
||||
else:
|
||||
@ -638,8 +645,8 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
||||
|
||||
if _reticulum is None:
|
||||
targetloglevel = RNS.LOG_ERROR + verbosity - quietness
|
||||
rnslogging.RnsHandler.set_global_log_level(__logging.ERROR)
|
||||
_reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
|
||||
rnslogging.RnsHandler.set_log_level_with_rns_level(targetloglevel)
|
||||
|
||||
if _identity is None:
|
||||
_prepare_identity(identitypath)
|
||||
@ -661,6 +668,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
||||
)
|
||||
|
||||
if _link is None or _link.status == RNS.Link.PENDING:
|
||||
log.debug("No link")
|
||||
_link = RNS.Link(_destination)
|
||||
_link.did_identify = False
|
||||
|
||||
@ -668,6 +676,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
||||
if not await _spin(until=lambda: _link.status == RNS.Link.ACTIVE, timeout=timeout):
|
||||
raise RemoteExecutionError("Could not establish link with " + RNS.prettyhexrep(destination_hash))
|
||||
|
||||
log.debug("Have link")
|
||||
if not noid and not _link.did_identify:
|
||||
_link.identify(_identity)
|
||||
_link.did_identify = True
|
||||
@ -680,6 +689,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
||||
# TODO: Tune
|
||||
timeout = timeout + _link.rtt * 4 + _remote_exec_grace
|
||||
|
||||
log.debug("Sending request")
|
||||
request_receipt = _link.request(
|
||||
path="data",
|
||||
data=request,
|
||||
@ -687,6 +697,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
||||
)
|
||||
timeout += 0.5
|
||||
|
||||
log.debug("Waiting for delivery")
|
||||
await _spin(
|
||||
until=lambda: _link.status == RNS.Link.CLOSED or (
|
||||
request_receipt.status != RNS.RequestReceipt.FAILED and
|
||||
@ -732,15 +743,14 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
||||
except Exception as e:
|
||||
raise RemoteExecutionError(f"Received invalid response") from e
|
||||
|
||||
_tr.raw()
|
||||
if stdout is not None:
|
||||
# log.debug(f"stdout: {stdout}")
|
||||
_tr.raw()
|
||||
log.debug(f"stdout: {stdout}")
|
||||
os.write(sys.stdout.fileno(), stdout)
|
||||
sys.stdout.flush()
|
||||
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
|
||||
log.debug(f"{ready_bytes} bytes ready on server, return code {return_code}")
|
||||
got_bytes = len(stdout) if stdout is not None else 0
|
||||
log.debug(f"{got_bytes} chars received, {ready_bytes} bytes ready on server, return code {return_code}")
|
||||
|
||||
if ready_bytes > 0:
|
||||
_new_data.set()
|
||||
@ -761,11 +771,6 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
|
||||
|
||||
data_buffer = bytearray()
|
||||
|
||||
def sigint_handler():
|
||||
log.debug("KeyboardInterrupt")
|
||||
data_buffer.extend("\x03".encode("utf-8"))
|
||||
_new_data.set()
|
||||
|
||||
def sigwinch_handler():
|
||||
# log.debug("WindowChanged")
|
||||
if _new_data is not None:
|
||||
@ -773,7 +778,7 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
|
||||
|
||||
def stdin():
|
||||
data = process.tty_read(sys.stdin.fileno())
|
||||
# log.debug(f"stdin {data}")
|
||||
log.debug(f"stdin {data}")
|
||||
if data is not None:
|
||||
data_buffer.extend(data)
|
||||
_new_data.set()
|
||||
@ -782,10 +787,12 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
|
||||
|
||||
await _check_finished()
|
||||
loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler)
|
||||
|
||||
# leave a lot of overhead
|
||||
mdu = 64
|
||||
rtt = 5
|
||||
first_loop = True
|
||||
while True:
|
||||
while not await _check_finished():
|
||||
try:
|
||||
log.debug("top of client loop")
|
||||
stdin = data_buffer[:mdu]
|
||||
@ -806,22 +813,27 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
|
||||
|
||||
if first_loop:
|
||||
first_loop = False
|
||||
mdu = _link.MDU * 4 // 3
|
||||
loop.remove_signal_handler(signal.SIGINT)
|
||||
loop.add_signal_handler(signal.SIGINT, sigint_handler)
|
||||
mdu = _link.MDU // 2
|
||||
_new_data.set()
|
||||
|
||||
if _link:
|
||||
rtt = _link.rtt
|
||||
|
||||
if return_code is not None:
|
||||
log.debug(f"received return code {return_code}, exiting")
|
||||
with exception.permit(SystemExit):
|
||||
with exception.permit(SystemExit, KeyboardInterrupt):
|
||||
_link.teardown()
|
||||
|
||||
return return_code
|
||||
except asyncio.CancelledError:
|
||||
if _link and _link.status != RNS.Link.CLOSED:
|
||||
_link.teardown()
|
||||
return 0
|
||||
except RemoteExecutionError as e:
|
||||
print(e.msg)
|
||||
return 255
|
||||
|
||||
await process.event_wait(_new_data, 5)
|
||||
await process.event_wait_any([_new_data, _finished], timeout=min(max(rtt * 50, 5), 120))
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
@ -835,6 +847,11 @@ def _split_array_at(arr: [_T], at: _T) -> ([_T], [_T]):
|
||||
return arr, []
|
||||
|
||||
|
||||
def _loop_set_signal(sig, loop):
|
||||
loop.remove_signal_handler(sig)
|
||||
loop.add_signal_handler(sig, functools.partial(_sigint_handler, sig, None))
|
||||
|
||||
|
||||
async def _rnsh_cli_main():
|
||||
global _tr, _finished, _loop
|
||||
import docopt
|
||||
@ -842,16 +859,16 @@ async def _rnsh_cli_main():
|
||||
_loop = asyncio.get_running_loop()
|
||||
rnslogging.set_main_loop(_loop)
|
||||
_finished = asyncio.Event()
|
||||
_loop.remove_signal_handler(signal.SIGINT)
|
||||
_loop.add_signal_handler(signal.SIGINT, functools.partial(_sigint_handler, signal.SIGINT, None))
|
||||
_loop_set_signal(signal.SIGINT, _loop)
|
||||
_loop_set_signal(signal.SIGTERM, _loop)
|
||||
usage = '''
|
||||
Usage:
|
||||
rnsh [--config <configdir>] [-i <identityfile>] [-s <service_name>] [-l] -p
|
||||
rnsh -l [--config <configfile>] [-i <identityfile>] [-s <service_name>]
|
||||
[-v...] [-q...] [-b <period>] (-n | -a <identity_hash> [-a <identity_hash>] ...)
|
||||
[-v... | -q...] [-b <period>] (-n | -a <identity_hash> [-a <identity_hash>] ...)
|
||||
[--] <program> [<arg> ...]
|
||||
rnsh [--config <configfile>] [-i <identityfile>] [-s <service_name>]
|
||||
[-v...] [-q...] [-N] [-m] [-w <timeout>] <destination_hash>
|
||||
[-v... | -q...] [-N] [-m] [-w <timeout>] <destination_hash>
|
||||
rnsh -h
|
||||
rnsh --version
|
||||
|
||||
@ -869,6 +886,12 @@ Options:
|
||||
-m --mirror Client returns with code of remote process
|
||||
-w TIME --timeout TIME Specify client connect and request timeout in seconds
|
||||
-v --verbose Increase verbosity
|
||||
DEFAULT LEVEL
|
||||
CRITICAL
|
||||
Initiator -> ERROR
|
||||
WARNING
|
||||
Listener -> INFO
|
||||
DEBUG
|
||||
-q --quiet Increase quietness
|
||||
--version Show version
|
||||
-h --help Show this help
|
||||
@ -930,36 +953,40 @@ Options:
|
||||
)
|
||||
|
||||
if args_destination is not None and args_service_name is not None:
|
||||
try:
|
||||
return_code = await _initiate(
|
||||
configdir=args_config,
|
||||
identitypath=args_identity,
|
||||
verbosity=args_verbose,
|
||||
quietness=args_quiet,
|
||||
noid=args_no_id,
|
||||
destination=args_destination,
|
||||
service_name=args_service_name,
|
||||
timeout=args_timeout,
|
||||
)
|
||||
return return_code if args_mirror else 0
|
||||
finally:
|
||||
_tr.restore()
|
||||
return_code = await _initiate(
|
||||
configdir=args_config,
|
||||
identitypath=args_identity,
|
||||
verbosity=args_verbose,
|
||||
quietness=args_quiet,
|
||||
noid=args_no_id,
|
||||
destination=args_destination,
|
||||
service_name=args_service_name,
|
||||
timeout=args_timeout,
|
||||
)
|
||||
return return_code if args_mirror else 0
|
||||
else:
|
||||
print("")
|
||||
print(args)
|
||||
print("")
|
||||
|
||||
|
||||
def _noop():
|
||||
pass
|
||||
|
||||
# RNS.exit = _noop
|
||||
|
||||
|
||||
def rnsh_cli():
|
||||
global _tr, _retry_timer
|
||||
with process.TTYRestorer(sys.stdin.fileno()) as _tr:
|
||||
with retry.RetryThread() as _retry_timer:
|
||||
return_code = asyncio.run(_rnsh_cli_main())
|
||||
with process.TTYRestorer(sys.stdin.fileno()) as _tr, retry.RetryThread() as _retry_timer:
|
||||
return_code = asyncio.run(_rnsh_cli_main())
|
||||
|
||||
with exception.permit(SystemExit):
|
||||
process.tty_unset_reader_callbacks(sys.stdin.fileno())
|
||||
|
||||
sys.exit(return_code or 255)
|
||||
# RNS.Reticulum.exit_handler()
|
||||
# time.sleep(0.5)
|
||||
sys.exit(return_code if return_code is not None else 255)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -60,10 +60,29 @@ class RnsHandler(Handler):
|
||||
return RNS.LOG_DEBUG
|
||||
return RNS.LOG_DEBUG
|
||||
|
||||
def get_logging_loglevel(rnsloglevel: int) -> int:
|
||||
if rnsloglevel == RNS.LOG_CRITICAL:
|
||||
return logging.CRITICAL
|
||||
if rnsloglevel == RNS.LOG_ERROR:
|
||||
return logging.ERROR
|
||||
if rnsloglevel == RNS.LOG_WARNING:
|
||||
return logging.WARNING
|
||||
if rnsloglevel == RNS.LOG_NOTICE:
|
||||
return logging.INFO
|
||||
if rnsloglevel == RNS.LOG_INFO:
|
||||
return logging.INFO
|
||||
if rnsloglevel >= RNS.LOG_VERBOSE:
|
||||
return RNS.LOG_DEBUG
|
||||
return RNS.LOG_DEBUG
|
||||
|
||||
@classmethod
|
||||
def set_global_log_level(cls, log_level: int):
|
||||
logging.getLogger().setLevel(log_level)
|
||||
RNS.loglevel = cls.get_rns_loglevel(log_level)
|
||||
def set_log_level_with_rns_level(cls, rns_log_level: int):
|
||||
logging.getLogger().setLevel(RnsHandler.get_logging_loglevel(rns_log_level))
|
||||
RNS.loglevel = rns_log_level
|
||||
|
||||
def set_log_level_with_logging_level(cls, logging_log_level: int):
|
||||
logging.getLogger().setLevel(logging_log_level)
|
||||
RNS.loglevel = cls.get_rns_loglevel(logging_log_level)
|
||||
|
||||
def emit(self, record):
|
||||
"""
|
||||
@ -107,13 +126,16 @@ _rns_log_orig = RNS.log
|
||||
|
||||
|
||||
def _rns_log(msg, level=3, _override_destination=False):
|
||||
if RNS.loglevel < level:
|
||||
return
|
||||
|
||||
if not RNS.compact_log_fmt:
|
||||
msg = (" " * (7 - len(RNS.loglevelname(level)))) + msg
|
||||
|
||||
def _rns_log_inner():
|
||||
nonlocal msg, level, _override_destination
|
||||
try:
|
||||
with process.TTYRestorer(sys.stdin.fileno()) as tr:
|
||||
with process.TTYRestorer(sys.stdin.fileno(), suppress_logs=True) as tr:
|
||||
attr = tr.current_attr()
|
||||
if attr:
|
||||
attr[process.TTYRestorer.ATTR_IDX_OFLAG] = attr[process.TTYRestorer.ATTR_IDX_OFLAG] | \
|
||||
@ -123,12 +145,13 @@ def _rns_log(msg, level=3, _override_destination=False):
|
||||
except ValueError:
|
||||
_rns_log_orig(msg, level, _override_destination)
|
||||
|
||||
# TODO: figure out if forcing this to the main thread actually helps.
|
||||
try:
|
||||
if _loop and _loop.is_running():
|
||||
_loop.call_soon_threadsafe(_rns_log_inner)
|
||||
else:
|
||||
_rns_log_inner()
|
||||
except RuntimeError:
|
||||
except:
|
||||
_rns_log_inner()
|
||||
|
||||
|
||||
|
@ -9,6 +9,7 @@ import os
|
||||
import threading
|
||||
import types
|
||||
import typing
|
||||
import multiprocessing.pool
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
@ -26,6 +27,7 @@ class State(contextlib.AbstractContextManager):
|
||||
loop=self.loop,
|
||||
stdout_callback=self._stdout_cb,
|
||||
terminated_callback=self._terminated_cb)
|
||||
|
||||
def _stdout_cb(self, data):
|
||||
with self._lock:
|
||||
self._stdout.extend(data)
|
||||
@ -74,3 +76,65 @@ async def test_echo():
|
||||
decoded = data.decode("utf-8")
|
||||
assert decoded == message.replace("\n", "\r\n") * 2
|
||||
assert not state.process.running
|
||||
|
||||
|
||||
@pytest.mark.skip_ci
|
||||
@pytest.mark.asyncio
|
||||
async def test_echo_live():
|
||||
"""
|
||||
Check for immediate echo
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
with State(argv=["/bin/cat"],
|
||||
loop=loop) as state:
|
||||
state.start()
|
||||
assert state.process is not None
|
||||
assert state.process.running
|
||||
message = "t"
|
||||
state.process.write(message.encode("utf-8"))
|
||||
await asyncio.sleep(0.1)
|
||||
data = state.read()
|
||||
state.process.write(rnsh.process.CTRL_C)
|
||||
await asyncio.sleep(0.1)
|
||||
assert len(data) > 0
|
||||
decoded = data.decode("utf-8")
|
||||
assert decoded == message
|
||||
assert not state.process.running
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_wait_any():
|
||||
delay = 0.1
|
||||
with multiprocessing.pool.ThreadPool() as pool:
|
||||
loop = asyncio.get_running_loop()
|
||||
evt1 = asyncio.Event()
|
||||
evt2 = asyncio.Event()
|
||||
|
||||
def assert_between(min, max, val):
|
||||
assert min <= val <= max
|
||||
|
||||
# test 1: both timeout
|
||||
ts = time.time()
|
||||
finished = await rnsh.process.event_wait_any([evt1, evt2], timeout=delay*2)
|
||||
assert_between(delay*2, delay*2.1, time.time() - ts)
|
||||
assert finished is None
|
||||
assert not evt1.is_set()
|
||||
assert not evt2.is_set()
|
||||
|
||||
#test 2: evt1 set, evt2 not set
|
||||
hits = 0
|
||||
|
||||
def test2_bg():
|
||||
nonlocal hits
|
||||
hits += 1
|
||||
time.sleep(delay)
|
||||
evt1.set()
|
||||
|
||||
ts = time.time()
|
||||
pool.apply_async(test2_bg)
|
||||
finished = await rnsh.process.event_wait_any([evt1, evt2], timeout=delay * 2)
|
||||
assert_between(delay * 0.5, delay * 1.5, time.time() - ts)
|
||||
assert hits == 1
|
||||
assert evt1.is_set()
|
||||
assert not evt2.is_set()
|
||||
assert finished == evt1
|
||||
|
Loading…
Reference in New Issue
Block a user