mirror of
https://github.com/markqvist/rnsh.git
synced 2025-06-20 20:14:15 -04:00
Switch to RNS-provided Channel
This commit is contained in:
parent
b6a22cd2a7
commit
5bca575a4b
6 changed files with 229 additions and 482 deletions
|
@ -16,8 +16,6 @@ import RNS
|
|||
|
||||
import logging as __logging
|
||||
|
||||
from rnsh.protocol import MessageOutletBase, _TReceipt, MessageState
|
||||
|
||||
module_logger = __logging.getLogger(__name__)
|
||||
|
||||
_TLink = TypeVar("_TLink")
|
||||
|
@ -44,7 +42,7 @@ class LSState(enum.IntEnum):
|
|||
_TIdentity = TypeVar("_TIdentity")
|
||||
|
||||
|
||||
class LSOutletBase(protocol.MessageOutletBase):
|
||||
class LSOutletBase(ABC):
|
||||
@abstractmethod
|
||||
def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]):
|
||||
raise NotImplemented()
|
||||
|
@ -57,28 +55,29 @@ class LSOutletBase(protocol.MessageOutletBase):
|
|||
def unset_link_closed_callback(self):
|
||||
raise NotImplemented()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def teardown(self):
|
||||
def rtt(self):
|
||||
raise NotImplemented()
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
def teardown(self):
|
||||
raise NotImplemented()
|
||||
|
||||
|
||||
class ListenerSession:
|
||||
sessions: List[ListenerSession] = []
|
||||
messenger: protocol.Messenger = protocol.Messenger(retry_delay_min=5)
|
||||
allowed_identity_hashes: [any] = []
|
||||
allow_all: bool = False
|
||||
allow_remote_command: bool = False
|
||||
default_command: [str] = []
|
||||
remote_cmd_as_args = False
|
||||
|
||||
def __init__(self, outlet: LSOutletBase, loop: asyncio.AbstractEventLoop):
|
||||
def __init__(self, outlet: LSOutletBase, channel: RNS.Channel.Channel, loop: asyncio.AbstractEventLoop):
|
||||
self._log = module_logger.getChild(self.__class__.__name__)
|
||||
self._log.info(f"Session started for {outlet}")
|
||||
self.outlet = outlet
|
||||
self.channel = channel
|
||||
self.outlet.set_initiator_identified_callback(self._initiator_identified)
|
||||
self.outlet.set_link_closed_callback(self._link_closed)
|
||||
self.loop = loop
|
||||
|
@ -106,7 +105,8 @@ class ListenerSession:
|
|||
else:
|
||||
self._set_state(LSState.LSSTATE_WAIT_IDENT)
|
||||
self.sessions.append(self)
|
||||
self.outlet.set_packet_received_callback(self._packet_received)
|
||||
protocol.register_message_types(self.channel)
|
||||
self.channel.add_message_handler(self._handle_message)
|
||||
|
||||
def _terminated(self, return_code: int):
|
||||
self.return_code = return_code
|
||||
|
@ -128,8 +128,8 @@ class ListenerSession:
|
|||
self.loop.call_later(delay, func)
|
||||
self.loop.call_soon_threadsafe(call_inner)
|
||||
|
||||
def send(self, message: protocol.Message):
|
||||
self.messenger.send(self.outlet, message)
|
||||
def send(self, message: RNS.MessageBase):
|
||||
self.channel.send(message)
|
||||
|
||||
def _protocol_error(self, name: str):
|
||||
self.terminate(f"Protocol error ({name})")
|
||||
|
@ -171,7 +171,6 @@ class ListenerSession:
|
|||
return
|
||||
|
||||
self._log.debug(f"link_closed {outlet}")
|
||||
self.messenger.clear_retries(self.outlet)
|
||||
self.terminate()
|
||||
|
||||
def _initiator_identified(self, outlet, identity):
|
||||
|
@ -208,10 +207,10 @@ class ListenerSession:
|
|||
try:
|
||||
if self.state != LSState.LSSTATE_RUNNING:
|
||||
return False
|
||||
elif not self.messenger.is_outlet_ready(self.outlet):
|
||||
elif not self.channel.is_ready_to_send():
|
||||
return False
|
||||
elif len(self.stderr_buf) > 0:
|
||||
mdu = self.outlet.mdu - 16
|
||||
mdu = self.channel.MDU - protocol.StreamDataMessage.OVERHEAD
|
||||
data = self.stderr_buf[:mdu]
|
||||
self.stderr_buf = self.stderr_buf[mdu:]
|
||||
send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent
|
||||
|
@ -223,7 +222,7 @@ class ListenerSession:
|
|||
self.stderr_eof_sent = True
|
||||
return True
|
||||
elif len(self.stdout_buf) > 0:
|
||||
mdu = self.outlet.mdu - 16
|
||||
mdu = self.channel.MDU - protocol.StreamDataMessage.OVERHEAD
|
||||
data = self.stdout_buf[:mdu]
|
||||
self.stdout_buf = self.stdout_buf[mdu:]
|
||||
send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent
|
||||
|
@ -309,7 +308,7 @@ class ListenerSession:
|
|||
if eof:
|
||||
self.process.close_stdin()
|
||||
|
||||
def _handle_message(self, message: protocol.Message):
|
||||
def _handle_message(self, message: RNS.MessageBase):
|
||||
if self.state == LSState.LSSTATE_WAIT_IDENT:
|
||||
self._protocol_error("Identification required")
|
||||
return
|
||||
|
@ -352,17 +351,6 @@ class ListenerSession:
|
|||
self._protocol_error("unexpected message")
|
||||
return
|
||||
|
||||
def _packet_received(self, outlet: protocol.MessageOutletBase, raw: bytes):
|
||||
if outlet != self.outlet:
|
||||
self._log.debug("Packet received from incorrect outlet")
|
||||
return
|
||||
|
||||
try:
|
||||
message = self.messenger.receive(raw)
|
||||
self._handle_message(message)
|
||||
except Exception as ex:
|
||||
self._protocol_error(f"error receiving packet: {ex}")
|
||||
|
||||
|
||||
class RNSOutlet(LSOutletBase):
|
||||
|
||||
|
@ -384,55 +372,17 @@ class RNSOutlet(LSOutletBase):
|
|||
def teardown(self):
|
||||
self.link.teardown()
|
||||
|
||||
def send(self, raw: bytes) -> RNS.Packet:
|
||||
packet = RNS.Packet(self.link, raw)
|
||||
packet.send()
|
||||
return packet
|
||||
|
||||
def resend(self, packet: RNS.Packet) -> RNS.Packet:
|
||||
packet.resend()
|
||||
return packet
|
||||
|
||||
@property
|
||||
def mdu(self) -> int:
|
||||
return self.link.MDU
|
||||
|
||||
@property
|
||||
def rtt(self) -> float:
|
||||
return self.link.rtt
|
||||
|
||||
@property
|
||||
def is_usuable(self):
|
||||
return True #self.link.status in [RNS.Link.ACTIVE]
|
||||
|
||||
def get_receipt_state(self, packet: RNS.Packet) -> MessageState:
|
||||
status = packet.receipt.get_status()
|
||||
if status == RNS.PacketReceipt.SENT:
|
||||
return protocol.MessageState.MSGSTATE_SENT
|
||||
if status == RNS.PacketReceipt.DELIVERED:
|
||||
return protocol.MessageState.MSGSTATE_DELIVERED
|
||||
if status == RNS.PacketReceipt.FAILED:
|
||||
return protocol.MessageState.MSGSTATE_FAILED
|
||||
else:
|
||||
raise Exception(f"Unexpected receipt state: {status}")
|
||||
|
||||
def timed_out(self):
|
||||
self.link.teardown()
|
||||
|
||||
def __str__(self):
|
||||
return f"Outlet RNS Link {self.link}"
|
||||
|
||||
def set_packet_received_callback(self, cb: Callable[[MessageOutletBase, bytes], None]):
|
||||
def inner_cb(message, packet: RNS.Packet):
|
||||
packet.prove()
|
||||
cb(self, message)
|
||||
|
||||
self.link.set_packet_callback(inner_cb)
|
||||
|
||||
def __init__(self, link: RNS.Link):
|
||||
self.link = link
|
||||
link.lsoutlet = self
|
||||
link.msgoutlet = self
|
||||
|
||||
@staticmethod
|
||||
def get_outlet(link: RNS.Link):
|
||||
if hasattr(link, "lsoutlet"):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue