Improvements to TTY attribute handling

This commit is contained in:
Aaron Heise 2023-02-11 11:50:13 -06:00
parent f80823344b
commit f86acc3f08
4 changed files with 58 additions and 29 deletions

View File

@ -35,7 +35,8 @@ import sys
import termios
import threading
import tty
import types
import typing
import psutil
import rnsh.exception as exception
@ -160,10 +161,26 @@ def process_exists(pid) -> bool:
return True
class TtyRestorer:
class TTYRestorer(contextlib.AbstractContextManager):
# Indexes of flags within the attrs array
ATTR_IDX_IFLAG = 0
ATTR_IDX_OFLAG = 1
ATTR_IDX_CFLAG = 2
ATTR_IDX_LFLAG = 4
ATTR_IDX_CC = 5
def __init__(self, fd: int):
"""
Saves termios attributes for a tty for later restoration
Saves termios attributes for a tty for later restoration.
The attributes are an array of values with the following meanings.
tcflag_t c_iflag; /* input modes */
tcflag_t c_oflag; /* output modes */
tcflag_t c_cflag; /* control modes */
tcflag_t c_lflag; /* local modes */
cc_t c_cc[NCCS]; /* special characters */
:param fd: file descriptor of tty
"""
self._fd = fd
@ -175,12 +192,31 @@ class TtyRestorer:
"""
tty.setraw(self._fd, termios.TCSADRAIN)
def current_attr(self) -> [any]:
"""
Get the current termios attributes for the wrapped fd.
:return: attribute array
"""
return termios.tcgetattr(self._fd).copy()
def set_attr(self, attr: [any], when: int = termios.TCSANOW):
"""
Set termios attributes
:param attr: attribute list to set
"""
termios.tcsetattr(self._fd, when, attr)
def restore(self):
"""
Restore termios settings to state captured in constructor.
"""
termios.tcsetattr(self._fd, termios.TCSADRAIN, self._tattr)
def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
__traceback: types.TracebackType) -> bool:
self.restore()
return False
async def event_wait(evt: asyncio.Event, timeout: float) -> bool:
"""
@ -444,7 +480,7 @@ async def main():
if __name__ == "__main__":
tr = TtyRestorer(sys.stdin.fileno())
tr = TTYRestorer(sys.stdin.fileno())
try:
tr.raw()
asyncio.run(main())

View File

@ -558,7 +558,7 @@ async def _spin(until: callable = None, timeout: float | None = None) -> bool:
_link: RNS.Link | None = None
_remote_exec_grace = 2.0
_new_data: asyncio.Event | None = None
_tr = process.TtyRestorer(sys.stdin.fileno())
_tr = process.TTYRestorer(sys.stdin.fileno())
def _client_packet_handler(message, packet):

View File

@ -24,6 +24,7 @@ import asyncio
import logging
import sys
import termios
import rnsh.process as process
from logging import Handler, getLevelName
from types import GenericAlias
from typing import Any
@ -104,31 +105,23 @@ def _rns_log(msg, level=3, _override_destination=False):
if not RNS.compact_log_fmt:
msg = (" " * (7 - len(RNS.loglevelname(level)))) + msg
def inner():
tattr_orig: list[Any] = None
def _rns_log_inner():
nonlocal msg, level, _override_destination
with process.TTYRestorer(sys.stdin.fileno()) as tr:
with exception.permit(SystemExit):
tattr = termios.tcgetattr(sys.stdin.fileno())
tattr_orig = tattr.copy()
# tcflag_t c_iflag; /* input modes */
# tcflag_t c_oflag; /* output modes */
# tcflag_t c_cflag; /* control modes */
# tcflag_t c_lflag; /* local modes */
# cc_t c_cc[NCCS]; /* special characters */
tattr[1] = tattr[1] | termios.ONLRET | termios.ONLCR | termios.OPOST
termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, tattr)
attr = tr.current_attr()
attr[process.TTYRestorer.ATTR_IDX_OFLAG] = attr[process.TTYRestorer.ATTR_IDX_OFLAG] | \
termios.ONLRET | termios.ONLCR | termios.OPOST
tr.set_attr(attr)
_rns_log_orig(msg, level, _override_destination)
if tattr_orig is not None:
termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, tattr_orig)
try:
if _loop:
_loop.call_soon_threadsafe(inner)
_loop.call_soon_threadsafe(_rns_log_inner)
else:
inner()
_rns_log_inner()
except RuntimeError:
inner()
_rns_log_inner()
RNS.log = _rns_log

View File

@ -1,7 +1,5 @@
import uuid
import time
from types import TracebackType
from typing import Type
import pytest
import rnsh.process
import contextlib
@ -9,6 +7,8 @@ import asyncio
import logging
import os
import threading
import types
import typing
logging.getLogger().setLevel(logging.DEBUG)
@ -46,8 +46,8 @@ class State(contextlib.AbstractContextManager):
if self.process and self.process.running:
self.process.terminate(kill_delay=0.1)
def __exit__(self, __exc_type: Type[BaseException], __exc_value: BaseException,
__traceback: TracebackType) -> bool:
def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
__traceback: types.TracebackType) -> bool:
self.cleanup()
return False