Locking in a major increment for integrating the CallbackSubprocess module to rnsh

This commit is contained in:
Aaron Heise 2023-02-09 19:58:46 -06:00
parent 9186a64962
commit fcc73ba31a
4 changed files with 512 additions and 401 deletions

View File

@ -9,13 +9,10 @@ import pty
import os import os
import asyncio import asyncio
import sys import sys
import logging as __logging
import fcntl import fcntl
import select import select
import termios import termios
import logging as __logging
module_logger = __logging.getLogger(__name__) module_logger = __logging.getLogger(__name__)
def tty_add_reader_callback(fd: int, callback: callable, loop: asyncio.AbstractEventLoop | None = None): def tty_add_reader_callback(fd: int, callback: callable, loop: asyncio.AbstractEventLoop | None = None):
@ -165,8 +162,8 @@ class CallbackSubprocess:
assert stdout_callback is not None, "stdout_callback should not be None" assert stdout_callback is not None, "stdout_callback should not be None"
assert terminated_callback is not None, "terminated_callback should not be None" assert terminated_callback is not None, "terminated_callback should not be None"
self.log = module_logger.getChild(self.__class__.__name__) self._log = module_logger.getChild(self.__class__.__name__)
self.log.debug(f"__init__({argv},{term},...") self._log.debug(f"__init__({argv},{term},...")
self._command = argv self._command = argv
self._term = term self._term = term
self._loop = loop self._loop = loop
@ -181,7 +178,7 @@ class CallbackSubprocess:
Terminate child process if running Terminate child process if running
:param kill_delay: if after kill_delay seconds the child process has not exited, escalate to SIGHUP and SIGKILL :param kill_delay: if after kill_delay seconds the child process has not exited, escalate to SIGHUP and SIGKILL
""" """
self.log.debug("terminate()") self._log.debug("terminate()")
if not self.running: if not self.running:
return return
@ -191,7 +188,7 @@ class CallbackSubprocess:
pass pass
def kill(): def kill():
self.log.debug("kill()") self._log.debug("kill()")
try: try:
os.kill(self._pid, signal.SIGHUP) os.kill(self._pid, signal.SIGHUP)
os.kill(self._pid, signal.SIGKILL) os.kill(self._pid, signal.SIGKILL)
@ -201,9 +198,9 @@ class CallbackSubprocess:
self._loop.call_later(kill_delay, kill) self._loop.call_later(kill_delay, kill)
def wait(): def wait():
self.log.debug("wait()") self._log.debug("wait()")
os.waitpid(self._pid, 0) os.waitpid(self._pid, 0)
self.log.debug("wait() finish") self._log.debug("wait() finish")
threading.Thread(target=wait).start() threading.Thread(target=wait).start()
@ -226,7 +223,7 @@ class CallbackSubprocess:
Write bytes to the stdin of the child process. Write bytes to the stdin of the child process.
:param data: bytes to write :param data: bytes to write
""" """
self.log.debug(f"write({data})") self._log.debug(f"write({data})")
os.write(self._child_fd, data) os.write(self._child_fd, data)
def set_winsize(self, r: int, c: int, h: int, v: int): def set_winsize(self, r: int, c: int, h: int, v: int):
@ -238,7 +235,7 @@ class CallbackSubprocess:
:param v: vertical pixels visible :param v: vertical pixels visible
:return: :return:
""" """
self.log.debug(f"set_winsize({r},{c},{h},{v}") self._log.debug(f"set_winsize({r},{c},{h},{v}")
tty_set_winsize(self._child_fd, r, c, h, v) tty_set_winsize(self._child_fd, r, c, h, v)
def copy_winsize(self, fromfd:int): def copy_winsize(self, fromfd:int):
@ -268,7 +265,7 @@ class CallbackSubprocess:
""" """
Start the child process. Start the child process.
""" """
self.log.debug("start()") self._log.debug("start()")
parentenv = os.environ.copy() parentenv = os.environ.copy()
env = {"HOME": parentenv["HOME"], env = {"HOME": parentenv["HOME"],
"TERM": self._term if self._term is not None else parentenv.get("TERM", "xterm"), "TERM": self._term if self._term is not None else parentenv.get("TERM", "xterm"),
@ -290,11 +287,11 @@ class CallbackSubprocess:
try: try:
pid, self._return_code = os.waitpid(self._pid, os.WNOHANG) pid, self._return_code = os.waitpid(self._pid, os.WNOHANG)
if self._return_code is not None and not process_exists(self._pid): if self._return_code is not None and not process_exists(self._pid):
self.log.debug(f"polled return code {self._return_code}") self._log.debug(f"polled return code {self._return_code}")
self._terminated_cb(self._return_code) self._terminated_cb(self._return_code)
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll) self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
except Exception as e: except Exception as e:
self.log.debug(f"Error in process poll: {e}") self._log.debug(f"Error in process poll: {e}")
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll) self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
def reader(fd: int, callback: callable): def reader(fd: int, callback: callable):

116
rnsh/retry.py Normal file
View File

@ -0,0 +1,116 @@
import asyncio
import threading
import time
import logging as __logging
module_logger = __logging.getLogger(__name__)
class RetryStatus:
def __init__(self, id: any, try_limit: int, wait_delay: float, retry_callback: callable[any, int], timeout_callback: callable[any], tries: int = 1):
self.id = id
self.try_limit = try_limit
self.tries = tries
self.wait_delay = wait_delay
self.retry_callback = retry_callback
self.timeout_callback = timeout_callback
self.try_time = time.monotonic()
self.completed = False
@property
def ready(self):
return self.try_time + self.wait_delay < time.monotonic() and not self.completed
@property
def timed_out(self):
return self.ready and self.tries >= self.try_limit
def timeout(self):
self.completed = True
self.timeout_callback(self.id)
def retry(self):
self.tries += 1
self.retry_callback(self.id, self.tries)
class RetryThread:
def __init__(self, loop_period: float = 0.25):
self._log = module_logger.getChild(self.__class__.__name__)
self._loop_period = loop_period
self._statuses: list[RetryStatus] = []
self._id_counter = 0
self._lock = threading.RLock()
self._thread = threading.Thread(target=self._thread_run())
self._run = True
self._finished: asyncio.Future | None = None
self._thread.start()
def close(self, loop: asyncio.AbstractEventLoop | None = None) -> asyncio.Future | None:
self._log.debug("stopping timer thread")
if loop is None:
self._run = False
self._thread.join()
return None
else:
self._finished = loop.create_future()
return self._finished
def _thread_run(self):
last_run = time.monotonic()
while self._run and self._finished is None:
time.sleep(last_run + self._loop_period - time.monotonic())
last_run = time.monotonic()
ready: list[RetryStatus] = []
prune: list[RetryStatus] = []
with self._lock:
ready.extend(filter(lambda s: s.ready, self._statuses))
for retry in ready:
try:
if not retry.completed:
if retry.timed_out:
self._log.debug(f"timed out {retry.id} after {retry.try_limit} tries")
retry.timeout()
prune.append(retry)
else:
self._log.debug(f"retrying {retry.id}, try {retry.tries + 1}/{retry.try_limit}")
retry.retry()
except Exception as e:
self._log.error(f"error processing retry id {retry.id}: {e}")
prune.append(retry)
with self._lock:
for retry in prune:
self._log.debug(f"pruned retry {retry.id}, retry count {retry.tries}/{retry.try_limit}")
self._statuses.remove(retry)
if self._finished is not None:
self._finished.set_result(None)
def _get_id(self):
self._id_counter += 1
return self._id_counter
def begin(self, try_limit: int, wait_delay: float, try_callback: callable[[any | None, int], any], timeout_callback: callable[any, int], id: int | None = None) -> any:
self._log.debug(f"running first try")
id = try_callback(id, 1)
self._log.debug(f"first try success, got id {id}")
with self._lock:
if id is None:
id = self._get_id()
self._statuses.append(RetryStatus(id=id,
tries=1,
try_limit=try_limit,
wait_delay=wait_delay,
retry_callback=try_callback,
timeout_callback=timeout_callback))
self._log.debug(f"added retry timer for {id}")
def complete(self, id: any):
assert id is not None
with self._lock:
status = next(filter(lambda l: l.id == id, self._statuses))
assert status is not None
status.completed = True
self._statuses.remove(status)
self._log.debug(f"completed {id}")
def complete_all(self):
with self._lock:
for status in self._statuses:
status.completed = True
self._log.debug(f"completed {status.id}")
self._statuses.clear()

View File

@ -1,6 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import pty import functools
import threading
# MIT License # MIT License
# #
@ -24,309 +23,82 @@ import threading
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE. # SOFTWARE.
import rnslogging
import RNS import RNS
import subprocess
import argparse import argparse
import shlex import shlex
import time import time
import sys import sys
import tty
import os import os
import datetime import datetime
import select
import base64 import base64
import fcntl
import termios
import queue
import signal
import errno
import RNS.vendor.umsgpack as umsgpack import RNS.vendor.umsgpack as umsgpack
import process
import asyncio
import threading
import signal
import retry
import logging as __logging
module_logger = __logging.getLogger(__name__)
def _getLogger(name: str):
global module_logger
return module_logger.getChild(name)
from RNS._version import __version__ from RNS._version import __version__
APP_NAME = "rnsh" APP_NAME = "rnsh"
identity = None _identity = None
reticulum = None _reticulum = None
allow_all = False _allow_all = False
allowed_identity_hashes = [] _allowed_identity_hashes = []
cmd = None _cmd: str | None = None
processes = []
processes_lock = threading.Lock()
DATA_AVAIL_MSG = "data available" DATA_AVAIL_MSG = "data available"
_finished: asyncio.Future | None = None
_retry_timer = retry.RetryThread()
_destination: RNS.Destination | None = None
def _handle_sigint_with_async_int():
if _finished is not None:
_finished.set_exception(KeyboardInterrupt())
else:
raise KeyboardInterrupt()
def fd_set_non_blocking(fd): signal.signal(signal.SIGINT, _handle_sigint_with_async_int)
old_flags = fcntl.fcntl(fd, fcntl.F_GETFL)
if fd.isatty():
tty.setraw(fd)
fcntl.fcntl(fd, fcntl.F_SETFL, old_flags | os.O_NONBLOCK)
def fd_non_blocking_read(fd):
# from https://stackoverflow.com/questions/26263636/how-to-check-potentially-empty-stdin-without-waiting-for-input
# TODO: Windows is probably different
#return fd.read() def _prepare_identity(identity_path):
try: global _identity
old_settings = None log = _getLogger("_prepare_identity")
# try:
# old_settings = termios.tcgetattr(fd)
# except:
# pass
old_flags = fcntl.fcntl(fd, fcntl.F_GETFL)
try:
# try:
# tty.setraw(fd)
# except:
# pass
fcntl.fcntl(fd, fcntl.F_SETFL, old_flags | os.O_NONBLOCK)
return os.read(fd.fileno(), 1024)
except OSError as ose:
if ose.errno != 35:
raise ose
except Exception as e:
RNS.log(f"Raw read error {e}")
finally:
fcntl.fcntl(fd, fcntl.F_SETFL, old_flags)
# if old_settings is not None:
# termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
except:
pass
class NonBlockingStreamReader:
def __init__(self, stream, callback = None):
'''
stream: the stream to read from.
Usually a process' stdout or stderr.
'''
self._s = stream
self._q = queue.Queue()
self._callback = callback
self._stop_time = None
def _populateQueue(stream, queue):
'''
Collect lines from 'stream' and put them in 'quque'.
'''
# fd_set_non_blocking(stream)
run = True
while run and not (self._stop_time is not None and (datetime.datetime.now() - self._stop_time).total_seconds() > 0.05):
# stream.flush()
# line = stream.read(1) #fd_non_blocking_read(stream)
timeout = 0.01
ready, _, _ = select.select([stream], [], [], timeout)
for fd in ready:
try:
data = os.read(fd, 512)
except OSError as e:
if e.errno != errno.EIO:
raise
# EIO means EOF on some systems
run = False
else:
if not data: # EOF
run = False
if data is not None and len(data) > 0:
if self._callback is not None:
self._callback(data)
else:
queue.put(data)
RNS.log("NonBlockingStreamReader exiting", RNS.LOG_DEBUG)
os.close(stream)
self._t = threading.Thread(target = _populateQueue,
args = (self._s, self._q))
self._t.daemon = True
self._t.start() #start collecting lines from the stream
def read(self, timeout = None):
try:
result = self._q.get_nowait() if timeout is None else self._q.get(block = timeout is not None,
timeout = timeout)
return result
except TimeoutError:
return None
def is_open(self):
return self._t.is_alive()
def stop(self):
if self._stop_time is None:
self._stop_time = datetime.datetime.now()
class UnexpectedEndOfStream(Exception): pass
def processes_get():
processes_lock.acquire()
try:
return processes.copy()
finally:
processes_lock.release()
def processes_add(process):
processes_lock.acquire()
try:
processes.append(process)
finally:
processes_lock.release()
def processes_remove(process):
if process.link.status == RNS.Link.ACTIVE:
return
processes_lock.acquire()
try:
if next(filter(lambda p: p == process, processes)) is not None:
processes.remove(process)
finally:
processes_lock.release()
#### Link Overrides ####
_link_handle_request_orig = RNS.Link.handle_request
def link_handle_request(self, request_id, unpacked_request):
for process in processes_get():
if process.link.link_id == self.link_id:
RNS.log("Associating packet to link", RNS.LOG_DEBUG)
process.request_id = request_id
self.last_request_id = request_id
_link_handle_request_orig(self, request_id, unpacked_request)
RNS.Link.handle_request = link_handle_request
class ProcessState:
def __init__(self, command, link, remote_identity, term):
self.lock = threading.RLock()
self.link = link
self.remote_identity = remote_identity
self.term = term
self.command = command
self._stderrbuf = bytearray()
self._stdoutbuf = bytearray()
RNS.log("Launching " + self.command) # + " for client " + (RNS.prettyhexrep(self.remote_identity) if self.remote_identity else "unknown"), RNS.LOG_DEBUG)
env = os.environ.copy()
# env["PYTHONUNBUFFERED"] = "1"
# env["PS1"] ="\\u:\\h "
env["TERM"] = self.term
self.mo, so = pty.openpty()
self.me, se = pty.openpty()
self.mi, si = pty.openpty()
self.process = subprocess.Popen(shlex.split(self.command), bufsize=512, stdin=si, stdout=so, stderr=se, preexec_fn=os.setsid, shell=False, env=env)
for fd in [so, se, si]:
os.close(fd)
# tty.setcbreak(self.mo)
self.stdout_reader = NonBlockingStreamReader(self.mo, self._stdout_cb)
self.stderr_reader = NonBlockingStreamReader(self.me, self._stderr_cb)
self.last_update = datetime.datetime.now()
self.request_id = None
self.notify_tried = 0
self.return_code = None
def _fd_callback(self, fdbuf, data):
with self.lock:
fdbuf.extend(data)
def _stdout_cb(self, data):
self._fd_callback(self._stdoutbuf, data)
def _stderr_cb(self, data):
self._fd_callback(self._stderrbuf, data)
def notify_client_data_available(self, chars_available):
if (datetime.datetime.now() - self.last_update).total_seconds() < 1:
return
self.last_update = datetime.datetime.now()
if self.notify_tried > 15:
processes_remove(self)
RNS.log(f"Try count exceeded, terminating connection", RNS.LOG_ERROR)
self.link.teardown()
return
try:
RNS.log(f"Notifying client; try {self.notify_tried} retcode: {self.return_code} chars avail: {chars_available}")
RNS.Packet(self.link, DATA_AVAIL_MSG.encode("utf-8")).send()
self.notify_tried += 1
except Exception as e:
RNS.log("Error notifying client: " + str(e), RNS.LOG_ERROR)
def poll(self, should_notify):
self.return_code, chars_available = self.process.poll(), len(self._stdoutbuf) + len(self._stderrbuf)
if should_notify and self.return_code is not None or chars_available > 0:
self.notify_client_data_available(chars_available)
if self.return_code is not None:
self.stdout_reader.stop()
self.stderr_reader.stop()
return self.return_code, chars_available
def is_finished(self):
with self.lock:
return self.return_code is not None and not self.stdout_reader.is_open() # and not self.stderr_reader.is_open()
def read(self): #TODO: limit take sizes?
with self.lock:
self.notify_tried = 0
self.last_update = datetime.datetime.now()
stdout = self._stdoutbuf
self._stdoutbuf = bytearray()
stderr = self._stderrbuf.copy()
self._stderrbuf = bytearray()
self.return_code = self.process.poll()
if self.return_code is not None and len(stdout) == 0 and len(stderr) == 0:
self.final_checkin = True
return self.process.poll(), stdout, stderr
def write(self, bytes):
os.write(self.mi, bytes)
os.fsync(self.mi)
def terminate(self):
chars_available = 0
with self.lock:
os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)
for fd in [self.mo, self.me, self.mi]:
os.close(fd)
self.process.terminate()
self.process.wait()
if self.process.poll() is not None:
stdout, stderr = self.process.communicate()
self._stdoutbuf += stdout
self._stderrbuf += stderr
return len(self._stdoutbuf) + len(self._stderrbuf)
def prepare_identity(identity_path):
global identity
if identity_path == None: if identity_path == None:
identity_path = RNS.Reticulum.identitypath+"/"+APP_NAME identity_path = RNS.Reticulum.identitypath+"/"+APP_NAME
if os.path.isfile(identity_path): if os.path.isfile(identity_path):
identity = RNS.Identity.from_file(identity_path) _identity = RNS.Identity.from_file(identity_path)
if identity == None: if _identity == None:
RNS.log("No valid saved identity found, creating new...", RNS.LOG_INFO) log.info("No valid saved identity found, creating new...")
identity = RNS.Identity() _identity = RNS.Identity()
identity.to_file(identity_path) _identity.to_file(identity_path)
def listen(configdir, command, identitypath = None, service_name ="default", verbosity = 0, quietness = 0, allowed = [], print_identity = False, disable_auth = None, disable_announce=False): async def _listen(configdir, command, identitypath = None, service_name ="default", verbosity = 0, quietness = 0,
global identity, allow_all, allowed_identity_hashes, reticulum, cmd allowed = [], print_identity = False, disable_auth = None, disable_announce=False):
global _identity, _allow_all, _allowed_identity_hashes, _reticulum, _cmd, _destination
cmd = command log = _getLogger("_listen")
_cmd = command
targetloglevel = 3+verbosity-quietness targetloglevel = 3+verbosity-quietness
reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel) _reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
prepare_identity(identitypath) _prepare_identity(identitypath)
destination = RNS.Destination(identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, service_name) _destination = RNS.Destination(_identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, service_name)
if print_identity: if print_identity:
print("Identity : "+str(identity)) log.info("Identity : " + str(_identity))
print("Listening on : "+RNS.prettyhexrep(destination.hash)) log.info("Listening on : " + RNS.prettyhexrep(_destination.hash))
exit(0) exit(0)
if disable_auth: if disable_auth:
allow_all = True _allow_all = True
else: else:
if allowed != None: if allowed != None:
for a in allowed: for a in allowed:
@ -336,140 +108,306 @@ def listen(configdir, command, identitypath = None, service_name ="default", ver
raise ValueError("Allowed destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format(hex=dest_len, byte=dest_len//2)) raise ValueError("Allowed destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format(hex=dest_len, byte=dest_len//2))
try: try:
destination_hash = bytes.fromhex(a) destination_hash = bytes.fromhex(a)
allowed_identity_hashes.append(destination_hash) _allowed_identity_hashes.append(destination_hash)
except Exception as e: except Exception as e:
raise ValueError("Invalid destination entered. Check your input.") raise ValueError("Invalid destination entered. Check your input.")
except Exception as e: except Exception as e:
print(str(e)) log.error(str(e))
exit(1) exit(1)
if len(allowed_identity_hashes) < 1 and not disable_auth: if len(_allowed_identity_hashes) < 1 and not disable_auth:
print("Warning: No allowed identities configured, rncx will not accept any commands!") log.warning("Warning: No allowed identities configured, rnsh will not accept any connections!")
destination.set_link_established_callback(command_link_established) _destination.set_link_established_callback(_listen_link_established)
if not allow_all: if not _allow_all:
destination.register_request_handler( _destination.register_request_handler(
path = service_name, path = service_name,
response_generator = execute_received_command, response_generator = _listen_request,
allow = RNS.Destination.ALLOW_LIST, allow = RNS.Destination.ALLOW_LIST,
allowed_list = allowed_identity_hashes allowed_list = _allowed_identity_hashes
) )
else: else:
destination.register_request_handler( _destination.register_request_handler(
path = service_name, path = service_name,
response_generator = execute_received_command, response_generator = _listen_request,
allow = RNS.Destination.ALLOW_ALL, allow = RNS.Destination.ALLOW_ALL,
) )
RNS.log("rnsh listening for commands on "+RNS.prettyhexrep(destination.hash)) log.info("rnsh listening for commands on " + RNS.prettyhexrep(_destination.hash))
if not disable_announce: if not disable_announce:
destination.announce() _destination.announce()
last = datetime.datetime.now() last = time.monotonic()
while True:
if not disable_announce and (datetime.datetime.now() - last).total_seconds() > 900: # TODO: make parameter
last = datetime.datetime.now()
destination.announce()
time.sleep(0.005) try:
for proc in processes_get(): while True:
if not disable_announce and time.monotonic() - last > 900: # TODO: make parameter
last = datetime.datetime.now()
_destination.announce()
try: try:
if proc.link.status == RNS.Link.CLOSED: await asyncio.wait_for(_finished, timeout=1.0)
RNS.log("Link closed, terminating") except TimeoutError:
proc.terminate() pass
proc.poll(should_notify=True) except KeyboardInterrupt:
log.warning("Shutting down")
for link in list(_destination.links):
try:
if link.process is not None and link.process.process.running:
link.process.process.terminate()
except: except:
RNS.log("Error polling process for link " + proc.link.link_id, RNS.LOG_ERROR) pass
await asyncio.sleep(1)
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
for link in links_still_active:
if link.status != RNS.Link.CLOSED:
link.teardown()
if proc.link.status == RNS.Link.CLOSED: class ProcessState:
processes_remove(proc) def __init__(self,
cmd: str,
mdu: int,
data_available_callback: callable,
terminated_callback: callable,
term: str | None,
loop: asyncio.AbstractEventLoop = None):
self._log = _getLogger(self.__class__.__name__)
self._mdu = mdu
self._loop = loop if loop is not None else asyncio.get_running_loop()
self._process = process.CallbackSubprocess(argv=shlex.split(cmd),
term=term,
loop=asyncio.get_running_loop(),
stdout_callback=self._stdout_data,
terminated_callback=terminated_callback)
self._data_buffer = bytearray()
self._lock = threading.RLock()
self._data_available_cb = data_available_callback
self._terminated_cb = terminated_callback
self._pending_receipt: RNS.PacketReceipt | None = None
self._process.start()
self._term_state: [int] | None = None
@property
def mdu(self) -> int:
return self._mdu
@mdu.setter
def mdu(self, val: int):
self._mdu = val
def pending_receipt_peek(self) -> RNS.PacketReceipt | None:
return self._pending_receipt
def pending_receipt_take(self) -> RNS.PacketReceipt | None:
with self._lock:
val = self._pending_receipt
self._pending_receipt = None
return val
def pending_receipt_put(self, receipt: RNS.PacketReceipt | None):
with self._lock:
self._pending_receipt = receipt
@property
def process(self) -> process.CallbackSubprocess:
return self._process
@property
def return_code(self) -> int | None:
return self.process.return_code
@property
def lock(self) -> threading.RLock:
return self._lock
def read(self, count: int) -> bytes:
with self.lock:
take = self._data_buffer[:count-1]
self._data_buffer = self._data_buffer[count:]
return take
def _stdout_data(self, data: bytes):
total_available = 0
with self.lock:
self._data_buffer.extend(data)
total_available = len(self._data_buffer)
try:
self._data_available_cb(total_available)
except Exception as e:
self._log.error(f"Error calling ProcessState data_available_callback {e}")
def _update_winsz(self):
self.process.set_winsize(self._term_state[3],
self._term_state[4],
self._term_state[5],
self._term_state[6])
REQUEST_IDX_STDIN = 0
REQUEST_IDX_TERM = 1
REQUEST_IDX_TIOS = 2
REQUEST_IDX_ROWS = 3
REQUEST_IDX_COLS = 4
REQUEST_IDX_HPIX = 5
REQUEST_IDX_VPIX = 6
def process_request(self, data: [any], read_size: int) -> [any]:
stdin = data[ProcessState.REQUEST_IDX_STDIN] # Data passed to stdin
term = data[ProcessState.REQUEST_IDX_TERM] # TERM environment variable
tios = data[ProcessState.REQUEST_IDX_TIOS] # termios attr
rows = data[ProcessState.REQUEST_IDX_ROWS] # window rows
cols = data[ProcessState.REQUEST_IDX_COLS] # window cols
hpix = data[ProcessState.REQUEST_IDX_HPIX] # window horizontal pixels
vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels
term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX]
response = ProcessState.default_response()
def command_link_start_process(link, identity, term) -> ProcessState: response[ProcessState.RESPONSE_IDX_RUNNING] = not self.process.running
try: if self.process.running:
process = ProcessState(cmd, link, identity, term) if term_state != self._term_state:
processes_add(process) self._term_state = term_state
return process self._update_winsz()
except Exception as e:
RNS.log("Failed to launch process: " + str(e), RNS.LOG_ERROR)
link.teardown()
def command_link_established(link):
global allow_all
link.set_remote_identified_callback(initiator_identified)
link.set_link_closed_callback(command_link_closed)
RNS.log("Shell link "+str(link)+" established")
if allow_all:
command_link_start_process(link, None)
def command_link_closed(link):
RNS.log("Shell link "+str(link)+" closed")
matches = list(filter(lambda p: p.link == link, processes_get()))
if len(matches) == 0:
return
proc = matches[0]
try:
proc.terminate()
except:
RNS.log("Error closing process for link " + RNS.prettyhexrep(link.link_id), RNS.LOG_ERROR)
finally:
processes_remove(proc)
def initiator_identified(link, identity):
global allow_all, cmd
RNS.log("Initiator of link "+str(link)+" identified as "+RNS.prettyhexrep(identity.hash))
if not allow_all and not identity.hash in allowed_identity_hashes:
RNS.log("Identity "+RNS.prettyhexrep(identity.hash)+" not allowed, tearing down link", RNS.LOG_WARNING)
link.teardown()
def execute_received_command(path, data, request_id, remote_identity, requested_at):
RNS.log("execute_received_command", RNS.LOG_DEBUG)
process = None
for proc in processes_get():
RNS.log("checking a proc", RNS.LOG_DEBUG)
if proc.request_id == request_id:
process = proc
RNS.log("execute_received_command matched request", RNS.LOG_DEBUG)
stdin = data[0] # Data passed to stdin
if process is None:
link = next(filter(lambda l: hasattr(l, "last_request_id") and l.last_request_id == request_id, RNS.Transport.active_links))
if link is not None:
process = command_link_start_process(link, identity, base64.b64decode(stdin).decode("utf-8") if stdin is not None else "")
time.sleep(0.1)
# if remote_identity != None:
# RNS.log("Executing command ["+command+"] for "+RNS.prettyhexrep(remote_identity.hash))
# else:
# RNS.log("Executing command ["+command+"] for unknown requestor")
result = [
False, # 0: Command was executed
None, # 1: Return value
None, # 2: Stdout
None, # 3: Stderr
datetime.datetime.now(), # 4: Timestamp
]
try:
if process is not None:
result[0] = not process.is_finished()
if stdin is not None and len(stdin) > 0: if stdin is not None and len(stdin) > 0:
stdin = base64.b64decode(stdin) stdin = base64.b64decode(stdin)
process.write(stdin) self.process.write(stdin)
return_code, stdout, stderr = process.read() response[ProcessState.RESPONSE_IDX_RETCODE] = self.return_code
result[1] = return_code stdout = self.read(read_size)
result[2] = base64.b64encode(stdout).decode("utf-8") if stdout is not None else None with self.lock:
result[3] = base64.b64encode(stderr).decode("utf-8") if stderr is not None else None response[ProcessState.RESPONSE_IDX_RDYBYTE] = len(self._data_buffer)
response[ProcessState.RESPONSE_IDX_STDOUT] = \
base64.b64encode(stdout).decode("utf-8") if stdout is not None and len(stdout) > 0 else None
return response
RESPONSE_IDX_RUNNING = 0
RESPONSE_IDX_RETCODE = 1
RESPONSE_IDX_RDYBYTE = 2
RESPONSE_IDX_STDOUT = 3
RESPONSE_IDX_TMSTAMP = 4
@staticmethod
def default_response() -> [any]:
return [
False, # 0: Process running
None, # 1: Return value
0, # 2: Number of outstanding bytes
None, # 3: Stdout/Stderr
time.time(), # 4: Timestamp
]
def _subproc_data_ready(link: RNS.Link, chars_available: int):
global _retry_timer
log = _getLogger("_subproc_data_ready")
process_state: ProcessState = link.process
def send(timeout: bool, id: any, tries: int) -> any:
try:
pr = process_state.pending_receipt_take()
if pr is not None and pr.get_status() != RNS.PacketReceipt.SENT and pr.get_status() != RNS.PacketReceipt.DELIVERED:
if not timeout:
_retry_timer.complete(id)
log.debug(f"Packet {id} completed with status {pr.status} on link {link}")
return link.link_id
if not timeout:
log.info(f"Notifying client try {tries} (retcode: {process_state.return_code} chars avail: {chars_available})")
packet = RNS.Packet(link, DATA_AVAIL_MSG.encode("utf-8"))
packet.send()
pr = packet.receipt
process_state.pending_receipt_put(pr)
return link.link_id
else:
log.error(f"Retry count exceeded, terminating link {link}")
_retry_timer.complete(link.link_id)
link.teardown()
except Exception as e:
log.error("Error notifying client: " + str(e))
return link.link_id
with process_state.lock:
if process_state.pending_receipt_peek() is None:
_retry_timer.begin(try_limit=15,
wait_delay=link.rtt * 3 if link.rtt is not None else 1,
try_callback=functools.partial(send, False),
timeout_callback=functools.partial(send, True),
id=None)
else:
log.debug(f"Notification already pending for link {link}")
def _subproc_terminated(link: RNS.Link, return_code: int):
log = _getLogger("_subproc_terminated")
log.info(f"Subprocess terminated ({return_code} for link {link}")
link.teardown()
def _listen_start_proc(link: RNS.Link, term: str) -> ProcessState | None:
global _cmd
log = _getLogger("_listen_start_proc")
try:
link.process = ProcessState(cmd=_cmd,
term=term,
data_available_callback=functools.partial(_subproc_data_ready, link),
terminated_callback=functools.partial(_subproc_terminated, link))
return link.process
except Exception as e: except Exception as e:
result[0] = False log.error("Failed to launch process: " + str(e))
if process is not None: link.teardown()
process.terminate() return None
process.link.teardown() def _listen_link_established(link):
global _allow_all
log = _getLogger("_listen_link_established")
link.set_remote_identified_callback(_initiator_identified)
link.set_link_closed_callback(_listen_link_closed)
log.info("Link "+str(link)+" established")
def _listen_link_closed(link: RNS.Link):
log = _getLogger("_listen_link_closed")
# async def cleanup():
log.info("Link "+str(link)+" closed")
proc: ProcessState | None = link.proc if hasattr(link, "process") else None
if proc is None:
log.warning(f"No process for link {link}")
try:
proc.process.terminate()
except:
log.error(f"Error closing process for link {link}")
# asyncio.get_running_loop().call_soon(cleanup)
def _initiator_identified(link, identity):
global _allow_all, _cmd
log = _getLogger("_initiator_identified")
log.info("Initiator of link "+str(link)+" identified as "+RNS.prettyhexrep(identity.hash))
if not _allow_all and not identity.hash in _allowed_identity_hashes:
log.warning("Identity "+RNS.prettyhexrep(identity.hash)+" not allowed, tearing down link", RNS.LOG_WARNING)
link.teardown()
def _listen_request(path, data, request_id, link_id, remote_identity, requested_at):
global _destination, _retry_timer
log = _getLogger("_listen_request")
log.debug(f"listen_execute {path} {request_id} {link_id} {remote_identity}, {requested_at}")
_retry_timer.complete(link_id)
link: RNS.Link = next(filter(lambda l: l.link_id == link_id, _destination.links))
if link is None:
log.error(f"invalid request {request_id}, no link found with id {link_id}")
return
process_state: ProcessState | None = None
try:
term = data[1]
process_state = link.process if hasattr(link, "process") else None
if process_state is None:
log.debug(f"process not found for link {link}")
process_state = _listen_start_proc(link, term)
# leave significant overhead for metadata and encoding
result = process_state.process_request(data, link.MDU * 3 // 2)
except Exception as e:
result = ProcessState.default_response()
try:
if process_state is not None:
process_state.process.terminate()
link.teardown()
except Exception as e:
log.error(f"Error terminating process for link {link}")
return result return result
@ -552,7 +490,7 @@ def client_packet_handler(message, packet):
if message is not None and message.decode("utf-8") == DATA_AVAIL_MSG: if message is not None and message.decode("utf-8") == DATA_AVAIL_MSG:
new_data = True new_data = True
def execute(configdir, identitypath = None, verbosity = 0, quietness = 0, noid = False, destination = None, service_name = "default", stdin = None, timeout = RNS.Transport.PATH_REQUEST_TIMEOUT): def execute(configdir, identitypath = None, verbosity = 0, quietness = 0, noid = False, destination = None, service_name = "default", stdin = None, timeout = RNS.Transport.PATH_REQUEST_TIMEOUT):
global identity, reticulum, link, listener_destination, remote_exec_grace global _identity, _reticulum, link, listener_destination, remote_exec_grace
try: try:
dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH//8)*2 dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH//8)*2
@ -566,12 +504,12 @@ def execute(configdir, identitypath = None, verbosity = 0, quietness = 0, noid =
print(str(e)) print(str(e))
return 241 return 241
if reticulum == None: if _reticulum == None:
targetloglevel = 3+verbosity-quietness targetloglevel = 3+verbosity-quietness
reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel) _reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
if identity == None: if _identity == None:
prepare_identity(identitypath) _prepare_identity(identitypath)
if not RNS.Transport.has_path(destination_hash): if not RNS.Transport.has_path(destination_hash):
RNS.Transport.request_path(destination_hash) RNS.Transport.request_path(destination_hash)
@ -598,7 +536,7 @@ def execute(configdir, identitypath = None, verbosity = 0, quietness = 0, noid =
return 243 return 243
if not noid and not link.did_identify: if not noid and not link.did_identify:
link.identify(identity) link.identify(_identity)
link.did_identify = True link.did_identify = True
link.set_packet_callback(client_packet_handler) link.set_packet_callback(client_packet_handler)
@ -718,7 +656,7 @@ def main():
if args.listen or args.print_identity: if args.listen or args.print_identity:
RNS.log("command " + args.command) RNS.log("command " + args.command)
listen( _listen(
configdir=args.config, configdir=args.config,
command=args.command, command=args.command,
identitypath=args.identity, identitypath=args.identity,

60
rnsh/rnslogging.py Normal file
View File

@ -0,0 +1,60 @@
import logging
from logging import Handler, getLevelName
from types import GenericAlias
import os
import RNS
class RnsHandler(Handler):
"""
A handler class which writes logging records, appropriately formatted,
to the RNS logger.
"""
def __init__(self):
"""
Initialize the handler.
"""
Handler.__init__(self)
@staticmethod
def get_rns_loglevel(loglevel: int) -> int:
if loglevel == logging.CRITICAL:
return RNS.LOG_CRITICAL
if loglevel == logging.ERROR:
return RNS.LOG_ERROR
if loglevel == logging.WARNING:
return RNS.LOG_WARNING
if loglevel == logging.INFO:
return RNS.LOG_INFO
if loglevel == logging.DEBUG:
return RNS.LOG_DEBUG
return RNS.LOG_DEBUG
def emit(self, record):
"""
Emit a record.
"""
try:
msg = self.format(record)
RNS.log(msg, RnsHandler.get_rns_loglevel(record.levelno))
except RecursionError: # See issue 36272
raise
except Exception:
self.handleError(record)
def __repr__(self):
level = getLevelName(self.level)
return '<%s (%s)>' % (self.__class__.__name__, level)
__class_getitem__ = classmethod(GenericAlias)
log_format = '%(name)-40s %(message)s [%(threadName)s]'
logging.basicConfig(
level=logging.INFO,
#format='%(asctime)s.%(msecs)03d %(levelname)-6s %(threadName)-15s %(name)-15s %(message)s',
format=log_format,
datefmt='%Y-%m-%d %H:%M:%S',
handlers=[RnsHandler()])