mirror of
https://github.com/markqvist/rnsh.git
synced 2024-10-01 01:15:37 -04:00
Locking in a major increment for integrating the CallbackSubprocess module to rnsh
This commit is contained in:
parent
9186a64962
commit
fcc73ba31a
@ -9,13 +9,10 @@ import pty
|
||||
import os
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
import logging as __logging
|
||||
|
||||
import fcntl
|
||||
import select
|
||||
import termios
|
||||
|
||||
import logging as __logging
|
||||
module_logger = __logging.getLogger(__name__)
|
||||
|
||||
def tty_add_reader_callback(fd: int, callback: callable, loop: asyncio.AbstractEventLoop | None = None):
|
||||
@ -165,8 +162,8 @@ class CallbackSubprocess:
|
||||
assert stdout_callback is not None, "stdout_callback should not be None"
|
||||
assert terminated_callback is not None, "terminated_callback should not be None"
|
||||
|
||||
self.log = module_logger.getChild(self.__class__.__name__)
|
||||
self.log.debug(f"__init__({argv},{term},...")
|
||||
self._log = module_logger.getChild(self.__class__.__name__)
|
||||
self._log.debug(f"__init__({argv},{term},...")
|
||||
self._command = argv
|
||||
self._term = term
|
||||
self._loop = loop
|
||||
@ -181,7 +178,7 @@ class CallbackSubprocess:
|
||||
Terminate child process if running
|
||||
:param kill_delay: if after kill_delay seconds the child process has not exited, escalate to SIGHUP and SIGKILL
|
||||
"""
|
||||
self.log.debug("terminate()")
|
||||
self._log.debug("terminate()")
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
@ -191,7 +188,7 @@ class CallbackSubprocess:
|
||||
pass
|
||||
|
||||
def kill():
|
||||
self.log.debug("kill()")
|
||||
self._log.debug("kill()")
|
||||
try:
|
||||
os.kill(self._pid, signal.SIGHUP)
|
||||
os.kill(self._pid, signal.SIGKILL)
|
||||
@ -201,9 +198,9 @@ class CallbackSubprocess:
|
||||
self._loop.call_later(kill_delay, kill)
|
||||
|
||||
def wait():
|
||||
self.log.debug("wait()")
|
||||
self._log.debug("wait()")
|
||||
os.waitpid(self._pid, 0)
|
||||
self.log.debug("wait() finish")
|
||||
self._log.debug("wait() finish")
|
||||
|
||||
threading.Thread(target=wait).start()
|
||||
|
||||
@ -226,7 +223,7 @@ class CallbackSubprocess:
|
||||
Write bytes to the stdin of the child process.
|
||||
:param data: bytes to write
|
||||
"""
|
||||
self.log.debug(f"write({data})")
|
||||
self._log.debug(f"write({data})")
|
||||
os.write(self._child_fd, data)
|
||||
|
||||
def set_winsize(self, r: int, c: int, h: int, v: int):
|
||||
@ -238,7 +235,7 @@ class CallbackSubprocess:
|
||||
:param v: vertical pixels visible
|
||||
:return:
|
||||
"""
|
||||
self.log.debug(f"set_winsize({r},{c},{h},{v}")
|
||||
self._log.debug(f"set_winsize({r},{c},{h},{v}")
|
||||
tty_set_winsize(self._child_fd, r, c, h, v)
|
||||
|
||||
def copy_winsize(self, fromfd:int):
|
||||
@ -268,7 +265,7 @@ class CallbackSubprocess:
|
||||
"""
|
||||
Start the child process.
|
||||
"""
|
||||
self.log.debug("start()")
|
||||
self._log.debug("start()")
|
||||
parentenv = os.environ.copy()
|
||||
env = {"HOME": parentenv["HOME"],
|
||||
"TERM": self._term if self._term is not None else parentenv.get("TERM", "xterm"),
|
||||
@ -290,11 +287,11 @@ class CallbackSubprocess:
|
||||
try:
|
||||
pid, self._return_code = os.waitpid(self._pid, os.WNOHANG)
|
||||
if self._return_code is not None and not process_exists(self._pid):
|
||||
self.log.debug(f"polled return code {self._return_code}")
|
||||
self._log.debug(f"polled return code {self._return_code}")
|
||||
self._terminated_cb(self._return_code)
|
||||
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
|
||||
except Exception as e:
|
||||
self.log.debug(f"Error in process poll: {e}")
|
||||
self._log.debug(f"Error in process poll: {e}")
|
||||
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
|
||||
|
||||
def reader(fd: int, callback: callable):
|
||||
|
116
rnsh/retry.py
Normal file
116
rnsh/retry.py
Normal file
@ -0,0 +1,116 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import logging as __logging
|
||||
module_logger = __logging.getLogger(__name__)
|
||||
|
||||
class RetryStatus:
|
||||
def __init__(self, id: any, try_limit: int, wait_delay: float, retry_callback: callable[any, int], timeout_callback: callable[any], tries: int = 1):
|
||||
self.id = id
|
||||
self.try_limit = try_limit
|
||||
self.tries = tries
|
||||
self.wait_delay = wait_delay
|
||||
self.retry_callback = retry_callback
|
||||
self.timeout_callback = timeout_callback
|
||||
self.try_time = time.monotonic()
|
||||
self.completed = False
|
||||
|
||||
@property
|
||||
def ready(self):
|
||||
return self.try_time + self.wait_delay < time.monotonic() and not self.completed
|
||||
|
||||
@property
|
||||
def timed_out(self):
|
||||
return self.ready and self.tries >= self.try_limit
|
||||
|
||||
def timeout(self):
|
||||
self.completed = True
|
||||
self.timeout_callback(self.id)
|
||||
|
||||
def retry(self):
|
||||
self.tries += 1
|
||||
self.retry_callback(self.id, self.tries)
|
||||
|
||||
class RetryThread:
|
||||
def __init__(self, loop_period: float = 0.25):
|
||||
self._log = module_logger.getChild(self.__class__.__name__)
|
||||
self._loop_period = loop_period
|
||||
self._statuses: list[RetryStatus] = []
|
||||
self._id_counter = 0
|
||||
self._lock = threading.RLock()
|
||||
self._thread = threading.Thread(target=self._thread_run())
|
||||
self._run = True
|
||||
self._finished: asyncio.Future | None = None
|
||||
self._thread.start()
|
||||
|
||||
def close(self, loop: asyncio.AbstractEventLoop | None = None) -> asyncio.Future | None:
|
||||
self._log.debug("stopping timer thread")
|
||||
if loop is None:
|
||||
self._run = False
|
||||
self._thread.join()
|
||||
return None
|
||||
else:
|
||||
self._finished = loop.create_future()
|
||||
return self._finished
|
||||
def _thread_run(self):
|
||||
last_run = time.monotonic()
|
||||
while self._run and self._finished is None:
|
||||
time.sleep(last_run + self._loop_period - time.monotonic())
|
||||
last_run = time.monotonic()
|
||||
ready: list[RetryStatus] = []
|
||||
prune: list[RetryStatus] = []
|
||||
with self._lock:
|
||||
ready.extend(filter(lambda s: s.ready, self._statuses))
|
||||
for retry in ready:
|
||||
try:
|
||||
if not retry.completed:
|
||||
if retry.timed_out:
|
||||
self._log.debug(f"timed out {retry.id} after {retry.try_limit} tries")
|
||||
retry.timeout()
|
||||
prune.append(retry)
|
||||
else:
|
||||
self._log.debug(f"retrying {retry.id}, try {retry.tries + 1}/{retry.try_limit}")
|
||||
retry.retry()
|
||||
except Exception as e:
|
||||
self._log.error(f"error processing retry id {retry.id}: {e}")
|
||||
prune.append(retry)
|
||||
|
||||
with self._lock:
|
||||
for retry in prune:
|
||||
self._log.debug(f"pruned retry {retry.id}, retry count {retry.tries}/{retry.try_limit}")
|
||||
self._statuses.remove(retry)
|
||||
if self._finished is not None:
|
||||
self._finished.set_result(None)
|
||||
|
||||
def _get_id(self):
|
||||
self._id_counter += 1
|
||||
return self._id_counter
|
||||
|
||||
def begin(self, try_limit: int, wait_delay: float, try_callback: callable[[any | None, int], any], timeout_callback: callable[any, int], id: int | None = None) -> any:
|
||||
self._log.debug(f"running first try")
|
||||
id = try_callback(id, 1)
|
||||
self._log.debug(f"first try success, got id {id}")
|
||||
with self._lock:
|
||||
if id is None:
|
||||
id = self._get_id()
|
||||
self._statuses.append(RetryStatus(id=id,
|
||||
tries=1,
|
||||
try_limit=try_limit,
|
||||
wait_delay=wait_delay,
|
||||
retry_callback=try_callback,
|
||||
timeout_callback=timeout_callback))
|
||||
self._log.debug(f"added retry timer for {id}")
|
||||
def complete(self, id: any):
|
||||
assert id is not None
|
||||
with self._lock:
|
||||
status = next(filter(lambda l: l.id == id, self._statuses))
|
||||
assert status is not None
|
||||
status.completed = True
|
||||
self._statuses.remove(status)
|
||||
self._log.debug(f"completed {id}")
|
||||
def complete_all(self):
|
||||
with self._lock:
|
||||
for status in self._statuses:
|
||||
status.completed = True
|
||||
self._log.debug(f"completed {status.id}")
|
||||
self._statuses.clear()
|
710
rnsh/rnsh.py
710
rnsh/rnsh.py
@ -1,6 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
import pty
|
||||
import threading
|
||||
import functools
|
||||
|
||||
# MIT License
|
||||
#
|
||||
@ -24,309 +23,82 @@ import threading
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import rnslogging
|
||||
import RNS
|
||||
import subprocess
|
||||
import argparse
|
||||
import shlex
|
||||
import time
|
||||
import sys
|
||||
import tty
|
||||
import os
|
||||
import datetime
|
||||
import select
|
||||
import base64
|
||||
import fcntl
|
||||
import termios
|
||||
import queue
|
||||
import signal
|
||||
import errno
|
||||
import RNS.vendor.umsgpack as umsgpack
|
||||
import process
|
||||
import asyncio
|
||||
import threading
|
||||
import signal
|
||||
import retry
|
||||
|
||||
import logging as __logging
|
||||
module_logger = __logging.getLogger(__name__)
|
||||
def _getLogger(name: str):
|
||||
global module_logger
|
||||
return module_logger.getChild(name)
|
||||
|
||||
from RNS._version import __version__
|
||||
|
||||
APP_NAME = "rnsh"
|
||||
identity = None
|
||||
reticulum = None
|
||||
allow_all = False
|
||||
allowed_identity_hashes = []
|
||||
cmd = None
|
||||
processes = []
|
||||
processes_lock = threading.Lock()
|
||||
_identity = None
|
||||
_reticulum = None
|
||||
_allow_all = False
|
||||
_allowed_identity_hashes = []
|
||||
_cmd: str | None = None
|
||||
DATA_AVAIL_MSG = "data available"
|
||||
_finished: asyncio.Future | None = None
|
||||
_retry_timer = retry.RetryThread()
|
||||
_destination: RNS.Destination | None = None
|
||||
|
||||
def _handle_sigint_with_async_int():
|
||||
if _finished is not None:
|
||||
_finished.set_exception(KeyboardInterrupt())
|
||||
else:
|
||||
raise KeyboardInterrupt()
|
||||
|
||||
def fd_set_non_blocking(fd):
|
||||
old_flags = fcntl.fcntl(fd, fcntl.F_GETFL)
|
||||
if fd.isatty():
|
||||
tty.setraw(fd)
|
||||
fcntl.fcntl(fd, fcntl.F_SETFL, old_flags | os.O_NONBLOCK)
|
||||
def fd_non_blocking_read(fd):
|
||||
# from https://stackoverflow.com/questions/26263636/how-to-check-potentially-empty-stdin-without-waiting-for-input
|
||||
# TODO: Windows is probably different
|
||||
signal.signal(signal.SIGINT, _handle_sigint_with_async_int)
|
||||
|
||||
#return fd.read()
|
||||
try:
|
||||
old_settings = None
|
||||
# try:
|
||||
# old_settings = termios.tcgetattr(fd)
|
||||
# except:
|
||||
# pass
|
||||
old_flags = fcntl.fcntl(fd, fcntl.F_GETFL)
|
||||
try:
|
||||
# try:
|
||||
# tty.setraw(fd)
|
||||
# except:
|
||||
# pass
|
||||
fcntl.fcntl(fd, fcntl.F_SETFL, old_flags | os.O_NONBLOCK)
|
||||
return os.read(fd.fileno(), 1024)
|
||||
except OSError as ose:
|
||||
if ose.errno != 35:
|
||||
raise ose
|
||||
except Exception as e:
|
||||
RNS.log(f"Raw read error {e}")
|
||||
finally:
|
||||
fcntl.fcntl(fd, fcntl.F_SETFL, old_flags)
|
||||
# if old_settings is not None:
|
||||
# termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
||||
except:
|
||||
pass
|
||||
|
||||
class NonBlockingStreamReader:
|
||||
|
||||
def __init__(self, stream, callback = None):
|
||||
'''
|
||||
stream: the stream to read from.
|
||||
Usually a process' stdout or stderr.
|
||||
'''
|
||||
|
||||
self._s = stream
|
||||
self._q = queue.Queue()
|
||||
self._callback = callback
|
||||
self._stop_time = None
|
||||
|
||||
def _populateQueue(stream, queue):
|
||||
'''
|
||||
Collect lines from 'stream' and put them in 'quque'.
|
||||
'''
|
||||
# fd_set_non_blocking(stream)
|
||||
run = True
|
||||
while run and not (self._stop_time is not None and (datetime.datetime.now() - self._stop_time).total_seconds() > 0.05):
|
||||
# stream.flush()
|
||||
# line = stream.read(1) #fd_non_blocking_read(stream)
|
||||
timeout = 0.01
|
||||
ready, _, _ = select.select([stream], [], [], timeout)
|
||||
for fd in ready:
|
||||
try:
|
||||
data = os.read(fd, 512)
|
||||
except OSError as e:
|
||||
if e.errno != errno.EIO:
|
||||
raise
|
||||
# EIO means EOF on some systems
|
||||
run = False
|
||||
else:
|
||||
if not data: # EOF
|
||||
run = False
|
||||
if data is not None and len(data) > 0:
|
||||
if self._callback is not None:
|
||||
self._callback(data)
|
||||
else:
|
||||
queue.put(data)
|
||||
RNS.log("NonBlockingStreamReader exiting", RNS.LOG_DEBUG)
|
||||
os.close(stream)
|
||||
|
||||
self._t = threading.Thread(target = _populateQueue,
|
||||
args = (self._s, self._q))
|
||||
self._t.daemon = True
|
||||
self._t.start() #start collecting lines from the stream
|
||||
|
||||
def read(self, timeout = None):
|
||||
try:
|
||||
result = self._q.get_nowait() if timeout is None else self._q.get(block = timeout is not None,
|
||||
timeout = timeout)
|
||||
return result
|
||||
except TimeoutError:
|
||||
return None
|
||||
|
||||
def is_open(self):
|
||||
return self._t.is_alive()
|
||||
|
||||
def stop(self):
|
||||
if self._stop_time is None:
|
||||
self._stop_time = datetime.datetime.now()
|
||||
|
||||
class UnexpectedEndOfStream(Exception): pass
|
||||
|
||||
def processes_get():
|
||||
processes_lock.acquire()
|
||||
try:
|
||||
return processes.copy()
|
||||
finally:
|
||||
processes_lock.release()
|
||||
|
||||
def processes_add(process):
|
||||
processes_lock.acquire()
|
||||
try:
|
||||
processes.append(process)
|
||||
finally:
|
||||
processes_lock.release()
|
||||
|
||||
def processes_remove(process):
|
||||
if process.link.status == RNS.Link.ACTIVE:
|
||||
return
|
||||
|
||||
processes_lock.acquire()
|
||||
try:
|
||||
if next(filter(lambda p: p == process, processes)) is not None:
|
||||
processes.remove(process)
|
||||
finally:
|
||||
processes_lock.release()
|
||||
|
||||
#### Link Overrides ####
|
||||
_link_handle_request_orig = RNS.Link.handle_request
|
||||
def link_handle_request(self, request_id, unpacked_request):
|
||||
for process in processes_get():
|
||||
if process.link.link_id == self.link_id:
|
||||
RNS.log("Associating packet to link", RNS.LOG_DEBUG)
|
||||
process.request_id = request_id
|
||||
self.last_request_id = request_id
|
||||
_link_handle_request_orig(self, request_id, unpacked_request)
|
||||
|
||||
RNS.Link.handle_request = link_handle_request
|
||||
|
||||
class ProcessState:
|
||||
def __init__(self, command, link, remote_identity, term):
|
||||
self.lock = threading.RLock()
|
||||
self.link = link
|
||||
self.remote_identity = remote_identity
|
||||
self.term = term
|
||||
self.command = command
|
||||
self._stderrbuf = bytearray()
|
||||
self._stdoutbuf = bytearray()
|
||||
RNS.log("Launching " + self.command) # + " for client " + (RNS.prettyhexrep(self.remote_identity) if self.remote_identity else "unknown"), RNS.LOG_DEBUG)
|
||||
env = os.environ.copy()
|
||||
# env["PYTHONUNBUFFERED"] = "1"
|
||||
# env["PS1"] ="\\u:\\h "
|
||||
env["TERM"] = self.term
|
||||
self.mo, so = pty.openpty()
|
||||
self.me, se = pty.openpty()
|
||||
self.mi, si = pty.openpty()
|
||||
self.process = subprocess.Popen(shlex.split(self.command), bufsize=512, stdin=si, stdout=so, stderr=se, preexec_fn=os.setsid, shell=False, env=env)
|
||||
for fd in [so, se, si]:
|
||||
os.close(fd)
|
||||
# tty.setcbreak(self.mo)
|
||||
self.stdout_reader = NonBlockingStreamReader(self.mo, self._stdout_cb)
|
||||
self.stderr_reader = NonBlockingStreamReader(self.me, self._stderr_cb)
|
||||
self.last_update = datetime.datetime.now()
|
||||
self.request_id = None
|
||||
self.notify_tried = 0
|
||||
self.return_code = None
|
||||
|
||||
def _fd_callback(self, fdbuf, data):
|
||||
with self.lock:
|
||||
fdbuf.extend(data)
|
||||
|
||||
def _stdout_cb(self, data):
|
||||
self._fd_callback(self._stdoutbuf, data)
|
||||
|
||||
def _stderr_cb(self, data):
|
||||
self._fd_callback(self._stderrbuf, data)
|
||||
|
||||
def notify_client_data_available(self, chars_available):
|
||||
if (datetime.datetime.now() - self.last_update).total_seconds() < 1:
|
||||
return
|
||||
|
||||
self.last_update = datetime.datetime.now()
|
||||
if self.notify_tried > 15:
|
||||
processes_remove(self)
|
||||
RNS.log(f"Try count exceeded, terminating connection", RNS.LOG_ERROR)
|
||||
self.link.teardown()
|
||||
return
|
||||
|
||||
try:
|
||||
RNS.log(f"Notifying client; try {self.notify_tried} retcode: {self.return_code} chars avail: {chars_available}")
|
||||
RNS.Packet(self.link, DATA_AVAIL_MSG.encode("utf-8")).send()
|
||||
self.notify_tried += 1
|
||||
except Exception as e:
|
||||
RNS.log("Error notifying client: " + str(e), RNS.LOG_ERROR)
|
||||
|
||||
def poll(self, should_notify):
|
||||
self.return_code, chars_available = self.process.poll(), len(self._stdoutbuf) + len(self._stderrbuf)
|
||||
|
||||
if should_notify and self.return_code is not None or chars_available > 0:
|
||||
self.notify_client_data_available(chars_available)
|
||||
|
||||
if self.return_code is not None:
|
||||
self.stdout_reader.stop()
|
||||
self.stderr_reader.stop()
|
||||
|
||||
return self.return_code, chars_available
|
||||
|
||||
def is_finished(self):
|
||||
with self.lock:
|
||||
return self.return_code is not None and not self.stdout_reader.is_open() # and not self.stderr_reader.is_open()
|
||||
|
||||
def read(self): #TODO: limit take sizes?
|
||||
with self.lock:
|
||||
self.notify_tried = 0
|
||||
self.last_update = datetime.datetime.now()
|
||||
stdout = self._stdoutbuf
|
||||
self._stdoutbuf = bytearray()
|
||||
stderr = self._stderrbuf.copy()
|
||||
self._stderrbuf = bytearray()
|
||||
self.return_code = self.process.poll()
|
||||
if self.return_code is not None and len(stdout) == 0 and len(stderr) == 0:
|
||||
self.final_checkin = True
|
||||
return self.process.poll(), stdout, stderr
|
||||
|
||||
def write(self, bytes):
|
||||
os.write(self.mi, bytes)
|
||||
os.fsync(self.mi)
|
||||
|
||||
def terminate(self):
|
||||
chars_available = 0
|
||||
with self.lock:
|
||||
os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)
|
||||
for fd in [self.mo, self.me, self.mi]:
|
||||
os.close(fd)
|
||||
self.process.terminate()
|
||||
self.process.wait()
|
||||
if self.process.poll() is not None:
|
||||
stdout, stderr = self.process.communicate()
|
||||
self._stdoutbuf += stdout
|
||||
self._stderrbuf += stderr
|
||||
return len(self._stdoutbuf) + len(self._stderrbuf)
|
||||
|
||||
def prepare_identity(identity_path):
|
||||
global identity
|
||||
def _prepare_identity(identity_path):
|
||||
global _identity
|
||||
log = _getLogger("_prepare_identity")
|
||||
if identity_path == None:
|
||||
identity_path = RNS.Reticulum.identitypath+"/"+APP_NAME
|
||||
|
||||
if os.path.isfile(identity_path):
|
||||
identity = RNS.Identity.from_file(identity_path)
|
||||
_identity = RNS.Identity.from_file(identity_path)
|
||||
|
||||
if identity == None:
|
||||
RNS.log("No valid saved identity found, creating new...", RNS.LOG_INFO)
|
||||
identity = RNS.Identity()
|
||||
identity.to_file(identity_path)
|
||||
if _identity == None:
|
||||
log.info("No valid saved identity found, creating new...")
|
||||
_identity = RNS.Identity()
|
||||
_identity.to_file(identity_path)
|
||||
|
||||
def listen(configdir, command, identitypath = None, service_name ="default", verbosity = 0, quietness = 0, allowed = [], print_identity = False, disable_auth = None, disable_announce=False):
|
||||
global identity, allow_all, allowed_identity_hashes, reticulum, cmd
|
||||
|
||||
cmd = command
|
||||
async def _listen(configdir, command, identitypath = None, service_name ="default", verbosity = 0, quietness = 0,
|
||||
allowed = [], print_identity = False, disable_auth = None, disable_announce=False):
|
||||
global _identity, _allow_all, _allowed_identity_hashes, _reticulum, _cmd, _destination
|
||||
log = _getLogger("_listen")
|
||||
_cmd = command
|
||||
|
||||
targetloglevel = 3+verbosity-quietness
|
||||
reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
|
||||
_reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
|
||||
|
||||
prepare_identity(identitypath)
|
||||
destination = RNS.Destination(identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, service_name)
|
||||
_prepare_identity(identitypath)
|
||||
_destination = RNS.Destination(_identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, service_name)
|
||||
|
||||
if print_identity:
|
||||
print("Identity : "+str(identity))
|
||||
print("Listening on : "+RNS.prettyhexrep(destination.hash))
|
||||
log.info("Identity : " + str(_identity))
|
||||
log.info("Listening on : " + RNS.prettyhexrep(_destination.hash))
|
||||
exit(0)
|
||||
|
||||
if disable_auth:
|
||||
allow_all = True
|
||||
_allow_all = True
|
||||
else:
|
||||
if allowed != None:
|
||||
for a in allowed:
|
||||
@ -336,140 +108,306 @@ def listen(configdir, command, identitypath = None, service_name ="default", ver
|
||||
raise ValueError("Allowed destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format(hex=dest_len, byte=dest_len//2))
|
||||
try:
|
||||
destination_hash = bytes.fromhex(a)
|
||||
allowed_identity_hashes.append(destination_hash)
|
||||
_allowed_identity_hashes.append(destination_hash)
|
||||
except Exception as e:
|
||||
raise ValueError("Invalid destination entered. Check your input.")
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
log.error(str(e))
|
||||
exit(1)
|
||||
|
||||
if len(allowed_identity_hashes) < 1 and not disable_auth:
|
||||
print("Warning: No allowed identities configured, rncx will not accept any commands!")
|
||||
if len(_allowed_identity_hashes) < 1 and not disable_auth:
|
||||
log.warning("Warning: No allowed identities configured, rnsh will not accept any connections!")
|
||||
|
||||
destination.set_link_established_callback(command_link_established)
|
||||
_destination.set_link_established_callback(_listen_link_established)
|
||||
|
||||
if not allow_all:
|
||||
destination.register_request_handler(
|
||||
if not _allow_all:
|
||||
_destination.register_request_handler(
|
||||
path = service_name,
|
||||
response_generator = execute_received_command,
|
||||
response_generator = _listen_request,
|
||||
allow = RNS.Destination.ALLOW_LIST,
|
||||
allowed_list = allowed_identity_hashes
|
||||
allowed_list = _allowed_identity_hashes
|
||||
)
|
||||
else:
|
||||
destination.register_request_handler(
|
||||
_destination.register_request_handler(
|
||||
path = service_name,
|
||||
response_generator = execute_received_command,
|
||||
response_generator = _listen_request,
|
||||
allow = RNS.Destination.ALLOW_ALL,
|
||||
)
|
||||
|
||||
RNS.log("rnsh listening for commands on "+RNS.prettyhexrep(destination.hash))
|
||||
log.info("rnsh listening for commands on " + RNS.prettyhexrep(_destination.hash))
|
||||
|
||||
if not disable_announce:
|
||||
destination.announce()
|
||||
_destination.announce()
|
||||
|
||||
last = datetime.datetime.now()
|
||||
|
||||
while True:
|
||||
if not disable_announce and (datetime.datetime.now() - last).total_seconds() > 900: # TODO: make parameter
|
||||
last = datetime.datetime.now()
|
||||
destination.announce()
|
||||
last = time.monotonic()
|
||||
|
||||
time.sleep(0.005)
|
||||
for proc in processes_get():
|
||||
try:
|
||||
while True:
|
||||
if not disable_announce and time.monotonic() - last > 900: # TODO: make parameter
|
||||
last = datetime.datetime.now()
|
||||
_destination.announce()
|
||||
try:
|
||||
if proc.link.status == RNS.Link.CLOSED:
|
||||
RNS.log("Link closed, terminating")
|
||||
proc.terminate()
|
||||
proc.poll(should_notify=True)
|
||||
await asyncio.wait_for(_finished, timeout=1.0)
|
||||
except TimeoutError:
|
||||
pass
|
||||
except KeyboardInterrupt:
|
||||
log.warning("Shutting down")
|
||||
for link in list(_destination.links):
|
||||
try:
|
||||
if link.process is not None and link.process.process.running:
|
||||
link.process.process.terminate()
|
||||
except:
|
||||
RNS.log("Error polling process for link " + proc.link.link_id, RNS.LOG_ERROR)
|
||||
pass
|
||||
await asyncio.sleep(1)
|
||||
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
|
||||
for link in links_still_active:
|
||||
if link.status != RNS.Link.CLOSED:
|
||||
link.teardown()
|
||||
|
||||
if proc.link.status == RNS.Link.CLOSED:
|
||||
processes_remove(proc)
|
||||
class ProcessState:
|
||||
def __init__(self,
|
||||
cmd: str,
|
||||
mdu: int,
|
||||
data_available_callback: callable,
|
||||
terminated_callback: callable,
|
||||
term: str | None,
|
||||
loop: asyncio.AbstractEventLoop = None):
|
||||
|
||||
self._log = _getLogger(self.__class__.__name__)
|
||||
self._mdu = mdu
|
||||
self._loop = loop if loop is not None else asyncio.get_running_loop()
|
||||
self._process = process.CallbackSubprocess(argv=shlex.split(cmd),
|
||||
term=term,
|
||||
loop=asyncio.get_running_loop(),
|
||||
stdout_callback=self._stdout_data,
|
||||
terminated_callback=terminated_callback)
|
||||
self._data_buffer = bytearray()
|
||||
self._lock = threading.RLock()
|
||||
self._data_available_cb = data_available_callback
|
||||
self._terminated_cb = terminated_callback
|
||||
self._pending_receipt: RNS.PacketReceipt | None = None
|
||||
self._process.start()
|
||||
self._term_state: [int] | None = None
|
||||
|
||||
@property
|
||||
def mdu(self) -> int:
|
||||
return self._mdu
|
||||
|
||||
@mdu.setter
|
||||
def mdu(self, val: int):
|
||||
self._mdu = val
|
||||
|
||||
def pending_receipt_peek(self) -> RNS.PacketReceipt | None:
|
||||
return self._pending_receipt
|
||||
|
||||
def pending_receipt_take(self) -> RNS.PacketReceipt | None:
|
||||
with self._lock:
|
||||
val = self._pending_receipt
|
||||
self._pending_receipt = None
|
||||
return val
|
||||
|
||||
def pending_receipt_put(self, receipt: RNS.PacketReceipt | None):
|
||||
with self._lock:
|
||||
self._pending_receipt = receipt
|
||||
|
||||
@property
|
||||
def process(self) -> process.CallbackSubprocess:
|
||||
return self._process
|
||||
|
||||
@property
|
||||
def return_code(self) -> int | None:
|
||||
return self.process.return_code
|
||||
|
||||
@property
|
||||
def lock(self) -> threading.RLock:
|
||||
return self._lock
|
||||
|
||||
def read(self, count: int) -> bytes:
|
||||
with self.lock:
|
||||
take = self._data_buffer[:count-1]
|
||||
self._data_buffer = self._data_buffer[count:]
|
||||
return take
|
||||
|
||||
def _stdout_data(self, data: bytes):
|
||||
total_available = 0
|
||||
with self.lock:
|
||||
self._data_buffer.extend(data)
|
||||
total_available = len(self._data_buffer)
|
||||
try:
|
||||
self._data_available_cb(total_available)
|
||||
except Exception as e:
|
||||
self._log.error(f"Error calling ProcessState data_available_callback {e}")
|
||||
|
||||
def _update_winsz(self):
|
||||
self.process.set_winsize(self._term_state[3],
|
||||
self._term_state[4],
|
||||
self._term_state[5],
|
||||
self._term_state[6])
|
||||
|
||||
|
||||
REQUEST_IDX_STDIN = 0
|
||||
REQUEST_IDX_TERM = 1
|
||||
REQUEST_IDX_TIOS = 2
|
||||
REQUEST_IDX_ROWS = 3
|
||||
REQUEST_IDX_COLS = 4
|
||||
REQUEST_IDX_HPIX = 5
|
||||
REQUEST_IDX_VPIX = 6
|
||||
def process_request(self, data: [any], read_size: int) -> [any]:
|
||||
stdin = data[ProcessState.REQUEST_IDX_STDIN] # Data passed to stdin
|
||||
term = data[ProcessState.REQUEST_IDX_TERM] # TERM environment variable
|
||||
tios = data[ProcessState.REQUEST_IDX_TIOS] # termios attr
|
||||
rows = data[ProcessState.REQUEST_IDX_ROWS] # window rows
|
||||
cols = data[ProcessState.REQUEST_IDX_COLS] # window cols
|
||||
hpix = data[ProcessState.REQUEST_IDX_HPIX] # window horizontal pixels
|
||||
vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels
|
||||
term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX]
|
||||
response = ProcessState.default_response()
|
||||
|
||||
def command_link_start_process(link, identity, term) -> ProcessState:
|
||||
try:
|
||||
process = ProcessState(cmd, link, identity, term)
|
||||
processes_add(process)
|
||||
return process
|
||||
except Exception as e:
|
||||
RNS.log("Failed to launch process: " + str(e), RNS.LOG_ERROR)
|
||||
link.teardown()
|
||||
def command_link_established(link):
|
||||
global allow_all
|
||||
link.set_remote_identified_callback(initiator_identified)
|
||||
link.set_link_closed_callback(command_link_closed)
|
||||
RNS.log("Shell link "+str(link)+" established")
|
||||
if allow_all:
|
||||
command_link_start_process(link, None)
|
||||
|
||||
def command_link_closed(link):
|
||||
RNS.log("Shell link "+str(link)+" closed")
|
||||
matches = list(filter(lambda p: p.link == link, processes_get()))
|
||||
if len(matches) == 0:
|
||||
return
|
||||
proc = matches[0]
|
||||
try:
|
||||
proc.terminate()
|
||||
except:
|
||||
RNS.log("Error closing process for link " + RNS.prettyhexrep(link.link_id), RNS.LOG_ERROR)
|
||||
finally:
|
||||
processes_remove(proc)
|
||||
|
||||
def initiator_identified(link, identity):
|
||||
global allow_all, cmd
|
||||
RNS.log("Initiator of link "+str(link)+" identified as "+RNS.prettyhexrep(identity.hash))
|
||||
if not allow_all and not identity.hash in allowed_identity_hashes:
|
||||
RNS.log("Identity "+RNS.prettyhexrep(identity.hash)+" not allowed, tearing down link", RNS.LOG_WARNING)
|
||||
link.teardown()
|
||||
|
||||
def execute_received_command(path, data, request_id, remote_identity, requested_at):
|
||||
RNS.log("execute_received_command", RNS.LOG_DEBUG)
|
||||
process = None
|
||||
for proc in processes_get():
|
||||
RNS.log("checking a proc", RNS.LOG_DEBUG)
|
||||
if proc.request_id == request_id:
|
||||
process = proc
|
||||
RNS.log("execute_received_command matched request", RNS.LOG_DEBUG)
|
||||
|
||||
stdin = data[0] # Data passed to stdin
|
||||
if process is None:
|
||||
link = next(filter(lambda l: hasattr(l, "last_request_id") and l.last_request_id == request_id, RNS.Transport.active_links))
|
||||
if link is not None:
|
||||
process = command_link_start_process(link, identity, base64.b64decode(stdin).decode("utf-8") if stdin is not None else "")
|
||||
time.sleep(0.1)
|
||||
|
||||
# if remote_identity != None:
|
||||
# RNS.log("Executing command ["+command+"] for "+RNS.prettyhexrep(remote_identity.hash))
|
||||
# else:
|
||||
# RNS.log("Executing command ["+command+"] for unknown requestor")
|
||||
|
||||
result = [
|
||||
False, # 0: Command was executed
|
||||
None, # 1: Return value
|
||||
None, # 2: Stdout
|
||||
None, # 3: Stderr
|
||||
datetime.datetime.now(), # 4: Timestamp
|
||||
]
|
||||
|
||||
try:
|
||||
if process is not None:
|
||||
result[0] = not process.is_finished()
|
||||
response[ProcessState.RESPONSE_IDX_RUNNING] = not self.process.running
|
||||
if self.process.running:
|
||||
if term_state != self._term_state:
|
||||
self._term_state = term_state
|
||||
self._update_winsz()
|
||||
if stdin is not None and len(stdin) > 0:
|
||||
stdin = base64.b64decode(stdin)
|
||||
process.write(stdin)
|
||||
return_code, stdout, stderr = process.read()
|
||||
result[1] = return_code
|
||||
result[2] = base64.b64encode(stdout).decode("utf-8") if stdout is not None else None
|
||||
result[3] = base64.b64encode(stderr).decode("utf-8") if stderr is not None else None
|
||||
self.process.write(stdin)
|
||||
response[ProcessState.RESPONSE_IDX_RETCODE] = self.return_code
|
||||
stdout = self.read(read_size)
|
||||
with self.lock:
|
||||
response[ProcessState.RESPONSE_IDX_RDYBYTE] = len(self._data_buffer)
|
||||
response[ProcessState.RESPONSE_IDX_STDOUT] = \
|
||||
base64.b64encode(stdout).decode("utf-8") if stdout is not None and len(stdout) > 0 else None
|
||||
return response
|
||||
|
||||
RESPONSE_IDX_RUNNING = 0
|
||||
RESPONSE_IDX_RETCODE = 1
|
||||
RESPONSE_IDX_RDYBYTE = 2
|
||||
RESPONSE_IDX_STDOUT = 3
|
||||
RESPONSE_IDX_TMSTAMP = 4
|
||||
@staticmethod
|
||||
def default_response() -> [any]:
|
||||
return [
|
||||
False, # 0: Process running
|
||||
None, # 1: Return value
|
||||
0, # 2: Number of outstanding bytes
|
||||
None, # 3: Stdout/Stderr
|
||||
time.time(), # 4: Timestamp
|
||||
]
|
||||
|
||||
|
||||
def _subproc_data_ready(link: RNS.Link, chars_available: int):
|
||||
global _retry_timer
|
||||
log = _getLogger("_subproc_data_ready")
|
||||
process_state: ProcessState = link.process
|
||||
|
||||
def send(timeout: bool, id: any, tries: int) -> any:
|
||||
try:
|
||||
pr = process_state.pending_receipt_take()
|
||||
if pr is not None and pr.get_status() != RNS.PacketReceipt.SENT and pr.get_status() != RNS.PacketReceipt.DELIVERED:
|
||||
if not timeout:
|
||||
_retry_timer.complete(id)
|
||||
log.debug(f"Packet {id} completed with status {pr.status} on link {link}")
|
||||
return link.link_id
|
||||
|
||||
if not timeout:
|
||||
log.info(f"Notifying client try {tries} (retcode: {process_state.return_code} chars avail: {chars_available})")
|
||||
packet = RNS.Packet(link, DATA_AVAIL_MSG.encode("utf-8"))
|
||||
packet.send()
|
||||
pr = packet.receipt
|
||||
process_state.pending_receipt_put(pr)
|
||||
return link.link_id
|
||||
else:
|
||||
log.error(f"Retry count exceeded, terminating link {link}")
|
||||
_retry_timer.complete(link.link_id)
|
||||
link.teardown()
|
||||
except Exception as e:
|
||||
log.error("Error notifying client: " + str(e))
|
||||
return link.link_id
|
||||
|
||||
with process_state.lock:
|
||||
if process_state.pending_receipt_peek() is None:
|
||||
_retry_timer.begin(try_limit=15,
|
||||
wait_delay=link.rtt * 3 if link.rtt is not None else 1,
|
||||
try_callback=functools.partial(send, False),
|
||||
timeout_callback=functools.partial(send, True),
|
||||
id=None)
|
||||
else:
|
||||
log.debug(f"Notification already pending for link {link}")
|
||||
|
||||
def _subproc_terminated(link: RNS.Link, return_code: int):
|
||||
log = _getLogger("_subproc_terminated")
|
||||
log.info(f"Subprocess terminated ({return_code} for link {link}")
|
||||
link.teardown()
|
||||
|
||||
def _listen_start_proc(link: RNS.Link, term: str) -> ProcessState | None:
|
||||
global _cmd
|
||||
log = _getLogger("_listen_start_proc")
|
||||
try:
|
||||
link.process = ProcessState(cmd=_cmd,
|
||||
term=term,
|
||||
data_available_callback=functools.partial(_subproc_data_ready, link),
|
||||
terminated_callback=functools.partial(_subproc_terminated, link))
|
||||
return link.process
|
||||
except Exception as e:
|
||||
result[0] = False
|
||||
if process is not None:
|
||||
process.terminate()
|
||||
process.link.teardown()
|
||||
log.error("Failed to launch process: " + str(e))
|
||||
link.teardown()
|
||||
return None
|
||||
def _listen_link_established(link):
|
||||
global _allow_all
|
||||
log = _getLogger("_listen_link_established")
|
||||
link.set_remote_identified_callback(_initiator_identified)
|
||||
link.set_link_closed_callback(_listen_link_closed)
|
||||
log.info("Link "+str(link)+" established")
|
||||
|
||||
def _listen_link_closed(link: RNS.Link):
|
||||
log = _getLogger("_listen_link_closed")
|
||||
# async def cleanup():
|
||||
log.info("Link "+str(link)+" closed")
|
||||
proc: ProcessState | None = link.proc if hasattr(link, "process") else None
|
||||
if proc is None:
|
||||
log.warning(f"No process for link {link}")
|
||||
try:
|
||||
proc.process.terminate()
|
||||
except:
|
||||
log.error(f"Error closing process for link {link}")
|
||||
# asyncio.get_running_loop().call_soon(cleanup)
|
||||
|
||||
def _initiator_identified(link, identity):
|
||||
global _allow_all, _cmd
|
||||
log = _getLogger("_initiator_identified")
|
||||
log.info("Initiator of link "+str(link)+" identified as "+RNS.prettyhexrep(identity.hash))
|
||||
if not _allow_all and not identity.hash in _allowed_identity_hashes:
|
||||
log.warning("Identity "+RNS.prettyhexrep(identity.hash)+" not allowed, tearing down link", RNS.LOG_WARNING)
|
||||
link.teardown()
|
||||
|
||||
def _listen_request(path, data, request_id, link_id, remote_identity, requested_at):
|
||||
global _destination, _retry_timer
|
||||
log = _getLogger("_listen_request")
|
||||
log.debug(f"listen_execute {path} {request_id} {link_id} {remote_identity}, {requested_at}")
|
||||
_retry_timer.complete(link_id)
|
||||
link: RNS.Link = next(filter(lambda l: l.link_id == link_id, _destination.links))
|
||||
if link is None:
|
||||
log.error(f"invalid request {request_id}, no link found with id {link_id}")
|
||||
return
|
||||
process_state: ProcessState | None = None
|
||||
try:
|
||||
term = data[1]
|
||||
process_state = link.process if hasattr(link, "process") else None
|
||||
if process_state is None:
|
||||
log.debug(f"process not found for link {link}")
|
||||
process_state = _listen_start_proc(link, term)
|
||||
|
||||
|
||||
# leave significant overhead for metadata and encoding
|
||||
result = process_state.process_request(data, link.MDU * 3 // 2)
|
||||
except Exception as e:
|
||||
result = ProcessState.default_response()
|
||||
try:
|
||||
if process_state is not None:
|
||||
process_state.process.terminate()
|
||||
link.teardown()
|
||||
except Exception as e:
|
||||
log.error(f"Error terminating process for link {link}")
|
||||
|
||||
return result
|
||||
|
||||
@ -552,7 +490,7 @@ def client_packet_handler(message, packet):
|
||||
if message is not None and message.decode("utf-8") == DATA_AVAIL_MSG:
|
||||
new_data = True
|
||||
def execute(configdir, identitypath = None, verbosity = 0, quietness = 0, noid = False, destination = None, service_name = "default", stdin = None, timeout = RNS.Transport.PATH_REQUEST_TIMEOUT):
|
||||
global identity, reticulum, link, listener_destination, remote_exec_grace
|
||||
global _identity, _reticulum, link, listener_destination, remote_exec_grace
|
||||
|
||||
try:
|
||||
dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH//8)*2
|
||||
@ -566,12 +504,12 @@ def execute(configdir, identitypath = None, verbosity = 0, quietness = 0, noid =
|
||||
print(str(e))
|
||||
return 241
|
||||
|
||||
if reticulum == None:
|
||||
if _reticulum == None:
|
||||
targetloglevel = 3+verbosity-quietness
|
||||
reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
|
||||
_reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
|
||||
|
||||
if identity == None:
|
||||
prepare_identity(identitypath)
|
||||
if _identity == None:
|
||||
_prepare_identity(identitypath)
|
||||
|
||||
if not RNS.Transport.has_path(destination_hash):
|
||||
RNS.Transport.request_path(destination_hash)
|
||||
@ -598,7 +536,7 @@ def execute(configdir, identitypath = None, verbosity = 0, quietness = 0, noid =
|
||||
return 243
|
||||
|
||||
if not noid and not link.did_identify:
|
||||
link.identify(identity)
|
||||
link.identify(_identity)
|
||||
link.did_identify = True
|
||||
|
||||
link.set_packet_callback(client_packet_handler)
|
||||
@ -718,7 +656,7 @@ def main():
|
||||
|
||||
if args.listen or args.print_identity:
|
||||
RNS.log("command " + args.command)
|
||||
listen(
|
||||
_listen(
|
||||
configdir=args.config,
|
||||
command=args.command,
|
||||
identitypath=args.identity,
|
||||
|
60
rnsh/rnslogging.py
Normal file
60
rnsh/rnslogging.py
Normal file
@ -0,0 +1,60 @@
|
||||
import logging
|
||||
from logging import Handler, getLevelName
|
||||
from types import GenericAlias
|
||||
import os
|
||||
|
||||
import RNS
|
||||
|
||||
class RnsHandler(Handler):
|
||||
"""
|
||||
A handler class which writes logging records, appropriately formatted,
|
||||
to the RNS logger.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the handler.
|
||||
"""
|
||||
Handler.__init__(self)
|
||||
|
||||
@staticmethod
|
||||
def get_rns_loglevel(loglevel: int) -> int:
|
||||
if loglevel == logging.CRITICAL:
|
||||
return RNS.LOG_CRITICAL
|
||||
if loglevel == logging.ERROR:
|
||||
return RNS.LOG_ERROR
|
||||
if loglevel == logging.WARNING:
|
||||
return RNS.LOG_WARNING
|
||||
if loglevel == logging.INFO:
|
||||
return RNS.LOG_INFO
|
||||
if loglevel == logging.DEBUG:
|
||||
return RNS.LOG_DEBUG
|
||||
return RNS.LOG_DEBUG
|
||||
|
||||
def emit(self, record):
|
||||
"""
|
||||
Emit a record.
|
||||
"""
|
||||
try:
|
||||
msg = self.format(record)
|
||||
RNS.log(msg, RnsHandler.get_rns_loglevel(record.levelno))
|
||||
except RecursionError: # See issue 36272
|
||||
raise
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
def __repr__(self):
|
||||
level = getLevelName(self.level)
|
||||
return '<%s (%s)>' % (self.__class__.__name__, level)
|
||||
|
||||
__class_getitem__ = classmethod(GenericAlias)
|
||||
|
||||
|
||||
log_format = '%(name)-40s %(message)s [%(threadName)s]'
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
#format='%(asctime)s.%(msecs)03d %(levelname)-6s %(threadName)-15s %(name)-15s %(message)s',
|
||||
format=log_format,
|
||||
datefmt='%Y-%m-%d %H:%M:%S',
|
||||
handlers=[RnsHandler()])
|
Loading…
Reference in New Issue
Block a user