diff --git a/pyproject.toml b/pyproject.toml index ebc86b7..a64d610 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,9 +8,9 @@ readme = "README.md" [tool.poetry.dependencies] python = "^3.9" -rns = {git = "https://github.com/markqvist/Reticulum.git", rev = "3706769"} docopt = "^0.6.2" psutil = "^5.9.4" +rns = "^0.4.8" [tool.poetry.scripts] rnsh = 'rnsh.rnsh:rnsh_cli' diff --git a/rnsh/hacks.py b/rnsh/hacks.py new file mode 100644 index 0000000..20aeced --- /dev/null +++ b/rnsh/hacks.py @@ -0,0 +1,67 @@ +import asyncio +import threading +import RNS +import logging +module_logger = logging.getLogger(__name__) + + +class SelfPruningAssociation: + def __init__(self, ttl: float, loop: asyncio.AbstractEventLoop): + self._log = module_logger.getChild(self.__class__.__name__) + self._ttl = ttl + self._associations = list() + self._lock = threading.RLock() + self._loop = loop + # self._log.debug("__init__") + + def _schedule_prune(self, pair): + # self._log.debug(f"schedule prune {pair}") + + def schedule_prune_inner(p): + # self._log.debug(f"schedule inner {p}") + self._loop.call_later(self._ttl, self._prune, p) + + self._loop.call_soon_threadsafe(schedule_prune_inner, pair) + + def _prune(self, pair: any): + # self._log.debug(f"prune {pair}") + with self._lock: + self._associations.remove(pair) + + def get(self, key: any) -> any: + # self._log.debug(f"get {key}") + with self._lock: + pair = next(filter(lambda x: x[0] == key, self._associations), None) + if not pair: + return None + return pair[1] + + def put(self, key: any, value: any): + # self._log.debug(f"put {key},{value}") + with self._lock: + pair = (key, value) + self._associations.append(pair) + self._schedule_prune(pair) + + +_request_to_link: SelfPruningAssociation = None +_link_handle_request_orig = RNS.Link.handle_request + + +def _link_handle_request(self, request_id, unpacked_request): + global _request_to_link + _request_to_link.put(request_id, self.link_id) + _link_handle_request_orig(self, request_id, unpacked_request) + + +def request_request_id_hack(new_api_reponse_generator, loop): + global _request_to_link + if _request_to_link is None: + RNS.Link.handle_request = _link_handle_request + _request_to_link = SelfPruningAssociation(1.0, loop) + + def listen_request(path, data, request_id, remote_identity, requested_at): + link_id = _request_to_link.get(request_id) + return new_api_reponse_generator(path, data, request_id, link_id, remote_identity, requested_at) + + return listen_request diff --git a/rnsh/retry.py b/rnsh/retry.py index dba8141..cf91a81 100644 --- a/rnsh/retry.py +++ b/rnsh/retry.py @@ -27,6 +27,9 @@ import time import rnsh.exception as exception import logging as __logging from typing import Callable +from contextlib import AbstractContextManager +import types +import typing module_logger = __logging.getLogger(__name__) @@ -67,7 +70,7 @@ class RetryStatus: self.retry_callback(self.tag, self.tries) -class RetryThread: +class RetryThread(AbstractContextManager): def __init__(self, loop_period: float = 0.25, name: str = "retry thread"): self._log = module_logger.getChild(self.__class__.__name__) self._loop_period = loop_period @@ -176,3 +179,8 @@ class RetryThread: status.completed = True self._log.debug(f"completed {status.tag}") self._statuses.clear() + + def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException, + __traceback: types.TracebackType) -> bool: + self.close() + return False diff --git a/rnsh/rnsh.py b/rnsh/rnsh.py index f53262e..539ca23 100644 --- a/rnsh/rnsh.py +++ b/rnsh/rnsh.py @@ -35,12 +35,12 @@ import termios import threading import time from typing import Callable, TypeVar - import RNS import rnsh.exception as exception import rnsh.process as process import rnsh.retry as retry import rnsh.rnslogging as rnslogging +import rnsh.hacks as hacks from rnsh.__version import __version__ module_logger = __logging.getLogger(__name__) @@ -59,7 +59,7 @@ _allowed_identity_hashes = [] _cmd: [str] = None DATA_AVAIL_MSG = "data available" _finished: asyncio.Event | None = None -_retry_timer = retry.RetryThread() +_retry_timer: retry.RetryThread | None = None _destination: RNS.Destination | None = None _loop: asyncio.AbstractEventLoop | None = None @@ -107,8 +107,10 @@ def _print_identity(configdir, identitypath, service_name, include_destination: exit(0) +# hack_goes_here + async def _listen(configdir, command, identitypath=None, service_name="default", verbosity=0, quietness=0, - allowed=None, disable_auth=None, disable_announce=False): + allowed=None, disable_auth=None, announce_period=900): global _identity, _allow_all, _allowed_identity_hashes, _reticulum, _cmd, _destination log = _get_logger("_listen") _cmd = command @@ -147,14 +149,14 @@ async def _listen(configdir, command, identitypath=None, service_name="default", if not _allow_all: _destination.register_request_handler( path="data", - response_generator=_listen_request, + response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()), allow=RNS.Destination.ALLOW_LIST, allowed_list=_allowed_identity_hashes ) else: _destination.register_request_handler( path="data", - response_generator=_listen_request, + response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()), allow=RNS.Destination.ALLOW_ALL, ) @@ -162,14 +164,14 @@ async def _listen(configdir, command, identitypath=None, service_name="default", log.info("rnsh listening for commands on " + RNS.prettyhexrep(_destination.hash)) - if not disable_announce: + if announce_period is not None: _destination.announce() last = time.time() try: while True: - if not disable_announce and time.time() - last > 900: # TODO: make parameter + if announce_period and 0 < announce_period < time.time() - last: last = time.time() _destination.announce() await _check_finished(1.0) @@ -177,7 +179,7 @@ async def _listen(configdir, command, identitypath=None, service_name="default", log.warning("Shutting down") for link in list(_destination.links): with exception.permit(SystemExit): - proc = ProcessState.get_for_tag(link.link_id) + proc = Session.get_for_tag(link.link_id) if proc is not None and proc.process.running: proc.process.terminate() await asyncio.sleep(1) @@ -187,17 +189,35 @@ async def _listen(configdir, command, identitypath=None, service_name="default", link.teardown() -class ProcessState: - _processes: [(any, ProcessState)] = [] +_PROTOCOL_MAGIC = 0xdeadbeef + + +def _protocol_make_version(version: int): + return (_PROTOCOL_MAGIC << 32) & 0xffffffff00000000 | (0xffffffff & version) + + +_PROTOCOL_VERSION_0 = _protocol_make_version(0) + + +def _protocol_split_version(version: int): + return (version >> 32) & 0xffffffff, version & 0xffffffff + + +def _protocol_check_magic(value: int): + return _protocol_split_version(value)[0] == _PROTOCOL_MAGIC + + +class Session: + _processes: [(any, Session)] = [] _lock = threading.RLock() @classmethod - def get_for_tag(cls, tag: any) -> ProcessState | None: + def get_for_tag(cls, tag: any) -> Session | None: with cls._lock: return next(map(lambda p: p[1], filter(lambda p: p[0] == tag, cls._processes)), None) @classmethod - def put_for_tag(cls, tag: any, ps: ProcessState): + def put_for_tag(cls, tag: any, ps: Session): with cls._lock: cls.clear_tag(tag) cls._processes.append((tag, ps)) @@ -234,7 +254,7 @@ class ProcessState: self._pending_receipt: RNS.PacketReceipt | None = None self._process.start() self._term_state: [int] = None - ProcessState.put_for_tag(tag, self) + Session.put_for_tag(tag, self) @property def mdu(self) -> int: @@ -302,37 +322,40 @@ class ProcessState: except Exception as e: self._log.debug(f"failed to update winsz: {e}") - REQUEST_IDX_STDIN = 0 - REQUEST_IDX_TERM = 1 - REQUEST_IDX_TIOS = 2 - REQUEST_IDX_ROWS = 3 - REQUEST_IDX_COLS = 4 - REQUEST_IDX_HPIX = 5 - REQUEST_IDX_VPIX = 6 + REQUEST_IDX_VERS = 0 + REQUEST_IDX_STDIN = 1 + REQUEST_IDX_TERM = 2 + REQUEST_IDX_TIOS = 3 + REQUEST_IDX_ROWS = 4 + REQUEST_IDX_COLS = 5 + REQUEST_IDX_HPIX = 6 + REQUEST_IDX_VPIX = 7 @staticmethod def default_request(stdin_fd: int | None) -> [any]: + global _PROTOCOL_VERSION_0 request: list[any] = [ - None, # 0 Stdin - None, # 1 TERM variable - None, # 2 termios attributes or something - None, # 3 terminal rows - None, # 4 terminal cols - None, # 5 terminal horizontal pixels - None, # 6 terminal vertical pixels + _PROTOCOL_VERSION_0, # 0 Protocol Version + None, # 1 Stdin + None, # 2 TERM variable + None, # 3 termios attributes or something + None, # 4 terminal rows + None, # 5 terminal cols + None, # 6 terminal horizontal pixels + None, # 7 terminal vertical pixels ].copy() if stdin_fd is not None: - request[ProcessState.REQUEST_IDX_TERM] = os.environ.get("TERM", None) - request[ProcessState.REQUEST_IDX_TIOS] = termios.tcgetattr(stdin_fd) - request[ProcessState.REQUEST_IDX_ROWS], \ - request[ProcessState.REQUEST_IDX_COLS], \ - request[ProcessState.REQUEST_IDX_HPIX], \ - request[ProcessState.REQUEST_IDX_VPIX] = process.tty_get_winsize(stdin_fd) + request[Session.REQUEST_IDX_TERM] = os.environ.get("TERM", None) + request[Session.REQUEST_IDX_TIOS] = termios.tcgetattr(stdin_fd) + request[Session.REQUEST_IDX_ROWS], \ + request[Session.REQUEST_IDX_COLS], \ + request[Session.REQUEST_IDX_HPIX], \ + request[Session.REQUEST_IDX_VPIX] = process.tty_get_winsize(stdin_fd) return request def process_request(self, data: [any], read_size: int) -> [any]: - stdin = data[ProcessState.REQUEST_IDX_STDIN] # Data passed to stdin + stdin = data[Session.REQUEST_IDX_STDIN] # Data passed to stdin # term = data[ProcessState.REQUEST_IDX_TERM] # TERM environment variable # tios = data[ProcessState.REQUEST_IDX_TIOS] # termios attr # rows = data[ProcessState.REQUEST_IDX_ROWS] # window rows @@ -340,10 +363,10 @@ class ProcessState: # hpix = data[ProcessState.REQUEST_IDX_HPIX] # window horizontal pixels # vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels # term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX+1] - response = ProcessState.default_response() - term_state = data[ProcessState.REQUEST_IDX_TIOS:ProcessState.REQUEST_IDX_VPIX + 1] + response = Session.default_response() + term_state = data[Session.REQUEST_IDX_TIOS:Session.REQUEST_IDX_VPIX + 1] - response[ProcessState.RESPONSE_IDX_RUNNING] = self.process.running + response[Session.RESPONSE_IDX_RUNNING] = self.process.running if self.process.running: if term_state != self._term_state: self._term_state = term_state @@ -351,39 +374,42 @@ class ProcessState: if stdin is not None and len(stdin) > 0: stdin = base64.b64decode(stdin) self.process.write(stdin) - response[ProcessState.RESPONSE_IDX_RETCODE] = None if self.process.running else self.return_code + response[Session.RESPONSE_IDX_RETCODE] = None if self.process.running else self.return_code with self.lock: stdout = self.read(read_size) - response[ProcessState.RESPONSE_IDX_RDYBYTE] = len(self._data_buffer) + response[Session.RESPONSE_IDX_RDYBYTE] = len(self._data_buffer) if stdout is not None and len(stdout) > 0: - response[ProcessState.RESPONSE_IDX_STDOUT] = base64.b64encode(stdout).decode("utf-8") + response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode(stdout).decode("utf-8") return response - RESPONSE_IDX_RUNNING = 0 - RESPONSE_IDX_RETCODE = 1 - RESPONSE_IDX_RDYBYTE = 2 - RESPONSE_IDX_STDOUT = 3 - RESPONSE_IDX_TMSTAMP = 4 + RESPONSE_IDX_VERSION = 0 + RESPONSE_IDX_RUNNING = 1 + RESPONSE_IDX_RETCODE = 2 + RESPONSE_IDX_RDYBYTE = 3 + RESPONSE_IDX_STDOUT = 4 + RESPONSE_IDX_TMSTAMP = 5 @staticmethod def default_response() -> [any]: + global _PROTOCOL_VERSION_0 response: list[any] = [ - False, # 0: Process running - None, # 1: Return value - 0, # 2: Number of outstanding bytes - None, # 3: Stdout/Stderr - None, # 4: Timestamp + _PROTOCOL_VERSION_0, # 0: Protocol version + False, # 1: Process running + None, # 2: Return value + 0, # 3: Number of outstanding bytes + None, # 4: Stdout/Stderr + None, # 5: Timestamp ].copy() - response[ProcessState.RESPONSE_IDX_TMSTAMP] = time.time() + response[Session.RESPONSE_IDX_TMSTAMP] = time.time() return response def _subproc_data_ready(link: RNS.Link, chars_available: int): global _retry_timer log = _get_logger("_subproc_data_ready") - process_state: ProcessState = ProcessState.get_for_tag(link.link_id) + session: Session = Session.get_for_tag(link.link_id) def send(timeout: bool, tag: any, tries: int) -> any: # log.debug("send") @@ -392,10 +418,10 @@ def _subproc_data_ready(link: RNS.Link, chars_available: int): try: if link.status != RNS.Link.ACTIVE: _retry_timer.complete(link.link_id) - process_state.pending_receipt_take() + session.pending_receipt_take() return - pr = process_state.pending_receipt_take() + pr = session.pending_receipt_take() log.debug(f"send inner pr: {pr}") if pr is not None and pr.status == RNS.PacketReceipt.DELIVERED: if not timeout: @@ -405,12 +431,12 @@ def _subproc_data_ready(link: RNS.Link, chars_available: int): else: if not timeout: log.info( - f"Notifying client try {tries} (retcode: {process_state.return_code} " + + f"Notifying client try {tries} (retcode: {session.return_code} " + f"chars avail: {chars_available})") packet = RNS.Packet(link, DATA_AVAIL_MSG.encode("utf-8")) packet.send() pr = packet.receipt - process_state.pending_receipt_put(pr) + session.pending_receipt_put(pr) else: log.error(f"Retry count exceeded, terminating link {link}") _retry_timer.complete(link.link_id) @@ -421,7 +447,7 @@ def _subproc_data_ready(link: RNS.Link, chars_available: int): _loop.call_soon_threadsafe(inner) return link.link_id - with process_state.lock: + with session.lock: if not _retry_timer.has_tag(link.link_id): _retry_timer.begin(try_limit=15, wait_delay=max(link.rtt * 5 if link.rtt is not None else 1, 1), @@ -436,7 +462,7 @@ def _subproc_terminated(link: RNS.Link, return_code: int): global _loop log = _get_logger("_subproc_terminated") log.info(f"Subprocess returned {return_code} for link {link}") - proc = ProcessState.get_for_tag(link.link_id) + proc = Session.get_for_tag(link.link_id) if proc is None: log.debug(f"no proc for link {link}") return @@ -449,7 +475,7 @@ def _subproc_terminated(link: RNS.Link, return_code: int): try: link.teardown() finally: - ProcessState.clear_tag(link.link_id) + Session.clear_tag(link.link_id) _loop.call_later(300, inner) _loop.call_soon(_subproc_data_ready, link, 0) @@ -458,18 +484,18 @@ def _subproc_terminated(link: RNS.Link, return_code: int): def _listen_start_proc(link: RNS.Link, remote_identity: str | None, term: str, - loop: asyncio.AbstractEventLoop) -> ProcessState | None: + loop: asyncio.AbstractEventLoop) -> Session | None: global _cmd log = _get_logger("_listen_start_proc") try: - return ProcessState(tag=link.link_id, - cmd=_cmd, - term=term, - remote_identity=remote_identity, - mdu=link.MDU, - loop=loop, - data_available_callback=functools.partial(_subproc_data_ready, link), - terminated_callback=functools.partial(_subproc_terminated, link)) + return Session(tag=link.link_id, + cmd=_cmd, + term=term, + remote_identity=remote_identity, + mdu=link.MDU, + loop=loop, + data_available_callback=functools.partial(_subproc_data_ready, link), + terminated_callback=functools.partial(_subproc_terminated, link)) except Exception as e: log.error("Failed to launch process: " + str(e)) _subproc_terminated(link, 255) @@ -488,7 +514,7 @@ def _listen_link_closed(link: RNS.Link): log = _get_logger("_listen_link_closed") # async def cleanup(): log.info("Link " + str(link) + " closed") - proc: ProcessState | None = ProcessState.get_for_tag(link.link_id) + proc: Session | None = Session.get_for_tag(link.link_id) if proc is None: log.warning(f"No process for link {link}") else: @@ -497,7 +523,7 @@ def _listen_link_closed(link: RNS.Link): _retry_timer.complete(link.link_id) except Exception as e: log.error(f"Error closing process for link {link}: {e}") - ProcessState.clear_tag(link.link_id) + Session.clear_tag(link.link_id) def _initiator_identified(link, identity): @@ -513,34 +539,48 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_ global _destination, _retry_timer, _loop log = _get_logger("_listen_request") log.debug(f"listen_execute {path} {request_id} {link_id} {remote_identity}, {requested_at}") + if not hasattr(data, "__len__") or len(data) < 1: + raise Exception("Request data invalid") _retry_timer.complete(link_id) link: RNS.Link = next(filter(lambda l: l.link_id == link_id, _destination.links), None) if link is None: raise Exception(f"Invalid request {request_id}, no link found with id {link_id}") - process_state: ProcessState | None = None + + remote_version = data[Session.REQUEST_IDX_VERS] + if not _protocol_check_magic(remote_version): + raise Exception("Request magic incorrect") + + if not remote_version == _PROTOCOL_VERSION_0: + response = Session.default_response() + response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode("Listener<->initiator version mismatch\r\n".encode("utf-8")) + response[Session.RESPONSE_IDX_RETCODE] = 255 + response[Session.RESPONSE_IDX_RDYBYTE] = 0 + return response + + session: Session | None = None try: - term = data[ProcessState.REQUEST_IDX_TERM] - process_state = ProcessState.get_for_tag(link.link_id) - if process_state is None: + term = data[Session.REQUEST_IDX_TERM] + session = Session.get_for_tag(link.link_id) + if session is None: log.debug(f"Process not found for link {link}") - process_state = _listen_start_proc(link=link, + session = _listen_start_proc(link=link, term=term, remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""), loop=_loop) # leave significant headroom for metadata and encoding - result = process_state.process_request(data, link.MDU * 4 // 3) + result = session.process_request(data, link.MDU * 4 // 3) return result # return ProcessState.default_response() except Exception as e: log.error(f"Error procesing request for link {link}: {e}") try: - if process_state is not None and process_state.process.running: - process_state.process.terminate() + if session is not None and session.process.running: + session.process.terminate() except Exception as ee: log.debug(f"Error terminating process for link {link}: {ee}") - return ProcessState.default_response() + return Session.default_response() async def _spin(until: callable = None, timeout: float | None = None) -> bool: @@ -558,7 +598,7 @@ async def _spin(until: callable = None, timeout: float | None = None) -> bool: _link: RNS.Link | None = None _remote_exec_grace = 2.0 _new_data: asyncio.Event | None = None -_tr = process.TTYRestorer(sys.stdin.fileno()) +_tr: process.TTYRestorer | None = None def _client_packet_handler(message, packet): @@ -632,8 +672,8 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid= _link.set_packet_callback(_client_packet_handler) - request = ProcessState.default_request(sys.stdin.fileno()) - request[ProcessState.REQUEST_IDX_STDIN] = (base64.b64encode(stdin) if stdin is not None else None) + request = Session.default_request(sys.stdin.fileno()) + request[Session.REQUEST_IDX_STDIN] = (base64.b64encode(stdin) if stdin is not None else None) # TODO: Tune timeout = timeout + _link.rtt * 4 + _remote_exec_grace @@ -671,18 +711,27 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid= if request_receipt.response is not None: try: - running = request_receipt.response[ProcessState.RESPONSE_IDX_RUNNING] or True - return_code = request_receipt.response[ProcessState.RESPONSE_IDX_RETCODE] - ready_bytes = request_receipt.response[ProcessState.RESPONSE_IDX_RDYBYTE] or 0 - stdout = request_receipt.response[ProcessState.RESPONSE_IDX_STDOUT] - timestamp = request_receipt.response[ProcessState.RESPONSE_IDX_TMSTAMP] + version = request_receipt.response[Session.RESPONSE_IDX_VERSION] or 0 + if not _protocol_check_magic(version): + raise RemoteExecutionError("Protocol error") + elif version != _PROTOCOL_VERSION_0: + raise RemoteExecutionError("Protocol version mismatch") + + running = request_receipt.response[Session.RESPONSE_IDX_RUNNING] or True + return_code = request_receipt.response[Session.RESPONSE_IDX_RETCODE] + ready_bytes = request_receipt.response[Session.RESPONSE_IDX_RDYBYTE] or 0 + stdout = request_receipt.response[Session.RESPONSE_IDX_STDOUT] + if stdout is not None: + stdout = base64.b64decode(stdout) + timestamp = request_receipt.response[Session.RESPONSE_IDX_TMSTAMP] # log.debug("data: " + (stdout.decode("utf-8") if stdout is not None else "")) + except RemoteExecutionError: + raise except Exception as e: raise RemoteExecutionError(f"Received invalid response") from e _tr.raw() if stdout is not None: - stdout = base64.b64decode(stdout) # log.debug(f"stdout: {stdout}") os.write(sys.stdout.fileno(), stdout) @@ -784,7 +833,7 @@ def _split_array_at(arr: [_T], at: _T) -> ([_T], [_T]): return arr, [] -async def main(): +async def _rnsh_cli_main(): global _tr, _finished, _loop import docopt log = _get_logger("main") @@ -797,8 +846,8 @@ async def main(): Usage: rnsh [--config ] [-i ] [-s ] [-l] -p rnsh -l [--config ] [-i ] [-s ] - [-v...] [-q...] [-b] (-n | -a [-a ]...) - [--] [...] + [-v...] [-q...] [-b ] (-n | -a [-a ] ...) + [--] [ ...] rnsh [--config ] [-i ] [-s ] [-v...] [-q...] [-N] [-m] [-w ] rnsh -h @@ -810,7 +859,8 @@ Options: -s NAME --service NAME Listen on/connect to specific service name if not default -p --print-identity Print identity information and exit -l --listen Listen (server) mode - -b --no-announce Do not announce service + -b --announce PERIOD Announce on startup and every PERIOD seconds + Specify 0 for PERIOD to announce on startup only. -a HASH --allowed HASH Specify identities allowed to connect -n --no-auth Disable authentication -N --no-id Disable identify on connect @@ -837,7 +887,13 @@ Options: args_print_identity = args.get("--print-identity", None) or False args_verbose = args.get("--verbose", None) or 0 args_quiet = args.get("--quiet", None) or 0 - args_no_announce = args.get("--no-announce", None) or False + args_announce = args.get("--announce", None) + try: + if args_announce: + args_announce = int(args_announce) + except ValueError: + print("Invalid value for --announce") + return 1 args_no_auth = args.get("--no-auth", None) or False args_allowed = args.get("--allowed", None) or [] args_program = args.get("", None) @@ -868,7 +924,7 @@ Options: quietness=args_quiet, allowed=args_allowed, disable_auth=args_no_auth, - disable_announce=args_no_announce, + announce_period=args_announce, ) if args_destination is not None and args_service_name is not None: @@ -893,14 +949,14 @@ Options: def rnsh_cli(): - try: - return_code = asyncio.run(main()) - finally: - with exception.permit(SystemExit): - process.tty_unset_reader_callbacks(sys.stdin.fileno()) + global _tr, _retry_timer + with process.TTYRestorer(sys.stdin.fileno()) as _tr: + with retry.RetryThread() as _retry_timer: + return_code = asyncio.run(_rnsh_cli_main()) + + with exception.permit(SystemExit): + process.tty_unset_reader_callbacks(sys.stdin.fileno()) - _tr.restore() - _retry_timer.close() sys.exit(return_code or 255) diff --git a/tests/test_hacks.py b/tests/test_hacks.py new file mode 100644 index 0000000..29d04d5 --- /dev/null +++ b/tests/test_hacks.py @@ -0,0 +1,44 @@ +import asyncio + +import rnsh.hacks as hacks +import pytest +import time +import logging +logging.getLogger().setLevel(logging.DEBUG) + + +class FakeLink: + def __init__(self, link_id): + self.link_id = link_id + + +@pytest.mark.asyncio +async def test_pruning(): + def listen_request(path, data, request_id, link_id, remote_identity, requested_at): + assert path == 1 + assert data == 2 + assert request_id == 3 + assert remote_identity == 4 + assert requested_at == 5 + assert link_id == 6 + return 7 + + lhr_called = 0 + link = FakeLink(6) + + def link_handle_request(self, request_id, unpacked_request): + nonlocal lhr_called + lhr_called += 1 + + old_func = hacks.request_request_id_hack(listen_request, asyncio.get_running_loop()) + hacks._link_handle_request_orig = link_handle_request + hacks._link_handle_request(link, 3, None) + link_id = hacks._request_to_link.get(3) + assert link_id == link.link_id + result = old_func(1, 2, 3, 4, 5) + assert result == 7 + link_id = hacks._request_to_link.get(3) + assert link_id == 6 + await asyncio.sleep(1.5) + link_id = hacks._request_to_link.get(3) + assert link_id is None diff --git a/tests/test_rnsh.py b/tests/test_rnsh.py new file mode 100644 index 0000000..9b528d5 --- /dev/null +++ b/tests/test_rnsh.py @@ -0,0 +1,15 @@ +import logging +import rnsh.rnsh +logging.getLogger().setLevel(logging.DEBUG) + + +def test_check_magic(): + magic = rnsh.rnsh._PROTOCOL_VERSION_0 + # magic for version 0 is generated, make sure it comes out as expected + assert magic == 0xdeadbeef00000000 + # verify the checker thinks it's right + assert rnsh.rnsh._protocol_check_magic(magic) + # scramble the magic + magic = magic | 0x00ffff0000000000 + # make sure it fails now + assert not rnsh.rnsh._protocol_check_magic(magic)