Support for pipes.

Connects pipes on the listener child process the same way they are connected on the initiator--any combination of stdin, stdout, and stderr could be connected to a pipe. Any one of those not connected to a pipe will be connected to a pty.
This commit is contained in:
Aaron Heise 2023-02-16 00:16:04 -06:00
parent 27664df0b3
commit ebdcb50265
4 changed files with 262 additions and 97 deletions

View file

@ -227,9 +227,10 @@ def _protocol_make_version(version: int):
_PROTOCOL_VERSION_0 = _protocol_make_version(0)
_PROTOCOL_VERSION_1 = _protocol_make_version(1)
_PROTOCOL_VERSION_2 = _protocol_make_version(1)
_PROTOCOL_VERSION_2 = _protocol_make_version(2)
_PROTOCOL_VERSION_3 = _protocol_make_version(3)
_PROTOCOL_VERSION_DEFAULT = _PROTOCOL_VERSION_2
_PROTOCOL_VERSION_DEFAULT = _PROTOCOL_VERSION_3
def _protocol_split_version(version: int):
return (version >> 32) & 0xffffffff, version & 0xffffffff
@ -253,6 +254,16 @@ def _protocol_request_chars_take(link_mdu: int, version: int, term: str, cmd: st
return link_mdu // 2
def _bitwise_or_if(value: int, condition: bool, orval: int):
if not condition:
return value
return value | orval
def _check_and(value: int, andval: int) -> bool:
return (value & andval) > 0
class Session:
_processes: [(any, Session)] = []
_lock = threading.RLock()
@ -279,6 +290,7 @@ class Session:
cmd: [str],
data_available_callback: callable,
terminated_callback: callable,
session_flags: int,
term: str | None,
remote_identity: str | None,
loop: asyncio.AbstractEventLoop = None):
@ -290,15 +302,24 @@ class Session:
"RNS_REMOTE_IDENTITY": remote_identity or ""},
loop=loop,
stdout_callback=self._stdout_data,
terminated_callback=terminated_callback)
stderr_callback=self._stderr_data,
terminated_callback=terminated_callback,
stdin_is_pipe=_check_and(session_flags,
Session.REQUEST_FLAGS_PIPE_STDIN),
stdout_is_pipe=_check_and(session_flags,
Session.REQUEST_FLAGS_PIPE_STDOUT),
stderr_is_pipe=_check_and(session_flags,
Session.REQUEST_FLAGS_PIPE_STDERR))
self._log.debug(f"Starting {cmd}")
self._data_buffer = bytearray()
self._stdout_buffer = bytearray()
self._stderr_buffer = bytearray()
self._lock = threading.RLock()
self._data_available_cb = data_available_callback
self._terminated_cb = terminated_callback
self._pending_receipt: RNS.PacketReceipt | None = None
self._process.start()
self._term_state: [int] = None
self._session_flags: int = session_flags
Session.put_for_tag(tag, self)
def pending_receipt_peek(self) -> RNS.PacketReceipt | None:
@ -326,22 +347,39 @@ class Session:
def lock(self) -> threading.RLock:
return self._lock
def read(self, count: int) -> bytes:
def read_stdout(self, count: int) -> bytes:
with self.lock:
initial_len = len(self._data_buffer)
take = self._data_buffer[:count]
self._data_buffer = self._data_buffer[count:].copy()
self._log.debug(f"read {len(take)} bytes of {initial_len}, {len(self._data_buffer)} remaining")
initial_len = len(self._stdout_buffer)
take = self._stdout_buffer[:count]
self._stdout_buffer = self._stdout_buffer[count:]
self._log.debug(f"stdout: read {len(take)} bytes of {initial_len}, {len(self._stdout_buffer)} remaining")
return take
def _stdout_data(self, data: bytes):
with self.lock:
self._data_buffer.extend(data)
total_available = len(self._data_buffer)
self._stdout_buffer.extend(data)
total_available = len(self._stdout_buffer) + len(self._stderr_buffer)
try:
self._data_available_cb(total_available)
except Exception as e:
self._log.error(f"Error calling ProcessState data_available_callback {e}")
self._log.error(f"stdout: error calling ProcessState data_available_callback {e}")
def read_stderr(self, count: int) -> bytes:
with self.lock:
initial_len = len(self._stderr_buffer)
take = self._stderr_buffer[:count]
self._stderr_buffer = self._stderr_buffer[count:]
self._log.debug(f"stderr: read {len(take)} bytes of {initial_len}, {len(self._stderr_buffer)} remaining")
return take
def _stderr_data(self, data: bytes):
with self.lock:
self._stderr_buffer.extend(data)
total_available = len(self._stderr_buffer) + len(self._stdout_buffer)
try:
self._data_available_cb(total_available)
except Exception as e:
self._log.error(f"stderr: error calling ProcessState data_available_callback {e}")
TERMSTATE_IDX_TERM = 0
TERMSTATE_IDX_TIOS = 1
@ -368,30 +406,44 @@ class Session:
REQUEST_IDX_HPIX = 6
REQUEST_IDX_VPIX = 7
REQUEST_IDX_CMD = 8
REQUEST_IDX_FLAGS = 9
REQUEST_IDX_BYTES_AVAILABLE = 10
REQUEST_FLAGS_PIPE_STDIN = 0x01
REQUEST_FLAGS_PIPE_STDOUT = 0x02
REQUEST_FLAGS_PIPE_STDERR = 0x04
REQUEST_FLAGS_EOF_STDIN = 0x08
@staticmethod
def default_request(stdin_fd: int | None) -> [any]:
def default_request() -> [any]:
global _tr
request: list[any] = [
_PROTOCOL_VERSION_DEFAULT, # 0 Protocol Version
None, # 1 Stdin
None, # 2 TERM variable
None, # 3 termios attributes or something
None, # 4 terminal rows
None, # 5 terminal cols
None, # 6 terminal horizontal pixels
None, # 7 terminal vertical pixels
None, # 8 Command to run
None, # 1 Stdin
None, # 2 TERM variable
None, # 3 termios attributes or something
None, # 4 terminal rows
None, # 5 terminal cols
None, # 6 terminal horizontal pixels
None, # 7 terminal vertical pixels
None, # 8 Command to run
0, # 9 Flags
0, # 10 Bytes Available
].copy()
if stdin_fd is not None:
if os.isatty(0):
request[Session.REQUEST_IDX_TERM] = os.environ.get("TERM", None)
request[Session.REQUEST_IDX_TIOS] = _tr.original_attr() if _tr else None
with contextlib.suppress(OSError):
request[Session.REQUEST_IDX_ROWS], \
request[Session.REQUEST_IDX_COLS], \
request[Session.REQUEST_IDX_HPIX], \
request[Session.REQUEST_IDX_VPIX] = process.tty_get_winsize(stdin_fd)
request[Session.REQUEST_IDX_VPIX] = process.tty_get_winsize(0)
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], not os.isatty(0),
Session.REQUEST_FLAGS_PIPE_STDIN)
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], not os.isatty(1),
Session.REQUEST_FLAGS_PIPE_STDOUT)
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], not os.isatty(2),
Session.REQUEST_FLAGS_PIPE_STDERR)
return request
def process_request(self, data: [any], read_size: int) -> [any]:
@ -403,12 +455,16 @@ class Session:
# hpix = data[ProcessState.REQUEST_IDX_HPIX] # window horizontal pixels
# vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels
# term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX+1]
bytes_available = data[Session.REQUEST_IDX_BYTES_AVAILABLE]
flags = data[Session.REQUEST_IDX_FLAGS]
stdin_eof = _check_and(flags, Session.REQUEST_FLAGS_EOF_STDIN)
response = Session.default_response()
first_term_state = self._term_state is None
term_state = data[Session.REQUEST_IDX_TIOS:Session.REQUEST_IDX_VPIX + 1]
response[Session.RESPONSE_IDX_RUNNING] = self.process.running
response[Session.RESPONSE_IDX_FLAGS] = _bitwise_or_if(response[Session.RESPONSE_IDX_FLAGS],
self.process.running, Session.RESPONSE_FLAGS_RUNNING)
if self.process.running:
if term_state != self._term_state:
self._term_state = term_state
@ -422,45 +478,54 @@ class Session:
if data[Session.REQUEST_IDX_VERS] < _PROTOCOL_VERSION_2:
stdin = base64.b64decode(stdin)
self.process.write(stdin)
if stdin_eof and bytes_available == 0:
module_logger.debug("Closing stdin")
with contextlib.suppress(Exception):
self.process.close_stdin()
response[Session.RESPONSE_IDX_RETCODE] = None if self.process.running else self.return_code
with self.lock:
stdout = self.read(read_size)
response[Session.RESPONSE_IDX_RDYBYTE] = len(self._data_buffer)
#prioritizing stderr
stderr = self.read_stderr(read_size)
stdout = self.read_stdout(read_size - len(stderr))
response[Session.RESPONSE_IDX_RDYBYTE] = len(self._stdout_buffer) + len(self._stderr_buffer)
if stderr is not None and len(stderr) > 0:
response[Session.RESPONSE_IDX_STDERR] = bytes(stderr)
if stdout is not None and len(stdout) > 0:
if data[Session.REQUEST_IDX_VERS] < _PROTOCOL_VERSION_2:
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode(stdout).decode("utf-8")
else:
response[Session.RESPONSE_IDX_STDOUT] = bytes(stdout)
response[Session.RESPONSE_IDX_STDOUT] = bytes(stdout)
return response
RESPONSE_IDX_VERSION = 0
RESPONSE_IDX_RUNNING = 1
RESPONSE_IDX_FLAGS = 1
RESPONSE_IDX_RETCODE = 2
RESPONSE_IDX_RDYBYTE = 3
RESPONSE_IDX_STDOUT = 4
RESPONSE_IDX_TMSTAMP = 5
RESPONSE_IDX_STDERR = 4
RESPONSE_IDX_STDOUT = 5
RESPONSE_IDX_TMSTAMP = 6
RESPONSE_FLAGS_RUNNING = 0x01
RESPONSE_FLAGS_EOF_STDOUT = 0x02
RESPONSE_FLAGS_EOF_STDERR = 0x04
@staticmethod
def default_response(version: int = _PROTOCOL_VERSION_2) -> [any]:
def default_response() -> [any]:
response: list[any] = [
version, # 0: Protocol version
_PROTOCOL_VERSION_DEFAULT, # 0: Protocol version
False, # 1: Process running
None, # 2: Return value
0, # 3: Number of outstanding bytes
None, # 4: Stdout/Stderr
None, # 5: Timestamp
None, # 4: Stderr
None, # 5: Stdout
None, # 6: Timestamp
].copy()
response[Session.RESPONSE_IDX_TMSTAMP] = time.time()
return response
@classmethod
def error_response(cls, msg: str, version: int = _PROTOCOL_VERSION_2) -> [any]:
response = cls.default_response(version)
def error_response(cls, msg: str) -> [any]:
response = cls.default_response()
msg_bytes = f"{msg}\r\n".encode("utf-8")
response[Session.RESPONSE_IDX_STDOUT] = \
base64.b64encode(msg_bytes) if version < _PROTOCOL_VERSION_2 else bytes(msg_bytes)
response[Session.RESPONSE_IDX_STDERR] = bytes(msg_bytes)
response[Session.RESPONSE_IDX_RETCODE] = 255
response[Session.RESPONSE_IDX_RDYBYTE] = 0
return response
@ -543,11 +608,12 @@ def _subproc_terminated(link: RNS.Link, return_code: int):
def _listen_start_proc(link: RNS.Link, remote_identity: str | None, term: str, cmd: [str],
loop: asyncio.AbstractEventLoop) -> Session | None:
loop: asyncio.AbstractEventLoop, session_flags: int) -> Session | None:
log = _get_logger("_listen_start_proc")
try:
return Session(tag=link.link_id,
cmd=cmd,
session_flags=session_flags,
term=term,
remote_identity=remote_identity,
loop=loop,
@ -602,25 +668,27 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
_retry_timer.complete(link_id)
link: RNS.Link = next(filter(lambda l: l.link_id == link_id, _destination.links), None)
if link is None:
raise Exception(f"Invalid request {request_id}, no link found with id {link_id}")
log.error(f"Invalid request {request_id}, no link found with id {link_id}")
return None
remote_version = data[Session.REQUEST_IDX_VERS]
if not _protocol_check_magic(remote_version):
raise Exception("Request magic incorrect")
log.error("Request magic incorrect")
link.teardown()
return None
if not remote_version <= _PROTOCOL_VERSION_DEFAULT:
if not remote_version <= _PROTOCOL_VERSION_3:
return Session.error_response("Listener<->initiator version mismatch")
cmd = _cmd.copy()
if remote_version >= _PROTOCOL_VERSION_1:
remote_command = data[Session.REQUEST_IDX_CMD]
if remote_command is not None and len(remote_command) > 0:
if _no_remote_command:
return Session.error_response("Listener does not permit initiator to provide command.")
elif _remote_cmd_as_args:
cmd.extend(remote_command)
else:
cmd = remote_command
remote_command = data[Session.REQUEST_IDX_CMD]
if remote_command is not None and len(remote_command) > 0:
if _no_remote_command:
return Session.error_response("Listener does not permit initiator to provide command.")
elif _remote_cmd_as_args:
cmd.extend(remote_command)
else:
cmd = remote_command
if not _no_remote_command and (cmd is None or len(cmd) == 0):
return Session.error_response("No command supplied and no default command available.")
@ -636,9 +704,12 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
log.debug(f"Process not found for link {link}")
session = _listen_start_proc(link=link,
term=term,
session_flags=data[Session.REQUEST_IDX_FLAGS],
cmd=cmd,
remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""),
loop=_loop)
if session is None:
return Session.error_response("Unable to start subprocess")
# leave significant headroom for metadata and encoding
result = session.process_request(data, _protocol_response_chars_take(link.MDU, remote_version))
@ -695,7 +766,7 @@ def _response_handler(request_receipt: RNS.RequestReceipt):
async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=False, destination=None,
service_name="default", stdin=None, timeout=RNS.Transport.PATH_REQUEST_TIMEOUT,
cmd: [str] | None = None):
cmd: [str] | None = None, stdin_eof=False, bytes_available=0):
global _identity, _reticulum, _link, _destination, _remote_exec_grace, _tr, _new_data
log = _get_logger("_execute")
@ -749,11 +820,14 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
_link.set_packet_callback(_client_packet_handler)
request = Session.default_request(sys.stdin.fileno())
request = Session.default_request()
log.debug(f"Sending {len(stdin) or 0} bytes to listener")
# log.debug(f"Sending {stdin} to listener")
request[Session.REQUEST_IDX_STDIN] = bytes(stdin)
request[Session.REQUEST_IDX_CMD] = cmd
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], stdin_eof,
Session.REQUEST_FLAGS_EOF_STDIN)
request[Session.REQUEST_IDX_BYTES_AVAILABLE] = bytes_available
# TODO: Tune
timeout = timeout + _link.rtt * 4 + _remote_exec_grace
@ -796,13 +870,17 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
version = request_receipt.response[Session.RESPONSE_IDX_VERSION] or 0
if not _protocol_check_magic(version):
raise RemoteExecutionError("Protocol error")
elif version != _PROTOCOL_VERSION_2:
elif version != _PROTOCOL_VERSION_3:
raise RemoteExecutionError("Protocol version mismatch")
running = request_receipt.response[Session.RESPONSE_IDX_RUNNING] or True
flags = request_receipt.response[Session.RESPONSE_IDX_FLAGS]
running = _check_and(flags, Session.RESPONSE_FLAGS_RUNNING)
stdout_eof = _check_and(flags, Session.RESPONSE_FLAGS_EOF_STDOUT)
stderr_eof = _check_and(flags, Session.RESPONSE_FLAGS_EOF_STDERR)
return_code = request_receipt.response[Session.RESPONSE_IDX_RETCODE]
ready_bytes = request_receipt.response[Session.RESPONSE_IDX_RDYBYTE] or 0
stdout = request_receipt.response[Session.RESPONSE_IDX_STDOUT]
stderr = request_receipt.response[Session.RESPONSE_IDX_STDERR]
# if stdout is not None:
# stdout = base64.b64decode(stdout)
timestamp = request_receipt.response[Session.RESPONSE_IDX_TMSTAMP]
@ -815,10 +893,24 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
if stdout is not None:
_tr.raw()
log.debug(f"stdout: {stdout}")
os.write(sys.stdout.fileno(), stdout)
os.write(1, stdout)
sys.stdout.flush()
got_bytes = len(stdout) if stdout is not None else 0
if stderr is not None:
_tr.raw()
log.debug(f"stderr: {stderr}")
os.write(2, stderr)
sys.stderr.flush()
if stderr_eof and ready_bytes == 0:
log.debug("Closing stderr")
os.close(2)
if stdout_eof and ready_bytes == 0:
log.debug("Closing stdout")
os.close(1)
got_bytes = (len(stdout) if stdout is not None else 0) + (len(stderr) if stderr is not None else 0)
log.debug(f"{got_bytes} chars received, {ready_bytes} bytes ready on server, return code {return_code}")
if ready_bytes > 0:
@ -847,7 +939,9 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
if _new_data is not None:
_new_data.set()
stdin_eof = False
def stdin():
nonlocal stdin_eof
try:
data = process.tty_read(sys.stdin.fileno())
log.debug(f"stdin {data}")
@ -855,7 +949,9 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
data_buffer.extend(data)
_new_data.set()
except EOFError:
data_buffer.extend(process.CTRL_D)
if os.isatty(0):
data_buffer.extend(process.CTRL_D)
stdin_eof = True
process.tty_unset_reader_callbacks(sys.stdin.fileno())
process.tty_add_reader_callback(sys.stdin.fileno(), stdin)
@ -886,6 +982,8 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
stdin=stdin,
timeout=timeout,
cmd=command,
stdin_eof=stdin_eof,
bytes_available=len(data_buffer)
)
if first_loop:
@ -975,11 +1073,6 @@ async def _rnsh_cli_main():
def rnsh_cli():
global _tr, _retry_timer, _pre_input
with contextlib.suppress(Exception):
if not os.isatty(sys.stdin.fileno()):
time.sleep(0.1) # attempting to deal with an issue with missing input
tty.setraw(sys.stdin.fileno(), termios.TCSADRAIN)
with process.TTYRestorer(sys.stdin.fileno()) as _tr, retry.RetryThread() as _retry_timer:
return_code = asyncio.run(_rnsh_cli_main())