Address several code style concerns

This commit is contained in:
Aaron Heise 2023-02-11 07:38:35 -06:00
parent 2789ef2624
commit 81459efcd6
10 changed files with 196 additions and 147 deletions

View file

@ -23,22 +23,25 @@
# SOFTWARE.
from __future__ import annotations
import functools
from typing import Callable, TypeVar
import termios
import rnsh.rnslogging as rnslogging
import RNS
import time
import sys
import os
import base64
import rnsh.process as process
import asyncio
import threading
import signal
import rnsh.retry as retry
from rnsh.__version import __version__
import base64
import functools
import logging as __logging
import os
import signal
import sys
import termios
import threading
import time
from typing import Callable, TypeVar
import RNS
import rnsh.exception as exception
import rnsh.process as process
import rnsh.retry as retry
import rnsh.rnslogging as rnslogging
from rnsh.__version import __version__
module_logger = __logging.getLogger(__name__)
@ -60,11 +63,12 @@ _retry_timer = retry.RetryThread()
_destination: RNS.Destination | None = None
_loop: asyncio.AbstractEventLoop | None = None
async def _check_finished(timeout: float = 0):
await process.event_wait(_finished, timeout=timeout)
await process.event_wait(_finished, timeout=timeout)
def _sigint_handler(signal, frame):
def _sigint_handler(sig, frame):
global _finished
log = _get_logger("_sigint_handler")
log.debug("SIGINT")
@ -91,7 +95,9 @@ def _prepare_identity(identity_path):
_identity = RNS.Identity()
_identity.to_file(identity_path)
def _print_identity(configdir, identitypath, service_name, include_destination: bool):
global _reticulum
_reticulum = RNS.Reticulum(configdir=configdir, loglevel=RNS.LOG_INFO)
_prepare_identity(identitypath)
destination = RNS.Destination(_identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, service_name)
@ -121,12 +127,13 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2
if len(a) != dest_len:
raise ValueError(
"Allowed destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format(
"Allowed destination length is invalid, must be {hex} hexadecimal " +
"characters ({byte} bytes).".format(
hex=dest_len, byte=dest_len // 2))
try:
destination_hash = bytes.fromhex(a)
_allowed_identity_hashes.append(destination_hash)
except Exception as e:
except Exception:
raise ValueError("Invalid destination entered. Check your input.")
except Exception as e:
log.error(str(e))
@ -169,12 +176,10 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
except KeyboardInterrupt:
log.warning("Shutting down")
for link in list(_destination.links):
try:
with exception.permit(SystemExit):
proc = ProcessState.get_for_tag(link.link_id)
if proc is not None and proc.process.running:
proc.process.terminate()
except:
pass
await asyncio.sleep(1)
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
for link in links_still_active:
@ -197,16 +202,11 @@ class ProcessState:
cls.clear_tag(tag)
cls._processes.append((tag, ps))
@classmethod
def clear_tag(cls, tag: any):
with cls._lock:
try:
with exception.permit(SystemExit):
cls._processes.remove(tag)
except:
pass
def __init__(self,
tag: any,
@ -222,8 +222,8 @@ class ProcessState:
self._mdu = mdu
self._loop = loop if loop is not None else asyncio.get_running_loop()
self._process = process.CallbackSubprocess(argv=cmd,
env={ "TERM": term or os.environ.get("TERM", None),
"RNS_REMOTE_IDENTITY": remote_identity or ""},
env={"TERM": term or os.environ.get("TERM", None),
"RNS_REMOTE_IDENTITY": remote_identity or ""},
loop=loop,
stdout_callback=self._stdout_data,
terminated_callback=terminated_callback)
@ -302,7 +302,6 @@ class ProcessState:
except Exception as e:
self._log.debug(f"failed to update winsz: {e}")
REQUEST_IDX_STDIN = 0
REQUEST_IDX_TERM = 1
REQUEST_IDX_TIOS = 2
@ -327,9 +326,9 @@ class ProcessState:
request[ProcessState.REQUEST_IDX_TERM] = os.environ.get("TERM", None)
request[ProcessState.REQUEST_IDX_TIOS] = termios.tcgetattr(stdin_fd)
request[ProcessState.REQUEST_IDX_ROWS], \
request[ProcessState.REQUEST_IDX_COLS], \
request[ProcessState.REQUEST_IDX_HPIX], \
request[ProcessState.REQUEST_IDX_VPIX] = process.tty_get_winsize(stdin_fd)
request[ProcessState.REQUEST_IDX_COLS], \
request[ProcessState.REQUEST_IDX_HPIX], \
request[ProcessState.REQUEST_IDX_VPIX] = process.tty_get_winsize(stdin_fd)
return request
def process_request(self, data: [any], read_size: int) -> [any]:
@ -342,7 +341,7 @@ class ProcessState:
# vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels
# term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX+1]
response = ProcessState.default_response()
term_state = data[ProcessState.REQUEST_IDX_TIOS:ProcessState.REQUEST_IDX_VPIX+1]
term_state = data[ProcessState.REQUEST_IDX_TIOS:ProcessState.REQUEST_IDX_VPIX + 1]
response[ProcessState.RESPONSE_IDX_RUNNING] = self.process.running
if self.process.running:
@ -365,17 +364,17 @@ class ProcessState:
RESPONSE_IDX_RUNNING = 0
RESPONSE_IDX_RETCODE = 1
RESPONSE_IDX_RDYBYTE = 2
RESPONSE_IDX_STDOUT = 3
RESPONSE_IDX_STDOUT = 3
RESPONSE_IDX_TMSTAMP = 4
@staticmethod
def default_response() -> [any]:
response: list[any] = [
False, # 0: Process running
None, # 1: Return value
0, # 2: Number of outstanding bytes
None, # 3: Stdout/Stderr
None, # 4: Timestamp
False, # 0: Process running
None, # 1: Return value
0, # 2: Number of outstanding bytes
None, # 3: Stdout/Stderr
None, # 4: Timestamp
].copy()
response[ProcessState.RESPONSE_IDX_TMSTAMP] = time.time()
return response
@ -406,7 +405,8 @@ def _subproc_data_ready(link: RNS.Link, chars_available: int):
else:
if not timeout:
log.info(
f"Notifying client try {tries} (retcode: {process_state.return_code} chars avail: {chars_available})")
f"Notifying client try {tries} (retcode: {process_state.return_code} " +
f"chars avail: {chars_available})")
packet = RNS.Packet(link, DATA_AVAIL_MSG.encode("utf-8"))
packet.send()
pr = packet.receipt
@ -431,6 +431,7 @@ def _subproc_data_ready(link: RNS.Link, chars_available: int):
else:
log.debug(f"Notification already pending for link {link}")
def _subproc_terminated(link: RNS.Link, return_code: int):
global _loop
log = _get_logger("_subproc_terminated")
@ -444,18 +445,20 @@ def _subproc_terminated(link: RNS.Link, return_code: int):
def inner():
log.debug(f"cleanup culled link {link}")
if link and link.status != RNS.Link.CLOSED:
try:
link.teardown()
except:
pass
finally:
ProcessState.clear_tag(link.link_id)
with exception.permit(SystemExit):
try:
link.teardown()
finally:
ProcessState.clear_tag(link.link_id)
_loop.call_later(300, inner)
_loop.call_soon(_subproc_data_ready, link, 0)
_loop.call_soon_threadsafe(cleanup)
def _listen_start_proc(link: RNS.Link, remote_identity: str | None, term: str, loop: asyncio.AbstractEventLoop) -> ProcessState | None:
def _listen_start_proc(link: RNS.Link, remote_identity: str | None, term: str,
loop: asyncio.AbstractEventLoop) -> ProcessState | None:
global _cmd
log = _get_logger("_listen_start_proc")
try:
@ -501,7 +504,7 @@ def _initiator_identified(link, identity):
global _allow_all, _cmd, _loop
log = _get_logger("_initiator_identified")
log.info("Initiator of link " + str(link) + " identified as " + RNS.prettyhexrep(identity.hash))
if not _allow_all and not identity.hash in _allowed_identity_hashes:
if not _allow_all and identity.hash not in _allowed_identity_hashes:
log.warning("Identity " + RNS.prettyhexrep(identity.hash) + " not allowed, tearing down link", RNS.LOG_WARNING)
link.teardown()
@ -540,7 +543,7 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
return ProcessState.default_response()
async def _spin(until: Callable | None = None, timeout: float | None = None) -> bool:
async def _spin(until: callable = None, timeout: float | None = None) -> bool:
if timeout is not None:
timeout += time.time()
@ -644,7 +647,8 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
await _spin(
until=lambda: _link.status == RNS.Link.CLOSED or (
request_receipt.status != RNS.RequestReceipt.FAILED and request_receipt.status != RNS.RequestReceipt.SENT),
request_receipt.status != RNS.RequestReceipt.FAILED and
request_receipt.status != RNS.RequestReceipt.SENT),
timeout=timeout
)
@ -667,11 +671,11 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
if request_receipt.response is not None:
try:
running = request_receipt.response[ProcessState.RESPONSE_IDX_RUNNING] or True
running = request_receipt.response[ProcessState.RESPONSE_IDX_RUNNING] or True
return_code = request_receipt.response[ProcessState.RESPONSE_IDX_RETCODE]
ready_bytes = request_receipt.response[ProcessState.RESPONSE_IDX_RDYBYTE] or 0
stdout = request_receipt.response[ProcessState.RESPONSE_IDX_STDOUT]
timestamp = request_receipt.response[ProcessState.RESPONSE_IDX_TMSTAMP]
stdout = request_receipt.response[ProcessState.RESPONSE_IDX_STDOUT]
timestamp = request_receipt.response[ProcessState.RESPONSE_IDX_TMSTAMP]
# log.debug("data: " + (stdout.decode("utf-8") if stdout is not None else ""))
except Exception as e:
raise RemoteExecutionError(f"Received invalid response") from e
@ -758,10 +762,9 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
if return_code is not None:
log.debug(f"received return code {return_code}, exiting")
try:
with exception.permit(SystemExit):
_link.teardown()
except:
pass
return return_code
except RemoteExecutionError as e:
print(e.msg)
@ -772,13 +775,15 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
_T = TypeVar("_T")
def _split_array_at(arr: [_T], at: _T) -> ([_T], [_T]):
try:
idx = arr.index(at)
return arr[:idx], arr[idx+1:]
return arr[:idx], arr[idx + 1:]
except ValueError:
return arr, []
async def main():
global _tr, _finished, _loop
import docopt
@ -879,9 +884,8 @@ Options:
timeout=args_timeout,
)
return return_code if args_mirror else 0
except:
finally:
_tr.restore()
raise
else:
print("")
print(args)
@ -889,17 +893,15 @@ Options:
def rnsh_cli():
return_code = 1
try:
return_code = asyncio.run(main())
finally:
try:
with exception.permit(SystemExit):
process.tty_unset_reader_callbacks(sys.stdin.fileno())
except:
pass
_tr.restore()
_retry_timer.close()
sys.exit(return_code)
sys.exit(return_code or 255)
if __name__ == "__main__":