mirror of
https://github.com/markqvist/rnsh.git
synced 2025-06-20 20:14:15 -04:00
Performance and stability improvements.
This commit is contained in:
parent
7d711cde6d
commit
5e755acad4
6 changed files with 303 additions and 113 deletions
173
rnsh/rnsh.py
173
rnsh/rnsh.py
|
@ -34,6 +34,7 @@ import sys
|
|||
import termios
|
||||
import threading
|
||||
import time
|
||||
import tty
|
||||
from typing import Callable, TypeVar
|
||||
import RNS
|
||||
import rnsh.exception as exception
|
||||
|
@ -58,20 +59,20 @@ _allow_all = False
|
|||
_allowed_identity_hashes = []
|
||||
_cmd: [str] = None
|
||||
DATA_AVAIL_MSG = "data available"
|
||||
_finished: asyncio.Event | None = None
|
||||
_finished: asyncio.Event = None
|
||||
_retry_timer: retry.RetryThread | None = None
|
||||
_destination: RNS.Destination | None = None
|
||||
_loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
|
||||
async def _check_finished(timeout: float = 0):
|
||||
await process.event_wait(_finished, timeout=timeout)
|
||||
return await process.event_wait(_finished, timeout=timeout)
|
||||
|
||||
|
||||
def _sigint_handler(sig, frame):
|
||||
global _finished
|
||||
log = _get_logger("_sigint_handler")
|
||||
log.debug("SIGINT")
|
||||
log.debug(signal.Signals(sig).name)
|
||||
if _finished is not None:
|
||||
_finished.set()
|
||||
else:
|
||||
|
@ -107,8 +108,6 @@ def _print_identity(configdir, identitypath, service_name, include_destination:
|
|||
exit(0)
|
||||
|
||||
|
||||
# hack_goes_here
|
||||
|
||||
async def _listen(configdir, command, identitypath=None, service_name="default", verbosity=0, quietness=0,
|
||||
allowed=None, disable_auth=None, announce_period=900):
|
||||
global _identity, _allow_all, _allowed_identity_hashes, _reticulum, _cmd, _destination
|
||||
|
@ -116,8 +115,8 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
|||
_cmd = command
|
||||
|
||||
targetloglevel = RNS.LOG_INFO + verbosity - quietness
|
||||
rnslogging.RnsHandler.set_global_log_level(__logging.INFO)
|
||||
_reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
|
||||
rnslogging.RnsHandler.set_log_level_with_rns_level(targetloglevel)
|
||||
_prepare_identity(identitypath)
|
||||
_destination = RNS.Destination(_identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, service_name)
|
||||
|
||||
|
@ -151,6 +150,7 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
|||
_destination.register_request_handler(
|
||||
path="data",
|
||||
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
|
||||
# response_generator=_listen_request,
|
||||
allow=RNS.Destination.ALLOW_LIST,
|
||||
allowed_list=_allowed_identity_hashes
|
||||
)
|
||||
|
@ -158,10 +158,12 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
|||
_destination.register_request_handler(
|
||||
path="data",
|
||||
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
|
||||
# response_generator=_listen_request,
|
||||
allow=RNS.Destination.ALLOW_ALL,
|
||||
)
|
||||
|
||||
await _check_finished()
|
||||
if await _check_finished():
|
||||
return
|
||||
|
||||
log.info("rnsh listening for commands on " + RNS.prettyhexrep(_destination.hash))
|
||||
|
||||
|
@ -171,23 +173,23 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
|||
last = time.time()
|
||||
|
||||
try:
|
||||
while True:
|
||||
while not await _check_finished(1.0):
|
||||
if announce_period and 0 < announce_period < time.time() - last:
|
||||
last = time.time()
|
||||
_destination.announce()
|
||||
await _check_finished(1.0)
|
||||
except KeyboardInterrupt:
|
||||
finally:
|
||||
log.warning("Shutting down")
|
||||
for link in list(_destination.links):
|
||||
with exception.permit(SystemExit):
|
||||
with exception.permit(SystemExit, KeyboardInterrupt):
|
||||
proc = Session.get_for_tag(link.link_id)
|
||||
if proc is not None and proc.process.running:
|
||||
proc.process.terminate()
|
||||
await asyncio.sleep(1)
|
||||
await asyncio.sleep(0)
|
||||
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
|
||||
for link in links_still_active:
|
||||
if link.status != RNS.Link.CLOSED:
|
||||
link.teardown()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
_PROTOCOL_MAGIC = 0xdeadbeef
|
||||
|
@ -334,21 +336,22 @@ class Session:
|
|||
|
||||
@staticmethod
|
||||
def default_request(stdin_fd: int | None) -> [any]:
|
||||
global _tr
|
||||
global _PROTOCOL_VERSION_0
|
||||
request: list[any] = [
|
||||
_PROTOCOL_VERSION_0, # 0 Protocol Version
|
||||
None, # 1 Stdin
|
||||
None, # 2 TERM variable
|
||||
None, # 3 termios attributes or something
|
||||
None, # 4 terminal rows
|
||||
None, # 5 terminal cols
|
||||
None, # 6 terminal horizontal pixels
|
||||
None, # 7 terminal vertical pixels
|
||||
None, # 1 Stdin
|
||||
None, # 2 TERM variable
|
||||
None, # 3 termios attributes or something
|
||||
None, # 4 terminal rows
|
||||
None, # 5 terminal cols
|
||||
None, # 6 terminal horizontal pixels
|
||||
None, # 7 terminal vertical pixels
|
||||
].copy()
|
||||
|
||||
if stdin_fd is not None:
|
||||
request[Session.REQUEST_IDX_TERM] = os.environ.get("TERM", None)
|
||||
request[Session.REQUEST_IDX_TIOS] = termios.tcgetattr(stdin_fd)
|
||||
request[Session.REQUEST_IDX_TIOS] = _tr.original_attr() if _tr else termios.tcgetattr(stdin_fd)
|
||||
request[Session.REQUEST_IDX_ROWS], \
|
||||
request[Session.REQUEST_IDX_COLS], \
|
||||
request[Session.REQUEST_IDX_HPIX], \
|
||||
|
@ -372,6 +375,7 @@ class Session:
|
|||
if term_state != self._term_state:
|
||||
self._term_state = term_state
|
||||
self._update_winsz()
|
||||
# self.process.tcsetattr(termios.TCSANOW, self._term_state[0])
|
||||
if stdin is not None and len(stdin) > 0:
|
||||
stdin = base64.b64decode(stdin)
|
||||
self.process.write(stdin)
|
||||
|
@ -397,11 +401,11 @@ class Session:
|
|||
global _PROTOCOL_VERSION_0
|
||||
response: list[any] = [
|
||||
_PROTOCOL_VERSION_0, # 0: Protocol version
|
||||
False, # 1: Process running
|
||||
None, # 2: Return value
|
||||
0, # 3: Number of outstanding bytes
|
||||
None, # 4: Stdout/Stderr
|
||||
None, # 5: Timestamp
|
||||
False, # 1: Process running
|
||||
None, # 2: Return value
|
||||
0, # 3: Number of outstanding bytes
|
||||
None, # 4: Stdout/Stderr
|
||||
None, # 5: Timestamp
|
||||
].copy()
|
||||
response[Session.RESPONSE_IDX_TMSTAMP] = time.time()
|
||||
return response
|
||||
|
@ -539,7 +543,8 @@ def _initiator_identified(link, identity):
|
|||
def _listen_request(path, data, request_id, link_id, remote_identity, requested_at):
|
||||
global _destination, _retry_timer, _loop
|
||||
log = _get_logger("_listen_request")
|
||||
log.debug(f"listen_execute {path} {request_id} {link_id} {remote_identity}, {requested_at}")
|
||||
log.debug(
|
||||
f"listen_execute {path} {RNS.prettyhexrep(request_id)} {RNS.prettyhexrep(link_id)} {remote_identity}, {requested_at}")
|
||||
if not hasattr(data, "__len__") or len(data) < 1:
|
||||
raise Exception("Request data invalid")
|
||||
_retry_timer.complete(link_id)
|
||||
|
@ -553,7 +558,8 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
|
|||
|
||||
if not remote_version == _PROTOCOL_VERSION_0:
|
||||
response = Session.default_response()
|
||||
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode("Listener<->initiator version mismatch\r\n".encode("utf-8"))
|
||||
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode(
|
||||
"Listener<->initiator version mismatch\r\n".encode("utf-8"))
|
||||
response[Session.RESPONSE_IDX_RETCODE] = 255
|
||||
response[Session.RESPONSE_IDX_RDYBYTE] = 0
|
||||
return response
|
||||
|
@ -565,12 +571,12 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
|
|||
if session is None:
|
||||
log.debug(f"Process not found for link {link}")
|
||||
session = _listen_start_proc(link=link,
|
||||
term=term,
|
||||
remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""),
|
||||
loop=_loop)
|
||||
term=term,
|
||||
remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""),
|
||||
loop=_loop)
|
||||
|
||||
# leave significant headroom for metadata and encoding
|
||||
result = session.process_request(data, link.MDU * 4 // 3)
|
||||
result = session.process_request(data, link.MDU * 1 // 2)
|
||||
return result
|
||||
# return ProcessState.default_response()
|
||||
except Exception as e:
|
||||
|
@ -589,7 +595,8 @@ async def _spin(until: callable = None, timeout: float | None = None) -> bool:
|
|||
timeout += time.time()
|
||||
|
||||
while (timeout is None or time.time() < timeout) and not until():
|
||||
await _check_finished(0.01)
|
||||
if await _check_finished(0.001):
|
||||
raise asyncio.CancelledError()
|
||||
if timeout is not None and time.time() > timeout:
|
||||
return False
|
||||
else:
|
||||
|
@ -638,8 +645,8 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
|||
|
||||
if _reticulum is None:
|
||||
targetloglevel = RNS.LOG_ERROR + verbosity - quietness
|
||||
rnslogging.RnsHandler.set_global_log_level(__logging.ERROR)
|
||||
_reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
|
||||
rnslogging.RnsHandler.set_log_level_with_rns_level(targetloglevel)
|
||||
|
||||
if _identity is None:
|
||||
_prepare_identity(identitypath)
|
||||
|
@ -661,6 +668,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
|||
)
|
||||
|
||||
if _link is None or _link.status == RNS.Link.PENDING:
|
||||
log.debug("No link")
|
||||
_link = RNS.Link(_destination)
|
||||
_link.did_identify = False
|
||||
|
||||
|
@ -668,6 +676,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
|||
if not await _spin(until=lambda: _link.status == RNS.Link.ACTIVE, timeout=timeout):
|
||||
raise RemoteExecutionError("Could not establish link with " + RNS.prettyhexrep(destination_hash))
|
||||
|
||||
log.debug("Have link")
|
||||
if not noid and not _link.did_identify:
|
||||
_link.identify(_identity)
|
||||
_link.did_identify = True
|
||||
|
@ -680,6 +689,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
|||
# TODO: Tune
|
||||
timeout = timeout + _link.rtt * 4 + _remote_exec_grace
|
||||
|
||||
log.debug("Sending request")
|
||||
request_receipt = _link.request(
|
||||
path="data",
|
||||
data=request,
|
||||
|
@ -687,6 +697,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
|||
)
|
||||
timeout += 0.5
|
||||
|
||||
log.debug("Waiting for delivery")
|
||||
await _spin(
|
||||
until=lambda: _link.status == RNS.Link.CLOSED or (
|
||||
request_receipt.status != RNS.RequestReceipt.FAILED and
|
||||
|
@ -732,15 +743,14 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
|||
except Exception as e:
|
||||
raise RemoteExecutionError(f"Received invalid response") from e
|
||||
|
||||
_tr.raw()
|
||||
if stdout is not None:
|
||||
# log.debug(f"stdout: {stdout}")
|
||||
_tr.raw()
|
||||
log.debug(f"stdout: {stdout}")
|
||||
os.write(sys.stdout.fileno(), stdout)
|
||||
sys.stdout.flush()
|
||||
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
|
||||
log.debug(f"{ready_bytes} bytes ready on server, return code {return_code}")
|
||||
got_bytes = len(stdout) if stdout is not None else 0
|
||||
log.debug(f"{got_bytes} chars received, {ready_bytes} bytes ready on server, return code {return_code}")
|
||||
|
||||
if ready_bytes > 0:
|
||||
_new_data.set()
|
||||
|
@ -761,11 +771,6 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
|
|||
|
||||
data_buffer = bytearray()
|
||||
|
||||
def sigint_handler():
|
||||
log.debug("KeyboardInterrupt")
|
||||
data_buffer.extend("\x03".encode("utf-8"))
|
||||
_new_data.set()
|
||||
|
||||
def sigwinch_handler():
|
||||
# log.debug("WindowChanged")
|
||||
if _new_data is not None:
|
||||
|
@ -773,7 +778,7 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
|
|||
|
||||
def stdin():
|
||||
data = process.tty_read(sys.stdin.fileno())
|
||||
# log.debug(f"stdin {data}")
|
||||
log.debug(f"stdin {data}")
|
||||
if data is not None:
|
||||
data_buffer.extend(data)
|
||||
_new_data.set()
|
||||
|
@ -782,10 +787,12 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
|
|||
|
||||
await _check_finished()
|
||||
loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler)
|
||||
|
||||
# leave a lot of overhead
|
||||
mdu = 64
|
||||
rtt = 5
|
||||
first_loop = True
|
||||
while True:
|
||||
while not await _check_finished():
|
||||
try:
|
||||
log.debug("top of client loop")
|
||||
stdin = data_buffer[:mdu]
|
||||
|
@ -806,22 +813,27 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
|
|||
|
||||
if first_loop:
|
||||
first_loop = False
|
||||
mdu = _link.MDU * 4 // 3
|
||||
loop.remove_signal_handler(signal.SIGINT)
|
||||
loop.add_signal_handler(signal.SIGINT, sigint_handler)
|
||||
mdu = _link.MDU // 2
|
||||
_new_data.set()
|
||||
|
||||
if _link:
|
||||
rtt = _link.rtt
|
||||
|
||||
if return_code is not None:
|
||||
log.debug(f"received return code {return_code}, exiting")
|
||||
with exception.permit(SystemExit):
|
||||
with exception.permit(SystemExit, KeyboardInterrupt):
|
||||
_link.teardown()
|
||||
|
||||
return return_code
|
||||
except asyncio.CancelledError:
|
||||
if _link and _link.status != RNS.Link.CLOSED:
|
||||
_link.teardown()
|
||||
return 0
|
||||
except RemoteExecutionError as e:
|
||||
print(e.msg)
|
||||
return 255
|
||||
|
||||
await process.event_wait(_new_data, 5)
|
||||
await process.event_wait_any([_new_data, _finished], timeout=min(max(rtt * 50, 5), 120))
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
@ -835,6 +847,11 @@ def _split_array_at(arr: [_T], at: _T) -> ([_T], [_T]):
|
|||
return arr, []
|
||||
|
||||
|
||||
def _loop_set_signal(sig, loop):
|
||||
loop.remove_signal_handler(sig)
|
||||
loop.add_signal_handler(sig, functools.partial(_sigint_handler, sig, None))
|
||||
|
||||
|
||||
async def _rnsh_cli_main():
|
||||
global _tr, _finished, _loop
|
||||
import docopt
|
||||
|
@ -842,16 +859,16 @@ async def _rnsh_cli_main():
|
|||
_loop = asyncio.get_running_loop()
|
||||
rnslogging.set_main_loop(_loop)
|
||||
_finished = asyncio.Event()
|
||||
_loop.remove_signal_handler(signal.SIGINT)
|
||||
_loop.add_signal_handler(signal.SIGINT, functools.partial(_sigint_handler, signal.SIGINT, None))
|
||||
_loop_set_signal(signal.SIGINT, _loop)
|
||||
_loop_set_signal(signal.SIGTERM, _loop)
|
||||
usage = '''
|
||||
Usage:
|
||||
rnsh [--config <configdir>] [-i <identityfile>] [-s <service_name>] [-l] -p
|
||||
rnsh -l [--config <configfile>] [-i <identityfile>] [-s <service_name>]
|
||||
[-v...] [-q...] [-b <period>] (-n | -a <identity_hash> [-a <identity_hash>] ...)
|
||||
[-v... | -q...] [-b <period>] (-n | -a <identity_hash> [-a <identity_hash>] ...)
|
||||
[--] <program> [<arg> ...]
|
||||
rnsh [--config <configfile>] [-i <identityfile>] [-s <service_name>]
|
||||
[-v...] [-q...] [-N] [-m] [-w <timeout>] <destination_hash>
|
||||
[-v... | -q...] [-N] [-m] [-w <timeout>] <destination_hash>
|
||||
rnsh -h
|
||||
rnsh --version
|
||||
|
||||
|
@ -869,6 +886,12 @@ Options:
|
|||
-m --mirror Client returns with code of remote process
|
||||
-w TIME --timeout TIME Specify client connect and request timeout in seconds
|
||||
-v --verbose Increase verbosity
|
||||
DEFAULT LEVEL
|
||||
CRITICAL
|
||||
Initiator -> ERROR
|
||||
WARNING
|
||||
Listener -> INFO
|
||||
DEBUG
|
||||
-q --quiet Increase quietness
|
||||
--version Show version
|
||||
-h --help Show this help
|
||||
|
@ -930,36 +953,40 @@ Options:
|
|||
)
|
||||
|
||||
if args_destination is not None and args_service_name is not None:
|
||||
try:
|
||||
return_code = await _initiate(
|
||||
configdir=args_config,
|
||||
identitypath=args_identity,
|
||||
verbosity=args_verbose,
|
||||
quietness=args_quiet,
|
||||
noid=args_no_id,
|
||||
destination=args_destination,
|
||||
service_name=args_service_name,
|
||||
timeout=args_timeout,
|
||||
)
|
||||
return return_code if args_mirror else 0
|
||||
finally:
|
||||
_tr.restore()
|
||||
return_code = await _initiate(
|
||||
configdir=args_config,
|
||||
identitypath=args_identity,
|
||||
verbosity=args_verbose,
|
||||
quietness=args_quiet,
|
||||
noid=args_no_id,
|
||||
destination=args_destination,
|
||||
service_name=args_service_name,
|
||||
timeout=args_timeout,
|
||||
)
|
||||
return return_code if args_mirror else 0
|
||||
else:
|
||||
print("")
|
||||
print(args)
|
||||
print("")
|
||||
|
||||
|
||||
def _noop():
|
||||
pass
|
||||
|
||||
# RNS.exit = _noop
|
||||
|
||||
|
||||
def rnsh_cli():
|
||||
global _tr, _retry_timer
|
||||
with process.TTYRestorer(sys.stdin.fileno()) as _tr:
|
||||
with retry.RetryThread() as _retry_timer:
|
||||
return_code = asyncio.run(_rnsh_cli_main())
|
||||
with process.TTYRestorer(sys.stdin.fileno()) as _tr, retry.RetryThread() as _retry_timer:
|
||||
return_code = asyncio.run(_rnsh_cli_main())
|
||||
|
||||
with exception.permit(SystemExit):
|
||||
process.tty_unset_reader_callbacks(sys.stdin.fileno())
|
||||
|
||||
sys.exit(return_code or 255)
|
||||
# RNS.Reticulum.exit_handler()
|
||||
# time.sleep(0.5)
|
||||
sys.exit(return_code if return_code is not None else 255)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue