mirror of
https://github.com/markqvist/Reticulum.git
synced 2024-12-24 23:19:36 -05:00
Expose Channel on Link
Separates channel interface from link Also added: allow multiple message handlers
This commit is contained in:
parent
68cb4a6740
commit
fe3a3e22f7
@ -4,22 +4,22 @@ import enum
|
||||
import threading
|
||||
import time
|
||||
from types import TracebackType
|
||||
from typing import Type, Callable, TypeVar
|
||||
from typing import Type, Callable, TypeVar, Generic, NewType
|
||||
import abc
|
||||
import contextlib
|
||||
import struct
|
||||
import RNS
|
||||
from abc import ABC, abstractmethod
|
||||
_TPacket = TypeVar("_TPacket")
|
||||
TPacket = TypeVar("TPacket")
|
||||
|
||||
|
||||
class ChannelOutletBase(ABC):
|
||||
class ChannelOutletBase(ABC, Generic[TPacket]):
|
||||
@abstractmethod
|
||||
def send(self, raw: bytes) -> _TPacket:
|
||||
def send(self, raw: bytes) -> TPacket:
|
||||
raise NotImplemented()
|
||||
|
||||
@abstractmethod
|
||||
def resend(self, packet: _TPacket) -> _TPacket:
|
||||
def resend(self, packet: TPacket) -> TPacket:
|
||||
raise NotImplemented()
|
||||
|
||||
@property
|
||||
@ -38,7 +38,7 @@ class ChannelOutletBase(ABC):
|
||||
raise NotImplemented()
|
||||
|
||||
@abstractmethod
|
||||
def get_packet_state(self, packet: _TPacket) -> MessageState:
|
||||
def get_packet_state(self, packet: TPacket) -> MessageState:
|
||||
raise NotImplemented()
|
||||
|
||||
@abstractmethod
|
||||
@ -50,16 +50,16 @@ class ChannelOutletBase(ABC):
|
||||
raise NotImplemented()
|
||||
|
||||
@abstractmethod
|
||||
def set_packet_timeout_callback(self, packet: _TPacket, callback: Callable[[_TPacket], None] | None,
|
||||
def set_packet_timeout_callback(self, packet: TPacket, callback: Callable[[TPacket], None] | None,
|
||||
timeout: float | None = None):
|
||||
raise NotImplemented()
|
||||
|
||||
@abstractmethod
|
||||
def set_packet_delivered_callback(self, packet: _TPacket, callback: Callable[[_TPacket], None] | None):
|
||||
def set_packet_delivered_callback(self, packet: TPacket, callback: Callable[[TPacket], None] | None):
|
||||
raise NotImplemented()
|
||||
|
||||
@abstractmethod
|
||||
def get_packet_id(self, packet: _TPacket) -> any:
|
||||
def get_packet_id(self, packet: TPacket) -> any:
|
||||
raise NotImplemented()
|
||||
|
||||
|
||||
@ -97,6 +97,9 @@ class MessageBase(abc.ABC):
|
||||
raise NotImplemented()
|
||||
|
||||
|
||||
MessageCallbackType = NewType("MessageCallbackType", Callable[[MessageBase], bool])
|
||||
|
||||
|
||||
class Envelope:
|
||||
def unpack(self, message_factories: dict[int, Type]) -> MessageBase:
|
||||
msgtype, self.sequence, length = struct.unpack(">HHH", self.raw[:6])
|
||||
@ -120,7 +123,7 @@ class Envelope:
|
||||
self.id = id(self)
|
||||
self.message = message
|
||||
self.raw = raw
|
||||
self.packet: _TPacket = None
|
||||
self.packet: TPacket = None
|
||||
self.sequence = sequence
|
||||
self.outlet = outlet
|
||||
self.tries = 0
|
||||
@ -133,9 +136,9 @@ class Channel(contextlib.AbstractContextManager):
|
||||
self._lock = threading.RLock()
|
||||
self._tx_ring: collections.deque[Envelope] = collections.deque()
|
||||
self._rx_ring: collections.deque[Envelope] = collections.deque()
|
||||
self._message_callback: Callable[[MessageBase], None] | None = None
|
||||
self._message_callbacks: [MessageCallbackType] = []
|
||||
self._next_sequence = 0
|
||||
self._message_factories: dict[int, Type[MessageBase]] = self._get_msg_constructors()
|
||||
self._message_factories: dict[int, Type[MessageBase]] = {}
|
||||
self._max_tries = 5
|
||||
|
||||
def __enter__(self) -> Channel:
|
||||
@ -146,16 +149,6 @@ class Channel(contextlib.AbstractContextManager):
|
||||
self.shutdown()
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _get_msg_constructors() -> (int, Type[MessageBase]):
|
||||
subclass_tuples = []
|
||||
for subclass in MessageBase.__subclasses__():
|
||||
with contextlib.suppress(Exception):
|
||||
subclass() # verify constructor works with no arguments, needed for unpacking
|
||||
subclass_tuples.append((subclass.MSGTYPE, subclass))
|
||||
message_factories = dict(subclass_tuples)
|
||||
return message_factories
|
||||
|
||||
def register_message_type(self, message_class: Type[MessageBase]):
|
||||
if not issubclass(message_class, MessageBase):
|
||||
raise ChannelException(CEType.ME_INVALID_MSG_TYPE, f"{message_class} is not a subclass of {MessageBase}.")
|
||||
@ -169,10 +162,15 @@ class Channel(contextlib.AbstractContextManager):
|
||||
|
||||
self._message_factories[message_class.MSGTYPE] = message_class
|
||||
|
||||
def set_message_callback(self, callback: Callable[[MessageBase], None]):
|
||||
self._message_callback = callback
|
||||
def add_message_callback(self, callback: MessageCallbackType):
|
||||
if callback not in self._message_callbacks:
|
||||
self._message_callbacks.append(callback)
|
||||
|
||||
def remove_message_callback(self, callback: MessageCallbackType):
|
||||
self._message_callbacks.remove(callback)
|
||||
|
||||
def shutdown(self):
|
||||
self._message_callbacks.clear()
|
||||
self.clear_rings()
|
||||
|
||||
def clear_rings(self):
|
||||
@ -218,9 +216,8 @@ class Channel(contextlib.AbstractContextManager):
|
||||
RNS.log("Channel: Duplicate message received", RNS.LOG_DEBUG)
|
||||
return
|
||||
RNS.log(f"Message received: {message}", RNS.LOG_DEBUG)
|
||||
if self._message_callback:
|
||||
threading.Thread(target=self._message_callback, name="Message Callback", args=[message], daemon=True)\
|
||||
.start()
|
||||
for cb in self._message_callbacks:
|
||||
threading.Thread(target=cb, name="Message Callback", args=[message], daemon=True).start()
|
||||
except Exception as ex:
|
||||
RNS.log(f"Channel: Error receiving data: {ex}")
|
||||
|
||||
@ -237,7 +234,7 @@ class Channel(contextlib.AbstractContextManager):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _packet_tx_op(self, packet: _TPacket, op: Callable[[_TPacket], bool]):
|
||||
def _packet_tx_op(self, packet: TPacket, op: Callable[[TPacket], bool]):
|
||||
with self._lock:
|
||||
envelope = next(filter(lambda e: self._outlet.get_packet_id(e.packet) == self._outlet.get_packet_id(packet),
|
||||
self._tx_ring), None)
|
||||
@ -250,10 +247,10 @@ class Channel(contextlib.AbstractContextManager):
|
||||
if not envelope:
|
||||
RNS.log("Channel: Spurious message received.", RNS.LOG_EXTREME)
|
||||
|
||||
def _packet_delivered(self, packet: _TPacket):
|
||||
def _packet_delivered(self, packet: TPacket):
|
||||
self._packet_tx_op(packet, lambda env: True)
|
||||
|
||||
def _packet_timeout(self, packet: _TPacket):
|
||||
def _packet_timeout(self, packet: TPacket):
|
||||
def retry_envelope(envelope: Envelope) -> bool:
|
||||
if envelope.tries >= self._max_tries:
|
||||
RNS.log("Channel: Retry count exceeded, tearing down Link.", RNS.LOG_ERROR)
|
||||
@ -286,6 +283,10 @@ class Channel(contextlib.AbstractContextManager):
|
||||
self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout)
|
||||
return envelope
|
||||
|
||||
@property
|
||||
def MDU(self):
|
||||
return self._outlet.mdu - 6 # sizeof(msgtype) + sizeof(length) + sizeof(sequence)
|
||||
|
||||
|
||||
class LinkChannelOutlet(ChannelOutletBase):
|
||||
def __init__(self, link: RNS.Link):
|
||||
@ -313,7 +314,7 @@ class LinkChannelOutlet(ChannelOutletBase):
|
||||
def is_usable(self):
|
||||
return True # had issues looking at Link.status
|
||||
|
||||
def get_packet_state(self, packet: _TPacket) -> MessageState:
|
||||
def get_packet_state(self, packet: TPacket) -> MessageState:
|
||||
status = packet.receipt.get_status()
|
||||
if status == RNS.PacketReceipt.SENT:
|
||||
return MessageState.MSGSTATE_SENT
|
||||
|
18
RNS/Link.py
18
RNS/Link.py
@ -645,27 +645,11 @@ class Link:
|
||||
if pending_request.request_id == resource.request_id:
|
||||
pending_request.request_timed_out(None)
|
||||
|
||||
def _ensure_channel(self):
|
||||
def get_channel(self):
|
||||
if self._channel is None:
|
||||
self._channel = Channel(LinkChannelOutlet(self))
|
||||
return self._channel
|
||||
|
||||
def set_message_callback(self, callback, message_types=None):
|
||||
if not callback:
|
||||
if self._channel:
|
||||
self._channel.set_message_callback(None)
|
||||
return
|
||||
|
||||
self._ensure_channel()
|
||||
|
||||
if message_types:
|
||||
for msg_type in message_types:
|
||||
self._channel.register_message_type(msg_type)
|
||||
self._channel.set_message_callback(callback)
|
||||
|
||||
def send_message(self, message: RNS.Channel.MessageBase):
|
||||
self._ensure_channel().send(message)
|
||||
|
||||
def receive(self, packet):
|
||||
self.watchdog_lock = True
|
||||
if not self.status == Link.CLOSED and not (self.initiator and packet.context == RNS.Packet.KEEPALIVE and packet.data == bytes([0xFF])):
|
||||
|
@ -32,6 +32,7 @@ from ._version import __version__
|
||||
from .Reticulum import Reticulum
|
||||
from .Identity import Identity
|
||||
from .Link import Link, RequestReceipt
|
||||
from .Channel import MessageBase
|
||||
from .Transport import Transport
|
||||
from .Destination import Destination
|
||||
from .Packet import Packet
|
||||
|
@ -251,7 +251,7 @@ class TestChannel(unittest.TestCase):
|
||||
def handle_message(message: MessageBase):
|
||||
decoded.append(message)
|
||||
|
||||
self.h.channel.set_message_callback(handle_message)
|
||||
self.h.channel.add_message_callback(handle_message)
|
||||
self.assertEqual(len(self.h.outlet.packets), 0)
|
||||
|
||||
envelope = self.h.channel.send(message)
|
||||
|
@ -380,8 +380,10 @@ class TestLink(unittest.TestCase):
|
||||
test_message = MessageTest()
|
||||
test_message.data = "Hello"
|
||||
|
||||
l1.set_message_callback(handle_message)
|
||||
l1.send_message(test_message)
|
||||
channel = l1.get_channel()
|
||||
channel.register_message_type(MessageTest)
|
||||
channel.add_message_callback(handle_message)
|
||||
channel.send(test_message)
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
@ -458,11 +460,13 @@ def targets(yp=False):
|
||||
link.set_resource_strategy(RNS.Link.ACCEPT_ALL)
|
||||
link.set_resource_started_callback(resource_started)
|
||||
link.set_resource_concluded_callback(resource_concluded)
|
||||
channel = link.get_channel()
|
||||
|
||||
def handle_message(message):
|
||||
message.data = message.data + " back"
|
||||
link.send_message(message)
|
||||
link.set_message_callback(handle_message, [MessageTest])
|
||||
channel.send(message)
|
||||
channel.register_message_type(MessageTest)
|
||||
channel.add_message_callback(handle_message)
|
||||
|
||||
m_rns = RNS.Reticulum("./tests/rnsconfig")
|
||||
id1 = RNS.Identity.from_bytes(bytes.fromhex(fixed_keys[0][0]))
|
||||
|
Loading…
Reference in New Issue
Block a user