Address several code style concerns

This commit is contained in:
Aaron Heise 2023-02-11 07:38:35 -06:00
parent 2789ef2624
commit 81459efcd6
10 changed files with 196 additions and 147 deletions

1
.gitignore vendored
View File

@ -5,3 +5,4 @@ testconfig/
/rnsh.egg-info/ /rnsh.egg-info/
/build/ /build/
/dist/ /dist/
.pytest_cache/

View File

@ -12,7 +12,7 @@ out.
## Quickstart ## Quickstart
Requires Python 3.11+ on Linux or Unix. WSL probably works. Cygwin might work, too. Requires Python 3.10+ on Linux or Unix. WSL probably works. Cygwin might work, too.
- Activate a virtualenv - Activate a virtualenv
- `pip3 install rnsh` - `pip3 install rnsh`

View File

@ -16,6 +16,13 @@ psutil = "^5.9.4"
rnsh = 'rnsh.rnsh:rnsh_cli' rnsh = 'rnsh.rnsh:rnsh_cli'
[tool.poetry.group.dev.dependencies]
pytest = "^7.2.1"
flake8 = "^6.0.0"
bandit = "^1.7.4"
isort = "^5.12.0"
safety = "^2.3.5"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"

34
rnsh/exception.py Normal file
View File

@ -0,0 +1,34 @@
from contextlib import AbstractContextManager
class permit(AbstractContextManager):
"""Context manager to allow specified exceptions
The specified exceptions will be allowed to bubble up. Other
exceptions are suppressed.
After a non-matching exception is suppressed, execution proceeds
with the next statement following the with statement.
with allow(KeyboardInterrupt):
time.sleep(300)
# Execution still resumes here if no KeyboardInterrupt
"""
def __init__(self, *exceptions):
self._exceptions = exceptions
def __enter__(self):
pass
def __exit__(self, exctype, excinst, exctb):
# Unlike isinstance and issubclass, CPython exception handling
# currently only looks at the concrete type hierarchy (ignoring
# the instance and subclass checking hooks). While Guido considers
# that a bug rather than a feature, it's a fairly hard one to fix
# due to various internal implementation details. suppress provides
# the simpler issubclass based semantics, rather than trying to
# exactly reproduce the limitations of the CPython interpreter.
#
# See http://bugs.python.org/issue12029 for more details
return exctype is not None and issubclass(exctype, self._exceptions)

View File

@ -23,24 +23,27 @@
import asyncio import asyncio
import contextlib import contextlib
import errno import errno
import fcntl
import functools import functools
import re import logging as __logging
import os
import pty
import select
import signal import signal
import struct import struct
import sys
import termios
import threading import threading
import tty import tty
import pty
import os
import asyncio
import sys
import fcntl
import psutil import psutil
import select
import termios import rnsh.exception as exception
import logging as __logging
module_logger = __logging.getLogger(__name__) module_logger = __logging.getLogger(__name__)
def tty_add_reader_callback(fd: int, callback: callable, loop: asyncio.AbstractEventLoop | None = None):
def tty_add_reader_callback(fd: int, callback: callable, loop: asyncio.AbstractEventLoop = None):
""" """
Add an async reader callback for a tty file descriptor. Add an async reader callback for a tty file descriptor.
@ -60,7 +63,8 @@ def tty_add_reader_callback(fd: int, callback: callable, loop: asyncio.AbstractE
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
loop.add_reader(fd, callback) loop.add_reader(fd, callback)
def tty_read(fd: int) -> bytes | None:
def tty_read(fd: int) -> bytes:
""" """
Read available bytes from a tty file descriptor. When used in a callback added to a file descriptor using Read available bytes from a tty file descriptor. When used in a callback added to a file descriptor using
tty_add_reader_callback(...), this function creates a solution for non-blocking reads from ttys. tty_add_reader_callback(...), this function creates a solution for non-blocking reads from ttys.
@ -78,7 +82,7 @@ def tty_read(fd: int) -> bytes | None:
break break
for f in ready: for f in ready:
try: try:
data = os.read(fd, 512) data = os.read(f, 512)
except OSError as e: except OSError as e:
if e.errno != errno.EIO and e.errno != errno.EWOULDBLOCK: if e.errno != errno.EIO and e.errno != errno.EWOULDBLOCK:
raise raise
@ -89,6 +93,7 @@ def tty_read(fd: int) -> bytes | None:
result.extend(data) result.extend(data)
return result return result
def fd_is_closed(fd: int) -> bool: def fd_is_closed(fd: int) -> bool:
""" """
Check if file descriptor is closed Check if file descriptor is closed
@ -100,18 +105,18 @@ def fd_is_closed(fd: int) -> bool:
except OSError as ose: except OSError as ose:
return ose.errno == errno.EBADF return ose.errno == errno.EBADF
def tty_unset_reader_callbacks(fd: int, loop: asyncio.AbstractEventLoop | None = None):
def tty_unset_reader_callbacks(fd: int, loop: asyncio.AbstractEventLoop = None):
""" """
Remove async reader callbacks for file descriptor. Remove async reader callbacks for file descriptor.
:param fd: file descriptor :param fd: file descriptor
:param loop: asyncio event loop from which to remove callbacks :param loop: asyncio event loop from which to remove callbacks
""" """
try: with exception.permit(SystemExit):
if loop is None: if loop is None:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
loop.remove_reader(fd) loop.remove_reader(fd)
except:
pass
def tty_get_winsize(fd: int) -> [int, int, int, int]: def tty_get_winsize(fd: int) -> [int, int, int, int]:
""" """
@ -123,6 +128,7 @@ def tty_get_winsize(fd: int) -> [int, int, int , int]:
rows, cols, h_pixels, v_pixels = struct.unpack('HHHH', packed) rows, cols, h_pixels, v_pixels = struct.unpack('HHHH', packed)
return rows, cols, h_pixels, v_pixels return rows, cols, h_pixels, v_pixels
def tty_set_winsize(fd: int, rows: int, cols: int, h_pixels: int, v_pixels: int): def tty_set_winsize(fd: int, rows: int, cols: int, h_pixels: int, v_pixels: int):
""" """
Set the window size on a tty. Set the window size on a tty.
@ -137,6 +143,7 @@ def tty_set_winsize(fd: int, rows: int, cols: int, h_pixels: int, v_pixels: int)
packed = struct.pack('HHHH', rows, cols, h_pixels, v_pixels) packed = struct.pack('HHHH', rows, cols, h_pixels, v_pixels)
fcntl.ioctl(fd, termios.TIOCSWINSZ, packed) fcntl.ioctl(fd, termios.TIOCSWINSZ, packed)
def process_exists(pid) -> bool: def process_exists(pid) -> bool:
""" """
Check For the existence of a unix pid. Check For the existence of a unix pid.
@ -150,6 +157,7 @@ def process_exists(pid) -> bool:
else: else:
return True return True
class TtyRestorer: class TtyRestorer:
def __init__(self, fd: int): def __init__(self, fd: int):
""" """
@ -189,12 +197,12 @@ class CallbackSubprocess:
# time between checks of child process # time between checks of child process
PROCESS_POLL_TIME: float = 0.1 PROCESS_POLL_TIME: float = 0.1
def __init__(self, argv: [str], env: dict | None, loop: asyncio.AbstractEventLoop, stdout_callback: callable, def __init__(self, argv: [str], env: dict, loop: asyncio.AbstractEventLoop, stdout_callback: callable,
terminated_callback: callable): terminated_callback: callable):
""" """
Fork a child process and generate callbacks with output from the process. Fork a child process and generate callbacks with output from the process.
:param argv: the command line, tokenized. The first element must be the absolute path to an executable file. :param argv: the command line, tokenized. The first element must be the absolute path to an executable file.
:param term: the value that should be set for TERM. If None, the value from the parent process will be used :param env: environment variables to override
:param loop: the asyncio event loop to use :param loop: the asyncio event loop to use
:param stdout_callback: callback for data, e.g. def callback(data:bytes) -> None :param stdout_callback: callback for data, e.g. def callback(data:bytes) -> None
:param terminated_callback: callback for termination/return code, e.g. def callback(return_code:int) -> None :param terminated_callback: callback for termination/return code, e.g. def callback(return_code:int) -> None
@ -210,9 +218,9 @@ class CallbackSubprocess:
self._loop = loop self._loop = loop
self._stdout_cb = stdout_callback self._stdout_cb = stdout_callback
self._terminated_cb = terminated_callback self._terminated_cb = terminated_callback
self._pid: int | None = None self._pid: int = None
self._child_fd: int | None = None self._child_fd: int = None
self._return_code: int | None = None self._return_code: int = None
def terminate(self, kill_delay: float = 1.0): def terminate(self, kill_delay: float = 1.0):
""" """
@ -223,19 +231,15 @@ class CallbackSubprocess:
if not self.running: if not self.running:
return return
try: with exception.permit(SystemExit):
os.kill(self._pid, signal.SIGTERM) os.kill(self._pid, signal.SIGTERM)
except:
pass
def kill(): def kill():
if process_exists(self._pid): if process_exists(self._pid):
self._log.debug("kill()") self._log.debug("kill()")
try: with exception.permit(SystemExit):
os.kill(self._pid, signal.SIGHUP) os.kill(self._pid, signal.SIGHUP)
os.kill(self._pid, signal.SIGKILL) os.kill(self._pid, signal.SIGKILL)
except:
pass
self._loop.call_later(kill_delay, kill) self._loop.call_later(kill_delay, kill)
@ -288,7 +292,7 @@ class CallbackSubprocess:
r, c, h, v = tty_get_winsize(fromfd) r, c, h, v = tty_get_winsize(fromfd)
self.set_winsize(r, c, h, v) self.set_winsize(r, c, h, v)
def tcsetattr(self, when: int, attr: list[int | list[int | bytes]]): def tcsetattr(self, when: int, attr: list[any]): # actual type is list[int | list[int | bytes]]
""" """
Set tty attributes. Set tty attributes.
:param when: when to apply change: termios.TCSANOW or termios.TCSADRAIN or termios.TCSAFLUSH :param when: when to apply change: termios.TCSANOW or termios.TCSADRAIN or termios.TCSAFLUSH
@ -296,7 +300,7 @@ class CallbackSubprocess:
""" """
termios.tcsetattr(self._child_fd, when, attr) termios.tcsetattr(self._child_fd, when, attr)
def tcgetattr(self) -> list[int | list[int | bytes]]: def tcgetattr(self) -> list[any]: # actual type is list[int | list[int | bytes]]
""" """
Get tty attributes. Get tty attributes.
:return: tty attributes value :return: tty attributes value
@ -328,7 +332,6 @@ class CallbackSubprocess:
# env["SHELL"] = program # env["SHELL"] = program
# self._log.debug(f"set login shell {self._command}") # self._log.debug(f"set login shell {self._command}")
self._pid, self._child_fd = pty.fork() self._pid, self._child_fd = pty.fork()
if self._pid == 0: if self._pid == 0:
@ -341,10 +344,8 @@ class CallbackSubprocess:
for c in p.connections(kind='all'): for c in p.connections(kind='all'):
if c == sys.stdin.fileno() or c == sys.stdout.fileno() or c == sys.stderr.fileno(): if c == sys.stdin.fileno() or c == sys.stdout.fileno() or c == sys.stderr.fileno():
continue continue
try: with exception.permit(SystemExit):
os.close(c.fd) os.close(c.fd)
except:
pass
os.setpgrp() os.setpgrp()
os.execvpe(program, self._command, env) os.execvpe(program, self._command, env)
except Exception as err: except Exception as err:
@ -364,21 +365,19 @@ class CallbackSubprocess:
except Exception as e: except Exception as e:
if not hasattr(e, "errno") or e.errno != errno.ECHILD: if not hasattr(e, "errno") or e.errno != errno.ECHILD:
self._log.debug(f"Error in process poll: {e}") self._log.debug(f"Error in process poll: {e}")
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll) self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
def reader(fd: int, callback: callable): def reader(fd: int, callback: callable):
result = bytearray() with exception.permit(SystemExit):
try: data = tty_read(fd)
c = tty_read(fd) if data is not None and len(data) > 0:
if c is not None and len(c) > 0: callback(data)
callback(c)
except:
pass
tty_add_reader_callback(self._child_fd, functools.partial(reader, self._child_fd, self._stdout_cb), self._loop) tty_add_reader_callback(self._child_fd, functools.partial(reader, self._child_fd, self._stdout_cb), self._loop)
@property @property
def return_code(self) -> int | None: def return_code(self) -> int:
return self._return_code return self._return_code
@ -387,7 +386,6 @@ async def main():
A test driver for the CallbackProcess class. A test driver for the CallbackProcess class.
python ./process.py /bin/zsh --login python ./process.py /bin/zsh --login
""" """
import rnsh.testlogging
log = module_logger.getChild("main") log = module_logger.getChild("main")
if len(sys.argv) <= 1: if len(sys.argv) <= 1:
@ -413,14 +411,14 @@ async def main():
stdout_callback=stdout, stdout_callback=stdout,
terminated_callback=terminated) terminated_callback=terminated)
def sigint_handler(signal, frame): def sigint_handler(sig, frame):
# log.debug("KeyboardInterrupt") # log.debug("KeyboardInterrupt")
if process is None or process.started and not process.running: if process is None or process.started and not process.running:
raise KeyboardInterrupt raise KeyboardInterrupt
elif process.running: elif process.running:
process.write("\x03".encode("utf-8")) process.write("\x03".encode("utf-8"))
def sigwinch_handler(signal, frame): def sigwinch_handler(sig, frame):
# log.debug("WindowChanged") # log.debug("WindowChanged")
process.copy_winsize(sys.stdin.fileno()) process.copy_winsize(sys.stdin.fileno())
@ -443,6 +441,7 @@ async def main():
log.debug(f"got retcode {val}") log.debug(f"got retcode {val}")
return val return val
if __name__ == "__main__": if __name__ == "__main__":
tr = TtyRestorer(sys.stdin.fileno()) tr = TtyRestorer(sys.stdin.fileno())
try: try:

View File

@ -24,6 +24,7 @@ import asyncio
import logging import logging
import threading import threading
import time import time
import rnsh.exception as exception
import logging as __logging import logging as __logging
from typing import Callable from typing import Callable
@ -47,7 +48,9 @@ class RetryStatus:
@property @property
def ready(self): def ready(self):
ready = time.time() > self.try_time + self.wait_delay ready = time.time() > self.try_time + self.wait_delay
# self._log.debug(f"ready check {self.tag} try_time {self.try_time} wait_delay {self.wait_delay} next_try {self.try_time + self.wait_delay} now {time.time()} exceeded {time.time() - self.try_time - self.wait_delay} ready {ready}") self._log.debug(f"ready check {self.tag} try_time {self.try_time} wait_delay {self.wait_delay} " +
f"next_try {self.try_time + self.wait_delay} now {time.time()} " +
f"exceeded {time.time() - self.try_time - self.wait_delay} ready {ready}")
return ready return ready
@property @property
@ -72,11 +75,11 @@ class RetryThread:
self._tag_counter = 0 self._tag_counter = 0
self._lock = threading.RLock() self._lock = threading.RLock()
self._run = True self._run = True
self._finished: asyncio.Future | None = None self._finished: asyncio.Future = None
self._thread = threading.Thread(name=name, target=self._thread_run) self._thread = threading.Thread(name=name, target=self._thread_run)
self._thread.start() self._thread.start()
def close(self, loop: asyncio.AbstractEventLoop | None = None) -> asyncio.Future | None: def close(self, loop: asyncio.AbstractEventLoop = None) -> asyncio.Future:
self._log.debug("stopping timer thread") self._log.debug("stopping timer thread")
if loop is None: if loop is None:
self._run = False self._run = False
@ -110,10 +113,8 @@ class RetryThread:
with self._lock: with self._lock:
for retry in prune: for retry in prune:
self._log.debug(f"pruned retry {retry.tag}, retry count {retry.tries}/{retry.try_limit}") self._log.debug(f"pruned retry {retry.tag}, retry count {retry.tries}/{retry.try_limit}")
try: with exception.permit(SystemExit):
self._statuses.remove(retry) self._statuses.remove(retry)
except:
pass
if self._finished is not None: if self._finished is not None:
self._finished.set_result(None) self._finished.set_result(None)
@ -126,7 +127,7 @@ class RetryThread:
return next(filter(lambda s: s.tag == tag, self._statuses), None) is not None return next(filter(lambda s: s.tag == tag, self._statuses), None) is not None
def begin(self, try_limit: int, wait_delay: float, try_callback: Callable[[any, int], any], def begin(self, try_limit: int, wait_delay: float, try_callback: Callable[[any, int], any],
timeout_callback: Callable[[any, int], None], tag: int | None = None) -> any: timeout_callback: Callable[[any, int], None], tag: int = None) -> any:
self._log.debug(f"running first try") self._log.debug(f"running first try")
tag = try_callback(tag, 1) tag = try_callback(tag, 1)
self._log.debug(f"first try got id {tag}") self._log.debug(f"first try got id {tag}")

View File

@ -23,22 +23,25 @@
# SOFTWARE. # SOFTWARE.
from __future__ import annotations from __future__ import annotations
import functools
from typing import Callable, TypeVar
import termios
import rnsh.rnslogging as rnslogging
import RNS
import time
import sys
import os
import base64
import rnsh.process as process
import asyncio import asyncio
import threading import base64
import signal import functools
import rnsh.retry as retry
from rnsh.__version import __version__
import logging as __logging import logging as __logging
import os
import signal
import sys
import termios
import threading
import time
from typing import Callable, TypeVar
import RNS
import rnsh.exception as exception
import rnsh.process as process
import rnsh.retry as retry
import rnsh.rnslogging as rnslogging
from rnsh.__version import __version__
module_logger = __logging.getLogger(__name__) module_logger = __logging.getLogger(__name__)
@ -60,11 +63,12 @@ _retry_timer = retry.RetryThread()
_destination: RNS.Destination | None = None _destination: RNS.Destination | None = None
_loop: asyncio.AbstractEventLoop | None = None _loop: asyncio.AbstractEventLoop | None = None
async def _check_finished(timeout: float = 0): async def _check_finished(timeout: float = 0):
await process.event_wait(_finished, timeout=timeout) await process.event_wait(_finished, timeout=timeout)
def _sigint_handler(signal, frame): def _sigint_handler(sig, frame):
global _finished global _finished
log = _get_logger("_sigint_handler") log = _get_logger("_sigint_handler")
log.debug("SIGINT") log.debug("SIGINT")
@ -91,7 +95,9 @@ def _prepare_identity(identity_path):
_identity = RNS.Identity() _identity = RNS.Identity()
_identity.to_file(identity_path) _identity.to_file(identity_path)
def _print_identity(configdir, identitypath, service_name, include_destination: bool): def _print_identity(configdir, identitypath, service_name, include_destination: bool):
global _reticulum
_reticulum = RNS.Reticulum(configdir=configdir, loglevel=RNS.LOG_INFO) _reticulum = RNS.Reticulum(configdir=configdir, loglevel=RNS.LOG_INFO)
_prepare_identity(identitypath) _prepare_identity(identitypath)
destination = RNS.Destination(_identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, service_name) destination = RNS.Destination(_identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, service_name)
@ -121,12 +127,13 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2 dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2
if len(a) != dest_len: if len(a) != dest_len:
raise ValueError( raise ValueError(
"Allowed destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format( "Allowed destination length is invalid, must be {hex} hexadecimal " +
"characters ({byte} bytes).".format(
hex=dest_len, byte=dest_len // 2)) hex=dest_len, byte=dest_len // 2))
try: try:
destination_hash = bytes.fromhex(a) destination_hash = bytes.fromhex(a)
_allowed_identity_hashes.append(destination_hash) _allowed_identity_hashes.append(destination_hash)
except Exception as e: except Exception:
raise ValueError("Invalid destination entered. Check your input.") raise ValueError("Invalid destination entered. Check your input.")
except Exception as e: except Exception as e:
log.error(str(e)) log.error(str(e))
@ -169,12 +176,10 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
except KeyboardInterrupt: except KeyboardInterrupt:
log.warning("Shutting down") log.warning("Shutting down")
for link in list(_destination.links): for link in list(_destination.links):
try: with exception.permit(SystemExit):
proc = ProcessState.get_for_tag(link.link_id) proc = ProcessState.get_for_tag(link.link_id)
if proc is not None and proc.process.running: if proc is not None and proc.process.running:
proc.process.terminate() proc.process.terminate()
except:
pass
await asyncio.sleep(1) await asyncio.sleep(1)
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links)) links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
for link in links_still_active: for link in links_still_active:
@ -197,16 +202,11 @@ class ProcessState:
cls.clear_tag(tag) cls.clear_tag(tag)
cls._processes.append((tag, ps)) cls._processes.append((tag, ps))
@classmethod @classmethod
def clear_tag(cls, tag: any): def clear_tag(cls, tag: any):
with cls._lock: with cls._lock:
try: with exception.permit(SystemExit):
cls._processes.remove(tag) cls._processes.remove(tag)
except:
pass
def __init__(self, def __init__(self,
tag: any, tag: any,
@ -302,7 +302,6 @@ class ProcessState:
except Exception as e: except Exception as e:
self._log.debug(f"failed to update winsz: {e}") self._log.debug(f"failed to update winsz: {e}")
REQUEST_IDX_STDIN = 0 REQUEST_IDX_STDIN = 0
REQUEST_IDX_TERM = 1 REQUEST_IDX_TERM = 1
REQUEST_IDX_TIOS = 2 REQUEST_IDX_TIOS = 2
@ -406,7 +405,8 @@ def _subproc_data_ready(link: RNS.Link, chars_available: int):
else: else:
if not timeout: if not timeout:
log.info( log.info(
f"Notifying client try {tries} (retcode: {process_state.return_code} chars avail: {chars_available})") f"Notifying client try {tries} (retcode: {process_state.return_code} " +
f"chars avail: {chars_available})")
packet = RNS.Packet(link, DATA_AVAIL_MSG.encode("utf-8")) packet = RNS.Packet(link, DATA_AVAIL_MSG.encode("utf-8"))
packet.send() packet.send()
pr = packet.receipt pr = packet.receipt
@ -431,6 +431,7 @@ def _subproc_data_ready(link: RNS.Link, chars_available: int):
else: else:
log.debug(f"Notification already pending for link {link}") log.debug(f"Notification already pending for link {link}")
def _subproc_terminated(link: RNS.Link, return_code: int): def _subproc_terminated(link: RNS.Link, return_code: int):
global _loop global _loop
log = _get_logger("_subproc_terminated") log = _get_logger("_subproc_terminated")
@ -444,18 +445,20 @@ def _subproc_terminated(link: RNS.Link, return_code: int):
def inner(): def inner():
log.debug(f"cleanup culled link {link}") log.debug(f"cleanup culled link {link}")
if link and link.status != RNS.Link.CLOSED: if link and link.status != RNS.Link.CLOSED:
with exception.permit(SystemExit):
try: try:
link.teardown() link.teardown()
except:
pass
finally: finally:
ProcessState.clear_tag(link.link_id) ProcessState.clear_tag(link.link_id)
_loop.call_later(300, inner) _loop.call_later(300, inner)
_loop.call_soon(_subproc_data_ready, link, 0) _loop.call_soon(_subproc_data_ready, link, 0)
_loop.call_soon_threadsafe(cleanup) _loop.call_soon_threadsafe(cleanup)
def _listen_start_proc(link: RNS.Link, remote_identity: str | None, term: str, loop: asyncio.AbstractEventLoop) -> ProcessState | None: def _listen_start_proc(link: RNS.Link, remote_identity: str | None, term: str,
loop: asyncio.AbstractEventLoop) -> ProcessState | None:
global _cmd global _cmd
log = _get_logger("_listen_start_proc") log = _get_logger("_listen_start_proc")
try: try:
@ -501,7 +504,7 @@ def _initiator_identified(link, identity):
global _allow_all, _cmd, _loop global _allow_all, _cmd, _loop
log = _get_logger("_initiator_identified") log = _get_logger("_initiator_identified")
log.info("Initiator of link " + str(link) + " identified as " + RNS.prettyhexrep(identity.hash)) log.info("Initiator of link " + str(link) + " identified as " + RNS.prettyhexrep(identity.hash))
if not _allow_all and not identity.hash in _allowed_identity_hashes: if not _allow_all and identity.hash not in _allowed_identity_hashes:
log.warning("Identity " + RNS.prettyhexrep(identity.hash) + " not allowed, tearing down link", RNS.LOG_WARNING) log.warning("Identity " + RNS.prettyhexrep(identity.hash) + " not allowed, tearing down link", RNS.LOG_WARNING)
link.teardown() link.teardown()
@ -540,7 +543,7 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
return ProcessState.default_response() return ProcessState.default_response()
async def _spin(until: Callable | None = None, timeout: float | None = None) -> bool: async def _spin(until: callable = None, timeout: float | None = None) -> bool:
if timeout is not None: if timeout is not None:
timeout += time.time() timeout += time.time()
@ -644,7 +647,8 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
await _spin( await _spin(
until=lambda: _link.status == RNS.Link.CLOSED or ( until=lambda: _link.status == RNS.Link.CLOSED or (
request_receipt.status != RNS.RequestReceipt.FAILED and request_receipt.status != RNS.RequestReceipt.SENT), request_receipt.status != RNS.RequestReceipt.FAILED and
request_receipt.status != RNS.RequestReceipt.SENT),
timeout=timeout timeout=timeout
) )
@ -758,10 +762,9 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
if return_code is not None: if return_code is not None:
log.debug(f"received return code {return_code}, exiting") log.debug(f"received return code {return_code}, exiting")
try: with exception.permit(SystemExit):
_link.teardown() _link.teardown()
except:
pass
return return_code return return_code
except RemoteExecutionError as e: except RemoteExecutionError as e:
print(e.msg) print(e.msg)
@ -772,6 +775,7 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
_T = TypeVar("_T") _T = TypeVar("_T")
def _split_array_at(arr: [_T], at: _T) -> ([_T], [_T]): def _split_array_at(arr: [_T], at: _T) -> ([_T], [_T]):
try: try:
idx = arr.index(at) idx = arr.index(at)
@ -779,6 +783,7 @@ def _split_array_at(arr: [_T], at: _T) -> ([_T], [_T]):
except ValueError: except ValueError:
return arr, [] return arr, []
async def main(): async def main():
global _tr, _finished, _loop global _tr, _finished, _loop
import docopt import docopt
@ -879,9 +884,8 @@ Options:
timeout=args_timeout, timeout=args_timeout,
) )
return return_code if args_mirror else 0 return return_code if args_mirror else 0
except: finally:
_tr.restore() _tr.restore()
raise
else: else:
print("") print("")
print(args) print(args)
@ -889,17 +893,15 @@ Options:
def rnsh_cli(): def rnsh_cli():
return_code = 1
try: try:
return_code = asyncio.run(main()) return_code = asyncio.run(main())
finally: finally:
try: with exception.permit(SystemExit):
process.tty_unset_reader_callbacks(sys.stdin.fileno()) process.tty_unset_reader_callbacks(sys.stdin.fileno())
except:
pass
_tr.restore() _tr.restore()
_retry_timer.close() _retry_timer.close()
sys.exit(return_code) sys.exit(return_code or 255)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -20,17 +20,18 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE. # SOFTWARE.
import asyncio
import logging import logging
import sys
import termios
from logging import Handler, getLevelName from logging import Handler, getLevelName
from types import GenericAlias from types import GenericAlias
import os from typing import Any
import tty
from typing import List, Any
import asyncio
import termios
import sys
import RNS import RNS
import json
import rnsh.exception as exception
class RnsHandler(Handler): class RnsHandler(Handler):
""" """
@ -77,6 +78,7 @@ class RnsHandler(Handler):
__class_getitem__ = classmethod(GenericAlias) __class_getitem__ = classmethod(GenericAlias)
log_format = '%(name)-30s %(message)s [%(threadName)s]' log_format = '%(name)-30s %(message)s [%(threadName)s]'
logging.basicConfig( logging.basicConfig(
@ -86,20 +88,25 @@ logging.basicConfig(
datefmt='%Y-%m-%d %H:%M:%S', datefmt='%Y-%m-%d %H:%M:%S',
handlers=[RnsHandler()]) handlers=[RnsHandler()])
_loop: asyncio.AbstractEventLoop | None = None _loop: asyncio.AbstractEventLoop = None
def set_main_loop(loop: asyncio.AbstractEventLoop): def set_main_loop(loop: asyncio.AbstractEventLoop):
global _loop global _loop
_loop = loop _loop = loop
# hack for temporarily overriding term settings to make debug print right # hack for temporarily overriding term settings to make debug print right
_rns_log_orig = RNS.log _rns_log_orig = RNS.log
def _rns_log(msg, level=3, _override_destination=False): def _rns_log(msg, level=3, _override_destination=False):
if not RNS.compact_log_fmt: if not RNS.compact_log_fmt:
msg = (" " * (7 - len(RNS.loglevelname(level)))) + msg msg = (" " * (7 - len(RNS.loglevelname(level)))) + msg
def inner(): def inner():
tattr_orig: list[Any] | None = None tattr_orig: list[Any] = None
try: with exception.permit(SystemExit):
tattr = termios.tcgetattr(sys.stdin.fileno()) tattr = termios.tcgetattr(sys.stdin.fileno())
tattr_orig = tattr.copy() tattr_orig = tattr.copy()
# tcflag_t c_iflag; /* input modes */ # tcflag_t c_iflag; /* input modes */
@ -109,19 +116,16 @@ def _rns_log(msg, level=3, _override_destination = False):
# cc_t c_cc[NCCS]; /* special characters */ # cc_t c_cc[NCCS]; /* special characters */
tattr[1] = tattr[1] | termios.ONLRET | termios.ONLCR | termios.OPOST tattr[1] = tattr[1] | termios.ONLRET | termios.ONLCR | termios.OPOST
termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, tattr) termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, tattr)
except:
pass
_rns_log_orig(msg, level, _override_destination) _rns_log_orig(msg, level, _override_destination)
if tattr_orig is not None: if tattr_orig is not None:
termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, tattr_orig) termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, tattr_orig)
try:
if _loop: if _loop:
_loop.call_soon_threadsafe(inner) _loop.call_soon_threadsafe(inner)
else: else:
inner() inner()
except:
inner()
RNS.log = _rns_log RNS.log = _rns_log

View File

@ -33,3 +33,4 @@ __logging.basicConfig(
format=log_format, format=log_format,
datefmt='%Y-%m-%d %H:%M:%S', datefmt='%Y-%m-%d %H:%M:%S',
handlers=[__logging.StreamHandler()]) handlers=[__logging.StreamHandler()])