Switch to RNS-provided Channel

This commit is contained in:
Aaron Heise 2023-02-28 08:48:29 -06:00
parent b6a22cd2a7
commit 5bca575a4b
No known key found for this signature in database
GPG key ID: 6BA54088C41DE8BF
6 changed files with 229 additions and 482 deletions

View file

@ -9,7 +9,7 @@ readme = "README.md"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.9" python = "^3.9"
docopt = "^0.6.2" docopt = "^0.6.2"
rns = "^0.4.9" rns = { git = "https://github.com/acehoss/Reticulum.git", branch = "feature/channel" } #{ path = "../Reticulum/", develop = true } #
tomli = "^2.0.1" tomli = "^2.0.1"
[tool.poetry.scripts] [tool.poetry.scripts]

View file

@ -143,12 +143,12 @@ class InitiatorState(enum.IntEnum):
def _client_link_closed(link): def _client_link_closed(link):
log = _get_logger("_client_link_closed") log = _get_logger("_client_link_closed")
if _finished:
_finished.set() _finished.set()
def _client_packet_handler(message, packet): def _client_message_handler(message: RNS.MessageBase):
log = _get_logger("_client_packet_handler") log = _get_logger("_client_message_handler")
packet.prove()
_pq.put(message) _pq.put(message)
@ -213,10 +213,8 @@ async def _initiate_link(configdir, identitypath=None, verbosity=0, quietness=0,
_link.identify(_identity) _link.identify(_identity)
_link.did_identify = True _link.did_identify = True
_link.set_packet_callback(_client_packet_handler)
async def _handle_error(errmsg: RNS.MessageBase):
async def _handle_error(errmsg: protocol.Message):
if isinstance(errmsg, protocol.ErrorMessage): if isinstance(errmsg, protocol.ErrorMessage):
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
if _link and _link.status == RNS.Link.ACTIVE: if _link and _link.status == RNS.Link.ACTIVE:
@ -249,17 +247,18 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness:
state = InitiatorState.IS_LINKED state = InitiatorState.IS_LINKED
outlet = session.RNSOutlet(_link) outlet = session.RNSOutlet(_link)
with protocol.Messenger(retry_delay_min=5) as messenger: channel = _link.get_channel()
protocol.register_message_types(channel)
channel.add_message_handler(_client_message_handler)
# Next step after linking and identifying: send version # Next step after linking and identifying: send version
# if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5): # if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5):
# print("Error bringing up link") # print("Error bringing up link")
# return 253 # return 253
messenger.send(outlet, protocol.VersionInfoMessage()) channel.send(protocol.VersionInfoMessage())
try: try:
vp = _pq.get(timeout=max(outlet.rtt * 20, 5)) vm = _pq.get(timeout=max(outlet.rtt * 20, 5))
vm = messenger.receive(vp)
await _handle_error(vm) await _handle_error(vm)
if not isinstance(vm, protocol.VersionInfoMessage): if not isinstance(vm, protocol.VersionInfoMessage):
raise Exception("Invalid message received") raise Exception("Invalid message received")
@ -307,7 +306,8 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness:
except: except:
pass pass
messenger.send(outlet, protocol.ExecuteCommandMesssage(cmdline=command, await _spin(lambda: channel.is_ready_to_send(), "Waiting for channel...", 1)
channel.send(protocol.ExecuteCommandMesssage(cmdline=command,
pipe_stdin=not os.isatty(0), pipe_stdin=not os.isatty(0),
pipe_stdout=not os.isatty(1), pipe_stdout=not os.isatty(1),
pipe_stderr=not os.isatty(2), pipe_stderr=not os.isatty(2),
@ -330,8 +330,7 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness:
while not await _check_finished() and state in [InitiatorState.IS_RUNNING]: while not await _check_finished() and state in [InitiatorState.IS_RUNNING]:
try: try:
try: try:
packet = _pq.get(timeout=sleeper.next_sleep_time() if not processed else 0.0005) message = _pq.get(timeout=sleeper.next_sleep_time() if not processed else 0.0005)
message = messenger.receive(packet)
await _handle_error(message) await _handle_error(message)
processed = True processed = True
if isinstance(message, protocol.StreamDataMessage): if isinstance(message, protocol.StreamDataMessage):
@ -353,8 +352,6 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness:
os.close(2) os.close(2)
elif isinstance(message, protocol.CommandExitedMessage): elif isinstance(message, protocol.CommandExitedMessage):
log.debug(f"received return code {message.return_code}, exiting") log.debug(f"received return code {message.return_code}, exiting")
with exception.permit(SystemExit, KeyboardInterrupt):
_link.teardown()
return message.return_code return message.return_code
elif isinstance(message, protocol.ErrorMessage): elif isinstance(message, protocol.ErrorMessage):
log.error(message.data) log.error(message.data)
@ -365,13 +362,12 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness:
except queue.Empty: except queue.Empty:
processed = False processed = False
if messenger.is_outlet_ready(outlet): if channel.is_ready_to_send():
stdin = data_buffer[:mdu] stdin = data_buffer[:mdu]
data_buffer = data_buffer[mdu:] data_buffer = data_buffer[mdu:]
eof = not sent_eof and stdin_eof and len(stdin) == 0 eof = not sent_eof and stdin_eof and len(stdin) == 0
if len(stdin) > 0 or eof: if len(stdin) > 0 or eof:
messenger.send(outlet, protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN, channel.send(protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN, stdin, eof))
stdin, eof))
sent_eof = eof sent_eof = eof
processed = True processed = True
@ -381,7 +377,7 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness:
winch = False winch = False
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
r, c, h, v = process.tty_get_winsize(0) r, c, h, v = process.tty_get_winsize(0)
messenger.send(outlet, protocol.WindowSizeMessage(r, c, h, v)) channel.send(protocol.WindowSizeMessage(r, c, h, v))
processed = True processed = True
except RemoteExecutionError as e: except RemoteExecutionError as e:
print(e.msg) print(e.msg)

View file

@ -159,7 +159,7 @@ async def listen(configdir, command, identitypath=None, service_name=None, verbo
log.warning("Warning: No allowed identities configured, rnsh will not accept any connections!") log.warning("Warning: No allowed identities configured, rnsh will not accept any connections!")
def link_established(lnk: RNS.Link): def link_established(lnk: RNS.Link):
session.ListenerSession(session.RNSOutlet.get_outlet(lnk), loop) session.ListenerSession(session.RNSOutlet.get_outlet(lnk), lnk.get_channel(), loop)
_destination.set_link_established_callback(link_established) _destination.set_link_established_callback(link_established)
_finished = asyncio.Event() _finished = asyncio.Event()
@ -188,7 +188,6 @@ async def listen(configdir, command, identitypath=None, service_name=None, verbo
log.warning("Shutting down") log.warning("Shutting down")
await session.ListenerSession.terminate_all("Shutting down") await session.ListenerSession.terminate_all("Shutting down")
await asyncio.sleep(1) await asyncio.sleep(1)
session.ListenerSession.messenger.shutdown()
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links)) links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
for link in links_still_active: for link in links_still_active:
if link.status not in [RNS.Link.CLOSED]: if link.status not in [RNS.Link.CLOSED]:

View file

@ -19,9 +19,6 @@ from abc import ABC, abstractmethod
module_logger = __logging.getLogger(__name__) module_logger = __logging.getLogger(__name__)
_TReceipt = TypeVar("_TReceipt")
_TLink = TypeVar("_TLink")
MSG_MAGIC = 0xac MSG_MAGIC = 0xac
PROTOCOL_VERSION = 1 PROTOCOL_VERSION = 1
@ -30,120 +27,17 @@ def _make_MSGTYPE(val: int):
return ((MSG_MAGIC << 8) & 0xff00) | (val & 0x00ff) return ((MSG_MAGIC << 8) & 0xff00) | (val & 0x00ff)
class MessageOutletBase(ABC): class NoopMessage(RNS.MessageBase):
@abstractmethod
def send(self, raw: bytes) -> _TReceipt:
raise NotImplemented()
@abstractmethod
def resend(self, receipt: _TReceipt) -> _TReceipt:
raise NotImplemented()
@property
@abstractmethod
def mdu(self):
raise NotImplemented()
@property
@abstractmethod
def rtt(self):
raise NotImplemented()
@property
@abstractmethod
def is_usuable(self):
raise NotImplemented()
@abstractmethod
def get_receipt_state(self, receipt: _TReceipt) -> MessageState:
raise NotImplemented()
@abstractmethod
def timed_out(self):
raise NotImplemented()
@abstractmethod
def __str__(self):
raise NotImplemented()
@abstractmethod
def set_packet_received_callback(self, cb: Callable[[MessageOutletBase, bytes], None]):
raise NotImplemented()
class METype(enum.IntEnum):
ME_NO_MSG_TYPE = 0
ME_INVALID_MSG_TYPE = 1
ME_NOT_REGISTERED = 2
ME_LINK_NOT_READY = 3
ME_ALREADY_SENT = 4
class MessagingException(Exception):
def __init__(self, type: METype, *args):
super().__init__(args)
self.type = type
class MessageState(enum.IntEnum):
MSGSTATE_NEW = 0
MSGSTATE_SENT = 1
MSGSTATE_DELIVERED = 2
MSGSTATE_FAILED = 3
class Message(abc.ABC):
MSGTYPE = None
def __init__(self):
self.ts = time.time()
self.msgid = uuid.uuid4()
self.raw: bytes | None = None
self.receipt: _TReceipt = None
self.outlet: _TLink = None
self.tracked: bool = False
def __str__(self):
return f"{self.__class__.__name__} {self.msgid}"
@abstractmethod
def pack(self) -> bytes:
raise NotImplemented()
@abstractmethod
def unpack(self, raw):
raise NotImplemented()
def unwrap_MSGTYPE(self, raw: bytes) -> bytes:
if self.MSGTYPE is None:
raise MessagingException(METype.ME_NO_MSG_TYPE, f"{self.__class__} lacks MSGTYPE")
mid, raw = self.static_unwrap_MSGTYPE(raw)
if mid != self.MSGTYPE:
raise MessagingException(METype.ME_INVALID_MSG_TYPE,
f"invalid msg id, expected {hex(self.MSGTYPE)} got {hex(mid)}")
return raw
def wrap_MSGTYPE(self, raw: bytes) -> bytes:
if self.__class__.MSGTYPE is None:
raise MessagingException(METype.ME_NO_MSG_TYPE, f"{self.__class__} lacks MSGTYPE")
return struct.pack(">H", self.MSGTYPE) + raw
@staticmethod
def static_unwrap_MSGTYPE(raw: bytes) -> (int, bytes):
return struct.unpack(">H", raw[:2])[0], raw[2:]
class NoopMessage(Message):
MSGTYPE = _make_MSGTYPE(0) MSGTYPE = _make_MSGTYPE(0)
def pack(self) -> bytes: def pack(self) -> bytes:
return self.wrap_MSGTYPE(bytes()) return bytes()
def unpack(self, raw): def unpack(self, raw):
self.unwrap_MSGTYPE(raw) pass
class WindowSizeMessage(Message): class WindowSizeMessage(RNS.MessageBase):
MSGTYPE = _make_MSGTYPE(2) MSGTYPE = _make_MSGTYPE(2)
def __init__(self, rows: int = None, cols: int = None, hpix: int = None, vpix: int = None): def __init__(self, rows: int = None, cols: int = None, hpix: int = None, vpix: int = None):
@ -154,15 +48,13 @@ class WindowSizeMessage(Message):
self.vpix = vpix self.vpix = vpix
def pack(self) -> bytes: def pack(self) -> bytes:
raw = umsgpack.packb((self.rows, self.cols, self.hpix, self.vpix)) return umsgpack.packb((self.rows, self.cols, self.hpix, self.vpix))
return self.wrap_MSGTYPE(raw)
def unpack(self, raw): def unpack(self, raw):
raw = self.unwrap_MSGTYPE(raw)
self.rows, self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw) self.rows, self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw)
class ExecuteCommandMesssage(Message): class ExecuteCommandMesssage(RNS.MessageBase):
MSGTYPE = _make_MSGTYPE(3) MSGTYPE = _make_MSGTYPE(3)
def __init__(self, cmdline: [str] = None, pipe_stdin: bool = False, pipe_stdout: bool = False, def __init__(self, cmdline: [str] = None, pipe_stdin: bool = False, pipe_stdout: bool = False,
@ -181,20 +73,20 @@ class ExecuteCommandMesssage(Message):
self.vpix = vpix self.vpix = vpix
def pack(self) -> bytes: def pack(self) -> bytes:
raw = umsgpack.packb((self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, return umsgpack.packb((self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr,
self.tcflags, self.term, self.rows, self.cols, self.hpix, self.vpix)) self.tcflags, self.term, self.rows, self.cols, self.hpix, self.vpix))
return self.wrap_MSGTYPE(raw)
def unpack(self, raw): def unpack(self, raw):
raw = self.unwrap_MSGTYPE(raw)
self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, self.tcflags, self.term, self.rows, \ self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, self.tcflags, self.term, self.rows, \
self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw) self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw)
class StreamDataMessage(Message):
class StreamDataMessage(RNS.MessageBase):
MSGTYPE = _make_MSGTYPE(4) MSGTYPE = _make_MSGTYPE(4)
STREAM_ID_STDIN = 0 STREAM_ID_STDIN = 0
STREAM_ID_STDOUT = 1 STREAM_ID_STDOUT = 1
STREAM_ID_STDERR = 2 STREAM_ID_STDERR = 2
OVERHEAD = 0
def __init__(self, stream_id: int = None, data: bytes = None, eof: bool = False): def __init__(self, stream_id: int = None, data: bytes = None, eof: bool = False):
super().__init__() super().__init__()
@ -203,15 +95,19 @@ class StreamDataMessage(Message):
self.eof = eof self.eof = eof
def pack(self) -> bytes: def pack(self) -> bytes:
raw = umsgpack.packb((self.stream_id, self.eof, bytes(self.data))) return umsgpack.packb((self.stream_id, self.eof, bytes(self.data)))
return self.wrap_MSGTYPE(raw)
def unpack(self, raw): def unpack(self, raw):
raw = self.unwrap_MSGTYPE(raw)
self.stream_id, self.eof, self.data = umsgpack.unpackb(raw) self.stream_id, self.eof, self.data = umsgpack.unpackb(raw)
class VersionInfoMessage(Message): _link_sized_bytes = ("\0"*RNS.Link.MDU).encode("utf-8")
StreamDataMessage.OVERHEAD = len(StreamDataMessage(stream_id=0, data=_link_sized_bytes, eof=True).pack()) \
- len(_link_sized_bytes)
_link_sized_bytes = None
class VersionInfoMessage(RNS.MessageBase):
MSGTYPE = _make_MSGTYPE(5) MSGTYPE = _make_MSGTYPE(5)
def __init__(self, sw_version: str = None): def __init__(self, sw_version: str = None):
@ -220,15 +116,13 @@ class VersionInfoMessage(Message):
self.protocol_version = PROTOCOL_VERSION self.protocol_version = PROTOCOL_VERSION
def pack(self) -> bytes: def pack(self) -> bytes:
raw = umsgpack.packb((self.sw_version, self.protocol_version)) return umsgpack.packb((self.sw_version, self.protocol_version))
return self.wrap_MSGTYPE(raw)
def unpack(self, raw): def unpack(self, raw):
raw = self.unwrap_MSGTYPE(raw)
self.sw_version, self.protocol_version = umsgpack.unpackb(raw) self.sw_version, self.protocol_version = umsgpack.unpackb(raw)
class ErrorMessage(Message): class ErrorMessage(RNS.MessageBase):
MSGTYPE = _make_MSGTYPE(6) MSGTYPE = _make_MSGTYPE(6)
def __init__(self, msg: str = None, fatal: bool = False, data: dict = None): def __init__(self, msg: str = None, fatal: bool = False, data: dict = None):
@ -238,15 +132,13 @@ class ErrorMessage(Message):
self.data = data self.data = data
def pack(self) -> bytes: def pack(self) -> bytes:
raw = umsgpack.packb((self.msg, self.fatal, self.data)) return umsgpack.packb((self.msg, self.fatal, self.data))
return self.wrap_MSGTYPE(raw)
def unpack(self, raw: bytes): def unpack(self, raw: bytes):
raw = self.unwrap_MSGTYPE(raw)
self.msg, self.fatal, self.data = umsgpack.unpackb(raw) self.msg, self.fatal, self.data = umsgpack.unpackb(raw)
class CommandExitedMessage(Message): class CommandExitedMessage(RNS.MessageBase):
MSGTYPE = _make_MSGTYPE(7) MSGTYPE = _make_MSGTYPE(7)
def __init__(self, return_code: int = None): def __init__(self, return_code: int = None):
@ -254,114 +146,16 @@ class CommandExitedMessage(Message):
self.return_code = return_code self.return_code = return_code
def pack(self) -> bytes: def pack(self) -> bytes:
raw = umsgpack.packb(self.return_code) return umsgpack.packb(self.return_code)
return self.wrap_MSGTYPE(raw)
def unpack(self, raw: bytes): def unpack(self, raw: bytes):
raw = self.unwrap_MSGTYPE(raw)
self.return_code = umsgpack.unpackb(raw) self.return_code = umsgpack.unpackb(raw)
class Messenger(contextlib.AbstractContextManager): message_types = [NoopMessage, VersionInfoMessage, WindowSizeMessage, ExecuteCommandMesssage, StreamDataMessage,
CommandExitedMessage, ErrorMessage]
@staticmethod
def _get_msg_constructors() -> (int, Type[Message]):
subclass_tuples = []
for subclass in Message.__subclasses__():
subclass_tuples.append((subclass.MSGTYPE, subclass))
return subclass_tuples
def __init__(self, retry_delay_min: float = 10.0):
self._log = module_logger.getChild(self.__class__.__name__)
self._sent_messages: list[Message] = []
self._lock = threading.RLock()
self._retry_timer = rnsh.retry.RetryThread()
self._message_factories = dict(self.__class__._get_msg_constructors())
self._retry_delay_min = retry_delay_min
def __enter__(self) -> Messenger:
return self
def __exit__(self, __exc_type: Type[BaseException] | None, __exc_value: BaseException | None,
__traceback: TracebackType | None) -> bool | None:
self.shutdown()
return False
def shutdown(self):
self._retry_timer.close()
def clear_retries(self, outlet):
self._retry_timer.complete(outlet)
def receive(self, raw: bytes) -> Message:
(mid, contents) = Message.static_unwrap_MSGTYPE(raw)
ctor = self._message_factories.get(mid, None)
if ctor is None:
raise MessagingException(METype.ME_NOT_REGISTERED, f"unable to find constructor for message type {hex(mid)}")
message = ctor()
message.unpack(raw)
self._log.debug(f"Message received: {message}")
return message
def is_outlet_ready(self, outlet: MessageOutletBase) -> bool:
if not outlet.is_usuable:
self._log.debug("is_outlet_ready outlet unusable")
return False
with self._lock:
for message in self._sent_messages:
if message.outlet == outlet and message.tracked and message.receipt \
and outlet.get_receipt_state(message.receipt) == MessageState.MSGSTATE_SENT:
self._log.debug("is_outlet_ready pending message found")
return False
return True
def send(self, outlet: MessageOutletBase, message: Message):
with self._lock:
if not self.is_outlet_ready(outlet):
raise MessagingException(METype.ME_LINK_NOT_READY, f"link {outlet} not ready")
if message in self._sent_messages:
raise MessagingException(METype.ME_ALREADY_SENT)
self._sent_messages.append(message)
message.tracked = True
if not message.raw:
message.raw = message.pack()
message.outlet = outlet
def send_inner(tag: any, tries: int):
state = MessageState.MSGSTATE_NEW if not message.receipt else outlet.get_receipt_state(message.receipt)
if state in [MessageState.MSGSTATE_NEW, MessageState.MSGSTATE_FAILED]:
try:
if message.receipt:
self._log.debug(f"Resending packet for {message}")
message.receipt = outlet.resend(message.receipt)
else:
self._log.debug(f"Sending packet for {message}")
message.receipt = outlet.send(message.raw)
except Exception as ex:
self._log.exception(f"Error sending message {message}")
elif state in [MessageState.MSGSTATE_SENT]:
self._log.debug(f"Retry skipped, message still pending {message}")
elif state in [MessageState.MSGSTATE_DELIVERED]:
latency = round(time.time() - message.ts, 1)
self._log.debug(f"{message} delivered {message.msgid} after {tries-1} tries/{latency} seconds")
with self._lock:
self._sent_messages.remove(message)
message.tracked = False
self._retry_timer.complete(outlet)
return outlet
def timeout(tag: any, tries: int):
latency = round(time.time() - message.ts, 1)
msg = "delivered" if message.receipt and outlet.get_receipt_state(message.receipt) == MessageState.MSGSTATE_DELIVERED else "retry timeout"
self._log.debug(f"Message {msg} {message} after {tries} tries/{latency} seconds")
with self._lock:
self._sent_messages.remove(message)
message.tracked = False
outlet.timed_out()
rtt = outlet.rtt
self._retry_timer.begin(5, max(rtt * 5, self._retry_delay_min), send_inner, timeout)
def register_message_types(channel: RNS.Channel.Channel):
for message_type in message_types:
channel.register_message_type(message_type)

View file

@ -102,13 +102,16 @@ def print_identity(configdir, identitypath, service_name, include_destination: b
exit(0) exit(0)
verbose_set = False
async def _rnsh_cli_main(): async def _rnsh_cli_main():
#with contextlib.suppress(KeyboardInterrupt, SystemExit): global verbose_set
import docopt
log = _get_logger("main") log = _get_logger("main")
_loop = asyncio.get_running_loop() _loop = asyncio.get_running_loop()
rnslogging.set_main_loop(_loop) rnslogging.set_main_loop(_loop)
args = rnsh.args.Args(sys.argv) args = rnsh.args.Args(sys.argv)
verbose_set = args.verbose > 0
if args.print_identity: if args.print_identity:
print_identity(args.config, args.identity, args.service_name, args.listen) print_identity(args.config, args.identity, args.service_name, args.listen)
@ -148,7 +151,9 @@ async def _rnsh_cli_main():
def rnsh_cli(): def rnsh_cli():
global verbose_set
return_code = 1 return_code = 1
exc = None
try: try:
return_code = asyncio.run(_rnsh_cli_main()) return_code = asyncio.run(_rnsh_cli_main())
except SystemExit: except SystemExit:
@ -157,7 +162,10 @@ def rnsh_cli():
pass pass
except Exception as ex: except Exception as ex:
print(f"Unhandled exception: {ex}") print(f"Unhandled exception: {ex}")
exc = ex
process.tty_unset_reader_callbacks(0) process.tty_unset_reader_callbacks(0)
if verbose_set and exc:
raise exc
sys.exit(return_code if return_code is not None else 255) sys.exit(return_code if return_code is not None else 255)

View file

@ -16,8 +16,6 @@ import RNS
import logging as __logging import logging as __logging
from rnsh.protocol import MessageOutletBase, _TReceipt, MessageState
module_logger = __logging.getLogger(__name__) module_logger = __logging.getLogger(__name__)
_TLink = TypeVar("_TLink") _TLink = TypeVar("_TLink")
@ -44,7 +42,7 @@ class LSState(enum.IntEnum):
_TIdentity = TypeVar("_TIdentity") _TIdentity = TypeVar("_TIdentity")
class LSOutletBase(protocol.MessageOutletBase): class LSOutletBase(ABC):
@abstractmethod @abstractmethod
def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]): def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]):
raise NotImplemented() raise NotImplemented()
@ -57,28 +55,29 @@ class LSOutletBase(protocol.MessageOutletBase):
def unset_link_closed_callback(self): def unset_link_closed_callback(self):
raise NotImplemented() raise NotImplemented()
@property
@abstractmethod @abstractmethod
def teardown(self): def rtt(self):
raise NotImplemented() raise NotImplemented()
@abstractmethod @abstractmethod
def __init__(self): def teardown(self):
raise NotImplemented() raise NotImplemented()
class ListenerSession: class ListenerSession:
sessions: List[ListenerSession] = [] sessions: List[ListenerSession] = []
messenger: protocol.Messenger = protocol.Messenger(retry_delay_min=5)
allowed_identity_hashes: [any] = [] allowed_identity_hashes: [any] = []
allow_all: bool = False allow_all: bool = False
allow_remote_command: bool = False allow_remote_command: bool = False
default_command: [str] = [] default_command: [str] = []
remote_cmd_as_args = False remote_cmd_as_args = False
def __init__(self, outlet: LSOutletBase, loop: asyncio.AbstractEventLoop): def __init__(self, outlet: LSOutletBase, channel: RNS.Channel.Channel, loop: asyncio.AbstractEventLoop):
self._log = module_logger.getChild(self.__class__.__name__) self._log = module_logger.getChild(self.__class__.__name__)
self._log.info(f"Session started for {outlet}") self._log.info(f"Session started for {outlet}")
self.outlet = outlet self.outlet = outlet
self.channel = channel
self.outlet.set_initiator_identified_callback(self._initiator_identified) self.outlet.set_initiator_identified_callback(self._initiator_identified)
self.outlet.set_link_closed_callback(self._link_closed) self.outlet.set_link_closed_callback(self._link_closed)
self.loop = loop self.loop = loop
@ -106,7 +105,8 @@ class ListenerSession:
else: else:
self._set_state(LSState.LSSTATE_WAIT_IDENT) self._set_state(LSState.LSSTATE_WAIT_IDENT)
self.sessions.append(self) self.sessions.append(self)
self.outlet.set_packet_received_callback(self._packet_received) protocol.register_message_types(self.channel)
self.channel.add_message_handler(self._handle_message)
def _terminated(self, return_code: int): def _terminated(self, return_code: int):
self.return_code = return_code self.return_code = return_code
@ -128,8 +128,8 @@ class ListenerSession:
self.loop.call_later(delay, func) self.loop.call_later(delay, func)
self.loop.call_soon_threadsafe(call_inner) self.loop.call_soon_threadsafe(call_inner)
def send(self, message: protocol.Message): def send(self, message: RNS.MessageBase):
self.messenger.send(self.outlet, message) self.channel.send(message)
def _protocol_error(self, name: str): def _protocol_error(self, name: str):
self.terminate(f"Protocol error ({name})") self.terminate(f"Protocol error ({name})")
@ -171,7 +171,6 @@ class ListenerSession:
return return
self._log.debug(f"link_closed {outlet}") self._log.debug(f"link_closed {outlet}")
self.messenger.clear_retries(self.outlet)
self.terminate() self.terminate()
def _initiator_identified(self, outlet, identity): def _initiator_identified(self, outlet, identity):
@ -208,10 +207,10 @@ class ListenerSession:
try: try:
if self.state != LSState.LSSTATE_RUNNING: if self.state != LSState.LSSTATE_RUNNING:
return False return False
elif not self.messenger.is_outlet_ready(self.outlet): elif not self.channel.is_ready_to_send():
return False return False
elif len(self.stderr_buf) > 0: elif len(self.stderr_buf) > 0:
mdu = self.outlet.mdu - 16 mdu = self.channel.MDU - protocol.StreamDataMessage.OVERHEAD
data = self.stderr_buf[:mdu] data = self.stderr_buf[:mdu]
self.stderr_buf = self.stderr_buf[mdu:] self.stderr_buf = self.stderr_buf[mdu:]
send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent
@ -223,7 +222,7 @@ class ListenerSession:
self.stderr_eof_sent = True self.stderr_eof_sent = True
return True return True
elif len(self.stdout_buf) > 0: elif len(self.stdout_buf) > 0:
mdu = self.outlet.mdu - 16 mdu = self.channel.MDU - protocol.StreamDataMessage.OVERHEAD
data = self.stdout_buf[:mdu] data = self.stdout_buf[:mdu]
self.stdout_buf = self.stdout_buf[mdu:] self.stdout_buf = self.stdout_buf[mdu:]
send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent
@ -309,7 +308,7 @@ class ListenerSession:
if eof: if eof:
self.process.close_stdin() self.process.close_stdin()
def _handle_message(self, message: protocol.Message): def _handle_message(self, message: RNS.MessageBase):
if self.state == LSState.LSSTATE_WAIT_IDENT: if self.state == LSState.LSSTATE_WAIT_IDENT:
self._protocol_error("Identification required") self._protocol_error("Identification required")
return return
@ -352,17 +351,6 @@ class ListenerSession:
self._protocol_error("unexpected message") self._protocol_error("unexpected message")
return return
def _packet_received(self, outlet: protocol.MessageOutletBase, raw: bytes):
if outlet != self.outlet:
self._log.debug("Packet received from incorrect outlet")
return
try:
message = self.messenger.receive(raw)
self._handle_message(message)
except Exception as ex:
self._protocol_error(f"error receiving packet: {ex}")
class RNSOutlet(LSOutletBase): class RNSOutlet(LSOutletBase):
@ -384,55 +372,17 @@ class RNSOutlet(LSOutletBase):
def teardown(self): def teardown(self):
self.link.teardown() self.link.teardown()
def send(self, raw: bytes) -> RNS.Packet:
packet = RNS.Packet(self.link, raw)
packet.send()
return packet
def resend(self, packet: RNS.Packet) -> RNS.Packet:
packet.resend()
return packet
@property
def mdu(self) -> int:
return self.link.MDU
@property @property
def rtt(self) -> float: def rtt(self) -> float:
return self.link.rtt return self.link.rtt
@property
def is_usuable(self):
return True #self.link.status in [RNS.Link.ACTIVE]
def get_receipt_state(self, packet: RNS.Packet) -> MessageState:
status = packet.receipt.get_status()
if status == RNS.PacketReceipt.SENT:
return protocol.MessageState.MSGSTATE_SENT
if status == RNS.PacketReceipt.DELIVERED:
return protocol.MessageState.MSGSTATE_DELIVERED
if status == RNS.PacketReceipt.FAILED:
return protocol.MessageState.MSGSTATE_FAILED
else:
raise Exception(f"Unexpected receipt state: {status}")
def timed_out(self):
self.link.teardown()
def __str__(self): def __str__(self):
return f"Outlet RNS Link {self.link}" return f"Outlet RNS Link {self.link}"
def set_packet_received_callback(self, cb: Callable[[MessageOutletBase, bytes], None]):
def inner_cb(message, packet: RNS.Packet):
packet.prove()
cb(self, message)
self.link.set_packet_callback(inner_cb)
def __init__(self, link: RNS.Link): def __init__(self, link: RNS.Link):
self.link = link self.link = link
link.lsoutlet = self link.lsoutlet = self
link.msgoutlet = self
@staticmethod @staticmethod
def get_outlet(link: RNS.Link): def get_outlet(link: RNS.Link):
if hasattr(link, "lsoutlet"): if hasattr(link, "lsoutlet"):