Implemented adaptive compression on stdin, stdout and stderr streams

This commit is contained in:
Mark Qvist 2023-09-16 18:43:22 +02:00
parent e393857af8
commit a2a45a82af
2 changed files with 72 additions and 11 deletions

View file

@ -49,6 +49,7 @@ import re
import contextlib import contextlib
import rnsh.args import rnsh.args
import pwd import pwd
import bz2
import rnsh.protocol as protocol import rnsh.protocol as protocol
import rnsh.helpers as helpers import rnsh.helpers as helpers
import rnsh.rnsh import rnsh.rnsh
@ -362,11 +363,42 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness:
processed = False processed = False
if channel.is_ready_to_send(): if channel.is_ready_to_send():
stdin = data_buffer[:mdu] def compress_adaptive(buf: bytes):
data_buffer = data_buffer[mdu:] comp_tries = RNS.RawChannelWriter.COMPRESSION_TRIES
comp_try = 1
comp_success = False
chunk_len = len(buf)
if chunk_len > RNS.RawChannelWriter.MAX_CHUNK_LEN:
chunk_len = RNS.RawChannelWriter.MAX_CHUNK_LEN
chunk_segment = None
chunk_segment = None
while chunk_len > 32 and comp_try < comp_tries:
chunk_segment_length = int(chunk_len/comp_try)
compressed_chunk = bz2.compress(buf[:chunk_segment_length])
compressed_length = len(compressed_chunk)
if compressed_length < protocol.StreamDataMessage.MAX_DATA_LEN and compressed_length < chunk_segment_length:
comp_success = True
break
else:
comp_try += 1
if comp_success:
chunk = compressed_chunk
processed_length = chunk_segment_length
else:
chunk = bytes(buf[:protocol.StreamDataMessage.MAX_DATA_LEN])
processed_length = len(chunk)
return comp_success, processed_length, chunk
comp_success, processed_length, chunk = compress_adaptive(data_buffer)
stdin = chunk
data_buffer = data_buffer[processed_length:]
eof = not sent_eof and stdin_eof and len(stdin) == 0 eof = not sent_eof and stdin_eof and len(stdin) == 0
if len(stdin) > 0 or eof: if len(stdin) > 0 or eof:
channel.send(protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN, stdin, eof)) channel.send(protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN, stdin, eof, comp_success))
sent_eof = eof sent_eof = eof
processed = True processed = True

View file

@ -12,6 +12,7 @@ from typing import TypeVar, Generic, Callable, List
from abc import abstractmethod, ABC from abc import abstractmethod, ABC
from multiprocessing import Manager from multiprocessing import Manager
import os import os
import bz2
import RNS import RNS
import logging as __logging import logging as __logging
@ -204,31 +205,59 @@ class ListenerSession:
await asyncio.sleep(0) await asyncio.sleep(0)
def pump(self) -> bool: def pump(self) -> bool:
def compress_adaptive(buf: bytes):
comp_tries = RNS.RawChannelWriter.COMPRESSION_TRIES
comp_try = 1
comp_success = False
chunk_len = len(buf)
if chunk_len > RNS.RawChannelWriter.MAX_CHUNK_LEN:
chunk_len = RNS.RawChannelWriter.MAX_CHUNK_LEN
chunk_segment = None
chunk_segment = None
while chunk_len > 32 and comp_try < comp_tries:
chunk_segment_length = int(chunk_len/comp_try)
compressed_chunk = bz2.compress(buf[:chunk_segment_length])
compressed_length = len(compressed_chunk)
if compressed_length < protocol.StreamDataMessage.MAX_DATA_LEN and compressed_length < chunk_segment_length:
comp_success = True
break
else:
comp_try += 1
if comp_success:
chunk = compressed_chunk
processed_length = chunk_segment_length
else:
chunk = bytes(buf[:protocol.StreamDataMessage.MAX_DATA_LEN])
processed_length = len(chunk)
return comp_success, processed_length, chunk
try: try:
if self.state != LSState.LSSTATE_RUNNING: if self.state != LSState.LSSTATE_RUNNING:
return False return False
elif not self.channel.is_ready_to_send(): elif not self.channel.is_ready_to_send():
return False return False
elif len(self.stderr_buf) > 0: elif len(self.stderr_buf) > 0:
mdu = protocol.StreamDataMessage.MAX_DATA_LEN comp_success, processed_length, data = compress_adaptive(self.stderr_buf)
data = self.stderr_buf[:mdu] self.stderr_buf = self.stderr_buf[processed_length:]
self.stderr_buf = self.stderr_buf[mdu:]
send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent
self.stderr_eof_sent = self.stderr_eof_sent or send_eof self.stderr_eof_sent = self.stderr_eof_sent or send_eof
msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDERR, msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDERR,
data, send_eof) data, send_eof, comp_success)
self.send(msg) self.send(msg)
if send_eof: if send_eof:
self.stderr_eof_sent = True self.stderr_eof_sent = True
return True return True
elif len(self.stdout_buf) > 0: elif len(self.stdout_buf) > 0:
mdu = protocol.StreamDataMessage.MAX_DATA_LEN comp_success, processed_length, data = compress_adaptive(self.stdout_buf)
data = self.stdout_buf[:mdu] self.stdout_buf = self.stdout_buf[processed_length:]
self.stdout_buf = self.stdout_buf[mdu:]
send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent
self.stdout_eof_sent = self.stdout_eof_sent or send_eof self.stdout_eof_sent = self.stdout_eof_sent or send_eof
msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDOUT, msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDOUT,
data, send_eof) data, send_eof, comp_success)
self.send(msg) self.send(msg)
if send_eof: if send_eof:
self.stdout_eof_sent = True self.stdout_eof_sent = True