mirror of
https://github.com/markqvist/rnsh.git
synced 2024-10-01 01:15:37 -04:00
Support for pipes. #3
Connects pipes on the listener child process the same way they are connected on the initiator--any combination of stdin, stdout, and stderr could be connected to a pipe. Any one of those not connected to a pipe will be connected to a pty.
This commit is contained in:
parent
27664df0b3
commit
ebdcb50265
121
rnsh/process.py
121
rnsh/process.py
@ -323,12 +323,75 @@ async def event_wait(evt: asyncio.Event, timeout: float) -> bool:
|
||||
return evt.is_set()
|
||||
|
||||
|
||||
def _launch_child(cmd_line: list[str], env: dict[str, str], stdin_is_pipe: bool, stdout_is_pipe: bool,
|
||||
stderr_is_pipe: bool) -> tuple[int, int, int, int]:
|
||||
# Set up PTY and/or pipes
|
||||
child_fd = parent_fd = None
|
||||
if not (stdin_is_pipe and stdout_is_pipe and stderr_is_pipe):
|
||||
parent_fd, child_fd = pty.openpty()
|
||||
child_stdin, parent_stdin = (os.pipe() if stdin_is_pipe else (child_fd, parent_fd))
|
||||
parent_stdout, child_stdout = (os.pipe() if stdout_is_pipe else (parent_fd, child_fd))
|
||||
parent_stderr, child_stderr = (os.pipe() if stderr_is_pipe else (parent_fd, child_fd))
|
||||
|
||||
# Fork
|
||||
pid = os.fork()
|
||||
|
||||
if pid == 0:
|
||||
try:
|
||||
# We are in the child process, so close all open sockets and pipes except for the PTY and/or pipes
|
||||
max_fd = os.sysconf("SC_OPEN_MAX")
|
||||
for fd in range(3, max_fd):
|
||||
if fd not in (child_stdin, child_stdout, child_stderr):
|
||||
try:
|
||||
os.close(fd)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Set up PTY and/or pipes
|
||||
os.dup2(child_stdin, 0)
|
||||
os.dup2(child_stdout, 1)
|
||||
os.dup2(child_stderr, 2)
|
||||
# Make PTY controlling if necessary
|
||||
if not stdin_is_pipe:
|
||||
os.setsid()
|
||||
tmp_fd = os.open(os.ttyname(0), os.O_RDWR)
|
||||
os.close(tmp_fd)
|
||||
# fcntl.ioctl(0, termios.TIOCSCTTY, 0)
|
||||
|
||||
# Execute the command
|
||||
os.execvpe(cmd_line[0], cmd_line, env)
|
||||
except Exception as err:
|
||||
exc_type, exc_obj, exc_tb = sys.exc_info()
|
||||
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
|
||||
print(f"Unable to start {cmd_line[0]}: {err} ({fname}:{exc_tb.tb_lineno})")
|
||||
sys.stdout.flush()
|
||||
# don't let any other modules get in our way, do an immediate silent exit.
|
||||
os._exit(0)
|
||||
|
||||
else:
|
||||
# We are in the parent process, so close the child-side of the PTY and/or pipes
|
||||
if child_fd is not None:
|
||||
os.close(child_fd)
|
||||
if child_stdin != child_fd:
|
||||
os.close(child_stdin)
|
||||
if child_stdout != child_fd:
|
||||
os.close(child_stdout)
|
||||
if child_stderr != child_fd:
|
||||
os.close(child_stderr)
|
||||
# # Close the write end of the pipe if a pipe is used for standard input
|
||||
# if not stdin_is_pipe:
|
||||
# os.close(parent_stdin)
|
||||
# Return the child PID and the file descriptors for the PTY and/or pipes
|
||||
return pid, parent_stdin, parent_stdout, parent_stderr
|
||||
|
||||
|
||||
class CallbackSubprocess:
|
||||
# time between checks of child process
|
||||
PROCESS_POLL_TIME: float = 0.1
|
||||
|
||||
def __init__(self, argv: [str], env: dict, loop: asyncio.AbstractEventLoop, stdout_callback: callable,
|
||||
terminated_callback: callable):
|
||||
stderr_callback: callable, terminated_callback: callable, stdin_is_pipe: bool, stdout_is_pipe: bool,
|
||||
stderr_is_pipe: bool):
|
||||
"""
|
||||
Fork a child process and generate callbacks with output from the process.
|
||||
:param argv: the command line, tokenized. The first element must be the absolute path to an executable file.
|
||||
@ -347,11 +410,17 @@ class CallbackSubprocess:
|
||||
self._env = env or {}
|
||||
self._loop = loop
|
||||
self._stdout_cb = stdout_callback
|
||||
self._stderr_cb = stderr_callback
|
||||
self._terminated_cb = terminated_callback
|
||||
self._pid: int = None
|
||||
self._child_fd: int = None
|
||||
self._child_stdin: int = None
|
||||
self._child_stdout: int = None
|
||||
self._child_stderr: int = None
|
||||
self._return_code: int = None
|
||||
self._eof: bool = False
|
||||
self._stdin_is_pipe = stdin_is_pipe
|
||||
self._stdout_is_pipe = stdout_is_pipe
|
||||
self._stderr_is_pipe = stderr_is_pipe
|
||||
|
||||
def terminate(self, kill_delay: float = 1.0):
|
||||
"""
|
||||
@ -382,6 +451,10 @@ class CallbackSubprocess:
|
||||
|
||||
threading.Thread(target=wait).start()
|
||||
|
||||
def close_stdin(self):
|
||||
with contextlib.suppress(Exception):
|
||||
os.close(self._child_stdin)
|
||||
|
||||
@property
|
||||
def started(self) -> bool:
|
||||
"""
|
||||
@ -402,7 +475,7 @@ class CallbackSubprocess:
|
||||
:param data: bytes to write
|
||||
"""
|
||||
self._log.debug(f"write({data})")
|
||||
os.write(self._child_fd, data)
|
||||
os.write(self._child_stdin, data)
|
||||
|
||||
def set_winsize(self, r: int, c: int, h: int, v: int):
|
||||
"""
|
||||
@ -414,7 +487,7 @@ class CallbackSubprocess:
|
||||
:return:
|
||||
"""
|
||||
self._log.debug(f"set_winsize({r},{c},{h},{v}")
|
||||
tty_set_winsize(self._child_fd, r, c, h, v)
|
||||
tty_set_winsize(self._child_stdout, r, c, h, v)
|
||||
|
||||
def copy_winsize(self, fromfd: int):
|
||||
"""
|
||||
@ -430,17 +503,17 @@ class CallbackSubprocess:
|
||||
:param when: when to apply change: termios.TCSANOW or termios.TCSADRAIN or termios.TCSAFLUSH
|
||||
:param attr: attributes to set
|
||||
"""
|
||||
termios.tcsetattr(self._child_fd, when, attr)
|
||||
termios.tcsetattr(self._child_stdin, when, attr)
|
||||
|
||||
def tcgetattr(self) -> list[any]: # actual type is list[int | list[int | bytes]]
|
||||
"""
|
||||
Get tty attributes.
|
||||
:return: tty attributes value
|
||||
"""
|
||||
return termios.tcgetattr(self._child_fd)
|
||||
return termios.tcgetattr(self._child_stdout)
|
||||
|
||||
def ttysetraw(self):
|
||||
tty.setraw(self._child_fd, termios.TCSADRAIN)
|
||||
tty.setraw(self._child_stdout, termios.TCSADRAIN)
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
@ -469,27 +542,11 @@ class CallbackSubprocess:
|
||||
# env["SHELL"] = program
|
||||
# self._log.debug(f"set login shell {self._command}")
|
||||
|
||||
self._pid, self._child_fd = pty.fork()
|
||||
|
||||
if self._pid == 0:
|
||||
try:
|
||||
# This may not be strictly necessary, but there is
|
||||
# occasionally some funny business that goes on with
|
||||
# networking after the fork. Anecdotally this fixed
|
||||
# it, but more testing is needed as it might be a
|
||||
# coincidence.
|
||||
p = psutil.Process()
|
||||
for c in p.connections(kind='all'):
|
||||
with exception.permit(SystemExit):
|
||||
os.close(c.fd)
|
||||
# TODO: verify that skipping setpgrp fixes Operation not permitted on Manjaro
|
||||
# os.setpgrp()
|
||||
os.execvpe(program, self._command, env)
|
||||
except Exception as err:
|
||||
print(f"Child process error: {err}, command: {self._command}")
|
||||
sys.stdout.flush()
|
||||
# don't let any other modules get in our way.
|
||||
os._exit(0)
|
||||
self._pid, \
|
||||
self._child_stdin, \
|
||||
self._child_stdout, \
|
||||
self._child_stderr = _launch_child(self._command, env, self._stdin_is_pipe, self._stdout_is_pipe,
|
||||
self._stderr_is_pipe)
|
||||
|
||||
def poll():
|
||||
# self.log.debug("poll")
|
||||
@ -515,10 +572,14 @@ class CallbackSubprocess:
|
||||
callback(data)
|
||||
except EOFError:
|
||||
self._eof = True
|
||||
tty_unset_reader_callbacks(self._child_fd)
|
||||
tty_unset_reader_callbacks(self._child_stdout)
|
||||
callback(bytearray())
|
||||
|
||||
tty_add_reader_callback(self._child_fd, functools.partial(reader, self._child_fd, self._stdout_cb), self._loop)
|
||||
tty_add_reader_callback(self._child_stdout, functools.partial(reader, self._child_stdout, self._stdout_cb),
|
||||
self._loop)
|
||||
if self._child_stderr != self._child_stdout:
|
||||
tty_add_reader_callback(self._child_stderr, functools.partial(reader, self._child_stderr, self._stderr_cb),
|
||||
self._loop)
|
||||
|
||||
@property
|
||||
def eof(self):
|
||||
|
189
rnsh/rnsh.py
189
rnsh/rnsh.py
@ -227,9 +227,10 @@ def _protocol_make_version(version: int):
|
||||
|
||||
_PROTOCOL_VERSION_0 = _protocol_make_version(0)
|
||||
_PROTOCOL_VERSION_1 = _protocol_make_version(1)
|
||||
_PROTOCOL_VERSION_2 = _protocol_make_version(1)
|
||||
_PROTOCOL_VERSION_2 = _protocol_make_version(2)
|
||||
_PROTOCOL_VERSION_3 = _protocol_make_version(3)
|
||||
|
||||
_PROTOCOL_VERSION_DEFAULT = _PROTOCOL_VERSION_2
|
||||
_PROTOCOL_VERSION_DEFAULT = _PROTOCOL_VERSION_3
|
||||
|
||||
def _protocol_split_version(version: int):
|
||||
return (version >> 32) & 0xffffffff, version & 0xffffffff
|
||||
@ -253,6 +254,16 @@ def _protocol_request_chars_take(link_mdu: int, version: int, term: str, cmd: st
|
||||
return link_mdu // 2
|
||||
|
||||
|
||||
def _bitwise_or_if(value: int, condition: bool, orval: int):
|
||||
if not condition:
|
||||
return value
|
||||
return value | orval
|
||||
|
||||
|
||||
def _check_and(value: int, andval: int) -> bool:
|
||||
return (value & andval) > 0
|
||||
|
||||
|
||||
class Session:
|
||||
_processes: [(any, Session)] = []
|
||||
_lock = threading.RLock()
|
||||
@ -279,6 +290,7 @@ class Session:
|
||||
cmd: [str],
|
||||
data_available_callback: callable,
|
||||
terminated_callback: callable,
|
||||
session_flags: int,
|
||||
term: str | None,
|
||||
remote_identity: str | None,
|
||||
loop: asyncio.AbstractEventLoop = None):
|
||||
@ -290,15 +302,24 @@ class Session:
|
||||
"RNS_REMOTE_IDENTITY": remote_identity or ""},
|
||||
loop=loop,
|
||||
stdout_callback=self._stdout_data,
|
||||
terminated_callback=terminated_callback)
|
||||
stderr_callback=self._stderr_data,
|
||||
terminated_callback=terminated_callback,
|
||||
stdin_is_pipe=_check_and(session_flags,
|
||||
Session.REQUEST_FLAGS_PIPE_STDIN),
|
||||
stdout_is_pipe=_check_and(session_flags,
|
||||
Session.REQUEST_FLAGS_PIPE_STDOUT),
|
||||
stderr_is_pipe=_check_and(session_flags,
|
||||
Session.REQUEST_FLAGS_PIPE_STDERR))
|
||||
self._log.debug(f"Starting {cmd}")
|
||||
self._data_buffer = bytearray()
|
||||
self._stdout_buffer = bytearray()
|
||||
self._stderr_buffer = bytearray()
|
||||
self._lock = threading.RLock()
|
||||
self._data_available_cb = data_available_callback
|
||||
self._terminated_cb = terminated_callback
|
||||
self._pending_receipt: RNS.PacketReceipt | None = None
|
||||
self._process.start()
|
||||
self._term_state: [int] = None
|
||||
self._session_flags: int = session_flags
|
||||
Session.put_for_tag(tag, self)
|
||||
|
||||
def pending_receipt_peek(self) -> RNS.PacketReceipt | None:
|
||||
@ -326,22 +347,39 @@ class Session:
|
||||
def lock(self) -> threading.RLock:
|
||||
return self._lock
|
||||
|
||||
def read(self, count: int) -> bytes:
|
||||
def read_stdout(self, count: int) -> bytes:
|
||||
with self.lock:
|
||||
initial_len = len(self._data_buffer)
|
||||
take = self._data_buffer[:count]
|
||||
self._data_buffer = self._data_buffer[count:].copy()
|
||||
self._log.debug(f"read {len(take)} bytes of {initial_len}, {len(self._data_buffer)} remaining")
|
||||
initial_len = len(self._stdout_buffer)
|
||||
take = self._stdout_buffer[:count]
|
||||
self._stdout_buffer = self._stdout_buffer[count:]
|
||||
self._log.debug(f"stdout: read {len(take)} bytes of {initial_len}, {len(self._stdout_buffer)} remaining")
|
||||
return take
|
||||
|
||||
def _stdout_data(self, data: bytes):
|
||||
with self.lock:
|
||||
self._data_buffer.extend(data)
|
||||
total_available = len(self._data_buffer)
|
||||
self._stdout_buffer.extend(data)
|
||||
total_available = len(self._stdout_buffer) + len(self._stderr_buffer)
|
||||
try:
|
||||
self._data_available_cb(total_available)
|
||||
except Exception as e:
|
||||
self._log.error(f"Error calling ProcessState data_available_callback {e}")
|
||||
self._log.error(f"stdout: error calling ProcessState data_available_callback {e}")
|
||||
|
||||
def read_stderr(self, count: int) -> bytes:
|
||||
with self.lock:
|
||||
initial_len = len(self._stderr_buffer)
|
||||
take = self._stderr_buffer[:count]
|
||||
self._stderr_buffer = self._stderr_buffer[count:]
|
||||
self._log.debug(f"stderr: read {len(take)} bytes of {initial_len}, {len(self._stderr_buffer)} remaining")
|
||||
return take
|
||||
|
||||
def _stderr_data(self, data: bytes):
|
||||
with self.lock:
|
||||
self._stderr_buffer.extend(data)
|
||||
total_available = len(self._stderr_buffer) + len(self._stdout_buffer)
|
||||
try:
|
||||
self._data_available_cb(total_available)
|
||||
except Exception as e:
|
||||
self._log.error(f"stderr: error calling ProcessState data_available_callback {e}")
|
||||
|
||||
TERMSTATE_IDX_TERM = 0
|
||||
TERMSTATE_IDX_TIOS = 1
|
||||
@ -368,9 +406,15 @@ class Session:
|
||||
REQUEST_IDX_HPIX = 6
|
||||
REQUEST_IDX_VPIX = 7
|
||||
REQUEST_IDX_CMD = 8
|
||||
REQUEST_IDX_FLAGS = 9
|
||||
REQUEST_IDX_BYTES_AVAILABLE = 10
|
||||
REQUEST_FLAGS_PIPE_STDIN = 0x01
|
||||
REQUEST_FLAGS_PIPE_STDOUT = 0x02
|
||||
REQUEST_FLAGS_PIPE_STDERR = 0x04
|
||||
REQUEST_FLAGS_EOF_STDIN = 0x08
|
||||
|
||||
@staticmethod
|
||||
def default_request(stdin_fd: int | None) -> [any]:
|
||||
def default_request() -> [any]:
|
||||
global _tr
|
||||
request: list[any] = [
|
||||
_PROTOCOL_VERSION_DEFAULT, # 0 Protocol Version
|
||||
@ -382,16 +426,24 @@ class Session:
|
||||
None, # 6 terminal horizontal pixels
|
||||
None, # 7 terminal vertical pixels
|
||||
None, # 8 Command to run
|
||||
0, # 9 Flags
|
||||
0, # 10 Bytes Available
|
||||
].copy()
|
||||
|
||||
if stdin_fd is not None:
|
||||
if os.isatty(0):
|
||||
request[Session.REQUEST_IDX_TERM] = os.environ.get("TERM", None)
|
||||
request[Session.REQUEST_IDX_TIOS] = _tr.original_attr() if _tr else None
|
||||
with contextlib.suppress(OSError):
|
||||
request[Session.REQUEST_IDX_ROWS], \
|
||||
request[Session.REQUEST_IDX_COLS], \
|
||||
request[Session.REQUEST_IDX_HPIX], \
|
||||
request[Session.REQUEST_IDX_VPIX] = process.tty_get_winsize(stdin_fd)
|
||||
request[Session.REQUEST_IDX_VPIX] = process.tty_get_winsize(0)
|
||||
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], not os.isatty(0),
|
||||
Session.REQUEST_FLAGS_PIPE_STDIN)
|
||||
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], not os.isatty(1),
|
||||
Session.REQUEST_FLAGS_PIPE_STDOUT)
|
||||
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], not os.isatty(2),
|
||||
Session.REQUEST_FLAGS_PIPE_STDERR)
|
||||
return request
|
||||
|
||||
def process_request(self, data: [any], read_size: int) -> [any]:
|
||||
@ -403,12 +455,16 @@ class Session:
|
||||
# hpix = data[ProcessState.REQUEST_IDX_HPIX] # window horizontal pixels
|
||||
# vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels
|
||||
# term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX+1]
|
||||
bytes_available = data[Session.REQUEST_IDX_BYTES_AVAILABLE]
|
||||
flags = data[Session.REQUEST_IDX_FLAGS]
|
||||
stdin_eof = _check_and(flags, Session.REQUEST_FLAGS_EOF_STDIN)
|
||||
response = Session.default_response()
|
||||
|
||||
first_term_state = self._term_state is None
|
||||
term_state = data[Session.REQUEST_IDX_TIOS:Session.REQUEST_IDX_VPIX + 1]
|
||||
|
||||
response[Session.RESPONSE_IDX_RUNNING] = self.process.running
|
||||
response[Session.RESPONSE_IDX_FLAGS] = _bitwise_or_if(response[Session.RESPONSE_IDX_FLAGS],
|
||||
self.process.running, Session.RESPONSE_FLAGS_RUNNING)
|
||||
if self.process.running:
|
||||
if term_state != self._term_state:
|
||||
self._term_state = term_state
|
||||
@ -422,45 +478,54 @@ class Session:
|
||||
if data[Session.REQUEST_IDX_VERS] < _PROTOCOL_VERSION_2:
|
||||
stdin = base64.b64decode(stdin)
|
||||
self.process.write(stdin)
|
||||
if stdin_eof and bytes_available == 0:
|
||||
module_logger.debug("Closing stdin")
|
||||
with contextlib.suppress(Exception):
|
||||
self.process.close_stdin()
|
||||
response[Session.RESPONSE_IDX_RETCODE] = None if self.process.running else self.return_code
|
||||
|
||||
with self.lock:
|
||||
stdout = self.read(read_size)
|
||||
response[Session.RESPONSE_IDX_RDYBYTE] = len(self._data_buffer)
|
||||
#prioritizing stderr
|
||||
stderr = self.read_stderr(read_size)
|
||||
stdout = self.read_stdout(read_size - len(stderr))
|
||||
response[Session.RESPONSE_IDX_RDYBYTE] = len(self._stdout_buffer) + len(self._stderr_buffer)
|
||||
|
||||
if stderr is not None and len(stderr) > 0:
|
||||
response[Session.RESPONSE_IDX_STDERR] = bytes(stderr)
|
||||
if stdout is not None and len(stdout) > 0:
|
||||
if data[Session.REQUEST_IDX_VERS] < _PROTOCOL_VERSION_2:
|
||||
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode(stdout).decode("utf-8")
|
||||
else:
|
||||
response[Session.RESPONSE_IDX_STDOUT] = bytes(stdout)
|
||||
return response
|
||||
|
||||
RESPONSE_IDX_VERSION = 0
|
||||
RESPONSE_IDX_RUNNING = 1
|
||||
RESPONSE_IDX_FLAGS = 1
|
||||
RESPONSE_IDX_RETCODE = 2
|
||||
RESPONSE_IDX_RDYBYTE = 3
|
||||
RESPONSE_IDX_STDOUT = 4
|
||||
RESPONSE_IDX_TMSTAMP = 5
|
||||
RESPONSE_IDX_STDERR = 4
|
||||
RESPONSE_IDX_STDOUT = 5
|
||||
RESPONSE_IDX_TMSTAMP = 6
|
||||
RESPONSE_FLAGS_RUNNING = 0x01
|
||||
RESPONSE_FLAGS_EOF_STDOUT = 0x02
|
||||
RESPONSE_FLAGS_EOF_STDERR = 0x04
|
||||
|
||||
@staticmethod
|
||||
def default_response(version: int = _PROTOCOL_VERSION_2) -> [any]:
|
||||
def default_response() -> [any]:
|
||||
response: list[any] = [
|
||||
version, # 0: Protocol version
|
||||
_PROTOCOL_VERSION_DEFAULT, # 0: Protocol version
|
||||
False, # 1: Process running
|
||||
None, # 2: Return value
|
||||
0, # 3: Number of outstanding bytes
|
||||
None, # 4: Stdout/Stderr
|
||||
None, # 5: Timestamp
|
||||
None, # 4: Stderr
|
||||
None, # 5: Stdout
|
||||
None, # 6: Timestamp
|
||||
].copy()
|
||||
response[Session.RESPONSE_IDX_TMSTAMP] = time.time()
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def error_response(cls, msg: str, version: int = _PROTOCOL_VERSION_2) -> [any]:
|
||||
response = cls.default_response(version)
|
||||
def error_response(cls, msg: str) -> [any]:
|
||||
response = cls.default_response()
|
||||
msg_bytes = f"{msg}\r\n".encode("utf-8")
|
||||
response[Session.RESPONSE_IDX_STDOUT] = \
|
||||
base64.b64encode(msg_bytes) if version < _PROTOCOL_VERSION_2 else bytes(msg_bytes)
|
||||
response[Session.RESPONSE_IDX_STDERR] = bytes(msg_bytes)
|
||||
response[Session.RESPONSE_IDX_RETCODE] = 255
|
||||
response[Session.RESPONSE_IDX_RDYBYTE] = 0
|
||||
return response
|
||||
@ -543,11 +608,12 @@ def _subproc_terminated(link: RNS.Link, return_code: int):
|
||||
|
||||
|
||||
def _listen_start_proc(link: RNS.Link, remote_identity: str | None, term: str, cmd: [str],
|
||||
loop: asyncio.AbstractEventLoop) -> Session | None:
|
||||
loop: asyncio.AbstractEventLoop, session_flags: int) -> Session | None:
|
||||
log = _get_logger("_listen_start_proc")
|
||||
try:
|
||||
return Session(tag=link.link_id,
|
||||
cmd=cmd,
|
||||
session_flags=session_flags,
|
||||
term=term,
|
||||
remote_identity=remote_identity,
|
||||
loop=loop,
|
||||
@ -602,17 +668,19 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
|
||||
_retry_timer.complete(link_id)
|
||||
link: RNS.Link = next(filter(lambda l: l.link_id == link_id, _destination.links), None)
|
||||
if link is None:
|
||||
raise Exception(f"Invalid request {request_id}, no link found with id {link_id}")
|
||||
log.error(f"Invalid request {request_id}, no link found with id {link_id}")
|
||||
return None
|
||||
|
||||
remote_version = data[Session.REQUEST_IDX_VERS]
|
||||
if not _protocol_check_magic(remote_version):
|
||||
raise Exception("Request magic incorrect")
|
||||
log.error("Request magic incorrect")
|
||||
link.teardown()
|
||||
return None
|
||||
|
||||
if not remote_version <= _PROTOCOL_VERSION_DEFAULT:
|
||||
if not remote_version <= _PROTOCOL_VERSION_3:
|
||||
return Session.error_response("Listener<->initiator version mismatch")
|
||||
|
||||
cmd = _cmd.copy()
|
||||
if remote_version >= _PROTOCOL_VERSION_1:
|
||||
remote_command = data[Session.REQUEST_IDX_CMD]
|
||||
if remote_command is not None and len(remote_command) > 0:
|
||||
if _no_remote_command:
|
||||
@ -636,9 +704,12 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
|
||||
log.debug(f"Process not found for link {link}")
|
||||
session = _listen_start_proc(link=link,
|
||||
term=term,
|
||||
session_flags=data[Session.REQUEST_IDX_FLAGS],
|
||||
cmd=cmd,
|
||||
remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""),
|
||||
loop=_loop)
|
||||
if session is None:
|
||||
return Session.error_response("Unable to start subprocess")
|
||||
|
||||
# leave significant headroom for metadata and encoding
|
||||
result = session.process_request(data, _protocol_response_chars_take(link.MDU, remote_version))
|
||||
@ -695,7 +766,7 @@ def _response_handler(request_receipt: RNS.RequestReceipt):
|
||||
|
||||
async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=False, destination=None,
|
||||
service_name="default", stdin=None, timeout=RNS.Transport.PATH_REQUEST_TIMEOUT,
|
||||
cmd: [str] | None = None):
|
||||
cmd: [str] | None = None, stdin_eof=False, bytes_available=0):
|
||||
global _identity, _reticulum, _link, _destination, _remote_exec_grace, _tr, _new_data
|
||||
log = _get_logger("_execute")
|
||||
|
||||
@ -749,11 +820,14 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
||||
|
||||
_link.set_packet_callback(_client_packet_handler)
|
||||
|
||||
request = Session.default_request(sys.stdin.fileno())
|
||||
request = Session.default_request()
|
||||
log.debug(f"Sending {len(stdin) or 0} bytes to listener")
|
||||
# log.debug(f"Sending {stdin} to listener")
|
||||
request[Session.REQUEST_IDX_STDIN] = bytes(stdin)
|
||||
request[Session.REQUEST_IDX_CMD] = cmd
|
||||
request[Session.REQUEST_IDX_FLAGS] = _bitwise_or_if(request[Session.REQUEST_IDX_FLAGS], stdin_eof,
|
||||
Session.REQUEST_FLAGS_EOF_STDIN)
|
||||
request[Session.REQUEST_IDX_BYTES_AVAILABLE] = bytes_available
|
||||
|
||||
# TODO: Tune
|
||||
timeout = timeout + _link.rtt * 4 + _remote_exec_grace
|
||||
@ -796,13 +870,17 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
||||
version = request_receipt.response[Session.RESPONSE_IDX_VERSION] or 0
|
||||
if not _protocol_check_magic(version):
|
||||
raise RemoteExecutionError("Protocol error")
|
||||
elif version != _PROTOCOL_VERSION_2:
|
||||
elif version != _PROTOCOL_VERSION_3:
|
||||
raise RemoteExecutionError("Protocol version mismatch")
|
||||
|
||||
running = request_receipt.response[Session.RESPONSE_IDX_RUNNING] or True
|
||||
flags = request_receipt.response[Session.RESPONSE_IDX_FLAGS]
|
||||
running = _check_and(flags, Session.RESPONSE_FLAGS_RUNNING)
|
||||
stdout_eof = _check_and(flags, Session.RESPONSE_FLAGS_EOF_STDOUT)
|
||||
stderr_eof = _check_and(flags, Session.RESPONSE_FLAGS_EOF_STDERR)
|
||||
return_code = request_receipt.response[Session.RESPONSE_IDX_RETCODE]
|
||||
ready_bytes = request_receipt.response[Session.RESPONSE_IDX_RDYBYTE] or 0
|
||||
stdout = request_receipt.response[Session.RESPONSE_IDX_STDOUT]
|
||||
stderr = request_receipt.response[Session.RESPONSE_IDX_STDERR]
|
||||
# if stdout is not None:
|
||||
# stdout = base64.b64decode(stdout)
|
||||
timestamp = request_receipt.response[Session.RESPONSE_IDX_TMSTAMP]
|
||||
@ -815,10 +893,24 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
||||
if stdout is not None:
|
||||
_tr.raw()
|
||||
log.debug(f"stdout: {stdout}")
|
||||
os.write(sys.stdout.fileno(), stdout)
|
||||
os.write(1, stdout)
|
||||
sys.stdout.flush()
|
||||
|
||||
got_bytes = len(stdout) if stdout is not None else 0
|
||||
if stderr is not None:
|
||||
_tr.raw()
|
||||
log.debug(f"stderr: {stderr}")
|
||||
os.write(2, stderr)
|
||||
sys.stderr.flush()
|
||||
|
||||
if stderr_eof and ready_bytes == 0:
|
||||
log.debug("Closing stderr")
|
||||
os.close(2)
|
||||
|
||||
if stdout_eof and ready_bytes == 0:
|
||||
log.debug("Closing stdout")
|
||||
os.close(1)
|
||||
|
||||
got_bytes = (len(stdout) if stdout is not None else 0) + (len(stderr) if stderr is not None else 0)
|
||||
log.debug(f"{got_bytes} chars received, {ready_bytes} bytes ready on server, return code {return_code}")
|
||||
|
||||
if ready_bytes > 0:
|
||||
@ -847,7 +939,9 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
|
||||
if _new_data is not None:
|
||||
_new_data.set()
|
||||
|
||||
stdin_eof = False
|
||||
def stdin():
|
||||
nonlocal stdin_eof
|
||||
try:
|
||||
data = process.tty_read(sys.stdin.fileno())
|
||||
log.debug(f"stdin {data}")
|
||||
@ -855,7 +949,9 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
|
||||
data_buffer.extend(data)
|
||||
_new_data.set()
|
||||
except EOFError:
|
||||
if os.isatty(0):
|
||||
data_buffer.extend(process.CTRL_D)
|
||||
stdin_eof = True
|
||||
process.tty_unset_reader_callbacks(sys.stdin.fileno())
|
||||
|
||||
process.tty_add_reader_callback(sys.stdin.fileno(), stdin)
|
||||
@ -886,6 +982,8 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
|
||||
stdin=stdin,
|
||||
timeout=timeout,
|
||||
cmd=command,
|
||||
stdin_eof=stdin_eof,
|
||||
bytes_available=len(data_buffer)
|
||||
)
|
||||
|
||||
if first_loop:
|
||||
@ -975,11 +1073,6 @@ async def _rnsh_cli_main():
|
||||
|
||||
def rnsh_cli():
|
||||
global _tr, _retry_timer, _pre_input
|
||||
with contextlib.suppress(Exception):
|
||||
if not os.isatty(sys.stdin.fileno()):
|
||||
time.sleep(0.1) # attempting to deal with an issue with missing input
|
||||
tty.setraw(sys.stdin.fileno(), termios.TCSADRAIN)
|
||||
|
||||
with process.TTYRestorer(sys.stdin.fileno()) as _tr, retry.RetryThread() as _retry_timer:
|
||||
return_code = asyncio.run(_rnsh_cli_main())
|
||||
|
||||
|
@ -37,7 +37,11 @@ class SubprocessReader(contextlib.AbstractContextManager):
|
||||
env=self.env,
|
||||
loop=self.loop,
|
||||
stdout_callback=self._stdout_cb,
|
||||
terminated_callback=self._terminated_cb)
|
||||
terminated_callback=self._terminated_cb,
|
||||
stderr_callback=self._stdout_cb,
|
||||
stdin_is_pipe=False,
|
||||
stdout_is_pipe=False,
|
||||
stderr_is_pipe=False)
|
||||
|
||||
def _stdout_cb(self, data):
|
||||
self._log.debug(f"_stdout_cb({data})")
|
||||
|
7
tty_test.py
Normal file
7
tty_test.py
Normal file
@ -0,0 +1,7 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
for stream in [sys.stdin, sys.stdout, sys.stderr]:
|
||||
print(f"{stream.name:8s} " + ("tty" if os.isatty(stream.fileno()) else "not tty"))
|
||||
|
||||
print(f"args: {sys.argv}")
|
Loading…
Reference in New Issue
Block a user