mirror of
https://github.com/markqvist/rnsh.git
synced 2025-02-25 09:21:15 -05:00
Got the new protocol working.
This commit is contained in:
parent
0ee305795f
commit
8edb4020b1
8
rnsh/helpers.py
Normal file
8
rnsh/helpers.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
def bitwise_or_if(value: int, condition: bool, orval: int):
|
||||||
|
if not condition:
|
||||||
|
return value
|
||||||
|
return value | orval
|
||||||
|
|
||||||
|
|
||||||
|
def check_and(value: int, andval: int) -> bool:
|
||||||
|
return (value & andval) > 0
|
@ -382,8 +382,11 @@ def _launch_child(cmd_line: list[str], env: dict[str, str], stdin_is_pipe: bool,
|
|||||||
# Make PTY controlling if necessary
|
# Make PTY controlling if necessary
|
||||||
if child_fd is not None:
|
if child_fd is not None:
|
||||||
os.setsid()
|
os.setsid()
|
||||||
|
try:
|
||||||
tmp_fd = os.open(os.ttyname(0 if not stdin_is_pipe else 1 if not stdout_is_pipe else 2), os.O_RDWR)
|
tmp_fd = os.open(os.ttyname(0 if not stdin_is_pipe else 1 if not stdout_is_pipe else 2), os.O_RDWR)
|
||||||
os.close(tmp_fd)
|
os.close(tmp_fd)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
# fcntl.ioctl(0 if not stdin_is_pipe else 1 if not stdout_is_pipe else 2), os.O_RDWR, termios.TIOCSCTTY, 0)
|
# fcntl.ioctl(0 if not stdin_is_pipe else 1 if not stdout_is_pipe else 2), os.O_RDWR, termios.TIOCSCTTY, 0)
|
||||||
|
|
||||||
# Execute the command
|
# Execute the command
|
||||||
@ -445,7 +448,8 @@ class CallbackSubprocess:
|
|||||||
self._child_stdout: int = None
|
self._child_stdout: int = None
|
||||||
self._child_stderr: int = None
|
self._child_stderr: int = None
|
||||||
self._return_code: int = None
|
self._return_code: int = None
|
||||||
self._eof: bool = False
|
self._stdout_eof: bool = False
|
||||||
|
self._stderr_eof: bool = False
|
||||||
self._stdin_is_pipe = stdin_is_pipe
|
self._stdin_is_pipe = stdin_is_pipe
|
||||||
self._stdout_is_pipe = stdout_is_pipe
|
self._stdout_is_pipe = stdout_is_pipe
|
||||||
self._stderr_is_pipe = stderr_is_pipe
|
self._stderr_is_pipe = stderr_is_pipe
|
||||||
@ -455,6 +459,21 @@ class CallbackSubprocess:
|
|||||||
Terminate child process if running
|
Terminate child process if running
|
||||||
:param kill_delay: if after kill_delay seconds the child process has not exited, escalate to SIGHUP and SIGKILL
|
:param kill_delay: if after kill_delay seconds the child process has not exited, escalate to SIGHUP and SIGKILL
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
same = self._child_stdout == self._child_stderr
|
||||||
|
with contextlib.suppress(OSError):
|
||||||
|
if not self._stdout_eof:
|
||||||
|
tty_unset_reader_callbacks(self._child_stdout)
|
||||||
|
os.close(self._child_stdout)
|
||||||
|
self._child_stdout = None
|
||||||
|
|
||||||
|
if not same:
|
||||||
|
with contextlib.suppress(OSError):
|
||||||
|
if not self._stderr_eof:
|
||||||
|
tty_unset_reader_callbacks(self._child_stderr)
|
||||||
|
os.close(self._child_stderr)
|
||||||
|
self._child_stdout = None
|
||||||
|
|
||||||
self._log.debug("terminate()")
|
self._log.debug("terminate()")
|
||||||
if not self.running:
|
if not self.running:
|
||||||
return
|
return
|
||||||
@ -477,7 +496,7 @@ class CallbackSubprocess:
|
|||||||
os.waitpid(self._pid, 0)
|
os.waitpid(self._pid, 0)
|
||||||
self._log.debug("wait() finish")
|
self._log.debug("wait() finish")
|
||||||
|
|
||||||
threading.Thread(target=wait).start()
|
threading.Thread(target=wait, daemon=True).start()
|
||||||
|
|
||||||
def close_stdin(self):
|
def close_stdin(self):
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
@ -592,26 +611,44 @@ class CallbackSubprocess:
|
|||||||
|
|
||||||
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
|
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
|
||||||
|
|
||||||
def reader(fd: int, callback: callable):
|
def stdout():
|
||||||
try:
|
try:
|
||||||
with exception.permit(SystemExit):
|
with exception.permit(SystemExit):
|
||||||
data = tty_read(fd)
|
data = tty_read_poll(self._child_stdout)
|
||||||
if data is not None and len(data) > 0:
|
if data is not None and len(data) > 0:
|
||||||
callback(data)
|
self._stdout_cb(data)
|
||||||
except EOFError:
|
except EOFError:
|
||||||
self._eof = True
|
self._stdout_eof = True
|
||||||
tty_unset_reader_callbacks(self._child_stdout)
|
tty_unset_reader_callbacks(self._child_stdout)
|
||||||
callback(bytearray())
|
self._stdout_cb(bytearray())
|
||||||
|
with contextlib.suppress(OSError):
|
||||||
|
os.close(self._child_stdout)
|
||||||
|
|
||||||
tty_add_reader_callback(self._child_stdout, functools.partial(reader, self._child_stdout, self._stdout_cb),
|
def stderr():
|
||||||
self._loop)
|
try:
|
||||||
|
with exception.permit(SystemExit):
|
||||||
|
data = tty_read_poll(self._child_stderr)
|
||||||
|
if data is not None and len(data) > 0:
|
||||||
|
self._stderr_cb(data)
|
||||||
|
except EOFError:
|
||||||
|
self._stderr_eof = True
|
||||||
|
tty_unset_reader_callbacks(self._child_stderr)
|
||||||
|
self._stdout_cb(bytearray())
|
||||||
|
with contextlib.suppress(OSError):
|
||||||
|
os.close(self._child_stderr)
|
||||||
|
|
||||||
|
tty_add_reader_callback(self._child_stdout, stdout, self._loop)
|
||||||
if self._child_stderr != self._child_stdout:
|
if self._child_stderr != self._child_stdout:
|
||||||
tty_add_reader_callback(self._child_stderr, functools.partial(reader, self._child_stderr, self._stderr_cb),
|
tty_add_reader_callback(self._child_stderr, stderr, self._loop)
|
||||||
self._loop)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def eof(self):
|
def stdout_eof(self):
|
||||||
return self._eof or not self.running
|
return self._stdout_eof or not self.running
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stderr_eof(self):
|
||||||
|
return self._stderr_eof or not self.running
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def return_code(self) -> int:
|
def return_code(self) -> int:
|
||||||
|
156
rnsh/protocol.py
156
rnsh/protocol.py
@ -15,6 +15,7 @@ import abc
|
|||||||
import contextlib
|
import contextlib
|
||||||
import struct
|
import struct
|
||||||
import logging as __logging
|
import logging as __logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
module_logger = __logging.getLogger(__name__)
|
module_logger = __logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -29,6 +30,43 @@ def _make_MSGTYPE(val: int):
|
|||||||
return ((MSG_MAGIC << 8) & 0xff00) | (val & 0x00ff)
|
return ((MSG_MAGIC << 8) & 0xff00) | (val & 0x00ff)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageOutletBase(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def send(self, raw: bytes) -> _TReceipt:
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def mdu(self):
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def rtt(self):
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def is_usuable(self):
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_receipt_state(self, receipt: _TReceipt) -> MessageState:
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def timed_out(self):
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __str__(self):
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_packet_received_callback(self, cb: Callable[[MessageOutletBase, bytes], None]):
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
|
||||||
class METype(enum.IntEnum):
|
class METype(enum.IntEnum):
|
||||||
ME_NO_MSG_TYPE = 0
|
ME_NO_MSG_TYPE = 0
|
||||||
ME_INVALID_MSG_TYPE = 1
|
ME_INVALID_MSG_TYPE = 1
|
||||||
@ -58,17 +96,17 @@ class Message(abc.ABC):
|
|||||||
self.msgid = uuid.uuid4()
|
self.msgid = uuid.uuid4()
|
||||||
self.raw: bytes | None = None
|
self.raw: bytes | None = None
|
||||||
self.receipt: _TReceipt = None
|
self.receipt: _TReceipt = None
|
||||||
self.link: _TLink = None
|
self.outlet: _TLink = None
|
||||||
self.tracked: bool = False
|
self.tracked: bool = False
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"{self.__class__.__name__} {self.msgid}"
|
return f"{self.__class__.__name__} {self.msgid}"
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abstractmethod
|
||||||
def pack(self) -> bytes:
|
def pack(self) -> bytes:
|
||||||
raise NotImplemented()
|
raise NotImplemented()
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abstractmethod
|
||||||
def unpack(self, raw):
|
def unpack(self, raw):
|
||||||
raise NotImplemented()
|
raise NotImplemented()
|
||||||
|
|
||||||
@ -124,7 +162,8 @@ class ExecuteCommandMesssage(Message):
|
|||||||
MSGTYPE = _make_MSGTYPE(3)
|
MSGTYPE = _make_MSGTYPE(3)
|
||||||
|
|
||||||
def __init__(self, cmdline: [str] = None, pipe_stdin: bool = False, pipe_stdout: bool = False,
|
def __init__(self, cmdline: [str] = None, pipe_stdin: bool = False, pipe_stdout: bool = False,
|
||||||
pipe_stderr: bool = False, tcflags: [any] = None, term: str | None = None):
|
pipe_stderr: bool = False, tcflags: [any] = None, term: str | None = None, rows: int = None,
|
||||||
|
cols: int = None, hpix: int = None, vpix: int = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.cmdline = cmdline
|
self.cmdline = cmdline
|
||||||
self.pipe_stdin = pipe_stdin
|
self.pipe_stdin = pipe_stdin
|
||||||
@ -132,16 +171,20 @@ class ExecuteCommandMesssage(Message):
|
|||||||
self.pipe_stderr = pipe_stderr
|
self.pipe_stderr = pipe_stderr
|
||||||
self.tcflags = tcflags
|
self.tcflags = tcflags
|
||||||
self.term = term
|
self.term = term
|
||||||
|
self.rows = rows
|
||||||
|
self.cols = cols
|
||||||
|
self.hpix = hpix
|
||||||
|
self.vpix = vpix
|
||||||
|
|
||||||
def pack(self) -> bytes:
|
def pack(self) -> bytes:
|
||||||
raw = umsgpack.packb((self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr,
|
raw = umsgpack.packb((self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr,
|
||||||
self.tcflags, self.term))
|
self.tcflags, self.term, self.rows, self.cols, self.hpix, self.vpix))
|
||||||
return self.wrap_MSGTYPE(raw)
|
return self.wrap_MSGTYPE(raw)
|
||||||
|
|
||||||
def unpack(self, raw):
|
def unpack(self, raw):
|
||||||
raw = self.unwrap_MSGTYPE(raw)
|
raw = self.unwrap_MSGTYPE(raw)
|
||||||
self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, self.tcflags, self.term \
|
self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, self.tcflags, self.term, self.rows, \
|
||||||
= umsgpack.unpackb(raw)
|
self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw)
|
||||||
|
|
||||||
class StreamDataMessage(Message):
|
class StreamDataMessage(Message):
|
||||||
MSGTYPE = _make_MSGTYPE(4)
|
MSGTYPE = _make_MSGTYPE(4)
|
||||||
@ -156,7 +199,7 @@ class StreamDataMessage(Message):
|
|||||||
self.eof = eof
|
self.eof = eof
|
||||||
|
|
||||||
def pack(self) -> bytes:
|
def pack(self) -> bytes:
|
||||||
raw = umsgpack.packb((self.stream_id, self.eof, self.data))
|
raw = umsgpack.packb((self.stream_id, self.eof, bytes(self.data)))
|
||||||
return self.wrap_MSGTYPE(raw)
|
return self.wrap_MSGTYPE(raw)
|
||||||
|
|
||||||
def unpack(self, raw):
|
def unpack(self, raw):
|
||||||
@ -169,7 +212,7 @@ class VersionInfoMessage(Message):
|
|||||||
|
|
||||||
def __init__(self, sw_version: str = None):
|
def __init__(self, sw_version: str = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.sw_version = sw_version
|
self.sw_version = sw_version or rnsh.__version__
|
||||||
self.protocol_version = PROTOCOL_VERSION
|
self.protocol_version = PROTOCOL_VERSION
|
||||||
|
|
||||||
def pack(self) -> bytes:
|
def pack(self) -> bytes:
|
||||||
@ -199,6 +242,22 @@ class ErrorMessage(Message):
|
|||||||
self.msg, self.fatal, self.data = umsgpack.unpackb(raw)
|
self.msg, self.fatal, self.data = umsgpack.unpackb(raw)
|
||||||
|
|
||||||
|
|
||||||
|
class CommandExitedMessage(Message):
|
||||||
|
MSGTYPE = _make_MSGTYPE(7)
|
||||||
|
|
||||||
|
def __init__(self, return_code: int = None):
|
||||||
|
super().__init__()
|
||||||
|
self.return_code = return_code
|
||||||
|
|
||||||
|
def pack(self) -> bytes:
|
||||||
|
raw = umsgpack.packb(self.return_code)
|
||||||
|
return self.wrap_MSGTYPE(raw)
|
||||||
|
|
||||||
|
def unpack(self, raw: bytes):
|
||||||
|
raw = self.unwrap_MSGTYPE(raw)
|
||||||
|
self.return_code = umsgpack.unpackb(raw)
|
||||||
|
|
||||||
|
|
||||||
class Messenger(contextlib.AbstractContextManager):
|
class Messenger(contextlib.AbstractContextManager):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -208,29 +267,16 @@ class Messenger(contextlib.AbstractContextManager):
|
|||||||
subclass_tuples.append((subclass.MSGTYPE, subclass))
|
subclass_tuples.append((subclass.MSGTYPE, subclass))
|
||||||
return subclass_tuples
|
return subclass_tuples
|
||||||
|
|
||||||
def __init__(self, receipt_checker: Callable[[_TReceipt], MessageState],
|
def __init__(self, retry_delay_min: float = 10.0):
|
||||||
link_timeout_callback: Callable[[_TLink], None],
|
|
||||||
link_mdu_getter: Callable[[_TLink], int],
|
|
||||||
link_rtt_getter: Callable[[_TLink], float],
|
|
||||||
link_usable_getter: Callable[[_TLink], bool],
|
|
||||||
packet_sender: Callable[[_TLink, bytes], _TReceipt],
|
|
||||||
retry_delay_min: float = 10.0):
|
|
||||||
self._log = module_logger.getChild(self.__class__.__name__)
|
self._log = module_logger.getChild(self.__class__.__name__)
|
||||||
self._receipt_checker = receipt_checker
|
|
||||||
self._link_timeout_callback = link_timeout_callback
|
|
||||||
self._link_mdu_getter = link_mdu_getter
|
|
||||||
self._link_rtt_getter = link_rtt_getter
|
|
||||||
self._link_usable_getter = link_usable_getter
|
|
||||||
self._packet_sender = packet_sender
|
|
||||||
self._sent_messages: list[Message] = []
|
self._sent_messages: list[Message] = []
|
||||||
self._lock = threading.RLock()
|
self._lock = threading.RLock()
|
||||||
self._retry_timer = rnsh.retry.RetryThread()
|
self._retry_timer = rnsh.retry.RetryThread()
|
||||||
self._message_factories = dict(self.__class__._get_msg_constructors())
|
self._message_factories = dict(self.__class__._get_msg_constructors())
|
||||||
self._inbound_queue = queue.Queue()
|
|
||||||
self._retry_delay_min = retry_delay_min
|
self._retry_delay_min = retry_delay_min
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self) -> Messenger:
|
||||||
pass
|
return self
|
||||||
|
|
||||||
def __exit__(self, __exc_type: Type[BaseException] | None, __exc_value: BaseException | None,
|
def __exit__(self, __exc_type: Type[BaseException] | None, __exc_value: BaseException | None,
|
||||||
__traceback: TracebackType | None) -> bool | None:
|
__traceback: TracebackType | None) -> bool | None:
|
||||||
@ -238,39 +284,38 @@ class Messenger(contextlib.AbstractContextManager):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
self._run = False
|
|
||||||
self._retry_timer.close()
|
self._retry_timer.close()
|
||||||
|
|
||||||
def inbound(self, raw: bytes):
|
def clear_retries(self, outlet):
|
||||||
|
self._retry_timer.complete(outlet)
|
||||||
|
|
||||||
|
def receive(self, raw: bytes) -> Message:
|
||||||
(mid, contents) = Message.static_unwrap_MSGTYPE(raw)
|
(mid, contents) = Message.static_unwrap_MSGTYPE(raw)
|
||||||
ctor = self._message_factories.get(mid, None)
|
ctor = self._message_factories.get(mid, None)
|
||||||
if ctor is None:
|
if ctor is None:
|
||||||
raise MessagingException(METype.ME_NOT_REGISTERED, f"unable to find constructor for message type {hex(mid)}")
|
raise MessagingException(METype.ME_NOT_REGISTERED, f"unable to find constructor for message type {hex(mid)}")
|
||||||
message = ctor()
|
message = ctor()
|
||||||
message.unpack(raw)
|
message.unpack(raw)
|
||||||
self._log.debug("Message received: {message}")
|
self._log.debug(f"Message received: {message}")
|
||||||
self._inbound_queue.put(message)
|
return message
|
||||||
|
|
||||||
def get_mdu(self, link: _TLink) -> int:
|
def is_outlet_ready(self, outlet: MessageOutletBase) -> bool:
|
||||||
return self._link_mdu_getter(link) - 4
|
if not outlet.is_usuable:
|
||||||
|
self._log.debug("is_outlet_ready outlet unusable")
|
||||||
def get_rtt(self, link: _TLink) -> float:
|
|
||||||
return self._link_rtt_getter(link)
|
|
||||||
|
|
||||||
def is_link_ready(self, link: _TLink) -> bool:
|
|
||||||
if not self._link_usable_getter(link):
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
for message in self._sent_messages:
|
for message in self._sent_messages:
|
||||||
if message.link == link:
|
if message.outlet == outlet and message.tracked and message.receipt \
|
||||||
|
and outlet.get_receipt_state(message.receipt) == MessageState.MSGSTATE_SENT:
|
||||||
|
self._log.debug("is_outlet_ready pending message found")
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def send_message(self, link: _TLink, message: Message):
|
def send(self, outlet: MessageOutletBase, message: Message):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if not self.is_link_ready(link):
|
if not self.is_outlet_ready(outlet):
|
||||||
raise MessagingException(METype.ME_LINK_NOT_READY, f"link {link} not ready")
|
raise MessagingException(METype.ME_LINK_NOT_READY, f"link {outlet} not ready")
|
||||||
|
|
||||||
if message in self._sent_messages:
|
if message in self._sent_messages:
|
||||||
raise MessagingException(METype.ME_ALREADY_SENT)
|
raise MessagingException(METype.ME_ALREADY_SENT)
|
||||||
@ -279,43 +324,36 @@ class Messenger(contextlib.AbstractContextManager):
|
|||||||
|
|
||||||
if not message.raw:
|
if not message.raw:
|
||||||
message.raw = message.pack()
|
message.raw = message.pack()
|
||||||
message.link = link
|
message.outlet = outlet
|
||||||
|
|
||||||
def send(tag: any, tries: int):
|
def send_inner(tag: any, tries: int):
|
||||||
state = MessageState.MSGSTATE_NEW if not message.receipt else self._receipt_checker(message.receipt)
|
state = MessageState.MSGSTATE_NEW if not message.receipt else outlet.get_receipt_state(message.receipt)
|
||||||
if state in [MessageState.MSGSTATE_NEW, MessageState.MSGSTATE_FAILED]:
|
if state in [MessageState.MSGSTATE_NEW, MessageState.MSGSTATE_FAILED]:
|
||||||
try:
|
try:
|
||||||
self._log.debug(f"Sending packet for {message}")
|
self._log.debug(f"Sending packet for {message}")
|
||||||
message.receipt = self._packet_sender(link, message.raw)
|
message.receipt = outlet.send(message.raw)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
self._log.exception(f"Error sending message {message}")
|
self._log.exception(f"Error sending message {message}")
|
||||||
elif state in [MessageState.MSGSTATE_SENT]:
|
elif state in [MessageState.MSGSTATE_SENT]:
|
||||||
self._log.debug(f"Retry skipped, message still pending {message}")
|
self._log.debug(f"Retry skipped, message still pending {message}")
|
||||||
elif state in [MessageState.MSGSTATE_DELIVERED]:
|
elif state in [MessageState.MSGSTATE_DELIVERED]:
|
||||||
latency = round(time.time() - message.ts, 1)
|
latency = round(time.time() - message.ts, 1)
|
||||||
self._log.debug(f"Message delivered {message.msgid} after {tries-1} tries/{latency} seconds")
|
self._log.debug(f"{message} delivered {message.msgid} after {tries-1} tries/{latency} seconds")
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._sent_messages.remove(message)
|
self._sent_messages.remove(message)
|
||||||
message.tracked = False
|
message.tracked = False
|
||||||
self._retry_timer.complete(link)
|
self._retry_timer.complete(outlet)
|
||||||
return link
|
return outlet
|
||||||
|
|
||||||
def timeout(tag: any, tries: int):
|
def timeout(tag: any, tries: int):
|
||||||
latency = round(time.time() - message.ts, 1)
|
latency = round(time.time() - message.ts, 1)
|
||||||
msg = "delivered" if message.receipt and self._receipt_checker(message.receipt) == MessageState.MSGSTATE_DELIVERED else "retry timeout"
|
msg = "delivered" if message.receipt and outlet.get_receipt_state(message.receipt) == MessageState.MSGSTATE_DELIVERED else "retry timeout"
|
||||||
self._log.debug(f"Message {msg} {message} after {tries} tries/{latency} seconds")
|
self._log.debug(f"Message {msg} {message} after {tries} tries/{latency} seconds")
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._sent_messages.remove(message)
|
self._sent_messages.remove(message)
|
||||||
message.tracked = False
|
message.tracked = False
|
||||||
self._link_timeout_callback(link)
|
outlet.timed_out()
|
||||||
|
|
||||||
rtt = self._link_rtt_getter(link)
|
|
||||||
self._retry_timer.begin(5, min(rtt * 100, max(rtt * 2, self._retry_delay_min)), send, timeout)
|
|
||||||
|
|
||||||
def poll_inbound(self, block: bool = True, timeout: float = None) -> Message | None:
|
|
||||||
try:
|
|
||||||
return self._inbound_queue.get(block=block, timeout=timeout)
|
|
||||||
except queue.Empty:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
rtt = outlet.rtt
|
||||||
|
self._retry_timer.begin(5, max(rtt * 5, self._retry_delay_min), send_inner, timeout)
|
||||||
|
|
||||||
|
@ -79,7 +79,7 @@ class RetryThread(AbstractContextManager):
|
|||||||
self._lock = threading.RLock()
|
self._lock = threading.RLock()
|
||||||
self._run = True
|
self._run = True
|
||||||
self._finished: asyncio.Future = None
|
self._finished: asyncio.Future = None
|
||||||
self._thread = threading.Thread(name=name, target=self._thread_run)
|
self._thread = threading.Thread(name=name, target=self._thread_run, daemon=True)
|
||||||
self._thread.start()
|
self._thread.start()
|
||||||
|
|
||||||
def is_alive(self):
|
def is_alive(self):
|
||||||
|
865
rnsh/rnsh.py
865
rnsh/rnsh.py
@ -26,10 +26,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
|
import enum
|
||||||
import functools
|
import functools
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
import logging as __logging
|
import logging as __logging
|
||||||
import os
|
import os
|
||||||
|
import queue
|
||||||
import shlex
|
import shlex
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
@ -44,10 +46,12 @@ import rnsh.process as process
|
|||||||
import rnsh.retry as retry
|
import rnsh.retry as retry
|
||||||
import rnsh.rnslogging as rnslogging
|
import rnsh.rnslogging as rnslogging
|
||||||
import rnsh.hacks as hacks
|
import rnsh.hacks as hacks
|
||||||
|
import rnsh.session as session
|
||||||
import re
|
import re
|
||||||
import contextlib
|
import contextlib
|
||||||
import rnsh.args
|
import rnsh.args
|
||||||
import pwd
|
import pwd
|
||||||
|
import rnsh.protocol as protocol
|
||||||
|
|
||||||
module_logger = __logging.getLogger(__name__)
|
module_logger = __logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -141,13 +145,18 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
|||||||
log.info(f"Using command {shlex.join(_cmd)}")
|
log.info(f"Using command {shlex.join(_cmd)}")
|
||||||
|
|
||||||
_no_remote_command = no_remote_command
|
_no_remote_command = no_remote_command
|
||||||
|
session.ListenerSession.allow_remote_command = not no_remote_command
|
||||||
_remote_cmd_as_args = remote_cmd_as_args
|
_remote_cmd_as_args = remote_cmd_as_args
|
||||||
if (_cmd is None or len(_cmd) == 0 or _cmd[0] is None or len(_cmd[0]) == 0) \
|
if (_cmd is None or len(_cmd) == 0 or _cmd[0] is None or len(_cmd[0]) == 0) \
|
||||||
and (_no_remote_command or _remote_cmd_as_args):
|
and (_no_remote_command or _remote_cmd_as_args):
|
||||||
raise Exception(f"Unable to look up shell for {os.getlogin}, cannot proceed with -A or -C and no <program>.")
|
raise Exception(f"Unable to look up shell for {os.getlogin}, cannot proceed with -A or -C and no <program>.")
|
||||||
|
|
||||||
|
session.ListenerSession.default_command = _cmd
|
||||||
|
session.ListenerSession.remote_cmd_as_args = _remote_cmd_as_args
|
||||||
|
|
||||||
if disable_auth:
|
if disable_auth:
|
||||||
_allow_all = True
|
_allow_all = True
|
||||||
|
session.ListenerSession.allow_all = True
|
||||||
else:
|
else:
|
||||||
if allowed is not None:
|
if allowed is not None:
|
||||||
for a in allowed:
|
for a in allowed:
|
||||||
@ -161,6 +170,7 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
|||||||
try:
|
try:
|
||||||
destination_hash = bytes.fromhex(a)
|
destination_hash = bytes.fromhex(a)
|
||||||
_allowed_identity_hashes.append(destination_hash)
|
_allowed_identity_hashes.append(destination_hash)
|
||||||
|
session.ListenerSession.allowed_identity_hashes.append(destination_hash)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise ValueError("Invalid destination entered. Check your input.")
|
raise ValueError("Invalid destination entered. Check your input.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -170,23 +180,9 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
|||||||
if len(_allowed_identity_hashes) < 1 and not disable_auth:
|
if len(_allowed_identity_hashes) < 1 and not disable_auth:
|
||||||
log.warning("Warning: No allowed identities configured, rnsh will not accept any connections!")
|
log.warning("Warning: No allowed identities configured, rnsh will not accept any connections!")
|
||||||
|
|
||||||
_destination.set_link_established_callback(_listen_link_established)
|
def link_established(lnk: RNS.Link):
|
||||||
|
session.ListenerSession(session.RNSOutlet.get_outlet(lnk), _loop)
|
||||||
if not _allow_all:
|
_destination.set_link_established_callback(link_established)
|
||||||
_destination.register_request_handler(
|
|
||||||
path="data",
|
|
||||||
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
|
|
||||||
# response_generator=_listen_request,
|
|
||||||
allow=RNS.Destination.ALLOW_LIST,
|
|
||||||
allowed_list=_allowed_identity_hashes
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
_destination.register_request_handler(
|
|
||||||
path="data",
|
|
||||||
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
|
|
||||||
# response_generator=_listen_request,
|
|
||||||
allow=RNS.Destination.ALLOW_ALL,
|
|
||||||
)
|
|
||||||
|
|
||||||
if await _check_finished():
|
if await _check_finished():
|
||||||
return
|
return
|
||||||
@ -199,531 +195,22 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
|||||||
last = time.time()
|
last = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while not await _check_finished(1.0):
|
while not await _check_finished():
|
||||||
if announce_period and 0 < announce_period < time.time() - last:
|
if announce_period and 0 < announce_period < time.time() - last:
|
||||||
last = time.time()
|
last = time.time()
|
||||||
_destination.announce()
|
_destination.announce()
|
||||||
|
await session.ListenerSession.pump_all()
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
finally:
|
finally:
|
||||||
log.warning("Shutting down")
|
log.warning("Shutting down")
|
||||||
for link in list(_destination.links):
|
await session.ListenerSession.terminate_all("Shutting down")
|
||||||
with exception.permit(SystemExit, KeyboardInterrupt):
|
await asyncio.sleep(1)
|
||||||
proc = Session.get_for_tag(link.link_id)
|
session.ListenerSession.messenger.shutdown()
|
||||||
if proc is not None and proc.process.running:
|
|
||||||
proc.process.terminate()
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
|
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
|
||||||
for link in links_still_active:
|
for link in links_still_active:
|
||||||
if link.status != RNS.Link.CLOSED:
|
if link.status not in [RNS.Link.CLOSED]:
|
||||||
link.teardown()
|
link.teardown()
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
|
|
||||||
_PROTOCOL_MAGIC = 0xdeadbeef
|
|
||||||
|
|
||||||
|
|
||||||
def _protocol_make_version(version: int):
|
|
||||||
return (_PROTOCOL_MAGIC << 32) & 0xffffffff00000000 | (0xffffffff & version)
|
|
||||||
|
|
||||||
|
|
||||||
_PROTOCOL_VERSION_0 = _protocol_make_version(0)
|
|
||||||
_PROTOCOL_VERSION_1 = _protocol_make_version(1)
|
|
||||||
_PROTOCOL_VERSION_2 = _protocol_make_version(2)
|
|
||||||
_PROTOCOL_VERSION_3 = _protocol_make_version(3)
|
|
||||||
|
|
||||||
_PROTOCOL_VERSION_DEFAULT = _PROTOCOL_VERSION_3
|
|
||||||
|
|
||||||
def _protocol_split_version(version: int):
|
|
||||||
return (version >> 32) & 0xffffffff, version & 0xffffffff
|
|
||||||
|
|
||||||
|
|
||||||
def _protocol_check_magic(value: int):
|
|
||||||
return _protocol_split_version(value)[0] == _PROTOCOL_MAGIC
|
|
||||||
|
|
||||||
|
|
||||||
def _protocol_response_chars_take(link_mdu: int, version: int) -> int:
|
|
||||||
if version >= _PROTOCOL_VERSION_2:
|
|
||||||
return link_mdu - 64 # TODO: tune
|
|
||||||
else:
|
|
||||||
return link_mdu // 2
|
|
||||||
|
|
||||||
|
|
||||||
def _protocol_request_chars_take(link_mdu: int, version: int, term: str, cmd: str) -> int:
|
|
||||||
if version >= _PROTOCOL_VERSION_2:
|
|
||||||
return link_mdu - 15 * 8 - len(term) - len(cmd) - 20 # TODO: tune
|
|
||||||
else:
|
|
||||||
return link_mdu // 2
|
|
||||||
|
|
||||||
|
|
||||||
def _bitwise_or_if(value: int, condition: bool, orval: int):
|
|
||||||
if not condition:
|
|
||||||
return value
|
|
||||||
return value | orval
|
|
||||||
|
|
||||||
|
|
||||||
def _check_and(value: int, andval: int) -> bool:
|
|
||||||
return (value & andval) > 0
|
|
||||||
|
|
||||||
|
|
||||||
class Session:
|
|
||||||
_processes: [(any, Session)] = []
|
|
||||||
_lock = threading.RLock()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_for_tag(cls, tag: any) -> Session | None:
|
|
||||||
with cls._lock:
|
|
||||||
return next(map(lambda p: p[1], filter(lambda p: p[0] == tag, cls._processes)), None)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def put_for_tag(cls, tag: any, ps: Session):
|
|
||||||
with cls._lock:
|
|
||||||
cls.clear_tag(tag)
|
|
||||||
cls._processes.append((tag, ps))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def clear_tag(cls, tag: any):
|
|
||||||
with cls._lock:
|
|
||||||
with exception.permit(SystemExit):
|
|
||||||
cls._processes.remove(tag)
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
tag: any,
|
|
||||||
cmd: [str],
|
|
||||||
data_available_callback: callable,
|
|
||||||
terminated_callback: callable,
|
|
||||||
session_flags: int,
|
|
||||||
term: str | None,
|
|
||||||
remote_identity: str | None,
|
|
||||||
loop: asyncio.AbstractEventLoop = None):
|
|
||||||
|
|
||||||
self._log = _get_logger(self.__class__.__name__)
|
|
||||||
self._loop = loop if loop is not None else asyncio.get_running_loop()
|
|
||||||
self._process = process.CallbackSubprocess(argv=cmd,
|
|
||||||
env={"TERM": term or os.environ.get("TERM", None),
|
|
||||||
"RNS_REMOTE_IDENTITY": remote_identity or ""},
|
|
||||||
loop=loop,
|
|
||||||
stdout_callback=self._stdout_data,
|
|
||||||
stderr_callback=self._stderr_data,
|
|
||||||
terminated_callback=terminated_callback,
|
|
||||||
stdin_is_pipe=_check_and(session_flags,
|
|
||||||
Session.REQUEST_FLAGS_PIPE_STDIN),
|
|
||||||
stdout_is_pipe=_check_and(session_flags,
|
|
||||||
Session.REQUEST_FLAGS_PIPE_STDOUT),
|
|
||||||
stderr_is_pipe=_check_and(session_flags,
|
|
||||||
Session.REQUEST_FLAGS_PIPE_STDERR))
|
|
||||||
self._log.debug(f"Starting {cmd}")
|
|
||||||
self._stdout_buffer = bytearray()
|
|
||||||
self._stderr_buffer = bytearray()
|
|
||||||
self._lock = threading.RLock()
|
|
||||||
self._data_available_cb = data_available_callback
|
|
||||||
self._terminated_cb = terminated_callback
|
|
||||||
self._pending_receipt: RNS.PacketReceipt | None = None
|
|
||||||
self._process.start()
|
|
||||||
self._term_state: [int] = None
|
|
||||||
self._session_flags: int = session_flags
|
|
||||||
Session.put_for_tag(tag, self)
|
|
||||||
|
|
||||||
def pending_receipt_peek(self) -> RNS.PacketReceipt | None:
|
|
||||||
return self._pending_receipt
|
|
||||||
|
|
||||||
def pending_receipt_take(self) -> RNS.PacketReceipt | None:
|
|
||||||
with self._lock:
|
|
||||||
val = self._pending_receipt
|
|
||||||
self._pending_receipt = None
|
|
||||||
return val
|
|
||||||
|
|
||||||
def pending_receipt_put(self, receipt: RNS.PacketReceipt | None):
|
|
||||||
with self._lock:
|
|
||||||
self._pending_receipt = receipt
|
|
||||||
|
|
||||||
@property
|
|
||||||
def process(self) -> process.CallbackSubprocess:
|
|
||||||
return self._process
|
|
||||||
|
|
||||||
@property
|
|
||||||
def return_code(self) -> int | None:
|
|
||||||
return self.process.return_code
|
|
||||||
|
|
||||||
@property
|
|
||||||
def lock(self) -> threading.RLock:
|
|
||||||
return self._lock
|
|
||||||
|
|
||||||
def read_stdout(self, count: int) -> bytes:
|
|
||||||
with self.lock:
|
|
||||||
initial_len = len(self._stdout_buffer)
|
|
||||||
take = self._stdout_buffer[:count]
|
|
||||||
self._stdout_buffer = self._stdout_buffer[count:]
|
|
||||||
self._log.debug(f"stdout: read {len(take)} bytes of {initial_len}, {len(self._stdout_buffer)} remaining")
|
|
||||||
return take
|
|
||||||
|
|
||||||
def _stdout_data(self, data: bytes):
|
|
||||||
with self.lock:
|
|
||||||
self._stdout_buffer.extend(data)
|
|
||||||
total_available = len(self._stdout_buffer) + len(self._stderr_buffer)
|
|
||||||
try:
|
|
||||||
self._data_available_cb(total_available)
|
|
||||||
except Exception as e:
|
|
||||||
self._log.error(f"stdout: error calling ProcessState data_available_callback {e}")
|
|
||||||
|
|
||||||
def read_stderr(self, count: int) -> bytes:
|
|
||||||
with self.lock:
|
|
||||||
initial_len = len(self._stderr_buffer)
|
|
||||||
take = self._stderr_buffer[:count]
|
|
||||||
self._stderr_buffer = self._stderr_buffer[count:]
|
|
||||||
self._log.debug(f"stderr: read {len(take)} bytes of {initial_len}, {len(self._stderr_buffer)} remaining")
|
|
||||||
return take
|
|
||||||
|
|
||||||
def _stderr_data(self, data: bytes):
|
|
||||||
with self.lock:
|
|
||||||
self._stderr_buffer.extend(data)
|
|
||||||
total_available = len(self._stderr_buffer) + len(self._stdout_buffer)
|
|
||||||
try:
|
|
||||||
self._data_available_cb(total_available)
|
|
||||||
except Exception as e:
|
|
||||||
self._log.error(f"stderr: error calling ProcessState data_available_callback {e}")
|
|
||||||
|
|
||||||
TERMSTATE_IDX_TERM = 0
|
|
||||||
TERMSTATE_IDX_TIOS = 1
|
|
||||||
TERMSTATE_IDX_ROWS = 2
|
|
||||||
TERMSTATE_IDX_COLS = 3
|
|
||||||
TERMSTATE_IDX_HPIX = 4
|
|
||||||
TERMSTATE_IDX_VPIX = 5
|
|
||||||
|
|
||||||
def _update_winsz(self):
|
|
||||||
try:
|
|
||||||
self.process.set_winsize(self._term_state[1],
|
|
||||||
self._term_state[2],
|
|
||||||
self._term_state[3],
|
|
||||||
self._term_state[4])
|
|
||||||
except Exception as e:
|
|
||||||
self._log.debug(f"failed to update winsz: {e}")
|
|
||||||
|
|
||||||
REQUEST_IDX_VERS = 0
|
|
||||||
REQUEST_IDX_STDIN = 1
|
|
||||||
REQUEST_IDX_TERM = 2
|
|
||||||
REQUEST_IDX_TIOS = 3
|
|
||||||
REQUEST_IDX_ROWS = 4
|
|
||||||
REQUEST_IDX_COLS = 5
|
|
||||||
REQUEST_IDX_HPIX = 6
|
|
||||||
REQUEST_IDX_VPIX = 7
|
|
||||||
REQUEST_IDX_CMD = 8
|
|
||||||
REQUEST_IDX_FLAGS = 9
|
|
||||||
REQUEST_IDX_BYTES_AVAILABLE = 10
|
|
||||||
REQUEST_FLAGS_PIPE_STDIN = 0x01
|
|
||||||
REQUEST_FLAGS_PIPE_STDOUT = 0x02
|
|
||||||
REQUEST_FLAGS_PIPE_STDERR = 0x04
|
|
||||||
REQUEST_FLAGS_EOF_STDIN = 0x08
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def default_request() -> [any]:
|
|
||||||
global _tr
|
|
||||||
request: list[any] = [
|
|
||||||
_PROTOCOL_VERSION_DEFAULT, # 0 Protocol Version
|
|
||||||
None, # 1 Stdin
|
|
||||||
None, # 2 TERM variable
|
|
||||||
None, # 3 termios attributes or something
|
|
||||||
None, # 4 terminal rows
|
|
||||||
None, # 5 terminal cols
|
|
||||||
None, # 6 terminal horizontal pixels
|
|
||||||
None, # 7 terminal vertical pixels
|
|
||||||
None, # 8 Command to run
|
|
||||||
0, # 9 Flags
|
|
||||||
0, # 10 Bytes Available
|
|
||||||
].copy()
|
|
||||||
|
|
||||||
if os.isatty(0):
|
|
||||||
request[Session.REQUEST_IDX_TERM] = os.environ.get("TERM", None)
|
|
||||||
request[Session.REQUEST_IDX_TIOS] = _tr.original_attr() if _tr else None
|
|
||||||
with contextlib.suppress(OSError):
|
|
||||||
request[Session.REQUEST_IDX_ROWS], \
|
|
||||||
request[Session.REQUEST_IDX_COLS], \
|
|
||||||
request[Session.REQUEST_IDX_HPIX], \
|
|
||||||
request[Session.REQUEST_IDX_VPIX] = process.tty_get_winsize(0)
|
|
||||||
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], not os.isatty(0),
|
|
||||||
Session.REQUEST_FLAGS_PIPE_STDIN)
|
|
||||||
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], not os.isatty(1),
|
|
||||||
Session.REQUEST_FLAGS_PIPE_STDOUT)
|
|
||||||
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], not os.isatty(2),
|
|
||||||
Session.REQUEST_FLAGS_PIPE_STDERR)
|
|
||||||
return request
|
|
||||||
|
|
||||||
def process_request(self, data: [any], read_size: int) -> [any]:
|
|
||||||
stdin = data[Session.REQUEST_IDX_STDIN] # Data passed to stdin
|
|
||||||
# term = data[ProcessState.REQUEST_IDX_TERM] # TERM environment variable
|
|
||||||
# tios = data[ProcessState.REQUEST_IDX_TIOS] # termios attr
|
|
||||||
# rows = data[ProcessState.REQUEST_IDX_ROWS] # window rows
|
|
||||||
# cols = data[ProcessState.REQUEST_IDX_COLS] # window cols
|
|
||||||
# hpix = data[ProcessState.REQUEST_IDX_HPIX] # window horizontal pixels
|
|
||||||
# vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels
|
|
||||||
# term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX+1]
|
|
||||||
bytes_available = data[Session.REQUEST_IDX_BYTES_AVAILABLE]
|
|
||||||
flags = data[Session.REQUEST_IDX_FLAGS]
|
|
||||||
stdin_eof = _check_and(flags, Session.REQUEST_FLAGS_EOF_STDIN)
|
|
||||||
response = Session.default_response()
|
|
||||||
|
|
||||||
first_term_state = self._term_state is None
|
|
||||||
term_state = data[Session.REQUEST_IDX_TIOS:Session.REQUEST_IDX_VPIX + 1]
|
|
||||||
|
|
||||||
response[Session.RESPONSE_IDX_FLAGS] = _bitwise_or_if(response[Session.RESPONSE_IDX_FLAGS],
|
|
||||||
self.process.running, Session.RESPONSE_FLAGS_RUNNING)
|
|
||||||
if self.process.running:
|
|
||||||
if term_state != self._term_state:
|
|
||||||
self._term_state = term_state
|
|
||||||
if term_state is not None:
|
|
||||||
self._update_winsz()
|
|
||||||
if first_term_state is not None:
|
|
||||||
# TODO: use a more specific error
|
|
||||||
with contextlib.suppress(Exception):
|
|
||||||
self.process.tcsetattr(termios.TCSADRAIN, term_state[0])
|
|
||||||
if stdin is not None and len(stdin) > 0:
|
|
||||||
if data[Session.REQUEST_IDX_VERS] < _PROTOCOL_VERSION_2:
|
|
||||||
stdin = base64.b64decode(stdin)
|
|
||||||
self.process.write(stdin)
|
|
||||||
if stdin_eof and bytes_available == 0:
|
|
||||||
module_logger.debug("Closing stdin")
|
|
||||||
with contextlib.suppress(Exception):
|
|
||||||
self.process.close_stdin()
|
|
||||||
response[Session.RESPONSE_IDX_RETCODE] = None if self.process.running else self.return_code
|
|
||||||
|
|
||||||
with self.lock:
|
|
||||||
#prioritizing stderr
|
|
||||||
stderr = self.read_stderr(read_size)
|
|
||||||
stdout = self.read_stdout(read_size - len(stderr))
|
|
||||||
response[Session.RESPONSE_IDX_RDYBYTE] = len(self._stdout_buffer) + len(self._stderr_buffer)
|
|
||||||
|
|
||||||
if stderr is not None and len(stderr) > 0:
|
|
||||||
response[Session.RESPONSE_IDX_STDERR] = bytes(stderr)
|
|
||||||
if stdout is not None and len(stdout) > 0:
|
|
||||||
response[Session.RESPONSE_IDX_STDOUT] = bytes(stdout)
|
|
||||||
return response
|
|
||||||
|
|
||||||
RESPONSE_IDX_VERSION = 0
|
|
||||||
RESPONSE_IDX_FLAGS = 1
|
|
||||||
RESPONSE_IDX_RETCODE = 2
|
|
||||||
RESPONSE_IDX_RDYBYTE = 3
|
|
||||||
RESPONSE_IDX_STDERR = 4
|
|
||||||
RESPONSE_IDX_STDOUT = 5
|
|
||||||
RESPONSE_IDX_TMSTAMP = 6
|
|
||||||
RESPONSE_FLAGS_RUNNING = 0x01
|
|
||||||
RESPONSE_FLAGS_EOF_STDOUT = 0x02
|
|
||||||
RESPONSE_FLAGS_EOF_STDERR = 0x04
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def default_response() -> [any]:
|
|
||||||
response: list[any] = [
|
|
||||||
_PROTOCOL_VERSION_DEFAULT, # 0: Protocol version
|
|
||||||
False, # 1: Process running
|
|
||||||
None, # 2: Return value
|
|
||||||
0, # 3: Number of outstanding bytes
|
|
||||||
None, # 4: Stderr
|
|
||||||
None, # 5: Stdout
|
|
||||||
None, # 6: Timestamp
|
|
||||||
].copy()
|
|
||||||
response[Session.RESPONSE_IDX_TMSTAMP] = time.time()
|
|
||||||
return response
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def error_response(cls, msg: str) -> [any]:
|
|
||||||
response = cls.default_response()
|
|
||||||
msg_bytes = f"{msg}\r\n".encode("utf-8")
|
|
||||||
response[Session.RESPONSE_IDX_STDERR] = bytes(msg_bytes)
|
|
||||||
response[Session.RESPONSE_IDX_RETCODE] = 255
|
|
||||||
response[Session.RESPONSE_IDX_RDYBYTE] = 0
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
def _subproc_data_ready(link: RNS.Link, chars_available: int):
|
|
||||||
global _retry_timer
|
|
||||||
log = _get_logger("_subproc_data_ready")
|
|
||||||
session: Session = Session.get_for_tag(link.link_id)
|
|
||||||
|
|
||||||
def send(timeout: bool, tag: any, tries: int) -> any:
|
|
||||||
# log.debug("send")
|
|
||||||
def inner():
|
|
||||||
# log.debug("inner")
|
|
||||||
try:
|
|
||||||
if link.status != RNS.Link.ACTIVE:
|
|
||||||
_retry_timer.complete(link.link_id)
|
|
||||||
session.pending_receipt_take()
|
|
||||||
return
|
|
||||||
|
|
||||||
pr = session.pending_receipt_take()
|
|
||||||
log.debug(f"send inner pr: {pr}")
|
|
||||||
if pr is not None and pr.status == RNS.PacketReceipt.DELIVERED:
|
|
||||||
if not timeout:
|
|
||||||
_retry_timer.complete(tag)
|
|
||||||
log.debug(f"Notification completed with status {pr.status} on link {link}")
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
if not timeout:
|
|
||||||
log.debug(
|
|
||||||
f"Notifying client try {tries} (retcode: {session.return_code} " +
|
|
||||||
f"chars avail: {chars_available})")
|
|
||||||
packet = RNS.Packet(link, DATA_AVAIL_MSG.encode("utf-8"))
|
|
||||||
packet.send()
|
|
||||||
pr = packet.receipt
|
|
||||||
session.pending_receipt_put(pr)
|
|
||||||
else:
|
|
||||||
log.error(f"Retry count exceeded, terminating link {link}")
|
|
||||||
_retry_timer.complete(link.link_id)
|
|
||||||
link.teardown()
|
|
||||||
except Exception as e:
|
|
||||||
log.error("Error notifying client: " + str(e))
|
|
||||||
|
|
||||||
_loop.call_soon_threadsafe(inner)
|
|
||||||
return link.link_id
|
|
||||||
|
|
||||||
with session.lock:
|
|
||||||
if not _retry_timer.has_tag(link.link_id):
|
|
||||||
_retry_timer.begin(try_limit=15,
|
|
||||||
wait_delay=max(link.rtt * 5 if link.rtt is not None else 1, 1),
|
|
||||||
try_callback=functools.partial(send, False),
|
|
||||||
timeout_callback=functools.partial(send, True))
|
|
||||||
else:
|
|
||||||
log.debug(f"Notification already pending for link {link}")
|
|
||||||
|
|
||||||
|
|
||||||
def _subproc_terminated(link: RNS.Link, return_code: int):
|
|
||||||
global _loop
|
|
||||||
log = _get_logger("_subproc_terminated")
|
|
||||||
log.info(f"Subprocess returned {return_code} for link {link}")
|
|
||||||
proc = Session.get_for_tag(link.link_id)
|
|
||||||
if proc is None:
|
|
||||||
log.debug(f"no proc for link {link}")
|
|
||||||
return
|
|
||||||
|
|
||||||
def cleanup():
|
|
||||||
def inner():
|
|
||||||
log.debug(f"cleanup culled link {link}")
|
|
||||||
if link and link.status != RNS.Link.CLOSED:
|
|
||||||
with exception.permit(SystemExit):
|
|
||||||
try:
|
|
||||||
link.teardown()
|
|
||||||
finally:
|
|
||||||
Session.clear_tag(link.link_id)
|
|
||||||
|
|
||||||
_loop.call_later(300, inner)
|
|
||||||
_loop.call_soon(_subproc_data_ready, link, 0)
|
|
||||||
|
|
||||||
_loop.call_soon_threadsafe(cleanup)
|
|
||||||
|
|
||||||
|
|
||||||
def _listen_start_proc(link: RNS.Link, remote_identity: str | None, term: str, cmd: [str],
|
|
||||||
loop: asyncio.AbstractEventLoop, session_flags: int) -> Session | None:
|
|
||||||
log = _get_logger("_listen_start_proc")
|
|
||||||
try:
|
|
||||||
return Session(tag=link.link_id,
|
|
||||||
cmd=cmd,
|
|
||||||
session_flags=session_flags,
|
|
||||||
term=term,
|
|
||||||
remote_identity=remote_identity,
|
|
||||||
loop=loop,
|
|
||||||
data_available_callback=functools.partial(_subproc_data_ready, link),
|
|
||||||
terminated_callback=functools.partial(_subproc_terminated, link))
|
|
||||||
except Exception as e:
|
|
||||||
log.error("Failed to launch process: " + str(e))
|
|
||||||
_subproc_terminated(link, 255)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _listen_link_established(link):
|
|
||||||
global _allow_all
|
|
||||||
log = _get_logger("_listen_link_established")
|
|
||||||
link.set_remote_identified_callback(_initiator_identified)
|
|
||||||
link.set_link_closed_callback(_listen_link_closed)
|
|
||||||
log.info("Link " + str(link) + " established")
|
|
||||||
|
|
||||||
|
|
||||||
def _listen_link_closed(link: RNS.Link):
|
|
||||||
log = _get_logger("_listen_link_closed")
|
|
||||||
# async def cleanup():
|
|
||||||
log.info("Link " + str(link) + " closed")
|
|
||||||
proc: Session | None = Session.get_for_tag(link.link_id)
|
|
||||||
if proc is None:
|
|
||||||
log.warning(f"No process for link {link}")
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
proc.process.terminate()
|
|
||||||
_retry_timer.complete(link.link_id)
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"Error closing process for link {link}: {e}")
|
|
||||||
Session.clear_tag(link.link_id)
|
|
||||||
|
|
||||||
|
|
||||||
def _initiator_identified(link, identity):
|
|
||||||
global _allow_all, _cmd, _loop
|
|
||||||
log = _get_logger("_initiator_identified")
|
|
||||||
log.info("Initiator of link " + str(link) + " identified as " + RNS.prettyhexrep(identity.hash))
|
|
||||||
if not _allow_all and identity.hash not in _allowed_identity_hashes:
|
|
||||||
log.warning("Identity " + RNS.prettyhexrep(identity.hash) + " not allowed, tearing down link", RNS.LOG_WARNING)
|
|
||||||
link.teardown()
|
|
||||||
|
|
||||||
|
|
||||||
def _listen_request(path, data, request_id, link_id, remote_identity, requested_at):
|
|
||||||
global _destination, _retry_timer, _loop, _cmd, _no_remote_command
|
|
||||||
log = _get_logger("_listen_request")
|
|
||||||
log.debug(
|
|
||||||
f"listen_execute {path} {RNS.prettyhexrep(request_id)} {RNS.prettyhexrep(link_id)} {remote_identity}, {requested_at}")
|
|
||||||
if not hasattr(data, "__len__") or len(data) < 1:
|
|
||||||
raise Exception("Request data invalid")
|
|
||||||
_retry_timer.complete(link_id)
|
|
||||||
link: RNS.Link = next(filter(lambda l: l.link_id == link_id, _destination.links), None)
|
|
||||||
if link is None:
|
|
||||||
log.error(f"Invalid request {request_id}, no link found with id {link_id}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
remote_version = data[Session.REQUEST_IDX_VERS]
|
|
||||||
if not _protocol_check_magic(remote_version):
|
|
||||||
log.error("Request magic incorrect")
|
|
||||||
link.teardown()
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not remote_version <= _PROTOCOL_VERSION_3:
|
|
||||||
return Session.error_response("Listener<->initiator version mismatch")
|
|
||||||
|
|
||||||
cmd = _cmd.copy()
|
|
||||||
remote_command = data[Session.REQUEST_IDX_CMD]
|
|
||||||
if remote_command is not None and len(remote_command) > 0:
|
|
||||||
if _no_remote_command:
|
|
||||||
return Session.error_response("Listener does not permit initiator to provide command.")
|
|
||||||
elif _remote_cmd_as_args:
|
|
||||||
cmd.extend(remote_command)
|
|
||||||
else:
|
|
||||||
cmd = remote_command
|
|
||||||
|
|
||||||
if not _no_remote_command and (cmd is None or len(cmd) == 0):
|
|
||||||
return Session.error_response("No command supplied and no default command available.")
|
|
||||||
|
|
||||||
session: Session | None = None
|
|
||||||
try:
|
|
||||||
term = data[Session.REQUEST_IDX_TERM]
|
|
||||||
# sanitize
|
|
||||||
if term is not None:
|
|
||||||
term = re.sub('[^A-Za-z-0-9\-\_]','', term)
|
|
||||||
session = Session.get_for_tag(link.link_id)
|
|
||||||
if session is None:
|
|
||||||
log.debug(f"Process not found for link {link}")
|
|
||||||
session = _listen_start_proc(link=link,
|
|
||||||
term=term,
|
|
||||||
session_flags=data[Session.REQUEST_IDX_FLAGS],
|
|
||||||
cmd=cmd,
|
|
||||||
remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""),
|
|
||||||
loop=_loop)
|
|
||||||
if session is None:
|
|
||||||
return Session.error_response("Unable to start subprocess")
|
|
||||||
|
|
||||||
# leave significant headroom for metadata and encoding
|
|
||||||
result = session.process_request(data, _protocol_response_chars_take(link.MDU, remote_version))
|
|
||||||
return result
|
|
||||||
# return ProcessState.default_response()
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"Error procesing request for link {link}: {e}")
|
|
||||||
try:
|
|
||||||
if session is not None and session.process.running:
|
|
||||||
session.process.terminate()
|
|
||||||
except Exception as ee:
|
|
||||||
log.debug(f"Error terminating process for link {link}: {ee}")
|
|
||||||
|
|
||||||
return Session.default_response()
|
|
||||||
|
|
||||||
|
|
||||||
async def _spin(until: callable = None, timeout: float | None = None) -> bool:
|
async def _spin(until: callable = None, timeout: float | None = None) -> bool:
|
||||||
@ -731,7 +218,7 @@ async def _spin(until: callable = None, timeout: float | None = None) -> bool:
|
|||||||
timeout += time.time()
|
timeout += time.time()
|
||||||
|
|
||||||
while (timeout is None or time.time() < timeout) and not until():
|
while (timeout is None or time.time() < timeout) and not until():
|
||||||
if await _check_finished(0.001):
|
if await _check_finished(0.01):
|
||||||
raise asyncio.CancelledError()
|
raise asyncio.CancelledError()
|
||||||
if timeout is not None and time.time() > timeout:
|
if timeout is not None and time.time() > timeout:
|
||||||
return False
|
return False
|
||||||
@ -744,15 +231,28 @@ _remote_exec_grace = 2.0
|
|||||||
_new_data: asyncio.Event | None = None
|
_new_data: asyncio.Event | None = None
|
||||||
_tr: process.TTYRestorer | None = None
|
_tr: process.TTYRestorer | None = None
|
||||||
|
|
||||||
|
_pq = queue.Queue()
|
||||||
|
|
||||||
|
class InitiatorState(enum.IntEnum):
|
||||||
|
IS_INITIAL = 0
|
||||||
|
IS_LINKED = 1
|
||||||
|
IS_WAIT_VERS = 2
|
||||||
|
IS_RUNNING = 3
|
||||||
|
IS_TERMINATE = 4
|
||||||
|
IS_TEARDOWN = 5
|
||||||
|
|
||||||
|
|
||||||
|
def _client_link_closed(link):
|
||||||
|
log = _get_logger("_client_link_closed")
|
||||||
|
_finished.set()
|
||||||
|
|
||||||
|
|
||||||
def _client_packet_handler(message, packet):
|
def _client_packet_handler(message, packet):
|
||||||
global _new_data
|
global _new_data
|
||||||
log = _get_logger("_client_packet_handler")
|
log = _get_logger("_client_packet_handler")
|
||||||
if message is not None and message.decode("utf-8") == DATA_AVAIL_MSG and _new_data is not None:
|
packet.prove()
|
||||||
log.debug("data available")
|
_pq.put(message)
|
||||||
_new_data.set()
|
|
||||||
else:
|
|
||||||
log.error(f"received unhandled packet")
|
|
||||||
|
|
||||||
|
|
||||||
class RemoteExecutionError(Exception):
|
class RemoteExecutionError(Exception):
|
||||||
@ -760,15 +260,10 @@ class RemoteExecutionError(Exception):
|
|||||||
self.msg = msg
|
self.msg = msg
|
||||||
|
|
||||||
|
|
||||||
def _response_handler(request_receipt: RNS.RequestReceipt):
|
async def _initiate_link(configdir, identitypath=None, verbosity=0, quietness=0, noid=False, destination=None,
|
||||||
pass
|
service_name="default", timeout=RNS.Transport.PATH_REQUEST_TIMEOUT):
|
||||||
|
|
||||||
|
|
||||||
async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=False, destination=None,
|
|
||||||
service_name="default", stdin=None, timeout=RNS.Transport.PATH_REQUEST_TIMEOUT,
|
|
||||||
cmd: [str] | None = None, stdin_eof=False, bytes_available=0):
|
|
||||||
global _identity, _reticulum, _link, _destination, _remote_exec_grace, _tr, _new_data
|
global _identity, _reticulum, _link, _destination, _remote_exec_grace, _tr, _new_data
|
||||||
log = _get_logger("_execute")
|
log = _get_logger("_initiate_link")
|
||||||
|
|
||||||
dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2
|
dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2
|
||||||
if len(destination) != dest_len:
|
if len(destination) != dest_len:
|
||||||
@ -809,6 +304,8 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
|||||||
_link = RNS.Link(_destination)
|
_link = RNS.Link(_destination)
|
||||||
_link.did_identify = False
|
_link.did_identify = False
|
||||||
|
|
||||||
|
_link.set_link_closed_callback(_client_link_closed)
|
||||||
|
|
||||||
log.info(f"Establishing link...")
|
log.info(f"Establishing link...")
|
||||||
if not await _spin(until=lambda: _link.status == RNS.Link.ACTIVE, timeout=timeout):
|
if not await _spin(until=lambda: _link.status == RNS.Link.ACTIVE, timeout=timeout):
|
||||||
raise RemoteExecutionError("Could not establish link with " + RNS.prettyhexrep(destination_hash))
|
raise RemoteExecutionError("Could not establish link with " + RNS.prettyhexrep(destination_hash))
|
||||||
@ -820,110 +317,6 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
|||||||
|
|
||||||
_link.set_packet_callback(_client_packet_handler)
|
_link.set_packet_callback(_client_packet_handler)
|
||||||
|
|
||||||
request = Session.default_request()
|
|
||||||
log.debug(f"Sending {len(stdin) or 0} bytes to listener")
|
|
||||||
# log.debug(f"Sending {stdin} to listener")
|
|
||||||
request[Session.REQUEST_IDX_STDIN] = bytes(stdin)
|
|
||||||
request[Session.REQUEST_IDX_CMD] = cmd
|
|
||||||
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], stdin_eof,
|
|
||||||
Session.REQUEST_FLAGS_EOF_STDIN)
|
|
||||||
request[Session.REQUEST_IDX_BYTES_AVAILABLE] = bytes_available
|
|
||||||
|
|
||||||
# TODO: Tune
|
|
||||||
timeout = timeout + _link.rtt * 4 + _remote_exec_grace
|
|
||||||
|
|
||||||
log.debug("Sending request")
|
|
||||||
request_receipt = _link.request(
|
|
||||||
path="data",
|
|
||||||
data=request,
|
|
||||||
timeout=timeout
|
|
||||||
)
|
|
||||||
timeout += 0.5
|
|
||||||
|
|
||||||
log.debug("Waiting for delivery")
|
|
||||||
await _spin(
|
|
||||||
until=lambda: _link.status == RNS.Link.CLOSED or (
|
|
||||||
request_receipt.status != RNS.RequestReceipt.FAILED and
|
|
||||||
request_receipt.status != RNS.RequestReceipt.SENT),
|
|
||||||
timeout=timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
if _link.status == RNS.Link.CLOSED:
|
|
||||||
raise RemoteExecutionError("Could not request remote execution, link was closed")
|
|
||||||
|
|
||||||
if request_receipt.status == RNS.RequestReceipt.FAILED:
|
|
||||||
raise RemoteExecutionError("Could not request remote execution")
|
|
||||||
|
|
||||||
await _spin(
|
|
||||||
until=lambda: request_receipt.status != RNS.RequestReceipt.DELIVERED,
|
|
||||||
timeout=timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
if request_receipt.status == RNS.RequestReceipt.FAILED:
|
|
||||||
raise RemoteExecutionError("No result was received")
|
|
||||||
|
|
||||||
if request_receipt.status == RNS.RequestReceipt.FAILED:
|
|
||||||
raise RemoteExecutionError("Receiving result failed")
|
|
||||||
|
|
||||||
if request_receipt.response is not None:
|
|
||||||
try:
|
|
||||||
version = request_receipt.response[Session.RESPONSE_IDX_VERSION] or 0
|
|
||||||
if not _protocol_check_magic(version):
|
|
||||||
raise RemoteExecutionError("Protocol error")
|
|
||||||
elif version != _PROTOCOL_VERSION_3:
|
|
||||||
raise RemoteExecutionError("Protocol version mismatch")
|
|
||||||
|
|
||||||
flags = request_receipt.response[Session.RESPONSE_IDX_FLAGS]
|
|
||||||
running = _check_and(flags, Session.RESPONSE_FLAGS_RUNNING)
|
|
||||||
stdout_eof = _check_and(flags, Session.RESPONSE_FLAGS_EOF_STDOUT)
|
|
||||||
stderr_eof = _check_and(flags, Session.RESPONSE_FLAGS_EOF_STDERR)
|
|
||||||
return_code = request_receipt.response[Session.RESPONSE_IDX_RETCODE]
|
|
||||||
ready_bytes = request_receipt.response[Session.RESPONSE_IDX_RDYBYTE] or 0
|
|
||||||
stdout = request_receipt.response[Session.RESPONSE_IDX_STDOUT]
|
|
||||||
stderr = request_receipt.response[Session.RESPONSE_IDX_STDERR]
|
|
||||||
# if stdout is not None:
|
|
||||||
# stdout = base64.b64decode(stdout)
|
|
||||||
timestamp = request_receipt.response[Session.RESPONSE_IDX_TMSTAMP]
|
|
||||||
# log.debug("data: " + (stdout.decode("utf-8") if stdout is not None else ""))
|
|
||||||
except RemoteExecutionError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
raise RemoteExecutionError(f"Received invalid response") from e
|
|
||||||
|
|
||||||
if stdout is not None:
|
|
||||||
_tr.raw()
|
|
||||||
log.debug(f"stdout: {stdout}")
|
|
||||||
os.write(1, stdout)
|
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
||||||
if stderr is not None:
|
|
||||||
_tr.raw()
|
|
||||||
log.debug(f"stderr: {stderr}")
|
|
||||||
os.write(2, stderr)
|
|
||||||
sys.stderr.flush()
|
|
||||||
|
|
||||||
if stderr_eof and ready_bytes == 0:
|
|
||||||
log.debug("Closing stderr")
|
|
||||||
os.close(2)
|
|
||||||
|
|
||||||
if stdout_eof and ready_bytes == 0:
|
|
||||||
log.debug("Closing stdout")
|
|
||||||
os.close(1)
|
|
||||||
|
|
||||||
got_bytes = (len(stdout) if stdout is not None else 0) + (len(stderr) if stderr is not None else 0)
|
|
||||||
log.debug(f"{got_bytes} chars received, {ready_bytes} bytes ready on server, return code {return_code}")
|
|
||||||
|
|
||||||
if ready_bytes > 0:
|
|
||||||
_new_data.set()
|
|
||||||
|
|
||||||
if (not running or return_code is not None) and (ready_bytes == 0):
|
|
||||||
log.debug(f"returning running: {running}, return_code: {return_code}")
|
|
||||||
return return_code or 255
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
_pre_input = bytearray()
|
|
||||||
|
|
||||||
|
|
||||||
async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness: int, noid: bool, destination: str,
|
async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness: int, noid: bool, destination: str,
|
||||||
service_name: str, timeout: float, command: [str] | None = None):
|
service_name: str, timeout: float, command: [str] | None = None):
|
||||||
@ -931,13 +324,52 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
|
|||||||
log = _get_logger("_initiate")
|
log = _get_logger("_initiate")
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
_new_data = asyncio.Event()
|
_new_data = asyncio.Event()
|
||||||
|
state = InitiatorState.IS_INITIAL
|
||||||
data_buffer = bytearray(sys.stdin.buffer.read()) if not os.isatty(sys.stdin.fileno()) else bytearray()
|
data_buffer = bytearray(sys.stdin.buffer.read()) if not os.isatty(sys.stdin.fileno()) else bytearray()
|
||||||
|
|
||||||
|
await _initiate_link(
|
||||||
|
configdir=configdir,
|
||||||
|
identitypath=identitypath,
|
||||||
|
verbosity=verbosity,
|
||||||
|
quietness=quietness,
|
||||||
|
noid=noid,
|
||||||
|
destination=destination,
|
||||||
|
service_name=service_name,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not _link or _link.status != RNS.Link.ACTIVE:
|
||||||
|
_finished.set()
|
||||||
|
return 255
|
||||||
|
|
||||||
|
state = InitiatorState.IS_LINKED
|
||||||
|
outlet = session.RNSOutlet(_link)
|
||||||
|
with protocol.Messenger(retry_delay_min=5) as messenger:
|
||||||
|
|
||||||
|
# Next step after linking and identifying: send version
|
||||||
|
# if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5):
|
||||||
|
# print("Error bringing up link")
|
||||||
|
# return 253
|
||||||
|
|
||||||
|
messenger.send(outlet, protocol.VersionInfoMessage())
|
||||||
|
try:
|
||||||
|
vp = _pq.get(timeout=max(outlet.rtt * 20, 5))
|
||||||
|
vm = messenger.receive(vp)
|
||||||
|
if not isinstance(vm, protocol.VersionInfoMessage):
|
||||||
|
raise Exception("Invalid message received")
|
||||||
|
log.debug(f"Server version info: sw {vm.sw_version} prot {vm.protocol_version}")
|
||||||
|
state = InitiatorState.IS_RUNNING
|
||||||
|
except queue.Empty:
|
||||||
|
print("Protocol error")
|
||||||
|
return 254
|
||||||
|
|
||||||
|
winch = False
|
||||||
def sigwinch_handler():
|
def sigwinch_handler():
|
||||||
|
nonlocal winch
|
||||||
# log.debug("WindowChanged")
|
# log.debug("WindowChanged")
|
||||||
if _new_data is not None:
|
winch = True
|
||||||
_new_data.set()
|
# if _new_data is not None:
|
||||||
|
# _new_data.set()
|
||||||
|
|
||||||
stdin_eof = False
|
stdin_eof = False
|
||||||
def stdin():
|
def stdin():
|
||||||
@ -957,61 +389,98 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
|
|||||||
process.tty_add_reader_callback(sys.stdin.fileno(), stdin)
|
process.tty_add_reader_callback(sys.stdin.fileno(), stdin)
|
||||||
|
|
||||||
await _check_finished()
|
await _check_finished()
|
||||||
loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler)
|
|
||||||
|
|
||||||
# leave a lot of overhead
|
tcattr = None
|
||||||
mdu = 64
|
rows, cols, hpix, vpix = (None, None, None, None)
|
||||||
rtt = 5
|
|
||||||
first_loop = True
|
|
||||||
cmdline = " ".join(command or [])
|
|
||||||
while not await _check_finished():
|
|
||||||
try:
|
try:
|
||||||
log.debug("top of client loop")
|
tcattr = termios.tcgetattr(0)
|
||||||
stdin = data_buffer[:mdu]
|
rows, cols, hpix, vpix = process.tty_get_winsize(0)
|
||||||
data_buffer = data_buffer[mdu:]
|
except:
|
||||||
_new_data.clear()
|
try:
|
||||||
log.debug("before _execute")
|
tcattr = termios.tcgetattr(1)
|
||||||
return_code = await _execute(
|
rows, cols, hpix, vpix = process.tty_get_winsize(1)
|
||||||
configdir=configdir,
|
except:
|
||||||
identitypath=identitypath,
|
try:
|
||||||
verbosity=verbosity,
|
tcattr = termios.tcgetattr(2)
|
||||||
quietness=quietness,
|
rows, cols, hpix, vpix = process.tty_get_winsize(2)
|
||||||
noid=noid,
|
except:
|
||||||
destination=destination,
|
pass
|
||||||
service_name=service_name,
|
|
||||||
stdin=stdin,
|
|
||||||
timeout=timeout,
|
|
||||||
cmd=command,
|
|
||||||
stdin_eof=stdin_eof,
|
|
||||||
bytes_available=len(data_buffer)
|
|
||||||
)
|
|
||||||
|
|
||||||
if first_loop:
|
messenger.send(outlet, protocol.ExecuteCommandMesssage(cmdline=command,
|
||||||
first_loop = False
|
pipe_stdin=not os.isatty(0),
|
||||||
mdu = _protocol_request_chars_take(_link.MDU,
|
pipe_stdout=not os.isatty(1),
|
||||||
_PROTOCOL_VERSION_DEFAULT,
|
pipe_stderr=not os.isatty(2),
|
||||||
os.environ.get("TERM", ""),
|
tcflags=tcattr,
|
||||||
cmdline)
|
term=os.environ.get("TERM", None),
|
||||||
_new_data.set()
|
rows=rows,
|
||||||
|
cols=cols,
|
||||||
|
hpix=hpix,
|
||||||
|
vpix=vpix))
|
||||||
|
|
||||||
if _link:
|
|
||||||
rtt = _link.rtt
|
|
||||||
|
|
||||||
if return_code is not None:
|
|
||||||
log.debug(f"received return code {return_code}, exiting")
|
|
||||||
|
loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler)
|
||||||
|
mdu = _link.MDU - 16
|
||||||
|
sent_eof = False
|
||||||
|
|
||||||
|
while not await _check_finished() and state in [InitiatorState.IS_RUNNING]:
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
packet = _pq.get_nowait()
|
||||||
|
message = messenger.receive(packet)
|
||||||
|
|
||||||
|
if isinstance(message, protocol.StreamDataMessage):
|
||||||
|
if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDOUT:
|
||||||
|
if message.data and len(message.data) > 0:
|
||||||
|
_tr.raw()
|
||||||
|
log.debug(f"stdout: {message.data}")
|
||||||
|
os.write(1, message.data)
|
||||||
|
sys.stdout.flush()
|
||||||
|
if message.eof:
|
||||||
|
os.close(1)
|
||||||
|
if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDERR:
|
||||||
|
if message.data and len(message.data) > 0:
|
||||||
|
_tr.raw()
|
||||||
|
log.debug(f"stdout: {message.data}")
|
||||||
|
os.write(2, message.data)
|
||||||
|
sys.stderr.flush()
|
||||||
|
if message.eof:
|
||||||
|
os.close(2)
|
||||||
|
elif isinstance(message, protocol.CommandExitedMessage):
|
||||||
|
log.debug(f"received return code {message.return_code}, exiting")
|
||||||
with exception.permit(SystemExit, KeyboardInterrupt):
|
with exception.permit(SystemExit, KeyboardInterrupt):
|
||||||
_link.teardown()
|
_link.teardown()
|
||||||
|
return message.return_code
|
||||||
return return_code
|
elif isinstance(message, protocol.ErrorMessage):
|
||||||
except asyncio.CancelledError:
|
log.error(message.data)
|
||||||
if _link and _link.status != RNS.Link.CLOSED:
|
if message.fatal:
|
||||||
_link.teardown()
|
_link.teardown()
|
||||||
return 0
|
return 200
|
||||||
|
|
||||||
|
except queue.Empty:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if messenger.is_outlet_ready(outlet):
|
||||||
|
stdin = data_buffer[:mdu]
|
||||||
|
data_buffer = data_buffer[mdu:]
|
||||||
|
eof = not sent_eof and stdin_eof and len(stdin) == 0
|
||||||
|
if len(stdin) > 0 or eof:
|
||||||
|
messenger.send(outlet, protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN,
|
||||||
|
stdin, eof))
|
||||||
|
sent_eof = eof
|
||||||
except RemoteExecutionError as e:
|
except RemoteExecutionError as e:
|
||||||
print(e.msg)
|
print(e.msg)
|
||||||
return 255
|
return 255
|
||||||
|
except Exception as ex:
|
||||||
|
print(f"Client exception: {ex}")
|
||||||
|
if _link and _link.status != RNS.Link.CLOSED:
|
||||||
|
_link.teardown()
|
||||||
|
return 127
|
||||||
|
|
||||||
await process.event_wait_any([_new_data, _finished], timeout=min(max(rtt * 50, 5), 120))
|
# await process.event_wait_any([_new_data, _finished], timeout=min(max(rtt * 50, 5), 120))
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
log.debug("after main loop")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
423
rnsh/session.py
Normal file
423
rnsh/session.py
Normal file
@ -0,0 +1,423 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import contextlib
|
||||||
|
import functools
|
||||||
|
import threading
|
||||||
|
import rnsh.exception as exception
|
||||||
|
import asyncio
|
||||||
|
import rnsh.process as process
|
||||||
|
import rnsh.helpers as helpers
|
||||||
|
import rnsh.protocol as protocol
|
||||||
|
import enum
|
||||||
|
from typing import TypeVar, Generic, Callable, List
|
||||||
|
from abc import abstractmethod, ABC
|
||||||
|
from multiprocessing import Manager
|
||||||
|
import os
|
||||||
|
import RNS
|
||||||
|
|
||||||
|
import logging as __logging
|
||||||
|
|
||||||
|
from rnsh.protocol import MessageOutletBase, _TReceipt, MessageState
|
||||||
|
|
||||||
|
module_logger = __logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_TLink = TypeVar("_TLink")
|
||||||
|
|
||||||
|
class SEType(enum.IntEnum):
|
||||||
|
SE_LINK_CLOSED = 0
|
||||||
|
|
||||||
|
|
||||||
|
class SessionException(Exception):
|
||||||
|
def __init__(self, setype: SEType, msg: str, *args):
|
||||||
|
super().__init__(msg, args)
|
||||||
|
self.type = setype
|
||||||
|
|
||||||
|
|
||||||
|
class LSState(enum.IntEnum):
|
||||||
|
LSSTATE_WAIT_IDENT = 1
|
||||||
|
LSSTATE_WAIT_VERS = 2
|
||||||
|
LSSTATE_WAIT_CMD = 3
|
||||||
|
LSSTATE_RUNNING = 4
|
||||||
|
LSSTATE_ERROR = 5
|
||||||
|
LSSTATE_TEARDOWN = 6
|
||||||
|
|
||||||
|
|
||||||
|
_TIdentity = TypeVar("_TIdentity")
|
||||||
|
|
||||||
|
|
||||||
|
class LSOutletBase(protocol.MessageOutletBase):
|
||||||
|
@abstractmethod
|
||||||
|
def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]):
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_link_closed_callback(self, cb: Callable[[LSOutletBase], None]):
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def unset_link_closed_callback(self):
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def teardown(self):
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self):
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
|
||||||
|
class ListenerSession:
|
||||||
|
sessions: List[ListenerSession] = []
|
||||||
|
messenger: protocol.Messenger = protocol.Messenger(retry_delay_min=5)
|
||||||
|
allowed_identity_hashes: [any] = []
|
||||||
|
allow_all: bool = False
|
||||||
|
allow_remote_command: bool = False
|
||||||
|
default_command: [str] = []
|
||||||
|
remote_cmd_as_args = False
|
||||||
|
|
||||||
|
def __init__(self, outlet: LSOutletBase, loop: asyncio.AbstractEventLoop):
|
||||||
|
self._log = module_logger.getChild(self.__class__.__name__)
|
||||||
|
self._log.info(f"Session started for {outlet}")
|
||||||
|
self.outlet = outlet
|
||||||
|
self.outlet.set_initiator_identified_callback(self._initiator_identified)
|
||||||
|
self.outlet.set_link_closed_callback(self._link_closed)
|
||||||
|
self.loop = loop
|
||||||
|
self.state: LSState = None
|
||||||
|
self.remote_identity = None
|
||||||
|
self.term: str | None = None
|
||||||
|
self.stdin_is_pipe: bool = False
|
||||||
|
self.stdout_is_pipe: bool = False
|
||||||
|
self.stderr_is_pipe: bool = False
|
||||||
|
self.tcflags: [any] = None
|
||||||
|
self.cmdline: [str] = None
|
||||||
|
self.rows: int = 0
|
||||||
|
self.cols: int = 0
|
||||||
|
self.hpix: int = 0
|
||||||
|
self.vpix: int = 0
|
||||||
|
self.stdout_buf = bytearray()
|
||||||
|
self.stdout_eof_sent = False
|
||||||
|
self.stderr_buf = bytearray()
|
||||||
|
self.stderr_eof_sent = False
|
||||||
|
self.return_code: int | None = None
|
||||||
|
self.return_code_sent = False
|
||||||
|
self.process: process.CallbackSubprocess | None = None
|
||||||
|
self._set_state(LSState.LSSTATE_WAIT_IDENT)
|
||||||
|
self.sessions.append(self)
|
||||||
|
|
||||||
|
def _terminated(self, return_code: int):
|
||||||
|
self.return_code = return_code
|
||||||
|
|
||||||
|
def _set_state(self, state: LSState, timeout_factor: float = 10.0):
|
||||||
|
timeout = max(self.outlet.rtt * timeout_factor, max(self.outlet.rtt * 2, 10)) if timeout_factor is not None else None
|
||||||
|
self._log.debug(f"Set state: {state.name}, timeout {timeout}")
|
||||||
|
orig_state = self.state
|
||||||
|
self.state = state
|
||||||
|
if timeout_factor is not None:
|
||||||
|
self._call(functools.partial(self._check_protocol_timeout, lambda: self.state == orig_state, state.name), timeout)
|
||||||
|
|
||||||
|
def _call(self, func: callable, delay: float = 0):
|
||||||
|
def call_inner():
|
||||||
|
# self._log.debug("call_inner")
|
||||||
|
if delay == 0:
|
||||||
|
func()
|
||||||
|
else:
|
||||||
|
self.loop.call_later(delay, func)
|
||||||
|
self.loop.call_soon_threadsafe(call_inner)
|
||||||
|
|
||||||
|
def send(self, message: protocol.Message):
|
||||||
|
self.messenger.send(self.outlet, message)
|
||||||
|
|
||||||
|
def _protocol_error(self, name: str):
|
||||||
|
self.terminate(f"Protocol error ({name})")
|
||||||
|
|
||||||
|
def _protocol_timeout_error(self, name: str):
|
||||||
|
self.terminate(f"Protocol timeout error: {name}")
|
||||||
|
|
||||||
|
def terminate(self, error: str = None):
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
self._log.debug("Terminating session" + (f": {error}" if error else ""))
|
||||||
|
if error and self.state != LSState.LSSTATE_TEARDOWN:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
self.send(protocol.ErrorMessage(error, True))
|
||||||
|
self.state = LSState.LSSTATE_ERROR
|
||||||
|
self._terminate_process()
|
||||||
|
self._call(self._prune, max(self.outlet.rtt * 3, 5))
|
||||||
|
|
||||||
|
def _prune(self):
|
||||||
|
self.state = LSState.LSSTATE_TEARDOWN
|
||||||
|
with contextlib.suppress(ValueError):
|
||||||
|
self.sessions.remove(self)
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
self.outlet.teardown()
|
||||||
|
|
||||||
|
def _check_protocol_timeout(self, fail_condition: Callable[[], bool], name: str):
|
||||||
|
timeout = True
|
||||||
|
try:
|
||||||
|
timeout = self.state != LSState.LSSTATE_TEARDOWN and fail_condition()
|
||||||
|
except Exception as ex:
|
||||||
|
self._log.exception("Error in protocol timeout", ex)
|
||||||
|
if timeout:
|
||||||
|
self._protocol_timeout_error(name)
|
||||||
|
|
||||||
|
def _link_closed(self, outlet: LSOutletBase):
|
||||||
|
outlet.unset_link_closed_callback()
|
||||||
|
|
||||||
|
if outlet != self.outlet:
|
||||||
|
self._log.debug("Link closed received from incorrect outlet")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._log.debug(f"link_closed {outlet}")
|
||||||
|
self.messenger.clear_retries(self.outlet)
|
||||||
|
self.terminate()
|
||||||
|
|
||||||
|
def _initiator_identified(self, outlet, identity):
|
||||||
|
if outlet != self.outlet:
|
||||||
|
self._log.debug("Identity received from incorrect outlet")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._log.info(f"initiator_identified {identity} on link {outlet}")
|
||||||
|
if self.state != LSState.LSSTATE_WAIT_IDENT:
|
||||||
|
self._protocol_error(LSState.LSSTATE_WAIT_IDENT.name)
|
||||||
|
|
||||||
|
if not self.allow_all and identity.hash not in self.allowed_identity_hashes:
|
||||||
|
self.terminate("Identity is not allowed.")
|
||||||
|
|
||||||
|
self.remote_identity = identity
|
||||||
|
self.outlet.set_packet_received_callback(self._packet_received)
|
||||||
|
self._set_state(LSState.LSSTATE_WAIT_VERS)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def pump_all(cls):
|
||||||
|
for session in cls.sessions:
|
||||||
|
session.pump()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def terminate_all(cls, reason: str):
|
||||||
|
for session in cls.sessions:
|
||||||
|
session.terminate(reason)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
def pump(self):
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.state != LSState.LSSTATE_RUNNING:
|
||||||
|
return
|
||||||
|
elif not self.messenger.is_outlet_ready(self.outlet):
|
||||||
|
return
|
||||||
|
elif len(self.stderr_buf) > 0:
|
||||||
|
mdu = self.outlet.mdu - 16
|
||||||
|
data = self.stderr_buf[:mdu]
|
||||||
|
self.stderr_buf = self.stderr_buf[mdu:]
|
||||||
|
send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent
|
||||||
|
self.stderr_eof_sent = self.stderr_eof_sent or send_eof
|
||||||
|
msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDERR,
|
||||||
|
data, send_eof)
|
||||||
|
self.send(msg)
|
||||||
|
if send_eof:
|
||||||
|
self.stderr_eof_sent = True
|
||||||
|
elif len(self.stdout_buf) > 0:
|
||||||
|
mdu = self.outlet.mdu - 16
|
||||||
|
data = self.stdout_buf[:mdu]
|
||||||
|
self.stdout_buf = self.stdout_buf[mdu:]
|
||||||
|
send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent
|
||||||
|
self.stdout_eof_sent = self.stdout_eof_sent or send_eof
|
||||||
|
msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDOUT,
|
||||||
|
data, send_eof)
|
||||||
|
self.send(msg)
|
||||||
|
elif self.return_code is not None and not self.return_code_sent:
|
||||||
|
msg = protocol.CommandExitedMessage(self.return_code)
|
||||||
|
self.send(msg)
|
||||||
|
self.return_code_sent = True
|
||||||
|
self._call(functools.partial(self._check_protocol_timeout,
|
||||||
|
lambda: self.state == LSState.LSSTATE_RUNNING, "CommandExitedMessage"),
|
||||||
|
max(self.outlet.rtt * 5, 10))
|
||||||
|
except Exception as ex:
|
||||||
|
self._log.exception("Error during pump", ex)
|
||||||
|
|
||||||
|
def _terminate_process(self):
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
if self.process and self.process.running:
|
||||||
|
self.process.terminate()
|
||||||
|
|
||||||
|
def _start_cmd(self, cmdline: [str], pipe_stdin: bool, pipe_stdout: bool, pipe_stderr: bool, tcflags: [any],
|
||||||
|
term: str | None, rows: int, cols: int, hpix: int, vpix: int):
|
||||||
|
|
||||||
|
self.cmdline = self.default_command
|
||||||
|
if not self.allow_remote_command and cmdline and len(cmdline) > 0:
|
||||||
|
self.terminate("Remote command line not allowed by listener")
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.remote_cmd_as_args and cmdline and len(cmdline) > 0:
|
||||||
|
self.cmdline.extend(cmdline)
|
||||||
|
elif cmdline and len(cmdline) > 0:
|
||||||
|
self.cmdline = cmdline
|
||||||
|
|
||||||
|
|
||||||
|
self.stdin_is_pipe = pipe_stdin
|
||||||
|
self.stdout_is_pipe = pipe_stdout
|
||||||
|
self.stderr_is_pipe = pipe_stderr
|
||||||
|
self.tcflags = tcflags
|
||||||
|
self.term = term
|
||||||
|
|
||||||
|
def stdout(data: bytes):
|
||||||
|
self.stdout_buf.extend(data)
|
||||||
|
|
||||||
|
def stderr(data: bytes):
|
||||||
|
self.stderr_buf.extend(data)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.process = process.CallbackSubprocess(argv=self.cmdline,
|
||||||
|
env={"TERM": self.term or os.environ.get("TERM", None),
|
||||||
|
"RNS_REMOTE_IDENTITY": RNS.prettyhexrep(self.remote_identity.hash) or ""},
|
||||||
|
loop=self.loop,
|
||||||
|
stdout_callback=stdout,
|
||||||
|
stderr_callback=stderr,
|
||||||
|
terminated_callback=self._terminated,
|
||||||
|
stdin_is_pipe=self.stdin_is_pipe,
|
||||||
|
stdout_is_pipe=self.stdout_is_pipe,
|
||||||
|
stderr_is_pipe=self.stderr_is_pipe)
|
||||||
|
self.process.start()
|
||||||
|
self._set_window_size(rows, cols, hpix, vpix)
|
||||||
|
except Exception as ex:
|
||||||
|
self._log.exception(f"Unable to start process for link {self.outlet}", ex)
|
||||||
|
self.terminate("Unable to start process")
|
||||||
|
|
||||||
|
def _set_window_size(self, rows: int, cols: int, hpix: int, vpix: int):
|
||||||
|
self.rows = rows
|
||||||
|
self.cols = cols
|
||||||
|
self.hpix = hpix
|
||||||
|
self.vpix = vpix
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
self.process.set_winsize(rows, cols, hpix, vpix)
|
||||||
|
|
||||||
|
def _received_stdin(self, data: bytes, eof: bool):
|
||||||
|
if data and len(data) > 0:
|
||||||
|
self.process.write(data)
|
||||||
|
if eof:
|
||||||
|
self.process.close_stdin()
|
||||||
|
|
||||||
|
def _handle_message(self, message: protocol.Message):
|
||||||
|
if self.state == LSState.LSSTATE_WAIT_VERS:
|
||||||
|
if not isinstance(message, protocol.VersionInfoMessage):
|
||||||
|
self._protocol_error(self.state.name)
|
||||||
|
return
|
||||||
|
self._log.info(f"version {message.sw_version}, protocol {message.protocol_version} on link {self.outlet}")
|
||||||
|
if message.protocol_version != protocol.PROTOCOL_VERSION:
|
||||||
|
self.terminate("Incompatible protocol")
|
||||||
|
return
|
||||||
|
self.send(protocol.VersionInfoMessage())
|
||||||
|
self._set_state(LSState.LSSTATE_WAIT_CMD)
|
||||||
|
return
|
||||||
|
elif self.state == LSState.LSSTATE_WAIT_CMD:
|
||||||
|
if not isinstance(message, protocol.ExecuteCommandMesssage):
|
||||||
|
return self._protocol_error(self.state.name)
|
||||||
|
self._log.info(f"Execute command message on link {self.outlet}: {message.cmdline}")
|
||||||
|
self._set_state(LSState.LSSTATE_RUNNING)
|
||||||
|
self._start_cmd(message.cmdline, message.pipe_stdin, message.pipe_stdout, message.pipe_stderr,
|
||||||
|
message.tcflags, message.term, message.rows, message.cols, message.hpix, message.vpix)
|
||||||
|
return
|
||||||
|
elif self.state == LSState.LSSTATE_RUNNING:
|
||||||
|
if isinstance(message, protocol.WindowSizeMessage):
|
||||||
|
self._set_window_size(message.rows, message.cols, message.hpix, message.vpix)
|
||||||
|
elif isinstance(message, protocol.StreamDataMessage):
|
||||||
|
if message.stream_id != protocol.StreamDataMessage.STREAM_ID_STDIN:
|
||||||
|
self._log.error(f"Received stream data for invalid stream {message.stream_id} on link {self.outlet}")
|
||||||
|
return self._protocol_error(self.state.name)
|
||||||
|
self._received_stdin(message.data, message.eof)
|
||||||
|
return
|
||||||
|
elif isinstance(message, protocol.NoopMessage):
|
||||||
|
# echo noop only on listener--used for keepalive/connectivity check
|
||||||
|
self.send(message)
|
||||||
|
return
|
||||||
|
elif self.state in [LSState.LSSTATE_ERROR, LSState.LSSTATE_TEARDOWN]:
|
||||||
|
self._log.error(f"Received packet, but in state {self.state.name}")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self._protocol_error("unexpected message")
|
||||||
|
return
|
||||||
|
|
||||||
|
def _packet_received(self, outlet: protocol.MessageOutletBase, raw: bytes):
|
||||||
|
if outlet != self.outlet:
|
||||||
|
self._log.debug("Packet received from incorrect outlet")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
message = self.messenger.receive(raw)
|
||||||
|
self._handle_message(message)
|
||||||
|
except Exception as ex:
|
||||||
|
self._protocol_error("unusable packet")
|
||||||
|
|
||||||
|
|
||||||
|
class RNSOutlet(LSOutletBase):
|
||||||
|
|
||||||
|
def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]):
|
||||||
|
def inner_cb(link, identity: _TIdentity):
|
||||||
|
cb(self, identity)
|
||||||
|
|
||||||
|
self.link.set_remote_identified_callback(inner_cb)
|
||||||
|
|
||||||
|
def set_link_closed_callback(self, cb: Callable[[LSOutletBase], None]):
|
||||||
|
def inner_cb(link):
|
||||||
|
cb(self)
|
||||||
|
|
||||||
|
self.link.set_link_closed_callback(inner_cb)
|
||||||
|
|
||||||
|
def unset_link_closed_callback(self):
|
||||||
|
self.link.set_link_closed_callback(None)
|
||||||
|
|
||||||
|
def teardown(self):
|
||||||
|
self.link.teardown()
|
||||||
|
|
||||||
|
def send(self, raw: bytes) -> RNS.PacketReceipt:
|
||||||
|
packet = RNS.Packet(self.link, raw)
|
||||||
|
packet.send()
|
||||||
|
return packet.receipt
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mdu(self) -> int:
|
||||||
|
return self.link.MDU
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rtt(self) -> float:
|
||||||
|
return self.link.rtt
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_usuable(self):
|
||||||
|
return True #self.link.status in [RNS.Link.ACTIVE]
|
||||||
|
|
||||||
|
def get_receipt_state(self, receipt: RNS.PacketReceipt) -> MessageState:
|
||||||
|
status = receipt.get_status()
|
||||||
|
if status == RNS.PacketReceipt.SENT:
|
||||||
|
return protocol.MessageState.MSGSTATE_SENT
|
||||||
|
if status == RNS.PacketReceipt.DELIVERED:
|
||||||
|
return protocol.MessageState.MSGSTATE_DELIVERED
|
||||||
|
if status == RNS.PacketReceipt.FAILED:
|
||||||
|
return protocol.MessageState.MSGSTATE_FAILED
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unexpected receipt state: {status}")
|
||||||
|
|
||||||
|
def timed_out(self):
|
||||||
|
self.link.teardown()
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"Outlet RNS Link {self.link}"
|
||||||
|
|
||||||
|
def set_packet_received_callback(self, cb: Callable[[MessageOutletBase, bytes], None]):
|
||||||
|
def inner_cb(message, packet: RNS.Packet):
|
||||||
|
packet.prove()
|
||||||
|
cb(self, message)
|
||||||
|
|
||||||
|
self.link.set_packet_callback(inner_cb)
|
||||||
|
|
||||||
|
def __init__(self, link: RNS.Link):
|
||||||
|
self.link = link
|
||||||
|
link.lsoutlet = self
|
||||||
|
link.msgoutlet = self
|
||||||
|
@staticmethod
|
||||||
|
def get_outlet(link: RNS.Link):
|
||||||
|
if hasattr(link, "lsoutlet"):
|
||||||
|
return link.lsoutlet
|
||||||
|
|
||||||
|
return RNSOutlet(link)
|
@ -1,6 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from rnsh.protocol import _TReceipt, MessageState
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
logging.getLogger().setLevel(logging.DEBUG)
|
logging.getLogger().setLevel(logging.DEBUG)
|
||||||
|
|
||||||
import rnsh.protocol
|
import rnsh.protocol
|
||||||
@ -14,64 +18,60 @@ import uuid
|
|||||||
module_logger = logging.getLogger(__name__)
|
module_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Link:
|
class Receipt:
|
||||||
|
def __init__(self, state: rnsh.protocol.MessageState, raw: bytes):
|
||||||
|
self.state = state
|
||||||
|
self.raw = raw
|
||||||
|
|
||||||
|
|
||||||
|
class MessageOutletTest(rnsh.protocol.MessageOutletBase):
|
||||||
def __init__(self, mdu: int, rtt: float):
|
def __init__(self, mdu: int, rtt: float):
|
||||||
self.link_id = uuid.uuid4()
|
self.link_id = uuid.uuid4()
|
||||||
self.timeout_callbacks = 0
|
self.timeout_callbacks = 0
|
||||||
self.mdu = mdu
|
self._mdu = mdu
|
||||||
self.rtt = rtt
|
self._rtt = rtt
|
||||||
self.usable = True
|
self._usable = True
|
||||||
self.receipts = []
|
self.receipts = []
|
||||||
|
self.packet_callback: Callable[[rnsh.protocol.MessageOutletBase, bytes], None] | None = None
|
||||||
|
|
||||||
def timeout_callback(self):
|
def send(self, raw: bytes) -> Receipt:
|
||||||
|
receipt = Receipt(rnsh.protocol.MessageState.MSGSTATE_SENT, raw)
|
||||||
|
self.receipts.append(receipt)
|
||||||
|
return receipt
|
||||||
|
|
||||||
|
def set_packet_received_callback(self, cb: Callable[[rnsh.protocol.MessageOutletBase, bytes], None]):
|
||||||
|
self.packet_callback = cb
|
||||||
|
|
||||||
|
def receive(self, raw: bytes):
|
||||||
|
if self.packet_callback:
|
||||||
|
self.packet_callback(self, raw)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mdu(self):
|
||||||
|
return self._mdu
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rtt(self):
|
||||||
|
return self._rtt
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_usuable(self):
|
||||||
|
return self._usable
|
||||||
|
|
||||||
|
def get_receipt_state(self, receipt: Receipt) -> MessageState:
|
||||||
|
return receipt.state
|
||||||
|
|
||||||
|
def timed_out(self):
|
||||||
self.timeout_callbacks += 1
|
self.timeout_callbacks += 1
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return str(self.link_id)
|
return str(self.link_id)
|
||||||
|
|
||||||
|
|
||||||
class Receipt:
|
|
||||||
def __init__(self, link: Link, state: rnsh.protocol.MessageState, raw: bytes):
|
|
||||||
self.state = state
|
|
||||||
self.raw = raw
|
|
||||||
self.link = link
|
|
||||||
|
|
||||||
|
|
||||||
class ProtocolHarness(contextlib.AbstractContextManager):
|
class ProtocolHarness(contextlib.AbstractContextManager):
|
||||||
def __init__(self, retry_delay_min: float = 1):
|
def __init__(self, retry_delay_min: float = 1):
|
||||||
self._log = module_logger.getChild(self.__class__.__name__)
|
self._log = module_logger.getChild(self.__class__.__name__)
|
||||||
self.messenger = rnsh.protocol.Messenger(receipt_checker=self.receipt_checker,
|
self.messenger = rnsh.protocol.Messenger(retry_delay_min=retry_delay_min)
|
||||||
link_timeout_callback=self.link_timeout_callback,
|
|
||||||
link_mdu_getter=self.link_mdu_getter,
|
|
||||||
link_rtt_getter=self.link_rtt_getter,
|
|
||||||
link_usable_getter=self.link_usable_getter,
|
|
||||||
packet_sender=self.packet_sender,
|
|
||||||
retry_delay_min=retry_delay_min)
|
|
||||||
|
|
||||||
def packet_sender(self, link: Link, raw: bytes) -> Receipt:
|
|
||||||
receipt = Receipt(link, rnsh.protocol.MessageState.MSGSTATE_SENT, raw)
|
|
||||||
link.receipts.append(receipt)
|
|
||||||
return receipt
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def link_mdu_getter(link: Link):
|
|
||||||
return link.mdu
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def link_rtt_getter(link: Link):
|
|
||||||
return link.rtt
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def link_usable_getter(link: Link):
|
|
||||||
return link.usable
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def receipt_checker(receipt: Receipt) -> rnsh.protocol.MessageState:
|
|
||||||
return receipt.state
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def link_timeout_callback(link: Link):
|
|
||||||
link.timeout_callback()
|
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
self.messenger.shutdown()
|
self.messenger.shutdown()
|
||||||
@ -83,49 +83,58 @@ class ProtocolHarness(contextlib.AbstractContextManager):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def test_mdu():
|
|
||||||
with ProtocolHarness() as h:
|
|
||||||
mdu = 500
|
|
||||||
link = Link(mdu=mdu, rtt=0.25)
|
|
||||||
assert h.messenger.get_mdu(link) == mdu - 4
|
|
||||||
link.mdu = mdu = 600
|
|
||||||
assert h.messenger.get_mdu(link) == mdu - 4
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtt():
|
|
||||||
with ProtocolHarness() as h:
|
|
||||||
rtt = 0.25
|
|
||||||
link = Link(mdu=500, rtt=rtt)
|
|
||||||
assert h.messenger.get_rtt(link) == rtt
|
|
||||||
|
|
||||||
|
|
||||||
def test_send_one_retry():
|
def test_send_one_retry():
|
||||||
rtt = 0.001
|
rtt = 0.001
|
||||||
retry_interval = rtt * 150
|
retry_interval = rtt * 150
|
||||||
message_content = b'Test'
|
message_content = b'Test'
|
||||||
with ProtocolHarness(retry_delay_min=retry_interval) as h:
|
with ProtocolHarness(retry_delay_min=retry_interval) as h:
|
||||||
link = Link(mdu=500, rtt=rtt)
|
outlet = MessageOutletTest(mdu=500, rtt=rtt)
|
||||||
message = rnsh.protocol.StreamDataMessage(stream_id=rnsh.protocol.StreamDataMessage.STREAM_ID_STDIN,
|
message = rnsh.protocol.StreamDataMessage(stream_id=rnsh.protocol.StreamDataMessage.STREAM_ID_STDIN,
|
||||||
data=message_content, eof=True)
|
data=message_content, eof=True)
|
||||||
assert len(link.receipts) == 0
|
assert len(outlet.receipts) == 0
|
||||||
h.messenger.send_message(link, message)
|
h.messenger.send(outlet, message)
|
||||||
assert message.tracked
|
assert message.tracked
|
||||||
assert message.raw is not None
|
assert message.raw is not None
|
||||||
assert len(link.receipts) == 1
|
assert len(outlet.receipts) == 1
|
||||||
receipt = link.receipts[0]
|
receipt = outlet.receipts[0]
|
||||||
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
|
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
|
||||||
assert receipt.raw == message.raw
|
assert receipt.raw == message.raw
|
||||||
time.sleep(retry_interval * 1.5)
|
time.sleep(retry_interval * 1.5)
|
||||||
assert len(link.receipts) == 1
|
assert len(outlet.receipts) == 1
|
||||||
receipt.state = rnsh.protocol.MessageState.MSGSTATE_FAILED
|
receipt.state = rnsh.protocol.MessageState.MSGSTATE_FAILED
|
||||||
module_logger.info("set failed")
|
module_logger.info("set failed")
|
||||||
time.sleep(retry_interval)
|
time.sleep(retry_interval)
|
||||||
assert len(link.receipts) == 2
|
assert len(outlet.receipts) == 2
|
||||||
receipt = link.receipts[1]
|
receipt = outlet.receipts[1]
|
||||||
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
|
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
|
||||||
receipt.state = rnsh.protocol.MessageState.MSGSTATE_DELIVERED
|
receipt.state = rnsh.protocol.MessageState.MSGSTATE_DELIVERED
|
||||||
time.sleep(retry_interval)
|
time.sleep(retry_interval)
|
||||||
assert len(link.receipts) == 2
|
assert len(outlet.receipts) == 2
|
||||||
|
assert not message.tracked
|
||||||
|
|
||||||
|
|
||||||
|
def test_send_timeout():
|
||||||
|
rtt = 0.001
|
||||||
|
retry_interval = rtt * 150
|
||||||
|
message_content = b'Test'
|
||||||
|
with ProtocolHarness(retry_delay_min=retry_interval) as h:
|
||||||
|
outlet = MessageOutletTest(mdu=500, rtt=rtt)
|
||||||
|
message = rnsh.protocol.StreamDataMessage(stream_id=rnsh.protocol.StreamDataMessage.STREAM_ID_STDIN,
|
||||||
|
data=message_content, eof=True)
|
||||||
|
assert len(outlet.receipts) == 0
|
||||||
|
h.messenger.send(outlet, message)
|
||||||
|
assert message.tracked
|
||||||
|
assert message.raw is not None
|
||||||
|
assert len(outlet.receipts) == 1
|
||||||
|
receipt = outlet.receipts[0]
|
||||||
|
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
|
||||||
|
assert receipt.raw == message.raw
|
||||||
|
time.sleep(retry_interval * 1.5)
|
||||||
|
assert outlet.timeout_callbacks == 0
|
||||||
|
time.sleep(retry_interval * 7)
|
||||||
|
assert len(outlet.receipts) == 1
|
||||||
|
assert outlet.timeout_callbacks == 1
|
||||||
|
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
|
||||||
assert not message.tracked
|
assert not message.tracked
|
||||||
|
|
||||||
|
|
||||||
@ -133,24 +142,32 @@ def eat_own_dog_food(message: rnsh.protocol.Message, checker: typing.Callable[[r
|
|||||||
rtt = 0.001
|
rtt = 0.001
|
||||||
retry_interval = rtt * 150
|
retry_interval = rtt * 150
|
||||||
with ProtocolHarness(retry_delay_min=retry_interval) as h:
|
with ProtocolHarness(retry_delay_min=retry_interval) as h:
|
||||||
link = Link(mdu=500, rtt=rtt)
|
|
||||||
assert len(link.receipts) == 0
|
decoded: [rnsh.protocol.Message] = []
|
||||||
h.messenger.send_message(link, message)
|
def packet(outlet, buffer):
|
||||||
|
decoded.append(h.messenger.receive(buffer))
|
||||||
|
|
||||||
|
outlet = MessageOutletTest(mdu=500, rtt=rtt)
|
||||||
|
outlet.set_packet_received_callback(packet)
|
||||||
|
assert len(outlet.receipts) == 0
|
||||||
|
h.messenger.send(outlet, message)
|
||||||
assert message.tracked
|
assert message.tracked
|
||||||
assert message.raw is not None
|
assert message.raw is not None
|
||||||
assert len(link.receipts) == 1
|
assert len(outlet.receipts) == 1
|
||||||
receipt = link.receipts[0]
|
receipt = outlet.receipts[0]
|
||||||
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
|
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
|
||||||
assert receipt.raw == message.raw
|
assert receipt.raw == message.raw
|
||||||
module_logger.info("set delivered")
|
module_logger.info("set delivered")
|
||||||
receipt.state = rnsh.protocol.MessageState.MSGSTATE_DELIVERED
|
receipt.state = rnsh.protocol.MessageState.MSGSTATE_DELIVERED
|
||||||
time.sleep(retry_interval * 2)
|
time.sleep(retry_interval * 2)
|
||||||
assert len(link.receipts) == 1
|
assert len(outlet.receipts) == 1
|
||||||
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_DELIVERED
|
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_DELIVERED
|
||||||
assert not message.tracked
|
assert not message.tracked
|
||||||
module_logger.info("injecting rx message")
|
module_logger.info("injecting rx message")
|
||||||
h.messenger.inbound(message.raw)
|
assert len(decoded) == 0
|
||||||
rx_message = h.messenger.poll_inbound(block=False)
|
outlet.receive(message.raw)
|
||||||
|
assert len(decoded) == 1
|
||||||
|
rx_message = decoded[0]
|
||||||
assert rx_message is not None
|
assert rx_message is not None
|
||||||
assert isinstance(rx_message, message.__class__)
|
assert isinstance(rx_message, message.__class__)
|
||||||
assert rx_message.msgid != message.msgid
|
assert rx_message.msgid != message.msgid
|
||||||
@ -238,6 +255,16 @@ def test_send_receive_error():
|
|||||||
eat_own_dog_food(message, check)
|
eat_own_dog_food(message, check)
|
||||||
|
|
||||||
|
|
||||||
|
def test_send_receive_cmdexit():
|
||||||
|
message = rnsh.protocol.CommandExitedMessage(5)
|
||||||
|
|
||||||
|
def check(rx_message: rnsh.protocol.Message):
|
||||||
|
assert isinstance(rx_message, message.__class__)
|
||||||
|
assert rx_message.return_code == message.return_code
|
||||||
|
|
||||||
|
eat_own_dog_food(message, check)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,19 +12,6 @@ import re
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_check_magic():
|
|
||||||
magic = rnsh.rnsh._PROTOCOL_VERSION_0
|
|
||||||
# magic for version 0 is generated, make sure it comes out as expected
|
|
||||||
assert magic == 0xdeadbeef00000000
|
|
||||||
# verify the checker thinks it's right
|
|
||||||
assert rnsh.rnsh._protocol_check_magic(magic)
|
|
||||||
# scramble the magic
|
|
||||||
magic = magic | 0x00ffff0000000000
|
|
||||||
# make sure it fails now
|
|
||||||
assert not rnsh.rnsh._protocol_check_magic(magic)
|
|
||||||
|
|
||||||
|
|
||||||
def test_version():
|
def test_version():
|
||||||
# version = importlib.metadata.version(rnsh.__version__)
|
# version = importlib.metadata.version(rnsh.__version__)
|
||||||
assert rnsh.__version__ != "0.0.0"
|
assert rnsh.__version__ != "0.0.0"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user