From a2a45a82af6346301373f80c4bd422c467425f26 Mon Sep 17 00:00:00 2001 From: Mark Qvist Date: Sat, 16 Sep 2023 18:43:22 +0200 Subject: [PATCH] Implemented adaptive compression on stdin, stdout and stderr streams --- rnsh/initiator.py | 38 +++++++++++++++++++++++++++++++++++--- rnsh/session.py | 45 +++++++++++++++++++++++++++++++++++++-------- 2 files changed, 72 insertions(+), 11 deletions(-) diff --git a/rnsh/initiator.py b/rnsh/initiator.py index 9f62ab8..74d30e3 100644 --- a/rnsh/initiator.py +++ b/rnsh/initiator.py @@ -49,6 +49,7 @@ import re import contextlib import rnsh.args import pwd +import bz2 import rnsh.protocol as protocol import rnsh.helpers as helpers import rnsh.rnsh @@ -362,11 +363,42 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness: processed = False if channel.is_ready_to_send(): - stdin = data_buffer[:mdu] - data_buffer = data_buffer[mdu:] + def compress_adaptive(buf: bytes): + comp_tries = RNS.RawChannelWriter.COMPRESSION_TRIES + comp_try = 1 + comp_success = False + + chunk_len = len(buf) + if chunk_len > RNS.RawChannelWriter.MAX_CHUNK_LEN: + chunk_len = RNS.RawChannelWriter.MAX_CHUNK_LEN + chunk_segment = None + + chunk_segment = None + while chunk_len > 32 and comp_try < comp_tries: + chunk_segment_length = int(chunk_len/comp_try) + compressed_chunk = bz2.compress(buf[:chunk_segment_length]) + compressed_length = len(compressed_chunk) + if compressed_length < protocol.StreamDataMessage.MAX_DATA_LEN and compressed_length < chunk_segment_length: + comp_success = True + break + else: + comp_try += 1 + + if comp_success: + chunk = compressed_chunk + processed_length = chunk_segment_length + else: + chunk = bytes(buf[:protocol.StreamDataMessage.MAX_DATA_LEN]) + processed_length = len(chunk) + + return comp_success, processed_length, chunk + + comp_success, processed_length, chunk = compress_adaptive(data_buffer) + stdin = chunk + data_buffer = data_buffer[processed_length:] eof = not sent_eof and stdin_eof and len(stdin) == 0 if len(stdin) > 0 or eof: - channel.send(protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN, stdin, eof)) + channel.send(protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN, stdin, eof, comp_success)) sent_eof = eof processed = True diff --git a/rnsh/session.py b/rnsh/session.py index 1c948e7..6a0dfee 100644 --- a/rnsh/session.py +++ b/rnsh/session.py @@ -12,6 +12,7 @@ from typing import TypeVar, Generic, Callable, List from abc import abstractmethod, ABC from multiprocessing import Manager import os +import bz2 import RNS import logging as __logging @@ -204,31 +205,59 @@ class ListenerSession: await asyncio.sleep(0) def pump(self) -> bool: + def compress_adaptive(buf: bytes): + comp_tries = RNS.RawChannelWriter.COMPRESSION_TRIES + comp_try = 1 + comp_success = False + + chunk_len = len(buf) + if chunk_len > RNS.RawChannelWriter.MAX_CHUNK_LEN: + chunk_len = RNS.RawChannelWriter.MAX_CHUNK_LEN + chunk_segment = None + + chunk_segment = None + while chunk_len > 32 and comp_try < comp_tries: + chunk_segment_length = int(chunk_len/comp_try) + compressed_chunk = bz2.compress(buf[:chunk_segment_length]) + compressed_length = len(compressed_chunk) + if compressed_length < protocol.StreamDataMessage.MAX_DATA_LEN and compressed_length < chunk_segment_length: + comp_success = True + break + else: + comp_try += 1 + + if comp_success: + chunk = compressed_chunk + processed_length = chunk_segment_length + else: + chunk = bytes(buf[:protocol.StreamDataMessage.MAX_DATA_LEN]) + processed_length = len(chunk) + + return comp_success, processed_length, chunk + try: if self.state != LSState.LSSTATE_RUNNING: return False elif not self.channel.is_ready_to_send(): return False elif len(self.stderr_buf) > 0: - mdu = protocol.StreamDataMessage.MAX_DATA_LEN - data = self.stderr_buf[:mdu] - self.stderr_buf = self.stderr_buf[mdu:] + comp_success, processed_length, data = compress_adaptive(self.stderr_buf) + self.stderr_buf = self.stderr_buf[processed_length:] send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent self.stderr_eof_sent = self.stderr_eof_sent or send_eof msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDERR, - data, send_eof) + data, send_eof, comp_success) self.send(msg) if send_eof: self.stderr_eof_sent = True return True elif len(self.stdout_buf) > 0: - mdu = protocol.StreamDataMessage.MAX_DATA_LEN - data = self.stdout_buf[:mdu] - self.stdout_buf = self.stdout_buf[mdu:] + comp_success, processed_length, data = compress_adaptive(self.stdout_buf) + self.stdout_buf = self.stdout_buf[processed_length:] send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent self.stdout_eof_sent = self.stdout_eof_sent or send_eof msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDOUT, - data, send_eof) + data, send_eof, comp_success) self.send(msg) if send_eof: self.stdout_eof_sent = True