Switch to RNS-provided Channel

This commit is contained in:
Aaron Heise 2023-02-28 08:48:29 -06:00
parent b6a22cd2a7
commit 5bca575a4b
No known key found for this signature in database
GPG Key ID: 6BA54088C41DE8BF
6 changed files with 229 additions and 482 deletions

View File

@ -9,7 +9,7 @@ readme = "README.md"
[tool.poetry.dependencies]
python = "^3.9"
docopt = "^0.6.2"
rns = "^0.4.9"
rns = { git = "https://github.com/acehoss/Reticulum.git", branch = "feature/channel" } #{ path = "../Reticulum/", develop = true } #
tomli = "^2.0.1"
[tool.poetry.scripts]

View File

@ -143,12 +143,12 @@ class InitiatorState(enum.IntEnum):
def _client_link_closed(link):
log = _get_logger("_client_link_closed")
_finished.set()
if _finished:
_finished.set()
def _client_packet_handler(message, packet):
log = _get_logger("_client_packet_handler")
packet.prove()
def _client_message_handler(message: RNS.MessageBase):
log = _get_logger("_client_message_handler")
_pq.put(message)
@ -213,10 +213,8 @@ async def _initiate_link(configdir, identitypath=None, verbosity=0, quietness=0,
_link.identify(_identity)
_link.did_identify = True
_link.set_packet_callback(_client_packet_handler)
async def _handle_error(errmsg: protocol.Message):
async def _handle_error(errmsg: RNS.MessageBase):
if isinstance(errmsg, protocol.ErrorMessage):
with contextlib.suppress(Exception):
if _link and _link.status == RNS.Link.ACTIVE:
@ -249,150 +247,148 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness:
state = InitiatorState.IS_LINKED
outlet = session.RNSOutlet(_link)
with protocol.Messenger(retry_delay_min=5) as messenger:
channel = _link.get_channel()
protocol.register_message_types(channel)
channel.add_message_handler(_client_message_handler)
# Next step after linking and identifying: send version
# if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5):
# print("Error bringing up link")
# return 253
# Next step after linking and identifying: send version
# if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5):
# print("Error bringing up link")
# return 253
messenger.send(outlet, protocol.VersionInfoMessage())
channel.send(protocol.VersionInfoMessage())
try:
vm = _pq.get(timeout=max(outlet.rtt * 20, 5))
await _handle_error(vm)
if not isinstance(vm, protocol.VersionInfoMessage):
raise Exception("Invalid message received")
log.debug(f"Server version info: sw {vm.sw_version} prot {vm.protocol_version}")
state = InitiatorState.IS_RUNNING
except queue.Empty:
print("Protocol error")
return 254
winch = False
def sigwinch_handler():
nonlocal winch
# log.debug("WindowChanged")
winch = True
stdin_eof = False
def stdin():
nonlocal stdin_eof
try:
vp = _pq.get(timeout=max(outlet.rtt * 20, 5))
vm = messenger.receive(vp)
await _handle_error(vm)
if not isinstance(vm, protocol.VersionInfoMessage):
raise Exception("Invalid message received")
log.debug(f"Server version info: sw {vm.sw_version} prot {vm.protocol_version}")
state = InitiatorState.IS_RUNNING
except queue.Empty:
print("Protocol error")
return 254
data = process.tty_read(sys.stdin.fileno())
log.debug(f"stdin {data}")
if data is not None:
data_buffer.extend(data)
except EOFError:
if os.isatty(0):
data_buffer.extend(process.CTRL_D)
stdin_eof = True
process.tty_unset_reader_callbacks(sys.stdin.fileno())
winch = False
def sigwinch_handler():
nonlocal winch
# log.debug("WindowChanged")
winch = True
process.tty_add_reader_callback(sys.stdin.fileno(), stdin)
stdin_eof = False
def stdin():
nonlocal stdin_eof
try:
data = process.tty_read(sys.stdin.fileno())
log.debug(f"stdin {data}")
if data is not None:
data_buffer.extend(data)
except EOFError:
if os.isatty(0):
data_buffer.extend(process.CTRL_D)
stdin_eof = True
process.tty_unset_reader_callbacks(sys.stdin.fileno())
process.tty_add_reader_callback(sys.stdin.fileno(), stdin)
tcattr = None
rows, cols, hpix, vpix = (None, None, None, None)
tcattr = None
rows, cols, hpix, vpix = (None, None, None, None)
try:
tcattr = termios.tcgetattr(0)
rows, cols, hpix, vpix = process.tty_get_winsize(0)
except:
try:
tcattr = termios.tcgetattr(0)
rows, cols, hpix, vpix = process.tty_get_winsize(0)
tcattr = termios.tcgetattr(1)
rows, cols, hpix, vpix = process.tty_get_winsize(1)
except:
try:
tcattr = termios.tcgetattr(1)
rows, cols, hpix, vpix = process.tty_get_winsize(1)
tcattr = termios.tcgetattr(2)
rows, cols, hpix, vpix = process.tty_get_winsize(2)
except:
try:
tcattr = termios.tcgetattr(2)
rows, cols, hpix, vpix = process.tty_get_winsize(2)
except:
pass
pass
messenger.send(outlet, protocol.ExecuteCommandMesssage(cmdline=command,
pipe_stdin=not os.isatty(0),
pipe_stdout=not os.isatty(1),
pipe_stderr=not os.isatty(2),
tcflags=tcattr,
term=os.environ.get("TERM", None),
rows=rows,
cols=cols,
hpix=hpix,
vpix=vpix))
await _spin(lambda: channel.is_ready_to_send(), "Waiting for channel...", 1)
channel.send(protocol.ExecuteCommandMesssage(cmdline=command,
pipe_stdin=not os.isatty(0),
pipe_stdout=not os.isatty(1),
pipe_stderr=not os.isatty(2),
tcflags=tcattr,
term=os.environ.get("TERM", None),
rows=rows,
cols=cols,
hpix=hpix,
vpix=vpix))
loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler)
_finished = asyncio.Event()
loop.add_signal_handler(signal.SIGINT, functools.partial(_sigint_handler, signal.SIGINT, loop))
loop.add_signal_handler(signal.SIGTERM, functools.partial(_sigint_handler, signal.SIGTERM, loop))
mdu = _link.MDU - 16
sent_eof = False
last_winch = time.time()
sleeper = helpers.SleepRate(0.01)
processed = False
while not await _check_finished() and state in [InitiatorState.IS_RUNNING]:
loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler)
_finished = asyncio.Event()
loop.add_signal_handler(signal.SIGINT, functools.partial(_sigint_handler, signal.SIGINT, loop))
loop.add_signal_handler(signal.SIGTERM, functools.partial(_sigint_handler, signal.SIGTERM, loop))
mdu = _link.MDU - 16
sent_eof = False
last_winch = time.time()
sleeper = helpers.SleepRate(0.01)
processed = False
while not await _check_finished() and state in [InitiatorState.IS_RUNNING]:
try:
try:
try:
packet = _pq.get(timeout=sleeper.next_sleep_time() if not processed else 0.0005)
message = messenger.receive(packet)
await _handle_error(message)
message = _pq.get(timeout=sleeper.next_sleep_time() if not processed else 0.0005)
await _handle_error(message)
processed = True
if isinstance(message, protocol.StreamDataMessage):
if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDOUT:
if message.data and len(message.data) > 0:
ttyRestorer.raw()
log.debug(f"stdout: {message.data}")
os.write(1, message.data)
sys.stdout.flush()
if message.eof:
os.close(1)
if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDERR:
if message.data and len(message.data) > 0:
ttyRestorer.raw()
log.debug(f"stdout: {message.data}")
os.write(2, message.data)
sys.stderr.flush()
if message.eof:
os.close(2)
elif isinstance(message, protocol.CommandExitedMessage):
log.debug(f"received return code {message.return_code}, exiting")
return message.return_code
elif isinstance(message, protocol.ErrorMessage):
log.error(message.data)
if message.fatal:
_link.teardown()
return 200
except queue.Empty:
processed = False
if channel.is_ready_to_send():
stdin = data_buffer[:mdu]
data_buffer = data_buffer[mdu:]
eof = not sent_eof and stdin_eof and len(stdin) == 0
if len(stdin) > 0 or eof:
channel.send(protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN, stdin, eof))
sent_eof = eof
processed = True
if isinstance(message, protocol.StreamDataMessage):
if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDOUT:
if message.data and len(message.data) > 0:
ttyRestorer.raw()
log.debug(f"stdout: {message.data}")
os.write(1, message.data)
sys.stdout.flush()
if message.eof:
os.close(1)
if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDERR:
if message.data and len(message.data) > 0:
ttyRestorer.raw()
log.debug(f"stdout: {message.data}")
os.write(2, message.data)
sys.stderr.flush()
if message.eof:
os.close(2)
elif isinstance(message, protocol.CommandExitedMessage):
log.debug(f"received return code {message.return_code}, exiting")
with exception.permit(SystemExit, KeyboardInterrupt):
_link.teardown()
return message.return_code
elif isinstance(message, protocol.ErrorMessage):
log.error(message.data)
if message.fatal:
_link.teardown()
return 200
except queue.Empty:
processed = False
# send window change, but rate limited
if winch and time.time() - last_winch > _link.rtt * 25:
last_winch = time.time()
winch = False
with contextlib.suppress(Exception):
r, c, h, v = process.tty_get_winsize(0)
channel.send(protocol.WindowSizeMessage(r, c, h, v))
processed = True
except RemoteExecutionError as e:
print(e.msg)
return 255
except Exception as ex:
print(f"Client exception: {ex}")
if _link and _link.status != RNS.Link.CLOSED:
_link.teardown()
return 127
if messenger.is_outlet_ready(outlet):
stdin = data_buffer[:mdu]
data_buffer = data_buffer[mdu:]
eof = not sent_eof and stdin_eof and len(stdin) == 0
if len(stdin) > 0 or eof:
messenger.send(outlet, protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN,
stdin, eof))
sent_eof = eof
processed = True
# send window change, but rate limited
if winch and time.time() - last_winch > _link.rtt * 25:
last_winch = time.time()
winch = False
with contextlib.suppress(Exception):
r, c, h, v = process.tty_get_winsize(0)
messenger.send(outlet, protocol.WindowSizeMessage(r, c, h, v))
processed = True
except RemoteExecutionError as e:
print(e.msg)
return 255
except Exception as ex:
print(f"Client exception: {ex}")
if _link and _link.status != RNS.Link.CLOSED:
_link.teardown()
return 127
# await process.event_wait_any([_new_data, _finished], timeout=min(max(rtt * 50, 5), 120))
# await sleeper.sleep_async()
log.debug("after main loop")
return 0
# await process.event_wait_any([_new_data, _finished], timeout=min(max(rtt * 50, 5), 120))
# await sleeper.sleep_async()
log.debug("after main loop")
return 0

View File

@ -159,7 +159,7 @@ async def listen(configdir, command, identitypath=None, service_name=None, verbo
log.warning("Warning: No allowed identities configured, rnsh will not accept any connections!")
def link_established(lnk: RNS.Link):
session.ListenerSession(session.RNSOutlet.get_outlet(lnk), loop)
session.ListenerSession(session.RNSOutlet.get_outlet(lnk), lnk.get_channel(), loop)
_destination.set_link_established_callback(link_established)
_finished = asyncio.Event()
@ -188,7 +188,6 @@ async def listen(configdir, command, identitypath=None, service_name=None, verbo
log.warning("Shutting down")
await session.ListenerSession.terminate_all("Shutting down")
await asyncio.sleep(1)
session.ListenerSession.messenger.shutdown()
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
for link in links_still_active:
if link.status not in [RNS.Link.CLOSED]:

View File

@ -19,9 +19,6 @@ from abc import ABC, abstractmethod
module_logger = __logging.getLogger(__name__)
_TReceipt = TypeVar("_TReceipt")
_TLink = TypeVar("_TLink")
MSG_MAGIC = 0xac
PROTOCOL_VERSION = 1
@ -30,120 +27,17 @@ def _make_MSGTYPE(val: int):
return ((MSG_MAGIC << 8) & 0xff00) | (val & 0x00ff)
class MessageOutletBase(ABC):
@abstractmethod
def send(self, raw: bytes) -> _TReceipt:
raise NotImplemented()
@abstractmethod
def resend(self, receipt: _TReceipt) -> _TReceipt:
raise NotImplemented()
@property
@abstractmethod
def mdu(self):
raise NotImplemented()
@property
@abstractmethod
def rtt(self):
raise NotImplemented()
@property
@abstractmethod
def is_usuable(self):
raise NotImplemented()
@abstractmethod
def get_receipt_state(self, receipt: _TReceipt) -> MessageState:
raise NotImplemented()
@abstractmethod
def timed_out(self):
raise NotImplemented()
@abstractmethod
def __str__(self):
raise NotImplemented()
@abstractmethod
def set_packet_received_callback(self, cb: Callable[[MessageOutletBase, bytes], None]):
raise NotImplemented()
class METype(enum.IntEnum):
ME_NO_MSG_TYPE = 0
ME_INVALID_MSG_TYPE = 1
ME_NOT_REGISTERED = 2
ME_LINK_NOT_READY = 3
ME_ALREADY_SENT = 4
class MessagingException(Exception):
def __init__(self, type: METype, *args):
super().__init__(args)
self.type = type
class MessageState(enum.IntEnum):
MSGSTATE_NEW = 0
MSGSTATE_SENT = 1
MSGSTATE_DELIVERED = 2
MSGSTATE_FAILED = 3
class Message(abc.ABC):
MSGTYPE = None
def __init__(self):
self.ts = time.time()
self.msgid = uuid.uuid4()
self.raw: bytes | None = None
self.receipt: _TReceipt = None
self.outlet: _TLink = None
self.tracked: bool = False
def __str__(self):
return f"{self.__class__.__name__} {self.msgid}"
@abstractmethod
def pack(self) -> bytes:
raise NotImplemented()
@abstractmethod
def unpack(self, raw):
raise NotImplemented()
def unwrap_MSGTYPE(self, raw: bytes) -> bytes:
if self.MSGTYPE is None:
raise MessagingException(METype.ME_NO_MSG_TYPE, f"{self.__class__} lacks MSGTYPE")
mid, raw = self.static_unwrap_MSGTYPE(raw)
if mid != self.MSGTYPE:
raise MessagingException(METype.ME_INVALID_MSG_TYPE,
f"invalid msg id, expected {hex(self.MSGTYPE)} got {hex(mid)}")
return raw
def wrap_MSGTYPE(self, raw: bytes) -> bytes:
if self.__class__.MSGTYPE is None:
raise MessagingException(METype.ME_NO_MSG_TYPE, f"{self.__class__} lacks MSGTYPE")
return struct.pack(">H", self.MSGTYPE) + raw
@staticmethod
def static_unwrap_MSGTYPE(raw: bytes) -> (int, bytes):
return struct.unpack(">H", raw[:2])[0], raw[2:]
class NoopMessage(Message):
class NoopMessage(RNS.MessageBase):
MSGTYPE = _make_MSGTYPE(0)
def pack(self) -> bytes:
return self.wrap_MSGTYPE(bytes())
return bytes()
def unpack(self, raw):
self.unwrap_MSGTYPE(raw)
pass
class WindowSizeMessage(Message):
class WindowSizeMessage(RNS.MessageBase):
MSGTYPE = _make_MSGTYPE(2)
def __init__(self, rows: int = None, cols: int = None, hpix: int = None, vpix: int = None):
@ -154,15 +48,13 @@ class WindowSizeMessage(Message):
self.vpix = vpix
def pack(self) -> bytes:
raw = umsgpack.packb((self.rows, self.cols, self.hpix, self.vpix))
return self.wrap_MSGTYPE(raw)
return umsgpack.packb((self.rows, self.cols, self.hpix, self.vpix))
def unpack(self, raw):
raw = self.unwrap_MSGTYPE(raw)
self.rows, self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw)
class ExecuteCommandMesssage(Message):
class ExecuteCommandMesssage(RNS.MessageBase):
MSGTYPE = _make_MSGTYPE(3)
def __init__(self, cmdline: [str] = None, pipe_stdin: bool = False, pipe_stdout: bool = False,
@ -181,20 +73,20 @@ class ExecuteCommandMesssage(Message):
self.vpix = vpix
def pack(self) -> bytes:
raw = umsgpack.packb((self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr,
self.tcflags, self.term, self.rows, self.cols, self.hpix, self.vpix))
return self.wrap_MSGTYPE(raw)
return umsgpack.packb((self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr,
self.tcflags, self.term, self.rows, self.cols, self.hpix, self.vpix))
def unpack(self, raw):
raw = self.unwrap_MSGTYPE(raw)
self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, self.tcflags, self.term, self.rows, \
self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw)
class StreamDataMessage(Message):
class StreamDataMessage(RNS.MessageBase):
MSGTYPE = _make_MSGTYPE(4)
STREAM_ID_STDIN = 0
STREAM_ID_STDOUT = 1
STREAM_ID_STDERR = 2
OVERHEAD = 0
def __init__(self, stream_id: int = None, data: bytes = None, eof: bool = False):
super().__init__()
@ -203,15 +95,19 @@ class StreamDataMessage(Message):
self.eof = eof
def pack(self) -> bytes:
raw = umsgpack.packb((self.stream_id, self.eof, bytes(self.data)))
return self.wrap_MSGTYPE(raw)
return umsgpack.packb((self.stream_id, self.eof, bytes(self.data)))
def unpack(self, raw):
raw = self.unwrap_MSGTYPE(raw)
self.stream_id, self.eof, self.data = umsgpack.unpackb(raw)
class VersionInfoMessage(Message):
_link_sized_bytes = ("\0"*RNS.Link.MDU).encode("utf-8")
StreamDataMessage.OVERHEAD = len(StreamDataMessage(stream_id=0, data=_link_sized_bytes, eof=True).pack()) \
- len(_link_sized_bytes)
_link_sized_bytes = None
class VersionInfoMessage(RNS.MessageBase):
MSGTYPE = _make_MSGTYPE(5)
def __init__(self, sw_version: str = None):
@ -220,15 +116,13 @@ class VersionInfoMessage(Message):
self.protocol_version = PROTOCOL_VERSION
def pack(self) -> bytes:
raw = umsgpack.packb((self.sw_version, self.protocol_version))
return self.wrap_MSGTYPE(raw)
return umsgpack.packb((self.sw_version, self.protocol_version))
def unpack(self, raw):
raw = self.unwrap_MSGTYPE(raw)
self.sw_version, self.protocol_version = umsgpack.unpackb(raw)
class ErrorMessage(Message):
class ErrorMessage(RNS.MessageBase):
MSGTYPE = _make_MSGTYPE(6)
def __init__(self, msg: str = None, fatal: bool = False, data: dict = None):
@ -238,15 +132,13 @@ class ErrorMessage(Message):
self.data = data
def pack(self) -> bytes:
raw = umsgpack.packb((self.msg, self.fatal, self.data))
return self.wrap_MSGTYPE(raw)
return umsgpack.packb((self.msg, self.fatal, self.data))
def unpack(self, raw: bytes):
raw = self.unwrap_MSGTYPE(raw)
self.msg, self.fatal, self.data = umsgpack.unpackb(raw)
class CommandExitedMessage(Message):
class CommandExitedMessage(RNS.MessageBase):
MSGTYPE = _make_MSGTYPE(7)
def __init__(self, return_code: int = None):
@ -254,114 +146,16 @@ class CommandExitedMessage(Message):
self.return_code = return_code
def pack(self) -> bytes:
raw = umsgpack.packb(self.return_code)
return self.wrap_MSGTYPE(raw)
return umsgpack.packb(self.return_code)
def unpack(self, raw: bytes):
raw = self.unwrap_MSGTYPE(raw)
self.return_code = umsgpack.unpackb(raw)
class Messenger(contextlib.AbstractContextManager):
message_types = [NoopMessage, VersionInfoMessage, WindowSizeMessage, ExecuteCommandMesssage, StreamDataMessage,
CommandExitedMessage, ErrorMessage]
@staticmethod
def _get_msg_constructors() -> (int, Type[Message]):
subclass_tuples = []
for subclass in Message.__subclasses__():
subclass_tuples.append((subclass.MSGTYPE, subclass))
return subclass_tuples
def __init__(self, retry_delay_min: float = 10.0):
self._log = module_logger.getChild(self.__class__.__name__)
self._sent_messages: list[Message] = []
self._lock = threading.RLock()
self._retry_timer = rnsh.retry.RetryThread()
self._message_factories = dict(self.__class__._get_msg_constructors())
self._retry_delay_min = retry_delay_min
def __enter__(self) -> Messenger:
return self
def __exit__(self, __exc_type: Type[BaseException] | None, __exc_value: BaseException | None,
__traceback: TracebackType | None) -> bool | None:
self.shutdown()
return False
def shutdown(self):
self._retry_timer.close()
def clear_retries(self, outlet):
self._retry_timer.complete(outlet)
def receive(self, raw: bytes) -> Message:
(mid, contents) = Message.static_unwrap_MSGTYPE(raw)
ctor = self._message_factories.get(mid, None)
if ctor is None:
raise MessagingException(METype.ME_NOT_REGISTERED, f"unable to find constructor for message type {hex(mid)}")
message = ctor()
message.unpack(raw)
self._log.debug(f"Message received: {message}")
return message
def is_outlet_ready(self, outlet: MessageOutletBase) -> bool:
if not outlet.is_usuable:
self._log.debug("is_outlet_ready outlet unusable")
return False
with self._lock:
for message in self._sent_messages:
if message.outlet == outlet and message.tracked and message.receipt \
and outlet.get_receipt_state(message.receipt) == MessageState.MSGSTATE_SENT:
self._log.debug("is_outlet_ready pending message found")
return False
return True
def send(self, outlet: MessageOutletBase, message: Message):
with self._lock:
if not self.is_outlet_ready(outlet):
raise MessagingException(METype.ME_LINK_NOT_READY, f"link {outlet} not ready")
if message in self._sent_messages:
raise MessagingException(METype.ME_ALREADY_SENT)
self._sent_messages.append(message)
message.tracked = True
if not message.raw:
message.raw = message.pack()
message.outlet = outlet
def send_inner(tag: any, tries: int):
state = MessageState.MSGSTATE_NEW if not message.receipt else outlet.get_receipt_state(message.receipt)
if state in [MessageState.MSGSTATE_NEW, MessageState.MSGSTATE_FAILED]:
try:
if message.receipt:
self._log.debug(f"Resending packet for {message}")
message.receipt = outlet.resend(message.receipt)
else:
self._log.debug(f"Sending packet for {message}")
message.receipt = outlet.send(message.raw)
except Exception as ex:
self._log.exception(f"Error sending message {message}")
elif state in [MessageState.MSGSTATE_SENT]:
self._log.debug(f"Retry skipped, message still pending {message}")
elif state in [MessageState.MSGSTATE_DELIVERED]:
latency = round(time.time() - message.ts, 1)
self._log.debug(f"{message} delivered {message.msgid} after {tries-1} tries/{latency} seconds")
with self._lock:
self._sent_messages.remove(message)
message.tracked = False
self._retry_timer.complete(outlet)
return outlet
def timeout(tag: any, tries: int):
latency = round(time.time() - message.ts, 1)
msg = "delivered" if message.receipt and outlet.get_receipt_state(message.receipt) == MessageState.MSGSTATE_DELIVERED else "retry timeout"
self._log.debug(f"Message {msg} {message} after {tries} tries/{latency} seconds")
with self._lock:
self._sent_messages.remove(message)
message.tracked = False
outlet.timed_out()
rtt = outlet.rtt
self._retry_timer.begin(5, max(rtt * 5, self._retry_delay_min), send_inner, timeout)
def register_message_types(channel: RNS.Channel.Channel):
for message_type in message_types:
channel.register_message_type(message_type)

View File

@ -102,53 +102,58 @@ def print_identity(configdir, identitypath, service_name, include_destination: b
exit(0)
verbose_set = False
async def _rnsh_cli_main():
#with contextlib.suppress(KeyboardInterrupt, SystemExit):
import docopt
log = _get_logger("main")
_loop = asyncio.get_running_loop()
rnslogging.set_main_loop(_loop)
args = rnsh.args.Args(sys.argv)
global verbose_set
log = _get_logger("main")
_loop = asyncio.get_running_loop()
rnslogging.set_main_loop(_loop)
args = rnsh.args.Args(sys.argv)
verbose_set = args.verbose > 0
if args.print_identity:
print_identity(args.config, args.identity, args.service_name, args.listen)
return 0
if args.print_identity:
print_identity(args.config, args.identity, args.service_name, args.listen)
return 0
if args.listen:
# log.info("command " + args.command)
await listener.listen(configdir=args.config,
command=args.command_line,
identitypath=args.identity,
service_name=args.service_name,
verbosity=args.verbose,
quietness=args.quiet,
allowed=args.allowed,
disable_auth=args.no_auth,
announce_period=args.announce,
no_remote_command=args.no_remote_cmd,
remote_cmd_as_args=args.remote_cmd_as_args)
return 0
if args.listen:
# log.info("command " + args.command)
await listener.listen(configdir=args.config,
command=args.command_line,
identitypath=args.identity,
service_name=args.service_name,
verbosity=args.verbose,
quietness=args.quiet,
allowed=args.allowed,
disable_auth=args.no_auth,
announce_period=args.announce,
no_remote_command=args.no_remote_cmd,
remote_cmd_as_args=args.remote_cmd_as_args)
return 0
if args.destination is not None:
return_code = await initiator.initiate(configdir=args.config,
identitypath=args.identity,
verbosity=args.verbose,
quietness=args.quiet,
noid=args.no_id,
destination=args.destination,
timeout=args.timeout,
command=args.command_line
)
return return_code if args.mirror else 0
else:
print("")
print(rnsh.args.usage)
print("")
return 1
if args.destination is not None:
return_code = await initiator.initiate(configdir=args.config,
identitypath=args.identity,
verbosity=args.verbose,
quietness=args.quiet,
noid=args.no_id,
destination=args.destination,
timeout=args.timeout,
command=args.command_line
)
return return_code if args.mirror else 0
else:
print("")
print(rnsh.args.usage)
print("")
return 1
def rnsh_cli():
global verbose_set
return_code = 1
exc = None
try:
return_code = asyncio.run(_rnsh_cli_main())
except SystemExit:
@ -156,8 +161,11 @@ def rnsh_cli():
except KeyboardInterrupt:
pass
except Exception as ex:
print(f"Unhandled exception: {ex}")
print(f"Unhandled exception: {ex}")
exc = ex
process.tty_unset_reader_callbacks(0)
if verbose_set and exc:
raise exc
sys.exit(return_code if return_code is not None else 255)

View File

@ -16,8 +16,6 @@ import RNS
import logging as __logging
from rnsh.protocol import MessageOutletBase, _TReceipt, MessageState
module_logger = __logging.getLogger(__name__)
_TLink = TypeVar("_TLink")
@ -44,7 +42,7 @@ class LSState(enum.IntEnum):
_TIdentity = TypeVar("_TIdentity")
class LSOutletBase(protocol.MessageOutletBase):
class LSOutletBase(ABC):
@abstractmethod
def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]):
raise NotImplemented()
@ -57,28 +55,29 @@ class LSOutletBase(protocol.MessageOutletBase):
def unset_link_closed_callback(self):
raise NotImplemented()
@property
@abstractmethod
def teardown(self):
def rtt(self):
raise NotImplemented()
@abstractmethod
def __init__(self):
def teardown(self):
raise NotImplemented()
class ListenerSession:
sessions: List[ListenerSession] = []
messenger: protocol.Messenger = protocol.Messenger(retry_delay_min=5)
allowed_identity_hashes: [any] = []
allow_all: bool = False
allow_remote_command: bool = False
default_command: [str] = []
remote_cmd_as_args = False
def __init__(self, outlet: LSOutletBase, loop: asyncio.AbstractEventLoop):
def __init__(self, outlet: LSOutletBase, channel: RNS.Channel.Channel, loop: asyncio.AbstractEventLoop):
self._log = module_logger.getChild(self.__class__.__name__)
self._log.info(f"Session started for {outlet}")
self.outlet = outlet
self.channel = channel
self.outlet.set_initiator_identified_callback(self._initiator_identified)
self.outlet.set_link_closed_callback(self._link_closed)
self.loop = loop
@ -106,7 +105,8 @@ class ListenerSession:
else:
self._set_state(LSState.LSSTATE_WAIT_IDENT)
self.sessions.append(self)
self.outlet.set_packet_received_callback(self._packet_received)
protocol.register_message_types(self.channel)
self.channel.add_message_handler(self._handle_message)
def _terminated(self, return_code: int):
self.return_code = return_code
@ -128,8 +128,8 @@ class ListenerSession:
self.loop.call_later(delay, func)
self.loop.call_soon_threadsafe(call_inner)
def send(self, message: protocol.Message):
self.messenger.send(self.outlet, message)
def send(self, message: RNS.MessageBase):
self.channel.send(message)
def _protocol_error(self, name: str):
self.terminate(f"Protocol error ({name})")
@ -171,7 +171,6 @@ class ListenerSession:
return
self._log.debug(f"link_closed {outlet}")
self.messenger.clear_retries(self.outlet)
self.terminate()
def _initiator_identified(self, outlet, identity):
@ -208,10 +207,10 @@ class ListenerSession:
try:
if self.state != LSState.LSSTATE_RUNNING:
return False
elif not self.messenger.is_outlet_ready(self.outlet):
elif not self.channel.is_ready_to_send():
return False
elif len(self.stderr_buf) > 0:
mdu = self.outlet.mdu - 16
mdu = self.channel.MDU - protocol.StreamDataMessage.OVERHEAD
data = self.stderr_buf[:mdu]
self.stderr_buf = self.stderr_buf[mdu:]
send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent
@ -223,7 +222,7 @@ class ListenerSession:
self.stderr_eof_sent = True
return True
elif len(self.stdout_buf) > 0:
mdu = self.outlet.mdu - 16
mdu = self.channel.MDU - protocol.StreamDataMessage.OVERHEAD
data = self.stdout_buf[:mdu]
self.stdout_buf = self.stdout_buf[mdu:]
send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent
@ -309,7 +308,7 @@ class ListenerSession:
if eof:
self.process.close_stdin()
def _handle_message(self, message: protocol.Message):
def _handle_message(self, message: RNS.MessageBase):
if self.state == LSState.LSSTATE_WAIT_IDENT:
self._protocol_error("Identification required")
return
@ -352,17 +351,6 @@ class ListenerSession:
self._protocol_error("unexpected message")
return
def _packet_received(self, outlet: protocol.MessageOutletBase, raw: bytes):
if outlet != self.outlet:
self._log.debug("Packet received from incorrect outlet")
return
try:
message = self.messenger.receive(raw)
self._handle_message(message)
except Exception as ex:
self._protocol_error(f"error receiving packet: {ex}")
class RNSOutlet(LSOutletBase):
@ -384,55 +372,17 @@ class RNSOutlet(LSOutletBase):
def teardown(self):
self.link.teardown()
def send(self, raw: bytes) -> RNS.Packet:
packet = RNS.Packet(self.link, raw)
packet.send()
return packet
def resend(self, packet: RNS.Packet) -> RNS.Packet:
packet.resend()
return packet
@property
def mdu(self) -> int:
return self.link.MDU
@property
def rtt(self) -> float:
return self.link.rtt
@property
def is_usuable(self):
return True #self.link.status in [RNS.Link.ACTIVE]
def get_receipt_state(self, packet: RNS.Packet) -> MessageState:
status = packet.receipt.get_status()
if status == RNS.PacketReceipt.SENT:
return protocol.MessageState.MSGSTATE_SENT
if status == RNS.PacketReceipt.DELIVERED:
return protocol.MessageState.MSGSTATE_DELIVERED
if status == RNS.PacketReceipt.FAILED:
return protocol.MessageState.MSGSTATE_FAILED
else:
raise Exception(f"Unexpected receipt state: {status}")
def timed_out(self):
self.link.teardown()
def __str__(self):
return f"Outlet RNS Link {self.link}"
def set_packet_received_callback(self, cb: Callable[[MessageOutletBase, bytes], None]):
def inner_cb(message, packet: RNS.Packet):
packet.prove()
cb(self, message)
self.link.set_packet_callback(inner_cb)
def __init__(self, link: RNS.Link):
self.link = link
link.lsoutlet = self
link.msgoutlet = self
@staticmethod
def get_outlet(link: RNS.Link):
if hasattr(link, "lsoutlet"):