diff --git a/rnsh/protocol.py b/rnsh/protocol.py index a833ad8..486b9e3 100644 --- a/rnsh/protocol.py +++ b/rnsh/protocol.py @@ -10,6 +10,7 @@ from types import TracebackType from typing import Type, Callable, TypeVar, Tuple import RNS from RNS.vendor import umsgpack +from RNS.Buffer import StreamDataMessage as RNSStreamDataMessage import rnsh.retry import abc import contextlib @@ -81,30 +82,12 @@ class ExecuteCommandMesssage(RNS.MessageBase): self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw) -class StreamDataMessage(RNS.MessageBase): +# Create a version of RNS.Buffer.StreamDataMessage that we control +class StreamDataMessage(RNSStreamDataMessage): MSGTYPE = _make_MSGTYPE(4) STREAM_ID_STDIN = 0 STREAM_ID_STDOUT = 1 STREAM_ID_STDERR = 2 - OVERHEAD = 0 - - def __init__(self, stream_id: int = None, data: bytes = None, eof: bool = False): - super().__init__() - self.stream_id = stream_id - self.data = data - self.eof = eof - - def pack(self) -> bytes: - return umsgpack.packb((self.stream_id, self.eof, bytes(self.data))) - - def unpack(self, raw): - self.stream_id, self.eof, self.data = umsgpack.unpackb(raw) - - -_link_sized_bytes = ("\0"*RNS.Link.MDU).encode("utf-8") -StreamDataMessage.OVERHEAD = len(StreamDataMessage(stream_id=0, data=_link_sized_bytes, eof=True).pack()) \ - - len(_link_sized_bytes) -_link_sized_bytes = None class VersionInfoMessage(RNS.MessageBase): diff --git a/rnsh/session.py b/rnsh/session.py index b79e155..1c948e7 100644 --- a/rnsh/session.py +++ b/rnsh/session.py @@ -210,7 +210,7 @@ class ListenerSession: elif not self.channel.is_ready_to_send(): return False elif len(self.stderr_buf) > 0: - mdu = self.channel.MDU - protocol.StreamDataMessage.OVERHEAD + mdu = protocol.StreamDataMessage.MAX_DATA_LEN data = self.stderr_buf[:mdu] self.stderr_buf = self.stderr_buf[mdu:] send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent @@ -222,7 +222,7 @@ class ListenerSession: self.stderr_eof_sent = True return True elif len(self.stdout_buf) > 0: - mdu = self.channel.MDU - protocol.StreamDataMessage.OVERHEAD + mdu = protocol.StreamDataMessage.MAX_DATA_LEN data = self.stdout_buf[:mdu] self.stdout_buf = self.stdout_buf[mdu:] send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent