From a3bd1f20fbc9761f71bc90e535f3b70fe70b2a6f Mon Sep 17 00:00:00 2001 From: Aaron Heise Date: Tue, 14 Feb 2023 04:28:31 -0600 Subject: [PATCH] Protocol version 2: do not base64 encode stream data --- rnsh/rnsh.py | 65 +++++++++++++++++++++++++++++++--------------------- 1 file changed, 39 insertions(+), 26 deletions(-) diff --git a/rnsh/rnsh.py b/rnsh/rnsh.py index 1e809ad..c175683 100644 --- a/rnsh/rnsh.py +++ b/rnsh/rnsh.py @@ -219,7 +219,9 @@ def _protocol_make_version(version: int): _PROTOCOL_VERSION_0 = _protocol_make_version(0) _PROTOCOL_VERSION_1 = _protocol_make_version(1) +_PROTOCOL_VERSION_2 = _protocol_make_version(1) +_PROTOCOL_VERSION_DEFAULT = _PROTOCOL_VERSION_2 def _protocol_split_version(version: int): return (version >> 32) & 0xffffffff, version & 0xffffffff @@ -229,6 +231,20 @@ def _protocol_check_magic(value: int): return _protocol_split_version(value)[0] == _PROTOCOL_MAGIC +def _protocol_response_chars_take(link_mdu: int, version: int) -> int: + if version >= _PROTOCOL_VERSION_2: + return link_mdu - 64 # TODO: tune + else: + return link_mdu // 2 + + +def _protocol_request_chars_take(link_mdu: int, version: int, term: str, cmd: str) -> int: + if version >= _PROTOCOL_VERSION_2: + return link_mdu - 15 * 8 - len(term) - len(cmd) - 20 # TODO: tune + else: + return link_mdu // 2 + + class Session: _processes: [(any, Session)] = [] _lock = threading.RLock() @@ -253,7 +269,6 @@ class Session: def __init__(self, tag: any, cmd: [str], - mdu: int, data_available_callback: callable, terminated_callback: callable, term: str | None, @@ -261,7 +276,6 @@ class Session: loop: asyncio.AbstractEventLoop = None): self._log = _get_logger(self.__class__.__name__) - self._mdu = mdu self._loop = loop if loop is not None else asyncio.get_running_loop() self._process = process.CallbackSubprocess(argv=cmd, env={"TERM": term or os.environ.get("TERM", None), @@ -279,14 +293,6 @@ class Session: self._term_state: [int] = None Session.put_for_tag(tag, self) - @property - def mdu(self) -> int: - return self._mdu - - @mdu.setter - def mdu(self, val: int): - self._mdu = val - def pending_receipt_peek(self) -> RNS.PacketReceipt | None: return self._pending_receipt @@ -358,9 +364,8 @@ class Session: @staticmethod def default_request(stdin_fd: int | None) -> [any]: global _tr - global _PROTOCOL_VERSION_1 request: list[any] = [ - _PROTOCOL_VERSION_1, # 0 Protocol Version + _PROTOCOL_VERSION_DEFAULT, # 0 Protocol Version None, # 1 Stdin None, # 2 TERM variable None, # 3 termios attributes or something @@ -406,7 +411,8 @@ class Session: with contextlib.suppress(Exception): self.process.tcsetattr(termios.TCSANOW, term_state[0]) if stdin is not None and len(stdin) > 0: - stdin = base64.b64decode(stdin) + if data[Session.REQUEST_IDX_VERS] < _PROTOCOL_VERSION_2: + stdin = base64.b64decode(stdin) self.process.write(stdin) response[Session.RESPONSE_IDX_RETCODE] = None if self.process.running else self.return_code @@ -415,7 +421,10 @@ class Session: response[Session.RESPONSE_IDX_RDYBYTE] = len(self._data_buffer) if stdout is not None and len(stdout) > 0: - response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode(stdout).decode("utf-8") + if data[Session.REQUEST_IDX_VERS] < _PROTOCOL_VERSION_2: + response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode(stdout).decode("utf-8") + else: + response[Session.RESPONSE_IDX_STDOUT] = bytes(stdout) return response RESPONSE_IDX_VERSION = 0 @@ -426,7 +435,7 @@ class Session: RESPONSE_IDX_TMSTAMP = 5 @staticmethod - def default_response(version: int = _PROTOCOL_VERSION_1) -> [any]: + def default_response(version: int = _PROTOCOL_VERSION_2) -> [any]: response: list[any] = [ version, # 0: Protocol version False, # 1: Process running @@ -439,9 +448,11 @@ class Session: return response @classmethod - def error_response(cls, msg: str, version: int = _PROTOCOL_VERSION_1) -> [any]: + def error_response(cls, msg: str, version: int = _PROTOCOL_VERSION_2) -> [any]: response = cls.default_response(version) - response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode( f"{msg}\r\n".encode("utf-8")) + msg_bytes = f"{msg}\r\n".encode("utf-8") + response[Session.RESPONSE_IDX_STDOUT] = \ + base64.b64encode(msg_bytes) if version < _PROTOCOL_VERSION_2 else bytes(msg_bytes) response[Session.RESPONSE_IDX_RETCODE] = 255 response[Session.RESPONSE_IDX_RDYBYTE] = 0 return response @@ -531,7 +542,6 @@ def _listen_start_proc(link: RNS.Link, remote_identity: str | None, term: str, c cmd=cmd, term=term, remote_identity=remote_identity, - mdu=link.MDU, loop=loop, data_available_callback=functools.partial(_subproc_data_ready, link), terminated_callback=functools.partial(_subproc_terminated, link)) @@ -590,11 +600,11 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_ if not _protocol_check_magic(remote_version): raise Exception("Request magic incorrect") - if not remote_version == _PROTOCOL_VERSION_0 and not remote_version == _PROTOCOL_VERSION_1: + if not remote_version <= _PROTOCOL_VERSION_DEFAULT: return Session.error_response("Listener<->initiator version mismatch") cmd = _cmd - if remote_version == _PROTOCOL_VERSION_1: + if remote_version >= _PROTOCOL_VERSION_1: remote_command = data[Session.REQUEST_IDX_CMD] if remote_command is not None and len(remote_command) > 0: if _no_remote_command: @@ -620,7 +630,7 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_ loop=_loop) # leave significant headroom for metadata and encoding - result = session.process_request(data, link.MDU * 1 // 2) + result = session.process_request(data, _protocol_response_chars_take(link.MDU, remote_version)) return result # return ProcessState.default_response() except Exception as e: @@ -731,7 +741,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid= request = Session.default_request(sys.stdin.fileno()) log.debug(f"Sending {len(stdin) or 0} bytes to listener") # log.debug(f"Sending {stdin} to listener") - request[Session.REQUEST_IDX_STDIN] = (base64.b64encode(stdin) if stdin is not None else None) + request[Session.REQUEST_IDX_STDIN] = bytes(stdin) request[Session.REQUEST_IDX_CMD] = cmd # TODO: Tune @@ -775,15 +785,15 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid= version = request_receipt.response[Session.RESPONSE_IDX_VERSION] or 0 if not _protocol_check_magic(version): raise RemoteExecutionError("Protocol error") - elif version != _PROTOCOL_VERSION_0 and version != _PROTOCOL_VERSION_1: + elif version != _PROTOCOL_VERSION_2: raise RemoteExecutionError("Protocol version mismatch") running = request_receipt.response[Session.RESPONSE_IDX_RUNNING] or True return_code = request_receipt.response[Session.RESPONSE_IDX_RETCODE] ready_bytes = request_receipt.response[Session.RESPONSE_IDX_RDYBYTE] or 0 stdout = request_receipt.response[Session.RESPONSE_IDX_STDOUT] - if stdout is not None: - stdout = base64.b64decode(stdout) + # if stdout is not None: + # stdout = base64.b64decode(stdout) timestamp = request_receipt.response[Session.RESPONSE_IDX_TMSTAMP] # log.debug("data: " + (stdout.decode("utf-8") if stdout is not None else "")) except RemoteExecutionError: @@ -871,7 +881,10 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness if first_loop: first_loop = False - mdu = _link.MDU // 2 + mdu = _protocol_request_chars_take(_link.MDU, + _PROTOCOL_VERSION_DEFAULT, + os.environ.get("TERM", ""), + " ".join(command)) _new_data.set() if _link: