From 3183923c8c9f70315916d7eb95a89bcf7cc98710 Mon Sep 17 00:00:00 2001 From: Aaron Heise Date: Fri, 10 Feb 2023 03:08:58 -0600 Subject: [PATCH] Lots more stuff, still more debugging to do --- rnsh/__version.py | 1 + rnsh/process.py | 40 +++ rnsh/retry.py | 69 ++-- rnsh/rnsh.py | 816 ++++++++++++++++++++++----------------------- rnsh/rnslogging.py | 31 +- 5 files changed, 515 insertions(+), 442 deletions(-) create mode 100644 rnsh/__version.py diff --git a/rnsh/__version.py b/rnsh/__version.py new file mode 100644 index 0000000..b3c06d4 --- /dev/null +++ b/rnsh/__version.py @@ -0,0 +1 @@ +__version__ = "0.0.1" \ No newline at end of file diff --git a/rnsh/process.py b/rnsh/process.py index 1de6a3f..628f9f6 100644 --- a/rnsh/process.py +++ b/rnsh/process.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import errno import functools import signal @@ -144,6 +145,20 @@ class TtyRestorer: """ termios.tcsetattr(self._fd, termios.TCSADRAIN, self._tattr) + +async def event_wait(evt: asyncio.Event, timeout: float) -> bool: + """ + Wait for event to be set, or timeout to expire. + :param evt: asyncio.Event to wait on + :param timeout: maximum number of seconds to wait. + :return: True if event was set, False if timeout expired + """ + # suppress TimeoutError because we'll return False in case of timeout + with contextlib.suppress(asyncio.TimeoutError): + await asyncio.wait_for(evt.wait(), timeout) + return evt.is_set() + + class CallbackSubprocess: # time between checks of child process PROCESS_POLL_TIME: float = 0.1 @@ -309,6 +324,31 @@ class CallbackSubprocess: def return_code(self) -> int | None: return self._return_code +# # from https://gist.github.com/bruce-shi/fd0e3f5e2360c64bc9ce2efb254744f7 +# from collections import defaultdict +# class disable_signals(object): +# def __init__(self, disabled_signals=None): +# self.stashed_signals = defaultdict(list) +# self.disabled_signals = disabled_signals or [] +# +# def __enter__(self): +# for signal in self.disabled_signals: +# self.disconnect(signal) +# +# def __exit__(self, exc_type, exc_val, exc_tb): +# for signal in list(self.stashed_signals): +# self.reconnect(signal) +# +# def disconnect(self, signal): +# self.stashed_signals[signal] = signal.receivers +# signal.receivers = [] +# +# def reconnect(self, signal): +# signal.receivers = self.stashed_signals.get(signal, []) +# del self.stashed_signals[signal] +# signal.sender_receivers_cache.clear() + + async def main(): """ A test driver for the CallbackProcess class. diff --git a/rnsh/retry.py b/rnsh/retry.py index fac8045..759acc0 100644 --- a/rnsh/retry.py +++ b/rnsh/retry.py @@ -2,11 +2,15 @@ import asyncio import threading import time import logging as __logging +from typing import Callable + module_logger = __logging.getLogger(__name__) + class RetryStatus: - def __init__(self, id: any, try_limit: int, wait_delay: float, retry_callback: callable[any, int], timeout_callback: callable[any], tries: int = 1): - self.id = id + def __init__(self, tag: any, try_limit: int, wait_delay: float, retry_callback: Callable[[any, int], any], + timeout_callback: Callable[[any, int], None], tries: int = 1): + self.tag = tag self.try_limit = try_limit self.tries = tries self.wait_delay = wait_delay @@ -25,22 +29,23 @@ class RetryStatus: def timeout(self): self.completed = True - self.timeout_callback(self.id) + self.timeout_callback(self.tag, self.tries) def retry(self): self.tries += 1 - self.retry_callback(self.id, self.tries) + self.retry_callback(self.tag, self.tries) + class RetryThread: def __init__(self, loop_period: float = 0.25): self._log = module_logger.getChild(self.__class__.__name__) self._loop_period = loop_period self._statuses: list[RetryStatus] = [] - self._id_counter = 0 + self._tag_counter = 0 self._lock = threading.RLock() - self._thread = threading.Thread(target=self._thread_run()) self._run = True self._finished: asyncio.Future | None = None + self._thread = threading.Thread(target=self._thread_run) self._thread.start() def close(self, loop: asyncio.AbstractEventLoop | None = None) -> asyncio.Future | None: @@ -52,6 +57,7 @@ class RetryThread: else: self._finished = loop.create_future() return self._finished + def _thread_run(self): last_run = time.monotonic() while self._run and self._finished is None: @@ -65,52 +71,59 @@ class RetryThread: try: if not retry.completed: if retry.timed_out: - self._log.debug(f"timed out {retry.id} after {retry.try_limit} tries") + self._log.debug(f"timed out {retry.tag} after {retry.try_limit} tries") retry.timeout() prune.append(retry) else: - self._log.debug(f"retrying {retry.id}, try {retry.tries + 1}/{retry.try_limit}") + self._log.debug(f"retrying {retry.tag}, try {retry.tries + 1}/{retry.try_limit}") retry.retry() except Exception as e: - self._log.error(f"error processing retry id {retry.id}: {e}") + self._log.error(f"error processing retry id {retry.tag}: {e}") prune.append(retry) with self._lock: for retry in prune: - self._log.debug(f"pruned retry {retry.id}, retry count {retry.tries}/{retry.try_limit}") + self._log.debug(f"pruned retry {retry.tag}, retry count {retry.tries}/{retry.try_limit}") self._statuses.remove(retry) if self._finished is not None: self._finished.set_result(None) - def _get_id(self): - self._id_counter += 1 - return self._id_counter + def _get_next_tag(self): + self._tag_counter += 1 + return self._tag_counter - def begin(self, try_limit: int, wait_delay: float, try_callback: callable[[any | None, int], any], timeout_callback: callable[any, int], id: int | None = None) -> any: + def begin(self, try_limit: int, wait_delay: float, try_callback: Callable[[any, int], any], + timeout_callback: Callable[[any, int], None], tag: int | None = None) -> any: self._log.debug(f"running first try") - id = try_callback(id, 1) - self._log.debug(f"first try success, got id {id}") + tag = try_callback(tag, 1) + self._log.debug(f"first try success, got id {tag}") with self._lock: - if id is None: - id = self._get_id() - self._statuses.append(RetryStatus(id=id, + if tag is None: + tag = self._get_next_tag() + self._statuses.append(RetryStatus(tag=tag, tries=1, try_limit=try_limit, wait_delay=wait_delay, retry_callback=try_callback, timeout_callback=timeout_callback)) - self._log.debug(f"added retry timer for {id}") - def complete(self, id: any): - assert id is not None + self._log.debug(f"added retry timer for {tag}") + + def complete(self, tag: any): + assert tag is not None + status: RetryStatus | None = None with self._lock: - status = next(filter(lambda l: l.id == id, self._statuses)) - assert status is not None - status.completed = True - self._statuses.remove(status) - self._log.debug(f"completed {id}") + status = next(filter(lambda l: l.tag == tag, self._statuses)) + if status is not None: + status.completed = True + self._statuses.remove(status) + if status is not None: + self._log.debug(f"completed {tag}") + else: + self._log.debug(f"status not found to complete {tag}") + def complete_all(self): with self._lock: for status in self._statuses: status.completed = True - self._log.debug(f"completed {status.id}") + self._log.debug(f"completed {status.tag}") self._statuses.clear() diff --git a/rnsh/rnsh.py b/rnsh/rnsh.py index 9a41379..5d35cc9 100644 --- a/rnsh/rnsh.py +++ b/rnsh/rnsh.py @@ -1,5 +1,8 @@ #!/usr/bin/env python3 import functools +from typing import Callable + +import termios # MIT License # @@ -32,20 +35,22 @@ import sys import os import datetime import base64 -import RNS.vendor.umsgpack as umsgpack import process import asyncio import threading import signal import retry - +from multiprocessing.pool import ThreadPool +from __version import __version__ import logging as __logging + module_logger = __logging.getLogger(__name__) -def _getLogger(name: str): + + +def _get_logger(name: str): global module_logger return module_logger.getChild(name) -from RNS._version import __version__ APP_NAME = "rnsh" _identity = None @@ -57,55 +62,77 @@ DATA_AVAIL_MSG = "data available" _finished: asyncio.Future | None = None _retry_timer = retry.RetryThread() _destination: RNS.Destination | None = None +_pool: ThreadPool = ThreadPool(10) -def _handle_sigint_with_async_int(): +async def _pump_int(timeout: float = 0): + if timeout == 0: + if _finished.done(): + raise _finished.exception() + return + try: + await asyncio.wait_for(_finished, timeout=timeout) + except asyncio.exceptions.CancelledError: + pass + except TimeoutError: + pass + +def _handle_sigint_with_async_int(signal, frame): if _finished is not None: _finished.set_exception(KeyboardInterrupt()) else: raise KeyboardInterrupt() + signal.signal(signal.SIGINT, _handle_sigint_with_async_int) + def _prepare_identity(identity_path): global _identity - log = _getLogger("_prepare_identity") - if identity_path == None: - identity_path = RNS.Reticulum.identitypath+"/"+APP_NAME + log = _get_logger("_prepare_identity") + if identity_path is None: + identity_path = RNS.Reticulum.identitypath + "/" + APP_NAME if os.path.isfile(identity_path): _identity = RNS.Identity.from_file(identity_path) - if _identity == None: + if _identity is None: log.info("No valid saved identity found, creating new...") _identity = RNS.Identity() _identity.to_file(identity_path) -async def _listen(configdir, command, identitypath = None, service_name ="default", verbosity = 0, quietness = 0, - allowed = [], print_identity = False, disable_auth = None, disable_announce=False): + +def _print_identity(configdir, identitypath, service_name, include_destination: bool): + _reticulum = RNS.Reticulum(configdir=configdir, loglevel=RNS.LOG_INFO) + _prepare_identity(identitypath) + destination = RNS.Destination(_identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, service_name) + print("Identity : " + str(_identity)) + if include_destination: + print("Listening on : " + RNS.prettyhexrep(destination.hash)) + exit(0) + + +async def _listen(configdir, command, identitypath=None, service_name="default", verbosity=0, quietness=0, + allowed=[], disable_auth=None, disable_announce=False): global _identity, _allow_all, _allowed_identity_hashes, _reticulum, _cmd, _destination - log = _getLogger("_listen") + log = _get_logger("_listen") _cmd = command - targetloglevel = 3+verbosity-quietness + targetloglevel = 3 + verbosity - quietness _reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel) - _prepare_identity(identitypath) _destination = RNS.Destination(_identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, service_name) - if print_identity: - log.info("Identity : " + str(_identity)) - log.info("Listening on : " + RNS.prettyhexrep(_destination.hash)) - exit(0) - if disable_auth: _allow_all = True else: - if allowed != None: + if allowed is not None: for a in allowed: try: - dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH//8)*2 + dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2 if len(a) != dest_len: - raise ValueError("Allowed destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format(hex=dest_len, byte=dest_len//2)) + raise ValueError( + "Allowed destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format( + hex=dest_len, byte=dest_len // 2)) try: destination_hash = bytes.fromhex(a) _allowed_identity_hashes.append(destination_hash) @@ -122,18 +149,20 @@ async def _listen(configdir, command, identitypath = None, service_name ="defaul if not _allow_all: _destination.register_request_handler( - path = service_name, - response_generator = _listen_request, - allow = RNS.Destination.ALLOW_LIST, - allowed_list = _allowed_identity_hashes + path="data", + response_generator=_listen_request, + allow=RNS.Destination.ALLOW_LIST, + allowed_list=_allowed_identity_hashes ) else: _destination.register_request_handler( - path = service_name, - response_generator = _listen_request, - allow = RNS.Destination.ALLOW_ALL, + path="data", + response_generator=_listen_request, + allow=RNS.Destination.ALLOW_ALL, ) + await _pump_int() + log.info("rnsh listening for commands on " + RNS.prettyhexrep(_destination.hash)) if not disable_announce: @@ -146,10 +175,7 @@ async def _listen(configdir, command, identitypath = None, service_name ="defaul if not disable_announce and time.monotonic() - last > 900: # TODO: make parameter last = datetime.datetime.now() _destination.announce() - try: - await asyncio.wait_for(_finished, timeout=1.0) - except TimeoutError: - pass + await _pump_int(1.0) except KeyboardInterrupt: log.warning("Shutting down") for link in list(_destination.links): @@ -164,6 +190,7 @@ async def _listen(configdir, command, identitypath = None, service_name ="defaul if link.status != RNS.Link.CLOSED: link.teardown() + class ProcessState: def __init__(self, cmd: str, @@ -173,7 +200,7 @@ class ProcessState: term: str | None, loop: asyncio.AbstractEventLoop = None): - self._log = _getLogger(self.__class__.__name__) + self._log = _get_logger(self.__class__.__name__) self._mdu = mdu self._loop = loop if loop is not None else asyncio.get_running_loop() self._process = process.CallbackSubprocess(argv=shlex.split(cmd), @@ -187,7 +214,7 @@ class ProcessState: self._terminated_cb = terminated_callback self._pending_receipt: RNS.PacketReceipt | None = None self._process.start() - self._term_state: [int] | None = None + self._term_state: [int] = None @property def mdu(self) -> int: @@ -224,12 +251,11 @@ class ProcessState: def read(self, count: int) -> bytes: with self.lock: - take = self._data_buffer[:count-1] + take = self._data_buffer[:count - 1] self._data_buffer = self._data_buffer[count:] return take def _stdout_data(self, data: bytes): - total_available = 0 with self.lock: self._data_buffer.extend(data) total_available = len(self._data_buffer) @@ -238,28 +264,55 @@ class ProcessState: except Exception as e: self._log.error(f"Error calling ProcessState data_available_callback {e}") - def _update_winsz(self): - self.process.set_winsize(self._term_state[3], - self._term_state[4], - self._term_state[5], - self._term_state[6]) + TERMSTATE_IDX_TERM = 0 + TERMSTATE_IDX_TIOS = 1 + TERMSTATE_IDX_ROWS = 2 + TERMSTATE_IDX_COLS = 3 + TERMSTATE_IDX_HPIX = 4 + TERMSTATE_IDX_VPIX = 5 + def _update_winsz(self): + self.process.set_winsize(self._term_state[ProcessState.TERMSTATE_IDX_ROWS], + self._term_state[ProcessState.TERMSTATE_IDX_COLS], + self._term_state[ProcessState.TERMSTATE_IDX_HPIX], + self._term_state[ProcessState.TERMSTATE_IDX_VPIX]) REQUEST_IDX_STDIN = 0 - REQUEST_IDX_TERM = 1 - REQUEST_IDX_TIOS = 2 - REQUEST_IDX_ROWS = 3 - REQUEST_IDX_COLS = 4 - REQUEST_IDX_HPIX = 5 - REQUEST_IDX_VPIX = 6 + REQUEST_IDX_TERM = 1 + REQUEST_IDX_TIOS = 2 + REQUEST_IDX_ROWS = 3 + REQUEST_IDX_COLS = 4 + REQUEST_IDX_HPIX = 5 + REQUEST_IDX_VPIX = 6 + + @staticmethod + def default_request(stdin_fd: int | None) -> [any]: + request = [ + None, # Stdin + None, # TERM variable + None, # termios attributes or something + None, # terminal rows + None, # terminal cols + None, # terminal horizontal pixels + None, # terminal vertical pixels + ] + if stdin_fd is not None: + request[ProcessState.REQUEST_IDX_TERM] = os.environ.get("TERM", None) + request[ProcessState.REQUEST_IDX_TIOS] = termios.tcgetattr(stdin_fd) + request[ProcessState.REQUEST_IDX_ROWS], \ + request[ProcessState.REQUEST_IDX_COLS], \ + request[ProcessState.REQUEST_IDX_HPIX], \ + request[ProcessState.REQUEST_IDX_VPIX] = process.tty_get_winsize(stdin_fd) + return request + def process_request(self, data: [any], read_size: int) -> [any]: stdin = data[ProcessState.REQUEST_IDX_STDIN] # Data passed to stdin - term = data[ProcessState.REQUEST_IDX_TERM] # TERM environment variable - tios = data[ProcessState.REQUEST_IDX_TIOS] # termios attr - rows = data[ProcessState.REQUEST_IDX_ROWS] # window rows - cols = data[ProcessState.REQUEST_IDX_COLS] # window cols - hpix = data[ProcessState.REQUEST_IDX_HPIX] # window horizontal pixels - vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels + # term = data[ProcessState.REQUEST_IDX_TERM] # TERM environment variable + # tios = data[ProcessState.REQUEST_IDX_TIOS] # termios attr + # rows = data[ProcessState.REQUEST_IDX_ROWS] # window rows + # cols = data[ProcessState.REQUEST_IDX_COLS] # window cols + # hpix = data[ProcessState.REQUEST_IDX_HPIX] # window horizontal pixels + # vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX] response = ProcessState.default_response() @@ -282,35 +335,37 @@ class ProcessState: RESPONSE_IDX_RUNNING = 0 RESPONSE_IDX_RETCODE = 1 RESPONSE_IDX_RDYBYTE = 2 - RESPONSE_IDX_STDOUT = 3 + RESPONSE_IDX_STDOUT = 3 RESPONSE_IDX_TMSTAMP = 4 + @staticmethod def default_response() -> [any]: return [ - False, # 0: Process running - None, # 1: Return value - 0, # 2: Number of outstanding bytes - None, # 3: Stdout/Stderr - time.time(), # 4: Timestamp + False, # 0: Process running + None, # 1: Return value + 0, # 2: Number of outstanding bytes + None, # 3: Stdout/Stderr + time.time(), # 4: Timestamp ] def _subproc_data_ready(link: RNS.Link, chars_available: int): global _retry_timer - log = _getLogger("_subproc_data_ready") + log = _get_logger("_subproc_data_ready") process_state: ProcessState = link.process - def send(timeout: bool, id: any, tries: int) -> any: + def send(timeout: bool, tag: any, tries: int) -> any: try: pr = process_state.pending_receipt_take() if pr is not None and pr.get_status() != RNS.PacketReceipt.SENT and pr.get_status() != RNS.PacketReceipt.DELIVERED: if not timeout: - _retry_timer.complete(id) - log.debug(f"Packet {id} completed with status {pr.status} on link {link}") + _retry_timer.complete(tag) + log.debug(f"Notification completed with status {pr.status} on link {link}") return link.link_id if not timeout: - log.info(f"Notifying client try {tries} (retcode: {process_state.return_code} chars avail: {chars_available})") + log.info( + f"Notifying client try {tries} (retcode: {process_state.return_code} chars avail: {chars_available})") packet = RNS.Packet(link, DATA_AVAIL_MSG.encode("utf-8")) packet.send() pr = packet.receipt @@ -330,18 +385,20 @@ def _subproc_data_ready(link: RNS.Link, chars_available: int): wait_delay=link.rtt * 3 if link.rtt is not None else 1, try_callback=functools.partial(send, False), timeout_callback=functools.partial(send, True), - id=None) + tag=None) else: log.debug(f"Notification already pending for link {link}") + def _subproc_terminated(link: RNS.Link, return_code: int): - log = _getLogger("_subproc_terminated") + log = _get_logger("_subproc_terminated") log.info(f"Subprocess terminated ({return_code} for link {link}") link.teardown() + def _listen_start_proc(link: RNS.Link, term: str) -> ProcessState | None: global _cmd - log = _getLogger("_listen_start_proc") + log = _get_logger("_listen_start_proc") try: link.process = ProcessState(cmd=_cmd, term=term, @@ -352,17 +409,20 @@ def _listen_start_proc(link: RNS.Link, term: str) -> ProcessState | None: log.error("Failed to launch process: " + str(e)) link.teardown() return None + + def _listen_link_established(link): global _allow_all - log = _getLogger("_listen_link_established") + log = _get_logger("_listen_link_established") link.set_remote_identified_callback(_initiator_identified) link.set_link_closed_callback(_listen_link_closed) - log.info("Link "+str(link)+" established") + log.info("Link " + str(link) + " established") + def _listen_link_closed(link: RNS.Link): - log = _getLogger("_listen_link_closed") + log = _get_logger("_listen_link_closed") # async def cleanup(): - log.info("Link "+str(link)+" closed") + log.info("Link " + str(link) + " closed") proc: ProcessState | None = link.proc if hasattr(link, "process") else None if proc is None: log.warning(f"No process for link {link}") @@ -372,17 +432,19 @@ def _listen_link_closed(link: RNS.Link): log.error(f"Error closing process for link {link}") # asyncio.get_running_loop().call_soon(cleanup) + def _initiator_identified(link, identity): global _allow_all, _cmd - log = _getLogger("_initiator_identified") - log.info("Initiator of link "+str(link)+" identified as "+RNS.prettyhexrep(identity.hash)) + log = _get_logger("_initiator_identified") + log.info("Initiator of link " + str(link) + " identified as " + RNS.prettyhexrep(identity.hash)) if not _allow_all and not identity.hash in _allowed_identity_hashes: - log.warning("Identity "+RNS.prettyhexrep(identity.hash)+" not allowed, tearing down link", RNS.LOG_WARNING) + log.warning("Identity " + RNS.prettyhexrep(identity.hash) + " not allowed, tearing down link", RNS.LOG_WARNING) link.teardown() + def _listen_request(path, data, request_id, link_id, remote_identity, requested_at): global _destination, _retry_timer - log = _getLogger("_listen_request") + log = _get_logger("_listen_request") log.debug(f"listen_execute {path} {request_id} {link_id} {remote_identity}, {requested_at}") _retry_timer.complete(link_id) link: RNS.Link = next(filter(lambda l: l.link_id == link_id, _destination.links)) @@ -391,14 +453,13 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_ return process_state: ProcessState | None = None try: - term = data[1] + term = data[ProcessState.REQUEST_IDX_TERM] process_state = link.process if hasattr(link, "process") else None if process_state is None: log.debug(f"process not found for link {link}") process_state = _listen_start_proc(link, term) - - # leave significant overhead for metadata and encoding + # leave significant headroom for metadata and encoding result = process_state.process_request(data, link.MDU * 3 // 2) except Exception as e: result = ProcessState.default_response() @@ -411,115 +472,70 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_ return result -def spin(until=None, msg=None, timeout=None): - i = 0 - syms = "⢄⢂⢁⡁⡈⡐⡠" - if timeout != None: - timeout = time.time()+timeout - # print(msg+" ", end=" ") - while (timeout == None or time.time() bool: + global _pool + finished = asyncio.get_running_loop().create_future() - # print("\r"+" "*len(msg)+" \r", end="") + def inner_spin(): + while (timeout is None or time.time() < timeout) and not until(): + if _finished.exception(): + finished.set_exception(_finished.exception()) + break + time.sleep(0.005) + if timeout is not None and time.time() > timeout: + finished.set_result(False) + else: + finished.set_result(True) - if timeout != None and time.time() > timeout: - return False + _pool.apply_async(inner_spin) + return await finished + +_link: RNS.Link | None = None +_remote_exec_grace = 2.0 +_new_data: asyncio.Event | None = None +_tr = process.TtyRestorer(sys.stdin.fileno()) + + +def _client_packet_handler(message, packet): + global _new_data + log = _get_logger("_client_packet_handler") + if message is not None and message.decode("utf-8") == DATA_AVAIL_MSG and _new_data is not None: + _new_data.set() else: - return True + log.error(f"received unhandled packet") -current_progress = 0.0 -stats = [] -speed = 0.0 -def spin_stat(until=None, timeout=None): - global current_progress, response_transfer_size, speed - i = 0 - syms = "⢄⢂⢁⡁⡈⡐⡠" - if timeout != None: - timeout = time.time()+timeout - while (timeout == None or time.time() timeout: - return False - else: - return True - -def remote_execution_done(request_receipt): - pass - -def remote_execution_progress(request_receipt): - stats_max = 32 - global current_progress, response_transfer_size, speed - current_progress = request_receipt.progress - response_transfer_size = request_receipt.response_transfer_size - now = time.time() - got = current_progress*response_transfer_size - entry = [now, got] - stats.append(entry) - while len(stats) > stats_max: - stats.pop(0) - - span = now - stats[0][0] - if span == 0: - speed = 0 - else: - diff = got - stats[0][1] - speed = diff/span - -link = None -listener_destination = None -remote_exec_grace = 2.0 -new_data = False - -def client_packet_handler(message, packet): - global new_data - if message is not None and message.decode("utf-8") == DATA_AVAIL_MSG: - new_data = True -def execute(configdir, identitypath = None, verbosity = 0, quietness = 0, noid = False, destination = None, service_name = "default", stdin = None, timeout = RNS.Transport.PATH_REQUEST_TIMEOUT): - global _identity, _reticulum, link, listener_destination, remote_exec_grace +async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=False, destination=None, + service_name="default", stdin=None, timeout=RNS.Transport.PATH_REQUEST_TIMEOUT): + global _identity, _reticulum, _link, _destination, _remote_exec_grace + log = _get_logger("_execute") + dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2 + if len(destination) != dest_len: + raise ValueError( + "Allowed destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format( + hex=dest_len, byte=dest_len // 2)) try: - dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH//8)*2 - if len(destination) != dest_len: - raise ValueError("Allowed destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format(hex=dest_len, byte=dest_len//2)) - try: - destination_hash = bytes.fromhex(destination) - except Exception as e: - raise ValueError("Invalid destination entered. Check your input.") + destination_hash = bytes.fromhex(destination) except Exception as e: - print(str(e)) - return 241 + raise ValueError("Invalid destination entered. Check your input.") - if _reticulum == None: - targetloglevel = 3+verbosity-quietness + if _reticulum is None: + targetloglevel = 2 + verbosity - quietness _reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel) - if _identity == None: + if _identity is None: _prepare_identity(identitypath) if not RNS.Transport.has_path(destination_hash): RNS.Transport.request_path(destination_hash) - if not spin(until=lambda: RNS.Transport.has_path(destination_hash), msg="Path to "+RNS.prettyhexrep(destination_hash)+" requested", timeout=timeout): - print("Path not found") - return 242 + if not await _spin(until=lambda: RNS.Transport.has_path(destination_hash), timeout=timeout): + raise Exception("Path not found") - if listener_destination == None: + if _destination is None: listener_identity = RNS.Identity.recall(destination_hash) - listener_destination = RNS.Destination( + _destination = RNS.Destination( listener_identity, RNS.Destination.OUT, RNS.Destination.SINGLE, @@ -527,285 +543,263 @@ def execute(configdir, identitypath = None, verbosity = 0, quietness = 0, noid = service_name ) - if link == None or link.status == RNS.Link.PENDING: - link = RNS.Link(listener_destination) - link.did_identify = False + if _link is None or _link.status == RNS.Link.PENDING: + _link = RNS.Link(_destination) + _link.did_identify = False - if not spin(until=lambda: link.status == RNS.Link.ACTIVE, msg="Establishing link with "+RNS.prettyhexrep(destination_hash), timeout=timeout): - print("Could not establish link with "+RNS.prettyhexrep(destination_hash)) - return 243 + if not await _spin(until=lambda: _link.status == RNS.Link.ACTIVE, timeout=timeout): + raise Exception("Could not establish link with " + RNS.prettyhexrep(destination_hash)) - if not noid and not link.did_identify: - link.identify(_identity) - link.did_identify = True + if not noid and not _link.did_identify: + _link.identify(_identity) + _link.did_identify = True - link.set_packet_callback(client_packet_handler) + _link.set_packet_callback(_client_packet_handler) - # if stdin != None: - # stdin = stdin.encode("utf-8") - - request_data = [ - (base64.b64encode(stdin) if stdin is not None else None), # Data passed to stdin - ] + request = ProcessState.default_request(sys.stdin.fileno()) + request[ProcessState.REQUEST_IDX_STDIN] = (base64.b64encode(stdin) if stdin is not None else None) # TODO: Tune - rexec_timeout = timeout+link.rtt*4+remote_exec_grace + timeout = timeout + _link.rtt * 4 + _remote_exec_grace - request_receipt = link.request( - path=service_name, - data=request_data, - response_callback=remote_execution_done, - failed_callback=remote_execution_done, - progress_callback=remote_execution_progress, - timeout=rexec_timeout + request_receipt = _link.request( + path="data", + data=request, + timeout=timeout + ) + timeout += 0.5 + + await _spin( + until=lambda: _link.status == RNS.Link.CLOSED or ( + request_receipt.status != RNS.RequestReceipt.FAILED and request_receipt.status != RNS.RequestReceipt.SENT), + timeout=timeout ) - spin( - until=lambda:link.status == RNS.Link.CLOSED or (request_receipt.status != RNS.RequestReceipt.FAILED and request_receipt.status != RNS.RequestReceipt.SENT), - msg="Sending execution request", - timeout=rexec_timeout+0.5 - ) - - if link.status == RNS.Link.CLOSED: - print("Could not request remote execution, link was closed") - return 244 + if _link.status == RNS.Link.CLOSED: + raise Exception("Could not request remote execution, link was closed") if request_receipt.status == RNS.RequestReceipt.FAILED: - print("Could not request remote execution") - return 245 + raise Exception("Could not request remote execution") - spin( - until=lambda:request_receipt.status != RNS.RequestReceipt.DELIVERED, - msg="Command delivered, awaiting result", + await _spin( + until=lambda: request_receipt.status != RNS.RequestReceipt.DELIVERED, timeout=timeout ) if request_receipt.status == RNS.RequestReceipt.FAILED: - print("No result was received") - return 246 - - # spin_stat( - # until=lambda:request_receipt.status != RNS.RequestReceipt.RECEIVING, - # timeout=result_timeout - # ) + raise Exception("No result was received") if request_receipt.status == RNS.RequestReceipt.FAILED: - print("Receiving result failed") - return 247 + raise Exception("Receiving result failed") - if request_receipt.response != None: + if request_receipt.response is not None: try: - running = request_receipt.response[0] - retval = request_receipt.response[1] - stdout = request_receipt.response[2] - stderr = request_receipt.response[3] - timestamp = request_receipt.response[4] - # print("data: " + (stdout.decode("utf-8") if stdout is not None else "")) + running = request_receipt.response[ProcessState.RESPONSE_IDX_RUNNING] + return_code = request_receipt.response[ProcessState.RESPONSE_IDX_RETCODE] + ready_bytes = request_receipt.response[ProcessState.RESPONSE_IDX_RDYBYTE] + stdout = request_receipt.response[ProcessState.RESPONSE_IDX_STDOUT] + timestamp = request_receipt.response[ProcessState.RESPONSE_IDX_TMSTAMP] + # log.debug("data: " + (stdout.decode("utf-8") if stdout is not None else "")) except Exception as e: - print("Received invalid result: " + str(e)) - return 248 + raise Exception(f"Received invalid response") from e if stdout is not None: stdout = base64.b64decode(stdout) - # print(f"stdout: {stdout}") - os.write(sys.stdout.buffer.fileno(), stdout) - # print(stdout.decode("utf-8"), end="") - if stderr is not None: - stderr = base64.b64decode(stderr) - # print(f"stderr: {stderr}") - os.write(sys.stderr.buffer.fileno(), stderr) - # print(stderr.decode("utf-8"), file=sys.stderr, end="") + # log.debug(f"stdout: {stdout}") + os.write(sys.stdout.fileno(), stdout) - sys.stdout.buffer.flush() sys.stdout.flush() - sys.stderr.buffer.flush() sys.stderr.flush() - if not running and retval is not None: - return retval + if not running and return_code is not None: + return return_code return None -def main(): - global new_data - parser = argparse.ArgumentParser(description="Reticulum Remote Execution Utility") - parser.add_argument("destination", nargs="?", default=None, help="hexadecimal hash of the listener", type=str) - parser.add_argument("-c", "--command", nargs="?", default="/bin/zsh", help="command to be execute", type=str) - parser.add_argument("--config", metavar="path", action="store", default=None, help="path to alternative Reticulum config directory", type=str) - parser.add_argument("-s", "--service-name", action="store", default="default", help="service name for connection") - parser.add_argument('-v', '--verbose', action='count', default=0, help="increase verbosity") - parser.add_argument('-q', '--quiet', action='count', default=0, help="decrease verbosity") - parser.add_argument('-p', '--print-identity', action='store_true', default=False, help="print identity and destination info and exit") - parser.add_argument("-l", '--listen', action='store_true', default=False, help="listen for incoming commands") - parser.add_argument('-i', metavar="identity", action='store', dest="identity", default=None, help="path to identity to use", type=str) - parser.add_argument("-x", '--interactive', action='store_true', default=False, help="enter interactive mode") - parser.add_argument("-b", '--no-announce', action='store_true', default=False, help="don't announce at program start") - parser.add_argument('-a', metavar="allowed_hash", dest="allowed", action='append', help="accept from this identity", type=str) - parser.add_argument('-n', '--noauth', action='store_true', default=False, help="accept commands from anyone") - parser.add_argument('-N', '--noid', action='store_true', default=False, help="don't identify to listener") - parser.add_argument("-d", '--detailed', action='store_true', default=False, help="show detailed result output") - parser.add_argument("-m", action='store_true', dest="mirror", default=False, help="mirror exit code of remote command") - parser.add_argument("-w", action="store", metavar="seconds", type=float, help="connect and request timeout before giving up", default=RNS.Transport.PATH_REQUEST_TIMEOUT) - parser.add_argument("-W", action="store", metavar="seconds", type=float, help="max result download time", default=None) - parser.add_argument("--stdin", action='store', default=None, help="pass input to stdin", type=str) - parser.add_argument("--stdout", action='store', default=None, help="max size in bytes of returned stdout", type=int) - parser.add_argument("--stderr", action='store', default=None, help="max size in bytes of returned stderr", type=int) - parser.add_argument("--version", action="version", version="rnx {version}".format(version=__version__)) - args = parser.parse_args() +async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness: int, noid: bool, destination: str, + service_name: str, timeout: float): + global _new_data, _finished, _tr + loop = asyncio.get_event_loop() + _new_data = asyncio.Event() - if args.listen or args.print_identity: - RNS.log("command " + args.command) - _listen( - configdir=args.config, - command=args.command, - identitypath=args.identity, - service_name=args.service_name, - verbosity=args.verbose, - quietness=args.quiet, - allowed=args.allowed, - print_identity=args.print_identity, - disable_auth=args.noauth, - disable_announce=args.no_announce, + def stdout(data: bytes): + # log.debug(f"stdout {data}") + os.write(sys.stdout.fileno(), data) + + def terminated(rc: int): + # log.debug(f"terminated {rc}") + return_code.set_result(rc) + + data_buffer = bytearray() + + def sigint_handler(signal, frame): + # log.debug("KeyboardInterrupt") + data_buffer.extend("\x03".encode("utf-8")) + + def sigwinch_handler(signal, frame): + # log.debug("WindowChanged") + if _new_data is not None: + _new_data.set() + + _tr.raw() + def stdin(): + data = process.tty_read(sys.stdin.fileno()) + # log.debug(f"stdin {data}") + if data is not None: + data_buffer.extend(data) + + process.tty_add_reader_callback(sys.stdin.fileno(), stdin) + + await _pump_int() + signal.signal(signal.SIGWINCH, sigwinch_handler) + signal.signal(signal.SIGINT, sigint_handler) + while True: + stdin = data_buffer.copy() + data_buffer.clear() + _new_data.clear() + + return_code = await _execute( + configdir=configdir, + identitypath=identitypath, + verbosity=verbosity, + quietness=quietness, + noid=noid, + destination=destination, + service_name=service_name, + stdin=stdin, + timeout=timeout, + ) + if return_code is not None: + _link.teardown() + return return_code + + await process.event_wait(_new_data, 5) + + +# def _print_help(): +# # retrieve subparsers from parser +# subparsers_actions = [ +# action for action in parser._actions +# if isinstance(action, argparse._SubParsersAction)] +# # there will probably only be one subparser_action, +# # but better safe than sorry +# for subparsers_action in subparsers_actions: +# # get all subparsers and print help +# for choice, subparser in subparsers_action.choices.items(): +# print("Subparser '{}'".format(choice)) +# print(subparser.format_help()) +# return 0 + +import docopt +import json + + +async def main(): + global _tr, _finished + log = _get_logger("main") + _finished = asyncio.get_running_loop().create_future() + usage = ''' +Usage: + rnsh [--config ] [-i ] [-s ] [-l] -p + rnsh -l [--config ] [-i ] [-s ] [-v...] [-q...] [-b] + (-n | -a [-a ]...) [...] + rnsh [--config ] [-i ] [-s ] [-v...] [-q...] [-N] [-m] + [-w ] + rnsh -h + rnsh --version + +Options: + --config FILE Alternate Reticulum config file to use + -i FILE --identity FILE Specific identity file to use + -s NAME --service NAME Listen on/connect to specific service name if not default + -p --print-identity Print identity information and exit + -l --listen Listen (server) mode + -b --no-announce Do not announce service + -a HASH --allowed HASH Specify identities allowed to connect + -n --no-auth Disable authentication + -N --no-id Disable identify on connect + -m --mirror Client returns with code of remote process + -w TIME --timeout TIME Specify client connect and request timeout in seconds + -v --verbose Increase verbosity + -q --quiet Increase quietness + --version Show version + -h --help Show this help + ''' + args = docopt.docopt(usage, version=f"rnsh {__version__}") + # json.dump(args, sys.stdout) + + args_service_name = args.get("--service", None) or "default" + args_listen = args.get("--listen", None) or False + args_identity = args.get("--identity", None) + args_config = args.get("--config", None) + args_print_identity = args.get("--print-identity", None) or False + args_verbose = args.get("--verbose", None) or 0 + args_quiet = args.get("--quiet", None) or 0 + args_no_announce = args.get("--no-announce", None) or False + args_no_auth = args.get("--no-auth", None) or False + args_allowed = args.get("--allowed", None) or [] + args_program = args.get("", None) + args_program_args = args.get("", None) or [] + args_no_id = args.get("--no-id", None) or False + args_mirror = args.get("--mirror", None) or False + args_timeout = args.get("--timeout", None) + args_destination = args.get("", None) + args_help = args.get("--help", None) or False + + if args_help: + return 0 + + if args_print_identity: + _print_identity(args_config, args_identity, args_service_name, args_listen) + return 0 + + if args_listen: + # log.info("command " + args.command) + await _listen( + configdir=args_config, + command=[args_program].extend(args_program_args), + identitypath=args_identity, + service_name=args_service_name, + verbosity=args_verbose, + quietness=args_quiet, + allowed=args_allowed, + disable_auth=args_no_auth, + disable_announce=args_no_announce, ) - if args.destination is not None and args.service_name is not None: - # command_history_max = 5000 - # command_history = [] - # command_current = "" - # history_idx = 0 - # tty.setcbreak(sys.stdin.fileno()) - - fr = execute( - configdir=args.config, - identitypath=args.identity, - verbosity=args.verbose, - quietness=args.quiet, - noid=args.noid, - destination=args.destination, - service_name=args.service_name, - stdin=os.environ["TERM"].encode("utf-8"), - timeout=args.w, - ) - - if fr is not None: - print(f"Remote returned result {fr}") - exit(1) - - last = datetime.datetime.now() - #reader = NonBlockingStreamReader(sys.stdin.fileno()) - while True: # reader.is_open() and (link is None or link.status != RNS.Link.CLOSED): - stdin = bytearray() - # try: - # try: - # # while True: - # # got = reader.read() - # # if got is None: - # # break - # # stdin.extend(got.encode("utf-8")) - # - # except: - # pass - # - # except KeyboardInterrupt: - # stdin.extend("\x03".encode("utf-8")) - # except EOFError: - # stdin.extend("\x04".encode("utf-8")) - - if new_data or (datetime.datetime.now() - last).total_seconds() > 5 or link is None or (stdin is not None and len(stdin) > 0): - last = datetime.datetime.now() - new_data = False - result = execute( - configdir=args.config, - identitypath=args.identity, - verbosity=args.verbose, - quietness=args.quiet, - noid=args.noid, - destination=args.destination, - service_name=args.service_name, - stdin=stdin, - timeout=args.w, - ) - # print("|", end="") - if result is not None: - break - time.sleep(0.010) - if link is not None: - link.teardown() - + if args_destination is not None and args_service_name is not None: + try: + return_code = await _initiate( + configdir=args_config, + identitypath=args_identity, + verbosity=args_verbose, + quietness=args_quiet, + noid=args_no_id, + destination=args_destination, + service_name=args_service_name, + timeout=args_timeout, + ) + return return_code if args_mirror else 0 + except: + _tr.restore() + raise else: print("") - parser.print_help() + print(args) print("") - # except KeyboardInterrupt: - # pass - # # tty.setnocbreak(sys.stdin.fileno()) - # print("") - # if link != None: - # link.teardown() - # exit() - -def size_str(num, suffix='B'): - units = ['','K','M','G','T','P','E','Z'] - last_unit = 'Y' - - if suffix == 'b': - num *= 8 - units = ['','K','M','G','T','P','E','Z'] - last_unit = 'Y' - - for unit in units: - if abs(num) < 1000.0: - if unit == "": - return "%.0f %s%s" % (num, unit, suffix) - else: - return "%.2f %s%s" % (num, unit, suffix) - num /= 1000.0 - - return "%.2f%s%s" % (num, last_unit, suffix) - -def pretty_time(time, verbose=False): - days = int(time // (24 * 3600)) - time = time % (24 * 3600) - hours = int(time // 3600) - time %= 3600 - minutes = int(time // 60) - time %= 60 - seconds = round(time, 2) - - ss = "" if seconds == 1 else "s" - sm = "" if minutes == 1 else "s" - sh = "" if hours == 1 else "s" - sd = "" if days == 1 else "s" - - components = [] - if days > 0: - components.append(str(days)+" day"+sd if verbose else str(days)+"d") - - if hours > 0: - components.append(str(hours)+" hour"+sh if verbose else str(hours)+"h") - - if minutes > 0: - components.append(str(minutes)+" minute"+sm if verbose else str(minutes)+"m") - - if seconds > 0: - components.append(str(seconds)+" second"+ss if verbose else str(seconds)+"s") - - i = 0 - tstr = "" - for c in components: - i += 1 - if i == 1: - pass - elif i < len(components): - tstr += ", " - elif i == len(components): - tstr += " and " - - tstr += c - - return tstr if __name__ == "__main__": - main() + return_code = 1 + try: + return_code = asyncio.run(main()) + finally: + try: + process.tty_unset_reader_callbacks(sys.stdin.fileno()) + except: + pass + _tr.restore() + _pool.close() + _retry_timer.close() + sys.exit(return_code) diff --git a/rnsh/rnslogging.py b/rnsh/rnslogging.py index 5c017ee..e7a35fb 100644 --- a/rnsh/rnslogging.py +++ b/rnsh/rnslogging.py @@ -2,8 +2,11 @@ import logging from logging import Handler, getLevelName from types import GenericAlias import os - +import tty +import termios +import sys import RNS +import json class RnsHandler(Handler): """ @@ -37,7 +40,12 @@ class RnsHandler(Handler): """ try: msg = self.format(record) + + # tattr = termios.tcgetattr(sys.stdin.fileno()) + # json.dump(tattr, sys.stdout) + # termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, tattr | termios.ONLRET | termios.ONLCR | termios.OPOST) RNS.log(msg, RnsHandler.get_rns_loglevel(record.levelno)) + # termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, tattr) except RecursionError: # See issue 36272 raise except Exception: @@ -49,7 +57,6 @@ class RnsHandler(Handler): __class_getitem__ = classmethod(GenericAlias) - log_format = '%(name)-40s %(message)s [%(threadName)s]' logging.basicConfig( @@ -57,4 +64,22 @@ logging.basicConfig( #format='%(asctime)s.%(msecs)03d %(levelname)-6s %(threadName)-15s %(name)-15s %(message)s', format=log_format, datefmt='%Y-%m-%d %H:%M:%S', - handlers=[RnsHandler()]) \ No newline at end of file + handlers=[RnsHandler()]) + +#hack for temporarily overriding term settings to make debug print right +_rns_log_orig = RNS.log + +def _rns_log(msg, level=3, _override_destination = False): + tattr = termios.tcgetattr(sys.stdin.fileno()) + tattr_orig = tattr.copy() + # tcflag_t c_iflag; /* input modes */ + # tcflag_t c_oflag; /* output modes */ + # tcflag_t c_cflag; /* control modes */ + # tcflag_t c_lflag; /* local modes */ + # cc_t c_cc[NCCS]; /* special characters */ + tattr[1] = tattr[1] | termios.ONLRET | termios.ONLCR | termios.OPOST + termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, tattr) + _rns_log_orig(msg, level, _override_destination) + termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, tattr_orig) + +RNS.log = _rns_log \ No newline at end of file