Added a new message-based, packet-ready protocol layer with test suite for #4

This commit is contained in:
Aaron Heise 2023-02-17 10:48:16 -06:00
parent aa37e4e3da
commit 0ee305795f
3 changed files with 568 additions and 0 deletions

View File

@ -1,5 +1,7 @@
import contextlib
from contextlib import AbstractContextManager
import logging
import sys
class permit(AbstractContextManager):
@ -24,3 +26,5 @@ class permit(AbstractContextManager):
def __exit__(self, exctype, excinst, exctb):
return exctype is not None and not issubclass(exctype, self._exceptions)

321
rnsh/protocol.py Normal file
View File

@ -0,0 +1,321 @@
from __future__ import annotations
import enum
import queue
import threading
import time
import typing
import uuid
from types import TracebackType
from typing import Type, Callable, TypeVar, Tuple
import RNS
from RNS.vendor import umsgpack
import rnsh.retry
import abc
import contextlib
import struct
import logging as __logging
module_logger = __logging.getLogger(__name__)
_TReceipt = TypeVar("_TReceipt")
_TLink = TypeVar("_TLink")
MSG_MAGIC = 0xac
PROTOCOL_VERSION=1
def _make_MSGTYPE(val: int):
return ((MSG_MAGIC << 8) & 0xff00) | (val & 0x00ff)
class METype(enum.IntEnum):
ME_NO_MSG_TYPE = 0
ME_INVALID_MSG_TYPE = 1
ME_NOT_REGISTERED = 2
ME_LINK_NOT_READY = 3
ME_ALREADY_SENT = 4
class MessagingException(Exception):
def __init__(self, type: METype, *args):
super().__init__(args)
self.type = type
class MessageState(enum.IntEnum):
MSGSTATE_NEW = 0
MSGSTATE_SENT = 1
MSGSTATE_DELIVERED = 2
MSGSTATE_FAILED = 3
class Message(abc.ABC):
MSGTYPE = None
def __init__(self):
self.ts = time.time()
self.msgid = uuid.uuid4()
self.raw: bytes | None = None
self.receipt: _TReceipt = None
self.link: _TLink = None
self.tracked: bool = False
def __str__(self):
return f"{self.__class__.__name__} {self.msgid}"
@abc.abstractmethod
def pack(self) -> bytes:
raise NotImplemented()
@abc.abstractmethod
def unpack(self, raw):
raise NotImplemented()
def unwrap_MSGTYPE(self, raw: bytes) -> bytes:
if self.MSGTYPE is None:
raise MessagingException(METype.ME_NO_MSG_TYPE, f"{self.__class__} lacks MSGTYPE")
mid, raw = self.static_unwrap_MSGTYPE(raw)
if mid != self.MSGTYPE:
raise MessagingException(METype.ME_INVALID_MSG_TYPE,
f"invalid msg id, expected {hex(self.MSGTYPE)} got {hex(mid)}")
return raw
def wrap_MSGTYPE(self, raw: bytes) -> bytes:
if self.__class__.MSGTYPE is None:
raise MessagingException(METype.ME_NO_MSG_TYPE, f"{self.__class__} lacks MSGTYPE")
return struct.pack(">H", self.MSGTYPE) + raw
@staticmethod
def static_unwrap_MSGTYPE(raw: bytes) -> (int, bytes):
return struct.unpack(">H", raw[:2])[0], raw[2:]
class NoopMessage(Message):
MSGTYPE = _make_MSGTYPE(0)
def pack(self) -> bytes:
return self.wrap_MSGTYPE(bytes())
def unpack(self, raw):
self.unwrap_MSGTYPE(raw)
class WindowSizeMessage(Message):
MSGTYPE = _make_MSGTYPE(2)
def __init__(self, rows: int = None, cols: int = None, hpix: int = None, vpix: int = None):
super().__init__()
self.rows = rows
self.cols = cols
self.hpix = hpix
self.vpix = vpix
def pack(self) -> bytes:
raw = umsgpack.packb((self.rows, self.cols, self.hpix, self.vpix))
return self.wrap_MSGTYPE(raw)
def unpack(self, raw):
raw = self.unwrap_MSGTYPE(raw)
self.rows, self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw)
class ExecuteCommandMesssage(Message):
MSGTYPE = _make_MSGTYPE(3)
def __init__(self, cmdline: [str] = None, pipe_stdin: bool = False, pipe_stdout: bool = False,
pipe_stderr: bool = False, tcflags: [any] = None, term: str | None = None):
super().__init__()
self.cmdline = cmdline
self.pipe_stdin = pipe_stdin
self.pipe_stdout = pipe_stdout
self.pipe_stderr = pipe_stderr
self.tcflags = tcflags
self.term = term
def pack(self) -> bytes:
raw = umsgpack.packb((self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr,
self.tcflags, self.term))
return self.wrap_MSGTYPE(raw)
def unpack(self, raw):
raw = self.unwrap_MSGTYPE(raw)
self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, self.tcflags, self.term \
= umsgpack.unpackb(raw)
class StreamDataMessage(Message):
MSGTYPE = _make_MSGTYPE(4)
STREAM_ID_STDIN = 0
STREAM_ID_STDOUT = 1
STREAM_ID_STDERR = 2
def __init__(self, stream_id: int = None, data: bytes = None, eof: bool = False):
super().__init__()
self.stream_id = stream_id
self.data = data
self.eof = eof
def pack(self) -> bytes:
raw = umsgpack.packb((self.stream_id, self.eof, self.data))
return self.wrap_MSGTYPE(raw)
def unpack(self, raw):
raw = self.unwrap_MSGTYPE(raw)
self.stream_id, self.eof, self.data = umsgpack.unpackb(raw)
class VersionInfoMessage(Message):
MSGTYPE = _make_MSGTYPE(5)
def __init__(self, sw_version: str = None):
super().__init__()
self.sw_version = sw_version
self.protocol_version = PROTOCOL_VERSION
def pack(self) -> bytes:
raw = umsgpack.packb((self.sw_version, self.protocol_version))
return self.wrap_MSGTYPE(raw)
def unpack(self, raw):
raw = self.unwrap_MSGTYPE(raw)
self.sw_version, self.protocol_version = umsgpack.unpackb(raw)
class ErrorMessage(Message):
MSGTYPE = _make_MSGTYPE(6)
def __init__(self, msg: str = None, fatal: bool = False, data: dict = None):
super().__init__()
self.msg = msg
self.fatal = fatal
self.data = data
def pack(self) -> bytes:
raw = umsgpack.packb((self.msg, self.fatal, self.data))
return self.wrap_MSGTYPE(raw)
def unpack(self, raw: bytes):
raw = self.unwrap_MSGTYPE(raw)
self.msg, self.fatal, self.data = umsgpack.unpackb(raw)
class Messenger(contextlib.AbstractContextManager):
@staticmethod
def _get_msg_constructors() -> (int, Type[Message]):
subclass_tuples = []
for subclass in Message.__subclasses__():
subclass_tuples.append((subclass.MSGTYPE, subclass))
return subclass_tuples
def __init__(self, receipt_checker: Callable[[_TReceipt], MessageState],
link_timeout_callback: Callable[[_TLink], None],
link_mdu_getter: Callable[[_TLink], int],
link_rtt_getter: Callable[[_TLink], float],
link_usable_getter: Callable[[_TLink], bool],
packet_sender: Callable[[_TLink, bytes], _TReceipt],
retry_delay_min: float = 10.0):
self._log = module_logger.getChild(self.__class__.__name__)
self._receipt_checker = receipt_checker
self._link_timeout_callback = link_timeout_callback
self._link_mdu_getter = link_mdu_getter
self._link_rtt_getter = link_rtt_getter
self._link_usable_getter = link_usable_getter
self._packet_sender = packet_sender
self._sent_messages: list[Message] = []
self._lock = threading.RLock()
self._retry_timer = rnsh.retry.RetryThread()
self._message_factories = dict(self.__class__._get_msg_constructors())
self._inbound_queue = queue.Queue()
self._retry_delay_min = retry_delay_min
def __enter__(self):
pass
def __exit__(self, __exc_type: Type[BaseException] | None, __exc_value: BaseException | None,
__traceback: TracebackType | None) -> bool | None:
self.shutdown()
return False
def shutdown(self):
self._run = False
self._retry_timer.close()
def inbound(self, raw: bytes):
(mid, contents) = Message.static_unwrap_MSGTYPE(raw)
ctor = self._message_factories.get(mid, None)
if ctor is None:
raise MessagingException(METype.ME_NOT_REGISTERED, f"unable to find constructor for message type {hex(mid)}")
message = ctor()
message.unpack(raw)
self._log.debug("Message received: {message}")
self._inbound_queue.put(message)
def get_mdu(self, link: _TLink) -> int:
return self._link_mdu_getter(link) - 4
def get_rtt(self, link: _TLink) -> float:
return self._link_rtt_getter(link)
def is_link_ready(self, link: _TLink) -> bool:
if not self._link_usable_getter(link):
return False
with self._lock:
for message in self._sent_messages:
if message.link == link:
return False
return True
def send_message(self, link: _TLink, message: Message):
with self._lock:
if not self.is_link_ready(link):
raise MessagingException(METype.ME_LINK_NOT_READY, f"link {link} not ready")
if message in self._sent_messages:
raise MessagingException(METype.ME_ALREADY_SENT)
self._sent_messages.append(message)
message.tracked = True
if not message.raw:
message.raw = message.pack()
message.link = link
def send(tag: any, tries: int):
state = MessageState.MSGSTATE_NEW if not message.receipt else self._receipt_checker(message.receipt)
if state in [MessageState.MSGSTATE_NEW, MessageState.MSGSTATE_FAILED]:
try:
self._log.debug(f"Sending packet for {message}")
message.receipt = self._packet_sender(link, message.raw)
except Exception as ex:
self._log.exception(f"Error sending message {message}")
elif state in [MessageState.MSGSTATE_SENT]:
self._log.debug(f"Retry skipped, message still pending {message}")
elif state in [MessageState.MSGSTATE_DELIVERED]:
latency = round(time.time() - message.ts, 1)
self._log.debug(f"Message delivered {message.msgid} after {tries-1} tries/{latency} seconds")
with self._lock:
self._sent_messages.remove(message)
message.tracked = False
self._retry_timer.complete(link)
return link
def timeout(tag: any, tries: int):
latency = round(time.time() - message.ts, 1)
msg = "delivered" if message.receipt and self._receipt_checker(message.receipt) == MessageState.MSGSTATE_DELIVERED else "retry timeout"
self._log.debug(f"Message {msg} {message} after {tries} tries/{latency} seconds")
with self._lock:
self._sent_messages.remove(message)
message.tracked = False
self._link_timeout_callback(link)
rtt = self._link_rtt_getter(link)
self._retry_timer.begin(5, min(rtt * 100, max(rtt * 2, self._retry_delay_min)), send, timeout)
def poll_inbound(self, block: bool = True, timeout: float = None) -> Message | None:
try:
return self._inbound_queue.get(block=block, timeout=timeout)
except queue.Empty:
return None

243
tests/test_protocol.py Normal file
View File

@ -0,0 +1,243 @@
from __future__ import annotations
import logging
logging.getLogger().setLevel(logging.DEBUG)
import rnsh.protocol
import contextlib
import typing
import types
import time
import uuid
module_logger = logging.getLogger(__name__)
class Link:
def __init__(self, mdu: int, rtt: float):
self.link_id = uuid.uuid4()
self.timeout_callbacks = 0
self.mdu = mdu
self.rtt = rtt
self.usable = True
self.receipts = []
def timeout_callback(self):
self.timeout_callbacks += 1
def __str__(self):
return str(self.link_id)
class Receipt:
def __init__(self, link: Link, state: rnsh.protocol.MessageState, raw: bytes):
self.state = state
self.raw = raw
self.link = link
class ProtocolHarness(contextlib.AbstractContextManager):
def __init__(self, retry_delay_min: float = 1):
self._log = module_logger.getChild(self.__class__.__name__)
self.messenger = rnsh.protocol.Messenger(receipt_checker=self.receipt_checker,
link_timeout_callback=self.link_timeout_callback,
link_mdu_getter=self.link_mdu_getter,
link_rtt_getter=self.link_rtt_getter,
link_usable_getter=self.link_usable_getter,
packet_sender=self.packet_sender,
retry_delay_min=retry_delay_min)
def packet_sender(self, link: Link, raw: bytes) -> Receipt:
receipt = Receipt(link, rnsh.protocol.MessageState.MSGSTATE_SENT, raw)
link.receipts.append(receipt)
return receipt
@staticmethod
def link_mdu_getter(link: Link):
return link.mdu
@staticmethod
def link_rtt_getter(link: Link):
return link.rtt
@staticmethod
def link_usable_getter(link: Link):
return link.usable
@staticmethod
def receipt_checker(receipt: Receipt) -> rnsh.protocol.MessageState:
return receipt.state
@staticmethod
def link_timeout_callback(link: Link):
link.timeout_callback()
def cleanup(self):
self.messenger.shutdown()
def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
__traceback: types.TracebackType) -> bool:
# self._log.debug(f"__exit__({__exc_type}, {__exc_value}, {__traceback})")
self.cleanup()
return False
def test_mdu():
with ProtocolHarness() as h:
mdu = 500
link = Link(mdu=mdu, rtt=0.25)
assert h.messenger.get_mdu(link) == mdu - 4
link.mdu = mdu = 600
assert h.messenger.get_mdu(link) == mdu - 4
def test_rtt():
with ProtocolHarness() as h:
rtt = 0.25
link = Link(mdu=500, rtt=rtt)
assert h.messenger.get_rtt(link) == rtt
def test_send_one_retry():
rtt = 0.001
retry_interval = rtt * 150
message_content = b'Test'
with ProtocolHarness(retry_delay_min=retry_interval) as h:
link = Link(mdu=500, rtt=rtt)
message = rnsh.protocol.StreamDataMessage(stream_id=rnsh.protocol.StreamDataMessage.STREAM_ID_STDIN,
data=message_content, eof=True)
assert len(link.receipts) == 0
h.messenger.send_message(link, message)
assert message.tracked
assert message.raw is not None
assert len(link.receipts) == 1
receipt = link.receipts[0]
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
assert receipt.raw == message.raw
time.sleep(retry_interval * 1.5)
assert len(link.receipts) == 1
receipt.state = rnsh.protocol.MessageState.MSGSTATE_FAILED
module_logger.info("set failed")
time.sleep(retry_interval)
assert len(link.receipts) == 2
receipt = link.receipts[1]
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
receipt.state = rnsh.protocol.MessageState.MSGSTATE_DELIVERED
time.sleep(retry_interval)
assert len(link.receipts) == 2
assert not message.tracked
def eat_own_dog_food(message: rnsh.protocol.Message, checker: typing.Callable[[rnsh.protocol.Message], None]):
rtt = 0.001
retry_interval = rtt * 150
with ProtocolHarness(retry_delay_min=retry_interval) as h:
link = Link(mdu=500, rtt=rtt)
assert len(link.receipts) == 0
h.messenger.send_message(link, message)
assert message.tracked
assert message.raw is not None
assert len(link.receipts) == 1
receipt = link.receipts[0]
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
assert receipt.raw == message.raw
module_logger.info("set delivered")
receipt.state = rnsh.protocol.MessageState.MSGSTATE_DELIVERED
time.sleep(retry_interval * 2)
assert len(link.receipts) == 1
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_DELIVERED
assert not message.tracked
module_logger.info("injecting rx message")
h.messenger.inbound(message.raw)
rx_message = h.messenger.poll_inbound(block=False)
assert rx_message is not None
assert isinstance(rx_message, message.__class__)
assert rx_message.msgid != message.msgid
checker(rx_message)
def test_send_receive_streamdata():
message = rnsh.protocol.StreamDataMessage(stream_id=rnsh.protocol.StreamDataMessage.STREAM_ID_STDIN,
data=b'Test', eof=True)
def check(rx_message: rnsh.protocol.Message):
assert isinstance(rx_message, message.__class__)
assert rx_message.stream_id == message.stream_id
assert rx_message.data == message.data
assert rx_message.eof == message.eof
eat_own_dog_food(message, check)
def test_send_receive_noop():
message = rnsh.protocol.NoopMessage()
def check(rx_message: rnsh.protocol.Message):
assert isinstance(rx_message, message.__class__)
eat_own_dog_food(message, check)
def test_send_receive_execute():
message = rnsh.protocol.ExecuteCommandMesssage(cmdline=["test", "one", "two"],
pipe_stdin=False,
pipe_stdout=True,
pipe_stderr=False,
tcflags=[12, 34, 56, [78, 90]],
term="xtermmmm")
def check(rx_message: rnsh.protocol.Message):
assert isinstance(rx_message, message.__class__)
assert rx_message.cmdline == message.cmdline
assert rx_message.pipe_stdin == message.pipe_stdin
assert rx_message.pipe_stdout == message.pipe_stdout
assert rx_message.pipe_stderr == message.pipe_stderr
assert rx_message.tcflags == message.tcflags
assert rx_message.term == message.term
eat_own_dog_food(message, check)
def test_send_receive_windowsize():
message = rnsh.protocol.WindowSizeMessage(1, 2, 3, 4)
def check(rx_message: rnsh.protocol.Message):
assert isinstance(rx_message, message.__class__)
assert rx_message.rows == message.rows
assert rx_message.cols == message.cols
assert rx_message.hpix == message.hpix
assert rx_message.vpix == message.vpix
eat_own_dog_food(message, check)
def test_send_receive_versioninfo():
message = rnsh.protocol.VersionInfoMessage(sw_version="1.2.3")
message.protocol_version = 30
def check(rx_message: rnsh.protocol.Message):
assert isinstance(rx_message, message.__class__)
assert rx_message.sw_version == message.sw_version
assert rx_message.protocol_version == message.protocol_version
eat_own_dog_food(message, check)
def test_send_receive_error():
message = rnsh.protocol.ErrorMessage(msg="TESTerr",
fatal=True,
data={"one": 2})
def check(rx_message: rnsh.protocol.Message):
assert isinstance(rx_message, message.__class__)
assert rx_message.msg == message.msg
assert rx_message.fatal == message.fatal
assert rx_message.data == message.data
eat_own_dog_food(message, check)