Protocol version 2: do not base64 encode stream data

This commit is contained in:
Aaron Heise 2023-02-14 04:28:31 -06:00
parent 2c61fdf391
commit a3bd1f20fb

View File

@ -219,7 +219,9 @@ def _protocol_make_version(version: int):
_PROTOCOL_VERSION_0 = _protocol_make_version(0)
_PROTOCOL_VERSION_1 = _protocol_make_version(1)
_PROTOCOL_VERSION_2 = _protocol_make_version(1)
_PROTOCOL_VERSION_DEFAULT = _PROTOCOL_VERSION_2
def _protocol_split_version(version: int):
return (version >> 32) & 0xffffffff, version & 0xffffffff
@ -229,6 +231,20 @@ def _protocol_check_magic(value: int):
return _protocol_split_version(value)[0] == _PROTOCOL_MAGIC
def _protocol_response_chars_take(link_mdu: int, version: int) -> int:
if version >= _PROTOCOL_VERSION_2:
return link_mdu - 64 # TODO: tune
else:
return link_mdu // 2
def _protocol_request_chars_take(link_mdu: int, version: int, term: str, cmd: str) -> int:
if version >= _PROTOCOL_VERSION_2:
return link_mdu - 15 * 8 - len(term) - len(cmd) - 20 # TODO: tune
else:
return link_mdu // 2
class Session:
_processes: [(any, Session)] = []
_lock = threading.RLock()
@ -253,7 +269,6 @@ class Session:
def __init__(self,
tag: any,
cmd: [str],
mdu: int,
data_available_callback: callable,
terminated_callback: callable,
term: str | None,
@ -261,7 +276,6 @@ class Session:
loop: asyncio.AbstractEventLoop = None):
self._log = _get_logger(self.__class__.__name__)
self._mdu = mdu
self._loop = loop if loop is not None else asyncio.get_running_loop()
self._process = process.CallbackSubprocess(argv=cmd,
env={"TERM": term or os.environ.get("TERM", None),
@ -279,14 +293,6 @@ class Session:
self._term_state: [int] = None
Session.put_for_tag(tag, self)
@property
def mdu(self) -> int:
return self._mdu
@mdu.setter
def mdu(self, val: int):
self._mdu = val
def pending_receipt_peek(self) -> RNS.PacketReceipt | None:
return self._pending_receipt
@ -358,9 +364,8 @@ class Session:
@staticmethod
def default_request(stdin_fd: int | None) -> [any]:
global _tr
global _PROTOCOL_VERSION_1
request: list[any] = [
_PROTOCOL_VERSION_1, # 0 Protocol Version
_PROTOCOL_VERSION_DEFAULT, # 0 Protocol Version
None, # 1 Stdin
None, # 2 TERM variable
None, # 3 termios attributes or something
@ -406,7 +411,8 @@ class Session:
with contextlib.suppress(Exception):
self.process.tcsetattr(termios.TCSANOW, term_state[0])
if stdin is not None and len(stdin) > 0:
stdin = base64.b64decode(stdin)
if data[Session.REQUEST_IDX_VERS] < _PROTOCOL_VERSION_2:
stdin = base64.b64decode(stdin)
self.process.write(stdin)
response[Session.RESPONSE_IDX_RETCODE] = None if self.process.running else self.return_code
@ -415,7 +421,10 @@ class Session:
response[Session.RESPONSE_IDX_RDYBYTE] = len(self._data_buffer)
if stdout is not None and len(stdout) > 0:
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode(stdout).decode("utf-8")
if data[Session.REQUEST_IDX_VERS] < _PROTOCOL_VERSION_2:
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode(stdout).decode("utf-8")
else:
response[Session.RESPONSE_IDX_STDOUT] = bytes(stdout)
return response
RESPONSE_IDX_VERSION = 0
@ -426,7 +435,7 @@ class Session:
RESPONSE_IDX_TMSTAMP = 5
@staticmethod
def default_response(version: int = _PROTOCOL_VERSION_1) -> [any]:
def default_response(version: int = _PROTOCOL_VERSION_2) -> [any]:
response: list[any] = [
version, # 0: Protocol version
False, # 1: Process running
@ -439,9 +448,11 @@ class Session:
return response
@classmethod
def error_response(cls, msg: str, version: int = _PROTOCOL_VERSION_1) -> [any]:
def error_response(cls, msg: str, version: int = _PROTOCOL_VERSION_2) -> [any]:
response = cls.default_response(version)
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode( f"{msg}\r\n".encode("utf-8"))
msg_bytes = f"{msg}\r\n".encode("utf-8")
response[Session.RESPONSE_IDX_STDOUT] = \
base64.b64encode(msg_bytes) if version < _PROTOCOL_VERSION_2 else bytes(msg_bytes)
response[Session.RESPONSE_IDX_RETCODE] = 255
response[Session.RESPONSE_IDX_RDYBYTE] = 0
return response
@ -531,7 +542,6 @@ def _listen_start_proc(link: RNS.Link, remote_identity: str | None, term: str, c
cmd=cmd,
term=term,
remote_identity=remote_identity,
mdu=link.MDU,
loop=loop,
data_available_callback=functools.partial(_subproc_data_ready, link),
terminated_callback=functools.partial(_subproc_terminated, link))
@ -590,11 +600,11 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
if not _protocol_check_magic(remote_version):
raise Exception("Request magic incorrect")
if not remote_version == _PROTOCOL_VERSION_0 and not remote_version == _PROTOCOL_VERSION_1:
if not remote_version <= _PROTOCOL_VERSION_DEFAULT:
return Session.error_response("Listener<->initiator version mismatch")
cmd = _cmd
if remote_version == _PROTOCOL_VERSION_1:
if remote_version >= _PROTOCOL_VERSION_1:
remote_command = data[Session.REQUEST_IDX_CMD]
if remote_command is not None and len(remote_command) > 0:
if _no_remote_command:
@ -620,7 +630,7 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
loop=_loop)
# leave significant headroom for metadata and encoding
result = session.process_request(data, link.MDU * 1 // 2)
result = session.process_request(data, _protocol_response_chars_take(link.MDU, remote_version))
return result
# return ProcessState.default_response()
except Exception as e:
@ -731,7 +741,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
request = Session.default_request(sys.stdin.fileno())
log.debug(f"Sending {len(stdin) or 0} bytes to listener")
# log.debug(f"Sending {stdin} to listener")
request[Session.REQUEST_IDX_STDIN] = (base64.b64encode(stdin) if stdin is not None else None)
request[Session.REQUEST_IDX_STDIN] = bytes(stdin)
request[Session.REQUEST_IDX_CMD] = cmd
# TODO: Tune
@ -775,15 +785,15 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
version = request_receipt.response[Session.RESPONSE_IDX_VERSION] or 0
if not _protocol_check_magic(version):
raise RemoteExecutionError("Protocol error")
elif version != _PROTOCOL_VERSION_0 and version != _PROTOCOL_VERSION_1:
elif version != _PROTOCOL_VERSION_2:
raise RemoteExecutionError("Protocol version mismatch")
running = request_receipt.response[Session.RESPONSE_IDX_RUNNING] or True
return_code = request_receipt.response[Session.RESPONSE_IDX_RETCODE]
ready_bytes = request_receipt.response[Session.RESPONSE_IDX_RDYBYTE] or 0
stdout = request_receipt.response[Session.RESPONSE_IDX_STDOUT]
if stdout is not None:
stdout = base64.b64decode(stdout)
# if stdout is not None:
# stdout = base64.b64decode(stdout)
timestamp = request_receipt.response[Session.RESPONSE_IDX_TMSTAMP]
# log.debug("data: " + (stdout.decode("utf-8") if stdout is not None else ""))
except RemoteExecutionError:
@ -871,7 +881,10 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
if first_loop:
first_loop = False
mdu = _link.MDU // 2
mdu = _protocol_request_chars_take(_link.MDU,
_PROTOCOL_VERSION_DEFAULT,
os.environ.get("TERM", ""),
" ".join(command))
_new_data.set()
if _link: