Performance and stability improvements.

This commit is contained in:
Aaron Heise 2023-02-12 22:29:16 -06:00
parent 7d711cde6d
commit 5e755acad4
6 changed files with 303 additions and 113 deletions

View file

@ -34,6 +34,7 @@ import sys
import termios
import threading
import time
import tty
from typing import Callable, TypeVar
import RNS
import rnsh.exception as exception
@ -58,20 +59,20 @@ _allow_all = False
_allowed_identity_hashes = []
_cmd: [str] = None
DATA_AVAIL_MSG = "data available"
_finished: asyncio.Event | None = None
_finished: asyncio.Event = None
_retry_timer: retry.RetryThread | None = None
_destination: RNS.Destination | None = None
_loop: asyncio.AbstractEventLoop | None = None
async def _check_finished(timeout: float = 0):
await process.event_wait(_finished, timeout=timeout)
return await process.event_wait(_finished, timeout=timeout)
def _sigint_handler(sig, frame):
global _finished
log = _get_logger("_sigint_handler")
log.debug("SIGINT")
log.debug(signal.Signals(sig).name)
if _finished is not None:
_finished.set()
else:
@ -107,8 +108,6 @@ def _print_identity(configdir, identitypath, service_name, include_destination:
exit(0)
# hack_goes_here
async def _listen(configdir, command, identitypath=None, service_name="default", verbosity=0, quietness=0,
allowed=None, disable_auth=None, announce_period=900):
global _identity, _allow_all, _allowed_identity_hashes, _reticulum, _cmd, _destination
@ -116,8 +115,8 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
_cmd = command
targetloglevel = RNS.LOG_INFO + verbosity - quietness
rnslogging.RnsHandler.set_global_log_level(__logging.INFO)
_reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
rnslogging.RnsHandler.set_log_level_with_rns_level(targetloglevel)
_prepare_identity(identitypath)
_destination = RNS.Destination(_identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, service_name)
@ -151,6 +150,7 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
_destination.register_request_handler(
path="data",
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
# response_generator=_listen_request,
allow=RNS.Destination.ALLOW_LIST,
allowed_list=_allowed_identity_hashes
)
@ -158,10 +158,12 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
_destination.register_request_handler(
path="data",
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
# response_generator=_listen_request,
allow=RNS.Destination.ALLOW_ALL,
)
await _check_finished()
if await _check_finished():
return
log.info("rnsh listening for commands on " + RNS.prettyhexrep(_destination.hash))
@ -171,23 +173,23 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
last = time.time()
try:
while True:
while not await _check_finished(1.0):
if announce_period and 0 < announce_period < time.time() - last:
last = time.time()
_destination.announce()
await _check_finished(1.0)
except KeyboardInterrupt:
finally:
log.warning("Shutting down")
for link in list(_destination.links):
with exception.permit(SystemExit):
with exception.permit(SystemExit, KeyboardInterrupt):
proc = Session.get_for_tag(link.link_id)
if proc is not None and proc.process.running:
proc.process.terminate()
await asyncio.sleep(1)
await asyncio.sleep(0)
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
for link in links_still_active:
if link.status != RNS.Link.CLOSED:
link.teardown()
await asyncio.sleep(0)
_PROTOCOL_MAGIC = 0xdeadbeef
@ -334,21 +336,22 @@ class Session:
@staticmethod
def default_request(stdin_fd: int | None) -> [any]:
global _tr
global _PROTOCOL_VERSION_0
request: list[any] = [
_PROTOCOL_VERSION_0, # 0 Protocol Version
None, # 1 Stdin
None, # 2 TERM variable
None, # 3 termios attributes or something
None, # 4 terminal rows
None, # 5 terminal cols
None, # 6 terminal horizontal pixels
None, # 7 terminal vertical pixels
None, # 1 Stdin
None, # 2 TERM variable
None, # 3 termios attributes or something
None, # 4 terminal rows
None, # 5 terminal cols
None, # 6 terminal horizontal pixels
None, # 7 terminal vertical pixels
].copy()
if stdin_fd is not None:
request[Session.REQUEST_IDX_TERM] = os.environ.get("TERM", None)
request[Session.REQUEST_IDX_TIOS] = termios.tcgetattr(stdin_fd)
request[Session.REQUEST_IDX_TIOS] = _tr.original_attr() if _tr else termios.tcgetattr(stdin_fd)
request[Session.REQUEST_IDX_ROWS], \
request[Session.REQUEST_IDX_COLS], \
request[Session.REQUEST_IDX_HPIX], \
@ -372,6 +375,7 @@ class Session:
if term_state != self._term_state:
self._term_state = term_state
self._update_winsz()
# self.process.tcsetattr(termios.TCSANOW, self._term_state[0])
if stdin is not None and len(stdin) > 0:
stdin = base64.b64decode(stdin)
self.process.write(stdin)
@ -397,11 +401,11 @@ class Session:
global _PROTOCOL_VERSION_0
response: list[any] = [
_PROTOCOL_VERSION_0, # 0: Protocol version
False, # 1: Process running
None, # 2: Return value
0, # 3: Number of outstanding bytes
None, # 4: Stdout/Stderr
None, # 5: Timestamp
False, # 1: Process running
None, # 2: Return value
0, # 3: Number of outstanding bytes
None, # 4: Stdout/Stderr
None, # 5: Timestamp
].copy()
response[Session.RESPONSE_IDX_TMSTAMP] = time.time()
return response
@ -539,7 +543,8 @@ def _initiator_identified(link, identity):
def _listen_request(path, data, request_id, link_id, remote_identity, requested_at):
global _destination, _retry_timer, _loop
log = _get_logger("_listen_request")
log.debug(f"listen_execute {path} {request_id} {link_id} {remote_identity}, {requested_at}")
log.debug(
f"listen_execute {path} {RNS.prettyhexrep(request_id)} {RNS.prettyhexrep(link_id)} {remote_identity}, {requested_at}")
if not hasattr(data, "__len__") or len(data) < 1:
raise Exception("Request data invalid")
_retry_timer.complete(link_id)
@ -553,7 +558,8 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
if not remote_version == _PROTOCOL_VERSION_0:
response = Session.default_response()
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode("Listener<->initiator version mismatch\r\n".encode("utf-8"))
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode(
"Listener<->initiator version mismatch\r\n".encode("utf-8"))
response[Session.RESPONSE_IDX_RETCODE] = 255
response[Session.RESPONSE_IDX_RDYBYTE] = 0
return response
@ -565,12 +571,12 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
if session is None:
log.debug(f"Process not found for link {link}")
session = _listen_start_proc(link=link,
term=term,
remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""),
loop=_loop)
term=term,
remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""),
loop=_loop)
# leave significant headroom for metadata and encoding
result = session.process_request(data, link.MDU * 4 // 3)
result = session.process_request(data, link.MDU * 1 // 2)
return result
# return ProcessState.default_response()
except Exception as e:
@ -589,7 +595,8 @@ async def _spin(until: callable = None, timeout: float | None = None) -> bool:
timeout += time.time()
while (timeout is None or time.time() < timeout) and not until():
await _check_finished(0.01)
if await _check_finished(0.001):
raise asyncio.CancelledError()
if timeout is not None and time.time() > timeout:
return False
else:
@ -638,8 +645,8 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
if _reticulum is None:
targetloglevel = RNS.LOG_ERROR + verbosity - quietness
rnslogging.RnsHandler.set_global_log_level(__logging.ERROR)
_reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
rnslogging.RnsHandler.set_log_level_with_rns_level(targetloglevel)
if _identity is None:
_prepare_identity(identitypath)
@ -661,6 +668,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
)
if _link is None or _link.status == RNS.Link.PENDING:
log.debug("No link")
_link = RNS.Link(_destination)
_link.did_identify = False
@ -668,6 +676,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
if not await _spin(until=lambda: _link.status == RNS.Link.ACTIVE, timeout=timeout):
raise RemoteExecutionError("Could not establish link with " + RNS.prettyhexrep(destination_hash))
log.debug("Have link")
if not noid and not _link.did_identify:
_link.identify(_identity)
_link.did_identify = True
@ -680,6 +689,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
# TODO: Tune
timeout = timeout + _link.rtt * 4 + _remote_exec_grace
log.debug("Sending request")
request_receipt = _link.request(
path="data",
data=request,
@ -687,6 +697,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
)
timeout += 0.5
log.debug("Waiting for delivery")
await _spin(
until=lambda: _link.status == RNS.Link.CLOSED or (
request_receipt.status != RNS.RequestReceipt.FAILED and
@ -732,15 +743,14 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
except Exception as e:
raise RemoteExecutionError(f"Received invalid response") from e
_tr.raw()
if stdout is not None:
# log.debug(f"stdout: {stdout}")
_tr.raw()
log.debug(f"stdout: {stdout}")
os.write(sys.stdout.fileno(), stdout)
sys.stdout.flush()
sys.stdout.flush()
sys.stderr.flush()
log.debug(f"{ready_bytes} bytes ready on server, return code {return_code}")
got_bytes = len(stdout) if stdout is not None else 0
log.debug(f"{got_bytes} chars received, {ready_bytes} bytes ready on server, return code {return_code}")
if ready_bytes > 0:
_new_data.set()
@ -761,11 +771,6 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
data_buffer = bytearray()
def sigint_handler():
log.debug("KeyboardInterrupt")
data_buffer.extend("\x03".encode("utf-8"))
_new_data.set()
def sigwinch_handler():
# log.debug("WindowChanged")
if _new_data is not None:
@ -773,7 +778,7 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
def stdin():
data = process.tty_read(sys.stdin.fileno())
# log.debug(f"stdin {data}")
log.debug(f"stdin {data}")
if data is not None:
data_buffer.extend(data)
_new_data.set()
@ -782,10 +787,12 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
await _check_finished()
loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler)
# leave a lot of overhead
mdu = 64
rtt = 5
first_loop = True
while True:
while not await _check_finished():
try:
log.debug("top of client loop")
stdin = data_buffer[:mdu]
@ -806,22 +813,27 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
if first_loop:
first_loop = False
mdu = _link.MDU * 4 // 3
loop.remove_signal_handler(signal.SIGINT)
loop.add_signal_handler(signal.SIGINT, sigint_handler)
mdu = _link.MDU // 2
_new_data.set()
if _link:
rtt = _link.rtt
if return_code is not None:
log.debug(f"received return code {return_code}, exiting")
with exception.permit(SystemExit):
with exception.permit(SystemExit, KeyboardInterrupt):
_link.teardown()
return return_code
except asyncio.CancelledError:
if _link and _link.status != RNS.Link.CLOSED:
_link.teardown()
return 0
except RemoteExecutionError as e:
print(e.msg)
return 255
await process.event_wait(_new_data, 5)
await process.event_wait_any([_new_data, _finished], timeout=min(max(rtt * 50, 5), 120))
_T = TypeVar("_T")
@ -835,6 +847,11 @@ def _split_array_at(arr: [_T], at: _T) -> ([_T], [_T]):
return arr, []
def _loop_set_signal(sig, loop):
loop.remove_signal_handler(sig)
loop.add_signal_handler(sig, functools.partial(_sigint_handler, sig, None))
async def _rnsh_cli_main():
global _tr, _finished, _loop
import docopt
@ -842,16 +859,16 @@ async def _rnsh_cli_main():
_loop = asyncio.get_running_loop()
rnslogging.set_main_loop(_loop)
_finished = asyncio.Event()
_loop.remove_signal_handler(signal.SIGINT)
_loop.add_signal_handler(signal.SIGINT, functools.partial(_sigint_handler, signal.SIGINT, None))
_loop_set_signal(signal.SIGINT, _loop)
_loop_set_signal(signal.SIGTERM, _loop)
usage = '''
Usage:
rnsh [--config <configdir>] [-i <identityfile>] [-s <service_name>] [-l] -p
rnsh -l [--config <configfile>] [-i <identityfile>] [-s <service_name>]
[-v...] [-q...] [-b <period>] (-n | -a <identity_hash> [-a <identity_hash>] ...)
[-v... | -q...] [-b <period>] (-n | -a <identity_hash> [-a <identity_hash>] ...)
[--] <program> [<arg> ...]
rnsh [--config <configfile>] [-i <identityfile>] [-s <service_name>]
[-v...] [-q...] [-N] [-m] [-w <timeout>] <destination_hash>
[-v... | -q...] [-N] [-m] [-w <timeout>] <destination_hash>
rnsh -h
rnsh --version
@ -869,6 +886,12 @@ Options:
-m --mirror Client returns with code of remote process
-w TIME --timeout TIME Specify client connect and request timeout in seconds
-v --verbose Increase verbosity
DEFAULT LEVEL
CRITICAL
Initiator -> ERROR
WARNING
Listener -> INFO
DEBUG
-q --quiet Increase quietness
--version Show version
-h --help Show this help
@ -930,36 +953,40 @@ Options:
)
if args_destination is not None and args_service_name is not None:
try:
return_code = await _initiate(
configdir=args_config,
identitypath=args_identity,
verbosity=args_verbose,
quietness=args_quiet,
noid=args_no_id,
destination=args_destination,
service_name=args_service_name,
timeout=args_timeout,
)
return return_code if args_mirror else 0
finally:
_tr.restore()
return_code = await _initiate(
configdir=args_config,
identitypath=args_identity,
verbosity=args_verbose,
quietness=args_quiet,
noid=args_no_id,
destination=args_destination,
service_name=args_service_name,
timeout=args_timeout,
)
return return_code if args_mirror else 0
else:
print("")
print(args)
print("")
def _noop():
pass
# RNS.exit = _noop
def rnsh_cli():
global _tr, _retry_timer
with process.TTYRestorer(sys.stdin.fileno()) as _tr:
with retry.RetryThread() as _retry_timer:
return_code = asyncio.run(_rnsh_cli_main())
with process.TTYRestorer(sys.stdin.fileno()) as _tr, retry.RetryThread() as _retry_timer:
return_code = asyncio.run(_rnsh_cli_main())
with exception.permit(SystemExit):
process.tty_unset_reader_callbacks(sys.stdin.fileno())
sys.exit(return_code or 255)
# RNS.Reticulum.exit_handler()
# time.sleep(0.5)
sys.exit(return_code if return_code is not None else 255)
if __name__ == "__main__":