Got the new protocol working.

This commit is contained in:
Aaron Heise 2023-02-18 00:09:28 -06:00
parent 0ee305795f
commit 8edb4020b1
8 changed files with 876 additions and 887 deletions

8
rnsh/helpers.py Normal file
View File

@ -0,0 +1,8 @@
def bitwise_or_if(value: int, condition: bool, orval: int):
if not condition:
return value
return value | orval
def check_and(value: int, andval: int) -> bool:
return (value & andval) > 0

View File

@ -382,8 +382,11 @@ def _launch_child(cmd_line: list[str], env: dict[str, str], stdin_is_pipe: bool,
# Make PTY controlling if necessary # Make PTY controlling if necessary
if child_fd is not None: if child_fd is not None:
os.setsid() os.setsid()
tmp_fd = os.open(os.ttyname(0 if not stdin_is_pipe else 1 if not stdout_is_pipe else 2), os.O_RDWR) try:
os.close(tmp_fd) tmp_fd = os.open(os.ttyname(0 if not stdin_is_pipe else 1 if not stdout_is_pipe else 2), os.O_RDWR)
os.close(tmp_fd)
except:
pass
# fcntl.ioctl(0 if not stdin_is_pipe else 1 if not stdout_is_pipe else 2), os.O_RDWR, termios.TIOCSCTTY, 0) # fcntl.ioctl(0 if not stdin_is_pipe else 1 if not stdout_is_pipe else 2), os.O_RDWR, termios.TIOCSCTTY, 0)
# Execute the command # Execute the command
@ -445,7 +448,8 @@ class CallbackSubprocess:
self._child_stdout: int = None self._child_stdout: int = None
self._child_stderr: int = None self._child_stderr: int = None
self._return_code: int = None self._return_code: int = None
self._eof: bool = False self._stdout_eof: bool = False
self._stderr_eof: bool = False
self._stdin_is_pipe = stdin_is_pipe self._stdin_is_pipe = stdin_is_pipe
self._stdout_is_pipe = stdout_is_pipe self._stdout_is_pipe = stdout_is_pipe
self._stderr_is_pipe = stderr_is_pipe self._stderr_is_pipe = stderr_is_pipe
@ -455,6 +459,21 @@ class CallbackSubprocess:
Terminate child process if running Terminate child process if running
:param kill_delay: if after kill_delay seconds the child process has not exited, escalate to SIGHUP and SIGKILL :param kill_delay: if after kill_delay seconds the child process has not exited, escalate to SIGHUP and SIGKILL
""" """
same = self._child_stdout == self._child_stderr
with contextlib.suppress(OSError):
if not self._stdout_eof:
tty_unset_reader_callbacks(self._child_stdout)
os.close(self._child_stdout)
self._child_stdout = None
if not same:
with contextlib.suppress(OSError):
if not self._stderr_eof:
tty_unset_reader_callbacks(self._child_stderr)
os.close(self._child_stderr)
self._child_stdout = None
self._log.debug("terminate()") self._log.debug("terminate()")
if not self.running: if not self.running:
return return
@ -477,7 +496,7 @@ class CallbackSubprocess:
os.waitpid(self._pid, 0) os.waitpid(self._pid, 0)
self._log.debug("wait() finish") self._log.debug("wait() finish")
threading.Thread(target=wait).start() threading.Thread(target=wait, daemon=True).start()
def close_stdin(self): def close_stdin(self):
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
@ -592,26 +611,44 @@ class CallbackSubprocess:
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll) self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
def reader(fd: int, callback: callable): def stdout():
try: try:
with exception.permit(SystemExit): with exception.permit(SystemExit):
data = tty_read(fd) data = tty_read_poll(self._child_stdout)
if data is not None and len(data) > 0: if data is not None and len(data) > 0:
callback(data) self._stdout_cb(data)
except EOFError: except EOFError:
self._eof = True self._stdout_eof = True
tty_unset_reader_callbacks(self._child_stdout) tty_unset_reader_callbacks(self._child_stdout)
callback(bytearray()) self._stdout_cb(bytearray())
with contextlib.suppress(OSError):
os.close(self._child_stdout)
tty_add_reader_callback(self._child_stdout, functools.partial(reader, self._child_stdout, self._stdout_cb), def stderr():
self._loop) try:
with exception.permit(SystemExit):
data = tty_read_poll(self._child_stderr)
if data is not None and len(data) > 0:
self._stderr_cb(data)
except EOFError:
self._stderr_eof = True
tty_unset_reader_callbacks(self._child_stderr)
self._stdout_cb(bytearray())
with contextlib.suppress(OSError):
os.close(self._child_stderr)
tty_add_reader_callback(self._child_stdout, stdout, self._loop)
if self._child_stderr != self._child_stdout: if self._child_stderr != self._child_stdout:
tty_add_reader_callback(self._child_stderr, functools.partial(reader, self._child_stderr, self._stderr_cb), tty_add_reader_callback(self._child_stderr, stderr, self._loop)
self._loop)
@property @property
def eof(self): def stdout_eof(self):
return self._eof or not self.running return self._stdout_eof or not self.running
@property
def stderr_eof(self):
return self._stderr_eof or not self.running
@property @property
def return_code(self) -> int: def return_code(self) -> int:

View File

@ -15,6 +15,7 @@ import abc
import contextlib import contextlib
import struct import struct
import logging as __logging import logging as __logging
from abc import ABC, abstractmethod
module_logger = __logging.getLogger(__name__) module_logger = __logging.getLogger(__name__)
@ -22,13 +23,50 @@ module_logger = __logging.getLogger(__name__)
_TReceipt = TypeVar("_TReceipt") _TReceipt = TypeVar("_TReceipt")
_TLink = TypeVar("_TLink") _TLink = TypeVar("_TLink")
MSG_MAGIC = 0xac MSG_MAGIC = 0xac
PROTOCOL_VERSION=1 PROTOCOL_VERSION = 1
def _make_MSGTYPE(val: int): def _make_MSGTYPE(val: int):
return ((MSG_MAGIC << 8) & 0xff00) | (val & 0x00ff) return ((MSG_MAGIC << 8) & 0xff00) | (val & 0x00ff)
class MessageOutletBase(ABC):
@abstractmethod
def send(self, raw: bytes) -> _TReceipt:
raise NotImplemented()
@property
@abstractmethod
def mdu(self):
raise NotImplemented()
@property
@abstractmethod
def rtt(self):
raise NotImplemented()
@property
@abstractmethod
def is_usuable(self):
raise NotImplemented()
@abstractmethod
def get_receipt_state(self, receipt: _TReceipt) -> MessageState:
raise NotImplemented()
@abstractmethod
def timed_out(self):
raise NotImplemented()
@abstractmethod
def __str__(self):
raise NotImplemented()
@abstractmethod
def set_packet_received_callback(self, cb: Callable[[MessageOutletBase, bytes], None]):
raise NotImplemented()
class METype(enum.IntEnum): class METype(enum.IntEnum):
ME_NO_MSG_TYPE = 0 ME_NO_MSG_TYPE = 0
ME_INVALID_MSG_TYPE = 1 ME_INVALID_MSG_TYPE = 1
@ -58,17 +96,17 @@ class Message(abc.ABC):
self.msgid = uuid.uuid4() self.msgid = uuid.uuid4()
self.raw: bytes | None = None self.raw: bytes | None = None
self.receipt: _TReceipt = None self.receipt: _TReceipt = None
self.link: _TLink = None self.outlet: _TLink = None
self.tracked: bool = False self.tracked: bool = False
def __str__(self): def __str__(self):
return f"{self.__class__.__name__} {self.msgid}" return f"{self.__class__.__name__} {self.msgid}"
@abc.abstractmethod @abstractmethod
def pack(self) -> bytes: def pack(self) -> bytes:
raise NotImplemented() raise NotImplemented()
@abc.abstractmethod @abstractmethod
def unpack(self, raw): def unpack(self, raw):
raise NotImplemented() raise NotImplemented()
@ -124,7 +162,8 @@ class ExecuteCommandMesssage(Message):
MSGTYPE = _make_MSGTYPE(3) MSGTYPE = _make_MSGTYPE(3)
def __init__(self, cmdline: [str] = None, pipe_stdin: bool = False, pipe_stdout: bool = False, def __init__(self, cmdline: [str] = None, pipe_stdin: bool = False, pipe_stdout: bool = False,
pipe_stderr: bool = False, tcflags: [any] = None, term: str | None = None): pipe_stderr: bool = False, tcflags: [any] = None, term: str | None = None, rows: int = None,
cols: int = None, hpix: int = None, vpix: int = None):
super().__init__() super().__init__()
self.cmdline = cmdline self.cmdline = cmdline
self.pipe_stdin = pipe_stdin self.pipe_stdin = pipe_stdin
@ -132,16 +171,20 @@ class ExecuteCommandMesssage(Message):
self.pipe_stderr = pipe_stderr self.pipe_stderr = pipe_stderr
self.tcflags = tcflags self.tcflags = tcflags
self.term = term self.term = term
self.rows = rows
self.cols = cols
self.hpix = hpix
self.vpix = vpix
def pack(self) -> bytes: def pack(self) -> bytes:
raw = umsgpack.packb((self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, raw = umsgpack.packb((self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr,
self.tcflags, self.term)) self.tcflags, self.term, self.rows, self.cols, self.hpix, self.vpix))
return self.wrap_MSGTYPE(raw) return self.wrap_MSGTYPE(raw)
def unpack(self, raw): def unpack(self, raw):
raw = self.unwrap_MSGTYPE(raw) raw = self.unwrap_MSGTYPE(raw)
self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, self.tcflags, self.term \ self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, self.tcflags, self.term, self.rows, \
= umsgpack.unpackb(raw) self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw)
class StreamDataMessage(Message): class StreamDataMessage(Message):
MSGTYPE = _make_MSGTYPE(4) MSGTYPE = _make_MSGTYPE(4)
@ -156,7 +199,7 @@ class StreamDataMessage(Message):
self.eof = eof self.eof = eof
def pack(self) -> bytes: def pack(self) -> bytes:
raw = umsgpack.packb((self.stream_id, self.eof, self.data)) raw = umsgpack.packb((self.stream_id, self.eof, bytes(self.data)))
return self.wrap_MSGTYPE(raw) return self.wrap_MSGTYPE(raw)
def unpack(self, raw): def unpack(self, raw):
@ -169,7 +212,7 @@ class VersionInfoMessage(Message):
def __init__(self, sw_version: str = None): def __init__(self, sw_version: str = None):
super().__init__() super().__init__()
self.sw_version = sw_version self.sw_version = sw_version or rnsh.__version__
self.protocol_version = PROTOCOL_VERSION self.protocol_version = PROTOCOL_VERSION
def pack(self) -> bytes: def pack(self) -> bytes:
@ -199,6 +242,22 @@ class ErrorMessage(Message):
self.msg, self.fatal, self.data = umsgpack.unpackb(raw) self.msg, self.fatal, self.data = umsgpack.unpackb(raw)
class CommandExitedMessage(Message):
MSGTYPE = _make_MSGTYPE(7)
def __init__(self, return_code: int = None):
super().__init__()
self.return_code = return_code
def pack(self) -> bytes:
raw = umsgpack.packb(self.return_code)
return self.wrap_MSGTYPE(raw)
def unpack(self, raw: bytes):
raw = self.unwrap_MSGTYPE(raw)
self.return_code = umsgpack.unpackb(raw)
class Messenger(contextlib.AbstractContextManager): class Messenger(contextlib.AbstractContextManager):
@staticmethod @staticmethod
@ -208,29 +267,16 @@ class Messenger(contextlib.AbstractContextManager):
subclass_tuples.append((subclass.MSGTYPE, subclass)) subclass_tuples.append((subclass.MSGTYPE, subclass))
return subclass_tuples return subclass_tuples
def __init__(self, receipt_checker: Callable[[_TReceipt], MessageState], def __init__(self, retry_delay_min: float = 10.0):
link_timeout_callback: Callable[[_TLink], None],
link_mdu_getter: Callable[[_TLink], int],
link_rtt_getter: Callable[[_TLink], float],
link_usable_getter: Callable[[_TLink], bool],
packet_sender: Callable[[_TLink, bytes], _TReceipt],
retry_delay_min: float = 10.0):
self._log = module_logger.getChild(self.__class__.__name__) self._log = module_logger.getChild(self.__class__.__name__)
self._receipt_checker = receipt_checker
self._link_timeout_callback = link_timeout_callback
self._link_mdu_getter = link_mdu_getter
self._link_rtt_getter = link_rtt_getter
self._link_usable_getter = link_usable_getter
self._packet_sender = packet_sender
self._sent_messages: list[Message] = [] self._sent_messages: list[Message] = []
self._lock = threading.RLock() self._lock = threading.RLock()
self._retry_timer = rnsh.retry.RetryThread() self._retry_timer = rnsh.retry.RetryThread()
self._message_factories = dict(self.__class__._get_msg_constructors()) self._message_factories = dict(self.__class__._get_msg_constructors())
self._inbound_queue = queue.Queue()
self._retry_delay_min = retry_delay_min self._retry_delay_min = retry_delay_min
def __enter__(self): def __enter__(self) -> Messenger:
pass return self
def __exit__(self, __exc_type: Type[BaseException] | None, __exc_value: BaseException | None, def __exit__(self, __exc_type: Type[BaseException] | None, __exc_value: BaseException | None,
__traceback: TracebackType | None) -> bool | None: __traceback: TracebackType | None) -> bool | None:
@ -238,39 +284,38 @@ class Messenger(contextlib.AbstractContextManager):
return False return False
def shutdown(self): def shutdown(self):
self._run = False
self._retry_timer.close() self._retry_timer.close()
def inbound(self, raw: bytes): def clear_retries(self, outlet):
self._retry_timer.complete(outlet)
def receive(self, raw: bytes) -> Message:
(mid, contents) = Message.static_unwrap_MSGTYPE(raw) (mid, contents) = Message.static_unwrap_MSGTYPE(raw)
ctor = self._message_factories.get(mid, None) ctor = self._message_factories.get(mid, None)
if ctor is None: if ctor is None:
raise MessagingException(METype.ME_NOT_REGISTERED, f"unable to find constructor for message type {hex(mid)}") raise MessagingException(METype.ME_NOT_REGISTERED, f"unable to find constructor for message type {hex(mid)}")
message = ctor() message = ctor()
message.unpack(raw) message.unpack(raw)
self._log.debug("Message received: {message}") self._log.debug(f"Message received: {message}")
self._inbound_queue.put(message) return message
def get_mdu(self, link: _TLink) -> int: def is_outlet_ready(self, outlet: MessageOutletBase) -> bool:
return self._link_mdu_getter(link) - 4 if not outlet.is_usuable:
self._log.debug("is_outlet_ready outlet unusable")
def get_rtt(self, link: _TLink) -> float:
return self._link_rtt_getter(link)
def is_link_ready(self, link: _TLink) -> bool:
if not self._link_usable_getter(link):
return False return False
with self._lock: with self._lock:
for message in self._sent_messages: for message in self._sent_messages:
if message.link == link: if message.outlet == outlet and message.tracked and message.receipt \
and outlet.get_receipt_state(message.receipt) == MessageState.MSGSTATE_SENT:
self._log.debug("is_outlet_ready pending message found")
return False return False
return True return True
def send_message(self, link: _TLink, message: Message): def send(self, outlet: MessageOutletBase, message: Message):
with self._lock: with self._lock:
if not self.is_link_ready(link): if not self.is_outlet_ready(outlet):
raise MessagingException(METype.ME_LINK_NOT_READY, f"link {link} not ready") raise MessagingException(METype.ME_LINK_NOT_READY, f"link {outlet} not ready")
if message in self._sent_messages: if message in self._sent_messages:
raise MessagingException(METype.ME_ALREADY_SENT) raise MessagingException(METype.ME_ALREADY_SENT)
@ -279,43 +324,36 @@ class Messenger(contextlib.AbstractContextManager):
if not message.raw: if not message.raw:
message.raw = message.pack() message.raw = message.pack()
message.link = link message.outlet = outlet
def send(tag: any, tries: int): def send_inner(tag: any, tries: int):
state = MessageState.MSGSTATE_NEW if not message.receipt else self._receipt_checker(message.receipt) state = MessageState.MSGSTATE_NEW if not message.receipt else outlet.get_receipt_state(message.receipt)
if state in [MessageState.MSGSTATE_NEW, MessageState.MSGSTATE_FAILED]: if state in [MessageState.MSGSTATE_NEW, MessageState.MSGSTATE_FAILED]:
try: try:
self._log.debug(f"Sending packet for {message}") self._log.debug(f"Sending packet for {message}")
message.receipt = self._packet_sender(link, message.raw) message.receipt = outlet.send(message.raw)
except Exception as ex: except Exception as ex:
self._log.exception(f"Error sending message {message}") self._log.exception(f"Error sending message {message}")
elif state in [MessageState.MSGSTATE_SENT]: elif state in [MessageState.MSGSTATE_SENT]:
self._log.debug(f"Retry skipped, message still pending {message}") self._log.debug(f"Retry skipped, message still pending {message}")
elif state in [MessageState.MSGSTATE_DELIVERED]: elif state in [MessageState.MSGSTATE_DELIVERED]:
latency = round(time.time() - message.ts, 1) latency = round(time.time() - message.ts, 1)
self._log.debug(f"Message delivered {message.msgid} after {tries-1} tries/{latency} seconds") self._log.debug(f"{message} delivered {message.msgid} after {tries-1} tries/{latency} seconds")
with self._lock: with self._lock:
self._sent_messages.remove(message) self._sent_messages.remove(message)
message.tracked = False message.tracked = False
self._retry_timer.complete(link) self._retry_timer.complete(outlet)
return link return outlet
def timeout(tag: any, tries: int): def timeout(tag: any, tries: int):
latency = round(time.time() - message.ts, 1) latency = round(time.time() - message.ts, 1)
msg = "delivered" if message.receipt and self._receipt_checker(message.receipt) == MessageState.MSGSTATE_DELIVERED else "retry timeout" msg = "delivered" if message.receipt and outlet.get_receipt_state(message.receipt) == MessageState.MSGSTATE_DELIVERED else "retry timeout"
self._log.debug(f"Message {msg} {message} after {tries} tries/{latency} seconds") self._log.debug(f"Message {msg} {message} after {tries} tries/{latency} seconds")
with self._lock: with self._lock:
self._sent_messages.remove(message) self._sent_messages.remove(message)
message.tracked = False message.tracked = False
self._link_timeout_callback(link) outlet.timed_out()
rtt = self._link_rtt_getter(link)
self._retry_timer.begin(5, min(rtt * 100, max(rtt * 2, self._retry_delay_min)), send, timeout)
def poll_inbound(self, block: bool = True, timeout: float = None) -> Message | None:
try:
return self._inbound_queue.get(block=block, timeout=timeout)
except queue.Empty:
return None
rtt = outlet.rtt
self._retry_timer.begin(5, max(rtt * 5, self._retry_delay_min), send_inner, timeout)

View File

@ -79,7 +79,7 @@ class RetryThread(AbstractContextManager):
self._lock = threading.RLock() self._lock = threading.RLock()
self._run = True self._run = True
self._finished: asyncio.Future = None self._finished: asyncio.Future = None
self._thread = threading.Thread(name=name, target=self._thread_run) self._thread = threading.Thread(name=name, target=self._thread_run, daemon=True)
self._thread.start() self._thread.start()
def is_alive(self): def is_alive(self):

File diff suppressed because it is too large Load Diff

423
rnsh/session.py Normal file
View File

@ -0,0 +1,423 @@
from __future__ import annotations
import contextlib
import functools
import threading
import rnsh.exception as exception
import asyncio
import rnsh.process as process
import rnsh.helpers as helpers
import rnsh.protocol as protocol
import enum
from typing import TypeVar, Generic, Callable, List
from abc import abstractmethod, ABC
from multiprocessing import Manager
import os
import RNS
import logging as __logging
from rnsh.protocol import MessageOutletBase, _TReceipt, MessageState
module_logger = __logging.getLogger(__name__)
_TLink = TypeVar("_TLink")
class SEType(enum.IntEnum):
SE_LINK_CLOSED = 0
class SessionException(Exception):
def __init__(self, setype: SEType, msg: str, *args):
super().__init__(msg, args)
self.type = setype
class LSState(enum.IntEnum):
LSSTATE_WAIT_IDENT = 1
LSSTATE_WAIT_VERS = 2
LSSTATE_WAIT_CMD = 3
LSSTATE_RUNNING = 4
LSSTATE_ERROR = 5
LSSTATE_TEARDOWN = 6
_TIdentity = TypeVar("_TIdentity")
class LSOutletBase(protocol.MessageOutletBase):
@abstractmethod
def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]):
raise NotImplemented()
@abstractmethod
def set_link_closed_callback(self, cb: Callable[[LSOutletBase], None]):
raise NotImplemented()
@abstractmethod
def unset_link_closed_callback(self):
raise NotImplemented()
@abstractmethod
def teardown(self):
raise NotImplemented()
@abstractmethod
def __init__(self):
raise NotImplemented()
class ListenerSession:
sessions: List[ListenerSession] = []
messenger: protocol.Messenger = protocol.Messenger(retry_delay_min=5)
allowed_identity_hashes: [any] = []
allow_all: bool = False
allow_remote_command: bool = False
default_command: [str] = []
remote_cmd_as_args = False
def __init__(self, outlet: LSOutletBase, loop: asyncio.AbstractEventLoop):
self._log = module_logger.getChild(self.__class__.__name__)
self._log.info(f"Session started for {outlet}")
self.outlet = outlet
self.outlet.set_initiator_identified_callback(self._initiator_identified)
self.outlet.set_link_closed_callback(self._link_closed)
self.loop = loop
self.state: LSState = None
self.remote_identity = None
self.term: str | None = None
self.stdin_is_pipe: bool = False
self.stdout_is_pipe: bool = False
self.stderr_is_pipe: bool = False
self.tcflags: [any] = None
self.cmdline: [str] = None
self.rows: int = 0
self.cols: int = 0
self.hpix: int = 0
self.vpix: int = 0
self.stdout_buf = bytearray()
self.stdout_eof_sent = False
self.stderr_buf = bytearray()
self.stderr_eof_sent = False
self.return_code: int | None = None
self.return_code_sent = False
self.process: process.CallbackSubprocess | None = None
self._set_state(LSState.LSSTATE_WAIT_IDENT)
self.sessions.append(self)
def _terminated(self, return_code: int):
self.return_code = return_code
def _set_state(self, state: LSState, timeout_factor: float = 10.0):
timeout = max(self.outlet.rtt * timeout_factor, max(self.outlet.rtt * 2, 10)) if timeout_factor is not None else None
self._log.debug(f"Set state: {state.name}, timeout {timeout}")
orig_state = self.state
self.state = state
if timeout_factor is not None:
self._call(functools.partial(self._check_protocol_timeout, lambda: self.state == orig_state, state.name), timeout)
def _call(self, func: callable, delay: float = 0):
def call_inner():
# self._log.debug("call_inner")
if delay == 0:
func()
else:
self.loop.call_later(delay, func)
self.loop.call_soon_threadsafe(call_inner)
def send(self, message: protocol.Message):
self.messenger.send(self.outlet, message)
def _protocol_error(self, name: str):
self.terminate(f"Protocol error ({name})")
def _protocol_timeout_error(self, name: str):
self.terminate(f"Protocol timeout error: {name}")
def terminate(self, error: str = None):
with contextlib.suppress(Exception):
self._log.debug("Terminating session" + (f": {error}" if error else ""))
if error and self.state != LSState.LSSTATE_TEARDOWN:
with contextlib.suppress(Exception):
self.send(protocol.ErrorMessage(error, True))
self.state = LSState.LSSTATE_ERROR
self._terminate_process()
self._call(self._prune, max(self.outlet.rtt * 3, 5))
def _prune(self):
self.state = LSState.LSSTATE_TEARDOWN
with contextlib.suppress(ValueError):
self.sessions.remove(self)
with contextlib.suppress(Exception):
self.outlet.teardown()
def _check_protocol_timeout(self, fail_condition: Callable[[], bool], name: str):
timeout = True
try:
timeout = self.state != LSState.LSSTATE_TEARDOWN and fail_condition()
except Exception as ex:
self._log.exception("Error in protocol timeout", ex)
if timeout:
self._protocol_timeout_error(name)
def _link_closed(self, outlet: LSOutletBase):
outlet.unset_link_closed_callback()
if outlet != self.outlet:
self._log.debug("Link closed received from incorrect outlet")
return
self._log.debug(f"link_closed {outlet}")
self.messenger.clear_retries(self.outlet)
self.terminate()
def _initiator_identified(self, outlet, identity):
if outlet != self.outlet:
self._log.debug("Identity received from incorrect outlet")
return
self._log.info(f"initiator_identified {identity} on link {outlet}")
if self.state != LSState.LSSTATE_WAIT_IDENT:
self._protocol_error(LSState.LSSTATE_WAIT_IDENT.name)
if not self.allow_all and identity.hash not in self.allowed_identity_hashes:
self.terminate("Identity is not allowed.")
self.remote_identity = identity
self.outlet.set_packet_received_callback(self._packet_received)
self._set_state(LSState.LSSTATE_WAIT_VERS)
@classmethod
async def pump_all(cls):
for session in cls.sessions:
session.pump()
await asyncio.sleep(0)
@classmethod
async def terminate_all(cls, reason: str):
for session in cls.sessions:
session.terminate(reason)
await asyncio.sleep(0)
def pump(self):
try:
if self.state != LSState.LSSTATE_RUNNING:
return
elif not self.messenger.is_outlet_ready(self.outlet):
return
elif len(self.stderr_buf) > 0:
mdu = self.outlet.mdu - 16
data = self.stderr_buf[:mdu]
self.stderr_buf = self.stderr_buf[mdu:]
send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent
self.stderr_eof_sent = self.stderr_eof_sent or send_eof
msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDERR,
data, send_eof)
self.send(msg)
if send_eof:
self.stderr_eof_sent = True
elif len(self.stdout_buf) > 0:
mdu = self.outlet.mdu - 16
data = self.stdout_buf[:mdu]
self.stdout_buf = self.stdout_buf[mdu:]
send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent
self.stdout_eof_sent = self.stdout_eof_sent or send_eof
msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDOUT,
data, send_eof)
self.send(msg)
elif self.return_code is not None and not self.return_code_sent:
msg = protocol.CommandExitedMessage(self.return_code)
self.send(msg)
self.return_code_sent = True
self._call(functools.partial(self._check_protocol_timeout,
lambda: self.state == LSState.LSSTATE_RUNNING, "CommandExitedMessage"),
max(self.outlet.rtt * 5, 10))
except Exception as ex:
self._log.exception("Error during pump", ex)
def _terminate_process(self):
with contextlib.suppress(Exception):
if self.process and self.process.running:
self.process.terminate()
def _start_cmd(self, cmdline: [str], pipe_stdin: bool, pipe_stdout: bool, pipe_stderr: bool, tcflags: [any],
term: str | None, rows: int, cols: int, hpix: int, vpix: int):
self.cmdline = self.default_command
if not self.allow_remote_command and cmdline and len(cmdline) > 0:
self.terminate("Remote command line not allowed by listener")
return
if self.remote_cmd_as_args and cmdline and len(cmdline) > 0:
self.cmdline.extend(cmdline)
elif cmdline and len(cmdline) > 0:
self.cmdline = cmdline
self.stdin_is_pipe = pipe_stdin
self.stdout_is_pipe = pipe_stdout
self.stderr_is_pipe = pipe_stderr
self.tcflags = tcflags
self.term = term
def stdout(data: bytes):
self.stdout_buf.extend(data)
def stderr(data: bytes):
self.stderr_buf.extend(data)
try:
self.process = process.CallbackSubprocess(argv=self.cmdline,
env={"TERM": self.term or os.environ.get("TERM", None),
"RNS_REMOTE_IDENTITY": RNS.prettyhexrep(self.remote_identity.hash) or ""},
loop=self.loop,
stdout_callback=stdout,
stderr_callback=stderr,
terminated_callback=self._terminated,
stdin_is_pipe=self.stdin_is_pipe,
stdout_is_pipe=self.stdout_is_pipe,
stderr_is_pipe=self.stderr_is_pipe)
self.process.start()
self._set_window_size(rows, cols, hpix, vpix)
except Exception as ex:
self._log.exception(f"Unable to start process for link {self.outlet}", ex)
self.terminate("Unable to start process")
def _set_window_size(self, rows: int, cols: int, hpix: int, vpix: int):
self.rows = rows
self.cols = cols
self.hpix = hpix
self.vpix = vpix
with contextlib.suppress(Exception):
self.process.set_winsize(rows, cols, hpix, vpix)
def _received_stdin(self, data: bytes, eof: bool):
if data and len(data) > 0:
self.process.write(data)
if eof:
self.process.close_stdin()
def _handle_message(self, message: protocol.Message):
if self.state == LSState.LSSTATE_WAIT_VERS:
if not isinstance(message, protocol.VersionInfoMessage):
self._protocol_error(self.state.name)
return
self._log.info(f"version {message.sw_version}, protocol {message.protocol_version} on link {self.outlet}")
if message.protocol_version != protocol.PROTOCOL_VERSION:
self.terminate("Incompatible protocol")
return
self.send(protocol.VersionInfoMessage())
self._set_state(LSState.LSSTATE_WAIT_CMD)
return
elif self.state == LSState.LSSTATE_WAIT_CMD:
if not isinstance(message, protocol.ExecuteCommandMesssage):
return self._protocol_error(self.state.name)
self._log.info(f"Execute command message on link {self.outlet}: {message.cmdline}")
self._set_state(LSState.LSSTATE_RUNNING)
self._start_cmd(message.cmdline, message.pipe_stdin, message.pipe_stdout, message.pipe_stderr,
message.tcflags, message.term, message.rows, message.cols, message.hpix, message.vpix)
return
elif self.state == LSState.LSSTATE_RUNNING:
if isinstance(message, protocol.WindowSizeMessage):
self._set_window_size(message.rows, message.cols, message.hpix, message.vpix)
elif isinstance(message, protocol.StreamDataMessage):
if message.stream_id != protocol.StreamDataMessage.STREAM_ID_STDIN:
self._log.error(f"Received stream data for invalid stream {message.stream_id} on link {self.outlet}")
return self._protocol_error(self.state.name)
self._received_stdin(message.data, message.eof)
return
elif isinstance(message, protocol.NoopMessage):
# echo noop only on listener--used for keepalive/connectivity check
self.send(message)
return
elif self.state in [LSState.LSSTATE_ERROR, LSState.LSSTATE_TEARDOWN]:
self._log.error(f"Received packet, but in state {self.state.name}")
return
else:
self._protocol_error("unexpected message")
return
def _packet_received(self, outlet: protocol.MessageOutletBase, raw: bytes):
if outlet != self.outlet:
self._log.debug("Packet received from incorrect outlet")
return
try:
message = self.messenger.receive(raw)
self._handle_message(message)
except Exception as ex:
self._protocol_error("unusable packet")
class RNSOutlet(LSOutletBase):
def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]):
def inner_cb(link, identity: _TIdentity):
cb(self, identity)
self.link.set_remote_identified_callback(inner_cb)
def set_link_closed_callback(self, cb: Callable[[LSOutletBase], None]):
def inner_cb(link):
cb(self)
self.link.set_link_closed_callback(inner_cb)
def unset_link_closed_callback(self):
self.link.set_link_closed_callback(None)
def teardown(self):
self.link.teardown()
def send(self, raw: bytes) -> RNS.PacketReceipt:
packet = RNS.Packet(self.link, raw)
packet.send()
return packet.receipt
@property
def mdu(self) -> int:
return self.link.MDU
@property
def rtt(self) -> float:
return self.link.rtt
@property
def is_usuable(self):
return True #self.link.status in [RNS.Link.ACTIVE]
def get_receipt_state(self, receipt: RNS.PacketReceipt) -> MessageState:
status = receipt.get_status()
if status == RNS.PacketReceipt.SENT:
return protocol.MessageState.MSGSTATE_SENT
if status == RNS.PacketReceipt.DELIVERED:
return protocol.MessageState.MSGSTATE_DELIVERED
if status == RNS.PacketReceipt.FAILED:
return protocol.MessageState.MSGSTATE_FAILED
else:
raise Exception(f"Unexpected receipt state: {status}")
def timed_out(self):
self.link.teardown()
def __str__(self):
return f"Outlet RNS Link {self.link}"
def set_packet_received_callback(self, cb: Callable[[MessageOutletBase, bytes], None]):
def inner_cb(message, packet: RNS.Packet):
packet.prove()
cb(self, message)
self.link.set_packet_callback(inner_cb)
def __init__(self, link: RNS.Link):
self.link = link
link.lsoutlet = self
link.msgoutlet = self
@staticmethod
def get_outlet(link: RNS.Link):
if hasattr(link, "lsoutlet"):
return link.lsoutlet
return RNSOutlet(link)

View File

@ -1,6 +1,10 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from rnsh.protocol import _TReceipt, MessageState
from typing import Callable
logging.getLogger().setLevel(logging.DEBUG) logging.getLogger().setLevel(logging.DEBUG)
import rnsh.protocol import rnsh.protocol
@ -14,64 +18,60 @@ import uuid
module_logger = logging.getLogger(__name__) module_logger = logging.getLogger(__name__)
class Link: class Receipt:
def __init__(self, state: rnsh.protocol.MessageState, raw: bytes):
self.state = state
self.raw = raw
class MessageOutletTest(rnsh.protocol.MessageOutletBase):
def __init__(self, mdu: int, rtt: float): def __init__(self, mdu: int, rtt: float):
self.link_id = uuid.uuid4() self.link_id = uuid.uuid4()
self.timeout_callbacks = 0 self.timeout_callbacks = 0
self.mdu = mdu self._mdu = mdu
self.rtt = rtt self._rtt = rtt
self.usable = True self._usable = True
self.receipts = [] self.receipts = []
self.packet_callback: Callable[[rnsh.protocol.MessageOutletBase, bytes], None] | None = None
def timeout_callback(self): def send(self, raw: bytes) -> Receipt:
receipt = Receipt(rnsh.protocol.MessageState.MSGSTATE_SENT, raw)
self.receipts.append(receipt)
return receipt
def set_packet_received_callback(self, cb: Callable[[rnsh.protocol.MessageOutletBase, bytes], None]):
self.packet_callback = cb
def receive(self, raw: bytes):
if self.packet_callback:
self.packet_callback(self, raw)
@property
def mdu(self):
return self._mdu
@property
def rtt(self):
return self._rtt
@property
def is_usuable(self):
return self._usable
def get_receipt_state(self, receipt: Receipt) -> MessageState:
return receipt.state
def timed_out(self):
self.timeout_callbacks += 1 self.timeout_callbacks += 1
def __str__(self): def __str__(self):
return str(self.link_id) return str(self.link_id)
class Receipt:
def __init__(self, link: Link, state: rnsh.protocol.MessageState, raw: bytes):
self.state = state
self.raw = raw
self.link = link
class ProtocolHarness(contextlib.AbstractContextManager): class ProtocolHarness(contextlib.AbstractContextManager):
def __init__(self, retry_delay_min: float = 1): def __init__(self, retry_delay_min: float = 1):
self._log = module_logger.getChild(self.__class__.__name__) self._log = module_logger.getChild(self.__class__.__name__)
self.messenger = rnsh.protocol.Messenger(receipt_checker=self.receipt_checker, self.messenger = rnsh.protocol.Messenger(retry_delay_min=retry_delay_min)
link_timeout_callback=self.link_timeout_callback,
link_mdu_getter=self.link_mdu_getter,
link_rtt_getter=self.link_rtt_getter,
link_usable_getter=self.link_usable_getter,
packet_sender=self.packet_sender,
retry_delay_min=retry_delay_min)
def packet_sender(self, link: Link, raw: bytes) -> Receipt:
receipt = Receipt(link, rnsh.protocol.MessageState.MSGSTATE_SENT, raw)
link.receipts.append(receipt)
return receipt
@staticmethod
def link_mdu_getter(link: Link):
return link.mdu
@staticmethod
def link_rtt_getter(link: Link):
return link.rtt
@staticmethod
def link_usable_getter(link: Link):
return link.usable
@staticmethod
def receipt_checker(receipt: Receipt) -> rnsh.protocol.MessageState:
return receipt.state
@staticmethod
def link_timeout_callback(link: Link):
link.timeout_callback()
def cleanup(self): def cleanup(self):
self.messenger.shutdown() self.messenger.shutdown()
@ -83,49 +83,58 @@ class ProtocolHarness(contextlib.AbstractContextManager):
return False return False
def test_mdu():
with ProtocolHarness() as h:
mdu = 500
link = Link(mdu=mdu, rtt=0.25)
assert h.messenger.get_mdu(link) == mdu - 4
link.mdu = mdu = 600
assert h.messenger.get_mdu(link) == mdu - 4
def test_rtt():
with ProtocolHarness() as h:
rtt = 0.25
link = Link(mdu=500, rtt=rtt)
assert h.messenger.get_rtt(link) == rtt
def test_send_one_retry(): def test_send_one_retry():
rtt = 0.001 rtt = 0.001
retry_interval = rtt * 150 retry_interval = rtt * 150
message_content = b'Test' message_content = b'Test'
with ProtocolHarness(retry_delay_min=retry_interval) as h: with ProtocolHarness(retry_delay_min=retry_interval) as h:
link = Link(mdu=500, rtt=rtt) outlet = MessageOutletTest(mdu=500, rtt=rtt)
message = rnsh.protocol.StreamDataMessage(stream_id=rnsh.protocol.StreamDataMessage.STREAM_ID_STDIN, message = rnsh.protocol.StreamDataMessage(stream_id=rnsh.protocol.StreamDataMessage.STREAM_ID_STDIN,
data=message_content, eof=True) data=message_content, eof=True)
assert len(link.receipts) == 0 assert len(outlet.receipts) == 0
h.messenger.send_message(link, message) h.messenger.send(outlet, message)
assert message.tracked assert message.tracked
assert message.raw is not None assert message.raw is not None
assert len(link.receipts) == 1 assert len(outlet.receipts) == 1
receipt = link.receipts[0] receipt = outlet.receipts[0]
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
assert receipt.raw == message.raw assert receipt.raw == message.raw
time.sleep(retry_interval * 1.5) time.sleep(retry_interval * 1.5)
assert len(link.receipts) == 1 assert len(outlet.receipts) == 1
receipt.state = rnsh.protocol.MessageState.MSGSTATE_FAILED receipt.state = rnsh.protocol.MessageState.MSGSTATE_FAILED
module_logger.info("set failed") module_logger.info("set failed")
time.sleep(retry_interval) time.sleep(retry_interval)
assert len(link.receipts) == 2 assert len(outlet.receipts) == 2
receipt = link.receipts[1] receipt = outlet.receipts[1]
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
receipt.state = rnsh.protocol.MessageState.MSGSTATE_DELIVERED receipt.state = rnsh.protocol.MessageState.MSGSTATE_DELIVERED
time.sleep(retry_interval) time.sleep(retry_interval)
assert len(link.receipts) == 2 assert len(outlet.receipts) == 2
assert not message.tracked
def test_send_timeout():
rtt = 0.001
retry_interval = rtt * 150
message_content = b'Test'
with ProtocolHarness(retry_delay_min=retry_interval) as h:
outlet = MessageOutletTest(mdu=500, rtt=rtt)
message = rnsh.protocol.StreamDataMessage(stream_id=rnsh.protocol.StreamDataMessage.STREAM_ID_STDIN,
data=message_content, eof=True)
assert len(outlet.receipts) == 0
h.messenger.send(outlet, message)
assert message.tracked
assert message.raw is not None
assert len(outlet.receipts) == 1
receipt = outlet.receipts[0]
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
assert receipt.raw == message.raw
time.sleep(retry_interval * 1.5)
assert outlet.timeout_callbacks == 0
time.sleep(retry_interval * 7)
assert len(outlet.receipts) == 1
assert outlet.timeout_callbacks == 1
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
assert not message.tracked assert not message.tracked
@ -133,24 +142,32 @@ def eat_own_dog_food(message: rnsh.protocol.Message, checker: typing.Callable[[r
rtt = 0.001 rtt = 0.001
retry_interval = rtt * 150 retry_interval = rtt * 150
with ProtocolHarness(retry_delay_min=retry_interval) as h: with ProtocolHarness(retry_delay_min=retry_interval) as h:
link = Link(mdu=500, rtt=rtt)
assert len(link.receipts) == 0 decoded: [rnsh.protocol.Message] = []
h.messenger.send_message(link, message) def packet(outlet, buffer):
decoded.append(h.messenger.receive(buffer))
outlet = MessageOutletTest(mdu=500, rtt=rtt)
outlet.set_packet_received_callback(packet)
assert len(outlet.receipts) == 0
h.messenger.send(outlet, message)
assert message.tracked assert message.tracked
assert message.raw is not None assert message.raw is not None
assert len(link.receipts) == 1 assert len(outlet.receipts) == 1
receipt = link.receipts[0] receipt = outlet.receipts[0]
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
assert receipt.raw == message.raw assert receipt.raw == message.raw
module_logger.info("set delivered") module_logger.info("set delivered")
receipt.state = rnsh.protocol.MessageState.MSGSTATE_DELIVERED receipt.state = rnsh.protocol.MessageState.MSGSTATE_DELIVERED
time.sleep(retry_interval * 2) time.sleep(retry_interval * 2)
assert len(link.receipts) == 1 assert len(outlet.receipts) == 1
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_DELIVERED assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_DELIVERED
assert not message.tracked assert not message.tracked
module_logger.info("injecting rx message") module_logger.info("injecting rx message")
h.messenger.inbound(message.raw) assert len(decoded) == 0
rx_message = h.messenger.poll_inbound(block=False) outlet.receive(message.raw)
assert len(decoded) == 1
rx_message = decoded[0]
assert rx_message is not None assert rx_message is not None
assert isinstance(rx_message, message.__class__) assert isinstance(rx_message, message.__class__)
assert rx_message.msgid != message.msgid assert rx_message.msgid != message.msgid
@ -238,6 +255,16 @@ def test_send_receive_error():
eat_own_dog_food(message, check) eat_own_dog_food(message, check)
def test_send_receive_cmdexit():
message = rnsh.protocol.CommandExitedMessage(5)
def check(rx_message: rnsh.protocol.Message):
assert isinstance(rx_message, message.__class__)
assert rx_message.return_code == message.return_code
eat_own_dog_food(message, check)

View File

@ -12,19 +12,6 @@ import re
import os import os
def test_check_magic():
magic = rnsh.rnsh._PROTOCOL_VERSION_0
# magic for version 0 is generated, make sure it comes out as expected
assert magic == 0xdeadbeef00000000
# verify the checker thinks it's right
assert rnsh.rnsh._protocol_check_magic(magic)
# scramble the magic
magic = magic | 0x00ffff0000000000
# make sure it fails now
assert not rnsh.rnsh._protocol_check_magic(magic)
def test_version(): def test_version():
# version = importlib.metadata.version(rnsh.__version__) # version = importlib.metadata.version(rnsh.__version__)
assert rnsh.__version__ != "0.0.0" assert rnsh.__version__ != "0.0.0"