mirror of
https://github.com/markqvist/rnsh.git
synced 2025-01-23 04:31:01 -05:00
Protocol version 2: do not base64 encode stream data
This commit is contained in:
parent
2c61fdf391
commit
a3bd1f20fb
65
rnsh/rnsh.py
65
rnsh/rnsh.py
@ -219,7 +219,9 @@ def _protocol_make_version(version: int):
|
||||
|
||||
_PROTOCOL_VERSION_0 = _protocol_make_version(0)
|
||||
_PROTOCOL_VERSION_1 = _protocol_make_version(1)
|
||||
_PROTOCOL_VERSION_2 = _protocol_make_version(1)
|
||||
|
||||
_PROTOCOL_VERSION_DEFAULT = _PROTOCOL_VERSION_2
|
||||
|
||||
def _protocol_split_version(version: int):
|
||||
return (version >> 32) & 0xffffffff, version & 0xffffffff
|
||||
@ -229,6 +231,20 @@ def _protocol_check_magic(value: int):
|
||||
return _protocol_split_version(value)[0] == _PROTOCOL_MAGIC
|
||||
|
||||
|
||||
def _protocol_response_chars_take(link_mdu: int, version: int) -> int:
|
||||
if version >= _PROTOCOL_VERSION_2:
|
||||
return link_mdu - 64 # TODO: tune
|
||||
else:
|
||||
return link_mdu // 2
|
||||
|
||||
|
||||
def _protocol_request_chars_take(link_mdu: int, version: int, term: str, cmd: str) -> int:
|
||||
if version >= _PROTOCOL_VERSION_2:
|
||||
return link_mdu - 15 * 8 - len(term) - len(cmd) - 20 # TODO: tune
|
||||
else:
|
||||
return link_mdu // 2
|
||||
|
||||
|
||||
class Session:
|
||||
_processes: [(any, Session)] = []
|
||||
_lock = threading.RLock()
|
||||
@ -253,7 +269,6 @@ class Session:
|
||||
def __init__(self,
|
||||
tag: any,
|
||||
cmd: [str],
|
||||
mdu: int,
|
||||
data_available_callback: callable,
|
||||
terminated_callback: callable,
|
||||
term: str | None,
|
||||
@ -261,7 +276,6 @@ class Session:
|
||||
loop: asyncio.AbstractEventLoop = None):
|
||||
|
||||
self._log = _get_logger(self.__class__.__name__)
|
||||
self._mdu = mdu
|
||||
self._loop = loop if loop is not None else asyncio.get_running_loop()
|
||||
self._process = process.CallbackSubprocess(argv=cmd,
|
||||
env={"TERM": term or os.environ.get("TERM", None),
|
||||
@ -279,14 +293,6 @@ class Session:
|
||||
self._term_state: [int] = None
|
||||
Session.put_for_tag(tag, self)
|
||||
|
||||
@property
|
||||
def mdu(self) -> int:
|
||||
return self._mdu
|
||||
|
||||
@mdu.setter
|
||||
def mdu(self, val: int):
|
||||
self._mdu = val
|
||||
|
||||
def pending_receipt_peek(self) -> RNS.PacketReceipt | None:
|
||||
return self._pending_receipt
|
||||
|
||||
@ -358,9 +364,8 @@ class Session:
|
||||
@staticmethod
|
||||
def default_request(stdin_fd: int | None) -> [any]:
|
||||
global _tr
|
||||
global _PROTOCOL_VERSION_1
|
||||
request: list[any] = [
|
||||
_PROTOCOL_VERSION_1, # 0 Protocol Version
|
||||
_PROTOCOL_VERSION_DEFAULT, # 0 Protocol Version
|
||||
None, # 1 Stdin
|
||||
None, # 2 TERM variable
|
||||
None, # 3 termios attributes or something
|
||||
@ -406,7 +411,8 @@ class Session:
|
||||
with contextlib.suppress(Exception):
|
||||
self.process.tcsetattr(termios.TCSANOW, term_state[0])
|
||||
if stdin is not None and len(stdin) > 0:
|
||||
stdin = base64.b64decode(stdin)
|
||||
if data[Session.REQUEST_IDX_VERS] < _PROTOCOL_VERSION_2:
|
||||
stdin = base64.b64decode(stdin)
|
||||
self.process.write(stdin)
|
||||
response[Session.RESPONSE_IDX_RETCODE] = None if self.process.running else self.return_code
|
||||
|
||||
@ -415,7 +421,10 @@ class Session:
|
||||
response[Session.RESPONSE_IDX_RDYBYTE] = len(self._data_buffer)
|
||||
|
||||
if stdout is not None and len(stdout) > 0:
|
||||
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode(stdout).decode("utf-8")
|
||||
if data[Session.REQUEST_IDX_VERS] < _PROTOCOL_VERSION_2:
|
||||
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode(stdout).decode("utf-8")
|
||||
else:
|
||||
response[Session.RESPONSE_IDX_STDOUT] = bytes(stdout)
|
||||
return response
|
||||
|
||||
RESPONSE_IDX_VERSION = 0
|
||||
@ -426,7 +435,7 @@ class Session:
|
||||
RESPONSE_IDX_TMSTAMP = 5
|
||||
|
||||
@staticmethod
|
||||
def default_response(version: int = _PROTOCOL_VERSION_1) -> [any]:
|
||||
def default_response(version: int = _PROTOCOL_VERSION_2) -> [any]:
|
||||
response: list[any] = [
|
||||
version, # 0: Protocol version
|
||||
False, # 1: Process running
|
||||
@ -439,9 +448,11 @@ class Session:
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def error_response(cls, msg: str, version: int = _PROTOCOL_VERSION_1) -> [any]:
|
||||
def error_response(cls, msg: str, version: int = _PROTOCOL_VERSION_2) -> [any]:
|
||||
response = cls.default_response(version)
|
||||
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode( f"{msg}\r\n".encode("utf-8"))
|
||||
msg_bytes = f"{msg}\r\n".encode("utf-8")
|
||||
response[Session.RESPONSE_IDX_STDOUT] = \
|
||||
base64.b64encode(msg_bytes) if version < _PROTOCOL_VERSION_2 else bytes(msg_bytes)
|
||||
response[Session.RESPONSE_IDX_RETCODE] = 255
|
||||
response[Session.RESPONSE_IDX_RDYBYTE] = 0
|
||||
return response
|
||||
@ -531,7 +542,6 @@ def _listen_start_proc(link: RNS.Link, remote_identity: str | None, term: str, c
|
||||
cmd=cmd,
|
||||
term=term,
|
||||
remote_identity=remote_identity,
|
||||
mdu=link.MDU,
|
||||
loop=loop,
|
||||
data_available_callback=functools.partial(_subproc_data_ready, link),
|
||||
terminated_callback=functools.partial(_subproc_terminated, link))
|
||||
@ -590,11 +600,11 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
|
||||
if not _protocol_check_magic(remote_version):
|
||||
raise Exception("Request magic incorrect")
|
||||
|
||||
if not remote_version == _PROTOCOL_VERSION_0 and not remote_version == _PROTOCOL_VERSION_1:
|
||||
if not remote_version <= _PROTOCOL_VERSION_DEFAULT:
|
||||
return Session.error_response("Listener<->initiator version mismatch")
|
||||
|
||||
cmd = _cmd
|
||||
if remote_version == _PROTOCOL_VERSION_1:
|
||||
if remote_version >= _PROTOCOL_VERSION_1:
|
||||
remote_command = data[Session.REQUEST_IDX_CMD]
|
||||
if remote_command is not None and len(remote_command) > 0:
|
||||
if _no_remote_command:
|
||||
@ -620,7 +630,7 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
|
||||
loop=_loop)
|
||||
|
||||
# leave significant headroom for metadata and encoding
|
||||
result = session.process_request(data, link.MDU * 1 // 2)
|
||||
result = session.process_request(data, _protocol_response_chars_take(link.MDU, remote_version))
|
||||
return result
|
||||
# return ProcessState.default_response()
|
||||
except Exception as e:
|
||||
@ -731,7 +741,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
||||
request = Session.default_request(sys.stdin.fileno())
|
||||
log.debug(f"Sending {len(stdin) or 0} bytes to listener")
|
||||
# log.debug(f"Sending {stdin} to listener")
|
||||
request[Session.REQUEST_IDX_STDIN] = (base64.b64encode(stdin) if stdin is not None else None)
|
||||
request[Session.REQUEST_IDX_STDIN] = bytes(stdin)
|
||||
request[Session.REQUEST_IDX_CMD] = cmd
|
||||
|
||||
# TODO: Tune
|
||||
@ -775,15 +785,15 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
||||
version = request_receipt.response[Session.RESPONSE_IDX_VERSION] or 0
|
||||
if not _protocol_check_magic(version):
|
||||
raise RemoteExecutionError("Protocol error")
|
||||
elif version != _PROTOCOL_VERSION_0 and version != _PROTOCOL_VERSION_1:
|
||||
elif version != _PROTOCOL_VERSION_2:
|
||||
raise RemoteExecutionError("Protocol version mismatch")
|
||||
|
||||
running = request_receipt.response[Session.RESPONSE_IDX_RUNNING] or True
|
||||
return_code = request_receipt.response[Session.RESPONSE_IDX_RETCODE]
|
||||
ready_bytes = request_receipt.response[Session.RESPONSE_IDX_RDYBYTE] or 0
|
||||
stdout = request_receipt.response[Session.RESPONSE_IDX_STDOUT]
|
||||
if stdout is not None:
|
||||
stdout = base64.b64decode(stdout)
|
||||
# if stdout is not None:
|
||||
# stdout = base64.b64decode(stdout)
|
||||
timestamp = request_receipt.response[Session.RESPONSE_IDX_TMSTAMP]
|
||||
# log.debug("data: " + (stdout.decode("utf-8") if stdout is not None else ""))
|
||||
except RemoteExecutionError:
|
||||
@ -871,7 +881,10 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
|
||||
|
||||
if first_loop:
|
||||
first_loop = False
|
||||
mdu = _link.MDU // 2
|
||||
mdu = _protocol_request_chars_take(_link.MDU,
|
||||
_PROTOCOL_VERSION_DEFAULT,
|
||||
os.environ.get("TERM", ""),
|
||||
" ".join(command))
|
||||
_new_data.set()
|
||||
|
||||
if _link:
|
||||
|
Loading…
Reference in New Issue
Block a user