mirror of
https://github.com/markqvist/Reticulum.git
synced 2025-07-24 07:21:07 -04:00
Initial work on Channel
This commit is contained in:
parent
b381a61be8
commit
68cb4a6740
5 changed files with 759 additions and 1 deletions
316
tests/channel.py
Normal file
316
tests/channel.py
Normal file
|
@ -0,0 +1,316 @@
|
|||
from __future__ import annotations
|
||||
import threading
|
||||
import RNS
|
||||
from RNS.Channel import MessageState, ChannelOutletBase, Channel, MessageBase
|
||||
from RNS.vendor import umsgpack
|
||||
from typing import Callable
|
||||
import contextlib
|
||||
import typing
|
||||
import types
|
||||
import time
|
||||
import uuid
|
||||
import unittest
|
||||
|
||||
|
||||
class Packet:
|
||||
timeout = 1.0
|
||||
|
||||
def __init__(self, raw: bytes):
|
||||
self.state = MessageState.MSGSTATE_NEW
|
||||
self.raw = raw
|
||||
self.packet_id = uuid.uuid4()
|
||||
self.tries = 0
|
||||
self.timeout_id = None
|
||||
self.lock = threading.RLock()
|
||||
self.instances = 0
|
||||
self.timeout_callback: Callable[[Packet], None] | None = None
|
||||
self.delivered_callback: Callable[[Packet], None] | None = None
|
||||
|
||||
def set_timeout(self, callback: Callable[[Packet], None] | None, timeout: float):
|
||||
with self.lock:
|
||||
if timeout is not None:
|
||||
self.timeout = timeout
|
||||
self.timeout_callback = callback
|
||||
|
||||
|
||||
def send(self):
|
||||
self.tries += 1
|
||||
self.state = MessageState.MSGSTATE_SENT
|
||||
|
||||
def elapsed(timeout: float, timeout_id: uuid.uuid4):
|
||||
with self.lock:
|
||||
self.instances += 1
|
||||
try:
|
||||
time.sleep(timeout)
|
||||
with self.lock:
|
||||
if self.timeout_id == timeout_id:
|
||||
self.timeout_id = None
|
||||
self.state = MessageState.MSGSTATE_FAILED
|
||||
if self.timeout_callback:
|
||||
self.timeout_callback(self)
|
||||
finally:
|
||||
with self.lock:
|
||||
self.instances -= 1
|
||||
|
||||
self.timeout_id = uuid.uuid4()
|
||||
threading.Thread(target=elapsed, name="Packet Timeout", args=[self.timeout, self.timeout_id],
|
||||
daemon=True).start()
|
||||
|
||||
def clear_timeout(self):
|
||||
self.timeout_id = None
|
||||
|
||||
def set_delivered_callback(self, callback: Callable[[Packet], None]):
|
||||
self.delivered_callback = callback
|
||||
|
||||
def delivered(self):
|
||||
with self.lock:
|
||||
self.state = MessageState.MSGSTATE_DELIVERED
|
||||
self.timeout_id = None
|
||||
if self.delivered_callback:
|
||||
self.delivered_callback(self)
|
||||
|
||||
|
||||
class ChannelOutletTest(ChannelOutletBase):
|
||||
def get_packet_state(self, packet: Packet) -> MessageState:
|
||||
return packet.state
|
||||
|
||||
def set_packet_timeout_callback(self, packet: Packet, callback: Callable[[Packet], None] | None,
|
||||
timeout: float | None = None):
|
||||
packet.set_timeout(callback, timeout)
|
||||
|
||||
def set_packet_delivered_callback(self, packet: Packet, callback: Callable[[Packet], None] | None):
|
||||
packet.set_delivered_callback(callback)
|
||||
|
||||
def get_packet_id(self, packet: Packet) -> any:
|
||||
return packet.packet_id
|
||||
|
||||
def __init__(self, mdu: int, rtt: float):
|
||||
self.link_id = uuid.uuid4()
|
||||
self.timeout_callbacks = 0
|
||||
self._mdu = mdu
|
||||
self._rtt = rtt
|
||||
self._usable = True
|
||||
self.packets = []
|
||||
self.packet_callback: Callable[[ChannelOutletBase, bytes], None] | None = None
|
||||
|
||||
def send(self, raw: bytes) -> Packet:
|
||||
packet = Packet(raw)
|
||||
packet.send()
|
||||
self.packets.append(packet)
|
||||
return packet
|
||||
|
||||
def resend(self, packet: Packet) -> Packet:
|
||||
packet.send()
|
||||
return packet
|
||||
|
||||
@property
|
||||
def mdu(self):
|
||||
return self._mdu
|
||||
|
||||
@property
|
||||
def rtt(self):
|
||||
return self._rtt
|
||||
|
||||
@property
|
||||
def is_usable(self):
|
||||
return self._usable
|
||||
|
||||
def timed_out(self):
|
||||
self.timeout_callbacks += 1
|
||||
|
||||
def __str__(self):
|
||||
return str(self.link_id)
|
||||
|
||||
|
||||
class MessageTest(MessageBase):
|
||||
MSGTYPE = 0xabcd
|
||||
|
||||
def __init__(self):
|
||||
self.id = str(uuid.uuid4())
|
||||
self.data = "test"
|
||||
self.not_serialized = str(uuid.uuid4())
|
||||
|
||||
def pack(self) -> bytes:
|
||||
return umsgpack.packb((self.id, self.data))
|
||||
|
||||
def unpack(self, raw):
|
||||
self.id, self.data = umsgpack.unpackb(raw)
|
||||
|
||||
|
||||
class ProtocolHarness(contextlib.AbstractContextManager):
|
||||
def __init__(self, rtt: float):
|
||||
self.outlet = ChannelOutletTest(mdu=500, rtt=rtt)
|
||||
self.channel = Channel(self.outlet)
|
||||
|
||||
def cleanup(self):
|
||||
self.channel.shutdown()
|
||||
|
||||
def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
|
||||
__traceback: types.TracebackType) -> bool:
|
||||
# self._log.debug(f"__exit__({__exc_type}, {__exc_value}, {__traceback})")
|
||||
self.cleanup()
|
||||
return False
|
||||
|
||||
|
||||
class TestChannel(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.rtt = 0.001
|
||||
self.retry_interval = self.rtt * 150
|
||||
Packet.timeout = self.retry_interval
|
||||
self.h = ProtocolHarness(self.rtt)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.h.cleanup()
|
||||
|
||||
def test_send_one_retry(self):
|
||||
message = MessageTest()
|
||||
|
||||
self.assertEqual(0, len(self.h.outlet.packets))
|
||||
|
||||
envelope = self.h.channel.send(message)
|
||||
|
||||
self.assertIsNotNone(envelope)
|
||||
self.assertIsNotNone(envelope.raw)
|
||||
self.assertEqual(1, len(self.h.outlet.packets))
|
||||
self.assertIsNotNone(envelope.packet)
|
||||
self.assertTrue(envelope in self.h.channel._tx_ring)
|
||||
self.assertTrue(envelope.tracked)
|
||||
|
||||
packet = self.h.outlet.packets[0]
|
||||
|
||||
self.assertEqual(envelope.packet, packet)
|
||||
self.assertEqual(1, envelope.tries)
|
||||
self.assertEqual(1, packet.tries)
|
||||
self.assertEqual(1, packet.instances)
|
||||
self.assertEqual(MessageState.MSGSTATE_SENT, packet.state)
|
||||
self.assertEqual(envelope.raw, packet.raw)
|
||||
|
||||
time.sleep(self.retry_interval * 1.5)
|
||||
|
||||
self.assertEqual(1, len(self.h.outlet.packets))
|
||||
self.assertEqual(2, envelope.tries)
|
||||
self.assertEqual(2, packet.tries)
|
||||
self.assertEqual(1, packet.instances)
|
||||
|
||||
time.sleep(self.retry_interval)
|
||||
|
||||
self.assertEqual(1, len(self.h.outlet.packets))
|
||||
self.assertEqual(self.h.outlet.packets[0], packet)
|
||||
self.assertEqual(3, envelope.tries)
|
||||
self.assertEqual(3, packet.tries)
|
||||
self.assertEqual(1, packet.instances)
|
||||
self.assertEqual(MessageState.MSGSTATE_SENT, packet.state)
|
||||
|
||||
packet.delivered()
|
||||
|
||||
self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state)
|
||||
|
||||
time.sleep(self.retry_interval)
|
||||
|
||||
self.assertEqual(1, len(self.h.outlet.packets))
|
||||
self.assertEqual(3, envelope.tries)
|
||||
self.assertEqual(3, packet.tries)
|
||||
self.assertEqual(0, packet.instances)
|
||||
self.assertFalse(envelope.tracked)
|
||||
|
||||
def test_send_timeout(self):
|
||||
message = MessageTest()
|
||||
|
||||
self.assertEqual(0, len(self.h.outlet.packets))
|
||||
|
||||
envelope = self.h.channel.send(message)
|
||||
|
||||
self.assertIsNotNone(envelope)
|
||||
self.assertIsNotNone(envelope.raw)
|
||||
self.assertEqual(1, len(self.h.outlet.packets))
|
||||
self.assertIsNotNone(envelope.packet)
|
||||
self.assertTrue(envelope in self.h.channel._tx_ring)
|
||||
self.assertTrue(envelope.tracked)
|
||||
|
||||
packet = self.h.outlet.packets[0]
|
||||
|
||||
self.assertEqual(envelope.packet, packet)
|
||||
self.assertEqual(1, envelope.tries)
|
||||
self.assertEqual(1, packet.tries)
|
||||
self.assertEqual(1, packet.instances)
|
||||
self.assertEqual(MessageState.MSGSTATE_SENT, packet.state)
|
||||
self.assertEqual(envelope.raw, packet.raw)
|
||||
|
||||
time.sleep(self.retry_interval * 7.5)
|
||||
|
||||
self.assertEqual(1, len(self.h.outlet.packets))
|
||||
self.assertEqual(5, envelope.tries)
|
||||
self.assertEqual(5, packet.tries)
|
||||
self.assertEqual(0, packet.instances)
|
||||
self.assertEqual(MessageState.MSGSTATE_FAILED, packet.state)
|
||||
self.assertFalse(envelope.tracked)
|
||||
|
||||
def eat_own_dog_food(self, message: MessageBase, checker: typing.Callable[[MessageBase], None]):
|
||||
decoded: [MessageBase] = []
|
||||
|
||||
def handle_message(message: MessageBase):
|
||||
decoded.append(message)
|
||||
|
||||
self.h.channel.set_message_callback(handle_message)
|
||||
self.assertEqual(len(self.h.outlet.packets), 0)
|
||||
|
||||
envelope = self.h.channel.send(message)
|
||||
time.sleep(self.retry_interval * 0.5)
|
||||
|
||||
self.assertIsNotNone(envelope)
|
||||
self.assertIsNotNone(envelope.raw)
|
||||
self.assertEqual(1, len(self.h.outlet.packets))
|
||||
self.assertIsNotNone(envelope.packet)
|
||||
self.assertTrue(envelope in self.h.channel._tx_ring)
|
||||
self.assertTrue(envelope.tracked)
|
||||
|
||||
packet = self.h.outlet.packets[0]
|
||||
|
||||
self.assertEqual(envelope.packet, packet)
|
||||
self.assertEqual(1, envelope.tries)
|
||||
self.assertEqual(1, packet.tries)
|
||||
self.assertEqual(1, packet.instances)
|
||||
self.assertEqual(MessageState.MSGSTATE_SENT, packet.state)
|
||||
self.assertEqual(envelope.raw, packet.raw)
|
||||
|
||||
packet.delivered()
|
||||
|
||||
self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state)
|
||||
|
||||
time.sleep(self.retry_interval * 2)
|
||||
|
||||
self.assertEqual(1, len(self.h.outlet.packets))
|
||||
self.assertEqual(1, envelope.tries)
|
||||
self.assertEqual(1, packet.tries)
|
||||
self.assertEqual(0, packet.instances)
|
||||
self.assertFalse(envelope.tracked)
|
||||
|
||||
self.assertEqual(len(self.h.outlet.packets), 1)
|
||||
self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state)
|
||||
self.assertFalse(envelope.tracked)
|
||||
self.assertEqual(0, len(decoded))
|
||||
|
||||
self.h.channel.receive(packet.raw)
|
||||
|
||||
self.assertEqual(1, len(decoded))
|
||||
|
||||
rx_message = decoded[0]
|
||||
|
||||
self.assertIsNotNone(rx_message)
|
||||
self.assertIsInstance(rx_message, message.__class__)
|
||||
checker(rx_message)
|
||||
|
||||
def test_send_receive_message_test(self):
|
||||
message = MessageTest()
|
||||
|
||||
def check(rx_message: MessageBase):
|
||||
self.assertIsInstance(rx_message, message.__class__)
|
||||
self.assertEqual(message.id, rx_message.id)
|
||||
self.assertEqual(message.data, rx_message.data)
|
||||
self.assertNotEqual(message.not_serialized, rx_message.not_serialized)
|
||||
|
||||
self.eat_own_dog_food(message, check)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
|
@ -6,6 +6,9 @@ import threading
|
|||
import time
|
||||
import RNS
|
||||
import os
|
||||
from tests.channel import MessageTest
|
||||
from RNS.Channel import MessageBase
|
||||
|
||||
|
||||
APP_NAME = "rns_unit_tests"
|
||||
|
||||
|
@ -46,6 +49,11 @@ def close_rns():
|
|||
global c_rns
|
||||
if c_rns != None:
|
||||
c_rns.m_proc.kill()
|
||||
# stdout, stderr = c_rns.m_proc.communicate()
|
||||
# if stdout:
|
||||
# print(stdout.decode("utf-8"))
|
||||
# if stderr:
|
||||
# print(stderr.decode("utf-8"))
|
||||
|
||||
class TestLink(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
@ -346,6 +354,52 @@ class TestLink(unittest.TestCase):
|
|||
time.sleep(0.5)
|
||||
self.assertEqual(l1.status, RNS.Link.CLOSED)
|
||||
|
||||
def test_10_channel_round_trip(self):
|
||||
global c_rns
|
||||
init_rns(self)
|
||||
print("")
|
||||
print("Channel round trip test")
|
||||
|
||||
# TODO: Load this from public bytes only
|
||||
id1 = RNS.Identity.from_bytes(bytes.fromhex(fixed_keys[0][0]))
|
||||
self.assertEqual(id1.hash, bytes.fromhex(fixed_keys[0][1]))
|
||||
|
||||
dest = RNS.Destination(id1, RNS.Destination.OUT, RNS.Destination.SINGLE, APP_NAME, "link", "establish")
|
||||
|
||||
self.assertEqual(dest.hash, bytes.fromhex("fb48da0e82e6e01ba0c014513f74540d"))
|
||||
|
||||
l1 = RNS.Link(dest)
|
||||
time.sleep(1)
|
||||
self.assertEqual(l1.status, RNS.Link.ACTIVE)
|
||||
|
||||
received = []
|
||||
|
||||
def handle_message(message: MessageBase):
|
||||
received.append(message)
|
||||
|
||||
test_message = MessageTest()
|
||||
test_message.data = "Hello"
|
||||
|
||||
l1.set_message_callback(handle_message)
|
||||
l1.send_message(test_message)
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
self.assertEqual(1, len(received))
|
||||
|
||||
rx_message = received[0]
|
||||
|
||||
self.assertIsInstance(rx_message, MessageTest)
|
||||
self.assertEqual("Hello back", rx_message.data)
|
||||
self.assertEqual(test_message.id, rx_message.id)
|
||||
self.assertNotEqual(test_message.not_serialized, rx_message.not_serialized)
|
||||
self.assertEqual(1, len(l1._channel._rx_ring))
|
||||
|
||||
l1.teardown()
|
||||
time.sleep(0.5)
|
||||
self.assertEqual(l1.status, RNS.Link.CLOSED)
|
||||
self.assertEqual(0, len(l1._channel._rx_ring))
|
||||
|
||||
|
||||
def size_str(self, num, suffix='B'):
|
||||
units = ['','K','M','G','T','P','E','Z']
|
||||
|
@ -405,6 +459,11 @@ def targets(yp=False):
|
|||
link.set_resource_started_callback(resource_started)
|
||||
link.set_resource_concluded_callback(resource_concluded)
|
||||
|
||||
def handle_message(message):
|
||||
message.data = message.data + " back"
|
||||
link.send_message(message)
|
||||
link.set_message_callback(handle_message, [MessageTest])
|
||||
|
||||
m_rns = RNS.Reticulum("./tests/rnsconfig")
|
||||
id1 = RNS.Identity.from_bytes(bytes.fromhex(fixed_keys[0][0]))
|
||||
d1 = RNS.Destination(id1, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, "link", "establish")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue