mirror of
https://github.com/markqvist/rnsh.git
synced 2024-10-01 01:15:37 -04:00
Got the new protocol working.
This commit is contained in:
parent
0ee305795f
commit
8edb4020b1
8
rnsh/helpers.py
Normal file
8
rnsh/helpers.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
def bitwise_or_if(value: int, condition: bool, orval: int):
|
||||||
|
if not condition:
|
||||||
|
return value
|
||||||
|
return value | orval
|
||||||
|
|
||||||
|
|
||||||
|
def check_and(value: int, andval: int) -> bool:
|
||||||
|
return (value & andval) > 0
|
@ -382,8 +382,11 @@ def _launch_child(cmd_line: list[str], env: dict[str, str], stdin_is_pipe: bool,
|
|||||||
# Make PTY controlling if necessary
|
# Make PTY controlling if necessary
|
||||||
if child_fd is not None:
|
if child_fd is not None:
|
||||||
os.setsid()
|
os.setsid()
|
||||||
tmp_fd = os.open(os.ttyname(0 if not stdin_is_pipe else 1 if not stdout_is_pipe else 2), os.O_RDWR)
|
try:
|
||||||
os.close(tmp_fd)
|
tmp_fd = os.open(os.ttyname(0 if not stdin_is_pipe else 1 if not stdout_is_pipe else 2), os.O_RDWR)
|
||||||
|
os.close(tmp_fd)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
# fcntl.ioctl(0 if not stdin_is_pipe else 1 if not stdout_is_pipe else 2), os.O_RDWR, termios.TIOCSCTTY, 0)
|
# fcntl.ioctl(0 if not stdin_is_pipe else 1 if not stdout_is_pipe else 2), os.O_RDWR, termios.TIOCSCTTY, 0)
|
||||||
|
|
||||||
# Execute the command
|
# Execute the command
|
||||||
@ -445,7 +448,8 @@ class CallbackSubprocess:
|
|||||||
self._child_stdout: int = None
|
self._child_stdout: int = None
|
||||||
self._child_stderr: int = None
|
self._child_stderr: int = None
|
||||||
self._return_code: int = None
|
self._return_code: int = None
|
||||||
self._eof: bool = False
|
self._stdout_eof: bool = False
|
||||||
|
self._stderr_eof: bool = False
|
||||||
self._stdin_is_pipe = stdin_is_pipe
|
self._stdin_is_pipe = stdin_is_pipe
|
||||||
self._stdout_is_pipe = stdout_is_pipe
|
self._stdout_is_pipe = stdout_is_pipe
|
||||||
self._stderr_is_pipe = stderr_is_pipe
|
self._stderr_is_pipe = stderr_is_pipe
|
||||||
@ -455,6 +459,21 @@ class CallbackSubprocess:
|
|||||||
Terminate child process if running
|
Terminate child process if running
|
||||||
:param kill_delay: if after kill_delay seconds the child process has not exited, escalate to SIGHUP and SIGKILL
|
:param kill_delay: if after kill_delay seconds the child process has not exited, escalate to SIGHUP and SIGKILL
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
same = self._child_stdout == self._child_stderr
|
||||||
|
with contextlib.suppress(OSError):
|
||||||
|
if not self._stdout_eof:
|
||||||
|
tty_unset_reader_callbacks(self._child_stdout)
|
||||||
|
os.close(self._child_stdout)
|
||||||
|
self._child_stdout = None
|
||||||
|
|
||||||
|
if not same:
|
||||||
|
with contextlib.suppress(OSError):
|
||||||
|
if not self._stderr_eof:
|
||||||
|
tty_unset_reader_callbacks(self._child_stderr)
|
||||||
|
os.close(self._child_stderr)
|
||||||
|
self._child_stdout = None
|
||||||
|
|
||||||
self._log.debug("terminate()")
|
self._log.debug("terminate()")
|
||||||
if not self.running:
|
if not self.running:
|
||||||
return
|
return
|
||||||
@ -477,7 +496,7 @@ class CallbackSubprocess:
|
|||||||
os.waitpid(self._pid, 0)
|
os.waitpid(self._pid, 0)
|
||||||
self._log.debug("wait() finish")
|
self._log.debug("wait() finish")
|
||||||
|
|
||||||
threading.Thread(target=wait).start()
|
threading.Thread(target=wait, daemon=True).start()
|
||||||
|
|
||||||
def close_stdin(self):
|
def close_stdin(self):
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
@ -592,26 +611,44 @@ class CallbackSubprocess:
|
|||||||
|
|
||||||
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
|
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
|
||||||
|
|
||||||
def reader(fd: int, callback: callable):
|
def stdout():
|
||||||
try:
|
try:
|
||||||
with exception.permit(SystemExit):
|
with exception.permit(SystemExit):
|
||||||
data = tty_read(fd)
|
data = tty_read_poll(self._child_stdout)
|
||||||
if data is not None and len(data) > 0:
|
if data is not None and len(data) > 0:
|
||||||
callback(data)
|
self._stdout_cb(data)
|
||||||
except EOFError:
|
except EOFError:
|
||||||
self._eof = True
|
self._stdout_eof = True
|
||||||
tty_unset_reader_callbacks(self._child_stdout)
|
tty_unset_reader_callbacks(self._child_stdout)
|
||||||
callback(bytearray())
|
self._stdout_cb(bytearray())
|
||||||
|
with contextlib.suppress(OSError):
|
||||||
|
os.close(self._child_stdout)
|
||||||
|
|
||||||
tty_add_reader_callback(self._child_stdout, functools.partial(reader, self._child_stdout, self._stdout_cb),
|
def stderr():
|
||||||
self._loop)
|
try:
|
||||||
|
with exception.permit(SystemExit):
|
||||||
|
data = tty_read_poll(self._child_stderr)
|
||||||
|
if data is not None and len(data) > 0:
|
||||||
|
self._stderr_cb(data)
|
||||||
|
except EOFError:
|
||||||
|
self._stderr_eof = True
|
||||||
|
tty_unset_reader_callbacks(self._child_stderr)
|
||||||
|
self._stdout_cb(bytearray())
|
||||||
|
with contextlib.suppress(OSError):
|
||||||
|
os.close(self._child_stderr)
|
||||||
|
|
||||||
|
tty_add_reader_callback(self._child_stdout, stdout, self._loop)
|
||||||
if self._child_stderr != self._child_stdout:
|
if self._child_stderr != self._child_stdout:
|
||||||
tty_add_reader_callback(self._child_stderr, functools.partial(reader, self._child_stderr, self._stderr_cb),
|
tty_add_reader_callback(self._child_stderr, stderr, self._loop)
|
||||||
self._loop)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def eof(self):
|
def stdout_eof(self):
|
||||||
return self._eof or not self.running
|
return self._stdout_eof or not self.running
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stderr_eof(self):
|
||||||
|
return self._stderr_eof or not self.running
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def return_code(self) -> int:
|
def return_code(self) -> int:
|
||||||
|
158
rnsh/protocol.py
158
rnsh/protocol.py
@ -15,6 +15,7 @@ import abc
|
|||||||
import contextlib
|
import contextlib
|
||||||
import struct
|
import struct
|
||||||
import logging as __logging
|
import logging as __logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
module_logger = __logging.getLogger(__name__)
|
module_logger = __logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -22,13 +23,50 @@ module_logger = __logging.getLogger(__name__)
|
|||||||
_TReceipt = TypeVar("_TReceipt")
|
_TReceipt = TypeVar("_TReceipt")
|
||||||
_TLink = TypeVar("_TLink")
|
_TLink = TypeVar("_TLink")
|
||||||
MSG_MAGIC = 0xac
|
MSG_MAGIC = 0xac
|
||||||
PROTOCOL_VERSION=1
|
PROTOCOL_VERSION = 1
|
||||||
|
|
||||||
|
|
||||||
def _make_MSGTYPE(val: int):
|
def _make_MSGTYPE(val: int):
|
||||||
return ((MSG_MAGIC << 8) & 0xff00) | (val & 0x00ff)
|
return ((MSG_MAGIC << 8) & 0xff00) | (val & 0x00ff)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageOutletBase(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def send(self, raw: bytes) -> _TReceipt:
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def mdu(self):
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def rtt(self):
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def is_usuable(self):
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_receipt_state(self, receipt: _TReceipt) -> MessageState:
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def timed_out(self):
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __str__(self):
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_packet_received_callback(self, cb: Callable[[MessageOutletBase, bytes], None]):
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
|
||||||
class METype(enum.IntEnum):
|
class METype(enum.IntEnum):
|
||||||
ME_NO_MSG_TYPE = 0
|
ME_NO_MSG_TYPE = 0
|
||||||
ME_INVALID_MSG_TYPE = 1
|
ME_INVALID_MSG_TYPE = 1
|
||||||
@ -58,17 +96,17 @@ class Message(abc.ABC):
|
|||||||
self.msgid = uuid.uuid4()
|
self.msgid = uuid.uuid4()
|
||||||
self.raw: bytes | None = None
|
self.raw: bytes | None = None
|
||||||
self.receipt: _TReceipt = None
|
self.receipt: _TReceipt = None
|
||||||
self.link: _TLink = None
|
self.outlet: _TLink = None
|
||||||
self.tracked: bool = False
|
self.tracked: bool = False
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"{self.__class__.__name__} {self.msgid}"
|
return f"{self.__class__.__name__} {self.msgid}"
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abstractmethod
|
||||||
def pack(self) -> bytes:
|
def pack(self) -> bytes:
|
||||||
raise NotImplemented()
|
raise NotImplemented()
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abstractmethod
|
||||||
def unpack(self, raw):
|
def unpack(self, raw):
|
||||||
raise NotImplemented()
|
raise NotImplemented()
|
||||||
|
|
||||||
@ -124,7 +162,8 @@ class ExecuteCommandMesssage(Message):
|
|||||||
MSGTYPE = _make_MSGTYPE(3)
|
MSGTYPE = _make_MSGTYPE(3)
|
||||||
|
|
||||||
def __init__(self, cmdline: [str] = None, pipe_stdin: bool = False, pipe_stdout: bool = False,
|
def __init__(self, cmdline: [str] = None, pipe_stdin: bool = False, pipe_stdout: bool = False,
|
||||||
pipe_stderr: bool = False, tcflags: [any] = None, term: str | None = None):
|
pipe_stderr: bool = False, tcflags: [any] = None, term: str | None = None, rows: int = None,
|
||||||
|
cols: int = None, hpix: int = None, vpix: int = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.cmdline = cmdline
|
self.cmdline = cmdline
|
||||||
self.pipe_stdin = pipe_stdin
|
self.pipe_stdin = pipe_stdin
|
||||||
@ -132,16 +171,20 @@ class ExecuteCommandMesssage(Message):
|
|||||||
self.pipe_stderr = pipe_stderr
|
self.pipe_stderr = pipe_stderr
|
||||||
self.tcflags = tcflags
|
self.tcflags = tcflags
|
||||||
self.term = term
|
self.term = term
|
||||||
|
self.rows = rows
|
||||||
|
self.cols = cols
|
||||||
|
self.hpix = hpix
|
||||||
|
self.vpix = vpix
|
||||||
|
|
||||||
def pack(self) -> bytes:
|
def pack(self) -> bytes:
|
||||||
raw = umsgpack.packb((self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr,
|
raw = umsgpack.packb((self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr,
|
||||||
self.tcflags, self.term))
|
self.tcflags, self.term, self.rows, self.cols, self.hpix, self.vpix))
|
||||||
return self.wrap_MSGTYPE(raw)
|
return self.wrap_MSGTYPE(raw)
|
||||||
|
|
||||||
def unpack(self, raw):
|
def unpack(self, raw):
|
||||||
raw = self.unwrap_MSGTYPE(raw)
|
raw = self.unwrap_MSGTYPE(raw)
|
||||||
self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, self.tcflags, self.term \
|
self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, self.tcflags, self.term, self.rows, \
|
||||||
= umsgpack.unpackb(raw)
|
self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw)
|
||||||
|
|
||||||
class StreamDataMessage(Message):
|
class StreamDataMessage(Message):
|
||||||
MSGTYPE = _make_MSGTYPE(4)
|
MSGTYPE = _make_MSGTYPE(4)
|
||||||
@ -156,7 +199,7 @@ class StreamDataMessage(Message):
|
|||||||
self.eof = eof
|
self.eof = eof
|
||||||
|
|
||||||
def pack(self) -> bytes:
|
def pack(self) -> bytes:
|
||||||
raw = umsgpack.packb((self.stream_id, self.eof, self.data))
|
raw = umsgpack.packb((self.stream_id, self.eof, bytes(self.data)))
|
||||||
return self.wrap_MSGTYPE(raw)
|
return self.wrap_MSGTYPE(raw)
|
||||||
|
|
||||||
def unpack(self, raw):
|
def unpack(self, raw):
|
||||||
@ -169,7 +212,7 @@ class VersionInfoMessage(Message):
|
|||||||
|
|
||||||
def __init__(self, sw_version: str = None):
|
def __init__(self, sw_version: str = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.sw_version = sw_version
|
self.sw_version = sw_version or rnsh.__version__
|
||||||
self.protocol_version = PROTOCOL_VERSION
|
self.protocol_version = PROTOCOL_VERSION
|
||||||
|
|
||||||
def pack(self) -> bytes:
|
def pack(self) -> bytes:
|
||||||
@ -199,6 +242,22 @@ class ErrorMessage(Message):
|
|||||||
self.msg, self.fatal, self.data = umsgpack.unpackb(raw)
|
self.msg, self.fatal, self.data = umsgpack.unpackb(raw)
|
||||||
|
|
||||||
|
|
||||||
|
class CommandExitedMessage(Message):
|
||||||
|
MSGTYPE = _make_MSGTYPE(7)
|
||||||
|
|
||||||
|
def __init__(self, return_code: int = None):
|
||||||
|
super().__init__()
|
||||||
|
self.return_code = return_code
|
||||||
|
|
||||||
|
def pack(self) -> bytes:
|
||||||
|
raw = umsgpack.packb(self.return_code)
|
||||||
|
return self.wrap_MSGTYPE(raw)
|
||||||
|
|
||||||
|
def unpack(self, raw: bytes):
|
||||||
|
raw = self.unwrap_MSGTYPE(raw)
|
||||||
|
self.return_code = umsgpack.unpackb(raw)
|
||||||
|
|
||||||
|
|
||||||
class Messenger(contextlib.AbstractContextManager):
|
class Messenger(contextlib.AbstractContextManager):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -208,29 +267,16 @@ class Messenger(contextlib.AbstractContextManager):
|
|||||||
subclass_tuples.append((subclass.MSGTYPE, subclass))
|
subclass_tuples.append((subclass.MSGTYPE, subclass))
|
||||||
return subclass_tuples
|
return subclass_tuples
|
||||||
|
|
||||||
def __init__(self, receipt_checker: Callable[[_TReceipt], MessageState],
|
def __init__(self, retry_delay_min: float = 10.0):
|
||||||
link_timeout_callback: Callable[[_TLink], None],
|
|
||||||
link_mdu_getter: Callable[[_TLink], int],
|
|
||||||
link_rtt_getter: Callable[[_TLink], float],
|
|
||||||
link_usable_getter: Callable[[_TLink], bool],
|
|
||||||
packet_sender: Callable[[_TLink, bytes], _TReceipt],
|
|
||||||
retry_delay_min: float = 10.0):
|
|
||||||
self._log = module_logger.getChild(self.__class__.__name__)
|
self._log = module_logger.getChild(self.__class__.__name__)
|
||||||
self._receipt_checker = receipt_checker
|
|
||||||
self._link_timeout_callback = link_timeout_callback
|
|
||||||
self._link_mdu_getter = link_mdu_getter
|
|
||||||
self._link_rtt_getter = link_rtt_getter
|
|
||||||
self._link_usable_getter = link_usable_getter
|
|
||||||
self._packet_sender = packet_sender
|
|
||||||
self._sent_messages: list[Message] = []
|
self._sent_messages: list[Message] = []
|
||||||
self._lock = threading.RLock()
|
self._lock = threading.RLock()
|
||||||
self._retry_timer = rnsh.retry.RetryThread()
|
self._retry_timer = rnsh.retry.RetryThread()
|
||||||
self._message_factories = dict(self.__class__._get_msg_constructors())
|
self._message_factories = dict(self.__class__._get_msg_constructors())
|
||||||
self._inbound_queue = queue.Queue()
|
|
||||||
self._retry_delay_min = retry_delay_min
|
self._retry_delay_min = retry_delay_min
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self) -> Messenger:
|
||||||
pass
|
return self
|
||||||
|
|
||||||
def __exit__(self, __exc_type: Type[BaseException] | None, __exc_value: BaseException | None,
|
def __exit__(self, __exc_type: Type[BaseException] | None, __exc_value: BaseException | None,
|
||||||
__traceback: TracebackType | None) -> bool | None:
|
__traceback: TracebackType | None) -> bool | None:
|
||||||
@ -238,39 +284,38 @@ class Messenger(contextlib.AbstractContextManager):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
self._run = False
|
|
||||||
self._retry_timer.close()
|
self._retry_timer.close()
|
||||||
|
|
||||||
def inbound(self, raw: bytes):
|
def clear_retries(self, outlet):
|
||||||
|
self._retry_timer.complete(outlet)
|
||||||
|
|
||||||
|
def receive(self, raw: bytes) -> Message:
|
||||||
(mid, contents) = Message.static_unwrap_MSGTYPE(raw)
|
(mid, contents) = Message.static_unwrap_MSGTYPE(raw)
|
||||||
ctor = self._message_factories.get(mid, None)
|
ctor = self._message_factories.get(mid, None)
|
||||||
if ctor is None:
|
if ctor is None:
|
||||||
raise MessagingException(METype.ME_NOT_REGISTERED, f"unable to find constructor for message type {hex(mid)}")
|
raise MessagingException(METype.ME_NOT_REGISTERED, f"unable to find constructor for message type {hex(mid)}")
|
||||||
message = ctor()
|
message = ctor()
|
||||||
message.unpack(raw)
|
message.unpack(raw)
|
||||||
self._log.debug("Message received: {message}")
|
self._log.debug(f"Message received: {message}")
|
||||||
self._inbound_queue.put(message)
|
return message
|
||||||
|
|
||||||
def get_mdu(self, link: _TLink) -> int:
|
def is_outlet_ready(self, outlet: MessageOutletBase) -> bool:
|
||||||
return self._link_mdu_getter(link) - 4
|
if not outlet.is_usuable:
|
||||||
|
self._log.debug("is_outlet_ready outlet unusable")
|
||||||
def get_rtt(self, link: _TLink) -> float:
|
|
||||||
return self._link_rtt_getter(link)
|
|
||||||
|
|
||||||
def is_link_ready(self, link: _TLink) -> bool:
|
|
||||||
if not self._link_usable_getter(link):
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
for message in self._sent_messages:
|
for message in self._sent_messages:
|
||||||
if message.link == link:
|
if message.outlet == outlet and message.tracked and message.receipt \
|
||||||
|
and outlet.get_receipt_state(message.receipt) == MessageState.MSGSTATE_SENT:
|
||||||
|
self._log.debug("is_outlet_ready pending message found")
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def send_message(self, link: _TLink, message: Message):
|
def send(self, outlet: MessageOutletBase, message: Message):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if not self.is_link_ready(link):
|
if not self.is_outlet_ready(outlet):
|
||||||
raise MessagingException(METype.ME_LINK_NOT_READY, f"link {link} not ready")
|
raise MessagingException(METype.ME_LINK_NOT_READY, f"link {outlet} not ready")
|
||||||
|
|
||||||
if message in self._sent_messages:
|
if message in self._sent_messages:
|
||||||
raise MessagingException(METype.ME_ALREADY_SENT)
|
raise MessagingException(METype.ME_ALREADY_SENT)
|
||||||
@ -279,43 +324,36 @@ class Messenger(contextlib.AbstractContextManager):
|
|||||||
|
|
||||||
if not message.raw:
|
if not message.raw:
|
||||||
message.raw = message.pack()
|
message.raw = message.pack()
|
||||||
message.link = link
|
message.outlet = outlet
|
||||||
|
|
||||||
def send(tag: any, tries: int):
|
def send_inner(tag: any, tries: int):
|
||||||
state = MessageState.MSGSTATE_NEW if not message.receipt else self._receipt_checker(message.receipt)
|
state = MessageState.MSGSTATE_NEW if not message.receipt else outlet.get_receipt_state(message.receipt)
|
||||||
if state in [MessageState.MSGSTATE_NEW, MessageState.MSGSTATE_FAILED]:
|
if state in [MessageState.MSGSTATE_NEW, MessageState.MSGSTATE_FAILED]:
|
||||||
try:
|
try:
|
||||||
self._log.debug(f"Sending packet for {message}")
|
self._log.debug(f"Sending packet for {message}")
|
||||||
message.receipt = self._packet_sender(link, message.raw)
|
message.receipt = outlet.send(message.raw)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
self._log.exception(f"Error sending message {message}")
|
self._log.exception(f"Error sending message {message}")
|
||||||
elif state in [MessageState.MSGSTATE_SENT]:
|
elif state in [MessageState.MSGSTATE_SENT]:
|
||||||
self._log.debug(f"Retry skipped, message still pending {message}")
|
self._log.debug(f"Retry skipped, message still pending {message}")
|
||||||
elif state in [MessageState.MSGSTATE_DELIVERED]:
|
elif state in [MessageState.MSGSTATE_DELIVERED]:
|
||||||
latency = round(time.time() - message.ts, 1)
|
latency = round(time.time() - message.ts, 1)
|
||||||
self._log.debug(f"Message delivered {message.msgid} after {tries-1} tries/{latency} seconds")
|
self._log.debug(f"{message} delivered {message.msgid} after {tries-1} tries/{latency} seconds")
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._sent_messages.remove(message)
|
self._sent_messages.remove(message)
|
||||||
message.tracked = False
|
message.tracked = False
|
||||||
self._retry_timer.complete(link)
|
self._retry_timer.complete(outlet)
|
||||||
return link
|
return outlet
|
||||||
|
|
||||||
def timeout(tag: any, tries: int):
|
def timeout(tag: any, tries: int):
|
||||||
latency = round(time.time() - message.ts, 1)
|
latency = round(time.time() - message.ts, 1)
|
||||||
msg = "delivered" if message.receipt and self._receipt_checker(message.receipt) == MessageState.MSGSTATE_DELIVERED else "retry timeout"
|
msg = "delivered" if message.receipt and outlet.get_receipt_state(message.receipt) == MessageState.MSGSTATE_DELIVERED else "retry timeout"
|
||||||
self._log.debug(f"Message {msg} {message} after {tries} tries/{latency} seconds")
|
self._log.debug(f"Message {msg} {message} after {tries} tries/{latency} seconds")
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._sent_messages.remove(message)
|
self._sent_messages.remove(message)
|
||||||
message.tracked = False
|
message.tracked = False
|
||||||
self._link_timeout_callback(link)
|
outlet.timed_out()
|
||||||
|
|
||||||
rtt = self._link_rtt_getter(link)
|
|
||||||
self._retry_timer.begin(5, min(rtt * 100, max(rtt * 2, self._retry_delay_min)), send, timeout)
|
|
||||||
|
|
||||||
def poll_inbound(self, block: bool = True, timeout: float = None) -> Message | None:
|
|
||||||
try:
|
|
||||||
return self._inbound_queue.get(block=block, timeout=timeout)
|
|
||||||
except queue.Empty:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
rtt = outlet.rtt
|
||||||
|
self._retry_timer.begin(5, max(rtt * 5, self._retry_delay_min), send_inner, timeout)
|
||||||
|
|
||||||
|
@ -79,7 +79,7 @@ class RetryThread(AbstractContextManager):
|
|||||||
self._lock = threading.RLock()
|
self._lock = threading.RLock()
|
||||||
self._run = True
|
self._run = True
|
||||||
self._finished: asyncio.Future = None
|
self._finished: asyncio.Future = None
|
||||||
self._thread = threading.Thread(name=name, target=self._thread_run)
|
self._thread = threading.Thread(name=name, target=self._thread_run, daemon=True)
|
||||||
self._thread.start()
|
self._thread.start()
|
||||||
|
|
||||||
def is_alive(self):
|
def is_alive(self):
|
||||||
|
911
rnsh/rnsh.py
911
rnsh/rnsh.py
File diff suppressed because it is too large
Load Diff
423
rnsh/session.py
Normal file
423
rnsh/session.py
Normal file
@ -0,0 +1,423 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import contextlib
|
||||||
|
import functools
|
||||||
|
import threading
|
||||||
|
import rnsh.exception as exception
|
||||||
|
import asyncio
|
||||||
|
import rnsh.process as process
|
||||||
|
import rnsh.helpers as helpers
|
||||||
|
import rnsh.protocol as protocol
|
||||||
|
import enum
|
||||||
|
from typing import TypeVar, Generic, Callable, List
|
||||||
|
from abc import abstractmethod, ABC
|
||||||
|
from multiprocessing import Manager
|
||||||
|
import os
|
||||||
|
import RNS
|
||||||
|
|
||||||
|
import logging as __logging
|
||||||
|
|
||||||
|
from rnsh.protocol import MessageOutletBase, _TReceipt, MessageState
|
||||||
|
|
||||||
|
module_logger = __logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_TLink = TypeVar("_TLink")
|
||||||
|
|
||||||
|
class SEType(enum.IntEnum):
|
||||||
|
SE_LINK_CLOSED = 0
|
||||||
|
|
||||||
|
|
||||||
|
class SessionException(Exception):
|
||||||
|
def __init__(self, setype: SEType, msg: str, *args):
|
||||||
|
super().__init__(msg, args)
|
||||||
|
self.type = setype
|
||||||
|
|
||||||
|
|
||||||
|
class LSState(enum.IntEnum):
|
||||||
|
LSSTATE_WAIT_IDENT = 1
|
||||||
|
LSSTATE_WAIT_VERS = 2
|
||||||
|
LSSTATE_WAIT_CMD = 3
|
||||||
|
LSSTATE_RUNNING = 4
|
||||||
|
LSSTATE_ERROR = 5
|
||||||
|
LSSTATE_TEARDOWN = 6
|
||||||
|
|
||||||
|
|
||||||
|
_TIdentity = TypeVar("_TIdentity")
|
||||||
|
|
||||||
|
|
||||||
|
class LSOutletBase(protocol.MessageOutletBase):
|
||||||
|
@abstractmethod
|
||||||
|
def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]):
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_link_closed_callback(self, cb: Callable[[LSOutletBase], None]):
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def unset_link_closed_callback(self):
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def teardown(self):
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self):
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
|
||||||
|
class ListenerSession:
|
||||||
|
sessions: List[ListenerSession] = []
|
||||||
|
messenger: protocol.Messenger = protocol.Messenger(retry_delay_min=5)
|
||||||
|
allowed_identity_hashes: [any] = []
|
||||||
|
allow_all: bool = False
|
||||||
|
allow_remote_command: bool = False
|
||||||
|
default_command: [str] = []
|
||||||
|
remote_cmd_as_args = False
|
||||||
|
|
||||||
|
def __init__(self, outlet: LSOutletBase, loop: asyncio.AbstractEventLoop):
|
||||||
|
self._log = module_logger.getChild(self.__class__.__name__)
|
||||||
|
self._log.info(f"Session started for {outlet}")
|
||||||
|
self.outlet = outlet
|
||||||
|
self.outlet.set_initiator_identified_callback(self._initiator_identified)
|
||||||
|
self.outlet.set_link_closed_callback(self._link_closed)
|
||||||
|
self.loop = loop
|
||||||
|
self.state: LSState = None
|
||||||
|
self.remote_identity = None
|
||||||
|
self.term: str | None = None
|
||||||
|
self.stdin_is_pipe: bool = False
|
||||||
|
self.stdout_is_pipe: bool = False
|
||||||
|
self.stderr_is_pipe: bool = False
|
||||||
|
self.tcflags: [any] = None
|
||||||
|
self.cmdline: [str] = None
|
||||||
|
self.rows: int = 0
|
||||||
|
self.cols: int = 0
|
||||||
|
self.hpix: int = 0
|
||||||
|
self.vpix: int = 0
|
||||||
|
self.stdout_buf = bytearray()
|
||||||
|
self.stdout_eof_sent = False
|
||||||
|
self.stderr_buf = bytearray()
|
||||||
|
self.stderr_eof_sent = False
|
||||||
|
self.return_code: int | None = None
|
||||||
|
self.return_code_sent = False
|
||||||
|
self.process: process.CallbackSubprocess | None = None
|
||||||
|
self._set_state(LSState.LSSTATE_WAIT_IDENT)
|
||||||
|
self.sessions.append(self)
|
||||||
|
|
||||||
|
def _terminated(self, return_code: int):
|
||||||
|
self.return_code = return_code
|
||||||
|
|
||||||
|
def _set_state(self, state: LSState, timeout_factor: float = 10.0):
|
||||||
|
timeout = max(self.outlet.rtt * timeout_factor, max(self.outlet.rtt * 2, 10)) if timeout_factor is not None else None
|
||||||
|
self._log.debug(f"Set state: {state.name}, timeout {timeout}")
|
||||||
|
orig_state = self.state
|
||||||
|
self.state = state
|
||||||
|
if timeout_factor is not None:
|
||||||
|
self._call(functools.partial(self._check_protocol_timeout, lambda: self.state == orig_state, state.name), timeout)
|
||||||
|
|
||||||
|
def _call(self, func: callable, delay: float = 0):
|
||||||
|
def call_inner():
|
||||||
|
# self._log.debug("call_inner")
|
||||||
|
if delay == 0:
|
||||||
|
func()
|
||||||
|
else:
|
||||||
|
self.loop.call_later(delay, func)
|
||||||
|
self.loop.call_soon_threadsafe(call_inner)
|
||||||
|
|
||||||
|
def send(self, message: protocol.Message):
|
||||||
|
self.messenger.send(self.outlet, message)
|
||||||
|
|
||||||
|
def _protocol_error(self, name: str):
|
||||||
|
self.terminate(f"Protocol error ({name})")
|
||||||
|
|
||||||
|
def _protocol_timeout_error(self, name: str):
|
||||||
|
self.terminate(f"Protocol timeout error: {name}")
|
||||||
|
|
||||||
|
def terminate(self, error: str = None):
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
self._log.debug("Terminating session" + (f": {error}" if error else ""))
|
||||||
|
if error and self.state != LSState.LSSTATE_TEARDOWN:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
self.send(protocol.ErrorMessage(error, True))
|
||||||
|
self.state = LSState.LSSTATE_ERROR
|
||||||
|
self._terminate_process()
|
||||||
|
self._call(self._prune, max(self.outlet.rtt * 3, 5))
|
||||||
|
|
||||||
|
def _prune(self):
|
||||||
|
self.state = LSState.LSSTATE_TEARDOWN
|
||||||
|
with contextlib.suppress(ValueError):
|
||||||
|
self.sessions.remove(self)
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
self.outlet.teardown()
|
||||||
|
|
||||||
|
def _check_protocol_timeout(self, fail_condition: Callable[[], bool], name: str):
|
||||||
|
timeout = True
|
||||||
|
try:
|
||||||
|
timeout = self.state != LSState.LSSTATE_TEARDOWN and fail_condition()
|
||||||
|
except Exception as ex:
|
||||||
|
self._log.exception("Error in protocol timeout", ex)
|
||||||
|
if timeout:
|
||||||
|
self._protocol_timeout_error(name)
|
||||||
|
|
||||||
|
def _link_closed(self, outlet: LSOutletBase):
|
||||||
|
outlet.unset_link_closed_callback()
|
||||||
|
|
||||||
|
if outlet != self.outlet:
|
||||||
|
self._log.debug("Link closed received from incorrect outlet")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._log.debug(f"link_closed {outlet}")
|
||||||
|
self.messenger.clear_retries(self.outlet)
|
||||||
|
self.terminate()
|
||||||
|
|
||||||
|
def _initiator_identified(self, outlet, identity):
|
||||||
|
if outlet != self.outlet:
|
||||||
|
self._log.debug("Identity received from incorrect outlet")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._log.info(f"initiator_identified {identity} on link {outlet}")
|
||||||
|
if self.state != LSState.LSSTATE_WAIT_IDENT:
|
||||||
|
self._protocol_error(LSState.LSSTATE_WAIT_IDENT.name)
|
||||||
|
|
||||||
|
if not self.allow_all and identity.hash not in self.allowed_identity_hashes:
|
||||||
|
self.terminate("Identity is not allowed.")
|
||||||
|
|
||||||
|
self.remote_identity = identity
|
||||||
|
self.outlet.set_packet_received_callback(self._packet_received)
|
||||||
|
self._set_state(LSState.LSSTATE_WAIT_VERS)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def pump_all(cls):
|
||||||
|
for session in cls.sessions:
|
||||||
|
session.pump()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def terminate_all(cls, reason: str):
|
||||||
|
for session in cls.sessions:
|
||||||
|
session.terminate(reason)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
def pump(self):
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.state != LSState.LSSTATE_RUNNING:
|
||||||
|
return
|
||||||
|
elif not self.messenger.is_outlet_ready(self.outlet):
|
||||||
|
return
|
||||||
|
elif len(self.stderr_buf) > 0:
|
||||||
|
mdu = self.outlet.mdu - 16
|
||||||
|
data = self.stderr_buf[:mdu]
|
||||||
|
self.stderr_buf = self.stderr_buf[mdu:]
|
||||||
|
send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent
|
||||||
|
self.stderr_eof_sent = self.stderr_eof_sent or send_eof
|
||||||
|
msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDERR,
|
||||||
|
data, send_eof)
|
||||||
|
self.send(msg)
|
||||||
|
if send_eof:
|
||||||
|
self.stderr_eof_sent = True
|
||||||
|
elif len(self.stdout_buf) > 0:
|
||||||
|
mdu = self.outlet.mdu - 16
|
||||||
|
data = self.stdout_buf[:mdu]
|
||||||
|
self.stdout_buf = self.stdout_buf[mdu:]
|
||||||
|
send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent
|
||||||
|
self.stdout_eof_sent = self.stdout_eof_sent or send_eof
|
||||||
|
msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDOUT,
|
||||||
|
data, send_eof)
|
||||||
|
self.send(msg)
|
||||||
|
elif self.return_code is not None and not self.return_code_sent:
|
||||||
|
msg = protocol.CommandExitedMessage(self.return_code)
|
||||||
|
self.send(msg)
|
||||||
|
self.return_code_sent = True
|
||||||
|
self._call(functools.partial(self._check_protocol_timeout,
|
||||||
|
lambda: self.state == LSState.LSSTATE_RUNNING, "CommandExitedMessage"),
|
||||||
|
max(self.outlet.rtt * 5, 10))
|
||||||
|
except Exception as ex:
|
||||||
|
self._log.exception("Error during pump", ex)
|
||||||
|
|
||||||
|
def _terminate_process(self):
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
if self.process and self.process.running:
|
||||||
|
self.process.terminate()
|
||||||
|
|
||||||
|
def _start_cmd(self, cmdline: [str], pipe_stdin: bool, pipe_stdout: bool, pipe_stderr: bool, tcflags: [any],
|
||||||
|
term: str | None, rows: int, cols: int, hpix: int, vpix: int):
|
||||||
|
|
||||||
|
self.cmdline = self.default_command
|
||||||
|
if not self.allow_remote_command and cmdline and len(cmdline) > 0:
|
||||||
|
self.terminate("Remote command line not allowed by listener")
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.remote_cmd_as_args and cmdline and len(cmdline) > 0:
|
||||||
|
self.cmdline.extend(cmdline)
|
||||||
|
elif cmdline and len(cmdline) > 0:
|
||||||
|
self.cmdline = cmdline
|
||||||
|
|
||||||
|
|
||||||
|
self.stdin_is_pipe = pipe_stdin
|
||||||
|
self.stdout_is_pipe = pipe_stdout
|
||||||
|
self.stderr_is_pipe = pipe_stderr
|
||||||
|
self.tcflags = tcflags
|
||||||
|
self.term = term
|
||||||
|
|
||||||
|
def stdout(data: bytes):
|
||||||
|
self.stdout_buf.extend(data)
|
||||||
|
|
||||||
|
def stderr(data: bytes):
|
||||||
|
self.stderr_buf.extend(data)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.process = process.CallbackSubprocess(argv=self.cmdline,
|
||||||
|
env={"TERM": self.term or os.environ.get("TERM", None),
|
||||||
|
"RNS_REMOTE_IDENTITY": RNS.prettyhexrep(self.remote_identity.hash) or ""},
|
||||||
|
loop=self.loop,
|
||||||
|
stdout_callback=stdout,
|
||||||
|
stderr_callback=stderr,
|
||||||
|
terminated_callback=self._terminated,
|
||||||
|
stdin_is_pipe=self.stdin_is_pipe,
|
||||||
|
stdout_is_pipe=self.stdout_is_pipe,
|
||||||
|
stderr_is_pipe=self.stderr_is_pipe)
|
||||||
|
self.process.start()
|
||||||
|
self._set_window_size(rows, cols, hpix, vpix)
|
||||||
|
except Exception as ex:
|
||||||
|
self._log.exception(f"Unable to start process for link {self.outlet}", ex)
|
||||||
|
self.terminate("Unable to start process")
|
||||||
|
|
||||||
|
def _set_window_size(self, rows: int, cols: int, hpix: int, vpix: int):
|
||||||
|
self.rows = rows
|
||||||
|
self.cols = cols
|
||||||
|
self.hpix = hpix
|
||||||
|
self.vpix = vpix
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
self.process.set_winsize(rows, cols, hpix, vpix)
|
||||||
|
|
||||||
|
def _received_stdin(self, data: bytes, eof: bool):
|
||||||
|
if data and len(data) > 0:
|
||||||
|
self.process.write(data)
|
||||||
|
if eof:
|
||||||
|
self.process.close_stdin()
|
||||||
|
|
||||||
|
def _handle_message(self, message: protocol.Message):
|
||||||
|
if self.state == LSState.LSSTATE_WAIT_VERS:
|
||||||
|
if not isinstance(message, protocol.VersionInfoMessage):
|
||||||
|
self._protocol_error(self.state.name)
|
||||||
|
return
|
||||||
|
self._log.info(f"version {message.sw_version}, protocol {message.protocol_version} on link {self.outlet}")
|
||||||
|
if message.protocol_version != protocol.PROTOCOL_VERSION:
|
||||||
|
self.terminate("Incompatible protocol")
|
||||||
|
return
|
||||||
|
self.send(protocol.VersionInfoMessage())
|
||||||
|
self._set_state(LSState.LSSTATE_WAIT_CMD)
|
||||||
|
return
|
||||||
|
elif self.state == LSState.LSSTATE_WAIT_CMD:
|
||||||
|
if not isinstance(message, protocol.ExecuteCommandMesssage):
|
||||||
|
return self._protocol_error(self.state.name)
|
||||||
|
self._log.info(f"Execute command message on link {self.outlet}: {message.cmdline}")
|
||||||
|
self._set_state(LSState.LSSTATE_RUNNING)
|
||||||
|
self._start_cmd(message.cmdline, message.pipe_stdin, message.pipe_stdout, message.pipe_stderr,
|
||||||
|
message.tcflags, message.term, message.rows, message.cols, message.hpix, message.vpix)
|
||||||
|
return
|
||||||
|
elif self.state == LSState.LSSTATE_RUNNING:
|
||||||
|
if isinstance(message, protocol.WindowSizeMessage):
|
||||||
|
self._set_window_size(message.rows, message.cols, message.hpix, message.vpix)
|
||||||
|
elif isinstance(message, protocol.StreamDataMessage):
|
||||||
|
if message.stream_id != protocol.StreamDataMessage.STREAM_ID_STDIN:
|
||||||
|
self._log.error(f"Received stream data for invalid stream {message.stream_id} on link {self.outlet}")
|
||||||
|
return self._protocol_error(self.state.name)
|
||||||
|
self._received_stdin(message.data, message.eof)
|
||||||
|
return
|
||||||
|
elif isinstance(message, protocol.NoopMessage):
|
||||||
|
# echo noop only on listener--used for keepalive/connectivity check
|
||||||
|
self.send(message)
|
||||||
|
return
|
||||||
|
elif self.state in [LSState.LSSTATE_ERROR, LSState.LSSTATE_TEARDOWN]:
|
||||||
|
self._log.error(f"Received packet, but in state {self.state.name}")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self._protocol_error("unexpected message")
|
||||||
|
return
|
||||||
|
|
||||||
|
def _packet_received(self, outlet: protocol.MessageOutletBase, raw: bytes):
|
||||||
|
if outlet != self.outlet:
|
||||||
|
self._log.debug("Packet received from incorrect outlet")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
message = self.messenger.receive(raw)
|
||||||
|
self._handle_message(message)
|
||||||
|
except Exception as ex:
|
||||||
|
self._protocol_error("unusable packet")
|
||||||
|
|
||||||
|
|
||||||
|
class RNSOutlet(LSOutletBase):
|
||||||
|
|
||||||
|
def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]):
|
||||||
|
def inner_cb(link, identity: _TIdentity):
|
||||||
|
cb(self, identity)
|
||||||
|
|
||||||
|
self.link.set_remote_identified_callback(inner_cb)
|
||||||
|
|
||||||
|
def set_link_closed_callback(self, cb: Callable[[LSOutletBase], None]):
|
||||||
|
def inner_cb(link):
|
||||||
|
cb(self)
|
||||||
|
|
||||||
|
self.link.set_link_closed_callback(inner_cb)
|
||||||
|
|
||||||
|
def unset_link_closed_callback(self):
|
||||||
|
self.link.set_link_closed_callback(None)
|
||||||
|
|
||||||
|
def teardown(self):
|
||||||
|
self.link.teardown()
|
||||||
|
|
||||||
|
def send(self, raw: bytes) -> RNS.PacketReceipt:
|
||||||
|
packet = RNS.Packet(self.link, raw)
|
||||||
|
packet.send()
|
||||||
|
return packet.receipt
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mdu(self) -> int:
|
||||||
|
return self.link.MDU
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rtt(self) -> float:
|
||||||
|
return self.link.rtt
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_usuable(self):
|
||||||
|
return True #self.link.status in [RNS.Link.ACTIVE]
|
||||||
|
|
||||||
|
def get_receipt_state(self, receipt: RNS.PacketReceipt) -> MessageState:
|
||||||
|
status = receipt.get_status()
|
||||||
|
if status == RNS.PacketReceipt.SENT:
|
||||||
|
return protocol.MessageState.MSGSTATE_SENT
|
||||||
|
if status == RNS.PacketReceipt.DELIVERED:
|
||||||
|
return protocol.MessageState.MSGSTATE_DELIVERED
|
||||||
|
if status == RNS.PacketReceipt.FAILED:
|
||||||
|
return protocol.MessageState.MSGSTATE_FAILED
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unexpected receipt state: {status}")
|
||||||
|
|
||||||
|
def timed_out(self):
|
||||||
|
self.link.teardown()
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"Outlet RNS Link {self.link}"
|
||||||
|
|
||||||
|
def set_packet_received_callback(self, cb: Callable[[MessageOutletBase, bytes], None]):
|
||||||
|
def inner_cb(message, packet: RNS.Packet):
|
||||||
|
packet.prove()
|
||||||
|
cb(self, message)
|
||||||
|
|
||||||
|
self.link.set_packet_callback(inner_cb)
|
||||||
|
|
||||||
|
def __init__(self, link: RNS.Link):
|
||||||
|
self.link = link
|
||||||
|
link.lsoutlet = self
|
||||||
|
link.msgoutlet = self
|
||||||
|
@staticmethod
|
||||||
|
def get_outlet(link: RNS.Link):
|
||||||
|
if hasattr(link, "lsoutlet"):
|
||||||
|
return link.lsoutlet
|
||||||
|
|
||||||
|
return RNSOutlet(link)
|
@ -1,6 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from rnsh.protocol import _TReceipt, MessageState
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
logging.getLogger().setLevel(logging.DEBUG)
|
logging.getLogger().setLevel(logging.DEBUG)
|
||||||
|
|
||||||
import rnsh.protocol
|
import rnsh.protocol
|
||||||
@ -14,64 +18,60 @@ import uuid
|
|||||||
module_logger = logging.getLogger(__name__)
|
module_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Link:
|
class Receipt:
|
||||||
|
def __init__(self, state: rnsh.protocol.MessageState, raw: bytes):
|
||||||
|
self.state = state
|
||||||
|
self.raw = raw
|
||||||
|
|
||||||
|
|
||||||
|
class MessageOutletTest(rnsh.protocol.MessageOutletBase):
|
||||||
def __init__(self, mdu: int, rtt: float):
|
def __init__(self, mdu: int, rtt: float):
|
||||||
self.link_id = uuid.uuid4()
|
self.link_id = uuid.uuid4()
|
||||||
self.timeout_callbacks = 0
|
self.timeout_callbacks = 0
|
||||||
self.mdu = mdu
|
self._mdu = mdu
|
||||||
self.rtt = rtt
|
self._rtt = rtt
|
||||||
self.usable = True
|
self._usable = True
|
||||||
self.receipts = []
|
self.receipts = []
|
||||||
|
self.packet_callback: Callable[[rnsh.protocol.MessageOutletBase, bytes], None] | None = None
|
||||||
|
|
||||||
def timeout_callback(self):
|
def send(self, raw: bytes) -> Receipt:
|
||||||
|
receipt = Receipt(rnsh.protocol.MessageState.MSGSTATE_SENT, raw)
|
||||||
|
self.receipts.append(receipt)
|
||||||
|
return receipt
|
||||||
|
|
||||||
|
def set_packet_received_callback(self, cb: Callable[[rnsh.protocol.MessageOutletBase, bytes], None]):
|
||||||
|
self.packet_callback = cb
|
||||||
|
|
||||||
|
def receive(self, raw: bytes):
|
||||||
|
if self.packet_callback:
|
||||||
|
self.packet_callback(self, raw)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mdu(self):
|
||||||
|
return self._mdu
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rtt(self):
|
||||||
|
return self._rtt
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_usuable(self):
|
||||||
|
return self._usable
|
||||||
|
|
||||||
|
def get_receipt_state(self, receipt: Receipt) -> MessageState:
|
||||||
|
return receipt.state
|
||||||
|
|
||||||
|
def timed_out(self):
|
||||||
self.timeout_callbacks += 1
|
self.timeout_callbacks += 1
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return str(self.link_id)
|
return str(self.link_id)
|
||||||
|
|
||||||
|
|
||||||
class Receipt:
|
|
||||||
def __init__(self, link: Link, state: rnsh.protocol.MessageState, raw: bytes):
|
|
||||||
self.state = state
|
|
||||||
self.raw = raw
|
|
||||||
self.link = link
|
|
||||||
|
|
||||||
|
|
||||||
class ProtocolHarness(contextlib.AbstractContextManager):
|
class ProtocolHarness(contextlib.AbstractContextManager):
|
||||||
def __init__(self, retry_delay_min: float = 1):
|
def __init__(self, retry_delay_min: float = 1):
|
||||||
self._log = module_logger.getChild(self.__class__.__name__)
|
self._log = module_logger.getChild(self.__class__.__name__)
|
||||||
self.messenger = rnsh.protocol.Messenger(receipt_checker=self.receipt_checker,
|
self.messenger = rnsh.protocol.Messenger(retry_delay_min=retry_delay_min)
|
||||||
link_timeout_callback=self.link_timeout_callback,
|
|
||||||
link_mdu_getter=self.link_mdu_getter,
|
|
||||||
link_rtt_getter=self.link_rtt_getter,
|
|
||||||
link_usable_getter=self.link_usable_getter,
|
|
||||||
packet_sender=self.packet_sender,
|
|
||||||
retry_delay_min=retry_delay_min)
|
|
||||||
|
|
||||||
def packet_sender(self, link: Link, raw: bytes) -> Receipt:
|
|
||||||
receipt = Receipt(link, rnsh.protocol.MessageState.MSGSTATE_SENT, raw)
|
|
||||||
link.receipts.append(receipt)
|
|
||||||
return receipt
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def link_mdu_getter(link: Link):
|
|
||||||
return link.mdu
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def link_rtt_getter(link: Link):
|
|
||||||
return link.rtt
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def link_usable_getter(link: Link):
|
|
||||||
return link.usable
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def receipt_checker(receipt: Receipt) -> rnsh.protocol.MessageState:
|
|
||||||
return receipt.state
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def link_timeout_callback(link: Link):
|
|
||||||
link.timeout_callback()
|
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
self.messenger.shutdown()
|
self.messenger.shutdown()
|
||||||
@ -83,49 +83,58 @@ class ProtocolHarness(contextlib.AbstractContextManager):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def test_mdu():
|
|
||||||
with ProtocolHarness() as h:
|
|
||||||
mdu = 500
|
|
||||||
link = Link(mdu=mdu, rtt=0.25)
|
|
||||||
assert h.messenger.get_mdu(link) == mdu - 4
|
|
||||||
link.mdu = mdu = 600
|
|
||||||
assert h.messenger.get_mdu(link) == mdu - 4
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtt():
|
|
||||||
with ProtocolHarness() as h:
|
|
||||||
rtt = 0.25
|
|
||||||
link = Link(mdu=500, rtt=rtt)
|
|
||||||
assert h.messenger.get_rtt(link) == rtt
|
|
||||||
|
|
||||||
|
|
||||||
def test_send_one_retry():
|
def test_send_one_retry():
|
||||||
rtt = 0.001
|
rtt = 0.001
|
||||||
retry_interval = rtt * 150
|
retry_interval = rtt * 150
|
||||||
message_content = b'Test'
|
message_content = b'Test'
|
||||||
with ProtocolHarness(retry_delay_min=retry_interval) as h:
|
with ProtocolHarness(retry_delay_min=retry_interval) as h:
|
||||||
link = Link(mdu=500, rtt=rtt)
|
outlet = MessageOutletTest(mdu=500, rtt=rtt)
|
||||||
message = rnsh.protocol.StreamDataMessage(stream_id=rnsh.protocol.StreamDataMessage.STREAM_ID_STDIN,
|
message = rnsh.protocol.StreamDataMessage(stream_id=rnsh.protocol.StreamDataMessage.STREAM_ID_STDIN,
|
||||||
data=message_content, eof=True)
|
data=message_content, eof=True)
|
||||||
assert len(link.receipts) == 0
|
assert len(outlet.receipts) == 0
|
||||||
h.messenger.send_message(link, message)
|
h.messenger.send(outlet, message)
|
||||||
assert message.tracked
|
assert message.tracked
|
||||||
assert message.raw is not None
|
assert message.raw is not None
|
||||||
assert len(link.receipts) == 1
|
assert len(outlet.receipts) == 1
|
||||||
receipt = link.receipts[0]
|
receipt = outlet.receipts[0]
|
||||||
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
|
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
|
||||||
assert receipt.raw == message.raw
|
assert receipt.raw == message.raw
|
||||||
time.sleep(retry_interval * 1.5)
|
time.sleep(retry_interval * 1.5)
|
||||||
assert len(link.receipts) == 1
|
assert len(outlet.receipts) == 1
|
||||||
receipt.state = rnsh.protocol.MessageState.MSGSTATE_FAILED
|
receipt.state = rnsh.protocol.MessageState.MSGSTATE_FAILED
|
||||||
module_logger.info("set failed")
|
module_logger.info("set failed")
|
||||||
time.sleep(retry_interval)
|
time.sleep(retry_interval)
|
||||||
assert len(link.receipts) == 2
|
assert len(outlet.receipts) == 2
|
||||||
receipt = link.receipts[1]
|
receipt = outlet.receipts[1]
|
||||||
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
|
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
|
||||||
receipt.state = rnsh.protocol.MessageState.MSGSTATE_DELIVERED
|
receipt.state = rnsh.protocol.MessageState.MSGSTATE_DELIVERED
|
||||||
time.sleep(retry_interval)
|
time.sleep(retry_interval)
|
||||||
assert len(link.receipts) == 2
|
assert len(outlet.receipts) == 2
|
||||||
|
assert not message.tracked
|
||||||
|
|
||||||
|
|
||||||
|
def test_send_timeout():
|
||||||
|
rtt = 0.001
|
||||||
|
retry_interval = rtt * 150
|
||||||
|
message_content = b'Test'
|
||||||
|
with ProtocolHarness(retry_delay_min=retry_interval) as h:
|
||||||
|
outlet = MessageOutletTest(mdu=500, rtt=rtt)
|
||||||
|
message = rnsh.protocol.StreamDataMessage(stream_id=rnsh.protocol.StreamDataMessage.STREAM_ID_STDIN,
|
||||||
|
data=message_content, eof=True)
|
||||||
|
assert len(outlet.receipts) == 0
|
||||||
|
h.messenger.send(outlet, message)
|
||||||
|
assert message.tracked
|
||||||
|
assert message.raw is not None
|
||||||
|
assert len(outlet.receipts) == 1
|
||||||
|
receipt = outlet.receipts[0]
|
||||||
|
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
|
||||||
|
assert receipt.raw == message.raw
|
||||||
|
time.sleep(retry_interval * 1.5)
|
||||||
|
assert outlet.timeout_callbacks == 0
|
||||||
|
time.sleep(retry_interval * 7)
|
||||||
|
assert len(outlet.receipts) == 1
|
||||||
|
assert outlet.timeout_callbacks == 1
|
||||||
|
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
|
||||||
assert not message.tracked
|
assert not message.tracked
|
||||||
|
|
||||||
|
|
||||||
@ -133,24 +142,32 @@ def eat_own_dog_food(message: rnsh.protocol.Message, checker: typing.Callable[[r
|
|||||||
rtt = 0.001
|
rtt = 0.001
|
||||||
retry_interval = rtt * 150
|
retry_interval = rtt * 150
|
||||||
with ProtocolHarness(retry_delay_min=retry_interval) as h:
|
with ProtocolHarness(retry_delay_min=retry_interval) as h:
|
||||||
link = Link(mdu=500, rtt=rtt)
|
|
||||||
assert len(link.receipts) == 0
|
decoded: [rnsh.protocol.Message] = []
|
||||||
h.messenger.send_message(link, message)
|
def packet(outlet, buffer):
|
||||||
|
decoded.append(h.messenger.receive(buffer))
|
||||||
|
|
||||||
|
outlet = MessageOutletTest(mdu=500, rtt=rtt)
|
||||||
|
outlet.set_packet_received_callback(packet)
|
||||||
|
assert len(outlet.receipts) == 0
|
||||||
|
h.messenger.send(outlet, message)
|
||||||
assert message.tracked
|
assert message.tracked
|
||||||
assert message.raw is not None
|
assert message.raw is not None
|
||||||
assert len(link.receipts) == 1
|
assert len(outlet.receipts) == 1
|
||||||
receipt = link.receipts[0]
|
receipt = outlet.receipts[0]
|
||||||
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
|
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
|
||||||
assert receipt.raw == message.raw
|
assert receipt.raw == message.raw
|
||||||
module_logger.info("set delivered")
|
module_logger.info("set delivered")
|
||||||
receipt.state = rnsh.protocol.MessageState.MSGSTATE_DELIVERED
|
receipt.state = rnsh.protocol.MessageState.MSGSTATE_DELIVERED
|
||||||
time.sleep(retry_interval * 2)
|
time.sleep(retry_interval * 2)
|
||||||
assert len(link.receipts) == 1
|
assert len(outlet.receipts) == 1
|
||||||
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_DELIVERED
|
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_DELIVERED
|
||||||
assert not message.tracked
|
assert not message.tracked
|
||||||
module_logger.info("injecting rx message")
|
module_logger.info("injecting rx message")
|
||||||
h.messenger.inbound(message.raw)
|
assert len(decoded) == 0
|
||||||
rx_message = h.messenger.poll_inbound(block=False)
|
outlet.receive(message.raw)
|
||||||
|
assert len(decoded) == 1
|
||||||
|
rx_message = decoded[0]
|
||||||
assert rx_message is not None
|
assert rx_message is not None
|
||||||
assert isinstance(rx_message, message.__class__)
|
assert isinstance(rx_message, message.__class__)
|
||||||
assert rx_message.msgid != message.msgid
|
assert rx_message.msgid != message.msgid
|
||||||
@ -238,6 +255,16 @@ def test_send_receive_error():
|
|||||||
eat_own_dog_food(message, check)
|
eat_own_dog_food(message, check)
|
||||||
|
|
||||||
|
|
||||||
|
def test_send_receive_cmdexit():
|
||||||
|
message = rnsh.protocol.CommandExitedMessage(5)
|
||||||
|
|
||||||
|
def check(rx_message: rnsh.protocol.Message):
|
||||||
|
assert isinstance(rx_message, message.__class__)
|
||||||
|
assert rx_message.return_code == message.return_code
|
||||||
|
|
||||||
|
eat_own_dog_food(message, check)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,19 +12,6 @@ import re
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_check_magic():
|
|
||||||
magic = rnsh.rnsh._PROTOCOL_VERSION_0
|
|
||||||
# magic for version 0 is generated, make sure it comes out as expected
|
|
||||||
assert magic == 0xdeadbeef00000000
|
|
||||||
# verify the checker thinks it's right
|
|
||||||
assert rnsh.rnsh._protocol_check_magic(magic)
|
|
||||||
# scramble the magic
|
|
||||||
magic = magic | 0x00ffff0000000000
|
|
||||||
# make sure it fails now
|
|
||||||
assert not rnsh.rnsh._protocol_check_magic(magic)
|
|
||||||
|
|
||||||
|
|
||||||
def test_version():
|
def test_version():
|
||||||
# version = importlib.metadata.version(rnsh.__version__)
|
# version = importlib.metadata.version(rnsh.__version__)
|
||||||
assert rnsh.__version__ != "0.0.0"
|
assert rnsh.__version__ != "0.0.0"
|
||||||
|
Loading…
Reference in New Issue
Block a user