mirror of
https://github.com/markqvist/rnsh.git
synced 2025-02-25 01:20:00 -05:00
453 lines
15 KiB
Python
453 lines
15 KiB
Python
# MIT License
|
|
#
|
|
# Copyright (c) 2023 Aaron Heise
|
|
#
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
# of this software and associated documentation files (the "Software"), to deal
|
|
# in the Software without restriction, including without limitation the rights
|
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
# copies of the Software, and to permit persons to whom the Software is
|
|
# furnished to do so, subject to the following conditions:
|
|
#
|
|
# The above copyright notice and this permission notice shall be included in all
|
|
# copies or substantial portions of the Software.
|
|
#
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
# SOFTWARE.
|
|
|
|
import asyncio
|
|
import contextlib
|
|
import errno
|
|
import functools
|
|
import re
|
|
import signal
|
|
import struct
|
|
import threading
|
|
import tty
|
|
import pty
|
|
import os
|
|
import asyncio
|
|
import sys
|
|
import fcntl
|
|
import psutil
|
|
import select
|
|
import termios
|
|
import logging as __logging
|
|
module_logger = __logging.getLogger(__name__)
|
|
|
|
def tty_add_reader_callback(fd: int, callback: callable, loop: asyncio.AbstractEventLoop | None = None):
|
|
"""
|
|
Add an async reader callback for a tty file descriptor.
|
|
|
|
Example usage:
|
|
|
|
def reader():
|
|
data = tty_read(fd)
|
|
# do something with data
|
|
|
|
tty_add_reader_callback(self._child_fd, reader, self._loop)
|
|
|
|
:param fd: file descriptor
|
|
:param callback: callback function
|
|
:param loop: asyncio event loop to which the reader should be added. If None, use the currently-running loop.
|
|
"""
|
|
if loop is None:
|
|
loop = asyncio.get_running_loop()
|
|
loop.add_reader(fd, callback)
|
|
|
|
def tty_read(fd: int) -> bytes | None:
|
|
"""
|
|
Read available bytes from a tty file descriptor. When used in a callback added to a file descriptor using
|
|
tty_add_reader_callback(...), this function creates a solution for non-blocking reads from ttys.
|
|
:param fd: tty file descriptor
|
|
:return: bytes read
|
|
"""
|
|
if fd_is_closed(fd):
|
|
return None
|
|
|
|
run = True
|
|
result = bytearray()
|
|
while run and not fd_is_closed(fd):
|
|
ready, _, _ = select.select([fd], [], [], 0)
|
|
if len(ready) == 0:
|
|
break
|
|
for f in ready:
|
|
try:
|
|
data = os.read(fd, 512)
|
|
except OSError as e:
|
|
if e.errno != errno.EIO and e.errno != errno.EWOULDBLOCK:
|
|
raise
|
|
else:
|
|
if not data: # EOF
|
|
run = False
|
|
if data is not None and len(data) > 0:
|
|
result.extend(data)
|
|
return result
|
|
|
|
def fd_is_closed(fd: int) -> bool:
|
|
"""
|
|
Check if file descriptor is closed
|
|
:param fd: file descriptor
|
|
:return: True if file descriptor is closed
|
|
"""
|
|
try:
|
|
fcntl.fcntl(fd, fcntl.F_GETFL) < 0
|
|
except OSError as ose:
|
|
return ose.errno == errno.EBADF
|
|
|
|
def tty_unset_reader_callbacks(fd: int, loop: asyncio.AbstractEventLoop | None = None):
|
|
"""
|
|
Remove async reader callbacks for file descriptor.
|
|
:param fd: file descriptor
|
|
:param loop: asyncio event loop from which to remove callbacks
|
|
"""
|
|
try:
|
|
if loop is None:
|
|
loop = asyncio.get_running_loop()
|
|
loop.remove_reader(fd)
|
|
except:
|
|
pass
|
|
|
|
def tty_get_winsize(fd: int) -> [int, int, int , int]:
|
|
"""
|
|
Ge the window size of a tty.
|
|
:param fd: file descriptor of tty
|
|
:return: (rows, cols, h_pixels, v_pixels)
|
|
"""
|
|
packed = fcntl.ioctl(fd, termios.TIOCGWINSZ, struct.pack('HHHH', 0, 0, 0, 0))
|
|
rows, cols, h_pixels, v_pixels = struct.unpack('HHHH', packed)
|
|
return rows, cols, h_pixels, v_pixels
|
|
|
|
def tty_set_winsize(fd: int, rows: int, cols: int, h_pixels: int, v_pixels: int):
|
|
"""
|
|
Set the window size on a tty.
|
|
:param fd: file descriptor of tty
|
|
:param rows: number of visible rows
|
|
:param cols: number of visible columns
|
|
:param h_pixels: number of visible horizontal pixels
|
|
:param v_pixels: number of visible vertical pixels
|
|
"""
|
|
if fd < 0:
|
|
return
|
|
packed = struct.pack('HHHH', rows, cols, h_pixels, v_pixels)
|
|
fcntl.ioctl(fd, termios.TIOCSWINSZ, packed)
|
|
|
|
def process_exists(pid) -> bool:
|
|
"""
|
|
Check For the existence of a unix pid.
|
|
:param pid: process id to check
|
|
:return: True if process exists
|
|
"""
|
|
try:
|
|
os.kill(pid, 0)
|
|
except OSError:
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
class TtyRestorer:
|
|
def __init__(self, fd: int):
|
|
"""
|
|
Saves termios attributes for a tty for later restoration
|
|
:param fd: file descriptor of tty
|
|
"""
|
|
self._fd = fd
|
|
self._tattr = termios.tcgetattr(self._fd)
|
|
|
|
def raw(self):
|
|
"""
|
|
Set raw mode on tty
|
|
"""
|
|
tty.setraw(self._fd, termios.TCSADRAIN)
|
|
|
|
def restore(self):
|
|
"""
|
|
Restore termios settings to state captured in constructor.
|
|
"""
|
|
termios.tcsetattr(self._fd, termios.TCSADRAIN, self._tattr)
|
|
|
|
|
|
async def event_wait(evt: asyncio.Event, timeout: float) -> bool:
|
|
"""
|
|
Wait for event to be set, or timeout to expire.
|
|
:param evt: asyncio.Event to wait on
|
|
:param timeout: maximum number of seconds to wait.
|
|
:return: True if event was set, False if timeout expired
|
|
"""
|
|
# suppress TimeoutError because we'll return False in case of timeout
|
|
with contextlib.suppress(asyncio.TimeoutError):
|
|
await asyncio.wait_for(evt.wait(), timeout)
|
|
return evt.is_set()
|
|
|
|
|
|
class CallbackSubprocess:
|
|
# time between checks of child process
|
|
PROCESS_POLL_TIME: float = 0.1
|
|
|
|
def __init__(self, argv: [str], env: dict | None, loop: asyncio.AbstractEventLoop, stdout_callback: callable,
|
|
terminated_callback: callable):
|
|
"""
|
|
Fork a child process and generate callbacks with output from the process.
|
|
:param argv: the command line, tokenized. The first element must be the absolute path to an executable file.
|
|
:param term: the value that should be set for TERM. If None, the value from the parent process will be used
|
|
:param loop: the asyncio event loop to use
|
|
:param stdout_callback: callback for data, e.g. def callback(data:bytes) -> None
|
|
:param terminated_callback: callback for termination/return code, e.g. def callback(return_code:int) -> None
|
|
"""
|
|
assert loop is not None, "loop should not be None"
|
|
assert stdout_callback is not None, "stdout_callback should not be None"
|
|
assert terminated_callback is not None, "terminated_callback should not be None"
|
|
|
|
self._log = module_logger.getChild(self.__class__.__name__)
|
|
# self._log.debug(f"__init__({argv},{term},...")
|
|
self._command: [str] = argv
|
|
self._env = env or {}
|
|
self._loop = loop
|
|
self._stdout_cb = stdout_callback
|
|
self._terminated_cb = terminated_callback
|
|
self._pid: int | None = None
|
|
self._child_fd: int | None = None
|
|
self._return_code: int | None = None
|
|
|
|
def terminate(self, kill_delay: float = 1.0):
|
|
"""
|
|
Terminate child process if running
|
|
:param kill_delay: if after kill_delay seconds the child process has not exited, escalate to SIGHUP and SIGKILL
|
|
"""
|
|
self._log.debug("terminate()")
|
|
if not self.running:
|
|
return
|
|
|
|
try:
|
|
os.kill(self._pid, signal.SIGTERM)
|
|
except:
|
|
pass
|
|
|
|
def kill():
|
|
if process_exists(self._pid):
|
|
self._log.debug("kill()")
|
|
try:
|
|
os.kill(self._pid, signal.SIGHUP)
|
|
os.kill(self._pid, signal.SIGKILL)
|
|
except:
|
|
pass
|
|
|
|
self._loop.call_later(kill_delay, kill)
|
|
|
|
def wait():
|
|
self._log.debug("wait()")
|
|
os.waitpid(self._pid, 0)
|
|
self._log.debug("wait() finish")
|
|
|
|
threading.Thread(target=wait).start()
|
|
|
|
@property
|
|
def started(self) -> bool:
|
|
"""
|
|
:return: True if child process has been started
|
|
"""
|
|
return self._pid is not None
|
|
|
|
@property
|
|
def running(self) -> bool:
|
|
"""
|
|
:return: True if child process is still running
|
|
"""
|
|
return self._pid is not None and process_exists(self._pid)
|
|
|
|
def write(self, data: bytes):
|
|
"""
|
|
Write bytes to the stdin of the child process.
|
|
:param data: bytes to write
|
|
"""
|
|
self._log.debug(f"write({data})")
|
|
os.write(self._child_fd, data)
|
|
|
|
def set_winsize(self, r: int, c: int, h: int, v: int):
|
|
"""
|
|
Set the window size on the tty of the child process.
|
|
:param r: rows visible
|
|
:param c: columns visible
|
|
:param h: horizontal pixels visible
|
|
:param v: vertical pixels visible
|
|
:return:
|
|
"""
|
|
self._log.debug(f"set_winsize({r},{c},{h},{v}")
|
|
tty_set_winsize(self._child_fd, r, c, h, v)
|
|
|
|
def copy_winsize(self, fromfd:int):
|
|
"""
|
|
Copy window size from one tty to another.
|
|
:param fromfd: source tty file descriptor
|
|
"""
|
|
r,c,h,v = tty_get_winsize(fromfd)
|
|
self.set_winsize(r,c,h,v)
|
|
|
|
def tcsetattr(self, when: int, attr: list[int | list[int | bytes]]):
|
|
"""
|
|
Set tty attributes.
|
|
:param when: when to apply change: termios.TCSANOW or termios.TCSADRAIN or termios.TCSAFLUSH
|
|
:param attr: attributes to set
|
|
"""
|
|
termios.tcsetattr(self._child_fd, when, attr)
|
|
|
|
def tcgetattr(self) -> list[int | list[int | bytes]]:
|
|
"""
|
|
Get tty attributes.
|
|
:return: tty attributes value
|
|
"""
|
|
return termios.tcgetattr(self._child_fd)
|
|
|
|
def start(self):
|
|
"""
|
|
Start the child process.
|
|
"""
|
|
self._log.debug("start()")
|
|
|
|
# # Using the parent environment seems to do some weird stuff, at least on macOS
|
|
# parentenv = os.environ.copy()
|
|
# env = {"HOME": parentenv["HOME"],
|
|
# "PATH": parentenv["PATH"],
|
|
# "TERM": self._term if self._term is not None else parentenv.get("TERM", "xterm"),
|
|
# "LANG": parentenv.get("LANG"),
|
|
# "SHELL": self._command[0]}
|
|
|
|
env = os.environ.copy()
|
|
for key in self._env:
|
|
env[key] = self._env[key]
|
|
|
|
program = self._command[0]
|
|
# match = re.search("^/bin/(.*sh)$", program)
|
|
# if match:
|
|
# self._command[0] = "-" + match.group(1)
|
|
# env["SHELL"] = program
|
|
# self._log.debug(f"set login shell {self._command}")
|
|
|
|
|
|
self._pid, self._child_fd = pty.fork()
|
|
|
|
if self._pid == 0:
|
|
try:
|
|
# this may not be strictly necessary, but there was
|
|
# is occasionally some funny business that goes on
|
|
# with networking. Anecdotally this fixed it, but
|
|
# more testing is needed as it might be a coincidence.
|
|
p = psutil.Process()
|
|
for c in p.connections(kind='all'):
|
|
if c == sys.stdin.fileno() or c == sys.stdout.fileno() or c == sys.stderr.fileno():
|
|
continue
|
|
try:
|
|
os.close(c.fd)
|
|
except:
|
|
pass
|
|
os.setpgrp()
|
|
os.execvpe(program, self._command, env)
|
|
except Exception as err:
|
|
print(f"Child process error: {err}, command: {self._command}")
|
|
sys.stdout.flush()
|
|
# don't let any other modules get in our way.
|
|
os._exit(0)
|
|
|
|
def poll():
|
|
# self.log.debug("poll")
|
|
try:
|
|
pid, self._return_code = os.waitpid(self._pid, os.WNOHANG)
|
|
if self._return_code is not None and not process_exists(self._pid):
|
|
self._log.debug(f"polled return code {self._return_code}")
|
|
self._terminated_cb(self._return_code)
|
|
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
|
|
except Exception as e:
|
|
if not hasattr(e, "errno") or e.errno != errno.ECHILD:
|
|
self._log.debug(f"Error in process poll: {e}")
|
|
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
|
|
|
|
def reader(fd: int, callback: callable):
|
|
result = bytearray()
|
|
try:
|
|
c = tty_read(fd)
|
|
if c is not None and len(c) > 0:
|
|
callback(c)
|
|
except:
|
|
pass
|
|
|
|
tty_add_reader_callback(self._child_fd, functools.partial(reader, self._child_fd, self._stdout_cb), self._loop)
|
|
|
|
@property
|
|
def return_code(self) -> int | None:
|
|
return self._return_code
|
|
|
|
|
|
async def main():
|
|
"""
|
|
A test driver for the CallbackProcess class.
|
|
python ./process.py /bin/zsh --login
|
|
"""
|
|
import rnsh.testlogging
|
|
|
|
log = module_logger.getChild("main")
|
|
if len(sys.argv) <= 1:
|
|
print(f"Usage: {sys.argv} <absolute_path_to_child_executable> [child_arg ...]")
|
|
exit(1)
|
|
|
|
loop = asyncio.get_event_loop()
|
|
# asyncio.set_event_loop(loop)
|
|
retcode = loop.create_future()
|
|
|
|
def stdout(data: bytes):
|
|
# log.debug("stdout")
|
|
os.write(sys.stdout.fileno(), data)
|
|
# sys.stdout.flush()
|
|
|
|
def terminated(rc: int):
|
|
# log.debug(f"terminated {rc}")
|
|
retcode.set_result(rc)
|
|
|
|
process = CallbackSubprocess(argv=sys.argv[1:],
|
|
env={"TERM": os.environ.get("TERM", "xterm")},
|
|
loop=loop,
|
|
stdout_callback=stdout,
|
|
terminated_callback=terminated)
|
|
|
|
def sigint_handler(signal, frame):
|
|
# log.debug("KeyboardInterrupt")
|
|
if process is None or process.started and not process.running:
|
|
raise KeyboardInterrupt
|
|
elif process.running:
|
|
process.write("\x03".encode("utf-8"))
|
|
|
|
def sigwinch_handler(signal, frame):
|
|
# log.debug("WindowChanged")
|
|
process.copy_winsize(sys.stdin.fileno())
|
|
|
|
signal.signal(signal.SIGINT, sigint_handler)
|
|
signal.signal(signal.SIGWINCH, sigwinch_handler)
|
|
|
|
def stdin():
|
|
data = tty_read(sys.stdin.fileno())
|
|
# log.debug(f"stdin {data}")
|
|
if data is not None:
|
|
process.write(data)
|
|
# sys.stdout.buffer.write(data)
|
|
|
|
tty_add_reader_callback(sys.stdin.fileno(), stdin)
|
|
process.start()
|
|
# call_soon called it too soon, not sure why.
|
|
loop.call_later(0.001, functools.partial(process.copy_winsize, sys.stdin.fileno()))
|
|
|
|
val = await retcode
|
|
log.debug(f"got retcode {val}")
|
|
return val
|
|
|
|
if __name__ == "__main__":
|
|
tr = TtyRestorer(sys.stdin.fileno())
|
|
try:
|
|
tr.raw()
|
|
asyncio.run(main())
|
|
finally:
|
|
tty_unset_reader_callbacks(sys.stdin.fileno())
|
|
tr.restore() |