mirror of
https://github.com/markqvist/rnsh.git
synced 2025-06-21 04:24:19 -04:00
Switch to RNS-provided Channel
This commit is contained in:
parent
b6a22cd2a7
commit
5bca575a4b
6 changed files with 229 additions and 482 deletions
|
@ -9,7 +9,7 @@ readme = "README.md"
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.9"
|
python = "^3.9"
|
||||||
docopt = "^0.6.2"
|
docopt = "^0.6.2"
|
||||||
rns = "^0.4.9"
|
rns = { git = "https://github.com/acehoss/Reticulum.git", branch = "feature/channel" } #{ path = "../Reticulum/", develop = true } #
|
||||||
tomli = "^2.0.1"
|
tomli = "^2.0.1"
|
||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
|
|
|
@ -143,12 +143,12 @@ class InitiatorState(enum.IntEnum):
|
||||||
|
|
||||||
def _client_link_closed(link):
|
def _client_link_closed(link):
|
||||||
log = _get_logger("_client_link_closed")
|
log = _get_logger("_client_link_closed")
|
||||||
|
if _finished:
|
||||||
_finished.set()
|
_finished.set()
|
||||||
|
|
||||||
|
|
||||||
def _client_packet_handler(message, packet):
|
def _client_message_handler(message: RNS.MessageBase):
|
||||||
log = _get_logger("_client_packet_handler")
|
log = _get_logger("_client_message_handler")
|
||||||
packet.prove()
|
|
||||||
_pq.put(message)
|
_pq.put(message)
|
||||||
|
|
||||||
|
|
||||||
|
@ -213,10 +213,8 @@ async def _initiate_link(configdir, identitypath=None, verbosity=0, quietness=0,
|
||||||
_link.identify(_identity)
|
_link.identify(_identity)
|
||||||
_link.did_identify = True
|
_link.did_identify = True
|
||||||
|
|
||||||
_link.set_packet_callback(_client_packet_handler)
|
|
||||||
|
|
||||||
|
async def _handle_error(errmsg: RNS.MessageBase):
|
||||||
async def _handle_error(errmsg: protocol.Message):
|
|
||||||
if isinstance(errmsg, protocol.ErrorMessage):
|
if isinstance(errmsg, protocol.ErrorMessage):
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
if _link and _link.status == RNS.Link.ACTIVE:
|
if _link and _link.status == RNS.Link.ACTIVE:
|
||||||
|
@ -249,17 +247,18 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness:
|
||||||
|
|
||||||
state = InitiatorState.IS_LINKED
|
state = InitiatorState.IS_LINKED
|
||||||
outlet = session.RNSOutlet(_link)
|
outlet = session.RNSOutlet(_link)
|
||||||
with protocol.Messenger(retry_delay_min=5) as messenger:
|
channel = _link.get_channel()
|
||||||
|
protocol.register_message_types(channel)
|
||||||
|
channel.add_message_handler(_client_message_handler)
|
||||||
|
|
||||||
# Next step after linking and identifying: send version
|
# Next step after linking and identifying: send version
|
||||||
# if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5):
|
# if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5):
|
||||||
# print("Error bringing up link")
|
# print("Error bringing up link")
|
||||||
# return 253
|
# return 253
|
||||||
|
|
||||||
messenger.send(outlet, protocol.VersionInfoMessage())
|
channel.send(protocol.VersionInfoMessage())
|
||||||
try:
|
try:
|
||||||
vp = _pq.get(timeout=max(outlet.rtt * 20, 5))
|
vm = _pq.get(timeout=max(outlet.rtt * 20, 5))
|
||||||
vm = messenger.receive(vp)
|
|
||||||
await _handle_error(vm)
|
await _handle_error(vm)
|
||||||
if not isinstance(vm, protocol.VersionInfoMessage):
|
if not isinstance(vm, protocol.VersionInfoMessage):
|
||||||
raise Exception("Invalid message received")
|
raise Exception("Invalid message received")
|
||||||
|
@ -307,7 +306,8 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness:
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
messenger.send(outlet, protocol.ExecuteCommandMesssage(cmdline=command,
|
await _spin(lambda: channel.is_ready_to_send(), "Waiting for channel...", 1)
|
||||||
|
channel.send(protocol.ExecuteCommandMesssage(cmdline=command,
|
||||||
pipe_stdin=not os.isatty(0),
|
pipe_stdin=not os.isatty(0),
|
||||||
pipe_stdout=not os.isatty(1),
|
pipe_stdout=not os.isatty(1),
|
||||||
pipe_stderr=not os.isatty(2),
|
pipe_stderr=not os.isatty(2),
|
||||||
|
@ -330,8 +330,7 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness:
|
||||||
while not await _check_finished() and state in [InitiatorState.IS_RUNNING]:
|
while not await _check_finished() and state in [InitiatorState.IS_RUNNING]:
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
packet = _pq.get(timeout=sleeper.next_sleep_time() if not processed else 0.0005)
|
message = _pq.get(timeout=sleeper.next_sleep_time() if not processed else 0.0005)
|
||||||
message = messenger.receive(packet)
|
|
||||||
await _handle_error(message)
|
await _handle_error(message)
|
||||||
processed = True
|
processed = True
|
||||||
if isinstance(message, protocol.StreamDataMessage):
|
if isinstance(message, protocol.StreamDataMessage):
|
||||||
|
@ -353,8 +352,6 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness:
|
||||||
os.close(2)
|
os.close(2)
|
||||||
elif isinstance(message, protocol.CommandExitedMessage):
|
elif isinstance(message, protocol.CommandExitedMessage):
|
||||||
log.debug(f"received return code {message.return_code}, exiting")
|
log.debug(f"received return code {message.return_code}, exiting")
|
||||||
with exception.permit(SystemExit, KeyboardInterrupt):
|
|
||||||
_link.teardown()
|
|
||||||
return message.return_code
|
return message.return_code
|
||||||
elif isinstance(message, protocol.ErrorMessage):
|
elif isinstance(message, protocol.ErrorMessage):
|
||||||
log.error(message.data)
|
log.error(message.data)
|
||||||
|
@ -365,13 +362,12 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness:
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
processed = False
|
processed = False
|
||||||
|
|
||||||
if messenger.is_outlet_ready(outlet):
|
if channel.is_ready_to_send():
|
||||||
stdin = data_buffer[:mdu]
|
stdin = data_buffer[:mdu]
|
||||||
data_buffer = data_buffer[mdu:]
|
data_buffer = data_buffer[mdu:]
|
||||||
eof = not sent_eof and stdin_eof and len(stdin) == 0
|
eof = not sent_eof and stdin_eof and len(stdin) == 0
|
||||||
if len(stdin) > 0 or eof:
|
if len(stdin) > 0 or eof:
|
||||||
messenger.send(outlet, protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN,
|
channel.send(protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN, stdin, eof))
|
||||||
stdin, eof))
|
|
||||||
sent_eof = eof
|
sent_eof = eof
|
||||||
processed = True
|
processed = True
|
||||||
|
|
||||||
|
@ -381,7 +377,7 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness:
|
||||||
winch = False
|
winch = False
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
r, c, h, v = process.tty_get_winsize(0)
|
r, c, h, v = process.tty_get_winsize(0)
|
||||||
messenger.send(outlet, protocol.WindowSizeMessage(r, c, h, v))
|
channel.send(protocol.WindowSizeMessage(r, c, h, v))
|
||||||
processed = True
|
processed = True
|
||||||
except RemoteExecutionError as e:
|
except RemoteExecutionError as e:
|
||||||
print(e.msg)
|
print(e.msg)
|
||||||
|
|
|
@ -159,7 +159,7 @@ async def listen(configdir, command, identitypath=None, service_name=None, verbo
|
||||||
log.warning("Warning: No allowed identities configured, rnsh will not accept any connections!")
|
log.warning("Warning: No allowed identities configured, rnsh will not accept any connections!")
|
||||||
|
|
||||||
def link_established(lnk: RNS.Link):
|
def link_established(lnk: RNS.Link):
|
||||||
session.ListenerSession(session.RNSOutlet.get_outlet(lnk), loop)
|
session.ListenerSession(session.RNSOutlet.get_outlet(lnk), lnk.get_channel(), loop)
|
||||||
_destination.set_link_established_callback(link_established)
|
_destination.set_link_established_callback(link_established)
|
||||||
|
|
||||||
_finished = asyncio.Event()
|
_finished = asyncio.Event()
|
||||||
|
@ -188,7 +188,6 @@ async def listen(configdir, command, identitypath=None, service_name=None, verbo
|
||||||
log.warning("Shutting down")
|
log.warning("Shutting down")
|
||||||
await session.ListenerSession.terminate_all("Shutting down")
|
await session.ListenerSession.terminate_all("Shutting down")
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
session.ListenerSession.messenger.shutdown()
|
|
||||||
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
|
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
|
||||||
for link in links_still_active:
|
for link in links_still_active:
|
||||||
if link.status not in [RNS.Link.CLOSED]:
|
if link.status not in [RNS.Link.CLOSED]:
|
||||||
|
|
262
rnsh/protocol.py
262
rnsh/protocol.py
|
@ -19,9 +19,6 @@ from abc import ABC, abstractmethod
|
||||||
|
|
||||||
module_logger = __logging.getLogger(__name__)
|
module_logger = __logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
_TReceipt = TypeVar("_TReceipt")
|
|
||||||
_TLink = TypeVar("_TLink")
|
|
||||||
MSG_MAGIC = 0xac
|
MSG_MAGIC = 0xac
|
||||||
PROTOCOL_VERSION = 1
|
PROTOCOL_VERSION = 1
|
||||||
|
|
||||||
|
@ -30,120 +27,17 @@ def _make_MSGTYPE(val: int):
|
||||||
return ((MSG_MAGIC << 8) & 0xff00) | (val & 0x00ff)
|
return ((MSG_MAGIC << 8) & 0xff00) | (val & 0x00ff)
|
||||||
|
|
||||||
|
|
||||||
class MessageOutletBase(ABC):
|
class NoopMessage(RNS.MessageBase):
|
||||||
@abstractmethod
|
|
||||||
def send(self, raw: bytes) -> _TReceipt:
|
|
||||||
raise NotImplemented()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def resend(self, receipt: _TReceipt) -> _TReceipt:
|
|
||||||
raise NotImplemented()
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def mdu(self):
|
|
||||||
raise NotImplemented()
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def rtt(self):
|
|
||||||
raise NotImplemented()
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def is_usuable(self):
|
|
||||||
raise NotImplemented()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_receipt_state(self, receipt: _TReceipt) -> MessageState:
|
|
||||||
raise NotImplemented()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def timed_out(self):
|
|
||||||
raise NotImplemented()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __str__(self):
|
|
||||||
raise NotImplemented()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def set_packet_received_callback(self, cb: Callable[[MessageOutletBase, bytes], None]):
|
|
||||||
raise NotImplemented()
|
|
||||||
|
|
||||||
|
|
||||||
class METype(enum.IntEnum):
|
|
||||||
ME_NO_MSG_TYPE = 0
|
|
||||||
ME_INVALID_MSG_TYPE = 1
|
|
||||||
ME_NOT_REGISTERED = 2
|
|
||||||
ME_LINK_NOT_READY = 3
|
|
||||||
ME_ALREADY_SENT = 4
|
|
||||||
|
|
||||||
|
|
||||||
class MessagingException(Exception):
|
|
||||||
def __init__(self, type: METype, *args):
|
|
||||||
super().__init__(args)
|
|
||||||
self.type = type
|
|
||||||
|
|
||||||
|
|
||||||
class MessageState(enum.IntEnum):
|
|
||||||
MSGSTATE_NEW = 0
|
|
||||||
MSGSTATE_SENT = 1
|
|
||||||
MSGSTATE_DELIVERED = 2
|
|
||||||
MSGSTATE_FAILED = 3
|
|
||||||
|
|
||||||
|
|
||||||
class Message(abc.ABC):
|
|
||||||
MSGTYPE = None
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.ts = time.time()
|
|
||||||
self.msgid = uuid.uuid4()
|
|
||||||
self.raw: bytes | None = None
|
|
||||||
self.receipt: _TReceipt = None
|
|
||||||
self.outlet: _TLink = None
|
|
||||||
self.tracked: bool = False
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return f"{self.__class__.__name__} {self.msgid}"
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def pack(self) -> bytes:
|
|
||||||
raise NotImplemented()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def unpack(self, raw):
|
|
||||||
raise NotImplemented()
|
|
||||||
|
|
||||||
def unwrap_MSGTYPE(self, raw: bytes) -> bytes:
|
|
||||||
if self.MSGTYPE is None:
|
|
||||||
raise MessagingException(METype.ME_NO_MSG_TYPE, f"{self.__class__} lacks MSGTYPE")
|
|
||||||
mid, raw = self.static_unwrap_MSGTYPE(raw)
|
|
||||||
if mid != self.MSGTYPE:
|
|
||||||
raise MessagingException(METype.ME_INVALID_MSG_TYPE,
|
|
||||||
f"invalid msg id, expected {hex(self.MSGTYPE)} got {hex(mid)}")
|
|
||||||
return raw
|
|
||||||
|
|
||||||
def wrap_MSGTYPE(self, raw: bytes) -> bytes:
|
|
||||||
if self.__class__.MSGTYPE is None:
|
|
||||||
raise MessagingException(METype.ME_NO_MSG_TYPE, f"{self.__class__} lacks MSGTYPE")
|
|
||||||
return struct.pack(">H", self.MSGTYPE) + raw
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def static_unwrap_MSGTYPE(raw: bytes) -> (int, bytes):
|
|
||||||
return struct.unpack(">H", raw[:2])[0], raw[2:]
|
|
||||||
|
|
||||||
|
|
||||||
class NoopMessage(Message):
|
|
||||||
MSGTYPE = _make_MSGTYPE(0)
|
MSGTYPE = _make_MSGTYPE(0)
|
||||||
|
|
||||||
def pack(self) -> bytes:
|
def pack(self) -> bytes:
|
||||||
return self.wrap_MSGTYPE(bytes())
|
return bytes()
|
||||||
|
|
||||||
def unpack(self, raw):
|
def unpack(self, raw):
|
||||||
self.unwrap_MSGTYPE(raw)
|
pass
|
||||||
|
|
||||||
|
|
||||||
class WindowSizeMessage(Message):
|
class WindowSizeMessage(RNS.MessageBase):
|
||||||
MSGTYPE = _make_MSGTYPE(2)
|
MSGTYPE = _make_MSGTYPE(2)
|
||||||
|
|
||||||
def __init__(self, rows: int = None, cols: int = None, hpix: int = None, vpix: int = None):
|
def __init__(self, rows: int = None, cols: int = None, hpix: int = None, vpix: int = None):
|
||||||
|
@ -154,15 +48,13 @@ class WindowSizeMessage(Message):
|
||||||
self.vpix = vpix
|
self.vpix = vpix
|
||||||
|
|
||||||
def pack(self) -> bytes:
|
def pack(self) -> bytes:
|
||||||
raw = umsgpack.packb((self.rows, self.cols, self.hpix, self.vpix))
|
return umsgpack.packb((self.rows, self.cols, self.hpix, self.vpix))
|
||||||
return self.wrap_MSGTYPE(raw)
|
|
||||||
|
|
||||||
def unpack(self, raw):
|
def unpack(self, raw):
|
||||||
raw = self.unwrap_MSGTYPE(raw)
|
|
||||||
self.rows, self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw)
|
self.rows, self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw)
|
||||||
|
|
||||||
|
|
||||||
class ExecuteCommandMesssage(Message):
|
class ExecuteCommandMesssage(RNS.MessageBase):
|
||||||
MSGTYPE = _make_MSGTYPE(3)
|
MSGTYPE = _make_MSGTYPE(3)
|
||||||
|
|
||||||
def __init__(self, cmdline: [str] = None, pipe_stdin: bool = False, pipe_stdout: bool = False,
|
def __init__(self, cmdline: [str] = None, pipe_stdin: bool = False, pipe_stdout: bool = False,
|
||||||
|
@ -181,20 +73,20 @@ class ExecuteCommandMesssage(Message):
|
||||||
self.vpix = vpix
|
self.vpix = vpix
|
||||||
|
|
||||||
def pack(self) -> bytes:
|
def pack(self) -> bytes:
|
||||||
raw = umsgpack.packb((self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr,
|
return umsgpack.packb((self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr,
|
||||||
self.tcflags, self.term, self.rows, self.cols, self.hpix, self.vpix))
|
self.tcflags, self.term, self.rows, self.cols, self.hpix, self.vpix))
|
||||||
return self.wrap_MSGTYPE(raw)
|
|
||||||
|
|
||||||
def unpack(self, raw):
|
def unpack(self, raw):
|
||||||
raw = self.unwrap_MSGTYPE(raw)
|
|
||||||
self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, self.tcflags, self.term, self.rows, \
|
self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, self.tcflags, self.term, self.rows, \
|
||||||
self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw)
|
self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw)
|
||||||
|
|
||||||
class StreamDataMessage(Message):
|
|
||||||
|
class StreamDataMessage(RNS.MessageBase):
|
||||||
MSGTYPE = _make_MSGTYPE(4)
|
MSGTYPE = _make_MSGTYPE(4)
|
||||||
STREAM_ID_STDIN = 0
|
STREAM_ID_STDIN = 0
|
||||||
STREAM_ID_STDOUT = 1
|
STREAM_ID_STDOUT = 1
|
||||||
STREAM_ID_STDERR = 2
|
STREAM_ID_STDERR = 2
|
||||||
|
OVERHEAD = 0
|
||||||
|
|
||||||
def __init__(self, stream_id: int = None, data: bytes = None, eof: bool = False):
|
def __init__(self, stream_id: int = None, data: bytes = None, eof: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -203,15 +95,19 @@ class StreamDataMessage(Message):
|
||||||
self.eof = eof
|
self.eof = eof
|
||||||
|
|
||||||
def pack(self) -> bytes:
|
def pack(self) -> bytes:
|
||||||
raw = umsgpack.packb((self.stream_id, self.eof, bytes(self.data)))
|
return umsgpack.packb((self.stream_id, self.eof, bytes(self.data)))
|
||||||
return self.wrap_MSGTYPE(raw)
|
|
||||||
|
|
||||||
def unpack(self, raw):
|
def unpack(self, raw):
|
||||||
raw = self.unwrap_MSGTYPE(raw)
|
|
||||||
self.stream_id, self.eof, self.data = umsgpack.unpackb(raw)
|
self.stream_id, self.eof, self.data = umsgpack.unpackb(raw)
|
||||||
|
|
||||||
|
|
||||||
class VersionInfoMessage(Message):
|
_link_sized_bytes = ("\0"*RNS.Link.MDU).encode("utf-8")
|
||||||
|
StreamDataMessage.OVERHEAD = len(StreamDataMessage(stream_id=0, data=_link_sized_bytes, eof=True).pack()) \
|
||||||
|
- len(_link_sized_bytes)
|
||||||
|
_link_sized_bytes = None
|
||||||
|
|
||||||
|
|
||||||
|
class VersionInfoMessage(RNS.MessageBase):
|
||||||
MSGTYPE = _make_MSGTYPE(5)
|
MSGTYPE = _make_MSGTYPE(5)
|
||||||
|
|
||||||
def __init__(self, sw_version: str = None):
|
def __init__(self, sw_version: str = None):
|
||||||
|
@ -220,15 +116,13 @@ class VersionInfoMessage(Message):
|
||||||
self.protocol_version = PROTOCOL_VERSION
|
self.protocol_version = PROTOCOL_VERSION
|
||||||
|
|
||||||
def pack(self) -> bytes:
|
def pack(self) -> bytes:
|
||||||
raw = umsgpack.packb((self.sw_version, self.protocol_version))
|
return umsgpack.packb((self.sw_version, self.protocol_version))
|
||||||
return self.wrap_MSGTYPE(raw)
|
|
||||||
|
|
||||||
def unpack(self, raw):
|
def unpack(self, raw):
|
||||||
raw = self.unwrap_MSGTYPE(raw)
|
|
||||||
self.sw_version, self.protocol_version = umsgpack.unpackb(raw)
|
self.sw_version, self.protocol_version = umsgpack.unpackb(raw)
|
||||||
|
|
||||||
|
|
||||||
class ErrorMessage(Message):
|
class ErrorMessage(RNS.MessageBase):
|
||||||
MSGTYPE = _make_MSGTYPE(6)
|
MSGTYPE = _make_MSGTYPE(6)
|
||||||
|
|
||||||
def __init__(self, msg: str = None, fatal: bool = False, data: dict = None):
|
def __init__(self, msg: str = None, fatal: bool = False, data: dict = None):
|
||||||
|
@ -238,15 +132,13 @@ class ErrorMessage(Message):
|
||||||
self.data = data
|
self.data = data
|
||||||
|
|
||||||
def pack(self) -> bytes:
|
def pack(self) -> bytes:
|
||||||
raw = umsgpack.packb((self.msg, self.fatal, self.data))
|
return umsgpack.packb((self.msg, self.fatal, self.data))
|
||||||
return self.wrap_MSGTYPE(raw)
|
|
||||||
|
|
||||||
def unpack(self, raw: bytes):
|
def unpack(self, raw: bytes):
|
||||||
raw = self.unwrap_MSGTYPE(raw)
|
|
||||||
self.msg, self.fatal, self.data = umsgpack.unpackb(raw)
|
self.msg, self.fatal, self.data = umsgpack.unpackb(raw)
|
||||||
|
|
||||||
|
|
||||||
class CommandExitedMessage(Message):
|
class CommandExitedMessage(RNS.MessageBase):
|
||||||
MSGTYPE = _make_MSGTYPE(7)
|
MSGTYPE = _make_MSGTYPE(7)
|
||||||
|
|
||||||
def __init__(self, return_code: int = None):
|
def __init__(self, return_code: int = None):
|
||||||
|
@ -254,114 +146,16 @@ class CommandExitedMessage(Message):
|
||||||
self.return_code = return_code
|
self.return_code = return_code
|
||||||
|
|
||||||
def pack(self) -> bytes:
|
def pack(self) -> bytes:
|
||||||
raw = umsgpack.packb(self.return_code)
|
return umsgpack.packb(self.return_code)
|
||||||
return self.wrap_MSGTYPE(raw)
|
|
||||||
|
|
||||||
def unpack(self, raw: bytes):
|
def unpack(self, raw: bytes):
|
||||||
raw = self.unwrap_MSGTYPE(raw)
|
|
||||||
self.return_code = umsgpack.unpackb(raw)
|
self.return_code = umsgpack.unpackb(raw)
|
||||||
|
|
||||||
|
|
||||||
class Messenger(contextlib.AbstractContextManager):
|
message_types = [NoopMessage, VersionInfoMessage, WindowSizeMessage, ExecuteCommandMesssage, StreamDataMessage,
|
||||||
|
CommandExitedMessage, ErrorMessage]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_msg_constructors() -> (int, Type[Message]):
|
|
||||||
subclass_tuples = []
|
|
||||||
for subclass in Message.__subclasses__():
|
|
||||||
subclass_tuples.append((subclass.MSGTYPE, subclass))
|
|
||||||
return subclass_tuples
|
|
||||||
|
|
||||||
def __init__(self, retry_delay_min: float = 10.0):
|
|
||||||
self._log = module_logger.getChild(self.__class__.__name__)
|
|
||||||
self._sent_messages: list[Message] = []
|
|
||||||
self._lock = threading.RLock()
|
|
||||||
self._retry_timer = rnsh.retry.RetryThread()
|
|
||||||
self._message_factories = dict(self.__class__._get_msg_constructors())
|
|
||||||
self._retry_delay_min = retry_delay_min
|
|
||||||
|
|
||||||
def __enter__(self) -> Messenger:
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, __exc_type: Type[BaseException] | None, __exc_value: BaseException | None,
|
|
||||||
__traceback: TracebackType | None) -> bool | None:
|
|
||||||
self.shutdown()
|
|
||||||
return False
|
|
||||||
|
|
||||||
def shutdown(self):
|
|
||||||
self._retry_timer.close()
|
|
||||||
|
|
||||||
def clear_retries(self, outlet):
|
|
||||||
self._retry_timer.complete(outlet)
|
|
||||||
|
|
||||||
def receive(self, raw: bytes) -> Message:
|
|
||||||
(mid, contents) = Message.static_unwrap_MSGTYPE(raw)
|
|
||||||
ctor = self._message_factories.get(mid, None)
|
|
||||||
if ctor is None:
|
|
||||||
raise MessagingException(METype.ME_NOT_REGISTERED, f"unable to find constructor for message type {hex(mid)}")
|
|
||||||
message = ctor()
|
|
||||||
message.unpack(raw)
|
|
||||||
self._log.debug(f"Message received: {message}")
|
|
||||||
return message
|
|
||||||
|
|
||||||
def is_outlet_ready(self, outlet: MessageOutletBase) -> bool:
|
|
||||||
if not outlet.is_usuable:
|
|
||||||
self._log.debug("is_outlet_ready outlet unusable")
|
|
||||||
return False
|
|
||||||
|
|
||||||
with self._lock:
|
|
||||||
for message in self._sent_messages:
|
|
||||||
if message.outlet == outlet and message.tracked and message.receipt \
|
|
||||||
and outlet.get_receipt_state(message.receipt) == MessageState.MSGSTATE_SENT:
|
|
||||||
self._log.debug("is_outlet_ready pending message found")
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def send(self, outlet: MessageOutletBase, message: Message):
|
|
||||||
with self._lock:
|
|
||||||
if not self.is_outlet_ready(outlet):
|
|
||||||
raise MessagingException(METype.ME_LINK_NOT_READY, f"link {outlet} not ready")
|
|
||||||
|
|
||||||
if message in self._sent_messages:
|
|
||||||
raise MessagingException(METype.ME_ALREADY_SENT)
|
|
||||||
self._sent_messages.append(message)
|
|
||||||
message.tracked = True
|
|
||||||
|
|
||||||
if not message.raw:
|
|
||||||
message.raw = message.pack()
|
|
||||||
message.outlet = outlet
|
|
||||||
|
|
||||||
def send_inner(tag: any, tries: int):
|
|
||||||
state = MessageState.MSGSTATE_NEW if not message.receipt else outlet.get_receipt_state(message.receipt)
|
|
||||||
if state in [MessageState.MSGSTATE_NEW, MessageState.MSGSTATE_FAILED]:
|
|
||||||
try:
|
|
||||||
if message.receipt:
|
|
||||||
self._log.debug(f"Resending packet for {message}")
|
|
||||||
message.receipt = outlet.resend(message.receipt)
|
|
||||||
else:
|
|
||||||
self._log.debug(f"Sending packet for {message}")
|
|
||||||
message.receipt = outlet.send(message.raw)
|
|
||||||
except Exception as ex:
|
|
||||||
self._log.exception(f"Error sending message {message}")
|
|
||||||
elif state in [MessageState.MSGSTATE_SENT]:
|
|
||||||
self._log.debug(f"Retry skipped, message still pending {message}")
|
|
||||||
elif state in [MessageState.MSGSTATE_DELIVERED]:
|
|
||||||
latency = round(time.time() - message.ts, 1)
|
|
||||||
self._log.debug(f"{message} delivered {message.msgid} after {tries-1} tries/{latency} seconds")
|
|
||||||
with self._lock:
|
|
||||||
self._sent_messages.remove(message)
|
|
||||||
message.tracked = False
|
|
||||||
self._retry_timer.complete(outlet)
|
|
||||||
return outlet
|
|
||||||
|
|
||||||
def timeout(tag: any, tries: int):
|
|
||||||
latency = round(time.time() - message.ts, 1)
|
|
||||||
msg = "delivered" if message.receipt and outlet.get_receipt_state(message.receipt) == MessageState.MSGSTATE_DELIVERED else "retry timeout"
|
|
||||||
self._log.debug(f"Message {msg} {message} after {tries} tries/{latency} seconds")
|
|
||||||
with self._lock:
|
|
||||||
self._sent_messages.remove(message)
|
|
||||||
message.tracked = False
|
|
||||||
outlet.timed_out()
|
|
||||||
|
|
||||||
rtt = outlet.rtt
|
|
||||||
self._retry_timer.begin(5, max(rtt * 5, self._retry_delay_min), send_inner, timeout)
|
|
||||||
|
|
||||||
|
def register_message_types(channel: RNS.Channel.Channel):
|
||||||
|
for message_type in message_types:
|
||||||
|
channel.register_message_type(message_type)
|
12
rnsh/rnsh.py
12
rnsh/rnsh.py
|
@ -102,13 +102,16 @@ def print_identity(configdir, identitypath, service_name, include_destination: b
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
verbose_set = False
|
||||||
|
|
||||||
|
|
||||||
async def _rnsh_cli_main():
|
async def _rnsh_cli_main():
|
||||||
#with contextlib.suppress(KeyboardInterrupt, SystemExit):
|
global verbose_set
|
||||||
import docopt
|
|
||||||
log = _get_logger("main")
|
log = _get_logger("main")
|
||||||
_loop = asyncio.get_running_loop()
|
_loop = asyncio.get_running_loop()
|
||||||
rnslogging.set_main_loop(_loop)
|
rnslogging.set_main_loop(_loop)
|
||||||
args = rnsh.args.Args(sys.argv)
|
args = rnsh.args.Args(sys.argv)
|
||||||
|
verbose_set = args.verbose > 0
|
||||||
|
|
||||||
if args.print_identity:
|
if args.print_identity:
|
||||||
print_identity(args.config, args.identity, args.service_name, args.listen)
|
print_identity(args.config, args.identity, args.service_name, args.listen)
|
||||||
|
@ -148,7 +151,9 @@ async def _rnsh_cli_main():
|
||||||
|
|
||||||
|
|
||||||
def rnsh_cli():
|
def rnsh_cli():
|
||||||
|
global verbose_set
|
||||||
return_code = 1
|
return_code = 1
|
||||||
|
exc = None
|
||||||
try:
|
try:
|
||||||
return_code = asyncio.run(_rnsh_cli_main())
|
return_code = asyncio.run(_rnsh_cli_main())
|
||||||
except SystemExit:
|
except SystemExit:
|
||||||
|
@ -157,7 +162,10 @@ def rnsh_cli():
|
||||||
pass
|
pass
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
print(f"Unhandled exception: {ex}")
|
print(f"Unhandled exception: {ex}")
|
||||||
|
exc = ex
|
||||||
process.tty_unset_reader_callbacks(0)
|
process.tty_unset_reader_callbacks(0)
|
||||||
|
if verbose_set and exc:
|
||||||
|
raise exc
|
||||||
sys.exit(return_code if return_code is not None else 255)
|
sys.exit(return_code if return_code is not None else 255)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -16,8 +16,6 @@ import RNS
|
||||||
|
|
||||||
import logging as __logging
|
import logging as __logging
|
||||||
|
|
||||||
from rnsh.protocol import MessageOutletBase, _TReceipt, MessageState
|
|
||||||
|
|
||||||
module_logger = __logging.getLogger(__name__)
|
module_logger = __logging.getLogger(__name__)
|
||||||
|
|
||||||
_TLink = TypeVar("_TLink")
|
_TLink = TypeVar("_TLink")
|
||||||
|
@ -44,7 +42,7 @@ class LSState(enum.IntEnum):
|
||||||
_TIdentity = TypeVar("_TIdentity")
|
_TIdentity = TypeVar("_TIdentity")
|
||||||
|
|
||||||
|
|
||||||
class LSOutletBase(protocol.MessageOutletBase):
|
class LSOutletBase(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]):
|
def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]):
|
||||||
raise NotImplemented()
|
raise NotImplemented()
|
||||||
|
@ -57,28 +55,29 @@ class LSOutletBase(protocol.MessageOutletBase):
|
||||||
def unset_link_closed_callback(self):
|
def unset_link_closed_callback(self):
|
||||||
raise NotImplemented()
|
raise NotImplemented()
|
||||||
|
|
||||||
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def teardown(self):
|
def rtt(self):
|
||||||
raise NotImplemented()
|
raise NotImplemented()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self):
|
def teardown(self):
|
||||||
raise NotImplemented()
|
raise NotImplemented()
|
||||||
|
|
||||||
|
|
||||||
class ListenerSession:
|
class ListenerSession:
|
||||||
sessions: List[ListenerSession] = []
|
sessions: List[ListenerSession] = []
|
||||||
messenger: protocol.Messenger = protocol.Messenger(retry_delay_min=5)
|
|
||||||
allowed_identity_hashes: [any] = []
|
allowed_identity_hashes: [any] = []
|
||||||
allow_all: bool = False
|
allow_all: bool = False
|
||||||
allow_remote_command: bool = False
|
allow_remote_command: bool = False
|
||||||
default_command: [str] = []
|
default_command: [str] = []
|
||||||
remote_cmd_as_args = False
|
remote_cmd_as_args = False
|
||||||
|
|
||||||
def __init__(self, outlet: LSOutletBase, loop: asyncio.AbstractEventLoop):
|
def __init__(self, outlet: LSOutletBase, channel: RNS.Channel.Channel, loop: asyncio.AbstractEventLoop):
|
||||||
self._log = module_logger.getChild(self.__class__.__name__)
|
self._log = module_logger.getChild(self.__class__.__name__)
|
||||||
self._log.info(f"Session started for {outlet}")
|
self._log.info(f"Session started for {outlet}")
|
||||||
self.outlet = outlet
|
self.outlet = outlet
|
||||||
|
self.channel = channel
|
||||||
self.outlet.set_initiator_identified_callback(self._initiator_identified)
|
self.outlet.set_initiator_identified_callback(self._initiator_identified)
|
||||||
self.outlet.set_link_closed_callback(self._link_closed)
|
self.outlet.set_link_closed_callback(self._link_closed)
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
|
@ -106,7 +105,8 @@ class ListenerSession:
|
||||||
else:
|
else:
|
||||||
self._set_state(LSState.LSSTATE_WAIT_IDENT)
|
self._set_state(LSState.LSSTATE_WAIT_IDENT)
|
||||||
self.sessions.append(self)
|
self.sessions.append(self)
|
||||||
self.outlet.set_packet_received_callback(self._packet_received)
|
protocol.register_message_types(self.channel)
|
||||||
|
self.channel.add_message_handler(self._handle_message)
|
||||||
|
|
||||||
def _terminated(self, return_code: int):
|
def _terminated(self, return_code: int):
|
||||||
self.return_code = return_code
|
self.return_code = return_code
|
||||||
|
@ -128,8 +128,8 @@ class ListenerSession:
|
||||||
self.loop.call_later(delay, func)
|
self.loop.call_later(delay, func)
|
||||||
self.loop.call_soon_threadsafe(call_inner)
|
self.loop.call_soon_threadsafe(call_inner)
|
||||||
|
|
||||||
def send(self, message: protocol.Message):
|
def send(self, message: RNS.MessageBase):
|
||||||
self.messenger.send(self.outlet, message)
|
self.channel.send(message)
|
||||||
|
|
||||||
def _protocol_error(self, name: str):
|
def _protocol_error(self, name: str):
|
||||||
self.terminate(f"Protocol error ({name})")
|
self.terminate(f"Protocol error ({name})")
|
||||||
|
@ -171,7 +171,6 @@ class ListenerSession:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._log.debug(f"link_closed {outlet}")
|
self._log.debug(f"link_closed {outlet}")
|
||||||
self.messenger.clear_retries(self.outlet)
|
|
||||||
self.terminate()
|
self.terminate()
|
||||||
|
|
||||||
def _initiator_identified(self, outlet, identity):
|
def _initiator_identified(self, outlet, identity):
|
||||||
|
@ -208,10 +207,10 @@ class ListenerSession:
|
||||||
try:
|
try:
|
||||||
if self.state != LSState.LSSTATE_RUNNING:
|
if self.state != LSState.LSSTATE_RUNNING:
|
||||||
return False
|
return False
|
||||||
elif not self.messenger.is_outlet_ready(self.outlet):
|
elif not self.channel.is_ready_to_send():
|
||||||
return False
|
return False
|
||||||
elif len(self.stderr_buf) > 0:
|
elif len(self.stderr_buf) > 0:
|
||||||
mdu = self.outlet.mdu - 16
|
mdu = self.channel.MDU - protocol.StreamDataMessage.OVERHEAD
|
||||||
data = self.stderr_buf[:mdu]
|
data = self.stderr_buf[:mdu]
|
||||||
self.stderr_buf = self.stderr_buf[mdu:]
|
self.stderr_buf = self.stderr_buf[mdu:]
|
||||||
send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent
|
send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent
|
||||||
|
@ -223,7 +222,7 @@ class ListenerSession:
|
||||||
self.stderr_eof_sent = True
|
self.stderr_eof_sent = True
|
||||||
return True
|
return True
|
||||||
elif len(self.stdout_buf) > 0:
|
elif len(self.stdout_buf) > 0:
|
||||||
mdu = self.outlet.mdu - 16
|
mdu = self.channel.MDU - protocol.StreamDataMessage.OVERHEAD
|
||||||
data = self.stdout_buf[:mdu]
|
data = self.stdout_buf[:mdu]
|
||||||
self.stdout_buf = self.stdout_buf[mdu:]
|
self.stdout_buf = self.stdout_buf[mdu:]
|
||||||
send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent
|
send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent
|
||||||
|
@ -309,7 +308,7 @@ class ListenerSession:
|
||||||
if eof:
|
if eof:
|
||||||
self.process.close_stdin()
|
self.process.close_stdin()
|
||||||
|
|
||||||
def _handle_message(self, message: protocol.Message):
|
def _handle_message(self, message: RNS.MessageBase):
|
||||||
if self.state == LSState.LSSTATE_WAIT_IDENT:
|
if self.state == LSState.LSSTATE_WAIT_IDENT:
|
||||||
self._protocol_error("Identification required")
|
self._protocol_error("Identification required")
|
||||||
return
|
return
|
||||||
|
@ -352,17 +351,6 @@ class ListenerSession:
|
||||||
self._protocol_error("unexpected message")
|
self._protocol_error("unexpected message")
|
||||||
return
|
return
|
||||||
|
|
||||||
def _packet_received(self, outlet: protocol.MessageOutletBase, raw: bytes):
|
|
||||||
if outlet != self.outlet:
|
|
||||||
self._log.debug("Packet received from incorrect outlet")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
message = self.messenger.receive(raw)
|
|
||||||
self._handle_message(message)
|
|
||||||
except Exception as ex:
|
|
||||||
self._protocol_error(f"error receiving packet: {ex}")
|
|
||||||
|
|
||||||
|
|
||||||
class RNSOutlet(LSOutletBase):
|
class RNSOutlet(LSOutletBase):
|
||||||
|
|
||||||
|
@ -384,55 +372,17 @@ class RNSOutlet(LSOutletBase):
|
||||||
def teardown(self):
|
def teardown(self):
|
||||||
self.link.teardown()
|
self.link.teardown()
|
||||||
|
|
||||||
def send(self, raw: bytes) -> RNS.Packet:
|
|
||||||
packet = RNS.Packet(self.link, raw)
|
|
||||||
packet.send()
|
|
||||||
return packet
|
|
||||||
|
|
||||||
def resend(self, packet: RNS.Packet) -> RNS.Packet:
|
|
||||||
packet.resend()
|
|
||||||
return packet
|
|
||||||
|
|
||||||
@property
|
|
||||||
def mdu(self) -> int:
|
|
||||||
return self.link.MDU
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def rtt(self) -> float:
|
def rtt(self) -> float:
|
||||||
return self.link.rtt
|
return self.link.rtt
|
||||||
|
|
||||||
@property
|
|
||||||
def is_usuable(self):
|
|
||||||
return True #self.link.status in [RNS.Link.ACTIVE]
|
|
||||||
|
|
||||||
def get_receipt_state(self, packet: RNS.Packet) -> MessageState:
|
|
||||||
status = packet.receipt.get_status()
|
|
||||||
if status == RNS.PacketReceipt.SENT:
|
|
||||||
return protocol.MessageState.MSGSTATE_SENT
|
|
||||||
if status == RNS.PacketReceipt.DELIVERED:
|
|
||||||
return protocol.MessageState.MSGSTATE_DELIVERED
|
|
||||||
if status == RNS.PacketReceipt.FAILED:
|
|
||||||
return protocol.MessageState.MSGSTATE_FAILED
|
|
||||||
else:
|
|
||||||
raise Exception(f"Unexpected receipt state: {status}")
|
|
||||||
|
|
||||||
def timed_out(self):
|
|
||||||
self.link.teardown()
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"Outlet RNS Link {self.link}"
|
return f"Outlet RNS Link {self.link}"
|
||||||
|
|
||||||
def set_packet_received_callback(self, cb: Callable[[MessageOutletBase, bytes], None]):
|
|
||||||
def inner_cb(message, packet: RNS.Packet):
|
|
||||||
packet.prove()
|
|
||||||
cb(self, message)
|
|
||||||
|
|
||||||
self.link.set_packet_callback(inner_cb)
|
|
||||||
|
|
||||||
def __init__(self, link: RNS.Link):
|
def __init__(self, link: RNS.Link):
|
||||||
self.link = link
|
self.link = link
|
||||||
link.lsoutlet = self
|
link.lsoutlet = self
|
||||||
link.msgoutlet = self
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_outlet(link: RNS.Link):
|
def get_outlet(link: RNS.Link):
|
||||||
if hasattr(link, "lsoutlet"):
|
if hasattr(link, "lsoutlet"):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue