Got the new protocol working.

This commit is contained in:
Aaron Heise 2023-02-18 00:09:28 -06:00
parent 0ee305795f
commit 8edb4020b1
8 changed files with 876 additions and 887 deletions

8
rnsh/helpers.py Normal file
View File

@ -0,0 +1,8 @@
def bitwise_or_if(value: int, condition: bool, orval: int):
if not condition:
return value
return value | orval
def check_and(value: int, andval: int) -> bool:
return (value & andval) > 0

View File

@ -382,8 +382,11 @@ def _launch_child(cmd_line: list[str], env: dict[str, str], stdin_is_pipe: bool,
# Make PTY controlling if necessary # Make PTY controlling if necessary
if child_fd is not None: if child_fd is not None:
os.setsid() os.setsid()
try:
tmp_fd = os.open(os.ttyname(0 if not stdin_is_pipe else 1 if not stdout_is_pipe else 2), os.O_RDWR) tmp_fd = os.open(os.ttyname(0 if not stdin_is_pipe else 1 if not stdout_is_pipe else 2), os.O_RDWR)
os.close(tmp_fd) os.close(tmp_fd)
except:
pass
# fcntl.ioctl(0 if not stdin_is_pipe else 1 if not stdout_is_pipe else 2), os.O_RDWR, termios.TIOCSCTTY, 0) # fcntl.ioctl(0 if not stdin_is_pipe else 1 if not stdout_is_pipe else 2), os.O_RDWR, termios.TIOCSCTTY, 0)
# Execute the command # Execute the command
@ -445,7 +448,8 @@ class CallbackSubprocess:
self._child_stdout: int = None self._child_stdout: int = None
self._child_stderr: int = None self._child_stderr: int = None
self._return_code: int = None self._return_code: int = None
self._eof: bool = False self._stdout_eof: bool = False
self._stderr_eof: bool = False
self._stdin_is_pipe = stdin_is_pipe self._stdin_is_pipe = stdin_is_pipe
self._stdout_is_pipe = stdout_is_pipe self._stdout_is_pipe = stdout_is_pipe
self._stderr_is_pipe = stderr_is_pipe self._stderr_is_pipe = stderr_is_pipe
@ -455,6 +459,21 @@ class CallbackSubprocess:
Terminate child process if running Terminate child process if running
:param kill_delay: if after kill_delay seconds the child process has not exited, escalate to SIGHUP and SIGKILL :param kill_delay: if after kill_delay seconds the child process has not exited, escalate to SIGHUP and SIGKILL
""" """
same = self._child_stdout == self._child_stderr
with contextlib.suppress(OSError):
if not self._stdout_eof:
tty_unset_reader_callbacks(self._child_stdout)
os.close(self._child_stdout)
self._child_stdout = None
if not same:
with contextlib.suppress(OSError):
if not self._stderr_eof:
tty_unset_reader_callbacks(self._child_stderr)
os.close(self._child_stderr)
self._child_stdout = None
self._log.debug("terminate()") self._log.debug("terminate()")
if not self.running: if not self.running:
return return
@ -477,7 +496,7 @@ class CallbackSubprocess:
os.waitpid(self._pid, 0) os.waitpid(self._pid, 0)
self._log.debug("wait() finish") self._log.debug("wait() finish")
threading.Thread(target=wait).start() threading.Thread(target=wait, daemon=True).start()
def close_stdin(self): def close_stdin(self):
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
@ -592,26 +611,44 @@ class CallbackSubprocess:
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll) self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
def reader(fd: int, callback: callable): def stdout():
try: try:
with exception.permit(SystemExit): with exception.permit(SystemExit):
data = tty_read(fd) data = tty_read_poll(self._child_stdout)
if data is not None and len(data) > 0: if data is not None and len(data) > 0:
callback(data) self._stdout_cb(data)
except EOFError: except EOFError:
self._eof = True self._stdout_eof = True
tty_unset_reader_callbacks(self._child_stdout) tty_unset_reader_callbacks(self._child_stdout)
callback(bytearray()) self._stdout_cb(bytearray())
with contextlib.suppress(OSError):
os.close(self._child_stdout)
tty_add_reader_callback(self._child_stdout, functools.partial(reader, self._child_stdout, self._stdout_cb), def stderr():
self._loop) try:
with exception.permit(SystemExit):
data = tty_read_poll(self._child_stderr)
if data is not None and len(data) > 0:
self._stderr_cb(data)
except EOFError:
self._stderr_eof = True
tty_unset_reader_callbacks(self._child_stderr)
self._stdout_cb(bytearray())
with contextlib.suppress(OSError):
os.close(self._child_stderr)
tty_add_reader_callback(self._child_stdout, stdout, self._loop)
if self._child_stderr != self._child_stdout: if self._child_stderr != self._child_stdout:
tty_add_reader_callback(self._child_stderr, functools.partial(reader, self._child_stderr, self._stderr_cb), tty_add_reader_callback(self._child_stderr, stderr, self._loop)
self._loop)
@property @property
def eof(self): def stdout_eof(self):
return self._eof or not self.running return self._stdout_eof or not self.running
@property
def stderr_eof(self):
return self._stderr_eof or not self.running
@property @property
def return_code(self) -> int: def return_code(self) -> int:

View File

@ -15,6 +15,7 @@ import abc
import contextlib import contextlib
import struct import struct
import logging as __logging import logging as __logging
from abc import ABC, abstractmethod
module_logger = __logging.getLogger(__name__) module_logger = __logging.getLogger(__name__)
@ -29,6 +30,43 @@ def _make_MSGTYPE(val: int):
return ((MSG_MAGIC << 8) & 0xff00) | (val & 0x00ff) return ((MSG_MAGIC << 8) & 0xff00) | (val & 0x00ff)
class MessageOutletBase(ABC):
@abstractmethod
def send(self, raw: bytes) -> _TReceipt:
raise NotImplemented()
@property
@abstractmethod
def mdu(self):
raise NotImplemented()
@property
@abstractmethod
def rtt(self):
raise NotImplemented()
@property
@abstractmethod
def is_usuable(self):
raise NotImplemented()
@abstractmethod
def get_receipt_state(self, receipt: _TReceipt) -> MessageState:
raise NotImplemented()
@abstractmethod
def timed_out(self):
raise NotImplemented()
@abstractmethod
def __str__(self):
raise NotImplemented()
@abstractmethod
def set_packet_received_callback(self, cb: Callable[[MessageOutletBase, bytes], None]):
raise NotImplemented()
class METype(enum.IntEnum): class METype(enum.IntEnum):
ME_NO_MSG_TYPE = 0 ME_NO_MSG_TYPE = 0
ME_INVALID_MSG_TYPE = 1 ME_INVALID_MSG_TYPE = 1
@ -58,17 +96,17 @@ class Message(abc.ABC):
self.msgid = uuid.uuid4() self.msgid = uuid.uuid4()
self.raw: bytes | None = None self.raw: bytes | None = None
self.receipt: _TReceipt = None self.receipt: _TReceipt = None
self.link: _TLink = None self.outlet: _TLink = None
self.tracked: bool = False self.tracked: bool = False
def __str__(self): def __str__(self):
return f"{self.__class__.__name__} {self.msgid}" return f"{self.__class__.__name__} {self.msgid}"
@abc.abstractmethod @abstractmethod
def pack(self) -> bytes: def pack(self) -> bytes:
raise NotImplemented() raise NotImplemented()
@abc.abstractmethod @abstractmethod
def unpack(self, raw): def unpack(self, raw):
raise NotImplemented() raise NotImplemented()
@ -124,7 +162,8 @@ class ExecuteCommandMesssage(Message):
MSGTYPE = _make_MSGTYPE(3) MSGTYPE = _make_MSGTYPE(3)
def __init__(self, cmdline: [str] = None, pipe_stdin: bool = False, pipe_stdout: bool = False, def __init__(self, cmdline: [str] = None, pipe_stdin: bool = False, pipe_stdout: bool = False,
pipe_stderr: bool = False, tcflags: [any] = None, term: str | None = None): pipe_stderr: bool = False, tcflags: [any] = None, term: str | None = None, rows: int = None,
cols: int = None, hpix: int = None, vpix: int = None):
super().__init__() super().__init__()
self.cmdline = cmdline self.cmdline = cmdline
self.pipe_stdin = pipe_stdin self.pipe_stdin = pipe_stdin
@ -132,16 +171,20 @@ class ExecuteCommandMesssage(Message):
self.pipe_stderr = pipe_stderr self.pipe_stderr = pipe_stderr
self.tcflags = tcflags self.tcflags = tcflags
self.term = term self.term = term
self.rows = rows
self.cols = cols
self.hpix = hpix
self.vpix = vpix
def pack(self) -> bytes: def pack(self) -> bytes:
raw = umsgpack.packb((self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, raw = umsgpack.packb((self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr,
self.tcflags, self.term)) self.tcflags, self.term, self.rows, self.cols, self.hpix, self.vpix))
return self.wrap_MSGTYPE(raw) return self.wrap_MSGTYPE(raw)
def unpack(self, raw): def unpack(self, raw):
raw = self.unwrap_MSGTYPE(raw) raw = self.unwrap_MSGTYPE(raw)
self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, self.tcflags, self.term \ self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, self.tcflags, self.term, self.rows, \
= umsgpack.unpackb(raw) self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw)
class StreamDataMessage(Message): class StreamDataMessage(Message):
MSGTYPE = _make_MSGTYPE(4) MSGTYPE = _make_MSGTYPE(4)
@ -156,7 +199,7 @@ class StreamDataMessage(Message):
self.eof = eof self.eof = eof
def pack(self) -> bytes: def pack(self) -> bytes:
raw = umsgpack.packb((self.stream_id, self.eof, self.data)) raw = umsgpack.packb((self.stream_id, self.eof, bytes(self.data)))
return self.wrap_MSGTYPE(raw) return self.wrap_MSGTYPE(raw)
def unpack(self, raw): def unpack(self, raw):
@ -169,7 +212,7 @@ class VersionInfoMessage(Message):
def __init__(self, sw_version: str = None): def __init__(self, sw_version: str = None):
super().__init__() super().__init__()
self.sw_version = sw_version self.sw_version = sw_version or rnsh.__version__
self.protocol_version = PROTOCOL_VERSION self.protocol_version = PROTOCOL_VERSION
def pack(self) -> bytes: def pack(self) -> bytes:
@ -199,6 +242,22 @@ class ErrorMessage(Message):
self.msg, self.fatal, self.data = umsgpack.unpackb(raw) self.msg, self.fatal, self.data = umsgpack.unpackb(raw)
class CommandExitedMessage(Message):
MSGTYPE = _make_MSGTYPE(7)
def __init__(self, return_code: int = None):
super().__init__()
self.return_code = return_code
def pack(self) -> bytes:
raw = umsgpack.packb(self.return_code)
return self.wrap_MSGTYPE(raw)
def unpack(self, raw: bytes):
raw = self.unwrap_MSGTYPE(raw)
self.return_code = umsgpack.unpackb(raw)
class Messenger(contextlib.AbstractContextManager): class Messenger(contextlib.AbstractContextManager):
@staticmethod @staticmethod
@ -208,29 +267,16 @@ class Messenger(contextlib.AbstractContextManager):
subclass_tuples.append((subclass.MSGTYPE, subclass)) subclass_tuples.append((subclass.MSGTYPE, subclass))
return subclass_tuples return subclass_tuples
def __init__(self, receipt_checker: Callable[[_TReceipt], MessageState], def __init__(self, retry_delay_min: float = 10.0):
link_timeout_callback: Callable[[_TLink], None],
link_mdu_getter: Callable[[_TLink], int],
link_rtt_getter: Callable[[_TLink], float],
link_usable_getter: Callable[[_TLink], bool],
packet_sender: Callable[[_TLink, bytes], _TReceipt],
retry_delay_min: float = 10.0):
self._log = module_logger.getChild(self.__class__.__name__) self._log = module_logger.getChild(self.__class__.__name__)
self._receipt_checker = receipt_checker
self._link_timeout_callback = link_timeout_callback
self._link_mdu_getter = link_mdu_getter
self._link_rtt_getter = link_rtt_getter
self._link_usable_getter = link_usable_getter
self._packet_sender = packet_sender
self._sent_messages: list[Message] = [] self._sent_messages: list[Message] = []
self._lock = threading.RLock() self._lock = threading.RLock()
self._retry_timer = rnsh.retry.RetryThread() self._retry_timer = rnsh.retry.RetryThread()
self._message_factories = dict(self.__class__._get_msg_constructors()) self._message_factories = dict(self.__class__._get_msg_constructors())
self._inbound_queue = queue.Queue()
self._retry_delay_min = retry_delay_min self._retry_delay_min = retry_delay_min
def __enter__(self): def __enter__(self) -> Messenger:
pass return self
def __exit__(self, __exc_type: Type[BaseException] | None, __exc_value: BaseException | None, def __exit__(self, __exc_type: Type[BaseException] | None, __exc_value: BaseException | None,
__traceback: TracebackType | None) -> bool | None: __traceback: TracebackType | None) -> bool | None:
@ -238,39 +284,38 @@ class Messenger(contextlib.AbstractContextManager):
return False return False
def shutdown(self): def shutdown(self):
self._run = False
self._retry_timer.close() self._retry_timer.close()
def inbound(self, raw: bytes): def clear_retries(self, outlet):
self._retry_timer.complete(outlet)
def receive(self, raw: bytes) -> Message:
(mid, contents) = Message.static_unwrap_MSGTYPE(raw) (mid, contents) = Message.static_unwrap_MSGTYPE(raw)
ctor = self._message_factories.get(mid, None) ctor = self._message_factories.get(mid, None)
if ctor is None: if ctor is None:
raise MessagingException(METype.ME_NOT_REGISTERED, f"unable to find constructor for message type {hex(mid)}") raise MessagingException(METype.ME_NOT_REGISTERED, f"unable to find constructor for message type {hex(mid)}")
message = ctor() message = ctor()
message.unpack(raw) message.unpack(raw)
self._log.debug("Message received: {message}") self._log.debug(f"Message received: {message}")
self._inbound_queue.put(message) return message
def get_mdu(self, link: _TLink) -> int: def is_outlet_ready(self, outlet: MessageOutletBase) -> bool:
return self._link_mdu_getter(link) - 4 if not outlet.is_usuable:
self._log.debug("is_outlet_ready outlet unusable")
def get_rtt(self, link: _TLink) -> float:
return self._link_rtt_getter(link)
def is_link_ready(self, link: _TLink) -> bool:
if not self._link_usable_getter(link):
return False return False
with self._lock: with self._lock:
for message in self._sent_messages: for message in self._sent_messages:
if message.link == link: if message.outlet == outlet and message.tracked and message.receipt \
and outlet.get_receipt_state(message.receipt) == MessageState.MSGSTATE_SENT:
self._log.debug("is_outlet_ready pending message found")
return False return False
return True return True
def send_message(self, link: _TLink, message: Message): def send(self, outlet: MessageOutletBase, message: Message):
with self._lock: with self._lock:
if not self.is_link_ready(link): if not self.is_outlet_ready(outlet):
raise MessagingException(METype.ME_LINK_NOT_READY, f"link {link} not ready") raise MessagingException(METype.ME_LINK_NOT_READY, f"link {outlet} not ready")
if message in self._sent_messages: if message in self._sent_messages:
raise MessagingException(METype.ME_ALREADY_SENT) raise MessagingException(METype.ME_ALREADY_SENT)
@ -279,43 +324,36 @@ class Messenger(contextlib.AbstractContextManager):
if not message.raw: if not message.raw:
message.raw = message.pack() message.raw = message.pack()
message.link = link message.outlet = outlet
def send(tag: any, tries: int): def send_inner(tag: any, tries: int):
state = MessageState.MSGSTATE_NEW if not message.receipt else self._receipt_checker(message.receipt) state = MessageState.MSGSTATE_NEW if not message.receipt else outlet.get_receipt_state(message.receipt)
if state in [MessageState.MSGSTATE_NEW, MessageState.MSGSTATE_FAILED]: if state in [MessageState.MSGSTATE_NEW, MessageState.MSGSTATE_FAILED]:
try: try:
self._log.debug(f"Sending packet for {message}") self._log.debug(f"Sending packet for {message}")
message.receipt = self._packet_sender(link, message.raw) message.receipt = outlet.send(message.raw)
except Exception as ex: except Exception as ex:
self._log.exception(f"Error sending message {message}") self._log.exception(f"Error sending message {message}")
elif state in [MessageState.MSGSTATE_SENT]: elif state in [MessageState.MSGSTATE_SENT]:
self._log.debug(f"Retry skipped, message still pending {message}") self._log.debug(f"Retry skipped, message still pending {message}")
elif state in [MessageState.MSGSTATE_DELIVERED]: elif state in [MessageState.MSGSTATE_DELIVERED]:
latency = round(time.time() - message.ts, 1) latency = round(time.time() - message.ts, 1)
self._log.debug(f"Message delivered {message.msgid} after {tries-1} tries/{latency} seconds") self._log.debug(f"{message} delivered {message.msgid} after {tries-1} tries/{latency} seconds")
with self._lock: with self._lock:
self._sent_messages.remove(message) self._sent_messages.remove(message)
message.tracked = False message.tracked = False
self._retry_timer.complete(link) self._retry_timer.complete(outlet)
return link return outlet
def timeout(tag: any, tries: int): def timeout(tag: any, tries: int):
latency = round(time.time() - message.ts, 1) latency = round(time.time() - message.ts, 1)
msg = "delivered" if message.receipt and self._receipt_checker(message.receipt) == MessageState.MSGSTATE_DELIVERED else "retry timeout" msg = "delivered" if message.receipt and outlet.get_receipt_state(message.receipt) == MessageState.MSGSTATE_DELIVERED else "retry timeout"
self._log.debug(f"Message {msg} {message} after {tries} tries/{latency} seconds") self._log.debug(f"Message {msg} {message} after {tries} tries/{latency} seconds")
with self._lock: with self._lock:
self._sent_messages.remove(message) self._sent_messages.remove(message)
message.tracked = False message.tracked = False
self._link_timeout_callback(link) outlet.timed_out()
rtt = self._link_rtt_getter(link)
self._retry_timer.begin(5, min(rtt * 100, max(rtt * 2, self._retry_delay_min)), send, timeout)
def poll_inbound(self, block: bool = True, timeout: float = None) -> Message | None:
try:
return self._inbound_queue.get(block=block, timeout=timeout)
except queue.Empty:
return None
rtt = outlet.rtt
self._retry_timer.begin(5, max(rtt * 5, self._retry_delay_min), send_inner, timeout)

View File

@ -79,7 +79,7 @@ class RetryThread(AbstractContextManager):
self._lock = threading.RLock() self._lock = threading.RLock()
self._run = True self._run = True
self._finished: asyncio.Future = None self._finished: asyncio.Future = None
self._thread = threading.Thread(name=name, target=self._thread_run) self._thread = threading.Thread(name=name, target=self._thread_run, daemon=True)
self._thread.start() self._thread.start()
def is_alive(self): def is_alive(self):

View File

@ -26,10 +26,12 @@ from __future__ import annotations
import asyncio import asyncio
import base64 import base64
import enum
import functools import functools
import importlib.metadata import importlib.metadata
import logging as __logging import logging as __logging
import os import os
import queue
import shlex import shlex
import signal import signal
import sys import sys
@ -44,10 +46,12 @@ import rnsh.process as process
import rnsh.retry as retry import rnsh.retry as retry
import rnsh.rnslogging as rnslogging import rnsh.rnslogging as rnslogging
import rnsh.hacks as hacks import rnsh.hacks as hacks
import rnsh.session as session
import re import re
import contextlib import contextlib
import rnsh.args import rnsh.args
import pwd import pwd
import rnsh.protocol as protocol
module_logger = __logging.getLogger(__name__) module_logger = __logging.getLogger(__name__)
@ -141,13 +145,18 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
log.info(f"Using command {shlex.join(_cmd)}") log.info(f"Using command {shlex.join(_cmd)}")
_no_remote_command = no_remote_command _no_remote_command = no_remote_command
session.ListenerSession.allow_remote_command = not no_remote_command
_remote_cmd_as_args = remote_cmd_as_args _remote_cmd_as_args = remote_cmd_as_args
if (_cmd is None or len(_cmd) == 0 or _cmd[0] is None or len(_cmd[0]) == 0) \ if (_cmd is None or len(_cmd) == 0 or _cmd[0] is None or len(_cmd[0]) == 0) \
and (_no_remote_command or _remote_cmd_as_args): and (_no_remote_command or _remote_cmd_as_args):
raise Exception(f"Unable to look up shell for {os.getlogin}, cannot proceed with -A or -C and no <program>.") raise Exception(f"Unable to look up shell for {os.getlogin}, cannot proceed with -A or -C and no <program>.")
session.ListenerSession.default_command = _cmd
session.ListenerSession.remote_cmd_as_args = _remote_cmd_as_args
if disable_auth: if disable_auth:
_allow_all = True _allow_all = True
session.ListenerSession.allow_all = True
else: else:
if allowed is not None: if allowed is not None:
for a in allowed: for a in allowed:
@ -161,6 +170,7 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
try: try:
destination_hash = bytes.fromhex(a) destination_hash = bytes.fromhex(a)
_allowed_identity_hashes.append(destination_hash) _allowed_identity_hashes.append(destination_hash)
session.ListenerSession.allowed_identity_hashes.append(destination_hash)
except Exception: except Exception:
raise ValueError("Invalid destination entered. Check your input.") raise ValueError("Invalid destination entered. Check your input.")
except Exception as e: except Exception as e:
@ -170,23 +180,9 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
if len(_allowed_identity_hashes) < 1 and not disable_auth: if len(_allowed_identity_hashes) < 1 and not disable_auth:
log.warning("Warning: No allowed identities configured, rnsh will not accept any connections!") log.warning("Warning: No allowed identities configured, rnsh will not accept any connections!")
_destination.set_link_established_callback(_listen_link_established) def link_established(lnk: RNS.Link):
session.ListenerSession(session.RNSOutlet.get_outlet(lnk), _loop)
if not _allow_all: _destination.set_link_established_callback(link_established)
_destination.register_request_handler(
path="data",
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
# response_generator=_listen_request,
allow=RNS.Destination.ALLOW_LIST,
allowed_list=_allowed_identity_hashes
)
else:
_destination.register_request_handler(
path="data",
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
# response_generator=_listen_request,
allow=RNS.Destination.ALLOW_ALL,
)
if await _check_finished(): if await _check_finished():
return return
@ -199,531 +195,22 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
last = time.time() last = time.time()
try: try:
while not await _check_finished(1.0): while not await _check_finished():
if announce_period and 0 < announce_period < time.time() - last: if announce_period and 0 < announce_period < time.time() - last:
last = time.time() last = time.time()
_destination.announce() _destination.announce()
await session.ListenerSession.pump_all()
await asyncio.sleep(0.01)
finally: finally:
log.warning("Shutting down") log.warning("Shutting down")
for link in list(_destination.links): await session.ListenerSession.terminate_all("Shutting down")
with exception.permit(SystemExit, KeyboardInterrupt): await asyncio.sleep(1)
proc = Session.get_for_tag(link.link_id) session.ListenerSession.messenger.shutdown()
if proc is not None and proc.process.running:
proc.process.terminate()
await asyncio.sleep(0)
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links)) links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
for link in links_still_active: for link in links_still_active:
if link.status != RNS.Link.CLOSED: if link.status not in [RNS.Link.CLOSED]:
link.teardown() link.teardown()
await asyncio.sleep(0) await asyncio.sleep(0.01)
_PROTOCOL_MAGIC = 0xdeadbeef
def _protocol_make_version(version: int):
return (_PROTOCOL_MAGIC << 32) & 0xffffffff00000000 | (0xffffffff & version)
_PROTOCOL_VERSION_0 = _protocol_make_version(0)
_PROTOCOL_VERSION_1 = _protocol_make_version(1)
_PROTOCOL_VERSION_2 = _protocol_make_version(2)
_PROTOCOL_VERSION_3 = _protocol_make_version(3)
_PROTOCOL_VERSION_DEFAULT = _PROTOCOL_VERSION_3
def _protocol_split_version(version: int):
return (version >> 32) & 0xffffffff, version & 0xffffffff
def _protocol_check_magic(value: int):
return _protocol_split_version(value)[0] == _PROTOCOL_MAGIC
def _protocol_response_chars_take(link_mdu: int, version: int) -> int:
if version >= _PROTOCOL_VERSION_2:
return link_mdu - 64 # TODO: tune
else:
return link_mdu // 2
def _protocol_request_chars_take(link_mdu: int, version: int, term: str, cmd: str) -> int:
if version >= _PROTOCOL_VERSION_2:
return link_mdu - 15 * 8 - len(term) - len(cmd) - 20 # TODO: tune
else:
return link_mdu // 2
def _bitwise_or_if(value: int, condition: bool, orval: int):
if not condition:
return value
return value | orval
def _check_and(value: int, andval: int) -> bool:
return (value & andval) > 0
class Session:
_processes: [(any, Session)] = []
_lock = threading.RLock()
@classmethod
def get_for_tag(cls, tag: any) -> Session | None:
with cls._lock:
return next(map(lambda p: p[1], filter(lambda p: p[0] == tag, cls._processes)), None)
@classmethod
def put_for_tag(cls, tag: any, ps: Session):
with cls._lock:
cls.clear_tag(tag)
cls._processes.append((tag, ps))
@classmethod
def clear_tag(cls, tag: any):
with cls._lock:
with exception.permit(SystemExit):
cls._processes.remove(tag)
def __init__(self,
tag: any,
cmd: [str],
data_available_callback: callable,
terminated_callback: callable,
session_flags: int,
term: str | None,
remote_identity: str | None,
loop: asyncio.AbstractEventLoop = None):
self._log = _get_logger(self.__class__.__name__)
self._loop = loop if loop is not None else asyncio.get_running_loop()
self._process = process.CallbackSubprocess(argv=cmd,
env={"TERM": term or os.environ.get("TERM", None),
"RNS_REMOTE_IDENTITY": remote_identity or ""},
loop=loop,
stdout_callback=self._stdout_data,
stderr_callback=self._stderr_data,
terminated_callback=terminated_callback,
stdin_is_pipe=_check_and(session_flags,
Session.REQUEST_FLAGS_PIPE_STDIN),
stdout_is_pipe=_check_and(session_flags,
Session.REQUEST_FLAGS_PIPE_STDOUT),
stderr_is_pipe=_check_and(session_flags,
Session.REQUEST_FLAGS_PIPE_STDERR))
self._log.debug(f"Starting {cmd}")
self._stdout_buffer = bytearray()
self._stderr_buffer = bytearray()
self._lock = threading.RLock()
self._data_available_cb = data_available_callback
self._terminated_cb = terminated_callback
self._pending_receipt: RNS.PacketReceipt | None = None
self._process.start()
self._term_state: [int] = None
self._session_flags: int = session_flags
Session.put_for_tag(tag, self)
def pending_receipt_peek(self) -> RNS.PacketReceipt | None:
return self._pending_receipt
def pending_receipt_take(self) -> RNS.PacketReceipt | None:
with self._lock:
val = self._pending_receipt
self._pending_receipt = None
return val
def pending_receipt_put(self, receipt: RNS.PacketReceipt | None):
with self._lock:
self._pending_receipt = receipt
@property
def process(self) -> process.CallbackSubprocess:
return self._process
@property
def return_code(self) -> int | None:
return self.process.return_code
@property
def lock(self) -> threading.RLock:
return self._lock
def read_stdout(self, count: int) -> bytes:
with self.lock:
initial_len = len(self._stdout_buffer)
take = self._stdout_buffer[:count]
self._stdout_buffer = self._stdout_buffer[count:]
self._log.debug(f"stdout: read {len(take)} bytes of {initial_len}, {len(self._stdout_buffer)} remaining")
return take
def _stdout_data(self, data: bytes):
with self.lock:
self._stdout_buffer.extend(data)
total_available = len(self._stdout_buffer) + len(self._stderr_buffer)
try:
self._data_available_cb(total_available)
except Exception as e:
self._log.error(f"stdout: error calling ProcessState data_available_callback {e}")
def read_stderr(self, count: int) -> bytes:
with self.lock:
initial_len = len(self._stderr_buffer)
take = self._stderr_buffer[:count]
self._stderr_buffer = self._stderr_buffer[count:]
self._log.debug(f"stderr: read {len(take)} bytes of {initial_len}, {len(self._stderr_buffer)} remaining")
return take
def _stderr_data(self, data: bytes):
with self.lock:
self._stderr_buffer.extend(data)
total_available = len(self._stderr_buffer) + len(self._stdout_buffer)
try:
self._data_available_cb(total_available)
except Exception as e:
self._log.error(f"stderr: error calling ProcessState data_available_callback {e}")
TERMSTATE_IDX_TERM = 0
TERMSTATE_IDX_TIOS = 1
TERMSTATE_IDX_ROWS = 2
TERMSTATE_IDX_COLS = 3
TERMSTATE_IDX_HPIX = 4
TERMSTATE_IDX_VPIX = 5
def _update_winsz(self):
try:
self.process.set_winsize(self._term_state[1],
self._term_state[2],
self._term_state[3],
self._term_state[4])
except Exception as e:
self._log.debug(f"failed to update winsz: {e}")
REQUEST_IDX_VERS = 0
REQUEST_IDX_STDIN = 1
REQUEST_IDX_TERM = 2
REQUEST_IDX_TIOS = 3
REQUEST_IDX_ROWS = 4
REQUEST_IDX_COLS = 5
REQUEST_IDX_HPIX = 6
REQUEST_IDX_VPIX = 7
REQUEST_IDX_CMD = 8
REQUEST_IDX_FLAGS = 9
REQUEST_IDX_BYTES_AVAILABLE = 10
REQUEST_FLAGS_PIPE_STDIN = 0x01
REQUEST_FLAGS_PIPE_STDOUT = 0x02
REQUEST_FLAGS_PIPE_STDERR = 0x04
REQUEST_FLAGS_EOF_STDIN = 0x08
@staticmethod
def default_request() -> [any]:
global _tr
request: list[any] = [
_PROTOCOL_VERSION_DEFAULT, # 0 Protocol Version
None, # 1 Stdin
None, # 2 TERM variable
None, # 3 termios attributes or something
None, # 4 terminal rows
None, # 5 terminal cols
None, # 6 terminal horizontal pixels
None, # 7 terminal vertical pixels
None, # 8 Command to run
0, # 9 Flags
0, # 10 Bytes Available
].copy()
if os.isatty(0):
request[Session.REQUEST_IDX_TERM] = os.environ.get("TERM", None)
request[Session.REQUEST_IDX_TIOS] = _tr.original_attr() if _tr else None
with contextlib.suppress(OSError):
request[Session.REQUEST_IDX_ROWS], \
request[Session.REQUEST_IDX_COLS], \
request[Session.REQUEST_IDX_HPIX], \
request[Session.REQUEST_IDX_VPIX] = process.tty_get_winsize(0)
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], not os.isatty(0),
Session.REQUEST_FLAGS_PIPE_STDIN)
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], not os.isatty(1),
Session.REQUEST_FLAGS_PIPE_STDOUT)
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], not os.isatty(2),
Session.REQUEST_FLAGS_PIPE_STDERR)
return request
def process_request(self, data: [any], read_size: int) -> [any]:
stdin = data[Session.REQUEST_IDX_STDIN] # Data passed to stdin
# term = data[ProcessState.REQUEST_IDX_TERM] # TERM environment variable
# tios = data[ProcessState.REQUEST_IDX_TIOS] # termios attr
# rows = data[ProcessState.REQUEST_IDX_ROWS] # window rows
# cols = data[ProcessState.REQUEST_IDX_COLS] # window cols
# hpix = data[ProcessState.REQUEST_IDX_HPIX] # window horizontal pixels
# vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels
# term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX+1]
bytes_available = data[Session.REQUEST_IDX_BYTES_AVAILABLE]
flags = data[Session.REQUEST_IDX_FLAGS]
stdin_eof = _check_and(flags, Session.REQUEST_FLAGS_EOF_STDIN)
response = Session.default_response()
first_term_state = self._term_state is None
term_state = data[Session.REQUEST_IDX_TIOS:Session.REQUEST_IDX_VPIX + 1]
response[Session.RESPONSE_IDX_FLAGS] = _bitwise_or_if(response[Session.RESPONSE_IDX_FLAGS],
self.process.running, Session.RESPONSE_FLAGS_RUNNING)
if self.process.running:
if term_state != self._term_state:
self._term_state = term_state
if term_state is not None:
self._update_winsz()
if first_term_state is not None:
# TODO: use a more specific error
with contextlib.suppress(Exception):
self.process.tcsetattr(termios.TCSADRAIN, term_state[0])
if stdin is not None and len(stdin) > 0:
if data[Session.REQUEST_IDX_VERS] < _PROTOCOL_VERSION_2:
stdin = base64.b64decode(stdin)
self.process.write(stdin)
if stdin_eof and bytes_available == 0:
module_logger.debug("Closing stdin")
with contextlib.suppress(Exception):
self.process.close_stdin()
response[Session.RESPONSE_IDX_RETCODE] = None if self.process.running else self.return_code
with self.lock:
#prioritizing stderr
stderr = self.read_stderr(read_size)
stdout = self.read_stdout(read_size - len(stderr))
response[Session.RESPONSE_IDX_RDYBYTE] = len(self._stdout_buffer) + len(self._stderr_buffer)
if stderr is not None and len(stderr) > 0:
response[Session.RESPONSE_IDX_STDERR] = bytes(stderr)
if stdout is not None and len(stdout) > 0:
response[Session.RESPONSE_IDX_STDOUT] = bytes(stdout)
return response
RESPONSE_IDX_VERSION = 0
RESPONSE_IDX_FLAGS = 1
RESPONSE_IDX_RETCODE = 2
RESPONSE_IDX_RDYBYTE = 3
RESPONSE_IDX_STDERR = 4
RESPONSE_IDX_STDOUT = 5
RESPONSE_IDX_TMSTAMP = 6
RESPONSE_FLAGS_RUNNING = 0x01
RESPONSE_FLAGS_EOF_STDOUT = 0x02
RESPONSE_FLAGS_EOF_STDERR = 0x04
@staticmethod
def default_response() -> [any]:
response: list[any] = [
_PROTOCOL_VERSION_DEFAULT, # 0: Protocol version
False, # 1: Process running
None, # 2: Return value
0, # 3: Number of outstanding bytes
None, # 4: Stderr
None, # 5: Stdout
None, # 6: Timestamp
].copy()
response[Session.RESPONSE_IDX_TMSTAMP] = time.time()
return response
@classmethod
def error_response(cls, msg: str) -> [any]:
response = cls.default_response()
msg_bytes = f"{msg}\r\n".encode("utf-8")
response[Session.RESPONSE_IDX_STDERR] = bytes(msg_bytes)
response[Session.RESPONSE_IDX_RETCODE] = 255
response[Session.RESPONSE_IDX_RDYBYTE] = 0
return response
def _subproc_data_ready(link: RNS.Link, chars_available: int):
global _retry_timer
log = _get_logger("_subproc_data_ready")
session: Session = Session.get_for_tag(link.link_id)
def send(timeout: bool, tag: any, tries: int) -> any:
# log.debug("send")
def inner():
# log.debug("inner")
try:
if link.status != RNS.Link.ACTIVE:
_retry_timer.complete(link.link_id)
session.pending_receipt_take()
return
pr = session.pending_receipt_take()
log.debug(f"send inner pr: {pr}")
if pr is not None and pr.status == RNS.PacketReceipt.DELIVERED:
if not timeout:
_retry_timer.complete(tag)
log.debug(f"Notification completed with status {pr.status} on link {link}")
return
else:
if not timeout:
log.debug(
f"Notifying client try {tries} (retcode: {session.return_code} " +
f"chars avail: {chars_available})")
packet = RNS.Packet(link, DATA_AVAIL_MSG.encode("utf-8"))
packet.send()
pr = packet.receipt
session.pending_receipt_put(pr)
else:
log.error(f"Retry count exceeded, terminating link {link}")
_retry_timer.complete(link.link_id)
link.teardown()
except Exception as e:
log.error("Error notifying client: " + str(e))
_loop.call_soon_threadsafe(inner)
return link.link_id
with session.lock:
if not _retry_timer.has_tag(link.link_id):
_retry_timer.begin(try_limit=15,
wait_delay=max(link.rtt * 5 if link.rtt is not None else 1, 1),
try_callback=functools.partial(send, False),
timeout_callback=functools.partial(send, True))
else:
log.debug(f"Notification already pending for link {link}")
def _subproc_terminated(link: RNS.Link, return_code: int):
global _loop
log = _get_logger("_subproc_terminated")
log.info(f"Subprocess returned {return_code} for link {link}")
proc = Session.get_for_tag(link.link_id)
if proc is None:
log.debug(f"no proc for link {link}")
return
def cleanup():
def inner():
log.debug(f"cleanup culled link {link}")
if link and link.status != RNS.Link.CLOSED:
with exception.permit(SystemExit):
try:
link.teardown()
finally:
Session.clear_tag(link.link_id)
_loop.call_later(300, inner)
_loop.call_soon(_subproc_data_ready, link, 0)
_loop.call_soon_threadsafe(cleanup)
def _listen_start_proc(link: RNS.Link, remote_identity: str | None, term: str, cmd: [str],
loop: asyncio.AbstractEventLoop, session_flags: int) -> Session | None:
log = _get_logger("_listen_start_proc")
try:
return Session(tag=link.link_id,
cmd=cmd,
session_flags=session_flags,
term=term,
remote_identity=remote_identity,
loop=loop,
data_available_callback=functools.partial(_subproc_data_ready, link),
terminated_callback=functools.partial(_subproc_terminated, link))
except Exception as e:
log.error("Failed to launch process: " + str(e))
_subproc_terminated(link, 255)
return None
def _listen_link_established(link):
global _allow_all
log = _get_logger("_listen_link_established")
link.set_remote_identified_callback(_initiator_identified)
link.set_link_closed_callback(_listen_link_closed)
log.info("Link " + str(link) + " established")
def _listen_link_closed(link: RNS.Link):
log = _get_logger("_listen_link_closed")
# async def cleanup():
log.info("Link " + str(link) + " closed")
proc: Session | None = Session.get_for_tag(link.link_id)
if proc is None:
log.warning(f"No process for link {link}")
else:
try:
proc.process.terminate()
_retry_timer.complete(link.link_id)
except Exception as e:
log.error(f"Error closing process for link {link}: {e}")
Session.clear_tag(link.link_id)
def _initiator_identified(link, identity):
global _allow_all, _cmd, _loop
log = _get_logger("_initiator_identified")
log.info("Initiator of link " + str(link) + " identified as " + RNS.prettyhexrep(identity.hash))
if not _allow_all and identity.hash not in _allowed_identity_hashes:
log.warning("Identity " + RNS.prettyhexrep(identity.hash) + " not allowed, tearing down link", RNS.LOG_WARNING)
link.teardown()
def _listen_request(path, data, request_id, link_id, remote_identity, requested_at):
global _destination, _retry_timer, _loop, _cmd, _no_remote_command
log = _get_logger("_listen_request")
log.debug(
f"listen_execute {path} {RNS.prettyhexrep(request_id)} {RNS.prettyhexrep(link_id)} {remote_identity}, {requested_at}")
if not hasattr(data, "__len__") or len(data) < 1:
raise Exception("Request data invalid")
_retry_timer.complete(link_id)
link: RNS.Link = next(filter(lambda l: l.link_id == link_id, _destination.links), None)
if link is None:
log.error(f"Invalid request {request_id}, no link found with id {link_id}")
return None
remote_version = data[Session.REQUEST_IDX_VERS]
if not _protocol_check_magic(remote_version):
log.error("Request magic incorrect")
link.teardown()
return None
if not remote_version <= _PROTOCOL_VERSION_3:
return Session.error_response("Listener<->initiator version mismatch")
cmd = _cmd.copy()
remote_command = data[Session.REQUEST_IDX_CMD]
if remote_command is not None and len(remote_command) > 0:
if _no_remote_command:
return Session.error_response("Listener does not permit initiator to provide command.")
elif _remote_cmd_as_args:
cmd.extend(remote_command)
else:
cmd = remote_command
if not _no_remote_command and (cmd is None or len(cmd) == 0):
return Session.error_response("No command supplied and no default command available.")
session: Session | None = None
try:
term = data[Session.REQUEST_IDX_TERM]
# sanitize
if term is not None:
term = re.sub('[^A-Za-z-0-9\-\_]','', term)
session = Session.get_for_tag(link.link_id)
if session is None:
log.debug(f"Process not found for link {link}")
session = _listen_start_proc(link=link,
term=term,
session_flags=data[Session.REQUEST_IDX_FLAGS],
cmd=cmd,
remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""),
loop=_loop)
if session is None:
return Session.error_response("Unable to start subprocess")
# leave significant headroom for metadata and encoding
result = session.process_request(data, _protocol_response_chars_take(link.MDU, remote_version))
return result
# return ProcessState.default_response()
except Exception as e:
log.error(f"Error procesing request for link {link}: {e}")
try:
if session is not None and session.process.running:
session.process.terminate()
except Exception as ee:
log.debug(f"Error terminating process for link {link}: {ee}")
return Session.default_response()
async def _spin(until: callable = None, timeout: float | None = None) -> bool: async def _spin(until: callable = None, timeout: float | None = None) -> bool:
@ -731,7 +218,7 @@ async def _spin(until: callable = None, timeout: float | None = None) -> bool:
timeout += time.time() timeout += time.time()
while (timeout is None or time.time() < timeout) and not until(): while (timeout is None or time.time() < timeout) and not until():
if await _check_finished(0.001): if await _check_finished(0.01):
raise asyncio.CancelledError() raise asyncio.CancelledError()
if timeout is not None and time.time() > timeout: if timeout is not None and time.time() > timeout:
return False return False
@ -744,15 +231,28 @@ _remote_exec_grace = 2.0
_new_data: asyncio.Event | None = None _new_data: asyncio.Event | None = None
_tr: process.TTYRestorer | None = None _tr: process.TTYRestorer | None = None
_pq = queue.Queue()
class InitiatorState(enum.IntEnum):
IS_INITIAL = 0
IS_LINKED = 1
IS_WAIT_VERS = 2
IS_RUNNING = 3
IS_TERMINATE = 4
IS_TEARDOWN = 5
def _client_link_closed(link):
log = _get_logger("_client_link_closed")
_finished.set()
def _client_packet_handler(message, packet): def _client_packet_handler(message, packet):
global _new_data global _new_data
log = _get_logger("_client_packet_handler") log = _get_logger("_client_packet_handler")
if message is not None and message.decode("utf-8") == DATA_AVAIL_MSG and _new_data is not None: packet.prove()
log.debug("data available") _pq.put(message)
_new_data.set()
else:
log.error(f"received unhandled packet")
class RemoteExecutionError(Exception): class RemoteExecutionError(Exception):
@ -760,15 +260,10 @@ class RemoteExecutionError(Exception):
self.msg = msg self.msg = msg
def _response_handler(request_receipt: RNS.RequestReceipt): async def _initiate_link(configdir, identitypath=None, verbosity=0, quietness=0, noid=False, destination=None,
pass service_name="default", timeout=RNS.Transport.PATH_REQUEST_TIMEOUT):
async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=False, destination=None,
service_name="default", stdin=None, timeout=RNS.Transport.PATH_REQUEST_TIMEOUT,
cmd: [str] | None = None, stdin_eof=False, bytes_available=0):
global _identity, _reticulum, _link, _destination, _remote_exec_grace, _tr, _new_data global _identity, _reticulum, _link, _destination, _remote_exec_grace, _tr, _new_data
log = _get_logger("_execute") log = _get_logger("_initiate_link")
dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2 dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2
if len(destination) != dest_len: if len(destination) != dest_len:
@ -809,6 +304,8 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
_link = RNS.Link(_destination) _link = RNS.Link(_destination)
_link.did_identify = False _link.did_identify = False
_link.set_link_closed_callback(_client_link_closed)
log.info(f"Establishing link...") log.info(f"Establishing link...")
if not await _spin(until=lambda: _link.status == RNS.Link.ACTIVE, timeout=timeout): if not await _spin(until=lambda: _link.status == RNS.Link.ACTIVE, timeout=timeout):
raise RemoteExecutionError("Could not establish link with " + RNS.prettyhexrep(destination_hash)) raise RemoteExecutionError("Could not establish link with " + RNS.prettyhexrep(destination_hash))
@ -820,110 +317,6 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
_link.set_packet_callback(_client_packet_handler) _link.set_packet_callback(_client_packet_handler)
request = Session.default_request()
log.debug(f"Sending {len(stdin) or 0} bytes to listener")
# log.debug(f"Sending {stdin} to listener")
request[Session.REQUEST_IDX_STDIN] = bytes(stdin)
request[Session.REQUEST_IDX_CMD] = cmd
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], stdin_eof,
Session.REQUEST_FLAGS_EOF_STDIN)
request[Session.REQUEST_IDX_BYTES_AVAILABLE] = bytes_available
# TODO: Tune
timeout = timeout + _link.rtt * 4 + _remote_exec_grace
log.debug("Sending request")
request_receipt = _link.request(
path="data",
data=request,
timeout=timeout
)
timeout += 0.5
log.debug("Waiting for delivery")
await _spin(
until=lambda: _link.status == RNS.Link.CLOSED or (
request_receipt.status != RNS.RequestReceipt.FAILED and
request_receipt.status != RNS.RequestReceipt.SENT),
timeout=timeout
)
if _link.status == RNS.Link.CLOSED:
raise RemoteExecutionError("Could not request remote execution, link was closed")
if request_receipt.status == RNS.RequestReceipt.FAILED:
raise RemoteExecutionError("Could not request remote execution")
await _spin(
until=lambda: request_receipt.status != RNS.RequestReceipt.DELIVERED,
timeout=timeout
)
if request_receipt.status == RNS.RequestReceipt.FAILED:
raise RemoteExecutionError("No result was received")
if request_receipt.status == RNS.RequestReceipt.FAILED:
raise RemoteExecutionError("Receiving result failed")
if request_receipt.response is not None:
try:
version = request_receipt.response[Session.RESPONSE_IDX_VERSION] or 0
if not _protocol_check_magic(version):
raise RemoteExecutionError("Protocol error")
elif version != _PROTOCOL_VERSION_3:
raise RemoteExecutionError("Protocol version mismatch")
flags = request_receipt.response[Session.RESPONSE_IDX_FLAGS]
running = _check_and(flags, Session.RESPONSE_FLAGS_RUNNING)
stdout_eof = _check_and(flags, Session.RESPONSE_FLAGS_EOF_STDOUT)
stderr_eof = _check_and(flags, Session.RESPONSE_FLAGS_EOF_STDERR)
return_code = request_receipt.response[Session.RESPONSE_IDX_RETCODE]
ready_bytes = request_receipt.response[Session.RESPONSE_IDX_RDYBYTE] or 0
stdout = request_receipt.response[Session.RESPONSE_IDX_STDOUT]
stderr = request_receipt.response[Session.RESPONSE_IDX_STDERR]
# if stdout is not None:
# stdout = base64.b64decode(stdout)
timestamp = request_receipt.response[Session.RESPONSE_IDX_TMSTAMP]
# log.debug("data: " + (stdout.decode("utf-8") if stdout is not None else ""))
except RemoteExecutionError:
raise
except Exception as e:
raise RemoteExecutionError(f"Received invalid response") from e
if stdout is not None:
_tr.raw()
log.debug(f"stdout: {stdout}")
os.write(1, stdout)
sys.stdout.flush()
if stderr is not None:
_tr.raw()
log.debug(f"stderr: {stderr}")
os.write(2, stderr)
sys.stderr.flush()
if stderr_eof and ready_bytes == 0:
log.debug("Closing stderr")
os.close(2)
if stdout_eof and ready_bytes == 0:
log.debug("Closing stdout")
os.close(1)
got_bytes = (len(stdout) if stdout is not None else 0) + (len(stderr) if stderr is not None else 0)
log.debug(f"{got_bytes} chars received, {ready_bytes} bytes ready on server, return code {return_code}")
if ready_bytes > 0:
_new_data.set()
if (not running or return_code is not None) and (ready_bytes == 0):
log.debug(f"returning running: {running}, return_code: {return_code}")
return return_code or 255
return None
_pre_input = bytearray()
async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness: int, noid: bool, destination: str, async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness: int, noid: bool, destination: str,
service_name: str, timeout: float, command: [str] | None = None): service_name: str, timeout: float, command: [str] | None = None):
@ -931,13 +324,52 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
log = _get_logger("_initiate") log = _get_logger("_initiate")
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
_new_data = asyncio.Event() _new_data = asyncio.Event()
state = InitiatorState.IS_INITIAL
data_buffer = bytearray(sys.stdin.buffer.read()) if not os.isatty(sys.stdin.fileno()) else bytearray() data_buffer = bytearray(sys.stdin.buffer.read()) if not os.isatty(sys.stdin.fileno()) else bytearray()
await _initiate_link(
configdir=configdir,
identitypath=identitypath,
verbosity=verbosity,
quietness=quietness,
noid=noid,
destination=destination,
service_name=service_name,
timeout=timeout,
)
if not _link or _link.status != RNS.Link.ACTIVE:
_finished.set()
return 255
state = InitiatorState.IS_LINKED
outlet = session.RNSOutlet(_link)
with protocol.Messenger(retry_delay_min=5) as messenger:
# Next step after linking and identifying: send version
# if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5):
# print("Error bringing up link")
# return 253
messenger.send(outlet, protocol.VersionInfoMessage())
try:
vp = _pq.get(timeout=max(outlet.rtt * 20, 5))
vm = messenger.receive(vp)
if not isinstance(vm, protocol.VersionInfoMessage):
raise Exception("Invalid message received")
log.debug(f"Server version info: sw {vm.sw_version} prot {vm.protocol_version}")
state = InitiatorState.IS_RUNNING
except queue.Empty:
print("Protocol error")
return 254
winch = False
def sigwinch_handler(): def sigwinch_handler():
nonlocal winch
# log.debug("WindowChanged") # log.debug("WindowChanged")
if _new_data is not None: winch = True
_new_data.set() # if _new_data is not None:
# _new_data.set()
stdin_eof = False stdin_eof = False
def stdin(): def stdin():
@ -957,61 +389,98 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
process.tty_add_reader_callback(sys.stdin.fileno(), stdin) process.tty_add_reader_callback(sys.stdin.fileno(), stdin)
await _check_finished() await _check_finished()
loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler)
# leave a lot of overhead tcattr = None
mdu = 64 rows, cols, hpix, vpix = (None, None, None, None)
rtt = 5
first_loop = True
cmdline = " ".join(command or [])
while not await _check_finished():
try: try:
log.debug("top of client loop") tcattr = termios.tcgetattr(0)
stdin = data_buffer[:mdu] rows, cols, hpix, vpix = process.tty_get_winsize(0)
data_buffer = data_buffer[mdu:] except:
_new_data.clear() try:
log.debug("before _execute") tcattr = termios.tcgetattr(1)
return_code = await _execute( rows, cols, hpix, vpix = process.tty_get_winsize(1)
configdir=configdir, except:
identitypath=identitypath, try:
verbosity=verbosity, tcattr = termios.tcgetattr(2)
quietness=quietness, rows, cols, hpix, vpix = process.tty_get_winsize(2)
noid=noid, except:
destination=destination, pass
service_name=service_name,
stdin=stdin,
timeout=timeout,
cmd=command,
stdin_eof=stdin_eof,
bytes_available=len(data_buffer)
)
if first_loop: messenger.send(outlet, protocol.ExecuteCommandMesssage(cmdline=command,
first_loop = False pipe_stdin=not os.isatty(0),
mdu = _protocol_request_chars_take(_link.MDU, pipe_stdout=not os.isatty(1),
_PROTOCOL_VERSION_DEFAULT, pipe_stderr=not os.isatty(2),
os.environ.get("TERM", ""), tcflags=tcattr,
cmdline) term=os.environ.get("TERM", None),
_new_data.set() rows=rows,
cols=cols,
hpix=hpix,
vpix=vpix))
if _link:
rtt = _link.rtt
if return_code is not None:
log.debug(f"received return code {return_code}, exiting")
loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler)
mdu = _link.MDU - 16
sent_eof = False
while not await _check_finished() and state in [InitiatorState.IS_RUNNING]:
try:
try:
packet = _pq.get_nowait()
message = messenger.receive(packet)
if isinstance(message, protocol.StreamDataMessage):
if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDOUT:
if message.data and len(message.data) > 0:
_tr.raw()
log.debug(f"stdout: {message.data}")
os.write(1, message.data)
sys.stdout.flush()
if message.eof:
os.close(1)
if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDERR:
if message.data and len(message.data) > 0:
_tr.raw()
log.debug(f"stdout: {message.data}")
os.write(2, message.data)
sys.stderr.flush()
if message.eof:
os.close(2)
elif isinstance(message, protocol.CommandExitedMessage):
log.debug(f"received return code {message.return_code}, exiting")
with exception.permit(SystemExit, KeyboardInterrupt): with exception.permit(SystemExit, KeyboardInterrupt):
_link.teardown() _link.teardown()
return message.return_code
return return_code elif isinstance(message, protocol.ErrorMessage):
except asyncio.CancelledError: log.error(message.data)
if _link and _link.status != RNS.Link.CLOSED: if message.fatal:
_link.teardown() _link.teardown()
return 0 return 200
except queue.Empty:
pass
if messenger.is_outlet_ready(outlet):
stdin = data_buffer[:mdu]
data_buffer = data_buffer[mdu:]
eof = not sent_eof and stdin_eof and len(stdin) == 0
if len(stdin) > 0 or eof:
messenger.send(outlet, protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN,
stdin, eof))
sent_eof = eof
except RemoteExecutionError as e: except RemoteExecutionError as e:
print(e.msg) print(e.msg)
return 255 return 255
except Exception as ex:
print(f"Client exception: {ex}")
if _link and _link.status != RNS.Link.CLOSED:
_link.teardown()
return 127
await process.event_wait_any([_new_data, _finished], timeout=min(max(rtt * 50, 5), 120)) # await process.event_wait_any([_new_data, _finished], timeout=min(max(rtt * 50, 5), 120))
await asyncio.sleep(0.01)
log.debug("after main loop")
return 0 return 0

423
rnsh/session.py Normal file
View File

@ -0,0 +1,423 @@
from __future__ import annotations
import contextlib
import functools
import threading
import rnsh.exception as exception
import asyncio
import rnsh.process as process
import rnsh.helpers as helpers
import rnsh.protocol as protocol
import enum
from typing import TypeVar, Generic, Callable, List
from abc import abstractmethod, ABC
from multiprocessing import Manager
import os
import RNS
import logging as __logging
from rnsh.protocol import MessageOutletBase, _TReceipt, MessageState
module_logger = __logging.getLogger(__name__)
_TLink = TypeVar("_TLink")
class SEType(enum.IntEnum):
SE_LINK_CLOSED = 0
class SessionException(Exception):
def __init__(self, setype: SEType, msg: str, *args):
super().__init__(msg, args)
self.type = setype
class LSState(enum.IntEnum):
LSSTATE_WAIT_IDENT = 1
LSSTATE_WAIT_VERS = 2
LSSTATE_WAIT_CMD = 3
LSSTATE_RUNNING = 4
LSSTATE_ERROR = 5
LSSTATE_TEARDOWN = 6
_TIdentity = TypeVar("_TIdentity")
class LSOutletBase(protocol.MessageOutletBase):
@abstractmethod
def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]):
raise NotImplemented()
@abstractmethod
def set_link_closed_callback(self, cb: Callable[[LSOutletBase], None]):
raise NotImplemented()
@abstractmethod
def unset_link_closed_callback(self):
raise NotImplemented()
@abstractmethod
def teardown(self):
raise NotImplemented()
@abstractmethod
def __init__(self):
raise NotImplemented()
class ListenerSession:
sessions: List[ListenerSession] = []
messenger: protocol.Messenger = protocol.Messenger(retry_delay_min=5)
allowed_identity_hashes: [any] = []
allow_all: bool = False
allow_remote_command: bool = False
default_command: [str] = []
remote_cmd_as_args = False
def __init__(self, outlet: LSOutletBase, loop: asyncio.AbstractEventLoop):
self._log = module_logger.getChild(self.__class__.__name__)
self._log.info(f"Session started for {outlet}")
self.outlet = outlet
self.outlet.set_initiator_identified_callback(self._initiator_identified)
self.outlet.set_link_closed_callback(self._link_closed)
self.loop = loop
self.state: LSState = None
self.remote_identity = None
self.term: str | None = None
self.stdin_is_pipe: bool = False
self.stdout_is_pipe: bool = False
self.stderr_is_pipe: bool = False
self.tcflags: [any] = None
self.cmdline: [str] = None
self.rows: int = 0
self.cols: int = 0
self.hpix: int = 0
self.vpix: int = 0
self.stdout_buf = bytearray()
self.stdout_eof_sent = False
self.stderr_buf = bytearray()
self.stderr_eof_sent = False
self.return_code: int | None = None
self.return_code_sent = False
self.process: process.CallbackSubprocess | None = None
self._set_state(LSState.LSSTATE_WAIT_IDENT)
self.sessions.append(self)
def _terminated(self, return_code: int):
self.return_code = return_code
def _set_state(self, state: LSState, timeout_factor: float = 10.0):
timeout = max(self.outlet.rtt * timeout_factor, max(self.outlet.rtt * 2, 10)) if timeout_factor is not None else None
self._log.debug(f"Set state: {state.name}, timeout {timeout}")
orig_state = self.state
self.state = state
if timeout_factor is not None:
self._call(functools.partial(self._check_protocol_timeout, lambda: self.state == orig_state, state.name), timeout)
def _call(self, func: callable, delay: float = 0):
def call_inner():
# self._log.debug("call_inner")
if delay == 0:
func()
else:
self.loop.call_later(delay, func)
self.loop.call_soon_threadsafe(call_inner)
def send(self, message: protocol.Message):
self.messenger.send(self.outlet, message)
def _protocol_error(self, name: str):
self.terminate(f"Protocol error ({name})")
def _protocol_timeout_error(self, name: str):
self.terminate(f"Protocol timeout error: {name}")
def terminate(self, error: str = None):
with contextlib.suppress(Exception):
self._log.debug("Terminating session" + (f": {error}" if error else ""))
if error and self.state != LSState.LSSTATE_TEARDOWN:
with contextlib.suppress(Exception):
self.send(protocol.ErrorMessage(error, True))
self.state = LSState.LSSTATE_ERROR
self._terminate_process()
self._call(self._prune, max(self.outlet.rtt * 3, 5))
def _prune(self):
self.state = LSState.LSSTATE_TEARDOWN
with contextlib.suppress(ValueError):
self.sessions.remove(self)
with contextlib.suppress(Exception):
self.outlet.teardown()
def _check_protocol_timeout(self, fail_condition: Callable[[], bool], name: str):
timeout = True
try:
timeout = self.state != LSState.LSSTATE_TEARDOWN and fail_condition()
except Exception as ex:
self._log.exception("Error in protocol timeout", ex)
if timeout:
self._protocol_timeout_error(name)
def _link_closed(self, outlet: LSOutletBase):
outlet.unset_link_closed_callback()
if outlet != self.outlet:
self._log.debug("Link closed received from incorrect outlet")
return
self._log.debug(f"link_closed {outlet}")
self.messenger.clear_retries(self.outlet)
self.terminate()
def _initiator_identified(self, outlet, identity):
if outlet != self.outlet:
self._log.debug("Identity received from incorrect outlet")
return
self._log.info(f"initiator_identified {identity} on link {outlet}")
if self.state != LSState.LSSTATE_WAIT_IDENT:
self._protocol_error(LSState.LSSTATE_WAIT_IDENT.name)
if not self.allow_all and identity.hash not in self.allowed_identity_hashes:
self.terminate("Identity is not allowed.")
self.remote_identity = identity
self.outlet.set_packet_received_callback(self._packet_received)
self._set_state(LSState.LSSTATE_WAIT_VERS)
@classmethod
async def pump_all(cls):
for session in cls.sessions:
session.pump()
await asyncio.sleep(0)
@classmethod
async def terminate_all(cls, reason: str):
for session in cls.sessions:
session.terminate(reason)
await asyncio.sleep(0)
def pump(self):
try:
if self.state != LSState.LSSTATE_RUNNING:
return
elif not self.messenger.is_outlet_ready(self.outlet):
return
elif len(self.stderr_buf) > 0:
mdu = self.outlet.mdu - 16
data = self.stderr_buf[:mdu]
self.stderr_buf = self.stderr_buf[mdu:]
send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent
self.stderr_eof_sent = self.stderr_eof_sent or send_eof
msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDERR,
data, send_eof)
self.send(msg)
if send_eof:
self.stderr_eof_sent = True
elif len(self.stdout_buf) > 0:
mdu = self.outlet.mdu - 16
data = self.stdout_buf[:mdu]
self.stdout_buf = self.stdout_buf[mdu:]
send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent
self.stdout_eof_sent = self.stdout_eof_sent or send_eof
msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDOUT,
data, send_eof)
self.send(msg)
elif self.return_code is not None and not self.return_code_sent:
msg = protocol.CommandExitedMessage(self.return_code)
self.send(msg)
self.return_code_sent = True
self._call(functools.partial(self._check_protocol_timeout,
lambda: self.state == LSState.LSSTATE_RUNNING, "CommandExitedMessage"),
max(self.outlet.rtt * 5, 10))
except Exception as ex:
self._log.exception("Error during pump", ex)
def _terminate_process(self):
with contextlib.suppress(Exception):
if self.process and self.process.running:
self.process.terminate()
def _start_cmd(self, cmdline: [str], pipe_stdin: bool, pipe_stdout: bool, pipe_stderr: bool, tcflags: [any],
term: str | None, rows: int, cols: int, hpix: int, vpix: int):
self.cmdline = self.default_command
if not self.allow_remote_command and cmdline and len(cmdline) > 0:
self.terminate("Remote command line not allowed by listener")
return
if self.remote_cmd_as_args and cmdline and len(cmdline) > 0:
self.cmdline.extend(cmdline)
elif cmdline and len(cmdline) > 0:
self.cmdline = cmdline
self.stdin_is_pipe = pipe_stdin
self.stdout_is_pipe = pipe_stdout
self.stderr_is_pipe = pipe_stderr
self.tcflags = tcflags
self.term = term
def stdout(data: bytes):
self.stdout_buf.extend(data)
def stderr(data: bytes):
self.stderr_buf.extend(data)
try:
self.process = process.CallbackSubprocess(argv=self.cmdline,
env={"TERM": self.term or os.environ.get("TERM", None),
"RNS_REMOTE_IDENTITY": RNS.prettyhexrep(self.remote_identity.hash) or ""},
loop=self.loop,
stdout_callback=stdout,
stderr_callback=stderr,
terminated_callback=self._terminated,
stdin_is_pipe=self.stdin_is_pipe,
stdout_is_pipe=self.stdout_is_pipe,
stderr_is_pipe=self.stderr_is_pipe)
self.process.start()
self._set_window_size(rows, cols, hpix, vpix)
except Exception as ex:
self._log.exception(f"Unable to start process for link {self.outlet}", ex)
self.terminate("Unable to start process")
def _set_window_size(self, rows: int, cols: int, hpix: int, vpix: int):
self.rows = rows
self.cols = cols
self.hpix = hpix
self.vpix = vpix
with contextlib.suppress(Exception):
self.process.set_winsize(rows, cols, hpix, vpix)
def _received_stdin(self, data: bytes, eof: bool):
if data and len(data) > 0:
self.process.write(data)
if eof:
self.process.close_stdin()
def _handle_message(self, message: protocol.Message):
if self.state == LSState.LSSTATE_WAIT_VERS:
if not isinstance(message, protocol.VersionInfoMessage):
self._protocol_error(self.state.name)
return
self._log.info(f"version {message.sw_version}, protocol {message.protocol_version} on link {self.outlet}")
if message.protocol_version != protocol.PROTOCOL_VERSION:
self.terminate("Incompatible protocol")
return
self.send(protocol.VersionInfoMessage())
self._set_state(LSState.LSSTATE_WAIT_CMD)
return
elif self.state == LSState.LSSTATE_WAIT_CMD:
if not isinstance(message, protocol.ExecuteCommandMesssage):
return self._protocol_error(self.state.name)
self._log.info(f"Execute command message on link {self.outlet}: {message.cmdline}")
self._set_state(LSState.LSSTATE_RUNNING)
self._start_cmd(message.cmdline, message.pipe_stdin, message.pipe_stdout, message.pipe_stderr,
message.tcflags, message.term, message.rows, message.cols, message.hpix, message.vpix)
return
elif self.state == LSState.LSSTATE_RUNNING:
if isinstance(message, protocol.WindowSizeMessage):
self._set_window_size(message.rows, message.cols, message.hpix, message.vpix)
elif isinstance(message, protocol.StreamDataMessage):
if message.stream_id != protocol.StreamDataMessage.STREAM_ID_STDIN:
self._log.error(f"Received stream data for invalid stream {message.stream_id} on link {self.outlet}")
return self._protocol_error(self.state.name)
self._received_stdin(message.data, message.eof)
return
elif isinstance(message, protocol.NoopMessage):
# echo noop only on listener--used for keepalive/connectivity check
self.send(message)
return
elif self.state in [LSState.LSSTATE_ERROR, LSState.LSSTATE_TEARDOWN]:
self._log.error(f"Received packet, but in state {self.state.name}")
return
else:
self._protocol_error("unexpected message")
return
def _packet_received(self, outlet: protocol.MessageOutletBase, raw: bytes):
if outlet != self.outlet:
self._log.debug("Packet received from incorrect outlet")
return
try:
message = self.messenger.receive(raw)
self._handle_message(message)
except Exception as ex:
self._protocol_error("unusable packet")
class RNSOutlet(LSOutletBase):
def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]):
def inner_cb(link, identity: _TIdentity):
cb(self, identity)
self.link.set_remote_identified_callback(inner_cb)
def set_link_closed_callback(self, cb: Callable[[LSOutletBase], None]):
def inner_cb(link):
cb(self)
self.link.set_link_closed_callback(inner_cb)
def unset_link_closed_callback(self):
self.link.set_link_closed_callback(None)
def teardown(self):
self.link.teardown()
def send(self, raw: bytes) -> RNS.PacketReceipt:
packet = RNS.Packet(self.link, raw)
packet.send()
return packet.receipt
@property
def mdu(self) -> int:
return self.link.MDU
@property
def rtt(self) -> float:
return self.link.rtt
@property
def is_usuable(self):
return True #self.link.status in [RNS.Link.ACTIVE]
def get_receipt_state(self, receipt: RNS.PacketReceipt) -> MessageState:
status = receipt.get_status()
if status == RNS.PacketReceipt.SENT:
return protocol.MessageState.MSGSTATE_SENT
if status == RNS.PacketReceipt.DELIVERED:
return protocol.MessageState.MSGSTATE_DELIVERED
if status == RNS.PacketReceipt.FAILED:
return protocol.MessageState.MSGSTATE_FAILED
else:
raise Exception(f"Unexpected receipt state: {status}")
def timed_out(self):
self.link.teardown()
def __str__(self):
return f"Outlet RNS Link {self.link}"
def set_packet_received_callback(self, cb: Callable[[MessageOutletBase, bytes], None]):
def inner_cb(message, packet: RNS.Packet):
packet.prove()
cb(self, message)
self.link.set_packet_callback(inner_cb)
def __init__(self, link: RNS.Link):
self.link = link
link.lsoutlet = self
link.msgoutlet = self
@staticmethod
def get_outlet(link: RNS.Link):
if hasattr(link, "lsoutlet"):
return link.lsoutlet
return RNSOutlet(link)

View File

@ -1,6 +1,10 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from rnsh.protocol import _TReceipt, MessageState
from typing import Callable
logging.getLogger().setLevel(logging.DEBUG) logging.getLogger().setLevel(logging.DEBUG)
import rnsh.protocol import rnsh.protocol
@ -14,64 +18,60 @@ import uuid
module_logger = logging.getLogger(__name__) module_logger = logging.getLogger(__name__)
class Link: class Receipt:
def __init__(self, state: rnsh.protocol.MessageState, raw: bytes):
self.state = state
self.raw = raw
class MessageOutletTest(rnsh.protocol.MessageOutletBase):
def __init__(self, mdu: int, rtt: float): def __init__(self, mdu: int, rtt: float):
self.link_id = uuid.uuid4() self.link_id = uuid.uuid4()
self.timeout_callbacks = 0 self.timeout_callbacks = 0
self.mdu = mdu self._mdu = mdu
self.rtt = rtt self._rtt = rtt
self.usable = True self._usable = True
self.receipts = [] self.receipts = []
self.packet_callback: Callable[[rnsh.protocol.MessageOutletBase, bytes], None] | None = None
def timeout_callback(self): def send(self, raw: bytes) -> Receipt:
receipt = Receipt(rnsh.protocol.MessageState.MSGSTATE_SENT, raw)
self.receipts.append(receipt)
return receipt
def set_packet_received_callback(self, cb: Callable[[rnsh.protocol.MessageOutletBase, bytes], None]):
self.packet_callback = cb
def receive(self, raw: bytes):
if self.packet_callback:
self.packet_callback(self, raw)
@property
def mdu(self):
return self._mdu
@property
def rtt(self):
return self._rtt
@property
def is_usuable(self):
return self._usable
def get_receipt_state(self, receipt: Receipt) -> MessageState:
return receipt.state
def timed_out(self):
self.timeout_callbacks += 1 self.timeout_callbacks += 1
def __str__(self): def __str__(self):
return str(self.link_id) return str(self.link_id)
class Receipt:
def __init__(self, link: Link, state: rnsh.protocol.MessageState, raw: bytes):
self.state = state
self.raw = raw
self.link = link
class ProtocolHarness(contextlib.AbstractContextManager): class ProtocolHarness(contextlib.AbstractContextManager):
def __init__(self, retry_delay_min: float = 1): def __init__(self, retry_delay_min: float = 1):
self._log = module_logger.getChild(self.__class__.__name__) self._log = module_logger.getChild(self.__class__.__name__)
self.messenger = rnsh.protocol.Messenger(receipt_checker=self.receipt_checker, self.messenger = rnsh.protocol.Messenger(retry_delay_min=retry_delay_min)
link_timeout_callback=self.link_timeout_callback,
link_mdu_getter=self.link_mdu_getter,
link_rtt_getter=self.link_rtt_getter,
link_usable_getter=self.link_usable_getter,
packet_sender=self.packet_sender,
retry_delay_min=retry_delay_min)
def packet_sender(self, link: Link, raw: bytes) -> Receipt:
receipt = Receipt(link, rnsh.protocol.MessageState.MSGSTATE_SENT, raw)
link.receipts.append(receipt)
return receipt
@staticmethod
def link_mdu_getter(link: Link):
return link.mdu
@staticmethod
def link_rtt_getter(link: Link):
return link.rtt
@staticmethod
def link_usable_getter(link: Link):
return link.usable
@staticmethod
def receipt_checker(receipt: Receipt) -> rnsh.protocol.MessageState:
return receipt.state
@staticmethod
def link_timeout_callback(link: Link):
link.timeout_callback()
def cleanup(self): def cleanup(self):
self.messenger.shutdown() self.messenger.shutdown()
@ -83,49 +83,58 @@ class ProtocolHarness(contextlib.AbstractContextManager):
return False return False
def test_mdu():
with ProtocolHarness() as h:
mdu = 500
link = Link(mdu=mdu, rtt=0.25)
assert h.messenger.get_mdu(link) == mdu - 4
link.mdu = mdu = 600
assert h.messenger.get_mdu(link) == mdu - 4
def test_rtt():
with ProtocolHarness() as h:
rtt = 0.25
link = Link(mdu=500, rtt=rtt)
assert h.messenger.get_rtt(link) == rtt
def test_send_one_retry(): def test_send_one_retry():
rtt = 0.001 rtt = 0.001
retry_interval = rtt * 150 retry_interval = rtt * 150
message_content = b'Test' message_content = b'Test'
with ProtocolHarness(retry_delay_min=retry_interval) as h: with ProtocolHarness(retry_delay_min=retry_interval) as h:
link = Link(mdu=500, rtt=rtt) outlet = MessageOutletTest(mdu=500, rtt=rtt)
message = rnsh.protocol.StreamDataMessage(stream_id=rnsh.protocol.StreamDataMessage.STREAM_ID_STDIN, message = rnsh.protocol.StreamDataMessage(stream_id=rnsh.protocol.StreamDataMessage.STREAM_ID_STDIN,
data=message_content, eof=True) data=message_content, eof=True)
assert len(link.receipts) == 0 assert len(outlet.receipts) == 0
h.messenger.send_message(link, message) h.messenger.send(outlet, message)
assert message.tracked assert message.tracked
assert message.raw is not None assert message.raw is not None
assert len(link.receipts) == 1 assert len(outlet.receipts) == 1
receipt = link.receipts[0] receipt = outlet.receipts[0]
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
assert receipt.raw == message.raw assert receipt.raw == message.raw
time.sleep(retry_interval * 1.5) time.sleep(retry_interval * 1.5)
assert len(link.receipts) == 1 assert len(outlet.receipts) == 1
receipt.state = rnsh.protocol.MessageState.MSGSTATE_FAILED receipt.state = rnsh.protocol.MessageState.MSGSTATE_FAILED
module_logger.info("set failed") module_logger.info("set failed")
time.sleep(retry_interval) time.sleep(retry_interval)
assert len(link.receipts) == 2 assert len(outlet.receipts) == 2
receipt = link.receipts[1] receipt = outlet.receipts[1]
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
receipt.state = rnsh.protocol.MessageState.MSGSTATE_DELIVERED receipt.state = rnsh.protocol.MessageState.MSGSTATE_DELIVERED
time.sleep(retry_interval) time.sleep(retry_interval)
assert len(link.receipts) == 2 assert len(outlet.receipts) == 2
assert not message.tracked
def test_send_timeout():
rtt = 0.001
retry_interval = rtt * 150
message_content = b'Test'
with ProtocolHarness(retry_delay_min=retry_interval) as h:
outlet = MessageOutletTest(mdu=500, rtt=rtt)
message = rnsh.protocol.StreamDataMessage(stream_id=rnsh.protocol.StreamDataMessage.STREAM_ID_STDIN,
data=message_content, eof=True)
assert len(outlet.receipts) == 0
h.messenger.send(outlet, message)
assert message.tracked
assert message.raw is not None
assert len(outlet.receipts) == 1
receipt = outlet.receipts[0]
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
assert receipt.raw == message.raw
time.sleep(retry_interval * 1.5)
assert outlet.timeout_callbacks == 0
time.sleep(retry_interval * 7)
assert len(outlet.receipts) == 1
assert outlet.timeout_callbacks == 1
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
assert not message.tracked assert not message.tracked
@ -133,24 +142,32 @@ def eat_own_dog_food(message: rnsh.protocol.Message, checker: typing.Callable[[r
rtt = 0.001 rtt = 0.001
retry_interval = rtt * 150 retry_interval = rtt * 150
with ProtocolHarness(retry_delay_min=retry_interval) as h: with ProtocolHarness(retry_delay_min=retry_interval) as h:
link = Link(mdu=500, rtt=rtt)
assert len(link.receipts) == 0 decoded: [rnsh.protocol.Message] = []
h.messenger.send_message(link, message) def packet(outlet, buffer):
decoded.append(h.messenger.receive(buffer))
outlet = MessageOutletTest(mdu=500, rtt=rtt)
outlet.set_packet_received_callback(packet)
assert len(outlet.receipts) == 0
h.messenger.send(outlet, message)
assert message.tracked assert message.tracked
assert message.raw is not None assert message.raw is not None
assert len(link.receipts) == 1 assert len(outlet.receipts) == 1
receipt = link.receipts[0] receipt = outlet.receipts[0]
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
assert receipt.raw == message.raw assert receipt.raw == message.raw
module_logger.info("set delivered") module_logger.info("set delivered")
receipt.state = rnsh.protocol.MessageState.MSGSTATE_DELIVERED receipt.state = rnsh.protocol.MessageState.MSGSTATE_DELIVERED
time.sleep(retry_interval * 2) time.sleep(retry_interval * 2)
assert len(link.receipts) == 1 assert len(outlet.receipts) == 1
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_DELIVERED assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_DELIVERED
assert not message.tracked assert not message.tracked
module_logger.info("injecting rx message") module_logger.info("injecting rx message")
h.messenger.inbound(message.raw) assert len(decoded) == 0
rx_message = h.messenger.poll_inbound(block=False) outlet.receive(message.raw)
assert len(decoded) == 1
rx_message = decoded[0]
assert rx_message is not None assert rx_message is not None
assert isinstance(rx_message, message.__class__) assert isinstance(rx_message, message.__class__)
assert rx_message.msgid != message.msgid assert rx_message.msgid != message.msgid
@ -238,6 +255,16 @@ def test_send_receive_error():
eat_own_dog_food(message, check) eat_own_dog_food(message, check)
def test_send_receive_cmdexit():
message = rnsh.protocol.CommandExitedMessage(5)
def check(rx_message: rnsh.protocol.Message):
assert isinstance(rx_message, message.__class__)
assert rx_message.return_code == message.return_code
eat_own_dog_food(message, check)

View File

@ -12,19 +12,6 @@ import re
import os import os
def test_check_magic():
magic = rnsh.rnsh._PROTOCOL_VERSION_0
# magic for version 0 is generated, make sure it comes out as expected
assert magic == 0xdeadbeef00000000
# verify the checker thinks it's right
assert rnsh.rnsh._protocol_check_magic(magic)
# scramble the magic
magic = magic | 0x00ffff0000000000
# make sure it fails now
assert not rnsh.rnsh._protocol_check_magic(magic)
def test_version(): def test_version():
# version = importlib.metadata.version(rnsh.__version__) # version = importlib.metadata.version(rnsh.__version__)
assert rnsh.__version__ != "0.0.0" assert rnsh.__version__ != "0.0.0"