Got the new protocol working.

This commit is contained in:
Aaron Heise 2023-02-18 00:09:28 -06:00
parent 0ee305795f
commit 8edb4020b1
8 changed files with 876 additions and 887 deletions

8
rnsh/helpers.py Normal file
View file

@ -0,0 +1,8 @@
def bitwise_or_if(value: int, condition: bool, orval: int):
if not condition:
return value
return value | orval
def check_and(value: int, andval: int) -> bool:
return (value & andval) > 0

View file

@ -382,8 +382,11 @@ def _launch_child(cmd_line: list[str], env: dict[str, str], stdin_is_pipe: bool,
# Make PTY controlling if necessary
if child_fd is not None:
os.setsid()
try:
tmp_fd = os.open(os.ttyname(0 if not stdin_is_pipe else 1 if not stdout_is_pipe else 2), os.O_RDWR)
os.close(tmp_fd)
except:
pass
# fcntl.ioctl(0 if not stdin_is_pipe else 1 if not stdout_is_pipe else 2), os.O_RDWR, termios.TIOCSCTTY, 0)
# Execute the command
@ -445,7 +448,8 @@ class CallbackSubprocess:
self._child_stdout: int = None
self._child_stderr: int = None
self._return_code: int = None
self._eof: bool = False
self._stdout_eof: bool = False
self._stderr_eof: bool = False
self._stdin_is_pipe = stdin_is_pipe
self._stdout_is_pipe = stdout_is_pipe
self._stderr_is_pipe = stderr_is_pipe
@ -455,6 +459,21 @@ class CallbackSubprocess:
Terminate child process if running
:param kill_delay: if after kill_delay seconds the child process has not exited, escalate to SIGHUP and SIGKILL
"""
same = self._child_stdout == self._child_stderr
with contextlib.suppress(OSError):
if not self._stdout_eof:
tty_unset_reader_callbacks(self._child_stdout)
os.close(self._child_stdout)
self._child_stdout = None
if not same:
with contextlib.suppress(OSError):
if not self._stderr_eof:
tty_unset_reader_callbacks(self._child_stderr)
os.close(self._child_stderr)
self._child_stdout = None
self._log.debug("terminate()")
if not self.running:
return
@ -477,7 +496,7 @@ class CallbackSubprocess:
os.waitpid(self._pid, 0)
self._log.debug("wait() finish")
threading.Thread(target=wait).start()
threading.Thread(target=wait, daemon=True).start()
def close_stdin(self):
with contextlib.suppress(Exception):
@ -592,26 +611,44 @@ class CallbackSubprocess:
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
def reader(fd: int, callback: callable):
def stdout():
try:
with exception.permit(SystemExit):
data = tty_read(fd)
data = tty_read_poll(self._child_stdout)
if data is not None and len(data) > 0:
callback(data)
self._stdout_cb(data)
except EOFError:
self._eof = True
self._stdout_eof = True
tty_unset_reader_callbacks(self._child_stdout)
callback(bytearray())
self._stdout_cb(bytearray())
with contextlib.suppress(OSError):
os.close(self._child_stdout)
tty_add_reader_callback(self._child_stdout, functools.partial(reader, self._child_stdout, self._stdout_cb),
self._loop)
def stderr():
try:
with exception.permit(SystemExit):
data = tty_read_poll(self._child_stderr)
if data is not None and len(data) > 0:
self._stderr_cb(data)
except EOFError:
self._stderr_eof = True
tty_unset_reader_callbacks(self._child_stderr)
self._stdout_cb(bytearray())
with contextlib.suppress(OSError):
os.close(self._child_stderr)
tty_add_reader_callback(self._child_stdout, stdout, self._loop)
if self._child_stderr != self._child_stdout:
tty_add_reader_callback(self._child_stderr, functools.partial(reader, self._child_stderr, self._stderr_cb),
self._loop)
tty_add_reader_callback(self._child_stderr, stderr, self._loop)
@property
def eof(self):
return self._eof or not self.running
def stdout_eof(self):
return self._stdout_eof or not self.running
@property
def stderr_eof(self):
return self._stderr_eof or not self.running
@property
def return_code(self) -> int:

View file

@ -15,6 +15,7 @@ import abc
import contextlib
import struct
import logging as __logging
from abc import ABC, abstractmethod
module_logger = __logging.getLogger(__name__)
@ -22,13 +23,50 @@ module_logger = __logging.getLogger(__name__)
_TReceipt = TypeVar("_TReceipt")
_TLink = TypeVar("_TLink")
MSG_MAGIC = 0xac
PROTOCOL_VERSION=1
PROTOCOL_VERSION = 1
def _make_MSGTYPE(val: int):
return ((MSG_MAGIC << 8) & 0xff00) | (val & 0x00ff)
class MessageOutletBase(ABC):
@abstractmethod
def send(self, raw: bytes) -> _TReceipt:
raise NotImplemented()
@property
@abstractmethod
def mdu(self):
raise NotImplemented()
@property
@abstractmethod
def rtt(self):
raise NotImplemented()
@property
@abstractmethod
def is_usuable(self):
raise NotImplemented()
@abstractmethod
def get_receipt_state(self, receipt: _TReceipt) -> MessageState:
raise NotImplemented()
@abstractmethod
def timed_out(self):
raise NotImplemented()
@abstractmethod
def __str__(self):
raise NotImplemented()
@abstractmethod
def set_packet_received_callback(self, cb: Callable[[MessageOutletBase, bytes], None]):
raise NotImplemented()
class METype(enum.IntEnum):
ME_NO_MSG_TYPE = 0
ME_INVALID_MSG_TYPE = 1
@ -58,17 +96,17 @@ class Message(abc.ABC):
self.msgid = uuid.uuid4()
self.raw: bytes | None = None
self.receipt: _TReceipt = None
self.link: _TLink = None
self.outlet: _TLink = None
self.tracked: bool = False
def __str__(self):
return f"{self.__class__.__name__} {self.msgid}"
@abc.abstractmethod
@abstractmethod
def pack(self) -> bytes:
raise NotImplemented()
@abc.abstractmethod
@abstractmethod
def unpack(self, raw):
raise NotImplemented()
@ -124,7 +162,8 @@ class ExecuteCommandMesssage(Message):
MSGTYPE = _make_MSGTYPE(3)
def __init__(self, cmdline: [str] = None, pipe_stdin: bool = False, pipe_stdout: bool = False,
pipe_stderr: bool = False, tcflags: [any] = None, term: str | None = None):
pipe_stderr: bool = False, tcflags: [any] = None, term: str | None = None, rows: int = None,
cols: int = None, hpix: int = None, vpix: int = None):
super().__init__()
self.cmdline = cmdline
self.pipe_stdin = pipe_stdin
@ -132,16 +171,20 @@ class ExecuteCommandMesssage(Message):
self.pipe_stderr = pipe_stderr
self.tcflags = tcflags
self.term = term
self.rows = rows
self.cols = cols
self.hpix = hpix
self.vpix = vpix
def pack(self) -> bytes:
raw = umsgpack.packb((self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr,
self.tcflags, self.term))
self.tcflags, self.term, self.rows, self.cols, self.hpix, self.vpix))
return self.wrap_MSGTYPE(raw)
def unpack(self, raw):
raw = self.unwrap_MSGTYPE(raw)
self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, self.tcflags, self.term \
= umsgpack.unpackb(raw)
self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, self.tcflags, self.term, self.rows, \
self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw)
class StreamDataMessage(Message):
MSGTYPE = _make_MSGTYPE(4)
@ -156,7 +199,7 @@ class StreamDataMessage(Message):
self.eof = eof
def pack(self) -> bytes:
raw = umsgpack.packb((self.stream_id, self.eof, self.data))
raw = umsgpack.packb((self.stream_id, self.eof, bytes(self.data)))
return self.wrap_MSGTYPE(raw)
def unpack(self, raw):
@ -169,7 +212,7 @@ class VersionInfoMessage(Message):
def __init__(self, sw_version: str = None):
super().__init__()
self.sw_version = sw_version
self.sw_version = sw_version or rnsh.__version__
self.protocol_version = PROTOCOL_VERSION
def pack(self) -> bytes:
@ -199,6 +242,22 @@ class ErrorMessage(Message):
self.msg, self.fatal, self.data = umsgpack.unpackb(raw)
class CommandExitedMessage(Message):
MSGTYPE = _make_MSGTYPE(7)
def __init__(self, return_code: int = None):
super().__init__()
self.return_code = return_code
def pack(self) -> bytes:
raw = umsgpack.packb(self.return_code)
return self.wrap_MSGTYPE(raw)
def unpack(self, raw: bytes):
raw = self.unwrap_MSGTYPE(raw)
self.return_code = umsgpack.unpackb(raw)
class Messenger(contextlib.AbstractContextManager):
@staticmethod
@ -208,29 +267,16 @@ class Messenger(contextlib.AbstractContextManager):
subclass_tuples.append((subclass.MSGTYPE, subclass))
return subclass_tuples
def __init__(self, receipt_checker: Callable[[_TReceipt], MessageState],
link_timeout_callback: Callable[[_TLink], None],
link_mdu_getter: Callable[[_TLink], int],
link_rtt_getter: Callable[[_TLink], float],
link_usable_getter: Callable[[_TLink], bool],
packet_sender: Callable[[_TLink, bytes], _TReceipt],
retry_delay_min: float = 10.0):
def __init__(self, retry_delay_min: float = 10.0):
self._log = module_logger.getChild(self.__class__.__name__)
self._receipt_checker = receipt_checker
self._link_timeout_callback = link_timeout_callback
self._link_mdu_getter = link_mdu_getter
self._link_rtt_getter = link_rtt_getter
self._link_usable_getter = link_usable_getter
self._packet_sender = packet_sender
self._sent_messages: list[Message] = []
self._lock = threading.RLock()
self._retry_timer = rnsh.retry.RetryThread()
self._message_factories = dict(self.__class__._get_msg_constructors())
self._inbound_queue = queue.Queue()
self._retry_delay_min = retry_delay_min
def __enter__(self):
pass
def __enter__(self) -> Messenger:
return self
def __exit__(self, __exc_type: Type[BaseException] | None, __exc_value: BaseException | None,
__traceback: TracebackType | None) -> bool | None:
@ -238,39 +284,38 @@ class Messenger(contextlib.AbstractContextManager):
return False
def shutdown(self):
self._run = False
self._retry_timer.close()
def inbound(self, raw: bytes):
def clear_retries(self, outlet):
self._retry_timer.complete(outlet)
def receive(self, raw: bytes) -> Message:
(mid, contents) = Message.static_unwrap_MSGTYPE(raw)
ctor = self._message_factories.get(mid, None)
if ctor is None:
raise MessagingException(METype.ME_NOT_REGISTERED, f"unable to find constructor for message type {hex(mid)}")
message = ctor()
message.unpack(raw)
self._log.debug("Message received: {message}")
self._inbound_queue.put(message)
self._log.debug(f"Message received: {message}")
return message
def get_mdu(self, link: _TLink) -> int:
return self._link_mdu_getter(link) - 4
def get_rtt(self, link: _TLink) -> float:
return self._link_rtt_getter(link)
def is_link_ready(self, link: _TLink) -> bool:
if not self._link_usable_getter(link):
def is_outlet_ready(self, outlet: MessageOutletBase) -> bool:
if not outlet.is_usuable:
self._log.debug("is_outlet_ready outlet unusable")
return False
with self._lock:
for message in self._sent_messages:
if message.link == link:
if message.outlet == outlet and message.tracked and message.receipt \
and outlet.get_receipt_state(message.receipt) == MessageState.MSGSTATE_SENT:
self._log.debug("is_outlet_ready pending message found")
return False
return True
def send_message(self, link: _TLink, message: Message):
def send(self, outlet: MessageOutletBase, message: Message):
with self._lock:
if not self.is_link_ready(link):
raise MessagingException(METype.ME_LINK_NOT_READY, f"link {link} not ready")
if not self.is_outlet_ready(outlet):
raise MessagingException(METype.ME_LINK_NOT_READY, f"link {outlet} not ready")
if message in self._sent_messages:
raise MessagingException(METype.ME_ALREADY_SENT)
@ -279,43 +324,36 @@ class Messenger(contextlib.AbstractContextManager):
if not message.raw:
message.raw = message.pack()
message.link = link
message.outlet = outlet
def send(tag: any, tries: int):
state = MessageState.MSGSTATE_NEW if not message.receipt else self._receipt_checker(message.receipt)
def send_inner(tag: any, tries: int):
state = MessageState.MSGSTATE_NEW if not message.receipt else outlet.get_receipt_state(message.receipt)
if state in [MessageState.MSGSTATE_NEW, MessageState.MSGSTATE_FAILED]:
try:
self._log.debug(f"Sending packet for {message}")
message.receipt = self._packet_sender(link, message.raw)
message.receipt = outlet.send(message.raw)
except Exception as ex:
self._log.exception(f"Error sending message {message}")
elif state in [MessageState.MSGSTATE_SENT]:
self._log.debug(f"Retry skipped, message still pending {message}")
elif state in [MessageState.MSGSTATE_DELIVERED]:
latency = round(time.time() - message.ts, 1)
self._log.debug(f"Message delivered {message.msgid} after {tries-1} tries/{latency} seconds")
self._log.debug(f"{message} delivered {message.msgid} after {tries-1} tries/{latency} seconds")
with self._lock:
self._sent_messages.remove(message)
message.tracked = False
self._retry_timer.complete(link)
return link
self._retry_timer.complete(outlet)
return outlet
def timeout(tag: any, tries: int):
latency = round(time.time() - message.ts, 1)
msg = "delivered" if message.receipt and self._receipt_checker(message.receipt) == MessageState.MSGSTATE_DELIVERED else "retry timeout"
msg = "delivered" if message.receipt and outlet.get_receipt_state(message.receipt) == MessageState.MSGSTATE_DELIVERED else "retry timeout"
self._log.debug(f"Message {msg} {message} after {tries} tries/{latency} seconds")
with self._lock:
self._sent_messages.remove(message)
message.tracked = False
self._link_timeout_callback(link)
rtt = self._link_rtt_getter(link)
self._retry_timer.begin(5, min(rtt * 100, max(rtt * 2, self._retry_delay_min)), send, timeout)
def poll_inbound(self, block: bool = True, timeout: float = None) -> Message | None:
try:
return self._inbound_queue.get(block=block, timeout=timeout)
except queue.Empty:
return None
outlet.timed_out()
rtt = outlet.rtt
self._retry_timer.begin(5, max(rtt * 5, self._retry_delay_min), send_inner, timeout)

View file

@ -79,7 +79,7 @@ class RetryThread(AbstractContextManager):
self._lock = threading.RLock()
self._run = True
self._finished: asyncio.Future = None
self._thread = threading.Thread(name=name, target=self._thread_run)
self._thread = threading.Thread(name=name, target=self._thread_run, daemon=True)
self._thread.start()
def is_alive(self):

View file

@ -26,10 +26,12 @@ from __future__ import annotations
import asyncio
import base64
import enum
import functools
import importlib.metadata
import logging as __logging
import os
import queue
import shlex
import signal
import sys
@ -44,10 +46,12 @@ import rnsh.process as process
import rnsh.retry as retry
import rnsh.rnslogging as rnslogging
import rnsh.hacks as hacks
import rnsh.session as session
import re
import contextlib
import rnsh.args
import pwd
import rnsh.protocol as protocol
module_logger = __logging.getLogger(__name__)
@ -141,13 +145,18 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
log.info(f"Using command {shlex.join(_cmd)}")
_no_remote_command = no_remote_command
session.ListenerSession.allow_remote_command = not no_remote_command
_remote_cmd_as_args = remote_cmd_as_args
if (_cmd is None or len(_cmd) == 0 or _cmd[0] is None or len(_cmd[0]) == 0) \
and (_no_remote_command or _remote_cmd_as_args):
raise Exception(f"Unable to look up shell for {os.getlogin}, cannot proceed with -A or -C and no <program>.")
session.ListenerSession.default_command = _cmd
session.ListenerSession.remote_cmd_as_args = _remote_cmd_as_args
if disable_auth:
_allow_all = True
session.ListenerSession.allow_all = True
else:
if allowed is not None:
for a in allowed:
@ -161,6 +170,7 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
try:
destination_hash = bytes.fromhex(a)
_allowed_identity_hashes.append(destination_hash)
session.ListenerSession.allowed_identity_hashes.append(destination_hash)
except Exception:
raise ValueError("Invalid destination entered. Check your input.")
except Exception as e:
@ -170,23 +180,9 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
if len(_allowed_identity_hashes) < 1 and not disable_auth:
log.warning("Warning: No allowed identities configured, rnsh will not accept any connections!")
_destination.set_link_established_callback(_listen_link_established)
if not _allow_all:
_destination.register_request_handler(
path="data",
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
# response_generator=_listen_request,
allow=RNS.Destination.ALLOW_LIST,
allowed_list=_allowed_identity_hashes
)
else:
_destination.register_request_handler(
path="data",
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
# response_generator=_listen_request,
allow=RNS.Destination.ALLOW_ALL,
)
def link_established(lnk: RNS.Link):
session.ListenerSession(session.RNSOutlet.get_outlet(lnk), _loop)
_destination.set_link_established_callback(link_established)
if await _check_finished():
return
@ -199,531 +195,22 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
last = time.time()
try:
while not await _check_finished(1.0):
while not await _check_finished():
if announce_period and 0 < announce_period < time.time() - last:
last = time.time()
_destination.announce()
await session.ListenerSession.pump_all()
await asyncio.sleep(0.01)
finally:
log.warning("Shutting down")
for link in list(_destination.links):
with exception.permit(SystemExit, KeyboardInterrupt):
proc = Session.get_for_tag(link.link_id)
if proc is not None and proc.process.running:
proc.process.terminate()
await asyncio.sleep(0)
await session.ListenerSession.terminate_all("Shutting down")
await asyncio.sleep(1)
session.ListenerSession.messenger.shutdown()
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
for link in links_still_active:
if link.status != RNS.Link.CLOSED:
if link.status not in [RNS.Link.CLOSED]:
link.teardown()
await asyncio.sleep(0)
_PROTOCOL_MAGIC = 0xdeadbeef
def _protocol_make_version(version: int):
return (_PROTOCOL_MAGIC << 32) & 0xffffffff00000000 | (0xffffffff & version)
_PROTOCOL_VERSION_0 = _protocol_make_version(0)
_PROTOCOL_VERSION_1 = _protocol_make_version(1)
_PROTOCOL_VERSION_2 = _protocol_make_version(2)
_PROTOCOL_VERSION_3 = _protocol_make_version(3)
_PROTOCOL_VERSION_DEFAULT = _PROTOCOL_VERSION_3
def _protocol_split_version(version: int):
return (version >> 32) & 0xffffffff, version & 0xffffffff
def _protocol_check_magic(value: int):
return _protocol_split_version(value)[0] == _PROTOCOL_MAGIC
def _protocol_response_chars_take(link_mdu: int, version: int) -> int:
if version >= _PROTOCOL_VERSION_2:
return link_mdu - 64 # TODO: tune
else:
return link_mdu // 2
def _protocol_request_chars_take(link_mdu: int, version: int, term: str, cmd: str) -> int:
if version >= _PROTOCOL_VERSION_2:
return link_mdu - 15 * 8 - len(term) - len(cmd) - 20 # TODO: tune
else:
return link_mdu // 2
def _bitwise_or_if(value: int, condition: bool, orval: int):
if not condition:
return value
return value | orval
def _check_and(value: int, andval: int) -> bool:
return (value & andval) > 0
class Session:
_processes: [(any, Session)] = []
_lock = threading.RLock()
@classmethod
def get_for_tag(cls, tag: any) -> Session | None:
with cls._lock:
return next(map(lambda p: p[1], filter(lambda p: p[0] == tag, cls._processes)), None)
@classmethod
def put_for_tag(cls, tag: any, ps: Session):
with cls._lock:
cls.clear_tag(tag)
cls._processes.append((tag, ps))
@classmethod
def clear_tag(cls, tag: any):
with cls._lock:
with exception.permit(SystemExit):
cls._processes.remove(tag)
def __init__(self,
tag: any,
cmd: [str],
data_available_callback: callable,
terminated_callback: callable,
session_flags: int,
term: str | None,
remote_identity: str | None,
loop: asyncio.AbstractEventLoop = None):
self._log = _get_logger(self.__class__.__name__)
self._loop = loop if loop is not None else asyncio.get_running_loop()
self._process = process.CallbackSubprocess(argv=cmd,
env={"TERM": term or os.environ.get("TERM", None),
"RNS_REMOTE_IDENTITY": remote_identity or ""},
loop=loop,
stdout_callback=self._stdout_data,
stderr_callback=self._stderr_data,
terminated_callback=terminated_callback,
stdin_is_pipe=_check_and(session_flags,
Session.REQUEST_FLAGS_PIPE_STDIN),
stdout_is_pipe=_check_and(session_flags,
Session.REQUEST_FLAGS_PIPE_STDOUT),
stderr_is_pipe=_check_and(session_flags,
Session.REQUEST_FLAGS_PIPE_STDERR))
self._log.debug(f"Starting {cmd}")
self._stdout_buffer = bytearray()
self._stderr_buffer = bytearray()
self._lock = threading.RLock()
self._data_available_cb = data_available_callback
self._terminated_cb = terminated_callback
self._pending_receipt: RNS.PacketReceipt | None = None
self._process.start()
self._term_state: [int] = None
self._session_flags: int = session_flags
Session.put_for_tag(tag, self)
def pending_receipt_peek(self) -> RNS.PacketReceipt | None:
return self._pending_receipt
def pending_receipt_take(self) -> RNS.PacketReceipt | None:
with self._lock:
val = self._pending_receipt
self._pending_receipt = None
return val
def pending_receipt_put(self, receipt: RNS.PacketReceipt | None):
with self._lock:
self._pending_receipt = receipt
@property
def process(self) -> process.CallbackSubprocess:
return self._process
@property
def return_code(self) -> int | None:
return self.process.return_code
@property
def lock(self) -> threading.RLock:
return self._lock
def read_stdout(self, count: int) -> bytes:
with self.lock:
initial_len = len(self._stdout_buffer)
take = self._stdout_buffer[:count]
self._stdout_buffer = self._stdout_buffer[count:]
self._log.debug(f"stdout: read {len(take)} bytes of {initial_len}, {len(self._stdout_buffer)} remaining")
return take
def _stdout_data(self, data: bytes):
with self.lock:
self._stdout_buffer.extend(data)
total_available = len(self._stdout_buffer) + len(self._stderr_buffer)
try:
self._data_available_cb(total_available)
except Exception as e:
self._log.error(f"stdout: error calling ProcessState data_available_callback {e}")
def read_stderr(self, count: int) -> bytes:
with self.lock:
initial_len = len(self._stderr_buffer)
take = self._stderr_buffer[:count]
self._stderr_buffer = self._stderr_buffer[count:]
self._log.debug(f"stderr: read {len(take)} bytes of {initial_len}, {len(self._stderr_buffer)} remaining")
return take
def _stderr_data(self, data: bytes):
with self.lock:
self._stderr_buffer.extend(data)
total_available = len(self._stderr_buffer) + len(self._stdout_buffer)
try:
self._data_available_cb(total_available)
except Exception as e:
self._log.error(f"stderr: error calling ProcessState data_available_callback {e}")
TERMSTATE_IDX_TERM = 0
TERMSTATE_IDX_TIOS = 1
TERMSTATE_IDX_ROWS = 2
TERMSTATE_IDX_COLS = 3
TERMSTATE_IDX_HPIX = 4
TERMSTATE_IDX_VPIX = 5
def _update_winsz(self):
try:
self.process.set_winsize(self._term_state[1],
self._term_state[2],
self._term_state[3],
self._term_state[4])
except Exception as e:
self._log.debug(f"failed to update winsz: {e}")
REQUEST_IDX_VERS = 0
REQUEST_IDX_STDIN = 1
REQUEST_IDX_TERM = 2
REQUEST_IDX_TIOS = 3
REQUEST_IDX_ROWS = 4
REQUEST_IDX_COLS = 5
REQUEST_IDX_HPIX = 6
REQUEST_IDX_VPIX = 7
REQUEST_IDX_CMD = 8
REQUEST_IDX_FLAGS = 9
REQUEST_IDX_BYTES_AVAILABLE = 10
REQUEST_FLAGS_PIPE_STDIN = 0x01
REQUEST_FLAGS_PIPE_STDOUT = 0x02
REQUEST_FLAGS_PIPE_STDERR = 0x04
REQUEST_FLAGS_EOF_STDIN = 0x08
@staticmethod
def default_request() -> [any]:
global _tr
request: list[any] = [
_PROTOCOL_VERSION_DEFAULT, # 0 Protocol Version
None, # 1 Stdin
None, # 2 TERM variable
None, # 3 termios attributes or something
None, # 4 terminal rows
None, # 5 terminal cols
None, # 6 terminal horizontal pixels
None, # 7 terminal vertical pixels
None, # 8 Command to run
0, # 9 Flags
0, # 10 Bytes Available
].copy()
if os.isatty(0):
request[Session.REQUEST_IDX_TERM] = os.environ.get("TERM", None)
request[Session.REQUEST_IDX_TIOS] = _tr.original_attr() if _tr else None
with contextlib.suppress(OSError):
request[Session.REQUEST_IDX_ROWS], \
request[Session.REQUEST_IDX_COLS], \
request[Session.REQUEST_IDX_HPIX], \
request[Session.REQUEST_IDX_VPIX] = process.tty_get_winsize(0)
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], not os.isatty(0),
Session.REQUEST_FLAGS_PIPE_STDIN)
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], not os.isatty(1),
Session.REQUEST_FLAGS_PIPE_STDOUT)
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], not os.isatty(2),
Session.REQUEST_FLAGS_PIPE_STDERR)
return request
def process_request(self, data: [any], read_size: int) -> [any]:
stdin = data[Session.REQUEST_IDX_STDIN] # Data passed to stdin
# term = data[ProcessState.REQUEST_IDX_TERM] # TERM environment variable
# tios = data[ProcessState.REQUEST_IDX_TIOS] # termios attr
# rows = data[ProcessState.REQUEST_IDX_ROWS] # window rows
# cols = data[ProcessState.REQUEST_IDX_COLS] # window cols
# hpix = data[ProcessState.REQUEST_IDX_HPIX] # window horizontal pixels
# vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels
# term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX+1]
bytes_available = data[Session.REQUEST_IDX_BYTES_AVAILABLE]
flags = data[Session.REQUEST_IDX_FLAGS]
stdin_eof = _check_and(flags, Session.REQUEST_FLAGS_EOF_STDIN)
response = Session.default_response()
first_term_state = self._term_state is None
term_state = data[Session.REQUEST_IDX_TIOS:Session.REQUEST_IDX_VPIX + 1]
response[Session.RESPONSE_IDX_FLAGS] = _bitwise_or_if(response[Session.RESPONSE_IDX_FLAGS],
self.process.running, Session.RESPONSE_FLAGS_RUNNING)
if self.process.running:
if term_state != self._term_state:
self._term_state = term_state
if term_state is not None:
self._update_winsz()
if first_term_state is not None:
# TODO: use a more specific error
with contextlib.suppress(Exception):
self.process.tcsetattr(termios.TCSADRAIN, term_state[0])
if stdin is not None and len(stdin) > 0:
if data[Session.REQUEST_IDX_VERS] < _PROTOCOL_VERSION_2:
stdin = base64.b64decode(stdin)
self.process.write(stdin)
if stdin_eof and bytes_available == 0:
module_logger.debug("Closing stdin")
with contextlib.suppress(Exception):
self.process.close_stdin()
response[Session.RESPONSE_IDX_RETCODE] = None if self.process.running else self.return_code
with self.lock:
#prioritizing stderr
stderr = self.read_stderr(read_size)
stdout = self.read_stdout(read_size - len(stderr))
response[Session.RESPONSE_IDX_RDYBYTE] = len(self._stdout_buffer) + len(self._stderr_buffer)
if stderr is not None and len(stderr) > 0:
response[Session.RESPONSE_IDX_STDERR] = bytes(stderr)
if stdout is not None and len(stdout) > 0:
response[Session.RESPONSE_IDX_STDOUT] = bytes(stdout)
return response
RESPONSE_IDX_VERSION = 0
RESPONSE_IDX_FLAGS = 1
RESPONSE_IDX_RETCODE = 2
RESPONSE_IDX_RDYBYTE = 3
RESPONSE_IDX_STDERR = 4
RESPONSE_IDX_STDOUT = 5
RESPONSE_IDX_TMSTAMP = 6
RESPONSE_FLAGS_RUNNING = 0x01
RESPONSE_FLAGS_EOF_STDOUT = 0x02
RESPONSE_FLAGS_EOF_STDERR = 0x04
@staticmethod
def default_response() -> [any]:
response: list[any] = [
_PROTOCOL_VERSION_DEFAULT, # 0: Protocol version
False, # 1: Process running
None, # 2: Return value
0, # 3: Number of outstanding bytes
None, # 4: Stderr
None, # 5: Stdout
None, # 6: Timestamp
].copy()
response[Session.RESPONSE_IDX_TMSTAMP] = time.time()
return response
@classmethod
def error_response(cls, msg: str) -> [any]:
response = cls.default_response()
msg_bytes = f"{msg}\r\n".encode("utf-8")
response[Session.RESPONSE_IDX_STDERR] = bytes(msg_bytes)
response[Session.RESPONSE_IDX_RETCODE] = 255
response[Session.RESPONSE_IDX_RDYBYTE] = 0
return response
def _subproc_data_ready(link: RNS.Link, chars_available: int):
global _retry_timer
log = _get_logger("_subproc_data_ready")
session: Session = Session.get_for_tag(link.link_id)
def send(timeout: bool, tag: any, tries: int) -> any:
# log.debug("send")
def inner():
# log.debug("inner")
try:
if link.status != RNS.Link.ACTIVE:
_retry_timer.complete(link.link_id)
session.pending_receipt_take()
return
pr = session.pending_receipt_take()
log.debug(f"send inner pr: {pr}")
if pr is not None and pr.status == RNS.PacketReceipt.DELIVERED:
if not timeout:
_retry_timer.complete(tag)
log.debug(f"Notification completed with status {pr.status} on link {link}")
return
else:
if not timeout:
log.debug(
f"Notifying client try {tries} (retcode: {session.return_code} " +
f"chars avail: {chars_available})")
packet = RNS.Packet(link, DATA_AVAIL_MSG.encode("utf-8"))
packet.send()
pr = packet.receipt
session.pending_receipt_put(pr)
else:
log.error(f"Retry count exceeded, terminating link {link}")
_retry_timer.complete(link.link_id)
link.teardown()
except Exception as e:
log.error("Error notifying client: " + str(e))
_loop.call_soon_threadsafe(inner)
return link.link_id
with session.lock:
if not _retry_timer.has_tag(link.link_id):
_retry_timer.begin(try_limit=15,
wait_delay=max(link.rtt * 5 if link.rtt is not None else 1, 1),
try_callback=functools.partial(send, False),
timeout_callback=functools.partial(send, True))
else:
log.debug(f"Notification already pending for link {link}")
def _subproc_terminated(link: RNS.Link, return_code: int):
global _loop
log = _get_logger("_subproc_terminated")
log.info(f"Subprocess returned {return_code} for link {link}")
proc = Session.get_for_tag(link.link_id)
if proc is None:
log.debug(f"no proc for link {link}")
return
def cleanup():
def inner():
log.debug(f"cleanup culled link {link}")
if link and link.status != RNS.Link.CLOSED:
with exception.permit(SystemExit):
try:
link.teardown()
finally:
Session.clear_tag(link.link_id)
_loop.call_later(300, inner)
_loop.call_soon(_subproc_data_ready, link, 0)
_loop.call_soon_threadsafe(cleanup)
def _listen_start_proc(link: RNS.Link, remote_identity: str | None, term: str, cmd: [str],
loop: asyncio.AbstractEventLoop, session_flags: int) -> Session | None:
log = _get_logger("_listen_start_proc")
try:
return Session(tag=link.link_id,
cmd=cmd,
session_flags=session_flags,
term=term,
remote_identity=remote_identity,
loop=loop,
data_available_callback=functools.partial(_subproc_data_ready, link),
terminated_callback=functools.partial(_subproc_terminated, link))
except Exception as e:
log.error("Failed to launch process: " + str(e))
_subproc_terminated(link, 255)
return None
def _listen_link_established(link):
global _allow_all
log = _get_logger("_listen_link_established")
link.set_remote_identified_callback(_initiator_identified)
link.set_link_closed_callback(_listen_link_closed)
log.info("Link " + str(link) + " established")
def _listen_link_closed(link: RNS.Link):
log = _get_logger("_listen_link_closed")
# async def cleanup():
log.info("Link " + str(link) + " closed")
proc: Session | None = Session.get_for_tag(link.link_id)
if proc is None:
log.warning(f"No process for link {link}")
else:
try:
proc.process.terminate()
_retry_timer.complete(link.link_id)
except Exception as e:
log.error(f"Error closing process for link {link}: {e}")
Session.clear_tag(link.link_id)
def _initiator_identified(link, identity):
global _allow_all, _cmd, _loop
log = _get_logger("_initiator_identified")
log.info("Initiator of link " + str(link) + " identified as " + RNS.prettyhexrep(identity.hash))
if not _allow_all and identity.hash not in _allowed_identity_hashes:
log.warning("Identity " + RNS.prettyhexrep(identity.hash) + " not allowed, tearing down link", RNS.LOG_WARNING)
link.teardown()
def _listen_request(path, data, request_id, link_id, remote_identity, requested_at):
global _destination, _retry_timer, _loop, _cmd, _no_remote_command
log = _get_logger("_listen_request")
log.debug(
f"listen_execute {path} {RNS.prettyhexrep(request_id)} {RNS.prettyhexrep(link_id)} {remote_identity}, {requested_at}")
if not hasattr(data, "__len__") or len(data) < 1:
raise Exception("Request data invalid")
_retry_timer.complete(link_id)
link: RNS.Link = next(filter(lambda l: l.link_id == link_id, _destination.links), None)
if link is None:
log.error(f"Invalid request {request_id}, no link found with id {link_id}")
return None
remote_version = data[Session.REQUEST_IDX_VERS]
if not _protocol_check_magic(remote_version):
log.error("Request magic incorrect")
link.teardown()
return None
if not remote_version <= _PROTOCOL_VERSION_3:
return Session.error_response("Listener<->initiator version mismatch")
cmd = _cmd.copy()
remote_command = data[Session.REQUEST_IDX_CMD]
if remote_command is not None and len(remote_command) > 0:
if _no_remote_command:
return Session.error_response("Listener does not permit initiator to provide command.")
elif _remote_cmd_as_args:
cmd.extend(remote_command)
else:
cmd = remote_command
if not _no_remote_command and (cmd is None or len(cmd) == 0):
return Session.error_response("No command supplied and no default command available.")
session: Session | None = None
try:
term = data[Session.REQUEST_IDX_TERM]
# sanitize
if term is not None:
term = re.sub('[^A-Za-z-0-9\-\_]','', term)
session = Session.get_for_tag(link.link_id)
if session is None:
log.debug(f"Process not found for link {link}")
session = _listen_start_proc(link=link,
term=term,
session_flags=data[Session.REQUEST_IDX_FLAGS],
cmd=cmd,
remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""),
loop=_loop)
if session is None:
return Session.error_response("Unable to start subprocess")
# leave significant headroom for metadata and encoding
result = session.process_request(data, _protocol_response_chars_take(link.MDU, remote_version))
return result
# return ProcessState.default_response()
except Exception as e:
log.error(f"Error procesing request for link {link}: {e}")
try:
if session is not None and session.process.running:
session.process.terminate()
except Exception as ee:
log.debug(f"Error terminating process for link {link}: {ee}")
return Session.default_response()
await asyncio.sleep(0.01)
async def _spin(until: callable = None, timeout: float | None = None) -> bool:
@ -731,7 +218,7 @@ async def _spin(until: callable = None, timeout: float | None = None) -> bool:
timeout += time.time()
while (timeout is None or time.time() < timeout) and not until():
if await _check_finished(0.001):
if await _check_finished(0.01):
raise asyncio.CancelledError()
if timeout is not None and time.time() > timeout:
return False
@ -744,15 +231,28 @@ _remote_exec_grace = 2.0
_new_data: asyncio.Event | None = None
_tr: process.TTYRestorer | None = None
_pq = queue.Queue()
class InitiatorState(enum.IntEnum):
IS_INITIAL = 0
IS_LINKED = 1
IS_WAIT_VERS = 2
IS_RUNNING = 3
IS_TERMINATE = 4
IS_TEARDOWN = 5
def _client_link_closed(link):
log = _get_logger("_client_link_closed")
_finished.set()
def _client_packet_handler(message, packet):
global _new_data
log = _get_logger("_client_packet_handler")
if message is not None and message.decode("utf-8") == DATA_AVAIL_MSG and _new_data is not None:
log.debug("data available")
_new_data.set()
else:
log.error(f"received unhandled packet")
packet.prove()
_pq.put(message)
class RemoteExecutionError(Exception):
@ -760,15 +260,10 @@ class RemoteExecutionError(Exception):
self.msg = msg
def _response_handler(request_receipt: RNS.RequestReceipt):
pass
async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=False, destination=None,
service_name="default", stdin=None, timeout=RNS.Transport.PATH_REQUEST_TIMEOUT,
cmd: [str] | None = None, stdin_eof=False, bytes_available=0):
async def _initiate_link(configdir, identitypath=None, verbosity=0, quietness=0, noid=False, destination=None,
service_name="default", timeout=RNS.Transport.PATH_REQUEST_TIMEOUT):
global _identity, _reticulum, _link, _destination, _remote_exec_grace, _tr, _new_data
log = _get_logger("_execute")
log = _get_logger("_initiate_link")
dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2
if len(destination) != dest_len:
@ -809,6 +304,8 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
_link = RNS.Link(_destination)
_link.did_identify = False
_link.set_link_closed_callback(_client_link_closed)
log.info(f"Establishing link...")
if not await _spin(until=lambda: _link.status == RNS.Link.ACTIVE, timeout=timeout):
raise RemoteExecutionError("Could not establish link with " + RNS.prettyhexrep(destination_hash))
@ -820,110 +317,6 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
_link.set_packet_callback(_client_packet_handler)
request = Session.default_request()
log.debug(f"Sending {len(stdin) or 0} bytes to listener")
# log.debug(f"Sending {stdin} to listener")
request[Session.REQUEST_IDX_STDIN] = bytes(stdin)
request[Session.REQUEST_IDX_CMD] = cmd
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], stdin_eof,
Session.REQUEST_FLAGS_EOF_STDIN)
request[Session.REQUEST_IDX_BYTES_AVAILABLE] = bytes_available
# TODO: Tune
timeout = timeout + _link.rtt * 4 + _remote_exec_grace
log.debug("Sending request")
request_receipt = _link.request(
path="data",
data=request,
timeout=timeout
)
timeout += 0.5
log.debug("Waiting for delivery")
await _spin(
until=lambda: _link.status == RNS.Link.CLOSED or (
request_receipt.status != RNS.RequestReceipt.FAILED and
request_receipt.status != RNS.RequestReceipt.SENT),
timeout=timeout
)
if _link.status == RNS.Link.CLOSED:
raise RemoteExecutionError("Could not request remote execution, link was closed")
if request_receipt.status == RNS.RequestReceipt.FAILED:
raise RemoteExecutionError("Could not request remote execution")
await _spin(
until=lambda: request_receipt.status != RNS.RequestReceipt.DELIVERED,
timeout=timeout
)
if request_receipt.status == RNS.RequestReceipt.FAILED:
raise RemoteExecutionError("No result was received")
if request_receipt.status == RNS.RequestReceipt.FAILED:
raise RemoteExecutionError("Receiving result failed")
if request_receipt.response is not None:
try:
version = request_receipt.response[Session.RESPONSE_IDX_VERSION] or 0
if not _protocol_check_magic(version):
raise RemoteExecutionError("Protocol error")
elif version != _PROTOCOL_VERSION_3:
raise RemoteExecutionError("Protocol version mismatch")
flags = request_receipt.response[Session.RESPONSE_IDX_FLAGS]
running = _check_and(flags, Session.RESPONSE_FLAGS_RUNNING)
stdout_eof = _check_and(flags, Session.RESPONSE_FLAGS_EOF_STDOUT)
stderr_eof = _check_and(flags, Session.RESPONSE_FLAGS_EOF_STDERR)
return_code = request_receipt.response[Session.RESPONSE_IDX_RETCODE]
ready_bytes = request_receipt.response[Session.RESPONSE_IDX_RDYBYTE] or 0
stdout = request_receipt.response[Session.RESPONSE_IDX_STDOUT]
stderr = request_receipt.response[Session.RESPONSE_IDX_STDERR]
# if stdout is not None:
# stdout = base64.b64decode(stdout)
timestamp = request_receipt.response[Session.RESPONSE_IDX_TMSTAMP]
# log.debug("data: " + (stdout.decode("utf-8") if stdout is not None else ""))
except RemoteExecutionError:
raise
except Exception as e:
raise RemoteExecutionError(f"Received invalid response") from e
if stdout is not None:
_tr.raw()
log.debug(f"stdout: {stdout}")
os.write(1, stdout)
sys.stdout.flush()
if stderr is not None:
_tr.raw()
log.debug(f"stderr: {stderr}")
os.write(2, stderr)
sys.stderr.flush()
if stderr_eof and ready_bytes == 0:
log.debug("Closing stderr")
os.close(2)
if stdout_eof and ready_bytes == 0:
log.debug("Closing stdout")
os.close(1)
got_bytes = (len(stdout) if stdout is not None else 0) + (len(stderr) if stderr is not None else 0)
log.debug(f"{got_bytes} chars received, {ready_bytes} bytes ready on server, return code {return_code}")
if ready_bytes > 0:
_new_data.set()
if (not running or return_code is not None) and (ready_bytes == 0):
log.debug(f"returning running: {running}, return_code: {return_code}")
return return_code or 255
return None
_pre_input = bytearray()
async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness: int, noid: bool, destination: str,
service_name: str, timeout: float, command: [str] | None = None):
@ -931,13 +324,52 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
log = _get_logger("_initiate")
loop = asyncio.get_running_loop()
_new_data = asyncio.Event()
state = InitiatorState.IS_INITIAL
data_buffer = bytearray(sys.stdin.buffer.read()) if not os.isatty(sys.stdin.fileno()) else bytearray()
await _initiate_link(
configdir=configdir,
identitypath=identitypath,
verbosity=verbosity,
quietness=quietness,
noid=noid,
destination=destination,
service_name=service_name,
timeout=timeout,
)
if not _link or _link.status != RNS.Link.ACTIVE:
_finished.set()
return 255
state = InitiatorState.IS_LINKED
outlet = session.RNSOutlet(_link)
with protocol.Messenger(retry_delay_min=5) as messenger:
# Next step after linking and identifying: send version
# if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5):
# print("Error bringing up link")
# return 253
messenger.send(outlet, protocol.VersionInfoMessage())
try:
vp = _pq.get(timeout=max(outlet.rtt * 20, 5))
vm = messenger.receive(vp)
if not isinstance(vm, protocol.VersionInfoMessage):
raise Exception("Invalid message received")
log.debug(f"Server version info: sw {vm.sw_version} prot {vm.protocol_version}")
state = InitiatorState.IS_RUNNING
except queue.Empty:
print("Protocol error")
return 254
winch = False
def sigwinch_handler():
nonlocal winch
# log.debug("WindowChanged")
if _new_data is not None:
_new_data.set()
winch = True
# if _new_data is not None:
# _new_data.set()
stdin_eof = False
def stdin():
@ -957,61 +389,98 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
process.tty_add_reader_callback(sys.stdin.fileno(), stdin)
await _check_finished()
loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler)
# leave a lot of overhead
mdu = 64
rtt = 5
first_loop = True
cmdline = " ".join(command or [])
while not await _check_finished():
tcattr = None
rows, cols, hpix, vpix = (None, None, None, None)
try:
log.debug("top of client loop")
stdin = data_buffer[:mdu]
data_buffer = data_buffer[mdu:]
_new_data.clear()
log.debug("before _execute")
return_code = await _execute(
configdir=configdir,
identitypath=identitypath,
verbosity=verbosity,
quietness=quietness,
noid=noid,
destination=destination,
service_name=service_name,
stdin=stdin,
timeout=timeout,
cmd=command,
stdin_eof=stdin_eof,
bytes_available=len(data_buffer)
)
tcattr = termios.tcgetattr(0)
rows, cols, hpix, vpix = process.tty_get_winsize(0)
except:
try:
tcattr = termios.tcgetattr(1)
rows, cols, hpix, vpix = process.tty_get_winsize(1)
except:
try:
tcattr = termios.tcgetattr(2)
rows, cols, hpix, vpix = process.tty_get_winsize(2)
except:
pass
if first_loop:
first_loop = False
mdu = _protocol_request_chars_take(_link.MDU,
_PROTOCOL_VERSION_DEFAULT,
os.environ.get("TERM", ""),
cmdline)
_new_data.set()
messenger.send(outlet, protocol.ExecuteCommandMesssage(cmdline=command,
pipe_stdin=not os.isatty(0),
pipe_stdout=not os.isatty(1),
pipe_stderr=not os.isatty(2),
tcflags=tcattr,
term=os.environ.get("TERM", None),
rows=rows,
cols=cols,
hpix=hpix,
vpix=vpix))
if _link:
rtt = _link.rtt
if return_code is not None:
log.debug(f"received return code {return_code}, exiting")
loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler)
mdu = _link.MDU - 16
sent_eof = False
while not await _check_finished() and state in [InitiatorState.IS_RUNNING]:
try:
try:
packet = _pq.get_nowait()
message = messenger.receive(packet)
if isinstance(message, protocol.StreamDataMessage):
if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDOUT:
if message.data and len(message.data) > 0:
_tr.raw()
log.debug(f"stdout: {message.data}")
os.write(1, message.data)
sys.stdout.flush()
if message.eof:
os.close(1)
if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDERR:
if message.data and len(message.data) > 0:
_tr.raw()
log.debug(f"stdout: {message.data}")
os.write(2, message.data)
sys.stderr.flush()
if message.eof:
os.close(2)
elif isinstance(message, protocol.CommandExitedMessage):
log.debug(f"received return code {message.return_code}, exiting")
with exception.permit(SystemExit, KeyboardInterrupt):
_link.teardown()
return return_code
except asyncio.CancelledError:
if _link and _link.status != RNS.Link.CLOSED:
return message.return_code
elif isinstance(message, protocol.ErrorMessage):
log.error(message.data)
if message.fatal:
_link.teardown()
return 0
return 200
except queue.Empty:
pass
if messenger.is_outlet_ready(outlet):
stdin = data_buffer[:mdu]
data_buffer = data_buffer[mdu:]
eof = not sent_eof and stdin_eof and len(stdin) == 0
if len(stdin) > 0 or eof:
messenger.send(outlet, protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN,
stdin, eof))
sent_eof = eof
except RemoteExecutionError as e:
print(e.msg)
return 255
except Exception as ex:
print(f"Client exception: {ex}")
if _link and _link.status != RNS.Link.CLOSED:
_link.teardown()
return 127
await process.event_wait_any([_new_data, _finished], timeout=min(max(rtt * 50, 5), 120))
# await process.event_wait_any([_new_data, _finished], timeout=min(max(rtt * 50, 5), 120))
await asyncio.sleep(0.01)
log.debug("after main loop")
return 0

423
rnsh/session.py Normal file
View file

@ -0,0 +1,423 @@
from __future__ import annotations
import contextlib
import functools
import threading
import rnsh.exception as exception
import asyncio
import rnsh.process as process
import rnsh.helpers as helpers
import rnsh.protocol as protocol
import enum
from typing import TypeVar, Generic, Callable, List
from abc import abstractmethod, ABC
from multiprocessing import Manager
import os
import RNS
import logging as __logging
from rnsh.protocol import MessageOutletBase, _TReceipt, MessageState
module_logger = __logging.getLogger(__name__)
_TLink = TypeVar("_TLink")
class SEType(enum.IntEnum):
SE_LINK_CLOSED = 0
class SessionException(Exception):
def __init__(self, setype: SEType, msg: str, *args):
super().__init__(msg, args)
self.type = setype
class LSState(enum.IntEnum):
LSSTATE_WAIT_IDENT = 1
LSSTATE_WAIT_VERS = 2
LSSTATE_WAIT_CMD = 3
LSSTATE_RUNNING = 4
LSSTATE_ERROR = 5
LSSTATE_TEARDOWN = 6
_TIdentity = TypeVar("_TIdentity")
class LSOutletBase(protocol.MessageOutletBase):
@abstractmethod
def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]):
raise NotImplemented()
@abstractmethod
def set_link_closed_callback(self, cb: Callable[[LSOutletBase], None]):
raise NotImplemented()
@abstractmethod
def unset_link_closed_callback(self):
raise NotImplemented()
@abstractmethod
def teardown(self):
raise NotImplemented()
@abstractmethod
def __init__(self):
raise NotImplemented()
class ListenerSession:
sessions: List[ListenerSession] = []
messenger: protocol.Messenger = protocol.Messenger(retry_delay_min=5)
allowed_identity_hashes: [any] = []
allow_all: bool = False
allow_remote_command: bool = False
default_command: [str] = []
remote_cmd_as_args = False
def __init__(self, outlet: LSOutletBase, loop: asyncio.AbstractEventLoop):
self._log = module_logger.getChild(self.__class__.__name__)
self._log.info(f"Session started for {outlet}")
self.outlet = outlet
self.outlet.set_initiator_identified_callback(self._initiator_identified)
self.outlet.set_link_closed_callback(self._link_closed)
self.loop = loop
self.state: LSState = None
self.remote_identity = None
self.term: str | None = None
self.stdin_is_pipe: bool = False
self.stdout_is_pipe: bool = False
self.stderr_is_pipe: bool = False
self.tcflags: [any] = None
self.cmdline: [str] = None
self.rows: int = 0
self.cols: int = 0
self.hpix: int = 0
self.vpix: int = 0
self.stdout_buf = bytearray()
self.stdout_eof_sent = False
self.stderr_buf = bytearray()
self.stderr_eof_sent = False
self.return_code: int | None = None
self.return_code_sent = False
self.process: process.CallbackSubprocess | None = None
self._set_state(LSState.LSSTATE_WAIT_IDENT)
self.sessions.append(self)
def _terminated(self, return_code: int):
self.return_code = return_code
def _set_state(self, state: LSState, timeout_factor: float = 10.0):
timeout = max(self.outlet.rtt * timeout_factor, max(self.outlet.rtt * 2, 10)) if timeout_factor is not None else None
self._log.debug(f"Set state: {state.name}, timeout {timeout}")
orig_state = self.state
self.state = state
if timeout_factor is not None:
self._call(functools.partial(self._check_protocol_timeout, lambda: self.state == orig_state, state.name), timeout)
def _call(self, func: callable, delay: float = 0):
def call_inner():
# self._log.debug("call_inner")
if delay == 0:
func()
else:
self.loop.call_later(delay, func)
self.loop.call_soon_threadsafe(call_inner)
def send(self, message: protocol.Message):
self.messenger.send(self.outlet, message)
def _protocol_error(self, name: str):
self.terminate(f"Protocol error ({name})")
def _protocol_timeout_error(self, name: str):
self.terminate(f"Protocol timeout error: {name}")
def terminate(self, error: str = None):
with contextlib.suppress(Exception):
self._log.debug("Terminating session" + (f": {error}" if error else ""))
if error and self.state != LSState.LSSTATE_TEARDOWN:
with contextlib.suppress(Exception):
self.send(protocol.ErrorMessage(error, True))
self.state = LSState.LSSTATE_ERROR
self._terminate_process()
self._call(self._prune, max(self.outlet.rtt * 3, 5))
def _prune(self):
self.state = LSState.LSSTATE_TEARDOWN
with contextlib.suppress(ValueError):
self.sessions.remove(self)
with contextlib.suppress(Exception):
self.outlet.teardown()
def _check_protocol_timeout(self, fail_condition: Callable[[], bool], name: str):
timeout = True
try:
timeout = self.state != LSState.LSSTATE_TEARDOWN and fail_condition()
except Exception as ex:
self._log.exception("Error in protocol timeout", ex)
if timeout:
self._protocol_timeout_error(name)
def _link_closed(self, outlet: LSOutletBase):
outlet.unset_link_closed_callback()
if outlet != self.outlet:
self._log.debug("Link closed received from incorrect outlet")
return
self._log.debug(f"link_closed {outlet}")
self.messenger.clear_retries(self.outlet)
self.terminate()
def _initiator_identified(self, outlet, identity):
if outlet != self.outlet:
self._log.debug("Identity received from incorrect outlet")
return
self._log.info(f"initiator_identified {identity} on link {outlet}")
if self.state != LSState.LSSTATE_WAIT_IDENT:
self._protocol_error(LSState.LSSTATE_WAIT_IDENT.name)
if not self.allow_all and identity.hash not in self.allowed_identity_hashes:
self.terminate("Identity is not allowed.")
self.remote_identity = identity
self.outlet.set_packet_received_callback(self._packet_received)
self._set_state(LSState.LSSTATE_WAIT_VERS)
@classmethod
async def pump_all(cls):
for session in cls.sessions:
session.pump()
await asyncio.sleep(0)
@classmethod
async def terminate_all(cls, reason: str):
for session in cls.sessions:
session.terminate(reason)
await asyncio.sleep(0)
def pump(self):
try:
if self.state != LSState.LSSTATE_RUNNING:
return
elif not self.messenger.is_outlet_ready(self.outlet):
return
elif len(self.stderr_buf) > 0:
mdu = self.outlet.mdu - 16
data = self.stderr_buf[:mdu]
self.stderr_buf = self.stderr_buf[mdu:]
send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent
self.stderr_eof_sent = self.stderr_eof_sent or send_eof
msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDERR,
data, send_eof)
self.send(msg)
if send_eof:
self.stderr_eof_sent = True
elif len(self.stdout_buf) > 0:
mdu = self.outlet.mdu - 16
data = self.stdout_buf[:mdu]
self.stdout_buf = self.stdout_buf[mdu:]
send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent
self.stdout_eof_sent = self.stdout_eof_sent or send_eof
msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDOUT,
data, send_eof)
self.send(msg)
elif self.return_code is not None and not self.return_code_sent:
msg = protocol.CommandExitedMessage(self.return_code)
self.send(msg)
self.return_code_sent = True
self._call(functools.partial(self._check_protocol_timeout,
lambda: self.state == LSState.LSSTATE_RUNNING, "CommandExitedMessage"),
max(self.outlet.rtt * 5, 10))
except Exception as ex:
self._log.exception("Error during pump", ex)
def _terminate_process(self):
with contextlib.suppress(Exception):
if self.process and self.process.running:
self.process.terminate()
def _start_cmd(self, cmdline: [str], pipe_stdin: bool, pipe_stdout: bool, pipe_stderr: bool, tcflags: [any],
term: str | None, rows: int, cols: int, hpix: int, vpix: int):
self.cmdline = self.default_command
if not self.allow_remote_command and cmdline and len(cmdline) > 0:
self.terminate("Remote command line not allowed by listener")
return
if self.remote_cmd_as_args and cmdline and len(cmdline) > 0:
self.cmdline.extend(cmdline)
elif cmdline and len(cmdline) > 0:
self.cmdline = cmdline
self.stdin_is_pipe = pipe_stdin
self.stdout_is_pipe = pipe_stdout
self.stderr_is_pipe = pipe_stderr
self.tcflags = tcflags
self.term = term
def stdout(data: bytes):
self.stdout_buf.extend(data)
def stderr(data: bytes):
self.stderr_buf.extend(data)
try:
self.process = process.CallbackSubprocess(argv=self.cmdline,
env={"TERM": self.term or os.environ.get("TERM", None),
"RNS_REMOTE_IDENTITY": RNS.prettyhexrep(self.remote_identity.hash) or ""},
loop=self.loop,
stdout_callback=stdout,
stderr_callback=stderr,
terminated_callback=self._terminated,
stdin_is_pipe=self.stdin_is_pipe,
stdout_is_pipe=self.stdout_is_pipe,
stderr_is_pipe=self.stderr_is_pipe)
self.process.start()
self._set_window_size(rows, cols, hpix, vpix)
except Exception as ex:
self._log.exception(f"Unable to start process for link {self.outlet}", ex)
self.terminate("Unable to start process")
def _set_window_size(self, rows: int, cols: int, hpix: int, vpix: int):
self.rows = rows
self.cols = cols
self.hpix = hpix
self.vpix = vpix
with contextlib.suppress(Exception):
self.process.set_winsize(rows, cols, hpix, vpix)
def _received_stdin(self, data: bytes, eof: bool):
if data and len(data) > 0:
self.process.write(data)
if eof:
self.process.close_stdin()
def _handle_message(self, message: protocol.Message):
if self.state == LSState.LSSTATE_WAIT_VERS:
if not isinstance(message, protocol.VersionInfoMessage):
self._protocol_error(self.state.name)
return
self._log.info(f"version {message.sw_version}, protocol {message.protocol_version} on link {self.outlet}")
if message.protocol_version != protocol.PROTOCOL_VERSION:
self.terminate("Incompatible protocol")
return
self.send(protocol.VersionInfoMessage())
self._set_state(LSState.LSSTATE_WAIT_CMD)
return
elif self.state == LSState.LSSTATE_WAIT_CMD:
if not isinstance(message, protocol.ExecuteCommandMesssage):
return self._protocol_error(self.state.name)
self._log.info(f"Execute command message on link {self.outlet}: {message.cmdline}")
self._set_state(LSState.LSSTATE_RUNNING)
self._start_cmd(message.cmdline, message.pipe_stdin, message.pipe_stdout, message.pipe_stderr,
message.tcflags, message.term, message.rows, message.cols, message.hpix, message.vpix)
return
elif self.state == LSState.LSSTATE_RUNNING:
if isinstance(message, protocol.WindowSizeMessage):
self._set_window_size(message.rows, message.cols, message.hpix, message.vpix)
elif isinstance(message, protocol.StreamDataMessage):
if message.stream_id != protocol.StreamDataMessage.STREAM_ID_STDIN:
self._log.error(f"Received stream data for invalid stream {message.stream_id} on link {self.outlet}")
return self._protocol_error(self.state.name)
self._received_stdin(message.data, message.eof)
return
elif isinstance(message, protocol.NoopMessage):
# echo noop only on listener--used for keepalive/connectivity check
self.send(message)
return
elif self.state in [LSState.LSSTATE_ERROR, LSState.LSSTATE_TEARDOWN]:
self._log.error(f"Received packet, but in state {self.state.name}")
return
else:
self._protocol_error("unexpected message")
return
def _packet_received(self, outlet: protocol.MessageOutletBase, raw: bytes):
if outlet != self.outlet:
self._log.debug("Packet received from incorrect outlet")
return
try:
message = self.messenger.receive(raw)
self._handle_message(message)
except Exception as ex:
self._protocol_error("unusable packet")
class RNSOutlet(LSOutletBase):
def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]):
def inner_cb(link, identity: _TIdentity):
cb(self, identity)
self.link.set_remote_identified_callback(inner_cb)
def set_link_closed_callback(self, cb: Callable[[LSOutletBase], None]):
def inner_cb(link):
cb(self)
self.link.set_link_closed_callback(inner_cb)
def unset_link_closed_callback(self):
self.link.set_link_closed_callback(None)
def teardown(self):
self.link.teardown()
def send(self, raw: bytes) -> RNS.PacketReceipt:
packet = RNS.Packet(self.link, raw)
packet.send()
return packet.receipt
@property
def mdu(self) -> int:
return self.link.MDU
@property
def rtt(self) -> float:
return self.link.rtt
@property
def is_usuable(self):
return True #self.link.status in [RNS.Link.ACTIVE]
def get_receipt_state(self, receipt: RNS.PacketReceipt) -> MessageState:
status = receipt.get_status()
if status == RNS.PacketReceipt.SENT:
return protocol.MessageState.MSGSTATE_SENT
if status == RNS.PacketReceipt.DELIVERED:
return protocol.MessageState.MSGSTATE_DELIVERED
if status == RNS.PacketReceipt.FAILED:
return protocol.MessageState.MSGSTATE_FAILED
else:
raise Exception(f"Unexpected receipt state: {status}")
def timed_out(self):
self.link.teardown()
def __str__(self):
return f"Outlet RNS Link {self.link}"
def set_packet_received_callback(self, cb: Callable[[MessageOutletBase, bytes], None]):
def inner_cb(message, packet: RNS.Packet):
packet.prove()
cb(self, message)
self.link.set_packet_callback(inner_cb)
def __init__(self, link: RNS.Link):
self.link = link
link.lsoutlet = self
link.msgoutlet = self
@staticmethod
def get_outlet(link: RNS.Link):
if hasattr(link, "lsoutlet"):
return link.lsoutlet
return RNSOutlet(link)

View file

@ -1,6 +1,10 @@
from __future__ import annotations
import logging
from rnsh.protocol import _TReceipt, MessageState
from typing import Callable
logging.getLogger().setLevel(logging.DEBUG)
import rnsh.protocol
@ -14,64 +18,60 @@ import uuid
module_logger = logging.getLogger(__name__)
class Link:
class Receipt:
def __init__(self, state: rnsh.protocol.MessageState, raw: bytes):
self.state = state
self.raw = raw
class MessageOutletTest(rnsh.protocol.MessageOutletBase):
def __init__(self, mdu: int, rtt: float):
self.link_id = uuid.uuid4()
self.timeout_callbacks = 0
self.mdu = mdu
self.rtt = rtt
self.usable = True
self._mdu = mdu
self._rtt = rtt
self._usable = True
self.receipts = []
self.packet_callback: Callable[[rnsh.protocol.MessageOutletBase, bytes], None] | None = None
def timeout_callback(self):
def send(self, raw: bytes) -> Receipt:
receipt = Receipt(rnsh.protocol.MessageState.MSGSTATE_SENT, raw)
self.receipts.append(receipt)
return receipt
def set_packet_received_callback(self, cb: Callable[[rnsh.protocol.MessageOutletBase, bytes], None]):
self.packet_callback = cb
def receive(self, raw: bytes):
if self.packet_callback:
self.packet_callback(self, raw)
@property
def mdu(self):
return self._mdu
@property
def rtt(self):
return self._rtt
@property
def is_usuable(self):
return self._usable
def get_receipt_state(self, receipt: Receipt) -> MessageState:
return receipt.state
def timed_out(self):
self.timeout_callbacks += 1
def __str__(self):
return str(self.link_id)
class Receipt:
def __init__(self, link: Link, state: rnsh.protocol.MessageState, raw: bytes):
self.state = state
self.raw = raw
self.link = link
class ProtocolHarness(contextlib.AbstractContextManager):
def __init__(self, retry_delay_min: float = 1):
self._log = module_logger.getChild(self.__class__.__name__)
self.messenger = rnsh.protocol.Messenger(receipt_checker=self.receipt_checker,
link_timeout_callback=self.link_timeout_callback,
link_mdu_getter=self.link_mdu_getter,
link_rtt_getter=self.link_rtt_getter,
link_usable_getter=self.link_usable_getter,
packet_sender=self.packet_sender,
retry_delay_min=retry_delay_min)
def packet_sender(self, link: Link, raw: bytes) -> Receipt:
receipt = Receipt(link, rnsh.protocol.MessageState.MSGSTATE_SENT, raw)
link.receipts.append(receipt)
return receipt
@staticmethod
def link_mdu_getter(link: Link):
return link.mdu
@staticmethod
def link_rtt_getter(link: Link):
return link.rtt
@staticmethod
def link_usable_getter(link: Link):
return link.usable
@staticmethod
def receipt_checker(receipt: Receipt) -> rnsh.protocol.MessageState:
return receipt.state
@staticmethod
def link_timeout_callback(link: Link):
link.timeout_callback()
self.messenger = rnsh.protocol.Messenger(retry_delay_min=retry_delay_min)
def cleanup(self):
self.messenger.shutdown()
@ -83,49 +83,58 @@ class ProtocolHarness(contextlib.AbstractContextManager):
return False
def test_mdu():
with ProtocolHarness() as h:
mdu = 500
link = Link(mdu=mdu, rtt=0.25)
assert h.messenger.get_mdu(link) == mdu - 4
link.mdu = mdu = 600
assert h.messenger.get_mdu(link) == mdu - 4
def test_rtt():
with ProtocolHarness() as h:
rtt = 0.25
link = Link(mdu=500, rtt=rtt)
assert h.messenger.get_rtt(link) == rtt
def test_send_one_retry():
rtt = 0.001
retry_interval = rtt * 150
message_content = b'Test'
with ProtocolHarness(retry_delay_min=retry_interval) as h:
link = Link(mdu=500, rtt=rtt)
outlet = MessageOutletTest(mdu=500, rtt=rtt)
message = rnsh.protocol.StreamDataMessage(stream_id=rnsh.protocol.StreamDataMessage.STREAM_ID_STDIN,
data=message_content, eof=True)
assert len(link.receipts) == 0
h.messenger.send_message(link, message)
assert len(outlet.receipts) == 0
h.messenger.send(outlet, message)
assert message.tracked
assert message.raw is not None
assert len(link.receipts) == 1
receipt = link.receipts[0]
assert len(outlet.receipts) == 1
receipt = outlet.receipts[0]
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
assert receipt.raw == message.raw
time.sleep(retry_interval * 1.5)
assert len(link.receipts) == 1
assert len(outlet.receipts) == 1
receipt.state = rnsh.protocol.MessageState.MSGSTATE_FAILED
module_logger.info("set failed")
time.sleep(retry_interval)
assert len(link.receipts) == 2
receipt = link.receipts[1]
assert len(outlet.receipts) == 2
receipt = outlet.receipts[1]
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
receipt.state = rnsh.protocol.MessageState.MSGSTATE_DELIVERED
time.sleep(retry_interval)
assert len(link.receipts) == 2
assert len(outlet.receipts) == 2
assert not message.tracked
def test_send_timeout():
rtt = 0.001
retry_interval = rtt * 150
message_content = b'Test'
with ProtocolHarness(retry_delay_min=retry_interval) as h:
outlet = MessageOutletTest(mdu=500, rtt=rtt)
message = rnsh.protocol.StreamDataMessage(stream_id=rnsh.protocol.StreamDataMessage.STREAM_ID_STDIN,
data=message_content, eof=True)
assert len(outlet.receipts) == 0
h.messenger.send(outlet, message)
assert message.tracked
assert message.raw is not None
assert len(outlet.receipts) == 1
receipt = outlet.receipts[0]
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
assert receipt.raw == message.raw
time.sleep(retry_interval * 1.5)
assert outlet.timeout_callbacks == 0
time.sleep(retry_interval * 7)
assert len(outlet.receipts) == 1
assert outlet.timeout_callbacks == 1
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
assert not message.tracked
@ -133,24 +142,32 @@ def eat_own_dog_food(message: rnsh.protocol.Message, checker: typing.Callable[[r
rtt = 0.001
retry_interval = rtt * 150
with ProtocolHarness(retry_delay_min=retry_interval) as h:
link = Link(mdu=500, rtt=rtt)
assert len(link.receipts) == 0
h.messenger.send_message(link, message)
decoded: [rnsh.protocol.Message] = []
def packet(outlet, buffer):
decoded.append(h.messenger.receive(buffer))
outlet = MessageOutletTest(mdu=500, rtt=rtt)
outlet.set_packet_received_callback(packet)
assert len(outlet.receipts) == 0
h.messenger.send(outlet, message)
assert message.tracked
assert message.raw is not None
assert len(link.receipts) == 1
receipt = link.receipts[0]
assert len(outlet.receipts) == 1
receipt = outlet.receipts[0]
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_SENT
assert receipt.raw == message.raw
module_logger.info("set delivered")
receipt.state = rnsh.protocol.MessageState.MSGSTATE_DELIVERED
time.sleep(retry_interval * 2)
assert len(link.receipts) == 1
assert len(outlet.receipts) == 1
assert receipt.state == rnsh.protocol.MessageState.MSGSTATE_DELIVERED
assert not message.tracked
module_logger.info("injecting rx message")
h.messenger.inbound(message.raw)
rx_message = h.messenger.poll_inbound(block=False)
assert len(decoded) == 0
outlet.receive(message.raw)
assert len(decoded) == 1
rx_message = decoded[0]
assert rx_message is not None
assert isinstance(rx_message, message.__class__)
assert rx_message.msgid != message.msgid
@ -238,6 +255,16 @@ def test_send_receive_error():
eat_own_dog_food(message, check)
def test_send_receive_cmdexit():
message = rnsh.protocol.CommandExitedMessage(5)
def check(rx_message: rnsh.protocol.Message):
assert isinstance(rx_message, message.__class__)
assert rx_message.return_code == message.return_code
eat_own_dog_food(message, check)

View file

@ -12,19 +12,6 @@ import re
import os
def test_check_magic():
magic = rnsh.rnsh._PROTOCOL_VERSION_0
# magic for version 0 is generated, make sure it comes out as expected
assert magic == 0xdeadbeef00000000
# verify the checker thinks it's right
assert rnsh.rnsh._protocol_check_magic(magic)
# scramble the magic
magic = magic | 0x00ffff0000000000
# make sure it fails now
assert not rnsh.rnsh._protocol_check_magic(magic)
def test_version():
# version = importlib.metadata.version(rnsh.__version__)
assert rnsh.__version__ != "0.0.0"