From 5bca575a4b3480eec433723ad3736cea4b44cbb0 Mon Sep 17 00:00:00 2001 From: Aaron Heise <5148966+acehoss@users.noreply.github.com> Date: Tue, 28 Feb 2023 08:48:29 -0600 Subject: [PATCH] Switch to RNS-provided Channel --- pyproject.toml | 2 +- rnsh/initiator.py | 274 +++++++++++++++++++++++----------------------- rnsh/listener.py | 3 +- rnsh/protocol.py | 264 +++++--------------------------------------- rnsh/rnsh.py | 88 ++++++++------- rnsh/session.py | 80 +++----------- 6 files changed, 229 insertions(+), 482 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cfd9b0d..e3bdb85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ readme = "README.md" [tool.poetry.dependencies] python = "^3.9" docopt = "^0.6.2" -rns = "^0.4.9" +rns = { git = "https://github.com/acehoss/Reticulum.git", branch = "feature/channel" } #{ path = "../Reticulum/", develop = true } # tomli = "^2.0.1" [tool.poetry.scripts] diff --git a/rnsh/initiator.py b/rnsh/initiator.py index 62d8477..8a82290 100644 --- a/rnsh/initiator.py +++ b/rnsh/initiator.py @@ -143,12 +143,12 @@ class InitiatorState(enum.IntEnum): def _client_link_closed(link): log = _get_logger("_client_link_closed") - _finished.set() + if _finished: + _finished.set() -def _client_packet_handler(message, packet): - log = _get_logger("_client_packet_handler") - packet.prove() +def _client_message_handler(message: RNS.MessageBase): + log = _get_logger("_client_message_handler") _pq.put(message) @@ -213,10 +213,8 @@ async def _initiate_link(configdir, identitypath=None, verbosity=0, quietness=0, _link.identify(_identity) _link.did_identify = True - _link.set_packet_callback(_client_packet_handler) - -async def _handle_error(errmsg: protocol.Message): +async def _handle_error(errmsg: RNS.MessageBase): if isinstance(errmsg, protocol.ErrorMessage): with contextlib.suppress(Exception): if _link and _link.status == RNS.Link.ACTIVE: @@ -249,150 +247,148 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness: state = InitiatorState.IS_LINKED outlet = session.RNSOutlet(_link) - with protocol.Messenger(retry_delay_min=5) as messenger: + channel = _link.get_channel() + protocol.register_message_types(channel) + channel.add_message_handler(_client_message_handler) - # Next step after linking and identifying: send version - # if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5): - # print("Error bringing up link") - # return 253 + # Next step after linking and identifying: send version + # if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5): + # print("Error bringing up link") + # return 253 - messenger.send(outlet, protocol.VersionInfoMessage()) + channel.send(protocol.VersionInfoMessage()) + try: + vm = _pq.get(timeout=max(outlet.rtt * 20, 5)) + await _handle_error(vm) + if not isinstance(vm, protocol.VersionInfoMessage): + raise Exception("Invalid message received") + log.debug(f"Server version info: sw {vm.sw_version} prot {vm.protocol_version}") + state = InitiatorState.IS_RUNNING + except queue.Empty: + print("Protocol error") + return 254 + + winch = False + def sigwinch_handler(): + nonlocal winch + # log.debug("WindowChanged") + winch = True + + stdin_eof = False + def stdin(): + nonlocal stdin_eof try: - vp = _pq.get(timeout=max(outlet.rtt * 20, 5)) - vm = messenger.receive(vp) - await _handle_error(vm) - if not isinstance(vm, protocol.VersionInfoMessage): - raise Exception("Invalid message received") - log.debug(f"Server version info: sw {vm.sw_version} prot {vm.protocol_version}") - state = InitiatorState.IS_RUNNING - except queue.Empty: - print("Protocol error") - return 254 + data = process.tty_read(sys.stdin.fileno()) + log.debug(f"stdin {data}") + if data is not None: + data_buffer.extend(data) + except EOFError: + if os.isatty(0): + data_buffer.extend(process.CTRL_D) + stdin_eof = True + process.tty_unset_reader_callbacks(sys.stdin.fileno()) - winch = False - def sigwinch_handler(): - nonlocal winch - # log.debug("WindowChanged") - winch = True + process.tty_add_reader_callback(sys.stdin.fileno(), stdin) - stdin_eof = False - def stdin(): - nonlocal stdin_eof - try: - data = process.tty_read(sys.stdin.fileno()) - log.debug(f"stdin {data}") - if data is not None: - data_buffer.extend(data) - except EOFError: - if os.isatty(0): - data_buffer.extend(process.CTRL_D) - stdin_eof = True - process.tty_unset_reader_callbacks(sys.stdin.fileno()) - - process.tty_add_reader_callback(sys.stdin.fileno(), stdin) - - tcattr = None - rows, cols, hpix, vpix = (None, None, None, None) + tcattr = None + rows, cols, hpix, vpix = (None, None, None, None) + try: + tcattr = termios.tcgetattr(0) + rows, cols, hpix, vpix = process.tty_get_winsize(0) + except: try: - tcattr = termios.tcgetattr(0) - rows, cols, hpix, vpix = process.tty_get_winsize(0) + tcattr = termios.tcgetattr(1) + rows, cols, hpix, vpix = process.tty_get_winsize(1) except: try: - tcattr = termios.tcgetattr(1) - rows, cols, hpix, vpix = process.tty_get_winsize(1) + tcattr = termios.tcgetattr(2) + rows, cols, hpix, vpix = process.tty_get_winsize(2) except: - try: - tcattr = termios.tcgetattr(2) - rows, cols, hpix, vpix = process.tty_get_winsize(2) - except: - pass + pass - messenger.send(outlet, protocol.ExecuteCommandMesssage(cmdline=command, - pipe_stdin=not os.isatty(0), - pipe_stdout=not os.isatty(1), - pipe_stderr=not os.isatty(2), - tcflags=tcattr, - term=os.environ.get("TERM", None), - rows=rows, - cols=cols, - hpix=hpix, - vpix=vpix)) + await _spin(lambda: channel.is_ready_to_send(), "Waiting for channel...", 1) + channel.send(protocol.ExecuteCommandMesssage(cmdline=command, + pipe_stdin=not os.isatty(0), + pipe_stdout=not os.isatty(1), + pipe_stderr=not os.isatty(2), + tcflags=tcattr, + term=os.environ.get("TERM", None), + rows=rows, + cols=cols, + hpix=hpix, + vpix=vpix)) - loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler) - _finished = asyncio.Event() - loop.add_signal_handler(signal.SIGINT, functools.partial(_sigint_handler, signal.SIGINT, loop)) - loop.add_signal_handler(signal.SIGTERM, functools.partial(_sigint_handler, signal.SIGTERM, loop)) - mdu = _link.MDU - 16 - sent_eof = False - last_winch = time.time() - sleeper = helpers.SleepRate(0.01) - processed = False - while not await _check_finished() and state in [InitiatorState.IS_RUNNING]: + loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler) + _finished = asyncio.Event() + loop.add_signal_handler(signal.SIGINT, functools.partial(_sigint_handler, signal.SIGINT, loop)) + loop.add_signal_handler(signal.SIGTERM, functools.partial(_sigint_handler, signal.SIGTERM, loop)) + mdu = _link.MDU - 16 + sent_eof = False + last_winch = time.time() + sleeper = helpers.SleepRate(0.01) + processed = False + while not await _check_finished() and state in [InitiatorState.IS_RUNNING]: + try: try: - try: - packet = _pq.get(timeout=sleeper.next_sleep_time() if not processed else 0.0005) - message = messenger.receive(packet) - await _handle_error(message) + message = _pq.get(timeout=sleeper.next_sleep_time() if not processed else 0.0005) + await _handle_error(message) + processed = True + if isinstance(message, protocol.StreamDataMessage): + if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDOUT: + if message.data and len(message.data) > 0: + ttyRestorer.raw() + log.debug(f"stdout: {message.data}") + os.write(1, message.data) + sys.stdout.flush() + if message.eof: + os.close(1) + if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDERR: + if message.data and len(message.data) > 0: + ttyRestorer.raw() + log.debug(f"stdout: {message.data}") + os.write(2, message.data) + sys.stderr.flush() + if message.eof: + os.close(2) + elif isinstance(message, protocol.CommandExitedMessage): + log.debug(f"received return code {message.return_code}, exiting") + return message.return_code + elif isinstance(message, protocol.ErrorMessage): + log.error(message.data) + if message.fatal: + _link.teardown() + return 200 + + except queue.Empty: + processed = False + + if channel.is_ready_to_send(): + stdin = data_buffer[:mdu] + data_buffer = data_buffer[mdu:] + eof = not sent_eof and stdin_eof and len(stdin) == 0 + if len(stdin) > 0 or eof: + channel.send(protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN, stdin, eof)) + sent_eof = eof processed = True - if isinstance(message, protocol.StreamDataMessage): - if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDOUT: - if message.data and len(message.data) > 0: - ttyRestorer.raw() - log.debug(f"stdout: {message.data}") - os.write(1, message.data) - sys.stdout.flush() - if message.eof: - os.close(1) - if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDERR: - if message.data and len(message.data) > 0: - ttyRestorer.raw() - log.debug(f"stdout: {message.data}") - os.write(2, message.data) - sys.stderr.flush() - if message.eof: - os.close(2) - elif isinstance(message, protocol.CommandExitedMessage): - log.debug(f"received return code {message.return_code}, exiting") - with exception.permit(SystemExit, KeyboardInterrupt): - _link.teardown() - return message.return_code - elif isinstance(message, protocol.ErrorMessage): - log.error(message.data) - if message.fatal: - _link.teardown() - return 200 - except queue.Empty: - processed = False + # send window change, but rate limited + if winch and time.time() - last_winch > _link.rtt * 25: + last_winch = time.time() + winch = False + with contextlib.suppress(Exception): + r, c, h, v = process.tty_get_winsize(0) + channel.send(protocol.WindowSizeMessage(r, c, h, v)) + processed = True + except RemoteExecutionError as e: + print(e.msg) + return 255 + except Exception as ex: + print(f"Client exception: {ex}") + if _link and _link.status != RNS.Link.CLOSED: + _link.teardown() + return 127 - if messenger.is_outlet_ready(outlet): - stdin = data_buffer[:mdu] - data_buffer = data_buffer[mdu:] - eof = not sent_eof and stdin_eof and len(stdin) == 0 - if len(stdin) > 0 or eof: - messenger.send(outlet, protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN, - stdin, eof)) - sent_eof = eof - processed = True - - # send window change, but rate limited - if winch and time.time() - last_winch > _link.rtt * 25: - last_winch = time.time() - winch = False - with contextlib.suppress(Exception): - r, c, h, v = process.tty_get_winsize(0) - messenger.send(outlet, protocol.WindowSizeMessage(r, c, h, v)) - processed = True - except RemoteExecutionError as e: - print(e.msg) - return 255 - except Exception as ex: - print(f"Client exception: {ex}") - if _link and _link.status != RNS.Link.CLOSED: - _link.teardown() - return 127 - - # await process.event_wait_any([_new_data, _finished], timeout=min(max(rtt * 50, 5), 120)) - # await sleeper.sleep_async() - log.debug("after main loop") - return 0 \ No newline at end of file + # await process.event_wait_any([_new_data, _finished], timeout=min(max(rtt * 50, 5), 120)) + # await sleeper.sleep_async() + log.debug("after main loop") + return 0 \ No newline at end of file diff --git a/rnsh/listener.py b/rnsh/listener.py index 796aaa6..07b6a0e 100644 --- a/rnsh/listener.py +++ b/rnsh/listener.py @@ -159,7 +159,7 @@ async def listen(configdir, command, identitypath=None, service_name=None, verbo log.warning("Warning: No allowed identities configured, rnsh will not accept any connections!") def link_established(lnk: RNS.Link): - session.ListenerSession(session.RNSOutlet.get_outlet(lnk), loop) + session.ListenerSession(session.RNSOutlet.get_outlet(lnk), lnk.get_channel(), loop) _destination.set_link_established_callback(link_established) _finished = asyncio.Event() @@ -188,7 +188,6 @@ async def listen(configdir, command, identitypath=None, service_name=None, verbo log.warning("Shutting down") await session.ListenerSession.terminate_all("Shutting down") await asyncio.sleep(1) - session.ListenerSession.messenger.shutdown() links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links)) for link in links_still_active: if link.status not in [RNS.Link.CLOSED]: diff --git a/rnsh/protocol.py b/rnsh/protocol.py index 8dc756b..a833ad8 100644 --- a/rnsh/protocol.py +++ b/rnsh/protocol.py @@ -19,9 +19,6 @@ from abc import ABC, abstractmethod module_logger = __logging.getLogger(__name__) - -_TReceipt = TypeVar("_TReceipt") -_TLink = TypeVar("_TLink") MSG_MAGIC = 0xac PROTOCOL_VERSION = 1 @@ -30,120 +27,17 @@ def _make_MSGTYPE(val: int): return ((MSG_MAGIC << 8) & 0xff00) | (val & 0x00ff) -class MessageOutletBase(ABC): - @abstractmethod - def send(self, raw: bytes) -> _TReceipt: - raise NotImplemented() - - @abstractmethod - def resend(self, receipt: _TReceipt) -> _TReceipt: - raise NotImplemented() - - @property - @abstractmethod - def mdu(self): - raise NotImplemented() - - @property - @abstractmethod - def rtt(self): - raise NotImplemented() - - @property - @abstractmethod - def is_usuable(self): - raise NotImplemented() - - @abstractmethod - def get_receipt_state(self, receipt: _TReceipt) -> MessageState: - raise NotImplemented() - - @abstractmethod - def timed_out(self): - raise NotImplemented() - - @abstractmethod - def __str__(self): - raise NotImplemented() - - @abstractmethod - def set_packet_received_callback(self, cb: Callable[[MessageOutletBase, bytes], None]): - raise NotImplemented() - - -class METype(enum.IntEnum): - ME_NO_MSG_TYPE = 0 - ME_INVALID_MSG_TYPE = 1 - ME_NOT_REGISTERED = 2 - ME_LINK_NOT_READY = 3 - ME_ALREADY_SENT = 4 - - -class MessagingException(Exception): - def __init__(self, type: METype, *args): - super().__init__(args) - self.type = type - - -class MessageState(enum.IntEnum): - MSGSTATE_NEW = 0 - MSGSTATE_SENT = 1 - MSGSTATE_DELIVERED = 2 - MSGSTATE_FAILED = 3 - - -class Message(abc.ABC): - MSGTYPE = None - - def __init__(self): - self.ts = time.time() - self.msgid = uuid.uuid4() - self.raw: bytes | None = None - self.receipt: _TReceipt = None - self.outlet: _TLink = None - self.tracked: bool = False - - def __str__(self): - return f"{self.__class__.__name__} {self.msgid}" - - @abstractmethod - def pack(self) -> bytes: - raise NotImplemented() - - @abstractmethod - def unpack(self, raw): - raise NotImplemented() - - def unwrap_MSGTYPE(self, raw: bytes) -> bytes: - if self.MSGTYPE is None: - raise MessagingException(METype.ME_NO_MSG_TYPE, f"{self.__class__} lacks MSGTYPE") - mid, raw = self.static_unwrap_MSGTYPE(raw) - if mid != self.MSGTYPE: - raise MessagingException(METype.ME_INVALID_MSG_TYPE, - f"invalid msg id, expected {hex(self.MSGTYPE)} got {hex(mid)}") - return raw - - def wrap_MSGTYPE(self, raw: bytes) -> bytes: - if self.__class__.MSGTYPE is None: - raise MessagingException(METype.ME_NO_MSG_TYPE, f"{self.__class__} lacks MSGTYPE") - return struct.pack(">H", self.MSGTYPE) + raw - - @staticmethod - def static_unwrap_MSGTYPE(raw: bytes) -> (int, bytes): - return struct.unpack(">H", raw[:2])[0], raw[2:] - - -class NoopMessage(Message): +class NoopMessage(RNS.MessageBase): MSGTYPE = _make_MSGTYPE(0) def pack(self) -> bytes: - return self.wrap_MSGTYPE(bytes()) + return bytes() def unpack(self, raw): - self.unwrap_MSGTYPE(raw) + pass -class WindowSizeMessage(Message): +class WindowSizeMessage(RNS.MessageBase): MSGTYPE = _make_MSGTYPE(2) def __init__(self, rows: int = None, cols: int = None, hpix: int = None, vpix: int = None): @@ -154,15 +48,13 @@ class WindowSizeMessage(Message): self.vpix = vpix def pack(self) -> bytes: - raw = umsgpack.packb((self.rows, self.cols, self.hpix, self.vpix)) - return self.wrap_MSGTYPE(raw) + return umsgpack.packb((self.rows, self.cols, self.hpix, self.vpix)) def unpack(self, raw): - raw = self.unwrap_MSGTYPE(raw) self.rows, self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw) -class ExecuteCommandMesssage(Message): +class ExecuteCommandMesssage(RNS.MessageBase): MSGTYPE = _make_MSGTYPE(3) def __init__(self, cmdline: [str] = None, pipe_stdin: bool = False, pipe_stdout: bool = False, @@ -181,20 +73,20 @@ class ExecuteCommandMesssage(Message): self.vpix = vpix def pack(self) -> bytes: - raw = umsgpack.packb((self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, - self.tcflags, self.term, self.rows, self.cols, self.hpix, self.vpix)) - return self.wrap_MSGTYPE(raw) + return umsgpack.packb((self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, + self.tcflags, self.term, self.rows, self.cols, self.hpix, self.vpix)) def unpack(self, raw): - raw = self.unwrap_MSGTYPE(raw) self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, self.tcflags, self.term, self.rows, \ self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw) -class StreamDataMessage(Message): + +class StreamDataMessage(RNS.MessageBase): MSGTYPE = _make_MSGTYPE(4) STREAM_ID_STDIN = 0 STREAM_ID_STDOUT = 1 STREAM_ID_STDERR = 2 + OVERHEAD = 0 def __init__(self, stream_id: int = None, data: bytes = None, eof: bool = False): super().__init__() @@ -203,15 +95,19 @@ class StreamDataMessage(Message): self.eof = eof def pack(self) -> bytes: - raw = umsgpack.packb((self.stream_id, self.eof, bytes(self.data))) - return self.wrap_MSGTYPE(raw) + return umsgpack.packb((self.stream_id, self.eof, bytes(self.data))) def unpack(self, raw): - raw = self.unwrap_MSGTYPE(raw) self.stream_id, self.eof, self.data = umsgpack.unpackb(raw) -class VersionInfoMessage(Message): +_link_sized_bytes = ("\0"*RNS.Link.MDU).encode("utf-8") +StreamDataMessage.OVERHEAD = len(StreamDataMessage(stream_id=0, data=_link_sized_bytes, eof=True).pack()) \ + - len(_link_sized_bytes) +_link_sized_bytes = None + + +class VersionInfoMessage(RNS.MessageBase): MSGTYPE = _make_MSGTYPE(5) def __init__(self, sw_version: str = None): @@ -220,15 +116,13 @@ class VersionInfoMessage(Message): self.protocol_version = PROTOCOL_VERSION def pack(self) -> bytes: - raw = umsgpack.packb((self.sw_version, self.protocol_version)) - return self.wrap_MSGTYPE(raw) + return umsgpack.packb((self.sw_version, self.protocol_version)) def unpack(self, raw): - raw = self.unwrap_MSGTYPE(raw) self.sw_version, self.protocol_version = umsgpack.unpackb(raw) -class ErrorMessage(Message): +class ErrorMessage(RNS.MessageBase): MSGTYPE = _make_MSGTYPE(6) def __init__(self, msg: str = None, fatal: bool = False, data: dict = None): @@ -238,15 +132,13 @@ class ErrorMessage(Message): self.data = data def pack(self) -> bytes: - raw = umsgpack.packb((self.msg, self.fatal, self.data)) - return self.wrap_MSGTYPE(raw) + return umsgpack.packb((self.msg, self.fatal, self.data)) def unpack(self, raw: bytes): - raw = self.unwrap_MSGTYPE(raw) self.msg, self.fatal, self.data = umsgpack.unpackb(raw) -class CommandExitedMessage(Message): +class CommandExitedMessage(RNS.MessageBase): MSGTYPE = _make_MSGTYPE(7) def __init__(self, return_code: int = None): @@ -254,114 +146,16 @@ class CommandExitedMessage(Message): self.return_code = return_code def pack(self) -> bytes: - raw = umsgpack.packb(self.return_code) - return self.wrap_MSGTYPE(raw) + return umsgpack.packb(self.return_code) def unpack(self, raw: bytes): - raw = self.unwrap_MSGTYPE(raw) self.return_code = umsgpack.unpackb(raw) -class Messenger(contextlib.AbstractContextManager): +message_types = [NoopMessage, VersionInfoMessage, WindowSizeMessage, ExecuteCommandMesssage, StreamDataMessage, + CommandExitedMessage, ErrorMessage] - @staticmethod - def _get_msg_constructors() -> (int, Type[Message]): - subclass_tuples = [] - for subclass in Message.__subclasses__(): - subclass_tuples.append((subclass.MSGTYPE, subclass)) - return subclass_tuples - - def __init__(self, retry_delay_min: float = 10.0): - self._log = module_logger.getChild(self.__class__.__name__) - self._sent_messages: list[Message] = [] - self._lock = threading.RLock() - self._retry_timer = rnsh.retry.RetryThread() - self._message_factories = dict(self.__class__._get_msg_constructors()) - self._retry_delay_min = retry_delay_min - - def __enter__(self) -> Messenger: - return self - - def __exit__(self, __exc_type: Type[BaseException] | None, __exc_value: BaseException | None, - __traceback: TracebackType | None) -> bool | None: - self.shutdown() - return False - - def shutdown(self): - self._retry_timer.close() - - def clear_retries(self, outlet): - self._retry_timer.complete(outlet) - - def receive(self, raw: bytes) -> Message: - (mid, contents) = Message.static_unwrap_MSGTYPE(raw) - ctor = self._message_factories.get(mid, None) - if ctor is None: - raise MessagingException(METype.ME_NOT_REGISTERED, f"unable to find constructor for message type {hex(mid)}") - message = ctor() - message.unpack(raw) - self._log.debug(f"Message received: {message}") - return message - - def is_outlet_ready(self, outlet: MessageOutletBase) -> bool: - if not outlet.is_usuable: - self._log.debug("is_outlet_ready outlet unusable") - return False - - with self._lock: - for message in self._sent_messages: - if message.outlet == outlet and message.tracked and message.receipt \ - and outlet.get_receipt_state(message.receipt) == MessageState.MSGSTATE_SENT: - self._log.debug("is_outlet_ready pending message found") - return False - return True - - def send(self, outlet: MessageOutletBase, message: Message): - with self._lock: - if not self.is_outlet_ready(outlet): - raise MessagingException(METype.ME_LINK_NOT_READY, f"link {outlet} not ready") - - if message in self._sent_messages: - raise MessagingException(METype.ME_ALREADY_SENT) - self._sent_messages.append(message) - message.tracked = True - - if not message.raw: - message.raw = message.pack() - message.outlet = outlet - - def send_inner(tag: any, tries: int): - state = MessageState.MSGSTATE_NEW if not message.receipt else outlet.get_receipt_state(message.receipt) - if state in [MessageState.MSGSTATE_NEW, MessageState.MSGSTATE_FAILED]: - try: - if message.receipt: - self._log.debug(f"Resending packet for {message}") - message.receipt = outlet.resend(message.receipt) - else: - self._log.debug(f"Sending packet for {message}") - message.receipt = outlet.send(message.raw) - except Exception as ex: - self._log.exception(f"Error sending message {message}") - elif state in [MessageState.MSGSTATE_SENT]: - self._log.debug(f"Retry skipped, message still pending {message}") - elif state in [MessageState.MSGSTATE_DELIVERED]: - latency = round(time.time() - message.ts, 1) - self._log.debug(f"{message} delivered {message.msgid} after {tries-1} tries/{latency} seconds") - with self._lock: - self._sent_messages.remove(message) - message.tracked = False - self._retry_timer.complete(outlet) - return outlet - - def timeout(tag: any, tries: int): - latency = round(time.time() - message.ts, 1) - msg = "delivered" if message.receipt and outlet.get_receipt_state(message.receipt) == MessageState.MSGSTATE_DELIVERED else "retry timeout" - self._log.debug(f"Message {msg} {message} after {tries} tries/{latency} seconds") - with self._lock: - self._sent_messages.remove(message) - message.tracked = False - outlet.timed_out() - - rtt = outlet.rtt - self._retry_timer.begin(5, max(rtt * 5, self._retry_delay_min), send_inner, timeout) +def register_message_types(channel: RNS.Channel.Channel): + for message_type in message_types: + channel.register_message_type(message_type) \ No newline at end of file diff --git a/rnsh/rnsh.py b/rnsh/rnsh.py index 5fd6a71..5f11040 100644 --- a/rnsh/rnsh.py +++ b/rnsh/rnsh.py @@ -102,53 +102,58 @@ def print_identity(configdir, identitypath, service_name, include_destination: b exit(0) +verbose_set = False + + async def _rnsh_cli_main(): - #with contextlib.suppress(KeyboardInterrupt, SystemExit): - import docopt - log = _get_logger("main") - _loop = asyncio.get_running_loop() - rnslogging.set_main_loop(_loop) - args = rnsh.args.Args(sys.argv) + global verbose_set + log = _get_logger("main") + _loop = asyncio.get_running_loop() + rnslogging.set_main_loop(_loop) + args = rnsh.args.Args(sys.argv) + verbose_set = args.verbose > 0 - if args.print_identity: - print_identity(args.config, args.identity, args.service_name, args.listen) - return 0 + if args.print_identity: + print_identity(args.config, args.identity, args.service_name, args.listen) + return 0 - if args.listen: - # log.info("command " + args.command) - await listener.listen(configdir=args.config, - command=args.command_line, - identitypath=args.identity, - service_name=args.service_name, - verbosity=args.verbose, - quietness=args.quiet, - allowed=args.allowed, - disable_auth=args.no_auth, - announce_period=args.announce, - no_remote_command=args.no_remote_cmd, - remote_cmd_as_args=args.remote_cmd_as_args) - return 0 + if args.listen: + # log.info("command " + args.command) + await listener.listen(configdir=args.config, + command=args.command_line, + identitypath=args.identity, + service_name=args.service_name, + verbosity=args.verbose, + quietness=args.quiet, + allowed=args.allowed, + disable_auth=args.no_auth, + announce_period=args.announce, + no_remote_command=args.no_remote_cmd, + remote_cmd_as_args=args.remote_cmd_as_args) + return 0 - if args.destination is not None: - return_code = await initiator.initiate(configdir=args.config, - identitypath=args.identity, - verbosity=args.verbose, - quietness=args.quiet, - noid=args.no_id, - destination=args.destination, - timeout=args.timeout, - command=args.command_line - ) - return return_code if args.mirror else 0 - else: - print("") - print(rnsh.args.usage) - print("") - return 1 + if args.destination is not None: + return_code = await initiator.initiate(configdir=args.config, + identitypath=args.identity, + verbosity=args.verbose, + quietness=args.quiet, + noid=args.no_id, + destination=args.destination, + timeout=args.timeout, + command=args.command_line + ) + return return_code if args.mirror else 0 + else: + print("") + print(rnsh.args.usage) + print("") + return 1 def rnsh_cli(): + global verbose_set return_code = 1 + exc = None try: return_code = asyncio.run(_rnsh_cli_main()) except SystemExit: @@ -156,8 +161,11 @@ def rnsh_cli(): except KeyboardInterrupt: pass except Exception as ex: - print(f"Unhandled exception: {ex}") + print(f"Unhandled exception: {ex}") + exc = ex process.tty_unset_reader_callbacks(0) + if verbose_set and exc: + raise exc sys.exit(return_code if return_code is not None else 255) diff --git a/rnsh/session.py b/rnsh/session.py index e551b50..b79e155 100644 --- a/rnsh/session.py +++ b/rnsh/session.py @@ -16,8 +16,6 @@ import RNS import logging as __logging -from rnsh.protocol import MessageOutletBase, _TReceipt, MessageState - module_logger = __logging.getLogger(__name__) _TLink = TypeVar("_TLink") @@ -44,7 +42,7 @@ class LSState(enum.IntEnum): _TIdentity = TypeVar("_TIdentity") -class LSOutletBase(protocol.MessageOutletBase): +class LSOutletBase(ABC): @abstractmethod def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]): raise NotImplemented() @@ -57,28 +55,29 @@ class LSOutletBase(protocol.MessageOutletBase): def unset_link_closed_callback(self): raise NotImplemented() + @property @abstractmethod - def teardown(self): + def rtt(self): raise NotImplemented() @abstractmethod - def __init__(self): + def teardown(self): raise NotImplemented() class ListenerSession: sessions: List[ListenerSession] = [] - messenger: protocol.Messenger = protocol.Messenger(retry_delay_min=5) allowed_identity_hashes: [any] = [] allow_all: bool = False allow_remote_command: bool = False default_command: [str] = [] remote_cmd_as_args = False - def __init__(self, outlet: LSOutletBase, loop: asyncio.AbstractEventLoop): + def __init__(self, outlet: LSOutletBase, channel: RNS.Channel.Channel, loop: asyncio.AbstractEventLoop): self._log = module_logger.getChild(self.__class__.__name__) self._log.info(f"Session started for {outlet}") self.outlet = outlet + self.channel = channel self.outlet.set_initiator_identified_callback(self._initiator_identified) self.outlet.set_link_closed_callback(self._link_closed) self.loop = loop @@ -106,7 +105,8 @@ class ListenerSession: else: self._set_state(LSState.LSSTATE_WAIT_IDENT) self.sessions.append(self) - self.outlet.set_packet_received_callback(self._packet_received) + protocol.register_message_types(self.channel) + self.channel.add_message_handler(self._handle_message) def _terminated(self, return_code: int): self.return_code = return_code @@ -128,8 +128,8 @@ class ListenerSession: self.loop.call_later(delay, func) self.loop.call_soon_threadsafe(call_inner) - def send(self, message: protocol.Message): - self.messenger.send(self.outlet, message) + def send(self, message: RNS.MessageBase): + self.channel.send(message) def _protocol_error(self, name: str): self.terminate(f"Protocol error ({name})") @@ -171,7 +171,6 @@ class ListenerSession: return self._log.debug(f"link_closed {outlet}") - self.messenger.clear_retries(self.outlet) self.terminate() def _initiator_identified(self, outlet, identity): @@ -208,10 +207,10 @@ class ListenerSession: try: if self.state != LSState.LSSTATE_RUNNING: return False - elif not self.messenger.is_outlet_ready(self.outlet): + elif not self.channel.is_ready_to_send(): return False elif len(self.stderr_buf) > 0: - mdu = self.outlet.mdu - 16 + mdu = self.channel.MDU - protocol.StreamDataMessage.OVERHEAD data = self.stderr_buf[:mdu] self.stderr_buf = self.stderr_buf[mdu:] send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent @@ -223,7 +222,7 @@ class ListenerSession: self.stderr_eof_sent = True return True elif len(self.stdout_buf) > 0: - mdu = self.outlet.mdu - 16 + mdu = self.channel.MDU - protocol.StreamDataMessage.OVERHEAD data = self.stdout_buf[:mdu] self.stdout_buf = self.stdout_buf[mdu:] send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent @@ -309,7 +308,7 @@ class ListenerSession: if eof: self.process.close_stdin() - def _handle_message(self, message: protocol.Message): + def _handle_message(self, message: RNS.MessageBase): if self.state == LSState.LSSTATE_WAIT_IDENT: self._protocol_error("Identification required") return @@ -352,17 +351,6 @@ class ListenerSession: self._protocol_error("unexpected message") return - def _packet_received(self, outlet: protocol.MessageOutletBase, raw: bytes): - if outlet != self.outlet: - self._log.debug("Packet received from incorrect outlet") - return - - try: - message = self.messenger.receive(raw) - self._handle_message(message) - except Exception as ex: - self._protocol_error(f"error receiving packet: {ex}") - class RNSOutlet(LSOutletBase): @@ -384,55 +372,17 @@ class RNSOutlet(LSOutletBase): def teardown(self): self.link.teardown() - def send(self, raw: bytes) -> RNS.Packet: - packet = RNS.Packet(self.link, raw) - packet.send() - return packet - - def resend(self, packet: RNS.Packet) -> RNS.Packet: - packet.resend() - return packet - - @property - def mdu(self) -> int: - return self.link.MDU - @property def rtt(self) -> float: return self.link.rtt - @property - def is_usuable(self): - return True #self.link.status in [RNS.Link.ACTIVE] - - def get_receipt_state(self, packet: RNS.Packet) -> MessageState: - status = packet.receipt.get_status() - if status == RNS.PacketReceipt.SENT: - return protocol.MessageState.MSGSTATE_SENT - if status == RNS.PacketReceipt.DELIVERED: - return protocol.MessageState.MSGSTATE_DELIVERED - if status == RNS.PacketReceipt.FAILED: - return protocol.MessageState.MSGSTATE_FAILED - else: - raise Exception(f"Unexpected receipt state: {status}") - - def timed_out(self): - self.link.teardown() - def __str__(self): return f"Outlet RNS Link {self.link}" - def set_packet_received_callback(self, cb: Callable[[MessageOutletBase, bytes], None]): - def inner_cb(message, packet: RNS.Packet): - packet.prove() - cb(self, message) - - self.link.set_packet_callback(inner_cb) - def __init__(self, link: RNS.Link): self.link = link link.lsoutlet = self - link.msgoutlet = self + @staticmethod def get_outlet(link: RNS.Link): if hasattr(link, "lsoutlet"):