diff --git a/rnsh/process.py b/rnsh/process.py index 1544f45..630a620 100644 --- a/rnsh/process.py +++ b/rnsh/process.py @@ -35,7 +35,8 @@ import sys import termios import threading import tty - +import types +import typing import psutil import rnsh.exception as exception @@ -160,10 +161,26 @@ def process_exists(pid) -> bool: return True -class TtyRestorer: +class TTYRestorer(contextlib.AbstractContextManager): + # Indexes of flags within the attrs array + ATTR_IDX_IFLAG = 0 + ATTR_IDX_OFLAG = 1 + ATTR_IDX_CFLAG = 2 + ATTR_IDX_LFLAG = 4 + ATTR_IDX_CC = 5 + def __init__(self, fd: int): """ - Saves termios attributes for a tty for later restoration + Saves termios attributes for a tty for later restoration. + + The attributes are an array of values with the following meanings. + + tcflag_t c_iflag; /* input modes */ + tcflag_t c_oflag; /* output modes */ + tcflag_t c_cflag; /* control modes */ + tcflag_t c_lflag; /* local modes */ + cc_t c_cc[NCCS]; /* special characters */ + :param fd: file descriptor of tty """ self._fd = fd @@ -175,12 +192,31 @@ class TtyRestorer: """ tty.setraw(self._fd, termios.TCSADRAIN) + def current_attr(self) -> [any]: + """ + Get the current termios attributes for the wrapped fd. + :return: attribute array + """ + return termios.tcgetattr(self._fd).copy() + + def set_attr(self, attr: [any], when: int = termios.TCSANOW): + """ + Set termios attributes + :param attr: attribute list to set + """ + termios.tcsetattr(self._fd, when, attr) + def restore(self): """ Restore termios settings to state captured in constructor. """ termios.tcsetattr(self._fd, termios.TCSADRAIN, self._tattr) + def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException, + __traceback: types.TracebackType) -> bool: + self.restore() + return False + async def event_wait(evt: asyncio.Event, timeout: float) -> bool: """ @@ -444,7 +480,7 @@ async def main(): if __name__ == "__main__": - tr = TtyRestorer(sys.stdin.fileno()) + tr = TTYRestorer(sys.stdin.fileno()) try: tr.raw() asyncio.run(main()) diff --git a/rnsh/rnsh.py b/rnsh/rnsh.py index f5f474b..f53262e 100644 --- a/rnsh/rnsh.py +++ b/rnsh/rnsh.py @@ -558,7 +558,7 @@ async def _spin(until: callable = None, timeout: float | None = None) -> bool: _link: RNS.Link | None = None _remote_exec_grace = 2.0 _new_data: asyncio.Event | None = None -_tr = process.TtyRestorer(sys.stdin.fileno()) +_tr = process.TTYRestorer(sys.stdin.fileno()) def _client_packet_handler(message, packet): diff --git a/rnsh/rnslogging.py b/rnsh/rnslogging.py index c3e4dde..f1ee3bd 100644 --- a/rnsh/rnslogging.py +++ b/rnsh/rnslogging.py @@ -24,6 +24,7 @@ import asyncio import logging import sys import termios +import rnsh.process as process from logging import Handler, getLevelName from types import GenericAlias from typing import Any @@ -104,31 +105,23 @@ def _rns_log(msg, level=3, _override_destination=False): if not RNS.compact_log_fmt: msg = (" " * (7 - len(RNS.loglevelname(level)))) + msg - def inner(): - tattr_orig: list[Any] = None - with exception.permit(SystemExit): - tattr = termios.tcgetattr(sys.stdin.fileno()) - tattr_orig = tattr.copy() - # tcflag_t c_iflag; /* input modes */ - # tcflag_t c_oflag; /* output modes */ - # tcflag_t c_cflag; /* control modes */ - # tcflag_t c_lflag; /* local modes */ - # cc_t c_cc[NCCS]; /* special characters */ - tattr[1] = tattr[1] | termios.ONLRET | termios.ONLCR | termios.OPOST - termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, tattr) - - _rns_log_orig(msg, level, _override_destination) - - if tattr_orig is not None: - termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, tattr_orig) + def _rns_log_inner(): + nonlocal msg, level, _override_destination + with process.TTYRestorer(sys.stdin.fileno()) as tr: + with exception.permit(SystemExit): + attr = tr.current_attr() + attr[process.TTYRestorer.ATTR_IDX_OFLAG] = attr[process.TTYRestorer.ATTR_IDX_OFLAG] | \ + termios.ONLRET | termios.ONLCR | termios.OPOST + tr.set_attr(attr) + _rns_log_orig(msg, level, _override_destination) try: if _loop: - _loop.call_soon_threadsafe(inner) + _loop.call_soon_threadsafe(_rns_log_inner) else: - inner() + _rns_log_inner() except RuntimeError: - inner() + _rns_log_inner() RNS.log = _rns_log diff --git a/tests/test_process.py b/tests/test_process.py index b8ec824..5c64d93 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -1,7 +1,5 @@ import uuid import time -from types import TracebackType -from typing import Type import pytest import rnsh.process import contextlib @@ -9,6 +7,8 @@ import asyncio import logging import os import threading +import types +import typing logging.getLogger().setLevel(logging.DEBUG) @@ -46,8 +46,8 @@ class State(contextlib.AbstractContextManager): if self.process and self.process.running: self.process.terminate(kill_delay=0.1) - def __exit__(self, __exc_type: Type[BaseException], __exc_value: BaseException, - __traceback: TracebackType) -> bool: + def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException, + __traceback: types.TracebackType) -> bool: self.cleanup() return False