Merge pull request #1 from acehoss/feature/pipe-tty-fix-rebased

Fix running initiator from a script/pipe
This commit is contained in:
acehoss 2023-02-13 15:46:14 -06:00 committed by GitHub
commit 5ce4c342bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 74 additions and 32 deletions

View File

@ -1,7 +1,7 @@
# `r n s h`  Shell over Reticulum # `r n s h`  Shell over Reticulum
[![CI](https://github.com/acehoss/rnsh/actions/workflows/python-package.yml/badge.svg)](https://github.com/acehoss/rnsh/actions/workflows/python-package.yml)  [![CI](https://github.com/acehoss/rnsh/actions/workflows/python-package.yml/badge.svg)](https://github.com/acehoss/rnsh/actions/workflows/python-package.yml) 
[![Release](https://github.com/acehoss/rnsh/actions/workflows/python-publish.yml/badge.svg)](https://github.com/acehoss/rnsh/actions/workflows/python-publish.yml)  [![Release](https://github.com/acehoss/rnsh/actions/workflows/python-publish.yml/badge.svg)](https://github.com/acehoss/rnsh/actions/workflows/python-publish.yml) 
[![PyPI version](https://badge.fury.io/py/rnsh.svg)](https://badge.fury.io/py/rnsh) [![PyPI version](https://badge.fury.io/py/rnsh.svg)](https://badge.fury.io/py/rnsh)  
![PyPI - Downloads](https://img.shields.io/pypi/dw/rnsh?color=informational&label=Installs&logo=pypi) ![PyPI - Downloads](https://img.shields.io/pypi/dw/rnsh?color=informational&label=Installs&logo=pypi)
`rnsh` is a utility written in Python that facilitates shell `rnsh` is a utility written in Python that facilitates shell
@ -165,5 +165,5 @@ The protocol is build on top of the Reticulum `Request` and
- [X] ~~Improve signal handling~~ - [X] ~~Improve signal handling~~
- [ ] Protocol improvements (throughput!) - [ ] Protocol improvements (throughput!)
- [ ] Test on several *nixes - [ ] Test on several *nixes
- [ ] Make it scriptable (currently requires a tty) - [X] ~~Make it scriptable (currently requires a tty)~~
- [ ] Documentation improvements - [ ] Documentation improvements

View File

@ -76,12 +76,12 @@ def tty_read(fd: int) -> bytes:
:return: bytes read :return: bytes read
""" """
if fd_is_closed(fd): if fd_is_closed(fd):
return None raise EOFError
try: try:
run = True run = True
result = bytearray() result = bytearray()
while run and not fd_is_closed(fd): while not fd_is_closed(fd):
ready, _, _ = select.select([fd], [], [], 0) ready, _, _ = select.select([fd], [], [], 0)
if len(ready) == 0: if len(ready) == 0:
break break
@ -93,10 +93,16 @@ def tty_read(fd: int) -> bytes:
raise raise
else: else:
if not data: # EOF if not data: # EOF
run = False
if data is not None and len(data) > 0: if data is not None and len(data) > 0:
result.extend(data) result.extend(data)
return result return result
else:
raise EOFError
if data is not None and len(data) > 0:
result.extend(data)
return result
except EOFError:
raise
except Exception as ex: except Exception as ex:
module_logger.error("tty_read error: {ex}") module_logger.error("tty_read error: {ex}")
@ -193,7 +199,7 @@ class TTYRestorer(contextlib.AbstractContextManager):
self._suppress_logs = suppress_logs self._suppress_logs = suppress_logs
self._tattr = self.current_attr() self._tattr = self.current_attr()
if not self._tattr and not self._suppress_logs: if not self._tattr and not self._suppress_logs:
self._log.warning(f"Could not get attrs for fd {fd}") self._log.debug(f"Could not get attrs for fd {fd}")
def raw(self): def raw(self):
""" """
@ -231,6 +237,9 @@ class TTYRestorer(contextlib.AbstractContextManager):
with contextlib.suppress(termios.error): with contextlib.suppress(termios.error):
termios.tcsetattr(self._fd, when, attr) termios.tcsetattr(self._fd, when, attr)
def isatty(self):
return os.isatty(self._fd) if self._fd is not None else None
def restore(self): def restore(self):
""" """
Restore termios settings to state captured in constructor. Restore termios settings to state captured in constructor.
@ -363,6 +372,7 @@ class CallbackSubprocess:
self._loop.call_later(kill_delay, kill) self._loop.call_later(kill_delay, kill)
def wait(): def wait():
with contextlib.suppress(OSError):
self._log.debug("wait()") self._log.debug("wait()")
os.waitpid(self._pid, 0) os.waitpid(self._pid, 0)
self._log.debug("wait() finish") self._log.debug("wait() finish")
@ -426,6 +436,9 @@ class CallbackSubprocess:
""" """
return termios.tcgetattr(self._child_fd) return termios.tcgetattr(self._child_fd)
def ttysetraw(self):
tty.setraw(self._child_fd, termios.TCSANOW)
def start(self): def start(self):
""" """
Start the child process. Start the child process.
@ -489,10 +502,14 @@ class CallbackSubprocess:
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll) self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
def reader(fd: int, callback: callable): def reader(fd: int, callback: callable):
try:
with exception.permit(SystemExit): with exception.permit(SystemExit):
data = tty_read(fd) data = tty_read(fd)
if data is not None and len(data) > 0: if data is not None and len(data) > 0:
callback(data) callback(data)
except EOFError:
tty_unset_reader_callbacks(self._child_fd)
callback(CTRL_D)
tty_add_reader_callback(self._child_fd, functools.partial(reader, self._child_fd, self._stdout_cb), self._loop) tty_add_reader_callback(self._child_fd, functools.partial(reader, self._child_fd, self._stdout_cb), self._loop)
@ -546,11 +563,15 @@ async def main():
signal.signal(signal.SIGWINCH, sigwinch_handler) signal.signal(signal.SIGWINCH, sigwinch_handler)
def stdin(): def stdin():
try:
data = tty_read(sys.stdin.fileno()) data = tty_read(sys.stdin.fileno())
# log.debug(f"stdin {data}") # log.debug(f"stdin {data}")
if data is not None: if data is not None:
process.write(data) process.write(data)
# sys.stdout.buffer.write(data) # sys.stdout.buffer.write(data)
except EOFError:
tty_unset_reader_callbacks(sys.stdin.fileno())
process.write(CTRL_D)
tty_add_reader_callback(sys.stdin.fileno(), stdin) tty_add_reader_callback(sys.stdin.fileno(), stdin)
process.start() process.start()

View File

@ -43,6 +43,8 @@ import rnsh.process as process
import rnsh.retry as retry import rnsh.retry as retry
import rnsh.rnslogging as rnslogging import rnsh.rnslogging as rnslogging
import rnsh.hacks as hacks import rnsh.hacks as hacks
import re
import contextlib
module_logger = __logging.getLogger(__name__) module_logger = __logging.getLogger(__name__)
@ -351,7 +353,8 @@ class Session:
if stdin_fd is not None: if stdin_fd is not None:
request[Session.REQUEST_IDX_TERM] = os.environ.get("TERM", None) request[Session.REQUEST_IDX_TERM] = os.environ.get("TERM", None)
request[Session.REQUEST_IDX_TIOS] = _tr.original_attr() if _tr else termios.tcgetattr(stdin_fd) request[Session.REQUEST_IDX_TIOS] = _tr.original_attr() if _tr else None
with contextlib.suppress(OSError):
request[Session.REQUEST_IDX_ROWS], \ request[Session.REQUEST_IDX_ROWS], \
request[Session.REQUEST_IDX_COLS], \ request[Session.REQUEST_IDX_COLS], \
request[Session.REQUEST_IDX_HPIX], \ request[Session.REQUEST_IDX_HPIX], \
@ -368,14 +371,20 @@ class Session:
# vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels # vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels
# term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX+1] # term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX+1]
response = Session.default_response() response = Session.default_response()
first_term_state = self._term_state is None
term_state = data[Session.REQUEST_IDX_TIOS:Session.REQUEST_IDX_VPIX + 1] term_state = data[Session.REQUEST_IDX_TIOS:Session.REQUEST_IDX_VPIX + 1]
response[Session.RESPONSE_IDX_RUNNING] = self.process.running response[Session.RESPONSE_IDX_RUNNING] = self.process.running
if self.process.running: if self.process.running:
if term_state != self._term_state: if term_state != self._term_state:
self._term_state = term_state self._term_state = term_state
if term_state is not None:
self._update_winsz() self._update_winsz()
# self.process.tcsetattr(termios.TCSANOW, self._term_state[0]) if first_term_state is not None:
# TODO: use a more specific error
with contextlib.suppress(Exception):
self.process.tcsetattr(termios.TCSANOW, term_state[0])
if stdin is not None and len(stdin) > 0: if stdin is not None and len(stdin) > 0:
stdin = base64.b64decode(stdin) stdin = base64.b64decode(stdin)
self.process.write(stdin) self.process.write(stdin)
@ -567,6 +576,9 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
session: Session | None = None session: Session | None = None
try: try:
term = data[Session.REQUEST_IDX_TERM] term = data[Session.REQUEST_IDX_TERM]
# sanitize
if term is not None:
term = re.sub('[^A-Za-z-0-9\-\_]','', term)
session = Session.get_for_tag(link.link_id) session = Session.get_for_tag(link.link_id)
if session is None: if session is None:
log.debug(f"Process not found for link {link}") log.debug(f"Process not found for link {link}")
@ -684,6 +696,8 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
_link.set_packet_callback(_client_packet_handler) _link.set_packet_callback(_client_packet_handler)
request = Session.default_request(sys.stdin.fileno()) request = Session.default_request(sys.stdin.fileno())
log.debug(f"Sending {len(stdin) or 0} bytes to listener")
# log.debug(f"Sending {stdin} to listener")
request[Session.REQUEST_IDX_STDIN] = (base64.b64encode(stdin) if stdin is not None else None) request[Session.REQUEST_IDX_STDIN] = (base64.b64encode(stdin) if stdin is not None else None)
# TODO: Tune # TODO: Tune
@ -769,7 +783,7 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
_new_data = asyncio.Event() _new_data = asyncio.Event()
data_buffer = bytearray() data_buffer = bytearray(sys.stdin.buffer.read()) if not os.isatty(sys.stdin.fileno()) else bytearray()
def sigwinch_handler(): def sigwinch_handler():
# log.debug("WindowChanged") # log.debug("WindowChanged")
@ -777,11 +791,15 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
_new_data.set() _new_data.set()
def stdin(): def stdin():
try:
data = process.tty_read(sys.stdin.fileno()) data = process.tty_read(sys.stdin.fileno())
log.debug(f"stdin {data}") log.debug(f"stdin {data}")
if data is not None: if data is not None:
data_buffer.extend(data) data_buffer.extend(data)
_new_data.set() _new_data.set()
except EOFError:
data_buffer.extend(process.CTRL_D)
process.tty_unset_reader_callbacks(sys.stdin.fileno())
process.tty_add_reader_callback(sys.stdin.fileno(), stdin) process.tty_add_reader_callback(sys.stdin.fileno(), stdin)
@ -978,6 +996,9 @@ def _noop():
def rnsh_cli(): def rnsh_cli():
global _tr, _retry_timer global _tr, _retry_timer
with contextlib.suppress(Exception):
if not os.isatty(sys.stdin.fileno()):
tty.setraw(sys.stdin.fileno(), termios.TCSANOW)
with process.TTYRestorer(sys.stdin.fileno()) as _tr, retry.RetryThread() as _retry_timer: with process.TTYRestorer(sys.stdin.fileno()) as _tr, retry.RetryThread() as _retry_timer:
return_code = asyncio.run(_rnsh_cli_main()) return_code = asyncio.run(_rnsh_cli_main())

View File

@ -142,7 +142,7 @@ def _rns_log(msg, level=3, _override_destination=False):
termios.ONLRET | termios.ONLCR | termios.OPOST termios.ONLRET | termios.ONLCR | termios.OPOST
tr.set_attr(attr) tr.set_attr(attr)
_rns_log_orig(msg, level, _override_destination) _rns_log_orig(msg, level, _override_destination)
except ValueError: except:
_rns_log_orig(msg, level, _override_destination) _rns_log_orig(msg, level, _override_destination)
# TODO: figure out if forcing this to the main thread actually helps. # TODO: figure out if forcing this to the main thread actually helps.