mirror of
https://github.com/markqvist/rnsh.git
synced 2025-02-02 17:04:40 -05:00
Protocol version 2: do not base64 encode stream data
This commit is contained in:
parent
2c61fdf391
commit
a3bd1f20fb
65
rnsh/rnsh.py
65
rnsh/rnsh.py
@ -219,7 +219,9 @@ def _protocol_make_version(version: int):
|
|||||||
|
|
||||||
_PROTOCOL_VERSION_0 = _protocol_make_version(0)
|
_PROTOCOL_VERSION_0 = _protocol_make_version(0)
|
||||||
_PROTOCOL_VERSION_1 = _protocol_make_version(1)
|
_PROTOCOL_VERSION_1 = _protocol_make_version(1)
|
||||||
|
_PROTOCOL_VERSION_2 = _protocol_make_version(1)
|
||||||
|
|
||||||
|
_PROTOCOL_VERSION_DEFAULT = _PROTOCOL_VERSION_2
|
||||||
|
|
||||||
def _protocol_split_version(version: int):
|
def _protocol_split_version(version: int):
|
||||||
return (version >> 32) & 0xffffffff, version & 0xffffffff
|
return (version >> 32) & 0xffffffff, version & 0xffffffff
|
||||||
@ -229,6 +231,20 @@ def _protocol_check_magic(value: int):
|
|||||||
return _protocol_split_version(value)[0] == _PROTOCOL_MAGIC
|
return _protocol_split_version(value)[0] == _PROTOCOL_MAGIC
|
||||||
|
|
||||||
|
|
||||||
|
def _protocol_response_chars_take(link_mdu: int, version: int) -> int:
|
||||||
|
if version >= _PROTOCOL_VERSION_2:
|
||||||
|
return link_mdu - 64 # TODO: tune
|
||||||
|
else:
|
||||||
|
return link_mdu // 2
|
||||||
|
|
||||||
|
|
||||||
|
def _protocol_request_chars_take(link_mdu: int, version: int, term: str, cmd: str) -> int:
|
||||||
|
if version >= _PROTOCOL_VERSION_2:
|
||||||
|
return link_mdu - 15 * 8 - len(term) - len(cmd) - 20 # TODO: tune
|
||||||
|
else:
|
||||||
|
return link_mdu // 2
|
||||||
|
|
||||||
|
|
||||||
class Session:
|
class Session:
|
||||||
_processes: [(any, Session)] = []
|
_processes: [(any, Session)] = []
|
||||||
_lock = threading.RLock()
|
_lock = threading.RLock()
|
||||||
@ -253,7 +269,6 @@ class Session:
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
tag: any,
|
tag: any,
|
||||||
cmd: [str],
|
cmd: [str],
|
||||||
mdu: int,
|
|
||||||
data_available_callback: callable,
|
data_available_callback: callable,
|
||||||
terminated_callback: callable,
|
terminated_callback: callable,
|
||||||
term: str | None,
|
term: str | None,
|
||||||
@ -261,7 +276,6 @@ class Session:
|
|||||||
loop: asyncio.AbstractEventLoop = None):
|
loop: asyncio.AbstractEventLoop = None):
|
||||||
|
|
||||||
self._log = _get_logger(self.__class__.__name__)
|
self._log = _get_logger(self.__class__.__name__)
|
||||||
self._mdu = mdu
|
|
||||||
self._loop = loop if loop is not None else asyncio.get_running_loop()
|
self._loop = loop if loop is not None else asyncio.get_running_loop()
|
||||||
self._process = process.CallbackSubprocess(argv=cmd,
|
self._process = process.CallbackSubprocess(argv=cmd,
|
||||||
env={"TERM": term or os.environ.get("TERM", None),
|
env={"TERM": term or os.environ.get("TERM", None),
|
||||||
@ -279,14 +293,6 @@ class Session:
|
|||||||
self._term_state: [int] = None
|
self._term_state: [int] = None
|
||||||
Session.put_for_tag(tag, self)
|
Session.put_for_tag(tag, self)
|
||||||
|
|
||||||
@property
|
|
||||||
def mdu(self) -> int:
|
|
||||||
return self._mdu
|
|
||||||
|
|
||||||
@mdu.setter
|
|
||||||
def mdu(self, val: int):
|
|
||||||
self._mdu = val
|
|
||||||
|
|
||||||
def pending_receipt_peek(self) -> RNS.PacketReceipt | None:
|
def pending_receipt_peek(self) -> RNS.PacketReceipt | None:
|
||||||
return self._pending_receipt
|
return self._pending_receipt
|
||||||
|
|
||||||
@ -358,9 +364,8 @@ class Session:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def default_request(stdin_fd: int | None) -> [any]:
|
def default_request(stdin_fd: int | None) -> [any]:
|
||||||
global _tr
|
global _tr
|
||||||
global _PROTOCOL_VERSION_1
|
|
||||||
request: list[any] = [
|
request: list[any] = [
|
||||||
_PROTOCOL_VERSION_1, # 0 Protocol Version
|
_PROTOCOL_VERSION_DEFAULT, # 0 Protocol Version
|
||||||
None, # 1 Stdin
|
None, # 1 Stdin
|
||||||
None, # 2 TERM variable
|
None, # 2 TERM variable
|
||||||
None, # 3 termios attributes or something
|
None, # 3 termios attributes or something
|
||||||
@ -406,7 +411,8 @@ class Session:
|
|||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
self.process.tcsetattr(termios.TCSANOW, term_state[0])
|
self.process.tcsetattr(termios.TCSANOW, term_state[0])
|
||||||
if stdin is not None and len(stdin) > 0:
|
if stdin is not None and len(stdin) > 0:
|
||||||
stdin = base64.b64decode(stdin)
|
if data[Session.REQUEST_IDX_VERS] < _PROTOCOL_VERSION_2:
|
||||||
|
stdin = base64.b64decode(stdin)
|
||||||
self.process.write(stdin)
|
self.process.write(stdin)
|
||||||
response[Session.RESPONSE_IDX_RETCODE] = None if self.process.running else self.return_code
|
response[Session.RESPONSE_IDX_RETCODE] = None if self.process.running else self.return_code
|
||||||
|
|
||||||
@ -415,7 +421,10 @@ class Session:
|
|||||||
response[Session.RESPONSE_IDX_RDYBYTE] = len(self._data_buffer)
|
response[Session.RESPONSE_IDX_RDYBYTE] = len(self._data_buffer)
|
||||||
|
|
||||||
if stdout is not None and len(stdout) > 0:
|
if stdout is not None and len(stdout) > 0:
|
||||||
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode(stdout).decode("utf-8")
|
if data[Session.REQUEST_IDX_VERS] < _PROTOCOL_VERSION_2:
|
||||||
|
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode(stdout).decode("utf-8")
|
||||||
|
else:
|
||||||
|
response[Session.RESPONSE_IDX_STDOUT] = bytes(stdout)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
RESPONSE_IDX_VERSION = 0
|
RESPONSE_IDX_VERSION = 0
|
||||||
@ -426,7 +435,7 @@ class Session:
|
|||||||
RESPONSE_IDX_TMSTAMP = 5
|
RESPONSE_IDX_TMSTAMP = 5
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def default_response(version: int = _PROTOCOL_VERSION_1) -> [any]:
|
def default_response(version: int = _PROTOCOL_VERSION_2) -> [any]:
|
||||||
response: list[any] = [
|
response: list[any] = [
|
||||||
version, # 0: Protocol version
|
version, # 0: Protocol version
|
||||||
False, # 1: Process running
|
False, # 1: Process running
|
||||||
@ -439,9 +448,11 @@ class Session:
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def error_response(cls, msg: str, version: int = _PROTOCOL_VERSION_1) -> [any]:
|
def error_response(cls, msg: str, version: int = _PROTOCOL_VERSION_2) -> [any]:
|
||||||
response = cls.default_response(version)
|
response = cls.default_response(version)
|
||||||
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode( f"{msg}\r\n".encode("utf-8"))
|
msg_bytes = f"{msg}\r\n".encode("utf-8")
|
||||||
|
response[Session.RESPONSE_IDX_STDOUT] = \
|
||||||
|
base64.b64encode(msg_bytes) if version < _PROTOCOL_VERSION_2 else bytes(msg_bytes)
|
||||||
response[Session.RESPONSE_IDX_RETCODE] = 255
|
response[Session.RESPONSE_IDX_RETCODE] = 255
|
||||||
response[Session.RESPONSE_IDX_RDYBYTE] = 0
|
response[Session.RESPONSE_IDX_RDYBYTE] = 0
|
||||||
return response
|
return response
|
||||||
@ -531,7 +542,6 @@ def _listen_start_proc(link: RNS.Link, remote_identity: str | None, term: str, c
|
|||||||
cmd=cmd,
|
cmd=cmd,
|
||||||
term=term,
|
term=term,
|
||||||
remote_identity=remote_identity,
|
remote_identity=remote_identity,
|
||||||
mdu=link.MDU,
|
|
||||||
loop=loop,
|
loop=loop,
|
||||||
data_available_callback=functools.partial(_subproc_data_ready, link),
|
data_available_callback=functools.partial(_subproc_data_ready, link),
|
||||||
terminated_callback=functools.partial(_subproc_terminated, link))
|
terminated_callback=functools.partial(_subproc_terminated, link))
|
||||||
@ -590,11 +600,11 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
|
|||||||
if not _protocol_check_magic(remote_version):
|
if not _protocol_check_magic(remote_version):
|
||||||
raise Exception("Request magic incorrect")
|
raise Exception("Request magic incorrect")
|
||||||
|
|
||||||
if not remote_version == _PROTOCOL_VERSION_0 and not remote_version == _PROTOCOL_VERSION_1:
|
if not remote_version <= _PROTOCOL_VERSION_DEFAULT:
|
||||||
return Session.error_response("Listener<->initiator version mismatch")
|
return Session.error_response("Listener<->initiator version mismatch")
|
||||||
|
|
||||||
cmd = _cmd
|
cmd = _cmd
|
||||||
if remote_version == _PROTOCOL_VERSION_1:
|
if remote_version >= _PROTOCOL_VERSION_1:
|
||||||
remote_command = data[Session.REQUEST_IDX_CMD]
|
remote_command = data[Session.REQUEST_IDX_CMD]
|
||||||
if remote_command is not None and len(remote_command) > 0:
|
if remote_command is not None and len(remote_command) > 0:
|
||||||
if _no_remote_command:
|
if _no_remote_command:
|
||||||
@ -620,7 +630,7 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
|
|||||||
loop=_loop)
|
loop=_loop)
|
||||||
|
|
||||||
# leave significant headroom for metadata and encoding
|
# leave significant headroom for metadata and encoding
|
||||||
result = session.process_request(data, link.MDU * 1 // 2)
|
result = session.process_request(data, _protocol_response_chars_take(link.MDU, remote_version))
|
||||||
return result
|
return result
|
||||||
# return ProcessState.default_response()
|
# return ProcessState.default_response()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -731,7 +741,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
|||||||
request = Session.default_request(sys.stdin.fileno())
|
request = Session.default_request(sys.stdin.fileno())
|
||||||
log.debug(f"Sending {len(stdin) or 0} bytes to listener")
|
log.debug(f"Sending {len(stdin) or 0} bytes to listener")
|
||||||
# log.debug(f"Sending {stdin} to listener")
|
# log.debug(f"Sending {stdin} to listener")
|
||||||
request[Session.REQUEST_IDX_STDIN] = (base64.b64encode(stdin) if stdin is not None else None)
|
request[Session.REQUEST_IDX_STDIN] = bytes(stdin)
|
||||||
request[Session.REQUEST_IDX_CMD] = cmd
|
request[Session.REQUEST_IDX_CMD] = cmd
|
||||||
|
|
||||||
# TODO: Tune
|
# TODO: Tune
|
||||||
@ -775,15 +785,15 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
|||||||
version = request_receipt.response[Session.RESPONSE_IDX_VERSION] or 0
|
version = request_receipt.response[Session.RESPONSE_IDX_VERSION] or 0
|
||||||
if not _protocol_check_magic(version):
|
if not _protocol_check_magic(version):
|
||||||
raise RemoteExecutionError("Protocol error")
|
raise RemoteExecutionError("Protocol error")
|
||||||
elif version != _PROTOCOL_VERSION_0 and version != _PROTOCOL_VERSION_1:
|
elif version != _PROTOCOL_VERSION_2:
|
||||||
raise RemoteExecutionError("Protocol version mismatch")
|
raise RemoteExecutionError("Protocol version mismatch")
|
||||||
|
|
||||||
running = request_receipt.response[Session.RESPONSE_IDX_RUNNING] or True
|
running = request_receipt.response[Session.RESPONSE_IDX_RUNNING] or True
|
||||||
return_code = request_receipt.response[Session.RESPONSE_IDX_RETCODE]
|
return_code = request_receipt.response[Session.RESPONSE_IDX_RETCODE]
|
||||||
ready_bytes = request_receipt.response[Session.RESPONSE_IDX_RDYBYTE] or 0
|
ready_bytes = request_receipt.response[Session.RESPONSE_IDX_RDYBYTE] or 0
|
||||||
stdout = request_receipt.response[Session.RESPONSE_IDX_STDOUT]
|
stdout = request_receipt.response[Session.RESPONSE_IDX_STDOUT]
|
||||||
if stdout is not None:
|
# if stdout is not None:
|
||||||
stdout = base64.b64decode(stdout)
|
# stdout = base64.b64decode(stdout)
|
||||||
timestamp = request_receipt.response[Session.RESPONSE_IDX_TMSTAMP]
|
timestamp = request_receipt.response[Session.RESPONSE_IDX_TMSTAMP]
|
||||||
# log.debug("data: " + (stdout.decode("utf-8") if stdout is not None else ""))
|
# log.debug("data: " + (stdout.decode("utf-8") if stdout is not None else ""))
|
||||||
except RemoteExecutionError:
|
except RemoteExecutionError:
|
||||||
@ -871,7 +881,10 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
|
|||||||
|
|
||||||
if first_loop:
|
if first_loop:
|
||||||
first_loop = False
|
first_loop = False
|
||||||
mdu = _link.MDU // 2
|
mdu = _protocol_request_chars_take(_link.MDU,
|
||||||
|
_PROTOCOL_VERSION_DEFAULT,
|
||||||
|
os.environ.get("TERM", ""),
|
||||||
|
" ".join(command))
|
||||||
_new_data.set()
|
_new_data.set()
|
||||||
|
|
||||||
if _link:
|
if _link:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user