mirror of
https://github.com/markqvist/rnsh.git
synced 2025-02-01 08:24:54 -05:00
Switch to RNS-provided Channel
This commit is contained in:
parent
b6a22cd2a7
commit
5bca575a4b
@ -9,7 +9,7 @@ readme = "README.md"
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.9"
|
||||
docopt = "^0.6.2"
|
||||
rns = "^0.4.9"
|
||||
rns = { git = "https://github.com/acehoss/Reticulum.git", branch = "feature/channel" } #{ path = "../Reticulum/", develop = true } #
|
||||
tomli = "^2.0.1"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
|
@ -143,12 +143,12 @@ class InitiatorState(enum.IntEnum):
|
||||
|
||||
def _client_link_closed(link):
|
||||
log = _get_logger("_client_link_closed")
|
||||
_finished.set()
|
||||
if _finished:
|
||||
_finished.set()
|
||||
|
||||
|
||||
def _client_packet_handler(message, packet):
|
||||
log = _get_logger("_client_packet_handler")
|
||||
packet.prove()
|
||||
def _client_message_handler(message: RNS.MessageBase):
|
||||
log = _get_logger("_client_message_handler")
|
||||
_pq.put(message)
|
||||
|
||||
|
||||
@ -213,10 +213,8 @@ async def _initiate_link(configdir, identitypath=None, verbosity=0, quietness=0,
|
||||
_link.identify(_identity)
|
||||
_link.did_identify = True
|
||||
|
||||
_link.set_packet_callback(_client_packet_handler)
|
||||
|
||||
|
||||
async def _handle_error(errmsg: protocol.Message):
|
||||
async def _handle_error(errmsg: RNS.MessageBase):
|
||||
if isinstance(errmsg, protocol.ErrorMessage):
|
||||
with contextlib.suppress(Exception):
|
||||
if _link and _link.status == RNS.Link.ACTIVE:
|
||||
@ -249,150 +247,148 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness:
|
||||
|
||||
state = InitiatorState.IS_LINKED
|
||||
outlet = session.RNSOutlet(_link)
|
||||
with protocol.Messenger(retry_delay_min=5) as messenger:
|
||||
channel = _link.get_channel()
|
||||
protocol.register_message_types(channel)
|
||||
channel.add_message_handler(_client_message_handler)
|
||||
|
||||
# Next step after linking and identifying: send version
|
||||
# if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5):
|
||||
# print("Error bringing up link")
|
||||
# return 253
|
||||
# Next step after linking and identifying: send version
|
||||
# if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5):
|
||||
# print("Error bringing up link")
|
||||
# return 253
|
||||
|
||||
messenger.send(outlet, protocol.VersionInfoMessage())
|
||||
channel.send(protocol.VersionInfoMessage())
|
||||
try:
|
||||
vm = _pq.get(timeout=max(outlet.rtt * 20, 5))
|
||||
await _handle_error(vm)
|
||||
if not isinstance(vm, protocol.VersionInfoMessage):
|
||||
raise Exception("Invalid message received")
|
||||
log.debug(f"Server version info: sw {vm.sw_version} prot {vm.protocol_version}")
|
||||
state = InitiatorState.IS_RUNNING
|
||||
except queue.Empty:
|
||||
print("Protocol error")
|
||||
return 254
|
||||
|
||||
winch = False
|
||||
def sigwinch_handler():
|
||||
nonlocal winch
|
||||
# log.debug("WindowChanged")
|
||||
winch = True
|
||||
|
||||
stdin_eof = False
|
||||
def stdin():
|
||||
nonlocal stdin_eof
|
||||
try:
|
||||
vp = _pq.get(timeout=max(outlet.rtt * 20, 5))
|
||||
vm = messenger.receive(vp)
|
||||
await _handle_error(vm)
|
||||
if not isinstance(vm, protocol.VersionInfoMessage):
|
||||
raise Exception("Invalid message received")
|
||||
log.debug(f"Server version info: sw {vm.sw_version} prot {vm.protocol_version}")
|
||||
state = InitiatorState.IS_RUNNING
|
||||
except queue.Empty:
|
||||
print("Protocol error")
|
||||
return 254
|
||||
data = process.tty_read(sys.stdin.fileno())
|
||||
log.debug(f"stdin {data}")
|
||||
if data is not None:
|
||||
data_buffer.extend(data)
|
||||
except EOFError:
|
||||
if os.isatty(0):
|
||||
data_buffer.extend(process.CTRL_D)
|
||||
stdin_eof = True
|
||||
process.tty_unset_reader_callbacks(sys.stdin.fileno())
|
||||
|
||||
winch = False
|
||||
def sigwinch_handler():
|
||||
nonlocal winch
|
||||
# log.debug("WindowChanged")
|
||||
winch = True
|
||||
process.tty_add_reader_callback(sys.stdin.fileno(), stdin)
|
||||
|
||||
stdin_eof = False
|
||||
def stdin():
|
||||
nonlocal stdin_eof
|
||||
try:
|
||||
data = process.tty_read(sys.stdin.fileno())
|
||||
log.debug(f"stdin {data}")
|
||||
if data is not None:
|
||||
data_buffer.extend(data)
|
||||
except EOFError:
|
||||
if os.isatty(0):
|
||||
data_buffer.extend(process.CTRL_D)
|
||||
stdin_eof = True
|
||||
process.tty_unset_reader_callbacks(sys.stdin.fileno())
|
||||
|
||||
process.tty_add_reader_callback(sys.stdin.fileno(), stdin)
|
||||
|
||||
tcattr = None
|
||||
rows, cols, hpix, vpix = (None, None, None, None)
|
||||
tcattr = None
|
||||
rows, cols, hpix, vpix = (None, None, None, None)
|
||||
try:
|
||||
tcattr = termios.tcgetattr(0)
|
||||
rows, cols, hpix, vpix = process.tty_get_winsize(0)
|
||||
except:
|
||||
try:
|
||||
tcattr = termios.tcgetattr(0)
|
||||
rows, cols, hpix, vpix = process.tty_get_winsize(0)
|
||||
tcattr = termios.tcgetattr(1)
|
||||
rows, cols, hpix, vpix = process.tty_get_winsize(1)
|
||||
except:
|
||||
try:
|
||||
tcattr = termios.tcgetattr(1)
|
||||
rows, cols, hpix, vpix = process.tty_get_winsize(1)
|
||||
tcattr = termios.tcgetattr(2)
|
||||
rows, cols, hpix, vpix = process.tty_get_winsize(2)
|
||||
except:
|
||||
try:
|
||||
tcattr = termios.tcgetattr(2)
|
||||
rows, cols, hpix, vpix = process.tty_get_winsize(2)
|
||||
except:
|
||||
pass
|
||||
pass
|
||||
|
||||
messenger.send(outlet, protocol.ExecuteCommandMesssage(cmdline=command,
|
||||
pipe_stdin=not os.isatty(0),
|
||||
pipe_stdout=not os.isatty(1),
|
||||
pipe_stderr=not os.isatty(2),
|
||||
tcflags=tcattr,
|
||||
term=os.environ.get("TERM", None),
|
||||
rows=rows,
|
||||
cols=cols,
|
||||
hpix=hpix,
|
||||
vpix=vpix))
|
||||
await _spin(lambda: channel.is_ready_to_send(), "Waiting for channel...", 1)
|
||||
channel.send(protocol.ExecuteCommandMesssage(cmdline=command,
|
||||
pipe_stdin=not os.isatty(0),
|
||||
pipe_stdout=not os.isatty(1),
|
||||
pipe_stderr=not os.isatty(2),
|
||||
tcflags=tcattr,
|
||||
term=os.environ.get("TERM", None),
|
||||
rows=rows,
|
||||
cols=cols,
|
||||
hpix=hpix,
|
||||
vpix=vpix))
|
||||
|
||||
loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler)
|
||||
_finished = asyncio.Event()
|
||||
loop.add_signal_handler(signal.SIGINT, functools.partial(_sigint_handler, signal.SIGINT, loop))
|
||||
loop.add_signal_handler(signal.SIGTERM, functools.partial(_sigint_handler, signal.SIGTERM, loop))
|
||||
mdu = _link.MDU - 16
|
||||
sent_eof = False
|
||||
last_winch = time.time()
|
||||
sleeper = helpers.SleepRate(0.01)
|
||||
processed = False
|
||||
while not await _check_finished() and state in [InitiatorState.IS_RUNNING]:
|
||||
loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler)
|
||||
_finished = asyncio.Event()
|
||||
loop.add_signal_handler(signal.SIGINT, functools.partial(_sigint_handler, signal.SIGINT, loop))
|
||||
loop.add_signal_handler(signal.SIGTERM, functools.partial(_sigint_handler, signal.SIGTERM, loop))
|
||||
mdu = _link.MDU - 16
|
||||
sent_eof = False
|
||||
last_winch = time.time()
|
||||
sleeper = helpers.SleepRate(0.01)
|
||||
processed = False
|
||||
while not await _check_finished() and state in [InitiatorState.IS_RUNNING]:
|
||||
try:
|
||||
try:
|
||||
try:
|
||||
packet = _pq.get(timeout=sleeper.next_sleep_time() if not processed else 0.0005)
|
||||
message = messenger.receive(packet)
|
||||
await _handle_error(message)
|
||||
message = _pq.get(timeout=sleeper.next_sleep_time() if not processed else 0.0005)
|
||||
await _handle_error(message)
|
||||
processed = True
|
||||
if isinstance(message, protocol.StreamDataMessage):
|
||||
if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDOUT:
|
||||
if message.data and len(message.data) > 0:
|
||||
ttyRestorer.raw()
|
||||
log.debug(f"stdout: {message.data}")
|
||||
os.write(1, message.data)
|
||||
sys.stdout.flush()
|
||||
if message.eof:
|
||||
os.close(1)
|
||||
if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDERR:
|
||||
if message.data and len(message.data) > 0:
|
||||
ttyRestorer.raw()
|
||||
log.debug(f"stdout: {message.data}")
|
||||
os.write(2, message.data)
|
||||
sys.stderr.flush()
|
||||
if message.eof:
|
||||
os.close(2)
|
||||
elif isinstance(message, protocol.CommandExitedMessage):
|
||||
log.debug(f"received return code {message.return_code}, exiting")
|
||||
return message.return_code
|
||||
elif isinstance(message, protocol.ErrorMessage):
|
||||
log.error(message.data)
|
||||
if message.fatal:
|
||||
_link.teardown()
|
||||
return 200
|
||||
|
||||
except queue.Empty:
|
||||
processed = False
|
||||
|
||||
if channel.is_ready_to_send():
|
||||
stdin = data_buffer[:mdu]
|
||||
data_buffer = data_buffer[mdu:]
|
||||
eof = not sent_eof and stdin_eof and len(stdin) == 0
|
||||
if len(stdin) > 0 or eof:
|
||||
channel.send(protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN, stdin, eof))
|
||||
sent_eof = eof
|
||||
processed = True
|
||||
if isinstance(message, protocol.StreamDataMessage):
|
||||
if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDOUT:
|
||||
if message.data and len(message.data) > 0:
|
||||
ttyRestorer.raw()
|
||||
log.debug(f"stdout: {message.data}")
|
||||
os.write(1, message.data)
|
||||
sys.stdout.flush()
|
||||
if message.eof:
|
||||
os.close(1)
|
||||
if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDERR:
|
||||
if message.data and len(message.data) > 0:
|
||||
ttyRestorer.raw()
|
||||
log.debug(f"stdout: {message.data}")
|
||||
os.write(2, message.data)
|
||||
sys.stderr.flush()
|
||||
if message.eof:
|
||||
os.close(2)
|
||||
elif isinstance(message, protocol.CommandExitedMessage):
|
||||
log.debug(f"received return code {message.return_code}, exiting")
|
||||
with exception.permit(SystemExit, KeyboardInterrupt):
|
||||
_link.teardown()
|
||||
return message.return_code
|
||||
elif isinstance(message, protocol.ErrorMessage):
|
||||
log.error(message.data)
|
||||
if message.fatal:
|
||||
_link.teardown()
|
||||
return 200
|
||||
|
||||
except queue.Empty:
|
||||
processed = False
|
||||
# send window change, but rate limited
|
||||
if winch and time.time() - last_winch > _link.rtt * 25:
|
||||
last_winch = time.time()
|
||||
winch = False
|
||||
with contextlib.suppress(Exception):
|
||||
r, c, h, v = process.tty_get_winsize(0)
|
||||
channel.send(protocol.WindowSizeMessage(r, c, h, v))
|
||||
processed = True
|
||||
except RemoteExecutionError as e:
|
||||
print(e.msg)
|
||||
return 255
|
||||
except Exception as ex:
|
||||
print(f"Client exception: {ex}")
|
||||
if _link and _link.status != RNS.Link.CLOSED:
|
||||
_link.teardown()
|
||||
return 127
|
||||
|
||||
if messenger.is_outlet_ready(outlet):
|
||||
stdin = data_buffer[:mdu]
|
||||
data_buffer = data_buffer[mdu:]
|
||||
eof = not sent_eof and stdin_eof and len(stdin) == 0
|
||||
if len(stdin) > 0 or eof:
|
||||
messenger.send(outlet, protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN,
|
||||
stdin, eof))
|
||||
sent_eof = eof
|
||||
processed = True
|
||||
|
||||
# send window change, but rate limited
|
||||
if winch and time.time() - last_winch > _link.rtt * 25:
|
||||
last_winch = time.time()
|
||||
winch = False
|
||||
with contextlib.suppress(Exception):
|
||||
r, c, h, v = process.tty_get_winsize(0)
|
||||
messenger.send(outlet, protocol.WindowSizeMessage(r, c, h, v))
|
||||
processed = True
|
||||
except RemoteExecutionError as e:
|
||||
print(e.msg)
|
||||
return 255
|
||||
except Exception as ex:
|
||||
print(f"Client exception: {ex}")
|
||||
if _link and _link.status != RNS.Link.CLOSED:
|
||||
_link.teardown()
|
||||
return 127
|
||||
|
||||
# await process.event_wait_any([_new_data, _finished], timeout=min(max(rtt * 50, 5), 120))
|
||||
# await sleeper.sleep_async()
|
||||
log.debug("after main loop")
|
||||
return 0
|
||||
# await process.event_wait_any([_new_data, _finished], timeout=min(max(rtt * 50, 5), 120))
|
||||
# await sleeper.sleep_async()
|
||||
log.debug("after main loop")
|
||||
return 0
|
@ -159,7 +159,7 @@ async def listen(configdir, command, identitypath=None, service_name=None, verbo
|
||||
log.warning("Warning: No allowed identities configured, rnsh will not accept any connections!")
|
||||
|
||||
def link_established(lnk: RNS.Link):
|
||||
session.ListenerSession(session.RNSOutlet.get_outlet(lnk), loop)
|
||||
session.ListenerSession(session.RNSOutlet.get_outlet(lnk), lnk.get_channel(), loop)
|
||||
_destination.set_link_established_callback(link_established)
|
||||
|
||||
_finished = asyncio.Event()
|
||||
@ -188,7 +188,6 @@ async def listen(configdir, command, identitypath=None, service_name=None, verbo
|
||||
log.warning("Shutting down")
|
||||
await session.ListenerSession.terminate_all("Shutting down")
|
||||
await asyncio.sleep(1)
|
||||
session.ListenerSession.messenger.shutdown()
|
||||
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
|
||||
for link in links_still_active:
|
||||
if link.status not in [RNS.Link.CLOSED]:
|
||||
|
264
rnsh/protocol.py
264
rnsh/protocol.py
@ -19,9 +19,6 @@ from abc import ABC, abstractmethod
|
||||
|
||||
module_logger = __logging.getLogger(__name__)
|
||||
|
||||
|
||||
_TReceipt = TypeVar("_TReceipt")
|
||||
_TLink = TypeVar("_TLink")
|
||||
MSG_MAGIC = 0xac
|
||||
PROTOCOL_VERSION = 1
|
||||
|
||||
@ -30,120 +27,17 @@ def _make_MSGTYPE(val: int):
|
||||
return ((MSG_MAGIC << 8) & 0xff00) | (val & 0x00ff)
|
||||
|
||||
|
||||
class MessageOutletBase(ABC):
|
||||
@abstractmethod
|
||||
def send(self, raw: bytes) -> _TReceipt:
|
||||
raise NotImplemented()
|
||||
|
||||
@abstractmethod
|
||||
def resend(self, receipt: _TReceipt) -> _TReceipt:
|
||||
raise NotImplemented()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def mdu(self):
|
||||
raise NotImplemented()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def rtt(self):
|
||||
raise NotImplemented()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_usuable(self):
|
||||
raise NotImplemented()
|
||||
|
||||
@abstractmethod
|
||||
def get_receipt_state(self, receipt: _TReceipt) -> MessageState:
|
||||
raise NotImplemented()
|
||||
|
||||
@abstractmethod
|
||||
def timed_out(self):
|
||||
raise NotImplemented()
|
||||
|
||||
@abstractmethod
|
||||
def __str__(self):
|
||||
raise NotImplemented()
|
||||
|
||||
@abstractmethod
|
||||
def set_packet_received_callback(self, cb: Callable[[MessageOutletBase, bytes], None]):
|
||||
raise NotImplemented()
|
||||
|
||||
|
||||
class METype(enum.IntEnum):
|
||||
ME_NO_MSG_TYPE = 0
|
||||
ME_INVALID_MSG_TYPE = 1
|
||||
ME_NOT_REGISTERED = 2
|
||||
ME_LINK_NOT_READY = 3
|
||||
ME_ALREADY_SENT = 4
|
||||
|
||||
|
||||
class MessagingException(Exception):
|
||||
def __init__(self, type: METype, *args):
|
||||
super().__init__(args)
|
||||
self.type = type
|
||||
|
||||
|
||||
class MessageState(enum.IntEnum):
|
||||
MSGSTATE_NEW = 0
|
||||
MSGSTATE_SENT = 1
|
||||
MSGSTATE_DELIVERED = 2
|
||||
MSGSTATE_FAILED = 3
|
||||
|
||||
|
||||
class Message(abc.ABC):
|
||||
MSGTYPE = None
|
||||
|
||||
def __init__(self):
|
||||
self.ts = time.time()
|
||||
self.msgid = uuid.uuid4()
|
||||
self.raw: bytes | None = None
|
||||
self.receipt: _TReceipt = None
|
||||
self.outlet: _TLink = None
|
||||
self.tracked: bool = False
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.__class__.__name__} {self.msgid}"
|
||||
|
||||
@abstractmethod
|
||||
def pack(self) -> bytes:
|
||||
raise NotImplemented()
|
||||
|
||||
@abstractmethod
|
||||
def unpack(self, raw):
|
||||
raise NotImplemented()
|
||||
|
||||
def unwrap_MSGTYPE(self, raw: bytes) -> bytes:
|
||||
if self.MSGTYPE is None:
|
||||
raise MessagingException(METype.ME_NO_MSG_TYPE, f"{self.__class__} lacks MSGTYPE")
|
||||
mid, raw = self.static_unwrap_MSGTYPE(raw)
|
||||
if mid != self.MSGTYPE:
|
||||
raise MessagingException(METype.ME_INVALID_MSG_TYPE,
|
||||
f"invalid msg id, expected {hex(self.MSGTYPE)} got {hex(mid)}")
|
||||
return raw
|
||||
|
||||
def wrap_MSGTYPE(self, raw: bytes) -> bytes:
|
||||
if self.__class__.MSGTYPE is None:
|
||||
raise MessagingException(METype.ME_NO_MSG_TYPE, f"{self.__class__} lacks MSGTYPE")
|
||||
return struct.pack(">H", self.MSGTYPE) + raw
|
||||
|
||||
@staticmethod
|
||||
def static_unwrap_MSGTYPE(raw: bytes) -> (int, bytes):
|
||||
return struct.unpack(">H", raw[:2])[0], raw[2:]
|
||||
|
||||
|
||||
class NoopMessage(Message):
|
||||
class NoopMessage(RNS.MessageBase):
|
||||
MSGTYPE = _make_MSGTYPE(0)
|
||||
|
||||
def pack(self) -> bytes:
|
||||
return self.wrap_MSGTYPE(bytes())
|
||||
return bytes()
|
||||
|
||||
def unpack(self, raw):
|
||||
self.unwrap_MSGTYPE(raw)
|
||||
pass
|
||||
|
||||
|
||||
class WindowSizeMessage(Message):
|
||||
class WindowSizeMessage(RNS.MessageBase):
|
||||
MSGTYPE = _make_MSGTYPE(2)
|
||||
|
||||
def __init__(self, rows: int = None, cols: int = None, hpix: int = None, vpix: int = None):
|
||||
@ -154,15 +48,13 @@ class WindowSizeMessage(Message):
|
||||
self.vpix = vpix
|
||||
|
||||
def pack(self) -> bytes:
|
||||
raw = umsgpack.packb((self.rows, self.cols, self.hpix, self.vpix))
|
||||
return self.wrap_MSGTYPE(raw)
|
||||
return umsgpack.packb((self.rows, self.cols, self.hpix, self.vpix))
|
||||
|
||||
def unpack(self, raw):
|
||||
raw = self.unwrap_MSGTYPE(raw)
|
||||
self.rows, self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw)
|
||||
|
||||
|
||||
class ExecuteCommandMesssage(Message):
|
||||
class ExecuteCommandMesssage(RNS.MessageBase):
|
||||
MSGTYPE = _make_MSGTYPE(3)
|
||||
|
||||
def __init__(self, cmdline: [str] = None, pipe_stdin: bool = False, pipe_stdout: bool = False,
|
||||
@ -181,20 +73,20 @@ class ExecuteCommandMesssage(Message):
|
||||
self.vpix = vpix
|
||||
|
||||
def pack(self) -> bytes:
|
||||
raw = umsgpack.packb((self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr,
|
||||
self.tcflags, self.term, self.rows, self.cols, self.hpix, self.vpix))
|
||||
return self.wrap_MSGTYPE(raw)
|
||||
return umsgpack.packb((self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr,
|
||||
self.tcflags, self.term, self.rows, self.cols, self.hpix, self.vpix))
|
||||
|
||||
def unpack(self, raw):
|
||||
raw = self.unwrap_MSGTYPE(raw)
|
||||
self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, self.tcflags, self.term, self.rows, \
|
||||
self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw)
|
||||
|
||||
class StreamDataMessage(Message):
|
||||
|
||||
class StreamDataMessage(RNS.MessageBase):
|
||||
MSGTYPE = _make_MSGTYPE(4)
|
||||
STREAM_ID_STDIN = 0
|
||||
STREAM_ID_STDOUT = 1
|
||||
STREAM_ID_STDERR = 2
|
||||
OVERHEAD = 0
|
||||
|
||||
def __init__(self, stream_id: int = None, data: bytes = None, eof: bool = False):
|
||||
super().__init__()
|
||||
@ -203,15 +95,19 @@ class StreamDataMessage(Message):
|
||||
self.eof = eof
|
||||
|
||||
def pack(self) -> bytes:
|
||||
raw = umsgpack.packb((self.stream_id, self.eof, bytes(self.data)))
|
||||
return self.wrap_MSGTYPE(raw)
|
||||
return umsgpack.packb((self.stream_id, self.eof, bytes(self.data)))
|
||||
|
||||
def unpack(self, raw):
|
||||
raw = self.unwrap_MSGTYPE(raw)
|
||||
self.stream_id, self.eof, self.data = umsgpack.unpackb(raw)
|
||||
|
||||
|
||||
class VersionInfoMessage(Message):
|
||||
_link_sized_bytes = ("\0"*RNS.Link.MDU).encode("utf-8")
|
||||
StreamDataMessage.OVERHEAD = len(StreamDataMessage(stream_id=0, data=_link_sized_bytes, eof=True).pack()) \
|
||||
- len(_link_sized_bytes)
|
||||
_link_sized_bytes = None
|
||||
|
||||
|
||||
class VersionInfoMessage(RNS.MessageBase):
|
||||
MSGTYPE = _make_MSGTYPE(5)
|
||||
|
||||
def __init__(self, sw_version: str = None):
|
||||
@ -220,15 +116,13 @@ class VersionInfoMessage(Message):
|
||||
self.protocol_version = PROTOCOL_VERSION
|
||||
|
||||
def pack(self) -> bytes:
|
||||
raw = umsgpack.packb((self.sw_version, self.protocol_version))
|
||||
return self.wrap_MSGTYPE(raw)
|
||||
return umsgpack.packb((self.sw_version, self.protocol_version))
|
||||
|
||||
def unpack(self, raw):
|
||||
raw = self.unwrap_MSGTYPE(raw)
|
||||
self.sw_version, self.protocol_version = umsgpack.unpackb(raw)
|
||||
|
||||
|
||||
class ErrorMessage(Message):
|
||||
class ErrorMessage(RNS.MessageBase):
|
||||
MSGTYPE = _make_MSGTYPE(6)
|
||||
|
||||
def __init__(self, msg: str = None, fatal: bool = False, data: dict = None):
|
||||
@ -238,15 +132,13 @@ class ErrorMessage(Message):
|
||||
self.data = data
|
||||
|
||||
def pack(self) -> bytes:
|
||||
raw = umsgpack.packb((self.msg, self.fatal, self.data))
|
||||
return self.wrap_MSGTYPE(raw)
|
||||
return umsgpack.packb((self.msg, self.fatal, self.data))
|
||||
|
||||
def unpack(self, raw: bytes):
|
||||
raw = self.unwrap_MSGTYPE(raw)
|
||||
self.msg, self.fatal, self.data = umsgpack.unpackb(raw)
|
||||
|
||||
|
||||
class CommandExitedMessage(Message):
|
||||
class CommandExitedMessage(RNS.MessageBase):
|
||||
MSGTYPE = _make_MSGTYPE(7)
|
||||
|
||||
def __init__(self, return_code: int = None):
|
||||
@ -254,114 +146,16 @@ class CommandExitedMessage(Message):
|
||||
self.return_code = return_code
|
||||
|
||||
def pack(self) -> bytes:
|
||||
raw = umsgpack.packb(self.return_code)
|
||||
return self.wrap_MSGTYPE(raw)
|
||||
return umsgpack.packb(self.return_code)
|
||||
|
||||
def unpack(self, raw: bytes):
|
||||
raw = self.unwrap_MSGTYPE(raw)
|
||||
self.return_code = umsgpack.unpackb(raw)
|
||||
|
||||
|
||||
class Messenger(contextlib.AbstractContextManager):
|
||||
message_types = [NoopMessage, VersionInfoMessage, WindowSizeMessage, ExecuteCommandMesssage, StreamDataMessage,
|
||||
CommandExitedMessage, ErrorMessage]
|
||||
|
||||
@staticmethod
|
||||
def _get_msg_constructors() -> (int, Type[Message]):
|
||||
subclass_tuples = []
|
||||
for subclass in Message.__subclasses__():
|
||||
subclass_tuples.append((subclass.MSGTYPE, subclass))
|
||||
return subclass_tuples
|
||||
|
||||
def __init__(self, retry_delay_min: float = 10.0):
|
||||
self._log = module_logger.getChild(self.__class__.__name__)
|
||||
self._sent_messages: list[Message] = []
|
||||
self._lock = threading.RLock()
|
||||
self._retry_timer = rnsh.retry.RetryThread()
|
||||
self._message_factories = dict(self.__class__._get_msg_constructors())
|
||||
self._retry_delay_min = retry_delay_min
|
||||
|
||||
def __enter__(self) -> Messenger:
|
||||
return self
|
||||
|
||||
def __exit__(self, __exc_type: Type[BaseException] | None, __exc_value: BaseException | None,
|
||||
__traceback: TracebackType | None) -> bool | None:
|
||||
self.shutdown()
|
||||
return False
|
||||
|
||||
def shutdown(self):
|
||||
self._retry_timer.close()
|
||||
|
||||
def clear_retries(self, outlet):
|
||||
self._retry_timer.complete(outlet)
|
||||
|
||||
def receive(self, raw: bytes) -> Message:
|
||||
(mid, contents) = Message.static_unwrap_MSGTYPE(raw)
|
||||
ctor = self._message_factories.get(mid, None)
|
||||
if ctor is None:
|
||||
raise MessagingException(METype.ME_NOT_REGISTERED, f"unable to find constructor for message type {hex(mid)}")
|
||||
message = ctor()
|
||||
message.unpack(raw)
|
||||
self._log.debug(f"Message received: {message}")
|
||||
return message
|
||||
|
||||
def is_outlet_ready(self, outlet: MessageOutletBase) -> bool:
|
||||
if not outlet.is_usuable:
|
||||
self._log.debug("is_outlet_ready outlet unusable")
|
||||
return False
|
||||
|
||||
with self._lock:
|
||||
for message in self._sent_messages:
|
||||
if message.outlet == outlet and message.tracked and message.receipt \
|
||||
and outlet.get_receipt_state(message.receipt) == MessageState.MSGSTATE_SENT:
|
||||
self._log.debug("is_outlet_ready pending message found")
|
||||
return False
|
||||
return True
|
||||
|
||||
def send(self, outlet: MessageOutletBase, message: Message):
|
||||
with self._lock:
|
||||
if not self.is_outlet_ready(outlet):
|
||||
raise MessagingException(METype.ME_LINK_NOT_READY, f"link {outlet} not ready")
|
||||
|
||||
if message in self._sent_messages:
|
||||
raise MessagingException(METype.ME_ALREADY_SENT)
|
||||
self._sent_messages.append(message)
|
||||
message.tracked = True
|
||||
|
||||
if not message.raw:
|
||||
message.raw = message.pack()
|
||||
message.outlet = outlet
|
||||
|
||||
def send_inner(tag: any, tries: int):
|
||||
state = MessageState.MSGSTATE_NEW if not message.receipt else outlet.get_receipt_state(message.receipt)
|
||||
if state in [MessageState.MSGSTATE_NEW, MessageState.MSGSTATE_FAILED]:
|
||||
try:
|
||||
if message.receipt:
|
||||
self._log.debug(f"Resending packet for {message}")
|
||||
message.receipt = outlet.resend(message.receipt)
|
||||
else:
|
||||
self._log.debug(f"Sending packet for {message}")
|
||||
message.receipt = outlet.send(message.raw)
|
||||
except Exception as ex:
|
||||
self._log.exception(f"Error sending message {message}")
|
||||
elif state in [MessageState.MSGSTATE_SENT]:
|
||||
self._log.debug(f"Retry skipped, message still pending {message}")
|
||||
elif state in [MessageState.MSGSTATE_DELIVERED]:
|
||||
latency = round(time.time() - message.ts, 1)
|
||||
self._log.debug(f"{message} delivered {message.msgid} after {tries-1} tries/{latency} seconds")
|
||||
with self._lock:
|
||||
self._sent_messages.remove(message)
|
||||
message.tracked = False
|
||||
self._retry_timer.complete(outlet)
|
||||
return outlet
|
||||
|
||||
def timeout(tag: any, tries: int):
|
||||
latency = round(time.time() - message.ts, 1)
|
||||
msg = "delivered" if message.receipt and outlet.get_receipt_state(message.receipt) == MessageState.MSGSTATE_DELIVERED else "retry timeout"
|
||||
self._log.debug(f"Message {msg} {message} after {tries} tries/{latency} seconds")
|
||||
with self._lock:
|
||||
self._sent_messages.remove(message)
|
||||
message.tracked = False
|
||||
outlet.timed_out()
|
||||
|
||||
rtt = outlet.rtt
|
||||
self._retry_timer.begin(5, max(rtt * 5, self._retry_delay_min), send_inner, timeout)
|
||||
|
||||
def register_message_types(channel: RNS.Channel.Channel):
|
||||
for message_type in message_types:
|
||||
channel.register_message_type(message_type)
|
88
rnsh/rnsh.py
88
rnsh/rnsh.py
@ -102,53 +102,58 @@ def print_identity(configdir, identitypath, service_name, include_destination: b
|
||||
exit(0)
|
||||
|
||||
|
||||
verbose_set = False
|
||||
|
||||
|
||||
async def _rnsh_cli_main():
|
||||
#with contextlib.suppress(KeyboardInterrupt, SystemExit):
|
||||
import docopt
|
||||
log = _get_logger("main")
|
||||
_loop = asyncio.get_running_loop()
|
||||
rnslogging.set_main_loop(_loop)
|
||||
args = rnsh.args.Args(sys.argv)
|
||||
global verbose_set
|
||||
log = _get_logger("main")
|
||||
_loop = asyncio.get_running_loop()
|
||||
rnslogging.set_main_loop(_loop)
|
||||
args = rnsh.args.Args(sys.argv)
|
||||
verbose_set = args.verbose > 0
|
||||
|
||||
if args.print_identity:
|
||||
print_identity(args.config, args.identity, args.service_name, args.listen)
|
||||
return 0
|
||||
if args.print_identity:
|
||||
print_identity(args.config, args.identity, args.service_name, args.listen)
|
||||
return 0
|
||||
|
||||
if args.listen:
|
||||
# log.info("command " + args.command)
|
||||
await listener.listen(configdir=args.config,
|
||||
command=args.command_line,
|
||||
identitypath=args.identity,
|
||||
service_name=args.service_name,
|
||||
verbosity=args.verbose,
|
||||
quietness=args.quiet,
|
||||
allowed=args.allowed,
|
||||
disable_auth=args.no_auth,
|
||||
announce_period=args.announce,
|
||||
no_remote_command=args.no_remote_cmd,
|
||||
remote_cmd_as_args=args.remote_cmd_as_args)
|
||||
return 0
|
||||
if args.listen:
|
||||
# log.info("command " + args.command)
|
||||
await listener.listen(configdir=args.config,
|
||||
command=args.command_line,
|
||||
identitypath=args.identity,
|
||||
service_name=args.service_name,
|
||||
verbosity=args.verbose,
|
||||
quietness=args.quiet,
|
||||
allowed=args.allowed,
|
||||
disable_auth=args.no_auth,
|
||||
announce_period=args.announce,
|
||||
no_remote_command=args.no_remote_cmd,
|
||||
remote_cmd_as_args=args.remote_cmd_as_args)
|
||||
return 0
|
||||
|
||||
if args.destination is not None:
|
||||
return_code = await initiator.initiate(configdir=args.config,
|
||||
identitypath=args.identity,
|
||||
verbosity=args.verbose,
|
||||
quietness=args.quiet,
|
||||
noid=args.no_id,
|
||||
destination=args.destination,
|
||||
timeout=args.timeout,
|
||||
command=args.command_line
|
||||
)
|
||||
return return_code if args.mirror else 0
|
||||
else:
|
||||
print("")
|
||||
print(rnsh.args.usage)
|
||||
print("")
|
||||
return 1
|
||||
if args.destination is not None:
|
||||
return_code = await initiator.initiate(configdir=args.config,
|
||||
identitypath=args.identity,
|
||||
verbosity=args.verbose,
|
||||
quietness=args.quiet,
|
||||
noid=args.no_id,
|
||||
destination=args.destination,
|
||||
timeout=args.timeout,
|
||||
command=args.command_line
|
||||
)
|
||||
return return_code if args.mirror else 0
|
||||
else:
|
||||
print("")
|
||||
print(rnsh.args.usage)
|
||||
print("")
|
||||
return 1
|
||||
|
||||
|
||||
def rnsh_cli():
|
||||
global verbose_set
|
||||
return_code = 1
|
||||
exc = None
|
||||
try:
|
||||
return_code = asyncio.run(_rnsh_cli_main())
|
||||
except SystemExit:
|
||||
@ -156,8 +161,11 @@ def rnsh_cli():
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except Exception as ex:
|
||||
print(f"Unhandled exception: {ex}")
|
||||
print(f"Unhandled exception: {ex}")
|
||||
exc = ex
|
||||
process.tty_unset_reader_callbacks(0)
|
||||
if verbose_set and exc:
|
||||
raise exc
|
||||
sys.exit(return_code if return_code is not None else 255)
|
||||
|
||||
|
||||
|
@ -16,8 +16,6 @@ import RNS
|
||||
|
||||
import logging as __logging
|
||||
|
||||
from rnsh.protocol import MessageOutletBase, _TReceipt, MessageState
|
||||
|
||||
module_logger = __logging.getLogger(__name__)
|
||||
|
||||
_TLink = TypeVar("_TLink")
|
||||
@ -44,7 +42,7 @@ class LSState(enum.IntEnum):
|
||||
_TIdentity = TypeVar("_TIdentity")
|
||||
|
||||
|
||||
class LSOutletBase(protocol.MessageOutletBase):
|
||||
class LSOutletBase(ABC):
|
||||
@abstractmethod
|
||||
def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]):
|
||||
raise NotImplemented()
|
||||
@ -57,28 +55,29 @@ class LSOutletBase(protocol.MessageOutletBase):
|
||||
def unset_link_closed_callback(self):
|
||||
raise NotImplemented()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def teardown(self):
|
||||
def rtt(self):
|
||||
raise NotImplemented()
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
def teardown(self):
|
||||
raise NotImplemented()
|
||||
|
||||
|
||||
class ListenerSession:
|
||||
sessions: List[ListenerSession] = []
|
||||
messenger: protocol.Messenger = protocol.Messenger(retry_delay_min=5)
|
||||
allowed_identity_hashes: [any] = []
|
||||
allow_all: bool = False
|
||||
allow_remote_command: bool = False
|
||||
default_command: [str] = []
|
||||
remote_cmd_as_args = False
|
||||
|
||||
def __init__(self, outlet: LSOutletBase, loop: asyncio.AbstractEventLoop):
|
||||
def __init__(self, outlet: LSOutletBase, channel: RNS.Channel.Channel, loop: asyncio.AbstractEventLoop):
|
||||
self._log = module_logger.getChild(self.__class__.__name__)
|
||||
self._log.info(f"Session started for {outlet}")
|
||||
self.outlet = outlet
|
||||
self.channel = channel
|
||||
self.outlet.set_initiator_identified_callback(self._initiator_identified)
|
||||
self.outlet.set_link_closed_callback(self._link_closed)
|
||||
self.loop = loop
|
||||
@ -106,7 +105,8 @@ class ListenerSession:
|
||||
else:
|
||||
self._set_state(LSState.LSSTATE_WAIT_IDENT)
|
||||
self.sessions.append(self)
|
||||
self.outlet.set_packet_received_callback(self._packet_received)
|
||||
protocol.register_message_types(self.channel)
|
||||
self.channel.add_message_handler(self._handle_message)
|
||||
|
||||
def _terminated(self, return_code: int):
|
||||
self.return_code = return_code
|
||||
@ -128,8 +128,8 @@ class ListenerSession:
|
||||
self.loop.call_later(delay, func)
|
||||
self.loop.call_soon_threadsafe(call_inner)
|
||||
|
||||
def send(self, message: protocol.Message):
|
||||
self.messenger.send(self.outlet, message)
|
||||
def send(self, message: RNS.MessageBase):
|
||||
self.channel.send(message)
|
||||
|
||||
def _protocol_error(self, name: str):
|
||||
self.terminate(f"Protocol error ({name})")
|
||||
@ -171,7 +171,6 @@ class ListenerSession:
|
||||
return
|
||||
|
||||
self._log.debug(f"link_closed {outlet}")
|
||||
self.messenger.clear_retries(self.outlet)
|
||||
self.terminate()
|
||||
|
||||
def _initiator_identified(self, outlet, identity):
|
||||
@ -208,10 +207,10 @@ class ListenerSession:
|
||||
try:
|
||||
if self.state != LSState.LSSTATE_RUNNING:
|
||||
return False
|
||||
elif not self.messenger.is_outlet_ready(self.outlet):
|
||||
elif not self.channel.is_ready_to_send():
|
||||
return False
|
||||
elif len(self.stderr_buf) > 0:
|
||||
mdu = self.outlet.mdu - 16
|
||||
mdu = self.channel.MDU - protocol.StreamDataMessage.OVERHEAD
|
||||
data = self.stderr_buf[:mdu]
|
||||
self.stderr_buf = self.stderr_buf[mdu:]
|
||||
send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent
|
||||
@ -223,7 +222,7 @@ class ListenerSession:
|
||||
self.stderr_eof_sent = True
|
||||
return True
|
||||
elif len(self.stdout_buf) > 0:
|
||||
mdu = self.outlet.mdu - 16
|
||||
mdu = self.channel.MDU - protocol.StreamDataMessage.OVERHEAD
|
||||
data = self.stdout_buf[:mdu]
|
||||
self.stdout_buf = self.stdout_buf[mdu:]
|
||||
send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent
|
||||
@ -309,7 +308,7 @@ class ListenerSession:
|
||||
if eof:
|
||||
self.process.close_stdin()
|
||||
|
||||
def _handle_message(self, message: protocol.Message):
|
||||
def _handle_message(self, message: RNS.MessageBase):
|
||||
if self.state == LSState.LSSTATE_WAIT_IDENT:
|
||||
self._protocol_error("Identification required")
|
||||
return
|
||||
@ -352,17 +351,6 @@ class ListenerSession:
|
||||
self._protocol_error("unexpected message")
|
||||
return
|
||||
|
||||
def _packet_received(self, outlet: protocol.MessageOutletBase, raw: bytes):
|
||||
if outlet != self.outlet:
|
||||
self._log.debug("Packet received from incorrect outlet")
|
||||
return
|
||||
|
||||
try:
|
||||
message = self.messenger.receive(raw)
|
||||
self._handle_message(message)
|
||||
except Exception as ex:
|
||||
self._protocol_error(f"error receiving packet: {ex}")
|
||||
|
||||
|
||||
class RNSOutlet(LSOutletBase):
|
||||
|
||||
@ -384,55 +372,17 @@ class RNSOutlet(LSOutletBase):
|
||||
def teardown(self):
|
||||
self.link.teardown()
|
||||
|
||||
def send(self, raw: bytes) -> RNS.Packet:
|
||||
packet = RNS.Packet(self.link, raw)
|
||||
packet.send()
|
||||
return packet
|
||||
|
||||
def resend(self, packet: RNS.Packet) -> RNS.Packet:
|
||||
packet.resend()
|
||||
return packet
|
||||
|
||||
@property
|
||||
def mdu(self) -> int:
|
||||
return self.link.MDU
|
||||
|
||||
@property
|
||||
def rtt(self) -> float:
|
||||
return self.link.rtt
|
||||
|
||||
@property
|
||||
def is_usuable(self):
|
||||
return True #self.link.status in [RNS.Link.ACTIVE]
|
||||
|
||||
def get_receipt_state(self, packet: RNS.Packet) -> MessageState:
|
||||
status = packet.receipt.get_status()
|
||||
if status == RNS.PacketReceipt.SENT:
|
||||
return protocol.MessageState.MSGSTATE_SENT
|
||||
if status == RNS.PacketReceipt.DELIVERED:
|
||||
return protocol.MessageState.MSGSTATE_DELIVERED
|
||||
if status == RNS.PacketReceipt.FAILED:
|
||||
return protocol.MessageState.MSGSTATE_FAILED
|
||||
else:
|
||||
raise Exception(f"Unexpected receipt state: {status}")
|
||||
|
||||
def timed_out(self):
|
||||
self.link.teardown()
|
||||
|
||||
def __str__(self):
|
||||
return f"Outlet RNS Link {self.link}"
|
||||
|
||||
def set_packet_received_callback(self, cb: Callable[[MessageOutletBase, bytes], None]):
|
||||
def inner_cb(message, packet: RNS.Packet):
|
||||
packet.prove()
|
||||
cb(self, message)
|
||||
|
||||
self.link.set_packet_callback(inner_cb)
|
||||
|
||||
def __init__(self, link: RNS.Link):
|
||||
self.link = link
|
||||
link.lsoutlet = self
|
||||
link.msgoutlet = self
|
||||
|
||||
@staticmethod
|
||||
def get_outlet(link: RNS.Link):
|
||||
if hasattr(link, "lsoutlet"):
|
||||
|
Loading…
x
Reference in New Issue
Block a user