mirror of
https://github.com/markqvist/rnsh.git
synced 2025-01-07 05:07:57 -05:00
Improvements to TTY attribute handling
This commit is contained in:
parent
f80823344b
commit
f86acc3f08
@ -35,7 +35,8 @@ import sys
|
|||||||
import termios
|
import termios
|
||||||
import threading
|
import threading
|
||||||
import tty
|
import tty
|
||||||
|
import types
|
||||||
|
import typing
|
||||||
import psutil
|
import psutil
|
||||||
|
|
||||||
import rnsh.exception as exception
|
import rnsh.exception as exception
|
||||||
@ -160,10 +161,26 @@ def process_exists(pid) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class TtyRestorer:
|
class TTYRestorer(contextlib.AbstractContextManager):
|
||||||
|
# Indexes of flags within the attrs array
|
||||||
|
ATTR_IDX_IFLAG = 0
|
||||||
|
ATTR_IDX_OFLAG = 1
|
||||||
|
ATTR_IDX_CFLAG = 2
|
||||||
|
ATTR_IDX_LFLAG = 4
|
||||||
|
ATTR_IDX_CC = 5
|
||||||
|
|
||||||
def __init__(self, fd: int):
|
def __init__(self, fd: int):
|
||||||
"""
|
"""
|
||||||
Saves termios attributes for a tty for later restoration
|
Saves termios attributes for a tty for later restoration.
|
||||||
|
|
||||||
|
The attributes are an array of values with the following meanings.
|
||||||
|
|
||||||
|
tcflag_t c_iflag; /* input modes */
|
||||||
|
tcflag_t c_oflag; /* output modes */
|
||||||
|
tcflag_t c_cflag; /* control modes */
|
||||||
|
tcflag_t c_lflag; /* local modes */
|
||||||
|
cc_t c_cc[NCCS]; /* special characters */
|
||||||
|
|
||||||
:param fd: file descriptor of tty
|
:param fd: file descriptor of tty
|
||||||
"""
|
"""
|
||||||
self._fd = fd
|
self._fd = fd
|
||||||
@ -175,12 +192,31 @@ class TtyRestorer:
|
|||||||
"""
|
"""
|
||||||
tty.setraw(self._fd, termios.TCSADRAIN)
|
tty.setraw(self._fd, termios.TCSADRAIN)
|
||||||
|
|
||||||
|
def current_attr(self) -> [any]:
|
||||||
|
"""
|
||||||
|
Get the current termios attributes for the wrapped fd.
|
||||||
|
:return: attribute array
|
||||||
|
"""
|
||||||
|
return termios.tcgetattr(self._fd).copy()
|
||||||
|
|
||||||
|
def set_attr(self, attr: [any], when: int = termios.TCSANOW):
|
||||||
|
"""
|
||||||
|
Set termios attributes
|
||||||
|
:param attr: attribute list to set
|
||||||
|
"""
|
||||||
|
termios.tcsetattr(self._fd, when, attr)
|
||||||
|
|
||||||
def restore(self):
|
def restore(self):
|
||||||
"""
|
"""
|
||||||
Restore termios settings to state captured in constructor.
|
Restore termios settings to state captured in constructor.
|
||||||
"""
|
"""
|
||||||
termios.tcsetattr(self._fd, termios.TCSADRAIN, self._tattr)
|
termios.tcsetattr(self._fd, termios.TCSADRAIN, self._tattr)
|
||||||
|
|
||||||
|
def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
|
||||||
|
__traceback: types.TracebackType) -> bool:
|
||||||
|
self.restore()
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def event_wait(evt: asyncio.Event, timeout: float) -> bool:
|
async def event_wait(evt: asyncio.Event, timeout: float) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -444,7 +480,7 @@ async def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tr = TtyRestorer(sys.stdin.fileno())
|
tr = TTYRestorer(sys.stdin.fileno())
|
||||||
try:
|
try:
|
||||||
tr.raw()
|
tr.raw()
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
@ -558,7 +558,7 @@ async def _spin(until: callable = None, timeout: float | None = None) -> bool:
|
|||||||
_link: RNS.Link | None = None
|
_link: RNS.Link | None = None
|
||||||
_remote_exec_grace = 2.0
|
_remote_exec_grace = 2.0
|
||||||
_new_data: asyncio.Event | None = None
|
_new_data: asyncio.Event | None = None
|
||||||
_tr = process.TtyRestorer(sys.stdin.fileno())
|
_tr = process.TTYRestorer(sys.stdin.fileno())
|
||||||
|
|
||||||
|
|
||||||
def _client_packet_handler(message, packet):
|
def _client_packet_handler(message, packet):
|
||||||
|
@ -24,6 +24,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
import termios
|
import termios
|
||||||
|
import rnsh.process as process
|
||||||
from logging import Handler, getLevelName
|
from logging import Handler, getLevelName
|
||||||
from types import GenericAlias
|
from types import GenericAlias
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -104,31 +105,23 @@ def _rns_log(msg, level=3, _override_destination=False):
|
|||||||
if not RNS.compact_log_fmt:
|
if not RNS.compact_log_fmt:
|
||||||
msg = (" " * (7 - len(RNS.loglevelname(level)))) + msg
|
msg = (" " * (7 - len(RNS.loglevelname(level)))) + msg
|
||||||
|
|
||||||
def inner():
|
def _rns_log_inner():
|
||||||
tattr_orig: list[Any] = None
|
nonlocal msg, level, _override_destination
|
||||||
with exception.permit(SystemExit):
|
with process.TTYRestorer(sys.stdin.fileno()) as tr:
|
||||||
tattr = termios.tcgetattr(sys.stdin.fileno())
|
with exception.permit(SystemExit):
|
||||||
tattr_orig = tattr.copy()
|
attr = tr.current_attr()
|
||||||
# tcflag_t c_iflag; /* input modes */
|
attr[process.TTYRestorer.ATTR_IDX_OFLAG] = attr[process.TTYRestorer.ATTR_IDX_OFLAG] | \
|
||||||
# tcflag_t c_oflag; /* output modes */
|
termios.ONLRET | termios.ONLCR | termios.OPOST
|
||||||
# tcflag_t c_cflag; /* control modes */
|
tr.set_attr(attr)
|
||||||
# tcflag_t c_lflag; /* local modes */
|
_rns_log_orig(msg, level, _override_destination)
|
||||||
# cc_t c_cc[NCCS]; /* special characters */
|
|
||||||
tattr[1] = tattr[1] | termios.ONLRET | termios.ONLCR | termios.OPOST
|
|
||||||
termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, tattr)
|
|
||||||
|
|
||||||
_rns_log_orig(msg, level, _override_destination)
|
|
||||||
|
|
||||||
if tattr_orig is not None:
|
|
||||||
termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, tattr_orig)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if _loop:
|
if _loop:
|
||||||
_loop.call_soon_threadsafe(inner)
|
_loop.call_soon_threadsafe(_rns_log_inner)
|
||||||
else:
|
else:
|
||||||
inner()
|
_rns_log_inner()
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
inner()
|
_rns_log_inner()
|
||||||
|
|
||||||
|
|
||||||
RNS.log = _rns_log
|
RNS.log = _rns_log
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
import uuid
|
import uuid
|
||||||
import time
|
import time
|
||||||
from types import TracebackType
|
|
||||||
from typing import Type
|
|
||||||
import pytest
|
import pytest
|
||||||
import rnsh.process
|
import rnsh.process
|
||||||
import contextlib
|
import contextlib
|
||||||
@ -9,6 +7,8 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
|
import types
|
||||||
|
import typing
|
||||||
logging.getLogger().setLevel(logging.DEBUG)
|
logging.getLogger().setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
@ -46,8 +46,8 @@ class State(contextlib.AbstractContextManager):
|
|||||||
if self.process and self.process.running:
|
if self.process and self.process.running:
|
||||||
self.process.terminate(kill_delay=0.1)
|
self.process.terminate(kill_delay=0.1)
|
||||||
|
|
||||||
def __exit__(self, __exc_type: Type[BaseException], __exc_value: BaseException,
|
def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
|
||||||
__traceback: TracebackType) -> bool:
|
__traceback: types.TracebackType) -> bool:
|
||||||
self.cleanup()
|
self.cleanup()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user