mirror of
https://github.com/markqvist/rnsh.git
synced 2025-06-24 14:00:35 -04:00
Got the new protocol working.
This commit is contained in:
parent
0ee305795f
commit
8edb4020b1
8 changed files with 876 additions and 887 deletions
8
rnsh/helpers.py
Normal file
8
rnsh/helpers.py
Normal file
|
@ -0,0 +1,8 @@
|
|||
def bitwise_or_if(value: int, condition: bool, orval: int):
|
||||
if not condition:
|
||||
return value
|
||||
return value | orval
|
||||
|
||||
|
||||
def check_and(value: int, andval: int) -> bool:
|
||||
return (value & andval) > 0
|
|
@ -382,8 +382,11 @@ def _launch_child(cmd_line: list[str], env: dict[str, str], stdin_is_pipe: bool,
|
|||
# Make PTY controlling if necessary
|
||||
if child_fd is not None:
|
||||
os.setsid()
|
||||
try:
|
||||
tmp_fd = os.open(os.ttyname(0 if not stdin_is_pipe else 1 if not stdout_is_pipe else 2), os.O_RDWR)
|
||||
os.close(tmp_fd)
|
||||
except:
|
||||
pass
|
||||
# fcntl.ioctl(0 if not stdin_is_pipe else 1 if not stdout_is_pipe else 2), os.O_RDWR, termios.TIOCSCTTY, 0)
|
||||
|
||||
# Execute the command
|
||||
|
@ -445,7 +448,8 @@ class CallbackSubprocess:
|
|||
self._child_stdout: int = None
|
||||
self._child_stderr: int = None
|
||||
self._return_code: int = None
|
||||
self._eof: bool = False
|
||||
self._stdout_eof: bool = False
|
||||
self._stderr_eof: bool = False
|
||||
self._stdin_is_pipe = stdin_is_pipe
|
||||
self._stdout_is_pipe = stdout_is_pipe
|
||||
self._stderr_is_pipe = stderr_is_pipe
|
||||
|
@ -455,6 +459,21 @@ class CallbackSubprocess:
|
|||
Terminate child process if running
|
||||
:param kill_delay: if after kill_delay seconds the child process has not exited, escalate to SIGHUP and SIGKILL
|
||||
"""
|
||||
|
||||
same = self._child_stdout == self._child_stderr
|
||||
with contextlib.suppress(OSError):
|
||||
if not self._stdout_eof:
|
||||
tty_unset_reader_callbacks(self._child_stdout)
|
||||
os.close(self._child_stdout)
|
||||
self._child_stdout = None
|
||||
|
||||
if not same:
|
||||
with contextlib.suppress(OSError):
|
||||
if not self._stderr_eof:
|
||||
tty_unset_reader_callbacks(self._child_stderr)
|
||||
os.close(self._child_stderr)
|
||||
self._child_stdout = None
|
||||
|
||||
self._log.debug("terminate()")
|
||||
if not self.running:
|
||||
return
|
||||
|
@ -477,7 +496,7 @@ class CallbackSubprocess:
|
|||
os.waitpid(self._pid, 0)
|
||||
self._log.debug("wait() finish")
|
||||
|
||||
threading.Thread(target=wait).start()
|
||||
threading.Thread(target=wait, daemon=True).start()
|
||||
|
||||
def close_stdin(self):
|
||||
with contextlib.suppress(Exception):
|
||||
|
@ -592,26 +611,44 @@ class CallbackSubprocess:
|
|||
|
||||
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
|
||||
|
||||
def reader(fd: int, callback: callable):
|
||||
def stdout():
|
||||
try:
|
||||
with exception.permit(SystemExit):
|
||||
data = tty_read(fd)
|
||||
data = tty_read_poll(self._child_stdout)
|
||||
if data is not None and len(data) > 0:
|
||||
callback(data)
|
||||
self._stdout_cb(data)
|
||||
except EOFError:
|
||||
self._eof = True
|
||||
self._stdout_eof = True
|
||||
tty_unset_reader_callbacks(self._child_stdout)
|
||||
callback(bytearray())
|
||||
self._stdout_cb(bytearray())
|
||||
with contextlib.suppress(OSError):
|
||||
os.close(self._child_stdout)
|
||||
|
||||
tty_add_reader_callback(self._child_stdout, functools.partial(reader, self._child_stdout, self._stdout_cb),
|
||||
self._loop)
|
||||
def stderr():
|
||||
try:
|
||||
with exception.permit(SystemExit):
|
||||
data = tty_read_poll(self._child_stderr)
|
||||
if data is not None and len(data) > 0:
|
||||
self._stderr_cb(data)
|
||||
except EOFError:
|
||||
self._stderr_eof = True
|
||||
tty_unset_reader_callbacks(self._child_stderr)
|
||||
self._stdout_cb(bytearray())
|
||||
with contextlib.suppress(OSError):
|
||||
os.close(self._child_stderr)
|
||||
|
||||
tty_add_reader_callback(self._child_stdout, stdout, self._loop)
|
||||
if self._child_stderr != self._child_stdout:
|
||||
tty_add_reader_callback(self._child_stderr, functools.partial(reader, self._child_stderr, self._stderr_cb),
|
||||
self._loop)
|
||||
tty_add_reader_callback(self._child_stderr, stderr, self._loop)
|
||||
|
||||
@property
|
||||
def eof(self):
|
||||
return self._eof or not self.running
|
||||
def stdout_eof(self):
|
||||
return self._stdout_eof or not self.running
|
||||
|
||||
@property
|
||||
def stderr_eof(self):
|
||||
return self._stderr_eof or not self.running
|
||||
|
||||
|
||||
@property
|
||||
def return_code(self) -> int:
|
||||
|
|
158
rnsh/protocol.py
158
rnsh/protocol.py
|
@ -15,6 +15,7 @@ import abc
|
|||
import contextlib
|
||||
import struct
|
||||
import logging as __logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
module_logger = __logging.getLogger(__name__)
|
||||
|
||||
|
@ -22,13 +23,50 @@ module_logger = __logging.getLogger(__name__)
|
|||
_TReceipt = TypeVar("_TReceipt")
|
||||
_TLink = TypeVar("_TLink")
|
||||
MSG_MAGIC = 0xac
|
||||
PROTOCOL_VERSION=1
|
||||
PROTOCOL_VERSION = 1
|
||||
|
||||
|
||||
def _make_MSGTYPE(val: int):
|
||||
return ((MSG_MAGIC << 8) & 0xff00) | (val & 0x00ff)
|
||||
|
||||
|
||||
class MessageOutletBase(ABC):
|
||||
@abstractmethod
|
||||
def send(self, raw: bytes) -> _TReceipt:
|
||||
raise NotImplemented()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def mdu(self):
|
||||
raise NotImplemented()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def rtt(self):
|
||||
raise NotImplemented()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_usuable(self):
|
||||
raise NotImplemented()
|
||||
|
||||
@abstractmethod
|
||||
def get_receipt_state(self, receipt: _TReceipt) -> MessageState:
|
||||
raise NotImplemented()
|
||||
|
||||
@abstractmethod
|
||||
def timed_out(self):
|
||||
raise NotImplemented()
|
||||
|
||||
@abstractmethod
|
||||
def __str__(self):
|
||||
raise NotImplemented()
|
||||
|
||||
@abstractmethod
|
||||
def set_packet_received_callback(self, cb: Callable[[MessageOutletBase, bytes], None]):
|
||||
raise NotImplemented()
|
||||
|
||||
|
||||
class METype(enum.IntEnum):
|
||||
ME_NO_MSG_TYPE = 0
|
||||
ME_INVALID_MSG_TYPE = 1
|
||||
|
@ -58,17 +96,17 @@ class Message(abc.ABC):
|
|||
self.msgid = uuid.uuid4()
|
||||
self.raw: bytes | None = None
|
||||
self.receipt: _TReceipt = None
|
||||
self.link: _TLink = None
|
||||
self.outlet: _TLink = None
|
||||
self.tracked: bool = False
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.__class__.__name__} {self.msgid}"
|
||||
|
||||
@abc.abstractmethod
|
||||
@abstractmethod
|
||||
def pack(self) -> bytes:
|
||||
raise NotImplemented()
|
||||
|
||||
@abc.abstractmethod
|
||||
@abstractmethod
|
||||
def unpack(self, raw):
|
||||
raise NotImplemented()
|
||||
|
||||
|
@ -124,7 +162,8 @@ class ExecuteCommandMesssage(Message):
|
|||
MSGTYPE = _make_MSGTYPE(3)
|
||||
|
||||
def __init__(self, cmdline: [str] = None, pipe_stdin: bool = False, pipe_stdout: bool = False,
|
||||
pipe_stderr: bool = False, tcflags: [any] = None, term: str | None = None):
|
||||
pipe_stderr: bool = False, tcflags: [any] = None, term: str | None = None, rows: int = None,
|
||||
cols: int = None, hpix: int = None, vpix: int = None):
|
||||
super().__init__()
|
||||
self.cmdline = cmdline
|
||||
self.pipe_stdin = pipe_stdin
|
||||
|
@ -132,16 +171,20 @@ class ExecuteCommandMesssage(Message):
|
|||
self.pipe_stderr = pipe_stderr
|
||||
self.tcflags = tcflags
|
||||
self.term = term
|
||||
self.rows = rows
|
||||
self.cols = cols
|
||||
self.hpix = hpix
|
||||
self.vpix = vpix
|
||||
|
||||
def pack(self) -> bytes:
|
||||
raw = umsgpack.packb((self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr,
|
||||
self.tcflags, self.term))
|
||||
self.tcflags, self.term, self.rows, self.cols, self.hpix, self.vpix))
|
||||
return self.wrap_MSGTYPE(raw)
|
||||
|
||||
def unpack(self, raw):
|
||||
raw = self.unwrap_MSGTYPE(raw)
|
||||
self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, self.tcflags, self.term \
|
||||
= umsgpack.unpackb(raw)
|
||||
self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, self.tcflags, self.term, self.rows, \
|
||||
self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw)
|
||||
|
||||
class StreamDataMessage(Message):
|
||||
MSGTYPE = _make_MSGTYPE(4)
|
||||
|
@ -156,7 +199,7 @@ class StreamDataMessage(Message):
|
|||
self.eof = eof
|
||||
|
||||
def pack(self) -> bytes:
|
||||
raw = umsgpack.packb((self.stream_id, self.eof, self.data))
|
||||
raw = umsgpack.packb((self.stream_id, self.eof, bytes(self.data)))
|
||||
return self.wrap_MSGTYPE(raw)
|
||||
|
||||
def unpack(self, raw):
|
||||
|
@ -169,7 +212,7 @@ class VersionInfoMessage(Message):
|
|||
|
||||
def __init__(self, sw_version: str = None):
|
||||
super().__init__()
|
||||
self.sw_version = sw_version
|
||||
self.sw_version = sw_version or rnsh.__version__
|
||||
self.protocol_version = PROTOCOL_VERSION
|
||||
|
||||
def pack(self) -> bytes:
|
||||
|
@ -199,6 +242,22 @@ class ErrorMessage(Message):
|
|||
self.msg, self.fatal, self.data = umsgpack.unpackb(raw)
|
||||
|
||||
|
||||
class CommandExitedMessage(Message):
|
||||
MSGTYPE = _make_MSGTYPE(7)
|
||||
|
||||
def __init__(self, return_code: int = None):
|
||||
super().__init__()
|
||||
self.return_code = return_code
|
||||
|
||||
def pack(self) -> bytes:
|
||||
raw = umsgpack.packb(self.return_code)
|
||||
return self.wrap_MSGTYPE(raw)
|
||||
|
||||
def unpack(self, raw: bytes):
|
||||
raw = self.unwrap_MSGTYPE(raw)
|
||||
self.return_code = umsgpack.unpackb(raw)
|
||||
|
||||
|
||||
class Messenger(contextlib.AbstractContextManager):
|
||||
|
||||
@staticmethod
|
||||
|
@ -208,29 +267,16 @@ class Messenger(contextlib.AbstractContextManager):
|
|||
subclass_tuples.append((subclass.MSGTYPE, subclass))
|
||||
return subclass_tuples
|
||||
|
||||
def __init__(self, receipt_checker: Callable[[_TReceipt], MessageState],
|
||||
link_timeout_callback: Callable[[_TLink], None],
|
||||
link_mdu_getter: Callable[[_TLink], int],
|
||||
link_rtt_getter: Callable[[_TLink], float],
|
||||
link_usable_getter: Callable[[_TLink], bool],
|
||||
packet_sender: Callable[[_TLink, bytes], _TReceipt],
|
||||
retry_delay_min: float = 10.0):
|
||||
def __init__(self, retry_delay_min: float = 10.0):
|
||||
self._log = module_logger.getChild(self.__class__.__name__)
|
||||
self._receipt_checker = receipt_checker
|
||||
self._link_timeout_callback = link_timeout_callback
|
||||
self._link_mdu_getter = link_mdu_getter
|
||||
self._link_rtt_getter = link_rtt_getter
|
||||
self._link_usable_getter = link_usable_getter
|
||||
self._packet_sender = packet_sender
|
||||
self._sent_messages: list[Message] = []
|
||||
self._lock = threading.RLock()
|
||||
self._retry_timer = rnsh.retry.RetryThread()
|
||||
self._message_factories = dict(self.__class__._get_msg_constructors())
|
||||
self._inbound_queue = queue.Queue()
|
||||
self._retry_delay_min = retry_delay_min
|
||||
|
||||
def __enter__(self):
|
||||
pass
|
||||
def __enter__(self) -> Messenger:
|
||||
return self
|
||||
|
||||
def __exit__(self, __exc_type: Type[BaseException] | None, __exc_value: BaseException | None,
|
||||
__traceback: TracebackType | None) -> bool | None:
|
||||
|
@ -238,39 +284,38 @@ class Messenger(contextlib.AbstractContextManager):
|
|||
return False
|
||||
|
||||
def shutdown(self):
|
||||
self._run = False
|
||||
self._retry_timer.close()
|
||||
|
||||
def inbound(self, raw: bytes):
|
||||
def clear_retries(self, outlet):
|
||||
self._retry_timer.complete(outlet)
|
||||
|
||||
def receive(self, raw: bytes) -> Message:
|
||||
(mid, contents) = Message.static_unwrap_MSGTYPE(raw)
|
||||
ctor = self._message_factories.get(mid, None)
|
||||
if ctor is None:
|
||||
raise MessagingException(METype.ME_NOT_REGISTERED, f"unable to find constructor for message type {hex(mid)}")
|
||||
message = ctor()
|
||||
message.unpack(raw)
|
||||
self._log.debug("Message received: {message}")
|
||||
self._inbound_queue.put(message)
|
||||
self._log.debug(f"Message received: {message}")
|
||||
return message
|
||||
|
||||
def get_mdu(self, link: _TLink) -> int:
|
||||
return self._link_mdu_getter(link) - 4
|
||||
|
||||
def get_rtt(self, link: _TLink) -> float:
|
||||
return self._link_rtt_getter(link)
|
||||
|
||||
def is_link_ready(self, link: _TLink) -> bool:
|
||||
if not self._link_usable_getter(link):
|
||||
def is_outlet_ready(self, outlet: MessageOutletBase) -> bool:
|
||||
if not outlet.is_usuable:
|
||||
self._log.debug("is_outlet_ready outlet unusable")
|
||||
return False
|
||||
|
||||
with self._lock:
|
||||
for message in self._sent_messages:
|
||||
if message.link == link:
|
||||
if message.outlet == outlet and message.tracked and message.receipt \
|
||||
and outlet.get_receipt_state(message.receipt) == MessageState.MSGSTATE_SENT:
|
||||
self._log.debug("is_outlet_ready pending message found")
|
||||
return False
|
||||
return True
|
||||
|
||||
def send_message(self, link: _TLink, message: Message):
|
||||
def send(self, outlet: MessageOutletBase, message: Message):
|
||||
with self._lock:
|
||||
if not self.is_link_ready(link):
|
||||
raise MessagingException(METype.ME_LINK_NOT_READY, f"link {link} not ready")
|
||||
if not self.is_outlet_ready(outlet):
|
||||
raise MessagingException(METype.ME_LINK_NOT_READY, f"link {outlet} not ready")
|
||||
|
||||
if message in self._sent_messages:
|
||||
raise MessagingException(METype.ME_ALREADY_SENT)
|
||||
|
@ -279,43 +324,36 @@ class Messenger(contextlib.AbstractContextManager):
|
|||
|
||||
if not message.raw:
|
||||
message.raw = message.pack()
|
||||
message.link = link
|
||||
message.outlet = outlet
|
||||
|
||||
def send(tag: any, tries: int):
|
||||
state = MessageState.MSGSTATE_NEW if not message.receipt else self._receipt_checker(message.receipt)
|
||||
def send_inner(tag: any, tries: int):
|
||||
state = MessageState.MSGSTATE_NEW if not message.receipt else outlet.get_receipt_state(message.receipt)
|
||||
if state in [MessageState.MSGSTATE_NEW, MessageState.MSGSTATE_FAILED]:
|
||||
try:
|
||||
self._log.debug(f"Sending packet for {message}")
|
||||
message.receipt = self._packet_sender(link, message.raw)
|
||||
message.receipt = outlet.send(message.raw)
|
||||
except Exception as ex:
|
||||
self._log.exception(f"Error sending message {message}")
|
||||
elif state in [MessageState.MSGSTATE_SENT]:
|
||||
self._log.debug(f"Retry skipped, message still pending {message}")
|
||||
elif state in [MessageState.MSGSTATE_DELIVERED]:
|
||||
latency = round(time.time() - message.ts, 1)
|
||||
self._log.debug(f"Message delivered {message.msgid} after {tries-1} tries/{latency} seconds")
|
||||
self._log.debug(f"{message} delivered {message.msgid} after {tries-1} tries/{latency} seconds")
|
||||
with self._lock:
|
||||
self._sent_messages.remove(message)
|
||||
message.tracked = False
|
||||
self._retry_timer.complete(link)
|
||||
return link
|
||||
self._retry_timer.complete(outlet)
|
||||
return outlet
|
||||
|
||||
def timeout(tag: any, tries: int):
|
||||
latency = round(time.time() - message.ts, 1)
|
||||
msg = "delivered" if message.receipt and self._receipt_checker(message.receipt) == MessageState.MSGSTATE_DELIVERED else "retry timeout"
|
||||
msg = "delivered" if message.receipt and outlet.get_receipt_state(message.receipt) == MessageState.MSGSTATE_DELIVERED else "retry timeout"
|
||||
self._log.debug(f"Message {msg} {message} after {tries} tries/{latency} seconds")
|
||||
with self._lock:
|
||||
self._sent_messages.remove(message)
|
||||
message.tracked = False
|
||||
self._link_timeout_callback(link)
|
||||
|
||||
rtt = self._link_rtt_getter(link)
|
||||
self._retry_timer.begin(5, min(rtt * 100, max(rtt * 2, self._retry_delay_min)), send, timeout)
|
||||
|
||||
def poll_inbound(self, block: bool = True, timeout: float = None) -> Message | None:
|
||||
try:
|
||||
return self._inbound_queue.get(block=block, timeout=timeout)
|
||||
except queue.Empty:
|
||||
return None
|
||||
outlet.timed_out()
|
||||
|
||||
rtt = outlet.rtt
|
||||
self._retry_timer.begin(5, max(rtt * 5, self._retry_delay_min), send_inner, timeout)
|
||||
|
||||
|
|
|
@ -79,7 +79,7 @@ class RetryThread(AbstractContextManager):
|
|||
self._lock = threading.RLock()
|
||||
self._run = True
|
||||
self._finished: asyncio.Future = None
|
||||
self._thread = threading.Thread(name=name, target=self._thread_run)
|
||||
self._thread = threading.Thread(name=name, target=self._thread_run, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def is_alive(self):
|
||||
|
|
865
rnsh/rnsh.py
865
rnsh/rnsh.py
|
@ -26,10 +26,12 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
import base64
|
||||
import enum
|
||||
import functools
|
||||
import importlib.metadata
|
||||
import logging as __logging
|
||||
import os
|
||||
import queue
|
||||
import shlex
|
||||
import signal
|
||||
import sys
|
||||
|
@ -44,10 +46,12 @@ import rnsh.process as process
|
|||
import rnsh.retry as retry
|
||||
import rnsh.rnslogging as rnslogging
|
||||
import rnsh.hacks as hacks
|
||||
import rnsh.session as session
|
||||
import re
|
||||
import contextlib
|
||||
import rnsh.args
|
||||
import pwd
|
||||
import rnsh.protocol as protocol
|
||||
|
||||
module_logger = __logging.getLogger(__name__)
|
||||
|
||||
|
@ -141,13 +145,18 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
|||
log.info(f"Using command {shlex.join(_cmd)}")
|
||||
|
||||
_no_remote_command = no_remote_command
|
||||
session.ListenerSession.allow_remote_command = not no_remote_command
|
||||
_remote_cmd_as_args = remote_cmd_as_args
|
||||
if (_cmd is None or len(_cmd) == 0 or _cmd[0] is None or len(_cmd[0]) == 0) \
|
||||
and (_no_remote_command or _remote_cmd_as_args):
|
||||
raise Exception(f"Unable to look up shell for {os.getlogin}, cannot proceed with -A or -C and no <program>.")
|
||||
|
||||
session.ListenerSession.default_command = _cmd
|
||||
session.ListenerSession.remote_cmd_as_args = _remote_cmd_as_args
|
||||
|
||||
if disable_auth:
|
||||
_allow_all = True
|
||||
session.ListenerSession.allow_all = True
|
||||
else:
|
||||
if allowed is not None:
|
||||
for a in allowed:
|
||||
|
@ -161,6 +170,7 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
|||
try:
|
||||
destination_hash = bytes.fromhex(a)
|
||||
_allowed_identity_hashes.append(destination_hash)
|
||||
session.ListenerSession.allowed_identity_hashes.append(destination_hash)
|
||||
except Exception:
|
||||
raise ValueError("Invalid destination entered. Check your input.")
|
||||
except Exception as e:
|
||||
|
@ -170,23 +180,9 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
|||
if len(_allowed_identity_hashes) < 1 and not disable_auth:
|
||||
log.warning("Warning: No allowed identities configured, rnsh will not accept any connections!")
|
||||
|
||||
_destination.set_link_established_callback(_listen_link_established)
|
||||
|
||||
if not _allow_all:
|
||||
_destination.register_request_handler(
|
||||
path="data",
|
||||
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
|
||||
# response_generator=_listen_request,
|
||||
allow=RNS.Destination.ALLOW_LIST,
|
||||
allowed_list=_allowed_identity_hashes
|
||||
)
|
||||
else:
|
||||
_destination.register_request_handler(
|
||||
path="data",
|
||||
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
|
||||
# response_generator=_listen_request,
|
||||
allow=RNS.Destination.ALLOW_ALL,
|
||||
)
|
||||
def link_established(lnk: RNS.Link):
|
||||
session.ListenerSession(session.RNSOutlet.get_outlet(lnk), _loop)
|
||||
_destination.set_link_established_callback(link_established)
|
||||
|
||||
if await _check_finished():
|
||||
return
|
||||
|
@ -199,531 +195,22 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
|||
last = time.time()
|
||||
|
||||
try:
|
||||
while not await _check_finished(1.0):
|
||||
while not await _check_finished():
|
||||
if announce_period and 0 < announce_period < time.time() - last:
|
||||
last = time.time()
|
||||
_destination.announce()
|
||||
await session.ListenerSession.pump_all()
|
||||
await asyncio.sleep(0.01)
|
||||
finally:
|
||||
log.warning("Shutting down")
|
||||
for link in list(_destination.links):
|
||||
with exception.permit(SystemExit, KeyboardInterrupt):
|
||||
proc = Session.get_for_tag(link.link_id)
|
||||
if proc is not None and proc.process.running:
|
||||
proc.process.terminate()
|
||||
await asyncio.sleep(0)
|
||||
await session.ListenerSession.terminate_all("Shutting down")
|
||||
await asyncio.sleep(1)
|
||||
session.ListenerSession.messenger.shutdown()
|
||||
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
|
||||
for link in links_still_active:
|
||||
if link.status != RNS.Link.CLOSED:
|
||||
if link.status not in [RNS.Link.CLOSED]:
|
||||
link.teardown()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
_PROTOCOL_MAGIC = 0xdeadbeef
|
||||
|
||||
|
||||
def _protocol_make_version(version: int):
|
||||
return (_PROTOCOL_MAGIC << 32) & 0xffffffff00000000 | (0xffffffff & version)
|
||||
|
||||
|
||||
_PROTOCOL_VERSION_0 = _protocol_make_version(0)
|
||||
_PROTOCOL_VERSION_1 = _protocol_make_version(1)
|
||||
_PROTOCOL_VERSION_2 = _protocol_make_version(2)
|
||||
_PROTOCOL_VERSION_3 = _protocol_make_version(3)
|
||||
|
||||
_PROTOCOL_VERSION_DEFAULT = _PROTOCOL_VERSION_3
|
||||
|
||||
def _protocol_split_version(version: int):
|
||||
return (version >> 32) & 0xffffffff, version & 0xffffffff
|
||||
|
||||
|
||||
def _protocol_check_magic(value: int):
|
||||
return _protocol_split_version(value)[0] == _PROTOCOL_MAGIC
|
||||
|
||||
|
||||
def _protocol_response_chars_take(link_mdu: int, version: int) -> int:
|
||||
if version >= _PROTOCOL_VERSION_2:
|
||||
return link_mdu - 64 # TODO: tune
|
||||
else:
|
||||
return link_mdu // 2
|
||||
|
||||
|
||||
def _protocol_request_chars_take(link_mdu: int, version: int, term: str, cmd: str) -> int:
|
||||
if version >= _PROTOCOL_VERSION_2:
|
||||
return link_mdu - 15 * 8 - len(term) - len(cmd) - 20 # TODO: tune
|
||||
else:
|
||||
return link_mdu // 2
|
||||
|
||||
|
||||
def _bitwise_or_if(value: int, condition: bool, orval: int):
|
||||
if not condition:
|
||||
return value
|
||||
return value | orval
|
||||
|
||||
|
||||
def _check_and(value: int, andval: int) -> bool:
|
||||
return (value & andval) > 0
|
||||
|
||||
|
||||
class Session:
|
||||
_processes: [(any, Session)] = []
|
||||
_lock = threading.RLock()
|
||||
|
||||
@classmethod
|
||||
def get_for_tag(cls, tag: any) -> Session | None:
|
||||
with cls._lock:
|
||||
return next(map(lambda p: p[1], filter(lambda p: p[0] == tag, cls._processes)), None)
|
||||
|
||||
@classmethod
|
||||
def put_for_tag(cls, tag: any, ps: Session):
|
||||
with cls._lock:
|
||||
cls.clear_tag(tag)
|
||||
cls._processes.append((tag, ps))
|
||||
|
||||
@classmethod
|
||||
def clear_tag(cls, tag: any):
|
||||
with cls._lock:
|
||||
with exception.permit(SystemExit):
|
||||
cls._processes.remove(tag)
|
||||
|
||||
def __init__(self,
|
||||
tag: any,
|
||||
cmd: [str],
|
||||
data_available_callback: callable,
|
||||
terminated_callback: callable,
|
||||
session_flags: int,
|
||||
term: str | None,
|
||||
remote_identity: str | None,
|
||||
loop: asyncio.AbstractEventLoop = None):
|
||||
|
||||
self._log = _get_logger(self.__class__.__name__)
|
||||
self._loop = loop if loop is not None else asyncio.get_running_loop()
|
||||
self._process = process.CallbackSubprocess(argv=cmd,
|
||||
env={"TERM": term or os.environ.get("TERM", None),
|
||||
"RNS_REMOTE_IDENTITY": remote_identity or ""},
|
||||
loop=loop,
|
||||
stdout_callback=self._stdout_data,
|
||||
stderr_callback=self._stderr_data,
|
||||
terminated_callback=terminated_callback,
|
||||
stdin_is_pipe=_check_and(session_flags,
|
||||
Session.REQUEST_FLAGS_PIPE_STDIN),
|
||||
stdout_is_pipe=_check_and(session_flags,
|
||||
Session.REQUEST_FLAGS_PIPE_STDOUT),
|
||||
stderr_is_pipe=_check_and(session_flags,
|
||||
Session.REQUEST_FLAGS_PIPE_STDERR))
|
||||
self._log.debug(f"Starting {cmd}")
|
||||
self._stdout_buffer = bytearray()
|
||||
self._stderr_buffer = bytearray()
|
||||
self._lock = threading.RLock()
|
||||
self._data_available_cb = data_available_callback
|
||||
self._terminated_cb = terminated_callback
|
||||
self._pending_receipt: RNS.PacketReceipt | None = None
|
||||
self._process.start()
|
||||
self._term_state: [int] = None
|
||||
self._session_flags: int = session_flags
|
||||
Session.put_for_tag(tag, self)
|
||||
|
||||
def pending_receipt_peek(self) -> RNS.PacketReceipt | None:
|
||||
return self._pending_receipt
|
||||
|
||||
def pending_receipt_take(self) -> RNS.PacketReceipt | None:
|
||||
with self._lock:
|
||||
val = self._pending_receipt
|
||||
self._pending_receipt = None
|
||||
return val
|
||||
|
||||
def pending_receipt_put(self, receipt: RNS.PacketReceipt | None):
|
||||
with self._lock:
|
||||
self._pending_receipt = receipt
|
||||
|
||||
@property
|
||||
def process(self) -> process.CallbackSubprocess:
|
||||
return self._process
|
||||
|
||||
@property
|
||||
def return_code(self) -> int | None:
|
||||
return self.process.return_code
|
||||
|
||||
@property
|
||||
def lock(self) -> threading.RLock:
|
||||
return self._lock
|
||||
|
||||
def read_stdout(self, count: int) -> bytes:
|
||||
with self.lock:
|
||||
initial_len = len(self._stdout_buffer)
|
||||
take = self._stdout_buffer[:count]
|
||||
self._stdout_buffer = self._stdout_buffer[count:]
|
||||
self._log.debug(f"stdout: read {len(take)} bytes of {initial_len}, {len(self._stdout_buffer)} remaining")
|
||||
return take
|
||||
|
||||
def _stdout_data(self, data: bytes):
|
||||
with self.lock:
|
||||
self._stdout_buffer.extend(data)
|
||||
total_available = len(self._stdout_buffer) + len(self._stderr_buffer)
|
||||
try:
|
||||
self._data_available_cb(total_available)
|
||||
except Exception as e:
|
||||
self._log.error(f"stdout: error calling ProcessState data_available_callback {e}")
|
||||
|
||||
def read_stderr(self, count: int) -> bytes:
|
||||
with self.lock:
|
||||
initial_len = len(self._stderr_buffer)
|
||||
take = self._stderr_buffer[:count]
|
||||
self._stderr_buffer = self._stderr_buffer[count:]
|
||||
self._log.debug(f"stderr: read {len(take)} bytes of {initial_len}, {len(self._stderr_buffer)} remaining")
|
||||
return take
|
||||
|
||||
def _stderr_data(self, data: bytes):
|
||||
with self.lock:
|
||||
self._stderr_buffer.extend(data)
|
||||
total_available = len(self._stderr_buffer) + len(self._stdout_buffer)
|
||||
try:
|
||||
self._data_available_cb(total_available)
|
||||
except Exception as e:
|
||||
self._log.error(f"stderr: error calling ProcessState data_available_callback {e}")
|
||||
|
||||
TERMSTATE_IDX_TERM = 0
|
||||
TERMSTATE_IDX_TIOS = 1
|
||||
TERMSTATE_IDX_ROWS = 2
|
||||
TERMSTATE_IDX_COLS = 3
|
||||
TERMSTATE_IDX_HPIX = 4
|
||||
TERMSTATE_IDX_VPIX = 5
|
||||
|
||||
def _update_winsz(self):
|
||||
try:
|
||||
self.process.set_winsize(self._term_state[1],
|
||||
self._term_state[2],
|
||||
self._term_state[3],
|
||||
self._term_state[4])
|
||||
except Exception as e:
|
||||
self._log.debug(f"failed to update winsz: {e}")
|
||||
|
||||
REQUEST_IDX_VERS = 0
|
||||
REQUEST_IDX_STDIN = 1
|
||||
REQUEST_IDX_TERM = 2
|
||||
REQUEST_IDX_TIOS = 3
|
||||
REQUEST_IDX_ROWS = 4
|
||||
REQUEST_IDX_COLS = 5
|
||||
REQUEST_IDX_HPIX = 6
|
||||
REQUEST_IDX_VPIX = 7
|
||||
REQUEST_IDX_CMD = 8
|
||||
REQUEST_IDX_FLAGS = 9
|
||||
REQUEST_IDX_BYTES_AVAILABLE = 10
|
||||
REQUEST_FLAGS_PIPE_STDIN = 0x01
|
||||
REQUEST_FLAGS_PIPE_STDOUT = 0x02
|
||||
REQUEST_FLAGS_PIPE_STDERR = 0x04
|
||||
REQUEST_FLAGS_EOF_STDIN = 0x08
|
||||
|
||||
@staticmethod
|
||||
def default_request() -> [any]:
|
||||
global _tr
|
||||
request: list[any] = [
|
||||
_PROTOCOL_VERSION_DEFAULT, # 0 Protocol Version
|
||||
None, # 1 Stdin
|
||||
None, # 2 TERM variable
|
||||
None, # 3 termios attributes or something
|
||||
None, # 4 terminal rows
|
||||
None, # 5 terminal cols
|
||||
None, # 6 terminal horizontal pixels
|
||||
None, # 7 terminal vertical pixels
|
||||
None, # 8 Command to run
|
||||
0, # 9 Flags
|
||||
0, # 10 Bytes Available
|
||||
].copy()
|
||||
|
||||
if os.isatty(0):
|
||||
request[Session.REQUEST_IDX_TERM] = os.environ.get("TERM", None)
|
||||
request[Session.REQUEST_IDX_TIOS] = _tr.original_attr() if _tr else None
|
||||
with contextlib.suppress(OSError):
|
||||
request[Session.REQUEST_IDX_ROWS], \
|
||||
request[Session.REQUEST_IDX_COLS], \
|
||||
request[Session.REQUEST_IDX_HPIX], \
|
||||
request[Session.REQUEST_IDX_VPIX] = process.tty_get_winsize(0)
|
||||
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], not os.isatty(0),
|
||||
Session.REQUEST_FLAGS_PIPE_STDIN)
|
||||
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], not os.isatty(1),
|
||||
Session.REQUEST_FLAGS_PIPE_STDOUT)
|
||||
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], not os.isatty(2),
|
||||
Session.REQUEST_FLAGS_PIPE_STDERR)
|
||||
return request
|
||||
|
||||
def process_request(self, data: [any], read_size: int) -> [any]:
|
||||
stdin = data[Session.REQUEST_IDX_STDIN] # Data passed to stdin
|
||||
# term = data[ProcessState.REQUEST_IDX_TERM] # TERM environment variable
|
||||
# tios = data[ProcessState.REQUEST_IDX_TIOS] # termios attr
|
||||
# rows = data[ProcessState.REQUEST_IDX_ROWS] # window rows
|
||||
# cols = data[ProcessState.REQUEST_IDX_COLS] # window cols
|
||||
# hpix = data[ProcessState.REQUEST_IDX_HPIX] # window horizontal pixels
|
||||
# vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels
|
||||
# term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX+1]
|
||||
bytes_available = data[Session.REQUEST_IDX_BYTES_AVAILABLE]
|
||||
flags = data[Session.REQUEST_IDX_FLAGS]
|
||||
stdin_eof = _check_and(flags, Session.REQUEST_FLAGS_EOF_STDIN)
|
||||
response = Session.default_response()
|
||||
|
||||
first_term_state = self._term_state is None
|
||||
term_state = data[Session.REQUEST_IDX_TIOS:Session.REQUEST_IDX_VPIX + 1]
|
||||
|
||||
response[Session.RESPONSE_IDX_FLAGS] = _bitwise_or_if(response[Session.RESPONSE_IDX_FLAGS],
|
||||
self.process.running, Session.RESPONSE_FLAGS_RUNNING)
|
||||
if self.process.running:
|
||||
if term_state != self._term_state:
|
||||
self._term_state = term_state
|
||||
if term_state is not None:
|
||||
self._update_winsz()
|
||||
if first_term_state is not None:
|
||||
# TODO: use a more specific error
|
||||
with contextlib.suppress(Exception):
|
||||
self.process.tcsetattr(termios.TCSADRAIN, term_state[0])
|
||||
if stdin is not None and len(stdin) > 0:
|
||||
if data[Session.REQUEST_IDX_VERS] < _PROTOCOL_VERSION_2:
|
||||
stdin = base64.b64decode(stdin)
|
||||
self.process.write(stdin)
|
||||
if stdin_eof and bytes_available == 0:
|
||||
module_logger.debug("Closing stdin")
|
||||
with contextlib.suppress(Exception):
|
||||
self.process.close_stdin()
|
||||
response[Session.RESPONSE_IDX_RETCODE] = None if self.process.running else self.return_code
|
||||
|
||||
with self.lock:
|
||||
#prioritizing stderr
|
||||
stderr = self.read_stderr(read_size)
|
||||
stdout = self.read_stdout(read_size - len(stderr))
|
||||
response[Session.RESPONSE_IDX_RDYBYTE] = len(self._stdout_buffer) + len(self._stderr_buffer)
|
||||
|
||||
if stderr is not None and len(stderr) > 0:
|
||||
response[Session.RESPONSE_IDX_STDERR] = bytes(stderr)
|
||||
if stdout is not None and len(stdout) > 0:
|
||||
response[Session.RESPONSE_IDX_STDOUT] = bytes(stdout)
|
||||
return response
|
||||
|
||||
RESPONSE_IDX_VERSION = 0
|
||||
RESPONSE_IDX_FLAGS = 1
|
||||
RESPONSE_IDX_RETCODE = 2
|
||||
RESPONSE_IDX_RDYBYTE = 3
|
||||
RESPONSE_IDX_STDERR = 4
|
||||
RESPONSE_IDX_STDOUT = 5
|
||||
RESPONSE_IDX_TMSTAMP = 6
|
||||
RESPONSE_FLAGS_RUNNING = 0x01
|
||||
RESPONSE_FLAGS_EOF_STDOUT = 0x02
|
||||
RESPONSE_FLAGS_EOF_STDERR = 0x04
|
||||
|
||||
@staticmethod
|
||||
def default_response() -> [any]:
|
||||
response: list[any] = [
|
||||
_PROTOCOL_VERSION_DEFAULT, # 0: Protocol version
|
||||
False, # 1: Process running
|
||||
None, # 2: Return value
|
||||
0, # 3: Number of outstanding bytes
|
||||
None, # 4: Stderr
|
||||
None, # 5: Stdout
|
||||
None, # 6: Timestamp
|
||||
].copy()
|
||||
response[Session.RESPONSE_IDX_TMSTAMP] = time.time()
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def error_response(cls, msg: str) -> [any]:
|
||||
response = cls.default_response()
|
||||
msg_bytes = f"{msg}\r\n".encode("utf-8")
|
||||
response[Session.RESPONSE_IDX_STDERR] = bytes(msg_bytes)
|
||||
response[Session.RESPONSE_IDX_RETCODE] = 255
|
||||
response[Session.RESPONSE_IDX_RDYBYTE] = 0
|
||||
return response
|
||||
|
||||
|
||||
def _subproc_data_ready(link: RNS.Link, chars_available: int):
|
||||
global _retry_timer
|
||||
log = _get_logger("_subproc_data_ready")
|
||||
session: Session = Session.get_for_tag(link.link_id)
|
||||
|
||||
def send(timeout: bool, tag: any, tries: int) -> any:
|
||||
# log.debug("send")
|
||||
def inner():
|
||||
# log.debug("inner")
|
||||
try:
|
||||
if link.status != RNS.Link.ACTIVE:
|
||||
_retry_timer.complete(link.link_id)
|
||||
session.pending_receipt_take()
|
||||
return
|
||||
|
||||
pr = session.pending_receipt_take()
|
||||
log.debug(f"send inner pr: {pr}")
|
||||
if pr is not None and pr.status == RNS.PacketReceipt.DELIVERED:
|
||||
if not timeout:
|
||||
_retry_timer.complete(tag)
|
||||
log.debug(f"Notification completed with status {pr.status} on link {link}")
|
||||
return
|
||||
else:
|
||||
if not timeout:
|
||||
log.debug(
|
||||
f"Notifying client try {tries} (retcode: {session.return_code} " +
|
||||
f"chars avail: {chars_available})")
|
||||
packet = RNS.Packet(link, DATA_AVAIL_MSG.encode("utf-8"))
|
||||
packet.send()
|
||||
pr = packet.receipt
|
||||
session.pending_receipt_put(pr)
|
||||
else:
|
||||
log.error(f"Retry count exceeded, terminating link {link}")
|
||||
_retry_timer.complete(link.link_id)
|
||||
link.teardown()
|
||||
except Exception as e:
|
||||
log.error("Error notifying client: " + str(e))
|
||||
|
||||
_loop.call_soon_threadsafe(inner)
|
||||
return link.link_id
|
||||
|
||||
with session.lock:
|
||||
if not _retry_timer.has_tag(link.link_id):
|
||||
_retry_timer.begin(try_limit=15,
|
||||
wait_delay=max(link.rtt * 5 if link.rtt is not None else 1, 1),
|
||||
try_callback=functools.partial(send, False),
|
||||
timeout_callback=functools.partial(send, True))
|
||||
else:
|
||||
log.debug(f"Notification already pending for link {link}")
|
||||
|
||||
|
||||
def _subproc_terminated(link: RNS.Link, return_code: int):
|
||||
global _loop
|
||||
log = _get_logger("_subproc_terminated")
|
||||
log.info(f"Subprocess returned {return_code} for link {link}")
|
||||
proc = Session.get_for_tag(link.link_id)
|
||||
if proc is None:
|
||||
log.debug(f"no proc for link {link}")
|
||||
return
|
||||
|
||||
def cleanup():
|
||||
def inner():
|
||||
log.debug(f"cleanup culled link {link}")
|
||||
if link and link.status != RNS.Link.CLOSED:
|
||||
with exception.permit(SystemExit):
|
||||
try:
|
||||
link.teardown()
|
||||
finally:
|
||||
Session.clear_tag(link.link_id)
|
||||
|
||||
_loop.call_later(300, inner)
|
||||
_loop.call_soon(_subproc_data_ready, link, 0)
|
||||
|
||||
_loop.call_soon_threadsafe(cleanup)
|
||||
|
||||
|
||||
def _listen_start_proc(link: RNS.Link, remote_identity: str | None, term: str, cmd: [str],
|
||||
loop: asyncio.AbstractEventLoop, session_flags: int) -> Session | None:
|
||||
log = _get_logger("_listen_start_proc")
|
||||
try:
|
||||
return Session(tag=link.link_id,
|
||||
cmd=cmd,
|
||||
session_flags=session_flags,
|
||||
term=term,
|
||||
remote_identity=remote_identity,
|
||||
loop=loop,
|
||||
data_available_callback=functools.partial(_subproc_data_ready, link),
|
||||
terminated_callback=functools.partial(_subproc_terminated, link))
|
||||
except Exception as e:
|
||||
log.error("Failed to launch process: " + str(e))
|
||||
_subproc_terminated(link, 255)
|
||||
return None
|
||||
|
||||
|
||||
def _listen_link_established(link):
|
||||
global _allow_all
|
||||
log = _get_logger("_listen_link_established")
|
||||
link.set_remote_identified_callback(_initiator_identified)
|
||||
link.set_link_closed_callback(_listen_link_closed)
|
||||
log.info("Link " + str(link) + " established")
|
||||
|
||||
|
||||
def _listen_link_closed(link: RNS.Link):
|
||||
log = _get_logger("_listen_link_closed")
|
||||
# async def cleanup():
|
||||
log.info("Link " + str(link) + " closed")
|
||||
proc: Session | None = Session.get_for_tag(link.link_id)
|
||||
if proc is None:
|
||||
log.warning(f"No process for link {link}")
|
||||
else:
|
||||
try:
|
||||
proc.process.terminate()
|
||||
_retry_timer.complete(link.link_id)
|
||||
except Exception as e:
|
||||
log.error(f"Error closing process for link {link}: {e}")
|
||||
Session.clear_tag(link.link_id)
|
||||
|
||||
|
||||
def _initiator_identified(link, identity):
|
||||
global _allow_all, _cmd, _loop
|
||||
log = _get_logger("_initiator_identified")
|
||||
log.info("Initiator of link " + str(link) + " identified as " + RNS.prettyhexrep(identity.hash))
|
||||
if not _allow_all and identity.hash not in _allowed_identity_hashes:
|
||||
log.warning("Identity " + RNS.prettyhexrep(identity.hash) + " not allowed, tearing down link", RNS.LOG_WARNING)
|
||||
link.teardown()
|
||||
|
||||
|
||||
def _listen_request(path, data, request_id, link_id, remote_identity, requested_at):
|
||||
global _destination, _retry_timer, _loop, _cmd, _no_remote_command
|
||||
log = _get_logger("_listen_request")
|
||||
log.debug(
|
||||
f"listen_execute {path} {RNS.prettyhexrep(request_id)} {RNS.prettyhexrep(link_id)} {remote_identity}, {requested_at}")
|
||||
if not hasattr(data, "__len__") or len(data) < 1:
|
||||
raise Exception("Request data invalid")
|
||||
_retry_timer.complete(link_id)
|
||||
link: RNS.Link = next(filter(lambda l: l.link_id == link_id, _destination.links), None)
|
||||
if link is None:
|
||||
log.error(f"Invalid request {request_id}, no link found with id {link_id}")
|
||||
return None
|
||||
|
||||
remote_version = data[Session.REQUEST_IDX_VERS]
|
||||
if not _protocol_check_magic(remote_version):
|
||||
log.error("Request magic incorrect")
|
||||
link.teardown()
|
||||
return None
|
||||
|
||||
if not remote_version <= _PROTOCOL_VERSION_3:
|
||||
return Session.error_response("Listener<->initiator version mismatch")
|
||||
|
||||
cmd = _cmd.copy()
|
||||
remote_command = data[Session.REQUEST_IDX_CMD]
|
||||
if remote_command is not None and len(remote_command) > 0:
|
||||
if _no_remote_command:
|
||||
return Session.error_response("Listener does not permit initiator to provide command.")
|
||||
elif _remote_cmd_as_args:
|
||||
cmd.extend(remote_command)
|
||||
else:
|
||||
cmd = remote_command
|
||||
|
||||
if not _no_remote_command and (cmd is None or len(cmd) == 0):
|
||||
return Session.error_response("No command supplied and no default command available.")
|
||||
|
||||
session: Session | None = None
|
||||
try:
|
||||
term = data[Session.REQUEST_IDX_TERM]
|
||||
# sanitize
|
||||
if term is not None:
|
||||
term = re.sub('[^A-Za-z-0-9\-\_]','', term)
|
||||
session = Session.get_for_tag(link.link_id)
|
||||
if session is None:
|
||||
log.debug(f"Process not found for link {link}")
|
||||
session = _listen_start_proc(link=link,
|
||||
term=term,
|
||||
session_flags=data[Session.REQUEST_IDX_FLAGS],
|
||||
cmd=cmd,
|
||||
remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""),
|
||||
loop=_loop)
|
||||
if session is None:
|
||||
return Session.error_response("Unable to start subprocess")
|
||||
|
||||
# leave significant headroom for metadata and encoding
|
||||
result = session.process_request(data, _protocol_response_chars_take(link.MDU, remote_version))
|
||||
return result
|
||||
# return ProcessState.default_response()
|
||||
except Exception as e:
|
||||
log.error(f"Error procesing request for link {link}: {e}")
|
||||
try:
|
||||
if session is not None and session.process.running:
|
||||
session.process.terminate()
|
||||
except Exception as ee:
|
||||
log.debug(f"Error terminating process for link {link}: {ee}")
|
||||
|
||||
return Session.default_response()
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
|
||||
async def _spin(until: callable = None, timeout: float | None = None) -> bool:
|
||||
|
@ -731,7 +218,7 @@ async def _spin(until: callable = None, timeout: float | None = None) -> bool:
|
|||
timeout += time.time()
|
||||
|
||||
while (timeout is None or time.time() < timeout) and not until():
|
||||
if await _check_finished(0.001):
|
||||
if await _check_finished(0.01):
|
||||
raise asyncio.CancelledError()
|
||||
if timeout is not None and time.time() > timeout:
|
||||
return False
|
||||
|
@ -744,15 +231,28 @@ _remote_exec_grace = 2.0
|
|||
_new_data: asyncio.Event | None = None
|
||||
_tr: process.TTYRestorer | None = None
|
||||
|
||||
_pq = queue.Queue()
|
||||
|
||||
class InitiatorState(enum.IntEnum):
|
||||
IS_INITIAL = 0
|
||||
IS_LINKED = 1
|
||||
IS_WAIT_VERS = 2
|
||||
IS_RUNNING = 3
|
||||
IS_TERMINATE = 4
|
||||
IS_TEARDOWN = 5
|
||||
|
||||
|
||||
def _client_link_closed(link):
|
||||
log = _get_logger("_client_link_closed")
|
||||
_finished.set()
|
||||
|
||||
|
||||
def _client_packet_handler(message, packet):
|
||||
global _new_data
|
||||
log = _get_logger("_client_packet_handler")
|
||||
if message is not None and message.decode("utf-8") == DATA_AVAIL_MSG and _new_data is not None:
|
||||
log.debug("data available")
|
||||
_new_data.set()
|
||||
else:
|
||||
log.error(f"received unhandled packet")
|
||||
packet.prove()
|
||||
_pq.put(message)
|
||||
|
||||
|
||||
|
||||
class RemoteExecutionError(Exception):
|
||||
|
@ -760,15 +260,10 @@ class RemoteExecutionError(Exception):
|
|||
self.msg = msg
|
||||
|
||||
|
||||
def _response_handler(request_receipt: RNS.RequestReceipt):
|
||||
pass
|
||||
|
||||
|
||||
async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=False, destination=None,
|
||||
service_name="default", stdin=None, timeout=RNS.Transport.PATH_REQUEST_TIMEOUT,
|
||||
cmd: [str] | None = None, stdin_eof=False, bytes_available=0):
|
||||
async def _initiate_link(configdir, identitypath=None, verbosity=0, quietness=0, noid=False, destination=None,
|
||||
service_name="default", timeout=RNS.Transport.PATH_REQUEST_TIMEOUT):
|
||||
global _identity, _reticulum, _link, _destination, _remote_exec_grace, _tr, _new_data
|
||||
log = _get_logger("_execute")
|
||||
log = _get_logger("_initiate_link")
|
||||
|
||||
dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2
|
||||
if len(destination) != dest_len:
|
||||
|
@ -809,6 +304,8 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
|||
_link = RNS.Link(_destination)
|
||||
_link.did_identify = False
|
||||
|
||||
_link.set_link_closed_callback(_client_link_closed)
|
||||
|
||||
log.info(f"Establishing link...")
|
||||
if not await _spin(until=lambda: _link.status == RNS.Link.ACTIVE, timeout=timeout):
|
||||
raise RemoteExecutionError("Could not establish link with " + RNS.prettyhexrep(destination_hash))
|
||||
|
@ -820,110 +317,6 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
|||
|
||||
_link.set_packet_callback(_client_packet_handler)
|
||||
|
||||
request = Session.default_request()
|
||||
log.debug(f"Sending {len(stdin) or 0} bytes to listener")
|
||||
# log.debug(f"Sending {stdin} to listener")
|
||||
request[Session.REQUEST_IDX_STDIN] = bytes(stdin)
|
||||
request[Session.REQUEST_IDX_CMD] = cmd
|
||||
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], stdin_eof,
|
||||
Session.REQUEST_FLAGS_EOF_STDIN)
|
||||
request[Session.REQUEST_IDX_BYTES_AVAILABLE] = bytes_available
|
||||
|
||||
# TODO: Tune
|
||||
timeout = timeout + _link.rtt * 4 + _remote_exec_grace
|
||||
|
||||
log.debug("Sending request")
|
||||
request_receipt = _link.request(
|
||||
path="data",
|
||||
data=request,
|
||||
timeout=timeout
|
||||
)
|
||||
timeout += 0.5
|
||||
|
||||
log.debug("Waiting for delivery")
|
||||
await _spin(
|
||||
until=lambda: _link.status == RNS.Link.CLOSED or (
|
||||
request_receipt.status != RNS.RequestReceipt.FAILED and
|
||||
request_receipt.status != RNS.RequestReceipt.SENT),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if _link.status == RNS.Link.CLOSED:
|
||||
raise RemoteExecutionError("Could not request remote execution, link was closed")
|
||||
|
||||
if request_receipt.status == RNS.RequestReceipt.FAILED:
|
||||
raise RemoteExecutionError("Could not request remote execution")
|
||||
|
||||
await _spin(
|
||||
until=lambda: request_receipt.status != RNS.RequestReceipt.DELIVERED,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if request_receipt.status == RNS.RequestReceipt.FAILED:
|
||||
raise RemoteExecutionError("No result was received")
|
||||
|
||||
if request_receipt.status == RNS.RequestReceipt.FAILED:
|
||||
raise RemoteExecutionError("Receiving result failed")
|
||||
|
||||
if request_receipt.response is not None:
|
||||
try:
|
||||
version = request_receipt.response[Session.RESPONSE_IDX_VERSION] or 0
|
||||
if not _protocol_check_magic(version):
|
||||
raise RemoteExecutionError("Protocol error")
|
||||
elif version != _PROTOCOL_VERSION_3:
|
||||
raise RemoteExecutionError("Protocol version mismatch")
|
||||
|
||||
flags = request_receipt.response[Session.RESPONSE_IDX_FLAGS]
|
||||
running = _check_and(flags, Session.RESPONSE_FLAGS_RUNNING)
|
||||
stdout_eof = _check_and(flags, Session.RESPONSE_FLAGS_EOF_STDOUT)
|
||||
stderr_eof = _check_and(flags, Session.RESPONSE_FLAGS_EOF_STDERR)
|
||||
return_code = request_receipt.response[Session.RESPONSE_IDX_RETCODE]
|
||||
ready_bytes = request_receipt.response[Session.RESPONSE_IDX_RDYBYTE] or 0
|
||||
stdout = request_receipt.response[Session.RESPONSE_IDX_STDOUT]
|
||||
stderr = request_receipt.response[Session.RESPONSE_IDX_STDERR]
|
||||
# if stdout is not None:
|
||||
# stdout = base64.b64decode(stdout)
|
||||
timestamp = request_receipt.response[Session.RESPONSE_IDX_TMSTAMP]
|
||||
# log.debug("data: " + (stdout.decode("utf-8") if stdout is not None else ""))
|
||||
except RemoteExecutionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RemoteExecutionError(f"Received invalid response") from e
|
||||
|
||||
if stdout is not None:
|
||||
_tr.raw()
|
||||
log.debug(f"stdout: {stdout}")
|
||||
os.write(1, stdout)
|
||||
sys.stdout.flush()
|
||||
|
||||
if stderr is not None:
|
||||
_tr.raw()
|
||||
log.debug(f"stderr: {stderr}")
|
||||
os.write(2, stderr)
|
||||
sys.stderr.flush()
|
||||
|
||||
if stderr_eof and ready_bytes == 0:
|
||||
log.debug("Closing stderr")
|
||||
os.close(2)
|
||||
|
||||
if stdout_eof and ready_bytes == 0:
|
||||
log.debug("Closing stdout")
|
||||
os.close(1)
|
||||
|
||||
got_bytes = (len(stdout) if stdout is not None else 0) + (len(stderr) if stderr is not None else 0)
|
||||
log.debug(f"{got_bytes} chars received, {ready_bytes} bytes ready on server, return code {return_code}")
|
||||
|
||||
if ready_bytes > 0:
|
||||
_new_data.set()
|
||||
|
||||
if (not running or return_code is not None) and (ready_bytes == 0):
|
||||
log.debug(f"returning running: {running}, return_code: {return_code}")
|
||||
return return_code or 255
|
||||
|
||||
return None
|
||||
|
||||
_pre_input = bytearray()
|
||||
|
||||
|
||||
async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness: int, noid: bool, destination: str,
|
||||
service_name: str, timeout: float, command: [str] | None = None):
|
||||
|
@ -931,13 +324,52 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
|
|||
log = _get_logger("_initiate")
|
||||
loop = asyncio.get_running_loop()
|
||||
_new_data = asyncio.Event()
|
||||
|
||||
state = InitiatorState.IS_INITIAL
|
||||
data_buffer = bytearray(sys.stdin.buffer.read()) if not os.isatty(sys.stdin.fileno()) else bytearray()
|
||||
|
||||
await _initiate_link(
|
||||
configdir=configdir,
|
||||
identitypath=identitypath,
|
||||
verbosity=verbosity,
|
||||
quietness=quietness,
|
||||
noid=noid,
|
||||
destination=destination,
|
||||
service_name=service_name,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if not _link or _link.status != RNS.Link.ACTIVE:
|
||||
_finished.set()
|
||||
return 255
|
||||
|
||||
state = InitiatorState.IS_LINKED
|
||||
outlet = session.RNSOutlet(_link)
|
||||
with protocol.Messenger(retry_delay_min=5) as messenger:
|
||||
|
||||
# Next step after linking and identifying: send version
|
||||
# if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5):
|
||||
# print("Error bringing up link")
|
||||
# return 253
|
||||
|
||||
messenger.send(outlet, protocol.VersionInfoMessage())
|
||||
try:
|
||||
vp = _pq.get(timeout=max(outlet.rtt * 20, 5))
|
||||
vm = messenger.receive(vp)
|
||||
if not isinstance(vm, protocol.VersionInfoMessage):
|
||||
raise Exception("Invalid message received")
|
||||
log.debug(f"Server version info: sw {vm.sw_version} prot {vm.protocol_version}")
|
||||
state = InitiatorState.IS_RUNNING
|
||||
except queue.Empty:
|
||||
print("Protocol error")
|
||||
return 254
|
||||
|
||||
winch = False
|
||||
def sigwinch_handler():
|
||||
nonlocal winch
|
||||
# log.debug("WindowChanged")
|
||||
if _new_data is not None:
|
||||
_new_data.set()
|
||||
winch = True
|
||||
# if _new_data is not None:
|
||||
# _new_data.set()
|
||||
|
||||
stdin_eof = False
|
||||
def stdin():
|
||||
|
@ -957,61 +389,98 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
|
|||
process.tty_add_reader_callback(sys.stdin.fileno(), stdin)
|
||||
|
||||
await _check_finished()
|
||||
loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler)
|
||||
|
||||
# leave a lot of overhead
|
||||
mdu = 64
|
||||
rtt = 5
|
||||
first_loop = True
|
||||
cmdline = " ".join(command or [])
|
||||
while not await _check_finished():
|
||||
tcattr = None
|
||||
rows, cols, hpix, vpix = (None, None, None, None)
|
||||
try:
|
||||
log.debug("top of client loop")
|
||||
stdin = data_buffer[:mdu]
|
||||
data_buffer = data_buffer[mdu:]
|
||||
_new_data.clear()
|
||||
log.debug("before _execute")
|
||||
return_code = await _execute(
|
||||
configdir=configdir,
|
||||
identitypath=identitypath,
|
||||
verbosity=verbosity,
|
||||
quietness=quietness,
|
||||
noid=noid,
|
||||
destination=destination,
|
||||
service_name=service_name,
|
||||
stdin=stdin,
|
||||
timeout=timeout,
|
||||
cmd=command,
|
||||
stdin_eof=stdin_eof,
|
||||
bytes_available=len(data_buffer)
|
||||
)
|
||||
tcattr = termios.tcgetattr(0)
|
||||
rows, cols, hpix, vpix = process.tty_get_winsize(0)
|
||||
except:
|
||||
try:
|
||||
tcattr = termios.tcgetattr(1)
|
||||
rows, cols, hpix, vpix = process.tty_get_winsize(1)
|
||||
except:
|
||||
try:
|
||||
tcattr = termios.tcgetattr(2)
|
||||
rows, cols, hpix, vpix = process.tty_get_winsize(2)
|
||||
except:
|
||||
pass
|
||||
|
||||
if first_loop:
|
||||
first_loop = False
|
||||
mdu = _protocol_request_chars_take(_link.MDU,
|
||||
_PROTOCOL_VERSION_DEFAULT,
|
||||
os.environ.get("TERM", ""),
|
||||
cmdline)
|
||||
_new_data.set()
|
||||
messenger.send(outlet, protocol.ExecuteCommandMesssage(cmdline=command,
|
||||
pipe_stdin=not os.isatty(0),
|
||||
pipe_stdout=not os.isatty(1),
|
||||
pipe_stderr=not os.isatty(2),
|
||||
tcflags=tcattr,
|
||||
term=os.environ.get("TERM", None),
|
||||
rows=rows,
|
||||
cols=cols,
|
||||
hpix=hpix,
|
||||
vpix=vpix))
|
||||
|
||||
if _link:
|
||||
rtt = _link.rtt
|
||||
|
||||
if return_code is not None:
|
||||
log.debug(f"received return code {return_code}, exiting")
|
||||
|
||||
|
||||
loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler)
|
||||
mdu = _link.MDU - 16
|
||||
sent_eof = False
|
||||
|
||||
while not await _check_finished() and state in [InitiatorState.IS_RUNNING]:
|
||||
try:
|
||||
try:
|
||||
packet = _pq.get_nowait()
|
||||
message = messenger.receive(packet)
|
||||
|
||||
if isinstance(message, protocol.StreamDataMessage):
|
||||
if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDOUT:
|
||||
if message.data and len(message.data) > 0:
|
||||
_tr.raw()
|
||||
log.debug(f"stdout: {message.data}")
|
||||
os.write(1, message.data)
|
||||
sys.stdout.flush()
|
||||
if message.eof:
|
||||
os.close(1)
|
||||
if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDERR:
|
||||
if message.data and len(message.data) > 0:
|
||||
_tr.raw()
|
||||
log.debug(f"stdout: {message.data}")
|
||||
os.write(2, message.data)
|
||||
sys.stderr.flush()
|
||||
if message.eof:
|
||||
os.close(2)
|
||||
elif isinstance(message, protocol.CommandExitedMessage):
|
||||
log.debug(f"received return code {message.return_code}, exiting")
|
||||
with exception.permit(SystemExit, KeyboardInterrupt):
|
||||
_link.teardown()
|
||||
|
||||
return return_code
|
||||
except asyncio.CancelledError:
|
||||
if _link and _link.status != RNS.Link.CLOSED:
|
||||
return message.return_code
|
||||
elif isinstance(message, protocol.ErrorMessage):
|
||||
log.error(message.data)
|
||||
if message.fatal:
|
||||
_link.teardown()
|
||||
return 0
|
||||
return 200
|
||||
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
if messenger.is_outlet_ready(outlet):
|
||||
stdin = data_buffer[:mdu]
|
||||
data_buffer = data_buffer[mdu:]
|
||||
eof = not sent_eof and stdin_eof and len(stdin) == 0
|
||||
if len(stdin) > 0 or eof:
|
||||
messenger.send(outlet, protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN,
|
||||
stdin, eof))
|
||||
sent_eof = eof
|
||||
except RemoteExecutionError as e:
|
||||
print(e.msg)
|
||||
return 255
|
||||
except Exception as ex:
|
||||
print(f"Client exception: {ex}")
|
||||
if _link and _link.status != RNS.Link.CLOSED:
|
||||
_link.teardown()
|
||||
return 127
|
||||
|
||||
await process.event_wait_any([_new_data, _finished], timeout=min(max(rtt * 50, 5), 120))
|
||||
# await process.event_wait_any([_new_data, _finished], timeout=min(max(rtt * 50, 5), 120))
|
||||
await asyncio.sleep(0.01)
|
||||
log.debug("after main loop")
|
||||
return 0
|
||||
|
||||
|
||||
|
|
423
rnsh/session.py
Normal file
423
rnsh/session.py
Normal file
|
@ -0,0 +1,423 @@
|
|||
from __future__ import annotations
|
||||
import contextlib
|
||||
import functools
|
||||
import threading
|
||||
import rnsh.exception as exception
|
||||
import asyncio
|
||||
import rnsh.process as process
|
||||
import rnsh.helpers as helpers
|
||||
import rnsh.protocol as protocol
|
||||
import enum
|
||||
from typing import TypeVar, Generic, Callable, List
|
||||
from abc import abstractmethod, ABC
|
||||
from multiprocessing import Manager
|
||||
import os
|
||||
import RNS
|
||||
|
||||
import logging as __logging
|
||||
|
||||
from rnsh.protocol import MessageOutletBase, _TReceipt, MessageState
|
||||
|
||||
module_logger = __logging.getLogger(__name__)
|
||||
|
||||
_TLink = TypeVar("_TLink")
|
||||
|
||||
class SEType(enum.IntEnum):
|
||||
SE_LINK_CLOSED = 0
|
||||
|
||||
|
||||
class SessionException(Exception):
|
||||
def __init__(self, setype: SEType, msg: str, *args):
|
||||
super().__init__(msg, args)
|
||||
self.type = setype
|
||||
|
||||
|
||||
class LSState(enum.IntEnum):
|
||||
LSSTATE_WAIT_IDENT = 1
|
||||
LSSTATE_WAIT_VERS = 2
|
||||
LSSTATE_WAIT_CMD = 3
|
||||
LSSTATE_RUNNING = 4
|
||||
LSSTATE_ERROR = 5
|
||||
LSSTATE_TEARDOWN = 6
|
||||
|
||||
|
||||
_TIdentity = TypeVar("_TIdentity")
|
||||
|
||||
|
||||
class LSOutletBase(protocol.MessageOutletBase):
|
||||
@abstractmethod
|
||||
def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]):
|
||||
raise NotImplemented()
|
||||
|
||||
@abstractmethod
|
||||
def set_link_closed_callback(self, cb: Callable[[LSOutletBase], None]):
|
||||
raise NotImplemented()
|
||||
|
||||
@abstractmethod
|
||||
def unset_link_closed_callback(self):
|
||||
raise NotImplemented()
|
||||
|
||||
@abstractmethod
|
||||
def teardown(self):
|
||||
raise NotImplemented()
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
raise NotImplemented()
|
||||
|
||||
|
||||
class ListenerSession:
|
||||
sessions: List[ListenerSession] = []
|
||||
messenger: protocol.Messenger = protocol.Messenger(retry_delay_min=5)
|
||||
allowed_identity_hashes: [any] = []
|
||||
allow_all: bool = False
|
||||
allow_remote_command: bool = False
|
||||
default_command: [str] = []
|
||||
remote_cmd_as_args = False
|
||||
|
||||
def __init__(self, outlet: LSOutletBase, loop: asyncio.AbstractEventLoop):
|
||||
self._log = module_logger.getChild(self.__class__.__name__)
|
||||
self._log.info(f"Session started for {outlet}")
|
||||
self.outlet = outlet
|
||||
self.outlet.set_initiator_identified_callback(self._initiator_identified)
|
||||
self.outlet.set_link_closed_callback(self._link_closed)
|
||||
self.loop = loop
|
||||
self.state: LSState = None
|
||||
self.remote_identity = None
|
||||
self.term: str | None = None
|
||||
self.stdin_is_pipe: bool = False
|
||||
self.stdout_is_pipe: bool = False
|
||||
self.stderr_is_pipe: bool = False
|
||||
self.tcflags: [any] = None
|
||||
self.cmdline: [str] = None
|
||||
self.rows: int = 0
|
||||
self.cols: int = 0
|
||||
self.hpix: int = 0
|
||||
self.vpix: int = 0
|
||||
self.stdout_buf = bytearray()
|
||||
self.stdout_eof_sent = False
|
||||
self.stderr_buf = bytearray()
|
||||
self.stderr_eof_sent = False
|
||||
self.return_code: int | None = None
|
||||
self.return_code_sent = False
|
||||
self.process: process.CallbackSubprocess | None = None
|
||||
self._set_state(LSState.LSSTATE_WAIT_IDENT)
|
||||
self.sessions.append(self)
|
||||
|
||||
def _terminated(self, return_code: int):
|
||||
self.return_code = return_code
|
||||
|
||||
def _set_state(self, state: LSState, timeout_factor: float = 10.0):
|
||||
timeout = max(self.outlet.rtt * timeout_factor, max(self.outlet.rtt * 2, 10)) if timeout_factor is not None else None
|
||||
self._log.debug(f"Set state: {state.name}, timeout {timeout}")
|
||||
orig_state = self.state
|
||||
self.state = state
|
||||
if timeout_factor is not None:
|
||||
self._call(functools.partial(self._check_protocol_timeout, lambda: self.state == orig_state, state.name), timeout)
|
||||
|
||||
def _call(self, func: callable, delay: float = 0):
|
||||
def call_inner():
|
||||
# self._log.debug("call_inner")
|
||||
if delay == 0:
|
||||
func()
|
||||
else:
|
||||
self.loop.call_later(delay, func)
|
||||
self.loop.call_soon_threadsafe(call_inner)
|
||||
|
||||
def send(self, message: protocol.Message):
|
||||
self.messenger.send(self.outlet, message)
|
||||
|
||||
def _protocol_error(self, name: str):
|
||||
self.terminate(f"Protocol error ({name})")
|
||||
|
||||
def _protocol_timeout_error(self, name: str):
|
||||
self.terminate(f"Protocol timeout error: {name}")
|
||||
|
||||
def terminate(self, error: str = None):
|
||||
with contextlib.suppress(Exception):
|
||||
self._log.debug("Terminating session" + (f": {error}" if error else ""))
|
||||
if error and self.state != LSState.LSSTATE_TEARDOWN:
|
||||
with contextlib.suppress(Exception):
|
||||
self.send(protocol.ErrorMessage(error, True))
|
||||
self.state = LSState.LSSTATE_ERROR
|
||||
self._terminate_process()
|
||||
self._call(self._prune, max(self.outlet.rtt * 3, 5))
|
||||
|
||||
def _prune(self):
|
||||
self.state = LSState.LSSTATE_TEARDOWN
|
||||
with contextlib.suppress(ValueError):
|
||||
self.sessions.remove(self)
|
||||
with contextlib.suppress(Exception):
|
||||
self.outlet.teardown()
|
||||
|
||||
def _check_protocol_timeout(self, fail_condition: Callable[[], bool], name: str):
|
||||
timeout = True
|
||||
try:
|
||||
timeout = self.state != LSState.LSSTATE_TEARDOWN and fail_condition()
|
||||
except Exception as ex:
|
||||
self._log.exception("Error in protocol timeout", ex)
|
||||
if timeout:
|
||||
self._protocol_timeout_error(name)
|
||||
|
||||
def _link_closed(self, outlet: LSOutletBase):
|
||||
outlet.unset_link_closed_callback()
|
||||
|
||||
if outlet != self.outlet:
|
||||
self._log.debug("Link closed received from incorrect outlet")
|
||||
return
|
||||
|
||||
self._log.debug(f"link_closed {outlet}")
|
||||
self.messenger.clear_retries(self.outlet)
|
||||
self.terminate()
|
||||
|
||||
def _initiator_identified(self, outlet, identity):
|
||||
if outlet != self.outlet:
|
||||
self._log.debug("Identity received from incorrect outlet")
|
||||
return
|
||||
|
||||
self._log.info(f"initiator_identified {identity} on link {outlet}")
|
||||
if self.state != LSState.LSSTATE_WAIT_IDENT:
|
||||
self._protocol_error(LSState.LSSTATE_WAIT_IDENT.name)
|
||||
|
||||
if not self.allow_all and identity.hash not in self.allowed_identity_hashes:
|
||||
self.terminate("Identity is not allowed.")
|
||||
|
||||
self.remote_identity = identity
|
||||
self.outlet.set_packet_received_callback(self._packet_received)
|
||||
self._set_state(LSState.LSSTATE_WAIT_VERS)
|
||||
|
||||
@classmethod
|
||||
async def pump_all(cls):
|
||||
for session in cls.sessions:
|
||||
session.pump()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
@classmethod
|
||||
async def terminate_all(cls, reason: str):
|
||||
for session in cls.sessions:
|
||||
session.terminate(reason)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
def pump(self):
|
||||
|
||||
try:
|
||||
if self.state != LSState.LSSTATE_RUNNING:
|
||||
return
|
||||
elif not self.messenger.is_outlet_ready(self.outlet):
|
||||
return
|
||||
elif len(self.stderr_buf) > 0:
|
||||
mdu = self.outlet.mdu - 16
|
||||
data = self.stderr_buf[:mdu]
|
||||
self.stderr_buf = self.stderr_buf[mdu:]
|
||||
send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent
|
||||
self.stderr_eof_sent = self.stderr_eof_sent or send_eof
|
||||
msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDERR,
|
||||
data, send_eof)
|
||||
self.send(msg)
|
||||
if send_eof:
|
||||
self.stderr_eof_sent = True
|
||||
elif len(self.stdout_buf) > 0:
|
||||
mdu = self.outlet.mdu - 16
|
||||
data = self.stdout_buf[:mdu]
|
||||
self.stdout_buf = self.stdout_buf[mdu:]
|
||||
send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent
|
||||
self.stdout_eof_sent = self.stdout_eof_sent or send_eof
|
||||
msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDOUT,
|
||||
data, send_eof)
|
||||
self.send(msg)
|
||||
elif self.return_code is not None and not self.return_code_sent:
|
||||
msg = protocol.CommandExitedMessage(self.return_code)
|
||||
self.send(msg)
|
||||
self.return_code_sent = True
|
||||
self._call(functools.partial(self._check_protocol_timeout,
|
||||
lambda: self.state == LSState.LSSTATE_RUNNING, "CommandExitedMessage"),
|
||||
max(self.outlet.rtt * 5, 10))
|
||||
except Exception as ex:
|
||||
self._log.exception("Error during pump", ex)
|
||||
|
||||
def _terminate_process(self):
|
||||
with contextlib.suppress(Exception):
|
||||
if self.process and self.process.running:
|
||||
self.process.terminate()
|
||||
|
||||
def _start_cmd(self, cmdline: [str], pipe_stdin: bool, pipe_stdout: bool, pipe_stderr: bool, tcflags: [any],
|
||||
term: str | None, rows: int, cols: int, hpix: int, vpix: int):
|
||||
|
||||
self.cmdline = self.default_command
|
||||
if not self.allow_remote_command and cmdline and len(cmdline) > 0:
|
||||
self.terminate("Remote command line not allowed by listener")
|
||||
return
|
||||
|
||||
if self.remote_cmd_as_args and cmdline and len(cmdline) > 0:
|
||||
self.cmdline.extend(cmdline)
|
||||
elif cmdline and len(cmdline) > 0:
|
||||
self.cmdline = cmdline
|
||||
|
||||
|
||||
self.stdin_is_pipe = pipe_stdin
|
||||
self.stdout_is_pipe = pipe_stdout
|
||||
self.stderr_is_pipe = pipe_stderr
|
||||
self.tcflags = tcflags
|
||||
self.term = term
|
||||
|
||||
def stdout(data: bytes):
|
||||
self.stdout_buf.extend(data)
|
||||
|
||||
def stderr(data: bytes):
|
||||
self.stderr_buf.extend(data)
|
||||
|
||||
try:
|
||||
self.process = process.CallbackSubprocess(argv=self.cmdline,
|
||||
env={"TERM": self.term or os.environ.get("TERM", None),
|
||||
"RNS_REMOTE_IDENTITY": RNS.prettyhexrep(self.remote_identity.hash) or ""},
|
||||
loop=self.loop,
|
||||
stdout_callback=stdout,
|
||||
stderr_callback=stderr,
|
||||
terminated_callback=self._terminated,
|
||||
stdin_is_pipe=self.stdin_is_pipe,
|
||||
stdout_is_pipe=self.stdout_is_pipe,
|
||||
stderr_is_pipe=self.stderr_is_pipe)
|
||||
self.process.start()
|
||||
self._set_window_size(rows, cols, hpix, vpix)
|
||||
except Exception as ex:
|
||||
self._log.exception(f"Unable to start process for link {self.outlet}", ex)
|
||||
self.terminate("Unable to start process")
|
||||
|
||||
def _set_window_size(self, rows: int, cols: int, hpix: int, vpix: int):
|
||||
self.rows = rows
|
||||
self.cols = cols
|
||||
self.hpix = hpix
|
||||
self.vpix = vpix
|
||||
with contextlib.suppress(Exception):
|
||||
self.process.set_winsize(rows, cols, hpix, vpix)
|
||||
|
||||
def _received_stdin(self, data: bytes, eof: bool):
|
||||
if data and len(data) > 0:
|
||||
self.process.write(data)
|
||||
if eof:
|
||||
self.process.close_stdin()
|
||||
|
||||
def _handle_message(self, message: protocol.Message):
|
||||
if self.state == LSState.LSSTATE_WAIT_VERS:
|
||||
if not isinstance(message, protocol.VersionInfoMessage):
|
||||
self._protocol_error(self.state.name)
|
||||
return
|
||||
self._log.info(f"version {message.sw_version}, protocol {message.protocol_version} on link {self.outlet}")
|
||||
if message.protocol_version != protocol.PROTOCOL_VERSION:
|
||||
self.terminate("Incompatible protocol")
|
||||
return
|
||||
self.send(protocol.VersionInfoMessage())
|
||||
self._set_state(LSState.LSSTATE_WAIT_CMD)
|
||||
return
|
||||
elif self.state == LSState.LSSTATE_WAIT_CMD:
|
||||
if not isinstance(message, protocol.ExecuteCommandMesssage):
|
||||
return self._protocol_error(self.state.name)
|
||||
self._log.info(f"Execute command message on link {self.outlet}: {message.cmdline}")
|
||||
self._set_state(LSState.LSSTATE_RUNNING)
|
||||
self._start_cmd(message.cmdline, message.pipe_stdin, message.pipe_stdout, message.pipe_stderr,
|
||||
message.tcflags, message.term, message.rows, message.cols, message.hpix, message.vpix)
|
||||
return
|
||||
elif self.state == LSState.LSSTATE_RUNNING:
|
||||
if isinstance(message, protocol.WindowSizeMessage):
|
||||
self._set_window_size(message.rows, message.cols, message.hpix, message.vpix)
|
||||
elif isinstance(message, protocol.StreamDataMessage):
|
||||
if message.stream_id != protocol.StreamDataMessage.STREAM_ID_STDIN:
|
||||
self._log.error(f"Received stream data for invalid stream {message.stream_id} on link {self.outlet}")
|
||||
return self._protocol_error(self.state.name)
|
||||
self._received_stdin(message.data, message.eof)
|
||||
return
|
||||
elif isinstance(message, protocol.NoopMessage):
|
||||
# echo noop only on listener--used for keepalive/connectivity check
|
||||
self.send(message)
|
||||
return
|
||||
elif self.state in [LSState.LSSTATE_ERROR, LSState.LSSTATE_TEARDOWN]:
|
||||
self._log.error(f"Received packet, but in state {self.state.name}")
|
||||
return
|
||||
else:
|
||||
self._protocol_error("unexpected message")
|
||||
return
|
||||
|
||||
def _packet_received(self, outlet: protocol.MessageOutletBase, raw: bytes):
|
||||
if outlet != self.outlet:
|
||||
self._log.debug("Packet received from incorrect outlet")
|
||||
return
|
||||
|
||||
try:
|
||||
message = self.messenger.receive(raw)
|
||||
self._handle_message(message)
|
||||
except Exception as ex:
|
||||
self._protocol_error("unusable packet")
|
||||
|
||||
|
||||
class RNSOutlet(LSOutletBase):
|
||||
|
||||
def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]):
|
||||
def inner_cb(link, identity: _TIdentity):
|
||||
cb(self, identity)
|
||||
|
||||
self.link.set_remote_identified_callback(inner_cb)
|
||||
|
||||
def set_link_closed_callback(self, cb: Callable[[LSOutletBase], None]):
|
||||
def inner_cb(link):
|
||||
cb(self)
|
||||
|
||||
self.link.set_link_closed_callback(inner_cb)
|
||||
|
||||
def unset_link_closed_callback(self):
|
||||
self.link.set_link_closed_callback(None)
|
||||
|
||||
def teardown(self):
|
||||
self.link.teardown()
|
||||
|
||||
def send(self, raw: bytes) -> RNS.PacketReceipt:
|
||||
packet = RNS.Packet(self.link, raw)
|
||||
packet.send()
|
||||
return packet.receipt
|
||||
|
||||
@property
|
||||
def mdu(self) -> int:
|
||||
return self.link.MDU
|
||||
|
||||
@property
|
||||
def rtt(self) -> float:
|
||||
return self.link.rtt
|
||||
|
||||
@property
|
||||
def is_usuable(self):
|
||||
return True #self.link.status in [RNS.Link.ACTIVE]
|
||||
|
||||
def get_receipt_state(self, receipt: RNS.PacketReceipt) -> MessageState:
|
||||
status = receipt.get_status()
|
||||
if status == RNS.PacketReceipt.SENT:
|
||||
return protocol.MessageState.MSGSTATE_SENT
|
||||
if status == RNS.PacketReceipt.DELIVERED:
|
||||
return protocol.MessageState.MSGSTATE_DELIVERED
|
||||
if status == RNS.PacketReceipt.FAILED:
|
||||
return protocol.MessageState.MSGSTATE_FAILED
|
||||
else:
|
||||
raise Exception(f"Unexpected receipt state: {status}")
|
||||
|
||||
def timed_out(self):
|
||||
self.link.teardown()
|
||||
|
||||
def __str__(self):
|
||||
return f"Outlet RNS Link {self.link}"
|
||||
|
||||
def set_packet_received_callback(self, cb: Callable[[MessageOutletBase, bytes], None]):
|
||||
def inner_cb(message, packet: RNS.Packet):
|
||||
packet.prove()
|
||||
cb(self, message)
|
||||
|
||||
self.link.set_packet_callback(inner_cb)
|
||||
|
||||
def __init__(self, link: RNS.Link):
|
||||
self.link = link
|
||||
link.lsoutlet = self
|
||||
link.msgoutlet = self
|
||||
@staticmethod
|
||||
def get_outlet(link: RNS.Link):
|
||||
if hasattr(link, "lsoutlet"):
|
||||
return link.lsoutlet
|
||||
|
||||
return RNSOutlet(link)
|
|
@ -1,6 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from rnsh.protocol import _TReceipt, MessageState
|
||||
from typing import Callable
|
||||
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
import rnsh.protocol
|
||||
|
@ -14,64 +18,60 @@ import uuid
|
|||
module_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Link:
|
||||
class Receipt:
|
||||
def __init__(self, state: rnsh.protocol.MessageState, raw: bytes):
|
||||
self.state = state
|
||||
self.raw = raw
|
||||
|
||||
|
||||
class MessageOutletTest(rnsh.protocol.MessageOutletBase):
|
||||
def __init__(self, mdu: int, rtt: float):
|
||||
self.link_id = uuid.uuid4()
|
||||
self.timeout_callbacks = 0
|
||||
self.mdu = mdu
|
||||
self.rtt = rtt
|
||||
self.usable = True
|
||||
self._mdu = mdu
|
||||
self._rtt = rtt
|
||||
self._usable = True
|
||||
self.receipts = []
|
||||
self.packet_callback: Callable[[rnsh.protocol.MessageOutletBase, bytes], None] | None = None
|
||||
|
||||
def timeout_callback(self):
|
||||
def send(self, raw: bytes) -> Receipt:
|
||||
receipt = Receipt(rnsh.protocol.MessageState.MSGSTATE_SENT, raw)
|
||||
self.receipts.append(receipt)
|
||||
return receipt
|
||||
|
||||
def set_packet_received_callback(self, cb: Callable[[rnsh.protocol.MessageOutletBase, bytes], None]):
|
||||
self.packet_callback = cb
|
||||
|
||||
def receive(self, raw: bytes):
|
||||
if self.packet_callback:
|
||||
self.packet_callback(self, raw)
|
||||
|
||||
@property
|
||||
def mdu(self):
|
||||
return self._mdu
|
||||
|
||||
@property
|
||||
def rtt(self):
|
||||
return self._rtt
|
||||
|
||||
@property
|
||||
def is_usuable(self):
|
||||
return self._usable
|
||||
|
||||
def get_receipt_state(self, receipt: Receipt) -> MessageState:
|
||||
return receipt.state
|
||||
|
||||
def timed_out(self):
|
||||
self.timeout_callbacks += 1
|
||||
|
||||
def __str__(self):
|
||||
return str(self.link_id)
|
||||
|
||||
|
||||
class Receipt:
|
||||
def __init__(self, link: Link, state: rnsh.protocol.MessageState, raw: bytes):
|
||||
self.state = state
|
||||
self.raw = raw
|
||||
self.link = link
|
||||
|
||||
|
||||
class ProtocolHarness(contextlib.AbstractContextManager):
|
||||
def __init__(self, retry_delay_min: float = 1):
|
||||
self._log = module_logger.getChild(self.__class__.__name__)
|
||||
self.messenger = rnsh.protocol.Messenger(receipt_checker=self.receipt_checker,
|
||||
link_timeout_callback=self.link_timeout_callback,
|
||||
link_mdu_getter=self.link_mdu_getter,
|
||||
link_rtt_getter=self.link_rtt_getter,
|
||||
link_usable_getter=self.link_usable_getter,
|
||||
packet_sender=self.packet_sender,
|
||||
retry_delay_min=retry_delay_min)
|
||||
|
||||
def packet_sender(self, link: Link, raw: bytes) -> Receipt:
|
||||
receipt = Receipt(link, rnsh.protocol.MessageState.MSGSTATE_SENT, raw)
|
||||
link.receipts.append(receipt)
|
||||
return receipt
|
||||
|
||||
@staticmethod
|
||||
def link_mdu_getter(link: Link):
|
||||
return link.mdu
|
||||
|
||||
@staticmethod
|
||||
def link_rtt_getter(link: Link):
|
||||
return link.rtt
|
||||
|
||||
@staticmethod
|
||||
def link_usable_getter(link: Link):
|
||||
return link.usable
|
||||
|
||||
@staticmethod
|
||||
def receipt_checker(receipt: Receipt) -> rnsh.protocol.MessageState:
|
||||
return receipt.state
|
||||
|
||||
@staticmethod
|
||||
def link_timeout_callback(link: Link):
|
||||
link.timeout_callback()
|
||||
self.messenger = rnsh.protocol.Messenger(retry_delay_min=retry_delay_min)
|
||||
|
||||
def cleanup(self):
|
||||
self.messenger.shutdown()
|
||||
|
@ -83,49 +83,58 @@ class ProtocolHarness(contextlib.AbstractContextManager):
|
|||
return False
|
||||
|
||||
|
||||
def test_mdu():
|
||||
with ProtocolHarness() as h:
|
||||
mdu = 500
|
||||
link = Link(mdu=mdu, rtt=0.25)
|
||||
assert h.messenger.get_mdu(link) == mdu - 4
|
||||
link.mdu = mdu = 600
|
||||
assert h.messenger.get_mdu(link) == mdu - 4
|
||||
|
||||
|
||||
def test_rtt():
|
||||
with ProtocolHarness() as h:
|
||||
rtt = 0.25
|
||||
link = Link(mdu=500, rtt=rtt)
|
||||
assert h.messenger.get_rtt(link) == rtt
|
||||
|
||||
|
||||
def test_send_one_retry():
|
||||
rtt = 0.001
|
||||
retry_interval = rtt * 150
|
||||
message_content = b'Test'
|
||||
with ProtocolHarness(retry_delay_min=retry_interval) as h:
|
||||
link = Link(mdu=500, rtt=rtt)
|
||||
outlet = MessageOutletTest(mdu=500, rtt=rtt)
|
||||
message = rnsh.protocol.StreamDataMessage(stream_id=rnsh.protocol.StreamDataMessage.STREAM_ID_STDIN,
|
||||
data=message_content, eof=True)
|
||||
assert len(link.receipts) == 0
|
||||
h.messenger.send_message(link, message)
|
||||
assert len(outlet.receipts) == 0
|
||||
h.messenger.send(outlet, message)
|
||||
assert message.tracked
|
||||
assert message.raw is not None
|
||||
assert len(link.receipts) == 1
|
||||
receipt = link.receipts[0]
|
||||
assert len(outlet.receipts) == 1
|
||||
receipt = outlet.receipts[0]
|
||||
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
|
||||
assert receipt.raw == message.raw
|
||||
time.sleep(retry_interval * 1.5)
|
||||
assert len(link.receipts) == 1
|
||||
assert len(outlet.receipts) == 1
|
||||
receipt.state = rnsh.protocol.MessageState.MSGSTATE_FAILED
|
||||
module_logger.info("set failed")
|
||||
time.sleep(retry_interval)
|
||||
assert len(link.receipts) == 2
|
||||
receipt = link.receipts[1]
|
||||
assert len(outlet.receipts) == 2
|
||||
receipt = outlet.receipts[1]
|
||||
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
|
||||
receipt.state = rnsh.protocol.MessageState.MSGSTATE_DELIVERED
|
||||
time.sleep(retry_interval)
|
||||
assert len(link.receipts) == 2
|
||||
assert len(outlet.receipts) == 2
|
||||
assert not message.tracked
|
||||
|
||||
|
||||
def test_send_timeout():
|
||||
rtt = 0.001
|
||||
retry_interval = rtt * 150
|
||||
message_content = b'Test'
|
||||
with ProtocolHarness(retry_delay_min=retry_interval) as h:
|
||||
outlet = MessageOutletTest(mdu=500, rtt=rtt)
|
||||
message = rnsh.protocol.StreamDataMessage(stream_id=rnsh.protocol.StreamDataMessage.STREAM_ID_STDIN,
|
||||
data=message_content, eof=True)
|
||||
assert len(outlet.receipts) == 0
|
||||
h.messenger.send(outlet, message)
|
||||
assert message.tracked
|
||||
assert message.raw is not None
|
||||
assert len(outlet.receipts) == 1
|
||||
receipt = outlet.receipts[0]
|
||||
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
|
||||
assert receipt.raw == message.raw
|
||||
time.sleep(retry_interval * 1.5)
|
||||
assert outlet.timeout_callbacks == 0
|
||||
time.sleep(retry_interval * 7)
|
||||
assert len(outlet.receipts) == 1
|
||||
assert outlet.timeout_callbacks == 1
|
||||
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
|
||||
assert not message.tracked
|
||||
|
||||
|
||||
|
@ -133,24 +142,32 @@ def eat_own_dog_food(message: rnsh.protocol.Message, checker: typing.Callable[[r
|
|||
rtt = 0.001
|
||||
retry_interval = rtt * 150
|
||||
with ProtocolHarness(retry_delay_min=retry_interval) as h:
|
||||
link = Link(mdu=500, rtt=rtt)
|
||||
assert len(link.receipts) == 0
|
||||
h.messenger.send_message(link, message)
|
||||
|
||||
decoded: [rnsh.protocol.Message] = []
|
||||
def packet(outlet, buffer):
|
||||
decoded.append(h.messenger.receive(buffer))
|
||||
|
||||
outlet = MessageOutletTest(mdu=500, rtt=rtt)
|
||||
outlet.set_packet_received_callback(packet)
|
||||
assert len(outlet.receipts) == 0
|
||||
h.messenger.send(outlet, message)
|
||||
assert message.tracked
|
||||
assert message.raw is not None
|
||||
assert len(link.receipts) == 1
|
||||
receipt = link.receipts[0]
|
||||
assert len(outlet.receipts) == 1
|
||||
receipt = outlet.receipts[0]
|
||||
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
|
||||
assert receipt.raw == message.raw
|
||||
module_logger.info("set delivered")
|
||||
receipt.state = rnsh.protocol.MessageState.MSGSTATE_DELIVERED
|
||||
time.sleep(retry_interval * 2)
|
||||
assert len(link.receipts) == 1
|
||||
assert len(outlet.receipts) == 1
|
||||
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_DELIVERED
|
||||
assert not message.tracked
|
||||
module_logger.info("injecting rx message")
|
||||
h.messenger.inbound(message.raw)
|
||||
rx_message = h.messenger.poll_inbound(block=False)
|
||||
assert len(decoded) == 0
|
||||
outlet.receive(message.raw)
|
||||
assert len(decoded) == 1
|
||||
rx_message = decoded[0]
|
||||
assert rx_message is not None
|
||||
assert isinstance(rx_message, message.__class__)
|
||||
assert rx_message.msgid != message.msgid
|
||||
|
@ -238,6 +255,16 @@ def test_send_receive_error():
|
|||
eat_own_dog_food(message, check)
|
||||
|
||||
|
||||
def test_send_receive_cmdexit():
|
||||
message = rnsh.protocol.CommandExitedMessage(5)
|
||||
|
||||
def check(rx_message: rnsh.protocol.Message):
|
||||
assert isinstance(rx_message, message.__class__)
|
||||
assert rx_message.return_code == message.return_code
|
||||
|
||||
eat_own_dog_food(message, check)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -12,19 +12,6 @@ import re
|
|||
import os
|
||||
|
||||
|
||||
|
||||
def test_check_magic():
|
||||
magic = rnsh.rnsh._PROTOCOL_VERSION_0
|
||||
# magic for version 0 is generated, make sure it comes out as expected
|
||||
assert magic == 0xdeadbeef00000000
|
||||
# verify the checker thinks it's right
|
||||
assert rnsh.rnsh._protocol_check_magic(magic)
|
||||
# scramble the magic
|
||||
magic = magic | 0x00ffff0000000000
|
||||
# make sure it fails now
|
||||
assert not rnsh.rnsh._protocol_check_magic(magic)
|
||||
|
||||
|
||||
def test_version():
|
||||
# version = importlib.metadata.version(rnsh.__version__)
|
||||
assert rnsh.__version__ != "0.0.0"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue