from __future__ import annotations import logging from rnsh.protocol import _TReceipt, MessageState from typing import Callable logging.getLogger().setLevel(logging.DEBUG) import rnsh.protocol import contextlib import typing import types import time import uuid module_logger = logging.getLogger(__name__) class Receipt: def __init__(self, state: rnsh.protocol.MessageState, raw: bytes): self.state = state self.raw = raw class MessageOutletTest(rnsh.protocol.MessageOutletBase): def __init__(self, mdu: int, rtt: float): self.link_id = uuid.uuid4() self.timeout_callbacks = 0 self._mdu = mdu self._rtt = rtt self._usable = True self.receipts = [] self.packet_callback: Callable[[rnsh.protocol.MessageOutletBase, bytes], None] | None = None def send(self, raw: bytes) -> Receipt: receipt = Receipt(rnsh.protocol.MessageState.MSGSTATE_SENT, raw) self.receipts.append(receipt) return receipt def set_packet_received_callback(self, cb: Callable[[rnsh.protocol.MessageOutletBase, bytes], None]): self.packet_callback = cb def receive(self, raw: bytes): if self.packet_callback: self.packet_callback(self, raw) @property def mdu(self): return self._mdu @property def rtt(self): return self._rtt @property def is_usuable(self): return self._usable def get_receipt_state(self, receipt: Receipt) -> MessageState: return receipt.state def timed_out(self): self.timeout_callbacks += 1 def __str__(self): return str(self.link_id) class ProtocolHarness(contextlib.AbstractContextManager): def __init__(self, retry_delay_min: float = 1): self._log = module_logger.getChild(self.__class__.__name__) self.messenger = rnsh.protocol.Messenger(retry_delay_min=retry_delay_min) def cleanup(self): self.messenger.shutdown() def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException, __traceback: types.TracebackType) -> bool: # self._log.debug(f"__exit__({__exc_type}, {__exc_value}, {__traceback})") self.cleanup() return False def test_send_one_retry(): rtt = 0.001 retry_interval = rtt * 150 message_content = b'Test' with ProtocolHarness(retry_delay_min=retry_interval) as h: outlet = MessageOutletTest(mdu=500, rtt=rtt) message = rnsh.protocol.StreamDataMessage(stream_id=rnsh.protocol.StreamDataMessage.STREAM_ID_STDIN, data=message_content, eof=True) assert len(outlet.receipts) == 0 h.messenger.send(outlet, message) assert message.tracked assert message.raw is not None assert len(outlet.receipts) == 1 receipt = outlet.receipts[0] assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT assert receipt.raw == message.raw time.sleep(retry_interval * 1.5) assert len(outlet.receipts) == 1 receipt.state = rnsh.protocol.MessageState.MSGSTATE_FAILED module_logger.info("set failed") time.sleep(retry_interval) assert len(outlet.receipts) == 2 receipt = outlet.receipts[1] assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT receipt.state = rnsh.protocol.MessageState.MSGSTATE_DELIVERED time.sleep(retry_interval) assert len(outlet.receipts) == 2 assert not message.tracked def test_send_timeout(): rtt = 0.001 retry_interval = rtt * 150 message_content = b'Test' with ProtocolHarness(retry_delay_min=retry_interval) as h: outlet = MessageOutletTest(mdu=500, rtt=rtt) message = rnsh.protocol.StreamDataMessage(stream_id=rnsh.protocol.StreamDataMessage.STREAM_ID_STDIN, data=message_content, eof=True) assert len(outlet.receipts) == 0 h.messenger.send(outlet, message) assert message.tracked assert message.raw is not None assert len(outlet.receipts) == 1 receipt = outlet.receipts[0] assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT assert receipt.raw == message.raw time.sleep(retry_interval * 1.5) assert outlet.timeout_callbacks == 0 time.sleep(retry_interval * 7) assert len(outlet.receipts) == 1 assert outlet.timeout_callbacks == 1 assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT assert not message.tracked def eat_own_dog_food(message: rnsh.protocol.Message, checker: typing.Callable[[rnsh.protocol.Message], None]): rtt = 0.001 retry_interval = rtt * 150 with ProtocolHarness(retry_delay_min=retry_interval) as h: decoded: [rnsh.protocol.Message] = [] def packet(outlet, buffer): decoded.append(h.messenger.receive(buffer)) outlet = MessageOutletTest(mdu=500, rtt=rtt) outlet.set_packet_received_callback(packet) assert len(outlet.receipts) == 0 h.messenger.send(outlet, message) assert message.tracked assert message.raw is not None assert len(outlet.receipts) == 1 receipt = outlet.receipts[0] assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT assert receipt.raw == message.raw module_logger.info("set delivered") receipt.state = rnsh.protocol.MessageState.MSGSTATE_DELIVERED time.sleep(retry_interval * 2) assert len(outlet.receipts) == 1 assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_DELIVERED assert not message.tracked module_logger.info("injecting rx message") assert len(decoded) == 0 outlet.receive(message.raw) assert len(decoded) == 1 rx_message = decoded[0] assert rx_message is not None assert isinstance(rx_message, message.__class__) assert rx_message.msgid != message.msgid checker(rx_message) def test_send_receive_streamdata(): message = rnsh.protocol.StreamDataMessage(stream_id=rnsh.protocol.StreamDataMessage.STREAM_ID_STDIN, data=b'Test', eof=True) def check(rx_message: rnsh.protocol.Message): assert isinstance(rx_message, message.__class__) assert rx_message.stream_id == message.stream_id assert rx_message.data == message.data assert rx_message.eof == message.eof eat_own_dog_food(message, check) def test_send_receive_noop(): message = rnsh.protocol.NoopMessage() def check(rx_message: rnsh.protocol.Message): assert isinstance(rx_message, message.__class__) eat_own_dog_food(message, check) def test_send_receive_execute(): message = rnsh.protocol.ExecuteCommandMesssage(cmdline=["test", "one", "two"], pipe_stdin=False, pipe_stdout=True, pipe_stderr=False, tcflags=[12, 34, 56, [78, 90]], term="xtermmmm") def check(rx_message: rnsh.protocol.Message): assert isinstance(rx_message, message.__class__) assert rx_message.cmdline == message.cmdline assert rx_message.pipe_stdin == message.pipe_stdin assert rx_message.pipe_stdout == message.pipe_stdout assert rx_message.pipe_stderr == message.pipe_stderr assert rx_message.tcflags == message.tcflags assert rx_message.term == message.term eat_own_dog_food(message, check) def test_send_receive_windowsize(): message = rnsh.protocol.WindowSizeMessage(1, 2, 3, 4) def check(rx_message: rnsh.protocol.Message): assert isinstance(rx_message, message.__class__) assert rx_message.rows == message.rows assert rx_message.cols == message.cols assert rx_message.hpix == message.hpix assert rx_message.vpix == message.vpix eat_own_dog_food(message, check) def test_send_receive_versioninfo(): message = rnsh.protocol.VersionInfoMessage(sw_version="1.2.3") message.protocol_version = 30 def check(rx_message: rnsh.protocol.Message): assert isinstance(rx_message, message.__class__) assert rx_message.sw_version == message.sw_version assert rx_message.protocol_version == message.protocol_version eat_own_dog_food(message, check) def test_send_receive_error(): message = rnsh.protocol.ErrorMessage(msg="TESTerr", fatal=True, data={"one": 2}) def check(rx_message: rnsh.protocol.Message): assert isinstance(rx_message, message.__class__) assert rx_message.msg == message.msg assert rx_message.fatal == message.fatal assert rx_message.data == message.data eat_own_dog_food(message, check) def test_send_receive_cmdexit(): message = rnsh.protocol.CommandExitedMessage(5) def check(rx_message: rnsh.protocol.Message): assert isinstance(rx_message, message.__class__) assert rx_message.return_code == message.return_code eat_own_dog_food(message, check)