Support for pipes. #3

Connects pipes on the listener child process the same way they are connected on the initiator--any combination of stdin, stdout, and stderr could be connected to a pipe. Any one of those not connected to a pipe will be connected to a pty.
This commit is contained in:
Aaron Heise 2023-02-16 00:16:04 -06:00
parent 27664df0b3
commit ebdcb50265
4 changed files with 262 additions and 97 deletions

View File

@ -323,12 +323,75 @@ async def event_wait(evt: asyncio.Event, timeout: float) -> bool:
return evt.is_set() return evt.is_set()
def _launch_child(cmd_line: list[str], env: dict[str, str], stdin_is_pipe: bool, stdout_is_pipe: bool,
stderr_is_pipe: bool) -> tuple[int, int, int, int]:
# Set up PTY and/or pipes
child_fd = parent_fd = None
if not (stdin_is_pipe and stdout_is_pipe and stderr_is_pipe):
parent_fd, child_fd = pty.openpty()
child_stdin, parent_stdin = (os.pipe() if stdin_is_pipe else (child_fd, parent_fd))
parent_stdout, child_stdout = (os.pipe() if stdout_is_pipe else (parent_fd, child_fd))
parent_stderr, child_stderr = (os.pipe() if stderr_is_pipe else (parent_fd, child_fd))
# Fork
pid = os.fork()
if pid == 0:
try:
# We are in the child process, so close all open sockets and pipes except for the PTY and/or pipes
max_fd = os.sysconf("SC_OPEN_MAX")
for fd in range(3, max_fd):
if fd not in (child_stdin, child_stdout, child_stderr):
try:
os.close(fd)
except OSError:
pass
# Set up PTY and/or pipes
os.dup2(child_stdin, 0)
os.dup2(child_stdout, 1)
os.dup2(child_stderr, 2)
# Make PTY controlling if necessary
if not stdin_is_pipe:
os.setsid()
tmp_fd = os.open(os.ttyname(0), os.O_RDWR)
os.close(tmp_fd)
# fcntl.ioctl(0, termios.TIOCSCTTY, 0)
# Execute the command
os.execvpe(cmd_line[0], cmd_line, env)
except Exception as err:
exc_type, exc_obj, exc_tb = sys.exc_info()
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
print(f"Unable to start {cmd_line[0]}: {err} ({fname}:{exc_tb.tb_lineno})")
sys.stdout.flush()
# don't let any other modules get in our way, do an immediate silent exit.
os._exit(0)
else:
# We are in the parent process, so close the child-side of the PTY and/or pipes
if child_fd is not None:
os.close(child_fd)
if child_stdin != child_fd:
os.close(child_stdin)
if child_stdout != child_fd:
os.close(child_stdout)
if child_stderr != child_fd:
os.close(child_stderr)
# # Close the write end of the pipe if a pipe is used for standard input
# if not stdin_is_pipe:
# os.close(parent_stdin)
# Return the child PID and the file descriptors for the PTY and/or pipes
return pid, parent_stdin, parent_stdout, parent_stderr
class CallbackSubprocess: class CallbackSubprocess:
# time between checks of child process # time between checks of child process
PROCESS_POLL_TIME: float = 0.1 PROCESS_POLL_TIME: float = 0.1
def __init__(self, argv: [str], env: dict, loop: asyncio.AbstractEventLoop, stdout_callback: callable, def __init__(self, argv: [str], env: dict, loop: asyncio.AbstractEventLoop, stdout_callback: callable,
terminated_callback: callable): stderr_callback: callable, terminated_callback: callable, stdin_is_pipe: bool, stdout_is_pipe: bool,
stderr_is_pipe: bool):
""" """
Fork a child process and generate callbacks with output from the process. Fork a child process and generate callbacks with output from the process.
:param argv: the command line, tokenized. The first element must be the absolute path to an executable file. :param argv: the command line, tokenized. The first element must be the absolute path to an executable file.
@ -347,11 +410,17 @@ class CallbackSubprocess:
self._env = env or {} self._env = env or {}
self._loop = loop self._loop = loop
self._stdout_cb = stdout_callback self._stdout_cb = stdout_callback
self._stderr_cb = stderr_callback
self._terminated_cb = terminated_callback self._terminated_cb = terminated_callback
self._pid: int = None self._pid: int = None
self._child_fd: int = None self._child_stdin: int = None
self._child_stdout: int = None
self._child_stderr: int = None
self._return_code: int = None self._return_code: int = None
self._eof: bool = False self._eof: bool = False
self._stdin_is_pipe = stdin_is_pipe
self._stdout_is_pipe = stdout_is_pipe
self._stderr_is_pipe = stderr_is_pipe
def terminate(self, kill_delay: float = 1.0): def terminate(self, kill_delay: float = 1.0):
""" """
@ -382,6 +451,10 @@ class CallbackSubprocess:
threading.Thread(target=wait).start() threading.Thread(target=wait).start()
def close_stdin(self):
with contextlib.suppress(Exception):
os.close(self._child_stdin)
@property @property
def started(self) -> bool: def started(self) -> bool:
""" """
@ -402,7 +475,7 @@ class CallbackSubprocess:
:param data: bytes to write :param data: bytes to write
""" """
self._log.debug(f"write({data})") self._log.debug(f"write({data})")
os.write(self._child_fd, data) os.write(self._child_stdin, data)
def set_winsize(self, r: int, c: int, h: int, v: int): def set_winsize(self, r: int, c: int, h: int, v: int):
""" """
@ -414,7 +487,7 @@ class CallbackSubprocess:
:return: :return:
""" """
self._log.debug(f"set_winsize({r},{c},{h},{v}") self._log.debug(f"set_winsize({r},{c},{h},{v}")
tty_set_winsize(self._child_fd, r, c, h, v) tty_set_winsize(self._child_stdout, r, c, h, v)
def copy_winsize(self, fromfd: int): def copy_winsize(self, fromfd: int):
""" """
@ -430,17 +503,17 @@ class CallbackSubprocess:
:param when: when to apply change: termios.TCSANOW or termios.TCSADRAIN or termios.TCSAFLUSH :param when: when to apply change: termios.TCSANOW or termios.TCSADRAIN or termios.TCSAFLUSH
:param attr: attributes to set :param attr: attributes to set
""" """
termios.tcsetattr(self._child_fd, when, attr) termios.tcsetattr(self._child_stdin, when, attr)
def tcgetattr(self) -> list[any]: # actual type is list[int | list[int | bytes]] def tcgetattr(self) -> list[any]: # actual type is list[int | list[int | bytes]]
""" """
Get tty attributes. Get tty attributes.
:return: tty attributes value :return: tty attributes value
""" """
return termios.tcgetattr(self._child_fd) return termios.tcgetattr(self._child_stdout)
def ttysetraw(self): def ttysetraw(self):
tty.setraw(self._child_fd, termios.TCSADRAIN) tty.setraw(self._child_stdout, termios.TCSADRAIN)
def start(self): def start(self):
""" """
@ -469,27 +542,11 @@ class CallbackSubprocess:
# env["SHELL"] = program # env["SHELL"] = program
# self._log.debug(f"set login shell {self._command}") # self._log.debug(f"set login shell {self._command}")
self._pid, self._child_fd = pty.fork() self._pid, \
self._child_stdin, \
if self._pid == 0: self._child_stdout, \
try: self._child_stderr = _launch_child(self._command, env, self._stdin_is_pipe, self._stdout_is_pipe,
# This may not be strictly necessary, but there is self._stderr_is_pipe)
# occasionally some funny business that goes on with
# networking after the fork. Anecdotally this fixed
# it, but more testing is needed as it might be a
# coincidence.
p = psutil.Process()
for c in p.connections(kind='all'):
with exception.permit(SystemExit):
os.close(c.fd)
# TODO: verify that skipping setpgrp fixes Operation not permitted on Manjaro
# os.setpgrp()
os.execvpe(program, self._command, env)
except Exception as err:
print(f"Child process error: {err}, command: {self._command}")
sys.stdout.flush()
# don't let any other modules get in our way.
os._exit(0)
def poll(): def poll():
# self.log.debug("poll") # self.log.debug("poll")
@ -515,10 +572,14 @@ class CallbackSubprocess:
callback(data) callback(data)
except EOFError: except EOFError:
self._eof = True self._eof = True
tty_unset_reader_callbacks(self._child_fd) tty_unset_reader_callbacks(self._child_stdout)
callback(bytearray()) callback(bytearray())
tty_add_reader_callback(self._child_fd, functools.partial(reader, self._child_fd, self._stdout_cb), self._loop) tty_add_reader_callback(self._child_stdout, functools.partial(reader, self._child_stdout, self._stdout_cb),
self._loop)
if self._child_stderr != self._child_stdout:
tty_add_reader_callback(self._child_stderr, functools.partial(reader, self._child_stderr, self._stderr_cb),
self._loop)
@property @property
def eof(self): def eof(self):

View File

@ -227,9 +227,10 @@ def _protocol_make_version(version: int):
_PROTOCOL_VERSION_0 = _protocol_make_version(0) _PROTOCOL_VERSION_0 = _protocol_make_version(0)
_PROTOCOL_VERSION_1 = _protocol_make_version(1) _PROTOCOL_VERSION_1 = _protocol_make_version(1)
_PROTOCOL_VERSION_2 = _protocol_make_version(1) _PROTOCOL_VERSION_2 = _protocol_make_version(2)
_PROTOCOL_VERSION_3 = _protocol_make_version(3)
_PROTOCOL_VERSION_DEFAULT = _PROTOCOL_VERSION_2 _PROTOCOL_VERSION_DEFAULT = _PROTOCOL_VERSION_3
def _protocol_split_version(version: int): def _protocol_split_version(version: int):
return (version >> 32) & 0xffffffff, version & 0xffffffff return (version >> 32) & 0xffffffff, version & 0xffffffff
@ -253,6 +254,16 @@ def _protocol_request_chars_take(link_mdu: int, version: int, term: str, cmd: st
return link_mdu // 2 return link_mdu // 2
def _bitwise_or_if(value: int, condition: bool, orval: int):
if not condition:
return value
return value | orval
def _check_and(value: int, andval: int) -> bool:
return (value & andval) > 0
class Session: class Session:
_processes: [(any, Session)] = [] _processes: [(any, Session)] = []
_lock = threading.RLock() _lock = threading.RLock()
@ -279,6 +290,7 @@ class Session:
cmd: [str], cmd: [str],
data_available_callback: callable, data_available_callback: callable,
terminated_callback: callable, terminated_callback: callable,
session_flags: int,
term: str | None, term: str | None,
remote_identity: str | None, remote_identity: str | None,
loop: asyncio.AbstractEventLoop = None): loop: asyncio.AbstractEventLoop = None):
@ -290,15 +302,24 @@ class Session:
"RNS_REMOTE_IDENTITY": remote_identity or ""}, "RNS_REMOTE_IDENTITY": remote_identity or ""},
loop=loop, loop=loop,
stdout_callback=self._stdout_data, stdout_callback=self._stdout_data,
terminated_callback=terminated_callback) stderr_callback=self._stderr_data,
terminated_callback=terminated_callback,
stdin_is_pipe=_check_and(session_flags,
Session.REQUEST_FLAGS_PIPE_STDIN),
stdout_is_pipe=_check_and(session_flags,
Session.REQUEST_FLAGS_PIPE_STDOUT),
stderr_is_pipe=_check_and(session_flags,
Session.REQUEST_FLAGS_PIPE_STDERR))
self._log.debug(f"Starting {cmd}") self._log.debug(f"Starting {cmd}")
self._data_buffer = bytearray() self._stdout_buffer = bytearray()
self._stderr_buffer = bytearray()
self._lock = threading.RLock() self._lock = threading.RLock()
self._data_available_cb = data_available_callback self._data_available_cb = data_available_callback
self._terminated_cb = terminated_callback self._terminated_cb = terminated_callback
self._pending_receipt: RNS.PacketReceipt | None = None self._pending_receipt: RNS.PacketReceipt | None = None
self._process.start() self._process.start()
self._term_state: [int] = None self._term_state: [int] = None
self._session_flags: int = session_flags
Session.put_for_tag(tag, self) Session.put_for_tag(tag, self)
def pending_receipt_peek(self) -> RNS.PacketReceipt | None: def pending_receipt_peek(self) -> RNS.PacketReceipt | None:
@ -326,22 +347,39 @@ class Session:
def lock(self) -> threading.RLock: def lock(self) -> threading.RLock:
return self._lock return self._lock
def read(self, count: int) -> bytes: def read_stdout(self, count: int) -> bytes:
with self.lock: with self.lock:
initial_len = len(self._data_buffer) initial_len = len(self._stdout_buffer)
take = self._data_buffer[:count] take = self._stdout_buffer[:count]
self._data_buffer = self._data_buffer[count:].copy() self._stdout_buffer = self._stdout_buffer[count:]
self._log.debug(f"read {len(take)} bytes of {initial_len}, {len(self._data_buffer)} remaining") self._log.debug(f"stdout: read {len(take)} bytes of {initial_len}, {len(self._stdout_buffer)} remaining")
return take return take
def _stdout_data(self, data: bytes): def _stdout_data(self, data: bytes):
with self.lock: with self.lock:
self._data_buffer.extend(data) self._stdout_buffer.extend(data)
total_available = len(self._data_buffer) total_available = len(self._stdout_buffer) + len(self._stderr_buffer)
try: try:
self._data_available_cb(total_available) self._data_available_cb(total_available)
except Exception as e: except Exception as e:
self._log.error(f"Error calling ProcessState data_available_callback {e}") self._log.error(f"stdout: error calling ProcessState data_available_callback {e}")
def read_stderr(self, count: int) -> bytes:
with self.lock:
initial_len = len(self._stderr_buffer)
take = self._stderr_buffer[:count]
self._stderr_buffer = self._stderr_buffer[count:]
self._log.debug(f"stderr: read {len(take)} bytes of {initial_len}, {len(self._stderr_buffer)} remaining")
return take
def _stderr_data(self, data: bytes):
with self.lock:
self._stderr_buffer.extend(data)
total_available = len(self._stderr_buffer) + len(self._stdout_buffer)
try:
self._data_available_cb(total_available)
except Exception as e:
self._log.error(f"stderr: error calling ProcessState data_available_callback {e}")
TERMSTATE_IDX_TERM = 0 TERMSTATE_IDX_TERM = 0
TERMSTATE_IDX_TIOS = 1 TERMSTATE_IDX_TIOS = 1
@ -368,30 +406,44 @@ class Session:
REQUEST_IDX_HPIX = 6 REQUEST_IDX_HPIX = 6
REQUEST_IDX_VPIX = 7 REQUEST_IDX_VPIX = 7
REQUEST_IDX_CMD = 8 REQUEST_IDX_CMD = 8
REQUEST_IDX_FLAGS = 9
REQUEST_IDX_BYTES_AVAILABLE = 10
REQUEST_FLAGS_PIPE_STDIN = 0x01
REQUEST_FLAGS_PIPE_STDOUT = 0x02
REQUEST_FLAGS_PIPE_STDERR = 0x04
REQUEST_FLAGS_EOF_STDIN = 0x08
@staticmethod @staticmethod
def default_request(stdin_fd: int | None) -> [any]: def default_request() -> [any]:
global _tr global _tr
request: list[any] = [ request: list[any] = [
_PROTOCOL_VERSION_DEFAULT, # 0 Protocol Version _PROTOCOL_VERSION_DEFAULT, # 0 Protocol Version
None, # 1 Stdin None, # 1 Stdin
None, # 2 TERM variable None, # 2 TERM variable
None, # 3 termios attributes or something None, # 3 termios attributes or something
None, # 4 terminal rows None, # 4 terminal rows
None, # 5 terminal cols None, # 5 terminal cols
None, # 6 terminal horizontal pixels None, # 6 terminal horizontal pixels
None, # 7 terminal vertical pixels None, # 7 terminal vertical pixels
None, # 8 Command to run None, # 8 Command to run
0, # 9 Flags
0, # 10 Bytes Available
].copy() ].copy()
if stdin_fd is not None: if os.isatty(0):
request[Session.REQUEST_IDX_TERM] = os.environ.get("TERM", None) request[Session.REQUEST_IDX_TERM] = os.environ.get("TERM", None)
request[Session.REQUEST_IDX_TIOS] = _tr.original_attr() if _tr else None request[Session.REQUEST_IDX_TIOS] = _tr.original_attr() if _tr else None
with contextlib.suppress(OSError): with contextlib.suppress(OSError):
request[Session.REQUEST_IDX_ROWS], \ request[Session.REQUEST_IDX_ROWS], \
request[Session.REQUEST_IDX_COLS], \ request[Session.REQUEST_IDX_COLS], \
request[Session.REQUEST_IDX_HPIX], \ request[Session.REQUEST_IDX_HPIX], \
request[Session.REQUEST_IDX_VPIX] = process.tty_get_winsize(stdin_fd) request[Session.REQUEST_IDX_VPIX] = process.tty_get_winsize(0)
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], not os.isatty(0),
Session.REQUEST_FLAGS_PIPE_STDIN)
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], not os.isatty(1),
Session.REQUEST_FLAGS_PIPE_STDOUT)
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], not os.isatty(2),
Session.REQUEST_FLAGS_PIPE_STDERR)
return request return request
def process_request(self, data: [any], read_size: int) -> [any]: def process_request(self, data: [any], read_size: int) -> [any]:
@ -403,12 +455,16 @@ class Session:
# hpix = data[ProcessState.REQUEST_IDX_HPIX] # window horizontal pixels # hpix = data[ProcessState.REQUEST_IDX_HPIX] # window horizontal pixels
# vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels # vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels
# term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX+1] # term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX+1]
bytes_available = data[Session.REQUEST_IDX_BYTES_AVAILABLE]
flags = data[Session.REQUEST_IDX_FLAGS]
stdin_eof = _check_and(flags, Session.REQUEST_FLAGS_EOF_STDIN)
response = Session.default_response() response = Session.default_response()
first_term_state = self._term_state is None first_term_state = self._term_state is None
term_state = data[Session.REQUEST_IDX_TIOS:Session.REQUEST_IDX_VPIX + 1] term_state = data[Session.REQUEST_IDX_TIOS:Session.REQUEST_IDX_VPIX + 1]
response[Session.RESPONSE_IDX_RUNNING] = self.process.running response[Session.RESPONSE_IDX_FLAGS] = _bitwise_or_if(response[Session.RESPONSE_IDX_FLAGS],
self.process.running, Session.RESPONSE_FLAGS_RUNNING)
if self.process.running: if self.process.running:
if term_state != self._term_state: if term_state != self._term_state:
self._term_state = term_state self._term_state = term_state
@ -422,45 +478,54 @@ class Session:
if data[Session.REQUEST_IDX_VERS] < _PROTOCOL_VERSION_2: if data[Session.REQUEST_IDX_VERS] < _PROTOCOL_VERSION_2:
stdin = base64.b64decode(stdin) stdin = base64.b64decode(stdin)
self.process.write(stdin) self.process.write(stdin)
if stdin_eof and bytes_available == 0:
module_logger.debug("Closing stdin")
with contextlib.suppress(Exception):
self.process.close_stdin()
response[Session.RESPONSE_IDX_RETCODE] = None if self.process.running else self.return_code response[Session.RESPONSE_IDX_RETCODE] = None if self.process.running else self.return_code
with self.lock: with self.lock:
stdout = self.read(read_size) #prioritizing stderr
response[Session.RESPONSE_IDX_RDYBYTE] = len(self._data_buffer) stderr = self.read_stderr(read_size)
stdout = self.read_stdout(read_size - len(stderr))
response[Session.RESPONSE_IDX_RDYBYTE] = len(self._stdout_buffer) + len(self._stderr_buffer)
if stderr is not None and len(stderr) > 0:
response[Session.RESPONSE_IDX_STDERR] = bytes(stderr)
if stdout is not None and len(stdout) > 0: if stdout is not None and len(stdout) > 0:
if data[Session.REQUEST_IDX_VERS] < _PROTOCOL_VERSION_2: response[Session.RESPONSE_IDX_STDOUT] = bytes(stdout)
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode(stdout).decode("utf-8")
else:
response[Session.RESPONSE_IDX_STDOUT] = bytes(stdout)
return response return response
RESPONSE_IDX_VERSION = 0 RESPONSE_IDX_VERSION = 0
RESPONSE_IDX_RUNNING = 1 RESPONSE_IDX_FLAGS = 1
RESPONSE_IDX_RETCODE = 2 RESPONSE_IDX_RETCODE = 2
RESPONSE_IDX_RDYBYTE = 3 RESPONSE_IDX_RDYBYTE = 3
RESPONSE_IDX_STDOUT = 4 RESPONSE_IDX_STDERR = 4
RESPONSE_IDX_TMSTAMP = 5 RESPONSE_IDX_STDOUT = 5
RESPONSE_IDX_TMSTAMP = 6
RESPONSE_FLAGS_RUNNING = 0x01
RESPONSE_FLAGS_EOF_STDOUT = 0x02
RESPONSE_FLAGS_EOF_STDERR = 0x04
@staticmethod @staticmethod
def default_response(version: int = _PROTOCOL_VERSION_2) -> [any]: def default_response() -> [any]:
response: list[any] = [ response: list[any] = [
version, # 0: Protocol version _PROTOCOL_VERSION_DEFAULT, # 0: Protocol version
False, # 1: Process running False, # 1: Process running
None, # 2: Return value None, # 2: Return value
0, # 3: Number of outstanding bytes 0, # 3: Number of outstanding bytes
None, # 4: Stdout/Stderr None, # 4: Stderr
None, # 5: Timestamp None, # 5: Stdout
None, # 6: Timestamp
].copy() ].copy()
response[Session.RESPONSE_IDX_TMSTAMP] = time.time() response[Session.RESPONSE_IDX_TMSTAMP] = time.time()
return response return response
@classmethod @classmethod
def error_response(cls, msg: str, version: int = _PROTOCOL_VERSION_2) -> [any]: def error_response(cls, msg: str) -> [any]:
response = cls.default_response(version) response = cls.default_response()
msg_bytes = f"{msg}\r\n".encode("utf-8") msg_bytes = f"{msg}\r\n".encode("utf-8")
response[Session.RESPONSE_IDX_STDOUT] = \ response[Session.RESPONSE_IDX_STDERR] = bytes(msg_bytes)
base64.b64encode(msg_bytes) if version < _PROTOCOL_VERSION_2 else bytes(msg_bytes)
response[Session.RESPONSE_IDX_RETCODE] = 255 response[Session.RESPONSE_IDX_RETCODE] = 255
response[Session.RESPONSE_IDX_RDYBYTE] = 0 response[Session.RESPONSE_IDX_RDYBYTE] = 0
return response return response
@ -543,11 +608,12 @@ def _subproc_terminated(link: RNS.Link, return_code: int):
def _listen_start_proc(link: RNS.Link, remote_identity: str | None, term: str, cmd: [str], def _listen_start_proc(link: RNS.Link, remote_identity: str | None, term: str, cmd: [str],
loop: asyncio.AbstractEventLoop) -> Session | None: loop: asyncio.AbstractEventLoop, session_flags: int) -> Session | None:
log = _get_logger("_listen_start_proc") log = _get_logger("_listen_start_proc")
try: try:
return Session(tag=link.link_id, return Session(tag=link.link_id,
cmd=cmd, cmd=cmd,
session_flags=session_flags,
term=term, term=term,
remote_identity=remote_identity, remote_identity=remote_identity,
loop=loop, loop=loop,
@ -602,25 +668,27 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
_retry_timer.complete(link_id) _retry_timer.complete(link_id)
link: RNS.Link = next(filter(lambda l: l.link_id == link_id, _destination.links), None) link: RNS.Link = next(filter(lambda l: l.link_id == link_id, _destination.links), None)
if link is None: if link is None:
raise Exception(f"Invalid request {request_id}, no link found with id {link_id}") log.error(f"Invalid request {request_id}, no link found with id {link_id}")
return None
remote_version = data[Session.REQUEST_IDX_VERS] remote_version = data[Session.REQUEST_IDX_VERS]
if not _protocol_check_magic(remote_version): if not _protocol_check_magic(remote_version):
raise Exception("Request magic incorrect") log.error("Request magic incorrect")
link.teardown()
return None
if not remote_version <= _PROTOCOL_VERSION_DEFAULT: if not remote_version <= _PROTOCOL_VERSION_3:
return Session.error_response("Listener<->initiator version mismatch") return Session.error_response("Listener<->initiator version mismatch")
cmd = _cmd.copy() cmd = _cmd.copy()
if remote_version >= _PROTOCOL_VERSION_1: remote_command = data[Session.REQUEST_IDX_CMD]
remote_command = data[Session.REQUEST_IDX_CMD] if remote_command is not None and len(remote_command) > 0:
if remote_command is not None and len(remote_command) > 0: if _no_remote_command:
if _no_remote_command: return Session.error_response("Listener does not permit initiator to provide command.")
return Session.error_response("Listener does not permit initiator to provide command.") elif _remote_cmd_as_args:
elif _remote_cmd_as_args: cmd.extend(remote_command)
cmd.extend(remote_command) else:
else: cmd = remote_command
cmd = remote_command
if not _no_remote_command and (cmd is None or len(cmd) == 0): if not _no_remote_command and (cmd is None or len(cmd) == 0):
return Session.error_response("No command supplied and no default command available.") return Session.error_response("No command supplied and no default command available.")
@ -636,9 +704,12 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
log.debug(f"Process not found for link {link}") log.debug(f"Process not found for link {link}")
session = _listen_start_proc(link=link, session = _listen_start_proc(link=link,
term=term, term=term,
session_flags=data[Session.REQUEST_IDX_FLAGS],
cmd=cmd, cmd=cmd,
remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""), remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""),
loop=_loop) loop=_loop)
if session is None:
return Session.error_response("Unable to start subprocess")
# leave significant headroom for metadata and encoding # leave significant headroom for metadata and encoding
result = session.process_request(data, _protocol_response_chars_take(link.MDU, remote_version)) result = session.process_request(data, _protocol_response_chars_take(link.MDU, remote_version))
@ -695,7 +766,7 @@ def _response_handler(request_receipt: RNS.RequestReceipt):
async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=False, destination=None, async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=False, destination=None,
service_name="default", stdin=None, timeout=RNS.Transport.PATH_REQUEST_TIMEOUT, service_name="default", stdin=None, timeout=RNS.Transport.PATH_REQUEST_TIMEOUT,
cmd: [str] | None = None): cmd: [str] | None = None, stdin_eof=False, bytes_available=0):
global _identity, _reticulum, _link, _destination, _remote_exec_grace, _tr, _new_data global _identity, _reticulum, _link, _destination, _remote_exec_grace, _tr, _new_data
log = _get_logger("_execute") log = _get_logger("_execute")
@ -749,11 +820,14 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
_link.set_packet_callback(_client_packet_handler) _link.set_packet_callback(_client_packet_handler)
request = Session.default_request(sys.stdin.fileno()) request = Session.default_request()
log.debug(f"Sending {len(stdin) or 0} bytes to listener") log.debug(f"Sending {len(stdin) or 0} bytes to listener")
# log.debug(f"Sending {stdin} to listener") # log.debug(f"Sending {stdin} to listener")
request[Session.REQUEST_IDX_STDIN] = bytes(stdin) request[Session.REQUEST_IDX_STDIN] = bytes(stdin)
request[Session.REQUEST_IDX_CMD] = cmd request[Session.REQUEST_IDX_CMD] = cmd
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], stdin_eof,
Session.REQUEST_FLAGS_EOF_STDIN)
request[Session.REQUEST_IDX_BYTES_AVAILABLE] = bytes_available
# TODO: Tune # TODO: Tune
timeout = timeout + _link.rtt * 4 + _remote_exec_grace timeout = timeout + _link.rtt * 4 + _remote_exec_grace
@ -796,13 +870,17 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
version = request_receipt.response[Session.RESPONSE_IDX_VERSION] or 0 version = request_receipt.response[Session.RESPONSE_IDX_VERSION] or 0
if not _protocol_check_magic(version): if not _protocol_check_magic(version):
raise RemoteExecutionError("Protocol error") raise RemoteExecutionError("Protocol error")
elif version != _PROTOCOL_VERSION_2: elif version != _PROTOCOL_VERSION_3:
raise RemoteExecutionError("Protocol version mismatch") raise RemoteExecutionError("Protocol version mismatch")
running = request_receipt.response[Session.RESPONSE_IDX_RUNNING] or True flags = request_receipt.response[Session.RESPONSE_IDX_FLAGS]
running = _check_and(flags, Session.RESPONSE_FLAGS_RUNNING)
stdout_eof = _check_and(flags, Session.RESPONSE_FLAGS_EOF_STDOUT)
stderr_eof = _check_and(flags, Session.RESPONSE_FLAGS_EOF_STDERR)
return_code = request_receipt.response[Session.RESPONSE_IDX_RETCODE] return_code = request_receipt.response[Session.RESPONSE_IDX_RETCODE]
ready_bytes = request_receipt.response[Session.RESPONSE_IDX_RDYBYTE] or 0 ready_bytes = request_receipt.response[Session.RESPONSE_IDX_RDYBYTE] or 0
stdout = request_receipt.response[Session.RESPONSE_IDX_STDOUT] stdout = request_receipt.response[Session.RESPONSE_IDX_STDOUT]
stderr = request_receipt.response[Session.RESPONSE_IDX_STDERR]
# if stdout is not None: # if stdout is not None:
# stdout = base64.b64decode(stdout) # stdout = base64.b64decode(stdout)
timestamp = request_receipt.response[Session.RESPONSE_IDX_TMSTAMP] timestamp = request_receipt.response[Session.RESPONSE_IDX_TMSTAMP]
@ -815,10 +893,24 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
if stdout is not None: if stdout is not None:
_tr.raw() _tr.raw()
log.debug(f"stdout: {stdout}") log.debug(f"stdout: {stdout}")
os.write(sys.stdout.fileno(), stdout) os.write(1, stdout)
sys.stdout.flush() sys.stdout.flush()
got_bytes = len(stdout) if stdout is not None else 0 if stderr is not None:
_tr.raw()
log.debug(f"stderr: {stderr}")
os.write(2, stderr)
sys.stderr.flush()
if stderr_eof and ready_bytes == 0:
log.debug("Closing stderr")
os.close(2)
if stdout_eof and ready_bytes == 0:
log.debug("Closing stdout")
os.close(1)
got_bytes = (len(stdout) if stdout is not None else 0) + (len(stderr) if stderr is not None else 0)
log.debug(f"{got_bytes} chars received, {ready_bytes} bytes ready on server, return code {return_code}") log.debug(f"{got_bytes} chars received, {ready_bytes} bytes ready on server, return code {return_code}")
if ready_bytes > 0: if ready_bytes > 0:
@ -847,7 +939,9 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
if _new_data is not None: if _new_data is not None:
_new_data.set() _new_data.set()
stdin_eof = False
def stdin(): def stdin():
nonlocal stdin_eof
try: try:
data = process.tty_read(sys.stdin.fileno()) data = process.tty_read(sys.stdin.fileno())
log.debug(f"stdin {data}") log.debug(f"stdin {data}")
@ -855,7 +949,9 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
data_buffer.extend(data) data_buffer.extend(data)
_new_data.set() _new_data.set()
except EOFError: except EOFError:
data_buffer.extend(process.CTRL_D) if os.isatty(0):
data_buffer.extend(process.CTRL_D)
stdin_eof = True
process.tty_unset_reader_callbacks(sys.stdin.fileno()) process.tty_unset_reader_callbacks(sys.stdin.fileno())
process.tty_add_reader_callback(sys.stdin.fileno(), stdin) process.tty_add_reader_callback(sys.stdin.fileno(), stdin)
@ -886,6 +982,8 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
stdin=stdin, stdin=stdin,
timeout=timeout, timeout=timeout,
cmd=command, cmd=command,
stdin_eof=stdin_eof,
bytes_available=len(data_buffer)
) )
if first_loop: if first_loop:
@ -975,11 +1073,6 @@ async def _rnsh_cli_main():
def rnsh_cli(): def rnsh_cli():
global _tr, _retry_timer, _pre_input global _tr, _retry_timer, _pre_input
with contextlib.suppress(Exception):
if not os.isatty(sys.stdin.fileno()):
time.sleep(0.1) # attempting to deal with an issue with missing input
tty.setraw(sys.stdin.fileno(), termios.TCSADRAIN)
with process.TTYRestorer(sys.stdin.fileno()) as _tr, retry.RetryThread() as _retry_timer: with process.TTYRestorer(sys.stdin.fileno()) as _tr, retry.RetryThread() as _retry_timer:
return_code = asyncio.run(_rnsh_cli_main()) return_code = asyncio.run(_rnsh_cli_main())

View File

@ -37,7 +37,11 @@ class SubprocessReader(contextlib.AbstractContextManager):
env=self.env, env=self.env,
loop=self.loop, loop=self.loop,
stdout_callback=self._stdout_cb, stdout_callback=self._stdout_cb,
terminated_callback=self._terminated_cb) terminated_callback=self._terminated_cb,
stderr_callback=self._stdout_cb,
stdin_is_pipe=False,
stdout_is_pipe=False,
stderr_is_pipe=False)
def _stdout_cb(self, data): def _stdout_cb(self, data):
self._log.debug(f"_stdout_cb({data})") self._log.debug(f"_stdout_cb({data})")

7
tty_test.py Normal file
View File

@ -0,0 +1,7 @@
import os
import sys
for stream in [sys.stdin, sys.stdout, sys.stderr]:
print(f"{stream.name:8s} " + ("tty" if os.isatty(stream.fileno()) else "not tty"))
print(f"args: {sys.argv}")