From 2239f69953fdaa539c519b3c18643297da768377 Mon Sep 17 00:00:00 2001 From: Aaron Heise Date: Fri, 10 Feb 2023 15:29:43 -0600 Subject: [PATCH] Output and termination work --- rnsh/process.py | 47 +++-- rnsh/retry.py | 37 ++-- rnsh/rnsh.py | 443 +++++++++++++++++++++++++++------------------ rnsh/rnslogging.py | 66 ++++--- 4 files changed, 369 insertions(+), 224 deletions(-) diff --git a/rnsh/process.py b/rnsh/process.py index 21c6093..03abea6 100644 --- a/rnsh/process.py +++ b/rnsh/process.py @@ -11,6 +11,7 @@ import os import asyncio import sys import fcntl +import psutil import select import termios import logging as __logging @@ -178,7 +179,7 @@ class CallbackSubprocess: assert terminated_callback is not None, "terminated_callback should not be None" self._log = module_logger.getChild(self.__class__.__name__) - self._log.debug(f"__init__({argv},{term},...") + # self._log.debug(f"__init__({argv},{term},...") self._command = argv self._term = term self._loop = loop @@ -203,12 +204,13 @@ class CallbackSubprocess: pass def kill(): - self._log.debug("kill()") - try: - os.kill(self._pid, signal.SIGHUP) - os.kill(self._pid, signal.SIGKILL) - except: - pass + if process_exists(self._pid): + self._log.debug("kill()") + try: + os.kill(self._pid, signal.SIGHUP) + os.kill(self._pid, signal.SIGKILL) + except: + pass self._loop.call_later(kill_delay, kill) @@ -281,21 +283,35 @@ class CallbackSubprocess: Start the child process. """ self._log.debug("start()") - parentenv = os.environ.copy() - env = {"HOME": parentenv["HOME"], - "TERM": self._term if self._term is not None else parentenv.get("TERM", "xterm"), - "LANG": parentenv.get("LANG"), - "SHELL": self._command[0]} + # parentenv = os.environ.copy() + # env = {"HOME": parentenv["HOME"], + # "TERM": self._term if self._term is not None else parentenv.get("TERM", "xterm"), + # "LANG": parentenv.get("LANG"), + # "SHELL": self._command[0]} + + env = os.environ.copy() + if self._term is not None: + env["TERM"] = self._term self._pid, self._child_fd = pty.fork() if self._pid == 0: try: + p = psutil.Process() + for c in p.connections(kind='all'): + if c == sys.stdin.fileno() or c == sys.stdout.fileno() or c == sys.stderr.fileno(): + continue + try: + os.close(c.fd) + except: + pass os.setpgrp() os.execvpe(self._command[0], self._command, env) except Exception as err: - print(f"Child process error {err}") - sys.exit(0) + print(f"Child process error: {err}") + sys.stdout.flush() + # don't let any other modules get in our way. + os._exit(0) def poll(): # self.log.debug("poll") @@ -306,7 +322,8 @@ class CallbackSubprocess: self._terminated_cb(self._return_code) self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll) except Exception as e: - self._log.debug(f"Error in process poll: {e}") + if not hasattr(e, "errno") or e.errno != errno.ECHILD: + self._log.debug(f"Error in process poll: {e}") self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll) def reader(fd: int, callback: callable): diff --git a/rnsh/retry.py b/rnsh/retry.py index 0d0066a..96d7ea4 100644 --- a/rnsh/retry.py +++ b/rnsh/retry.py @@ -10,6 +10,7 @@ module_logger = __logging.getLogger(__name__) class RetryStatus: def __init__(self, tag: any, try_limit: int, wait_delay: float, retry_callback: Callable[[any, int], any], timeout_callback: Callable[[any, int], None], tries: int = 1): + self._log = module_logger.getChild(self.__class__.__name__) self.tag = tag self.try_limit = try_limit self.tries = tries @@ -21,7 +22,9 @@ class RetryStatus: @property def ready(self): - return self.try_time + self.wait_delay < time.time() and not self.completed + ready = time.time() > self.try_time + self.wait_delay + # self._log.debug(f"ready check {self.tag} try_time {self.try_time} wait_delay {self.wait_delay} next_try {self.try_time + self.wait_delay} now {time.time()} exceeded {time.time() - self.try_time - self.wait_delay} ready {ready}") + return ready @property def timed_out(self): @@ -33,11 +36,12 @@ class RetryStatus: def retry(self): self.tries += 1 + self.try_time = time.time() self.retry_callback(self.tag, self.tries) class RetryThread: - def __init__(self, loop_period: float = 0.25): + def __init__(self, loop_period: float = 0.25, name: str = "retry thread"): self._log = module_logger.getChild(self.__class__.__name__) self._loop_period = loop_period self._statuses: list[RetryStatus] = [] @@ -45,7 +49,7 @@ class RetryThread: self._lock = threading.RLock() self._run = True self._finished: asyncio.Future | None = None - self._thread = threading.Thread(target=self._thread_run) + self._thread = threading.Thread(name=name, target=self._thread_run) self._thread.start() def close(self, loop: asyncio.AbstractEventLoop | None = None) -> asyncio.Future | None: @@ -64,7 +68,7 @@ class RetryThread: ready: list[RetryStatus] = [] prune: list[RetryStatus] = [] with self._lock: - ready.extend(filter(lambda s: s.ready, self._statuses)) + ready.extend(list(filter(lambda s: s.ready, self._statuses))) for retry in ready: try: if not retry.completed: @@ -72,7 +76,7 @@ class RetryThread: self._log.debug(f"timed out {retry.tag} after {retry.try_limit} tries") retry.timeout() prune.append(retry) - else: + elif retry.ready: self._log.debug(f"retrying {retry.tag}, try {retry.tries + 1}/{retry.try_limit}") retry.retry() except Exception as e: @@ -82,7 +86,10 @@ class RetryThread: with self._lock: for retry in prune: self._log.debug(f"pruned retry {retry.tag}, retry count {retry.tries}/{retry.try_limit}") - self._statuses.remove(retry) + try: + self._statuses.remove(retry) + except: + pass if self._finished is not None: self._finished.set_result(None) @@ -90,14 +97,19 @@ class RetryThread: self._tag_counter += 1 return self._tag_counter + def has_tag(self, tag: any) -> bool: + with self._lock: + return next(filter(lambda s: s.tag == tag, self._statuses), None) is not None + def begin(self, try_limit: int, wait_delay: float, try_callback: Callable[[any, int], any], timeout_callback: Callable[[any, int], None], tag: int | None = None) -> any: self._log.debug(f"running first try") tag = try_callback(tag, 1) - self._log.debug(f"first try success, got id {tag}") + self._log.debug(f"first try got id {tag}") with self._lock: if tag is None: tag = self._get_next_tag() + self.complete(tag) self._statuses.append(RetryStatus(tag=tag, tries=1, try_limit=try_limit, @@ -108,16 +120,15 @@ class RetryThread: def complete(self, tag: any): assert tag is not None - status: RetryStatus | None = None with self._lock: - status = next(filter(lambda l: l.tag == tag, self._statuses)) + status = next(filter(lambda l: l.tag == tag, self._statuses), None) if status is not None: status.completed = True self._statuses.remove(status) - if status is not None: - self._log.debug(f"completed {tag}") - else: - self._log.debug(f"status not found to complete {tag}") + self._log.debug(f"completed {tag}") + return + + self._log.debug(f"status not found to complete {tag}") def complete_all(self): with self._lock: diff --git a/rnsh/rnsh.py b/rnsh/rnsh.py index 4331c1e..11b86b6 100644 --- a/rnsh/rnsh.py +++ b/rnsh/rnsh.py @@ -1,8 +1,4 @@ #!/usr/bin/env python3 -import functools -from typing import Callable - -import termios # MIT License # @@ -26,9 +22,12 @@ import termios # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +from __future__ import annotations +import functools +from typing import Callable, TypeVar +import termios import rnslogging import RNS -import argparse import shlex import time import sys @@ -57,31 +56,29 @@ _identity = None _reticulum = None _allow_all = False _allowed_identity_hashes = [] -_cmd: str | None = None +_cmd: [str] = None DATA_AVAIL_MSG = "data available" -_finished: asyncio.Future | None = None +_finished: asyncio.Event | None = None _retry_timer = retry.RetryThread() _destination: RNS.Destination | None = None _pool: ThreadPool = ThreadPool(10) +_loop: asyncio.AbstractEventLoop | None = None -async def _pump_int(timeout: float = 0): - try: - await asyncio.wait_for(_finished, timeout=timeout) - except asyncio.exceptions.CancelledError: - pass - except TimeoutError: - pass +async def _check_finished(timeout: float = 0): + await process.event_wait(_finished, timeout=timeout) -def _handle_sigint_with_async_int(signal, frame): +def _sigint_handler(signal, frame): global _finished + log = _get_logger("_sigint_handler") + log.debug("SIGINT") if _finished is not None: - _finished.get_loop().call_soon_threadsafe(_finished.set_exception, KeyboardInterrupt()) + _finished.set() else: raise KeyboardInterrupt() -signal.signal(signal.SIGINT, _handle_sigint_with_async_int) +signal.signal(signal.SIGINT, _sigint_handler) def _prepare_identity(identity_path): @@ -98,7 +95,6 @@ def _prepare_identity(identity_path): _identity = RNS.Identity() _identity.to_file(identity_path) - def _print_identity(configdir, identitypath, service_name, include_destination: bool): _reticulum = RNS.Reticulum(configdir=configdir, loglevel=RNS.LOG_INFO) _prepare_identity(identitypath) @@ -110,7 +106,7 @@ def _print_identity(configdir, identitypath, service_name, include_destination: async def _listen(configdir, command, identitypath=None, service_name="default", verbosity=0, quietness=0, - allowed=[], disable_auth=None, disable_announce=False): + allowed=None, disable_auth=None, disable_announce=False): global _identity, _allow_all, _allowed_identity_hashes, _reticulum, _cmd, _destination log = _get_logger("_listen") _cmd = command @@ -159,7 +155,7 @@ async def _listen(configdir, command, identitypath=None, service_name="default", allow=RNS.Destination.ALLOW_ALL, ) - await _pump_int() + await _check_finished() log.info("rnsh listening for commands on " + RNS.prettyhexrep(_destination.hash)) @@ -173,13 +169,14 @@ async def _listen(configdir, command, identitypath=None, service_name="default", if not disable_announce and time.time() - last > 900: # TODO: make parameter last = datetime.datetime.now() _destination.announce() - await _pump_int(1.0) + await _check_finished(1.0) except KeyboardInterrupt: log.warning("Shutting down") for link in list(_destination.links): try: - if link.process is not None and link.process.process.running: - link.process.process.terminate() + proc = ProcessState.get_for_tag(link.link_id) + if proc is not None and proc.process.running: + proc.process.terminate() except: pass await asyncio.sleep(1) @@ -190,8 +187,34 @@ async def _listen(configdir, command, identitypath=None, service_name="default", class ProcessState: + _processes: [(any, ProcessState)] = [] + _lock = threading.RLock() + + @classmethod + def get_for_tag(cls, tag: any) -> ProcessState | None: + with cls._lock: + return next(map(lambda p: p[1], filter(lambda p: p[0] == tag, cls._processes)), None) + + @classmethod + def put_for_tag(cls, tag: any, ps: ProcessState): + with cls._lock: + cls.clear_tag(tag) + cls._processes.append((tag, ps)) + + + @classmethod + def clear_tag(cls, tag: any): + with cls._lock: + try: + cls._processes.remove(tag) + except: + pass + + + def __init__(self, - cmd: str, + tag: any, + cmd: [str], mdu: int, data_available_callback: callable, terminated_callback: callable, @@ -201,9 +224,9 @@ class ProcessState: self._log = _get_logger(self.__class__.__name__) self._mdu = mdu self._loop = loop if loop is not None else asyncio.get_running_loop() - self._process = process.CallbackSubprocess(argv=shlex.split(cmd), + self._process = process.CallbackSubprocess(argv=cmd, term=term, - loop=asyncio.get_running_loop(), + loop=loop, stdout_callback=self._stdout_data, terminated_callback=terminated_callback) self._data_buffer = bytearray() @@ -213,6 +236,7 @@ class ProcessState: self._pending_receipt: RNS.PacketReceipt | None = None self._process.start() self._term_state: [int] = None + ProcessState.put_for_tag(tag, self) @property def mdu(self) -> int: @@ -249,8 +273,10 @@ class ProcessState: def read(self, count: int) -> bytes: with self.lock: - take = self._data_buffer[:count - 1] - self._data_buffer = self._data_buffer[count:] + initial_len = len(self._data_buffer) + take = self._data_buffer[:count] + self._data_buffer = self._data_buffer[count:].copy() + self._log.debug(f"read {len(take)} bytes of {initial_len}, {len(self._data_buffer)} remaining") return take def _stdout_data(self, data: bytes): @@ -270,10 +296,14 @@ class ProcessState: TERMSTATE_IDX_VPIX = 5 def _update_winsz(self): - self.process.set_winsize(self._term_state[ProcessState.TERMSTATE_IDX_ROWS], - self._term_state[ProcessState.TERMSTATE_IDX_COLS], - self._term_state[ProcessState.TERMSTATE_IDX_HPIX], - self._term_state[ProcessState.TERMSTATE_IDX_VPIX]) + try: + self.process.set_winsize(self._term_state[ProcessState.TERMSTATE_IDX_ROWS], + self._term_state[ProcessState.TERMSTATE_IDX_COLS], + self._term_state[ProcessState.TERMSTATE_IDX_HPIX], + self._term_state[ProcessState.TERMSTATE_IDX_VPIX]) + except Exception as e: + self._log.debug(f"failed to update winsz: {e}") + REQUEST_IDX_STDIN = 0 REQUEST_IDX_TERM = 1 @@ -285,15 +315,16 @@ class ProcessState: @staticmethod def default_request(stdin_fd: int | None) -> [any]: - request = [ - None, # Stdin - None, # TERM variable - None, # termios attributes or something - None, # terminal rows - None, # terminal cols - None, # terminal horizontal pixels - None, # terminal vertical pixels - ] + request: list[any] = [ + None, # 0 Stdin + None, # 1 TERM variable + None, # 2 termios attributes or something + None, # 3 terminal rows + None, # 4 terminal cols + None, # 5 terminal horizontal pixels + None, # 6 terminal vertical pixels + ].copy() + if stdin_fd is not None: request[ProcessState.REQUEST_IDX_TERM] = os.environ.get("TERM", None) request[ProcessState.REQUEST_IDX_TIOS] = termios.tcgetattr(stdin_fd) @@ -311,10 +342,11 @@ class ProcessState: # cols = data[ProcessState.REQUEST_IDX_COLS] # window cols # hpix = data[ProcessState.REQUEST_IDX_HPIX] # window horizontal pixels # vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels - term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX] + # term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX+1] response = ProcessState.default_response() + term_state = data[ProcessState.REQUEST_IDX_TIOS:ProcessState.REQUEST_IDX_VPIX+1] - response[ProcessState.RESPONSE_IDX_RUNNING] = not self.process.running + response[ProcessState.RESPONSE_IDX_RUNNING] = self.process.running if self.process.running: if term_state != self._term_state: self._term_state = term_state @@ -323,89 +355,122 @@ class ProcessState: stdin = base64.b64decode(stdin) self.process.write(stdin) response[ProcessState.RESPONSE_IDX_RETCODE] = self.return_code - stdout = self.read(read_size) + with self.lock: + stdout = self.read(read_size) response[ProcessState.RESPONSE_IDX_RDYBYTE] = len(self._data_buffer) - response[ProcessState.RESPONSE_IDX_STDOUT] = \ - base64.b64encode(stdout).decode("utf-8") if stdout is not None and len(stdout) > 0 else None + + if stdout is not None and len(stdout) > 0: + response[ProcessState.RESPONSE_IDX_STDOUT] = base64.b64encode(stdout).decode("utf-8") return response RESPONSE_IDX_RUNNING = 0 RESPONSE_IDX_RETCODE = 1 RESPONSE_IDX_RDYBYTE = 2 - RESPONSE_IDX_STDOUT = 3 + RESPONSE_IDX_STDOUT = 3 RESPONSE_IDX_TMSTAMP = 4 @staticmethod def default_response() -> [any]: - return [ + response: list[any] = [ False, # 0: Process running None, # 1: Return value 0, # 2: Number of outstanding bytes None, # 3: Stdout/Stderr - time.time(), # 4: Timestamp - ] + None, # 4: Timestamp + ].copy() + response[ProcessState.RESPONSE_IDX_TMSTAMP] = time.time() + return response def _subproc_data_ready(link: RNS.Link, chars_available: int): global _retry_timer log = _get_logger("_subproc_data_ready") - process_state: ProcessState = link.process + process_state: ProcessState = ProcessState.get_for_tag(link.link_id) def send(timeout: bool, tag: any, tries: int) -> any: - try: - pr = process_state.pending_receipt_take() - if pr is not None and pr.get_status() != RNS.PacketReceipt.SENT and pr.get_status() != RNS.PacketReceipt.DELIVERED: - if not timeout: - _retry_timer.complete(tag) - log.debug(f"Notification completed with status {pr.status} on link {link}") - return link.link_id + # log.debug("send") + def inner(): + # log.debug("inner") + try: + if link.status != RNS.Link.ACTIVE: + _retry_timer.complete(link.link_id) + process_state.pending_receipt_take() + return - if not timeout: - log.info( - f"Notifying client try {tries} (retcode: {process_state.return_code} chars avail: {chars_available})") - packet = RNS.Packet(link, DATA_AVAIL_MSG.encode("utf-8")) - packet.send() - pr = packet.receipt - process_state.pending_receipt_put(pr) - return link.link_id - else: - log.error(f"Retry count exceeded, terminating link {link}") - _retry_timer.complete(link.link_id) - link.teardown() - except Exception as e: - log.error("Error notifying client: " + str(e)) + pr = process_state.pending_receipt_take() + log.debug(f"send inner pr: {pr}") + if pr is not None and pr.status == RNS.PacketReceipt.DELIVERED: + if not timeout: + _retry_timer.complete(tag) + log.debug(f"Notification completed with status {pr.status} on link {link}") + return + else: + if not timeout: + log.info( + f"Notifying client try {tries} (retcode: {process_state.return_code} chars avail: {chars_available})") + packet = RNS.Packet(link, DATA_AVAIL_MSG.encode("utf-8")) + packet.send() + pr = packet.receipt + process_state.pending_receipt_put(pr) + else: + log.error(f"Retry count exceeded, terminating link {link}") + _retry_timer.complete(link.link_id) + link.teardown() + except Exception as e: + log.error("Error notifying client: " + str(e)) + + _loop.call_soon_threadsafe(inner) return link.link_id with process_state.lock: - if process_state.pending_receipt_peek() is None: + if not _retry_timer.has_tag(link.link_id): _retry_timer.begin(try_limit=15, - wait_delay=link.rtt * 3 if link.rtt is not None else 1, + wait_delay=max(link.rtt * 5 if link.rtt is not None else 1, 1), try_callback=functools.partial(send, False), timeout_callback=functools.partial(send, True), tag=None) else: log.debug(f"Notification already pending for link {link}") - def _subproc_terminated(link: RNS.Link, return_code: int): + global _loop log = _get_logger("_subproc_terminated") - log.info(f"Subprocess terminated ({return_code} for link {link}") - link.teardown() + log.info(f"Subprocess returned {return_code} for link {link}") + proc = ProcessState.get_for_tag(link.link_id) + if proc is None: + log.debug(f"no proc for link {link}") + return + + def cleanup(): + def inner(): + log.debug(f"cleanup culled link {link}") + if link and link.status != RNS.Link.CLOSED: + try: + link.teardown() + except: + pass + finally: + ProcessState.clear_tag(link.link_id) + _loop.call_later(300, inner) + _loop.call_soon(_subproc_data_ready, link, 0) + _loop.call_soon_threadsafe(cleanup) -def _listen_start_proc(link: RNS.Link, term: str) -> ProcessState | None: +def _listen_start_proc(link: RNS.Link, term: str, loop: asyncio.AbstractEventLoop) -> ProcessState | None: global _cmd log = _get_logger("_listen_start_proc") try: - link.process = ProcessState(cmd=_cmd, - term=term, - data_available_callback=functools.partial(_subproc_data_ready, link), - terminated_callback=functools.partial(_subproc_terminated, link)) - return link.process + return ProcessState(tag=link.link_id, + cmd=_cmd, + term=term, + mdu=link.MDU, + loop=loop, + data_available_callback=functools.partial(_subproc_data_ready, link), + terminated_callback=functools.partial(_subproc_terminated, link)) except Exception as e: log.error("Failed to launch process: " + str(e)) - link.teardown() + _subproc_terminated(link, 255) return None @@ -421,18 +486,20 @@ def _listen_link_closed(link: RNS.Link): log = _get_logger("_listen_link_closed") # async def cleanup(): log.info("Link " + str(link) + " closed") - proc: ProcessState | None = link.proc if hasattr(link, "process") else None + proc: ProcessState | None = ProcessState.get_for_tag(link.link_id) if proc is None: log.warning(f"No process for link {link}") - try: - proc.process.terminate() - except: - log.error(f"Error closing process for link {link}") - # asyncio.get_running_loop().call_soon(cleanup) + else: + try: + proc.process.terminate() + _retry_timer.complete(link.link_id) + except Exception as e: + log.error(f"Error closing process for link {link}: {e}") + ProcessState.clear_tag(link.link_id) def _initiator_identified(link, identity): - global _allow_all, _cmd + global _allow_all, _cmd, _loop log = _get_logger("_initiator_identified") log.info("Initiator of link " + str(link) + " identified as " + RNS.prettyhexrep(identity.hash)) if not _allow_all and not identity.hash in _allowed_identity_hashes: @@ -441,34 +508,34 @@ def _initiator_identified(link, identity): def _listen_request(path, data, request_id, link_id, remote_identity, requested_at): - global _destination, _retry_timer + global _destination, _retry_timer, _loop log = _get_logger("_listen_request") log.debug(f"listen_execute {path} {request_id} {link_id} {remote_identity}, {requested_at}") _retry_timer.complete(link_id) - link: RNS.Link = next(filter(lambda l: l.link_id == link_id, _destination.links)) + link: RNS.Link = next(filter(lambda l: l.link_id == link_id, _destination.links), None) if link is None: - log.error(f"invalid request {request_id}, no link found with id {link_id}") - return + raise Exception(f"Invalid request {request_id}, no link found with id {link_id}") process_state: ProcessState | None = None try: term = data[ProcessState.REQUEST_IDX_TERM] - process_state = link.process if hasattr(link, "process") else None + process_state = ProcessState.get_for_tag(link.link_id) if process_state is None: - log.debug(f"process not found for link {link}") - process_state = _listen_start_proc(link, term) + log.debug(f"Process not found for link {link}") + process_state = _listen_start_proc(link, term, _loop) # leave significant headroom for metadata and encoding result = process_state.process_request(data, link.MDU * 3 // 2) + return result + # return ProcessState.default_response() except Exception as e: - result = ProcessState.default_response() + log.error(f"Error procesing request for link {link}: {e}") try: - if process_state is not None: + if process_state is not None and process_state.process.running: process_state.process.terminate() - link.teardown() - except Exception as e: - log.error(f"Error terminating process for link {link}") + except Exception as ee: + log.debug(f"Error terminating process for link {link}: {ee}") - return result + return ProcessState.default_response() async def _spin(until: Callable | None = None, timeout: float | None = None) -> bool: @@ -477,12 +544,13 @@ async def _spin(until: Callable | None = None, timeout: float | None = None) -> timeout += time.time() while (timeout is None or time.time() < timeout) and not until(): - await _pump_int(0.01) + await _check_finished(0.01) if timeout is not None and time.time() > timeout: return False else: return True + _link: RNS.Link | None = None _remote_exec_grace = 2.0 _new_data: asyncio.Event | None = None @@ -493,25 +561,35 @@ def _client_packet_handler(message, packet): global _new_data log = _get_logger("_client_packet_handler") if message is not None and message.decode("utf-8") == DATA_AVAIL_MSG and _new_data is not None: + log.debug("data available") _new_data.set() else: log.error(f"received unhandled packet") +class RemoteExecutionError(Exception): + def __init__(self, msg): + self.msg = msg + + +def _response_handler(request_receipt: RNS.RequestReceipt): + pass + + async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=False, destination=None, service_name="default", stdin=None, timeout=RNS.Transport.PATH_REQUEST_TIMEOUT): - global _identity, _reticulum, _link, _destination, _remote_exec_grace, _tr + global _identity, _reticulum, _link, _destination, _remote_exec_grace, _tr, _new_data log = _get_logger("_execute") dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2 if len(destination) != dest_len: - raise ValueError( + raise RemoteExecutionError( "Allowed destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format( hex=dest_len, byte=dest_len // 2)) try: destination_hash = bytes.fromhex(destination) except Exception as e: - raise ValueError("Invalid destination entered. Check your input.") + raise RemoteExecutionError("Invalid destination entered. Check your input.") if _reticulum is None: targetloglevel = 2 + verbosity - quietness @@ -524,7 +602,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid= RNS.Transport.request_path(destination_hash) log.info(f"Requesting path...") if not await _spin(until=lambda: RNS.Transport.has_path(destination_hash), timeout=timeout): - raise Exception("Path not found") + raise RemoteExecutionError("Path not found") if _destination is None: listener_identity = RNS.Identity.recall(destination_hash) @@ -542,7 +620,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid= log.info(f"Establishing link...") if not await _spin(until=lambda: _link.status == RNS.Link.ACTIVE, timeout=timeout): - raise Exception("Could not establish link with " + RNS.prettyhexrep(destination_hash)) + raise RemoteExecutionError("Could not establish link with " + RNS.prettyhexrep(destination_hash)) if not noid and not _link.did_identify: _link.identify(_identity) @@ -570,10 +648,10 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid= ) if _link.status == RNS.Link.CLOSED: - raise Exception("Could not request remote execution, link was closed") + raise RemoteExecutionError("Could not request remote execution, link was closed") if request_receipt.status == RNS.RequestReceipt.FAILED: - raise Exception("Could not request remote execution") + raise RemoteExecutionError("Could not request remote execution") await _spin( until=lambda: request_receipt.status != RNS.RequestReceipt.DELIVERED, @@ -581,21 +659,21 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid= ) if request_receipt.status == RNS.RequestReceipt.FAILED: - raise Exception("No result was received") + raise RemoteExecutionError("No result was received") if request_receipt.status == RNS.RequestReceipt.FAILED: - raise Exception("Receiving result failed") + raise RemoteExecutionError("Receiving result failed") if request_receipt.response is not None: try: - running = request_receipt.response[ProcessState.RESPONSE_IDX_RUNNING] + running = request_receipt.response[ProcessState.RESPONSE_IDX_RUNNING] or True return_code = request_receipt.response[ProcessState.RESPONSE_IDX_RETCODE] - ready_bytes = request_receipt.response[ProcessState.RESPONSE_IDX_RDYBYTE] + ready_bytes = request_receipt.response[ProcessState.RESPONSE_IDX_RDYBYTE] or 0 stdout = request_receipt.response[ProcessState.RESPONSE_IDX_STDOUT] timestamp = request_receipt.response[ProcessState.RESPONSE_IDX_TMSTAMP] # log.debug("data: " + (stdout.decode("utf-8") if stdout is not None else "")) except Exception as e: - raise Exception(f"Received invalid response") from e + raise RemoteExecutionError(f"Received invalid response") from e _tr.raw() if stdout is not None: @@ -606,8 +684,14 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid= sys.stdout.flush() sys.stderr.flush() - if not running and return_code is not None: - return return_code + log.debug(f"{ready_bytes} bytes ready on server, return code {return_code}") + + if ready_bytes > 0: + _new_data.set() + + if (not running or return_code is not None) and (ready_bytes == 0): + log.debug(f"returning running: {running}, return_code: {return_code}") + return return_code or 255 return None @@ -616,24 +700,16 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness service_name: str, timeout: float): global _new_data, _finished, _tr log = _get_logger("_initiate") - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() _new_data = asyncio.Event() - def stdout(data: bytes): - # log.debug(f"stdout {data}") - os.write(sys.stdout.fileno(), data) - - def terminated(rc: int): - # log.debug(f"terminated {rc}") - return_code.set_result(rc) - data_buffer = bytearray() - def sigint_handler(signal, frame): + def sigint_handler(): log.debug("KeyboardInterrupt") data_buffer.extend("\x03".encode("utf-8")) - def sigwinch_handler(signal, frame): + def sigwinch_handler(): # log.debug("WindowChanged") if _new_data is not None: _new_data.set() @@ -646,66 +722,79 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness process.tty_add_reader_callback(sys.stdin.fileno(), stdin) - await _pump_int() - signal.signal(signal.SIGWINCH, sigwinch_handler) + await _check_finished() + # signal.signal(signal.SIGWINCH, sigwinch_handler) + loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler) + first_loop = True while True: - stdin = data_buffer.copy() - data_buffer.clear() - _new_data.clear() + try: + log.debug("top of client loop") + stdin = data_buffer.copy() + data_buffer.clear() + _new_data.clear() + log.debug("before _execute") + return_code = await _execute( + configdir=configdir, + identitypath=identitypath, + verbosity=verbosity, + quietness=quietness, + noid=noid, + destination=destination, + service_name=service_name, + stdin=stdin, + timeout=timeout, + ) + # signal.signal(signal.SIGINT, sigint_handler) + if first_loop: + first_loop = False + loop.remove_signal_handler(signal.SIGINT) + loop.add_signal_handler(signal.SIGINT, sigint_handler) + _new_data.set() - return_code = await _execute( - configdir=configdir, - identitypath=identitypath, - verbosity=verbosity, - quietness=quietness, - noid=noid, - destination=destination, - service_name=service_name, - stdin=stdin, - timeout=timeout, - ) - signal.signal(signal.SIGINT, sigint_handler) - if return_code is not None: - _link.teardown() - return return_code + if return_code is not None: + log.debug(f"received return code {return_code}, exiting") + try: + _link.teardown() + except: + pass + return return_code + except RemoteExecutionError as e: + print(e.msg) + return 255 await process.event_wait(_new_data, 5) -# def _print_help(): -# # retrieve subparsers from parser -# subparsers_actions = [ -# action for action in parser._actions -# if isinstance(action, argparse._SubParsersAction)] -# # there will probably only be one subparser_action, -# # but better safe than sorry -# for subparsers_action in subparsers_actions: -# # get all subparsers and print help -# for choice, subparser in subparsers_action.choices.items(): -# print("Subparser '{}'".format(choice)) -# print(subparser.format_help()) -# return 0 - -import docopt -import json +_T = TypeVar("_T") +def _split_array_at(arr: [_T], at: _T) -> ([_T], [_T]): + try: + idx = arr.index(at) + return arr[:idx], arr[idx+1:] + except ValueError: + return arr, [] async def main(): - global _tr, _finished + global _tr, _finished, _loop + import docopt log = _get_logger("main") - _finished = asyncio.get_running_loop().create_future() + _loop = asyncio.get_running_loop() + rnslogging.set_main_loop(_loop) + _finished = asyncio.Event() + _loop.remove_signal_handler(signal.SIGINT) + _loop.add_signal_handler(signal.SIGINT, functools.partial(_sigint_handler, signal.SIGINT, None)) usage = ''' Usage: - rnsh [--config ] [-i ] [-s ] [-l] -p + rnsh [--config ] [-i ] [-s ] [-l] -p rnsh -l [--config ] [-i ] [-s ] [-v...] [-q...] [-b] - (-n | -a [-a ]...) [...] + (-n | -a [-a ]...) [--] [...] rnsh [--config ] [-i ] [-s ] [-v...] [-q...] [-N] [-m] [-w ] rnsh -h rnsh --version Options: - --config FILE Alternate Reticulum config file to use + --config FILE Alternate Reticulum config directory to use -i FILE --identity FILE Specific identity file to use -s NAME --service NAME Listen on/connect to specific service name if not default -p --print-identity Print identity information and exit @@ -721,7 +810,13 @@ Options: --version Show version -h --help Show this help ''' - args = docopt.docopt(usage, version=f"rnsh {__version__}") + + argv, program_args = _split_array_at(sys.argv, "--") + if len(program_args) > 0: + argv.append(program_args[0]) + program_args = program_args[1:] + + args = docopt.docopt(usage, argv=argv[1:], version=f"rnsh {__version__}") # json.dump(args, sys.stdout) args_service_name = args.get("--service", None) or "default" @@ -736,6 +831,8 @@ Options: args_allowed = args.get("--allowed", None) or [] args_program = args.get("", None) args_program_args = args.get("", None) or [] + args_program_args.insert(0, args_program) + args_program_args.extend(program_args) args_no_id = args.get("--no-id", None) or False args_mirror = args.get("--mirror", None) or False args_timeout = args.get("--timeout", None) or RNS.Transport.PATH_REQUEST_TIMEOUT @@ -753,7 +850,7 @@ Options: # log.info("command " + args.command) await _listen( configdir=args_config, - command=[args_program].extend(args_program_args), + command=args_program_args, identitypath=args_identity, service_name=args_service_name, verbosity=args_verbose, diff --git a/rnsh/rnslogging.py b/rnsh/rnslogging.py index c6a52a0..40b12d1 100644 --- a/rnsh/rnslogging.py +++ b/rnsh/rnslogging.py @@ -3,6 +3,8 @@ from logging import Handler, getLevelName from types import GenericAlias import os import tty +from typing import List, Any +import asyncio import termios import sys import RNS @@ -41,11 +43,7 @@ class RnsHandler(Handler): try: msg = self.format(record) - # tattr = termios.tcgetattr(sys.stdin.fileno()) - # json.dump(tattr, sys.stdout) - # termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, tattr | termios.ONLRET | termios.ONLCR | termios.OPOST) RNS.log(msg, RnsHandler.get_rns_loglevel(record.levelno)) - # termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, tattr) except RecursionError: # See issue 36272 raise except Exception: @@ -57,29 +55,51 @@ class RnsHandler(Handler): __class_getitem__ = classmethod(GenericAlias) -log_format = '%(name)-40s %(message)s [%(threadName)s]' +log_format = '%(name)-30s %(message)s [%(threadName)s]' logging.basicConfig( - level=logging.INFO, + level=logging.DEBUG, # RNS.log will filter it, but some formatting will still be processed before it gets there #format='%(asctime)s.%(msecs)03d %(levelname)-6s %(threadName)-15s %(name)-15s %(message)s', format=log_format, datefmt='%Y-%m-%d %H:%M:%S', handlers=[RnsHandler()]) -# #hack for temporarily overriding term settings to make debug print right -# _rns_log_orig = RNS.log -# -# def _rns_log(msg, level=3, _override_destination = False): -# tattr = termios.tcgetattr(sys.stdin.fileno()) -# tattr_orig = tattr.copy() -# # tcflag_t c_iflag; /* input modes */ -# # tcflag_t c_oflag; /* output modes */ -# # tcflag_t c_cflag; /* control modes */ -# # tcflag_t c_lflag; /* local modes */ -# # cc_t c_cc[NCCS]; /* special characters */ -# tattr[1] = tattr[1] | termios.ONLRET | termios.ONLCR | termios.OPOST -# termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, tattr) -# _rns_log_orig(msg, level, _override_destination) -# termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, tattr_orig) -# -# RNS.log = _rns_log \ No newline at end of file +_loop: asyncio.AbstractEventLoop | None = None +def set_main_loop(loop: asyncio.AbstractEventLoop): + global _loop + _loop = loop + +#hack for temporarily overriding term settings to make debug print right +_rns_log_orig = RNS.log + +def _rns_log(msg, level=3, _override_destination = False): + if not RNS.compact_log_fmt: + msg = (" " * (7 - len(RNS.loglevelname(level)))) + msg + def inner(): + tattr_orig: list[Any] | None = None + try: + tattr = termios.tcgetattr(sys.stdin.fileno()) + tattr_orig = tattr.copy() + # tcflag_t c_iflag; /* input modes */ + # tcflag_t c_oflag; /* output modes */ + # tcflag_t c_cflag; /* control modes */ + # tcflag_t c_lflag; /* local modes */ + # cc_t c_cc[NCCS]; /* special characters */ + tattr[1] = tattr[1] | termios.ONLRET | termios.ONLCR | termios.OPOST + termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, tattr) + except: + pass + + _rns_log_orig(msg, level, _override_destination) + + if tattr_orig is not None: + termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, tattr_orig) + try: + if _loop: + _loop.call_soon_threadsafe(inner) + else: + inner() + except: + inner() + +RNS.log = _rns_log \ No newline at end of file