Improvements to TTY attribute handling

This commit is contained in:
Aaron Heise 2023-02-11 11:50:13 -06:00
parent f80823344b
commit f86acc3f08
4 changed files with 58 additions and 29 deletions

View File

@ -35,7 +35,8 @@ import sys
import termios import termios
import threading import threading
import tty import tty
import types
import typing
import psutil import psutil
import rnsh.exception as exception import rnsh.exception as exception
@ -160,10 +161,26 @@ def process_exists(pid) -> bool:
return True return True
class TtyRestorer: class TTYRestorer(contextlib.AbstractContextManager):
# Indexes of flags within the attrs array
ATTR_IDX_IFLAG = 0
ATTR_IDX_OFLAG = 1
ATTR_IDX_CFLAG = 2
ATTR_IDX_LFLAG = 4
ATTR_IDX_CC = 5
def __init__(self, fd: int): def __init__(self, fd: int):
""" """
Saves termios attributes for a tty for later restoration Saves termios attributes for a tty for later restoration.
The attributes are an array of values with the following meanings.
tcflag_t c_iflag; /* input modes */
tcflag_t c_oflag; /* output modes */
tcflag_t c_cflag; /* control modes */
tcflag_t c_lflag; /* local modes */
cc_t c_cc[NCCS]; /* special characters */
:param fd: file descriptor of tty :param fd: file descriptor of tty
""" """
self._fd = fd self._fd = fd
@ -175,12 +192,31 @@ class TtyRestorer:
""" """
tty.setraw(self._fd, termios.TCSADRAIN) tty.setraw(self._fd, termios.TCSADRAIN)
def current_attr(self) -> [any]:
"""
Get the current termios attributes for the wrapped fd.
:return: attribute array
"""
return termios.tcgetattr(self._fd).copy()
def set_attr(self, attr: [any], when: int = termios.TCSANOW):
"""
Set termios attributes
:param attr: attribute list to set
"""
termios.tcsetattr(self._fd, when, attr)
def restore(self): def restore(self):
""" """
Restore termios settings to state captured in constructor. Restore termios settings to state captured in constructor.
""" """
termios.tcsetattr(self._fd, termios.TCSADRAIN, self._tattr) termios.tcsetattr(self._fd, termios.TCSADRAIN, self._tattr)
def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
__traceback: types.TracebackType) -> bool:
self.restore()
return False
async def event_wait(evt: asyncio.Event, timeout: float) -> bool: async def event_wait(evt: asyncio.Event, timeout: float) -> bool:
""" """
@ -444,7 +480,7 @@ async def main():
if __name__ == "__main__": if __name__ == "__main__":
tr = TtyRestorer(sys.stdin.fileno()) tr = TTYRestorer(sys.stdin.fileno())
try: try:
tr.raw() tr.raw()
asyncio.run(main()) asyncio.run(main())

View File

@ -558,7 +558,7 @@ async def _spin(until: callable = None, timeout: float | None = None) -> bool:
_link: RNS.Link | None = None _link: RNS.Link | None = None
_remote_exec_grace = 2.0 _remote_exec_grace = 2.0
_new_data: asyncio.Event | None = None _new_data: asyncio.Event | None = None
_tr = process.TtyRestorer(sys.stdin.fileno()) _tr = process.TTYRestorer(sys.stdin.fileno())
def _client_packet_handler(message, packet): def _client_packet_handler(message, packet):

View File

@ -24,6 +24,7 @@ import asyncio
import logging import logging
import sys import sys
import termios import termios
import rnsh.process as process
from logging import Handler, getLevelName from logging import Handler, getLevelName
from types import GenericAlias from types import GenericAlias
from typing import Any from typing import Any
@ -104,31 +105,23 @@ def _rns_log(msg, level=3, _override_destination=False):
if not RNS.compact_log_fmt: if not RNS.compact_log_fmt:
msg = (" " * (7 - len(RNS.loglevelname(level)))) + msg msg = (" " * (7 - len(RNS.loglevelname(level)))) + msg
def inner(): def _rns_log_inner():
tattr_orig: list[Any] = None nonlocal msg, level, _override_destination
with exception.permit(SystemExit): with process.TTYRestorer(sys.stdin.fileno()) as tr:
tattr = termios.tcgetattr(sys.stdin.fileno()) with exception.permit(SystemExit):
tattr_orig = tattr.copy() attr = tr.current_attr()
# tcflag_t c_iflag; /* input modes */ attr[process.TTYRestorer.ATTR_IDX_OFLAG] = attr[process.TTYRestorer.ATTR_IDX_OFLAG] | \
# tcflag_t c_oflag; /* output modes */ termios.ONLRET | termios.ONLCR | termios.OPOST
# tcflag_t c_cflag; /* control modes */ tr.set_attr(attr)
# tcflag_t c_lflag; /* local modes */ _rns_log_orig(msg, level, _override_destination)
# cc_t c_cc[NCCS]; /* special characters */
tattr[1] = tattr[1] | termios.ONLRET | termios.ONLCR | termios.OPOST
termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, tattr)
_rns_log_orig(msg, level, _override_destination)
if tattr_orig is not None:
termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, tattr_orig)
try: try:
if _loop: if _loop:
_loop.call_soon_threadsafe(inner) _loop.call_soon_threadsafe(_rns_log_inner)
else: else:
inner() _rns_log_inner()
except RuntimeError: except RuntimeError:
inner() _rns_log_inner()
RNS.log = _rns_log RNS.log = _rns_log

View File

@ -1,7 +1,5 @@
import uuid import uuid
import time import time
from types import TracebackType
from typing import Type
import pytest import pytest
import rnsh.process import rnsh.process
import contextlib import contextlib
@ -9,6 +7,8 @@ import asyncio
import logging import logging
import os import os
import threading import threading
import types
import typing
logging.getLogger().setLevel(logging.DEBUG) logging.getLogger().setLevel(logging.DEBUG)
@ -46,8 +46,8 @@ class State(contextlib.AbstractContextManager):
if self.process and self.process.running: if self.process and self.process.running:
self.process.terminate(kill_delay=0.1) self.process.terminate(kill_delay=0.1)
def __exit__(self, __exc_type: Type[BaseException], __exc_value: BaseException, def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
__traceback: TracebackType) -> bool: __traceback: types.TracebackType) -> bool:
self.cleanup() self.cleanup()
return False return False