Switch to RNS-provided Channel

This commit is contained in:
Aaron Heise 2023-02-28 08:48:29 -06:00
parent b6a22cd2a7
commit 5bca575a4b
No known key found for this signature in database
GPG key ID: 6BA54088C41DE8BF
6 changed files with 229 additions and 482 deletions

View file

@ -16,8 +16,6 @@ import RNS
import logging as __logging
from rnsh.protocol import MessageOutletBase, _TReceipt, MessageState
module_logger = __logging.getLogger(__name__)
_TLink = TypeVar("_TLink")
@ -44,7 +42,7 @@ class LSState(enum.IntEnum):
_TIdentity = TypeVar("_TIdentity")
class LSOutletBase(protocol.MessageOutletBase):
class LSOutletBase(ABC):
@abstractmethod
def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]):
raise NotImplemented()
@ -57,28 +55,29 @@ class LSOutletBase(protocol.MessageOutletBase):
def unset_link_closed_callback(self):
raise NotImplemented()
@property
@abstractmethod
def teardown(self):
def rtt(self):
raise NotImplemented()
@abstractmethod
def __init__(self):
def teardown(self):
raise NotImplemented()
class ListenerSession:
sessions: List[ListenerSession] = []
messenger: protocol.Messenger = protocol.Messenger(retry_delay_min=5)
allowed_identity_hashes: [any] = []
allow_all: bool = False
allow_remote_command: bool = False
default_command: [str] = []
remote_cmd_as_args = False
def __init__(self, outlet: LSOutletBase, loop: asyncio.AbstractEventLoop):
def __init__(self, outlet: LSOutletBase, channel: RNS.Channel.Channel, loop: asyncio.AbstractEventLoop):
self._log = module_logger.getChild(self.__class__.__name__)
self._log.info(f"Session started for {outlet}")
self.outlet = outlet
self.channel = channel
self.outlet.set_initiator_identified_callback(self._initiator_identified)
self.outlet.set_link_closed_callback(self._link_closed)
self.loop = loop
@ -106,7 +105,8 @@ class ListenerSession:
else:
self._set_state(LSState.LSSTATE_WAIT_IDENT)
self.sessions.append(self)
self.outlet.set_packet_received_callback(self._packet_received)
protocol.register_message_types(self.channel)
self.channel.add_message_handler(self._handle_message)
def _terminated(self, return_code: int):
self.return_code = return_code
@ -128,8 +128,8 @@ class ListenerSession:
self.loop.call_later(delay, func)
self.loop.call_soon_threadsafe(call_inner)
def send(self, message: protocol.Message):
self.messenger.send(self.outlet, message)
def send(self, message: RNS.MessageBase):
self.channel.send(message)
def _protocol_error(self, name: str):
self.terminate(f"Protocol error ({name})")
@ -171,7 +171,6 @@ class ListenerSession:
return
self._log.debug(f"link_closed {outlet}")
self.messenger.clear_retries(self.outlet)
self.terminate()
def _initiator_identified(self, outlet, identity):
@ -208,10 +207,10 @@ class ListenerSession:
try:
if self.state != LSState.LSSTATE_RUNNING:
return False
elif not self.messenger.is_outlet_ready(self.outlet):
elif not self.channel.is_ready_to_send():
return False
elif len(self.stderr_buf) > 0:
mdu = self.outlet.mdu - 16
mdu = self.channel.MDU - protocol.StreamDataMessage.OVERHEAD
data = self.stderr_buf[:mdu]
self.stderr_buf = self.stderr_buf[mdu:]
send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent
@ -223,7 +222,7 @@ class ListenerSession:
self.stderr_eof_sent = True
return True
elif len(self.stdout_buf) > 0:
mdu = self.outlet.mdu - 16
mdu = self.channel.MDU - protocol.StreamDataMessage.OVERHEAD
data = self.stdout_buf[:mdu]
self.stdout_buf = self.stdout_buf[mdu:]
send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent
@ -309,7 +308,7 @@ class ListenerSession:
if eof:
self.process.close_stdin()
def _handle_message(self, message: protocol.Message):
def _handle_message(self, message: RNS.MessageBase):
if self.state == LSState.LSSTATE_WAIT_IDENT:
self._protocol_error("Identification required")
return
@ -352,17 +351,6 @@ class ListenerSession:
self._protocol_error("unexpected message")
return
def _packet_received(self, outlet: protocol.MessageOutletBase, raw: bytes):
if outlet != self.outlet:
self._log.debug("Packet received from incorrect outlet")
return
try:
message = self.messenger.receive(raw)
self._handle_message(message)
except Exception as ex:
self._protocol_error(f"error receiving packet: {ex}")
class RNSOutlet(LSOutletBase):
@ -384,55 +372,17 @@ class RNSOutlet(LSOutletBase):
def teardown(self):
self.link.teardown()
def send(self, raw: bytes) -> RNS.Packet:
packet = RNS.Packet(self.link, raw)
packet.send()
return packet
def resend(self, packet: RNS.Packet) -> RNS.Packet:
packet.resend()
return packet
@property
def mdu(self) -> int:
return self.link.MDU
@property
def rtt(self) -> float:
return self.link.rtt
@property
def is_usuable(self):
return True #self.link.status in [RNS.Link.ACTIVE]
def get_receipt_state(self, packet: RNS.Packet) -> MessageState:
status = packet.receipt.get_status()
if status == RNS.PacketReceipt.SENT:
return protocol.MessageState.MSGSTATE_SENT
if status == RNS.PacketReceipt.DELIVERED:
return protocol.MessageState.MSGSTATE_DELIVERED
if status == RNS.PacketReceipt.FAILED:
return protocol.MessageState.MSGSTATE_FAILED
else:
raise Exception(f"Unexpected receipt state: {status}")
def timed_out(self):
self.link.teardown()
def __str__(self):
return f"Outlet RNS Link {self.link}"
def set_packet_received_callback(self, cb: Callable[[MessageOutletBase, bytes], None]):
def inner_cb(message, packet: RNS.Packet):
packet.prove()
cb(self, message)
self.link.set_packet_callback(inner_cb)
def __init__(self, link: RNS.Link):
self.link = link
link.lsoutlet = self
link.msgoutlet = self
@staticmethod
def get_outlet(link: RNS.Link):
if hasattr(link, "lsoutlet"):