Use StreamDataMessage from RNS.Buffer

But don't use Buffer yet
This commit is contained in:
Aaron Heise 2023-03-08 16:54:02 -06:00
parent 5e97dc372e
commit 53332bf1ac
No known key found for this signature in database
GPG Key ID: 6BA54088C41DE8BF
2 changed files with 5 additions and 22 deletions

View File

@ -10,6 +10,7 @@ from types import TracebackType
from typing import Type, Callable, TypeVar, Tuple from typing import Type, Callable, TypeVar, Tuple
import RNS import RNS
from RNS.vendor import umsgpack from RNS.vendor import umsgpack
from RNS.Buffer import StreamDataMessage as RNSStreamDataMessage
import rnsh.retry import rnsh.retry
import abc import abc
import contextlib import contextlib
@ -81,30 +82,12 @@ class ExecuteCommandMesssage(RNS.MessageBase):
self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw) self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw)
class StreamDataMessage(RNS.MessageBase): # Create a version of RNS.Buffer.StreamDataMessage that we control
class StreamDataMessage(RNSStreamDataMessage):
MSGTYPE = _make_MSGTYPE(4) MSGTYPE = _make_MSGTYPE(4)
STREAM_ID_STDIN = 0 STREAM_ID_STDIN = 0
STREAM_ID_STDOUT = 1 STREAM_ID_STDOUT = 1
STREAM_ID_STDERR = 2 STREAM_ID_STDERR = 2
OVERHEAD = 0
def __init__(self, stream_id: int = None, data: bytes = None, eof: bool = False):
super().__init__()
self.stream_id = stream_id
self.data = data
self.eof = eof
def pack(self) -> bytes:
return umsgpack.packb((self.stream_id, self.eof, bytes(self.data)))
def unpack(self, raw):
self.stream_id, self.eof, self.data = umsgpack.unpackb(raw)
_link_sized_bytes = ("\0"*RNS.Link.MDU).encode("utf-8")
StreamDataMessage.OVERHEAD = len(StreamDataMessage(stream_id=0, data=_link_sized_bytes, eof=True).pack()) \
- len(_link_sized_bytes)
_link_sized_bytes = None
class VersionInfoMessage(RNS.MessageBase): class VersionInfoMessage(RNS.MessageBase):

View File

@ -210,7 +210,7 @@ class ListenerSession:
elif not self.channel.is_ready_to_send(): elif not self.channel.is_ready_to_send():
return False return False
elif len(self.stderr_buf) > 0: elif len(self.stderr_buf) > 0:
mdu = self.channel.MDU - protocol.StreamDataMessage.OVERHEAD mdu = protocol.StreamDataMessage.MAX_DATA_LEN
data = self.stderr_buf[:mdu] data = self.stderr_buf[:mdu]
self.stderr_buf = self.stderr_buf[mdu:] self.stderr_buf = self.stderr_buf[mdu:]
send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent
@ -222,7 +222,7 @@ class ListenerSession:
self.stderr_eof_sent = True self.stderr_eof_sent = True
return True return True
elif len(self.stdout_buf) > 0: elif len(self.stdout_buf) > 0:
mdu = self.channel.MDU - protocol.StreamDataMessage.OVERHEAD mdu = protocol.StreamDataMessage.MAX_DATA_LEN
data = self.stdout_buf[:mdu] data = self.stdout_buf[:mdu]
self.stdout_buf = self.stdout_buf[mdu:] self.stdout_buf = self.stdout_buf[mdu:]
send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent