Fixed callback invocation on channel receive

This commit is contained in:
Mark Qvist 2023-05-19 01:58:28 +02:00
parent 1a860c6ffd
commit d7375bc4c3

View File

@ -176,6 +176,9 @@ class Envelope:
raise ChannelException(CEType.ME_NOT_REGISTERED, f"Unable to find constructor for Channel MSGTYPE {hex(msgtype)}") raise ChannelException(CEType.ME_NOT_REGISTERED, f"Unable to find constructor for Channel MSGTYPE {hex(msgtype)}")
message = ctor() message = ctor()
message.unpack(raw) message.unpack(raw)
self.unpacked = True
self.message = message
return message return message
def pack(self) -> bytes: def pack(self) -> bytes:
@ -183,6 +186,7 @@ class Envelope:
raise ChannelException(CEType.ME_NO_MSG_TYPE, f"{self.message.__class__} lacks MSGTYPE") raise ChannelException(CEType.ME_NO_MSG_TYPE, f"{self.message.__class__} lacks MSGTYPE")
data = self.message.pack() data = self.message.pack()
self.raw = struct.pack(">HHH", self.message.MSGTYPE, self.sequence, len(data)) + data self.raw = struct.pack(">HHH", self.message.MSGTYPE, self.sequence, len(data)) + data
self.packed = True
return self.raw return self.raw
def __init__(self, outlet: ChannelOutletBase, message: MessageBase = None, raw: bytes = None, sequence: int = None): def __init__(self, outlet: ChannelOutletBase, message: MessageBase = None, raw: bytes = None, sequence: int = None):
@ -194,6 +198,8 @@ class Envelope:
self.sequence = sequence self.sequence = sequence
self.outlet = outlet self.outlet = outlet
self.tries = 0 self.tries = 0
self.unpacked = False
self.packed = False
self.tracked = False self.tracked = False
@ -371,22 +377,29 @@ class Channel(contextlib.AbstractContextManager):
def _emplace_envelope(self, envelope: Envelope, ring: collections.deque[Envelope]) -> bool: def _emplace_envelope(self, envelope: Envelope, ring: collections.deque[Envelope]) -> bool:
with self._lock: with self._lock:
i = 0 i = 0
window_overflow = (self._next_rx_sequence+Channel.WINDOW_MAX) % Channel.SEQ_MODULUS
for existing in ring: for existing in ring:
if existing.sequence > envelope.sequence \
and not existing.sequence // 2 > envelope.sequence: # account for overflow if envelope.sequence == existing.sequence:
ring.insert(i, envelope)
return True
if existing.sequence == envelope.sequence:
RNS.log(f"Envelope: Emplacement of duplicate envelope with sequence "+str(envelope.sequence), RNS.LOG_EXTREME) RNS.log(f"Envelope: Emplacement of duplicate envelope with sequence "+str(envelope.sequence), RNS.LOG_EXTREME)
return False return False
if envelope.sequence < existing.sequence and not envelope.sequence < window_overflow:
ring.insert(i, envelope)
RNS.log("Inserted seq "+str(envelope.sequence)+" at "+str(i), RNS.LOG_DEBUG)
envelope.tracked = True
return True
i += 1 i += 1
envelope.tracked = True envelope.tracked = True
ring.append(envelope) ring.append(envelope)
return True return True
def _run_callbacks(self, message: MessageBase): def _run_callbacks(self, message: MessageBase):
with self._lock: cbs = self._message_callbacks.copy()
cbs = self._message_callbacks.copy()
for cb in cbs: for cb in cbs:
try: try:
@ -405,12 +418,11 @@ class Channel(contextlib.AbstractContextManager):
window_overflow = (self._next_rx_sequence+Channel.WINDOW_MAX) % Channel.SEQ_MODULUS window_overflow = (self._next_rx_sequence+Channel.WINDOW_MAX) % Channel.SEQ_MODULUS
if window_overflow < self._next_rx_sequence: if window_overflow < self._next_rx_sequence:
if envelope.sequence > window_overflow: if envelope.sequence > window_overflow:
RNS.log("Invalid packet sequence ("+str(envelope.sequence)+") received on channel "+str(self), RNS.LOG_DEBUG) RNS.log("Invalid packet sequence ("+str(envelope.sequence)+") received on channel "+str(self), RNS.LOG_EXTREME)
return return
else: else:
if envelope.sequence < self._next_rx_sequence: RNS.log("Invalid packet sequence ("+str(envelope.sequence)+") received on channel "+str(self), RNS.LOG_EXTREME)
RNS.log("Invalid packet sequence ("+str(envelope.sequence)+") received on channel "+str(self), RNS.LOG_DEBUG) return
return
is_new = self._emplace_envelope(envelope, self._rx_ring) is_new = self._emplace_envelope(envelope, self._rx_ring)
@ -426,9 +438,13 @@ class Channel(contextlib.AbstractContextManager):
self._next_rx_sequence = (self._next_rx_sequence + 1) % Channel.SEQ_MODULUS self._next_rx_sequence = (self._next_rx_sequence + 1) % Channel.SEQ_MODULUS
for e in contigous: for e in contigous:
m = e.unpack(self._message_factories) if not e.unpacked:
m = e.unpack(self._message_factories)
else:
m = e.message
self._rx_ring.remove(e) self._rx_ring.remove(e)
threading.Thread(target=self._run_callbacks, name="Message Callback", args=[m], daemon=True).start() self._run_callbacks(m)
except Exception as e: except Exception as e:
RNS.log("An error ocurred while receiving data on "+str(self)+". The contained exception was: "+str(e), RNS.LOG_ERROR) RNS.log("An error ocurred while receiving data on "+str(self)+". The contained exception was: "+str(e), RNS.LOG_ERROR)
@ -469,7 +485,7 @@ class Channel(contextlib.AbstractContextManager):
self.window_min += 1 self.window_min += 1
# TODO: Remove at some point # TODO: Remove at some point
RNS.log("Increased "+str(self)+" window to "+str(self.window), RNS.LOG_DEBUG) # RNS.log("Increased "+str(self)+" window to "+str(self.window), RNS.LOG_EXTREME)
if self._outlet.rtt != 0: if self._outlet.rtt != 0:
if self._outlet.rtt > Channel.RTT_FAST: if self._outlet.rtt > Channel.RTT_FAST:
@ -483,19 +499,17 @@ class Channel(contextlib.AbstractContextManager):
if self.window_max < Channel.WINDOW_MAX_MEDIUM and self.medium_rate_rounds == Channel.FAST_RATE_THRESHOLD: if self.window_max < Channel.WINDOW_MAX_MEDIUM and self.medium_rate_rounds == Channel.FAST_RATE_THRESHOLD:
self.window_max = Channel.WINDOW_MAX_MEDIUM self.window_max = Channel.WINDOW_MAX_MEDIUM
# TODO: Remove at some point # TODO: Remove at some point
RNS.log("Increased "+str(self)+" max window to "+str(self.window_max), RNS.LOG_EXTREME) # RNS.log("Increased "+str(self)+" max window to "+str(self.window_max), RNS.LOG_EXTREME)
else: else:
self.fast_rate_rounds += 1 self.fast_rate_rounds += 1
if self.window_max < Channel.WINDOW_MAX_FAST and self.fast_rate_rounds == Channel.FAST_RATE_THRESHOLD: if self.window_max < Channel.WINDOW_MAX_FAST and self.fast_rate_rounds == Channel.FAST_RATE_THRESHOLD:
self.window_max = Channel.WINDOW_MAX_FAST self.window_max = Channel.WINDOW_MAX_FAST
# TODO: Remove at some point # TODO: Remove at some point
RNS.log("Increased "+str(self)+" max window to "+str(self.window_max), RNS.LOG_EXTREME) # RNS.log("Increased "+str(self)+" max window to "+str(self.window_max), RNS.LOG_EXTREME)
else: else:
RNS.log("Envelope not found in TX ring for "+str(self), RNS.LOG_DEBUG) RNS.log("Envelope not found in TX ring for "+str(self), RNS.LOG_EXTREME)
if not envelope: if not envelope:
RNS.log("Spurious message received on "+str(self), RNS.LOG_EXTREME) RNS.log("Spurious message received on "+str(self), RNS.LOG_EXTREME)
@ -525,7 +539,7 @@ class Channel(contextlib.AbstractContextManager):
self.window_max -= 1 self.window_max -= 1
# TODO: Remove at some point # TODO: Remove at some point
RNS.log("Decreased "+str(self)+" window to "+str(self.window), RNS.LOG_EXTREME) # RNS.log("Decreased "+str(self)+" window to "+str(self.window), RNS.LOG_EXTREME)
return False return False
@ -543,16 +557,18 @@ class Channel(contextlib.AbstractContextManager):
with self._lock: with self._lock:
if not self.is_ready_to_send(): if not self.is_ready_to_send():
raise ChannelException(CEType.ME_LINK_NOT_READY, f"Link is not ready") raise ChannelException(CEType.ME_LINK_NOT_READY, f"Link is not ready")
envelope = Envelope(self._outlet, message=message, sequence=self._next_sequence) envelope = Envelope(self._outlet, message=message, sequence=self._next_sequence)
self._next_sequence = (self._next_sequence + 1) % Channel.SEQ_MODULUS self._next_sequence = (self._next_sequence + 1) % Channel.SEQ_MODULUS
self._emplace_envelope(envelope, self._tx_ring) self._emplace_envelope(envelope, self._tx_ring)
if envelope is None: if envelope is None:
raise BlockingIOError() raise BlockingIOError()
envelope.pack() envelope.pack()
if len(envelope.raw) > self._outlet.mdu: if len(envelope.raw) > self._outlet.mdu:
raise ChannelException(CEType.ME_TOO_BIG, f"Packed message too big for packet: {len(envelope.raw)} > {self._outlet.mdu}") raise ChannelException(CEType.ME_TOO_BIG, f"Packed message too big for packet: {len(envelope.raw)} > {self._outlet.mdu}")
envelope.packet = self._outlet.send(envelope.raw) envelope.packet = self._outlet.send(envelope.raw)
envelope.tries += 1 envelope.tries += 1
self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered) self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered)
@ -591,7 +607,6 @@ class LinkChannelOutlet(ChannelOutletBase):
return packet return packet
def resend(self, packet: RNS.Packet) -> RNS.Packet: def resend(self, packet: RNS.Packet) -> RNS.Packet:
RNS.log("Resending packet " + RNS.prettyhexrep(packet.packet_hash), RNS.LOG_DEBUG)
receipt = packet.resend() receipt = packet.resend()
if not receipt: if not receipt:
RNS.log("Failed to resend packet", RNS.LOG_ERROR) RNS.log("Failed to resend packet", RNS.LOG_ERROR)