Locking in a major increment for integrating the CallbackSubprocess module to rnsh

This commit is contained in:
Aaron Heise 2023-02-09 19:58:46 -06:00
parent 9186a64962
commit fcc73ba31a
4 changed files with 512 additions and 401 deletions

View File

@ -9,13 +9,10 @@ import pty
import os
import asyncio
import sys
import logging as __logging
import fcntl
import select
import termios
import logging as __logging
module_logger = __logging.getLogger(__name__)
def tty_add_reader_callback(fd: int, callback: callable, loop: asyncio.AbstractEventLoop | None = None):
@ -165,8 +162,8 @@ class CallbackSubprocess:
assert stdout_callback is not None, "stdout_callback should not be None"
assert terminated_callback is not None, "terminated_callback should not be None"
self.log = module_logger.getChild(self.__class__.__name__)
self.log.debug(f"__init__({argv},{term},...")
self._log = module_logger.getChild(self.__class__.__name__)
self._log.debug(f"__init__({argv},{term},...")
self._command = argv
self._term = term
self._loop = loop
@ -181,7 +178,7 @@ class CallbackSubprocess:
Terminate child process if running
:param kill_delay: if after kill_delay seconds the child process has not exited, escalate to SIGHUP and SIGKILL
"""
self.log.debug("terminate()")
self._log.debug("terminate()")
if not self.running:
return
@ -191,7 +188,7 @@ class CallbackSubprocess:
pass
def kill():
self.log.debug("kill()")
self._log.debug("kill()")
try:
os.kill(self._pid, signal.SIGHUP)
os.kill(self._pid, signal.SIGKILL)
@ -201,9 +198,9 @@ class CallbackSubprocess:
self._loop.call_later(kill_delay, kill)
def wait():
self.log.debug("wait()")
self._log.debug("wait()")
os.waitpid(self._pid, 0)
self.log.debug("wait() finish")
self._log.debug("wait() finish")
threading.Thread(target=wait).start()
@ -226,7 +223,7 @@ class CallbackSubprocess:
Write bytes to the stdin of the child process.
:param data: bytes to write
"""
self.log.debug(f"write({data})")
self._log.debug(f"write({data})")
os.write(self._child_fd, data)
def set_winsize(self, r: int, c: int, h: int, v: int):
@ -238,7 +235,7 @@ class CallbackSubprocess:
:param v: vertical pixels visible
:return:
"""
self.log.debug(f"set_winsize({r},{c},{h},{v}")
self._log.debug(f"set_winsize({r},{c},{h},{v}")
tty_set_winsize(self._child_fd, r, c, h, v)
def copy_winsize(self, fromfd:int):
@ -268,7 +265,7 @@ class CallbackSubprocess:
"""
Start the child process.
"""
self.log.debug("start()")
self._log.debug("start()")
parentenv = os.environ.copy()
env = {"HOME": parentenv["HOME"],
"TERM": self._term if self._term is not None else parentenv.get("TERM", "xterm"),
@ -290,11 +287,11 @@ class CallbackSubprocess:
try:
pid, self._return_code = os.waitpid(self._pid, os.WNOHANG)
if self._return_code is not None and not process_exists(self._pid):
self.log.debug(f"polled return code {self._return_code}")
self._log.debug(f"polled return code {self._return_code}")
self._terminated_cb(self._return_code)
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
except Exception as e:
self.log.debug(f"Error in process poll: {e}")
self._log.debug(f"Error in process poll: {e}")
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
def reader(fd: int, callback: callable):

116
rnsh/retry.py Normal file
View File

@ -0,0 +1,116 @@
import asyncio
import threading
import time
import logging as __logging
module_logger = __logging.getLogger(__name__)
class RetryStatus:
def __init__(self, id: any, try_limit: int, wait_delay: float, retry_callback: callable[any, int], timeout_callback: callable[any], tries: int = 1):
self.id = id
self.try_limit = try_limit
self.tries = tries
self.wait_delay = wait_delay
self.retry_callback = retry_callback
self.timeout_callback = timeout_callback
self.try_time = time.monotonic()
self.completed = False
@property
def ready(self):
return self.try_time + self.wait_delay < time.monotonic() and not self.completed
@property
def timed_out(self):
return self.ready and self.tries >= self.try_limit
def timeout(self):
self.completed = True
self.timeout_callback(self.id)
def retry(self):
self.tries += 1
self.retry_callback(self.id, self.tries)
class RetryThread:
def __init__(self, loop_period: float = 0.25):
self._log = module_logger.getChild(self.__class__.__name__)
self._loop_period = loop_period
self._statuses: list[RetryStatus] = []
self._id_counter = 0
self._lock = threading.RLock()
self._thread = threading.Thread(target=self._thread_run())
self._run = True
self._finished: asyncio.Future | None = None
self._thread.start()
def close(self, loop: asyncio.AbstractEventLoop | None = None) -> asyncio.Future | None:
self._log.debug("stopping timer thread")
if loop is None:
self._run = False
self._thread.join()
return None
else:
self._finished = loop.create_future()
return self._finished
def _thread_run(self):
last_run = time.monotonic()
while self._run and self._finished is None:
time.sleep(last_run + self._loop_period - time.monotonic())
last_run = time.monotonic()
ready: list[RetryStatus] = []
prune: list[RetryStatus] = []
with self._lock:
ready.extend(filter(lambda s: s.ready, self._statuses))
for retry in ready:
try:
if not retry.completed:
if retry.timed_out:
self._log.debug(f"timed out {retry.id} after {retry.try_limit} tries")
retry.timeout()
prune.append(retry)
else:
self._log.debug(f"retrying {retry.id}, try {retry.tries + 1}/{retry.try_limit}")
retry.retry()
except Exception as e:
self._log.error(f"error processing retry id {retry.id}: {e}")
prune.append(retry)
with self._lock:
for retry in prune:
self._log.debug(f"pruned retry {retry.id}, retry count {retry.tries}/{retry.try_limit}")
self._statuses.remove(retry)
if self._finished is not None:
self._finished.set_result(None)
def _get_id(self):
self._id_counter += 1
return self._id_counter
def begin(self, try_limit: int, wait_delay: float, try_callback: callable[[any | None, int], any], timeout_callback: callable[any, int], id: int | None = None) -> any:
self._log.debug(f"running first try")
id = try_callback(id, 1)
self._log.debug(f"first try success, got id {id}")
with self._lock:
if id is None:
id = self._get_id()
self._statuses.append(RetryStatus(id=id,
tries=1,
try_limit=try_limit,
wait_delay=wait_delay,
retry_callback=try_callback,
timeout_callback=timeout_callback))
self._log.debug(f"added retry timer for {id}")
def complete(self, id: any):
assert id is not None
with self._lock:
status = next(filter(lambda l: l.id == id, self._statuses))
assert status is not None
status.completed = True
self._statuses.remove(status)
self._log.debug(f"completed {id}")
def complete_all(self):
with self._lock:
for status in self._statuses:
status.completed = True
self._log.debug(f"completed {status.id}")
self._statuses.clear()

View File

@ -1,6 +1,5 @@
#!/usr/bin/env python3
import pty
import threading
import functools
# MIT License
#
@ -24,309 +23,82 @@ import threading
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import rnslogging
import RNS
import subprocess
import argparse
import shlex
import time
import sys
import tty
import os
import datetime
import select
import base64
import fcntl
import termios
import queue
import signal
import errno
import RNS.vendor.umsgpack as umsgpack
import process
import asyncio
import threading
import signal
import retry
import logging as __logging
module_logger = __logging.getLogger(__name__)
def _getLogger(name: str):
global module_logger
return module_logger.getChild(name)
from RNS._version import __version__
APP_NAME = "rnsh"
identity = None
reticulum = None
allow_all = False
allowed_identity_hashes = []
cmd = None
processes = []
processes_lock = threading.Lock()
_identity = None
_reticulum = None
_allow_all = False
_allowed_identity_hashes = []
_cmd: str | None = None
DATA_AVAIL_MSG = "data available"
_finished: asyncio.Future | None = None
_retry_timer = retry.RetryThread()
_destination: RNS.Destination | None = None
def _handle_sigint_with_async_int():
if _finished is not None:
_finished.set_exception(KeyboardInterrupt())
else:
raise KeyboardInterrupt()
def fd_set_non_blocking(fd):
old_flags = fcntl.fcntl(fd, fcntl.F_GETFL)
if fd.isatty():
tty.setraw(fd)
fcntl.fcntl(fd, fcntl.F_SETFL, old_flags | os.O_NONBLOCK)
def fd_non_blocking_read(fd):
# from https://stackoverflow.com/questions/26263636/how-to-check-potentially-empty-stdin-without-waiting-for-input
# TODO: Windows is probably different
signal.signal(signal.SIGINT, _handle_sigint_with_async_int)
#return fd.read()
try:
old_settings = None
# try:
# old_settings = termios.tcgetattr(fd)
# except:
# pass
old_flags = fcntl.fcntl(fd, fcntl.F_GETFL)
try:
# try:
# tty.setraw(fd)
# except:
# pass
fcntl.fcntl(fd, fcntl.F_SETFL, old_flags | os.O_NONBLOCK)
return os.read(fd.fileno(), 1024)
except OSError as ose:
if ose.errno != 35:
raise ose
except Exception as e:
RNS.log(f"Raw read error {e}")
finally:
fcntl.fcntl(fd, fcntl.F_SETFL, old_flags)
# if old_settings is not None:
# termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
except:
pass
class NonBlockingStreamReader:
def __init__(self, stream, callback = None):
'''
stream: the stream to read from.
Usually a process' stdout or stderr.
'''
self._s = stream
self._q = queue.Queue()
self._callback = callback
self._stop_time = None
def _populateQueue(stream, queue):
'''
Collect lines from 'stream' and put them in 'quque'.
'''
# fd_set_non_blocking(stream)
run = True
while run and not (self._stop_time is not None and (datetime.datetime.now() - self._stop_time).total_seconds() > 0.05):
# stream.flush()
# line = stream.read(1) #fd_non_blocking_read(stream)
timeout = 0.01
ready, _, _ = select.select([stream], [], [], timeout)
for fd in ready:
try:
data = os.read(fd, 512)
except OSError as e:
if e.errno != errno.EIO:
raise
# EIO means EOF on some systems
run = False
else:
if not data: # EOF
run = False
if data is not None and len(data) > 0:
if self._callback is not None:
self._callback(data)
else:
queue.put(data)
RNS.log("NonBlockingStreamReader exiting", RNS.LOG_DEBUG)
os.close(stream)
self._t = threading.Thread(target = _populateQueue,
args = (self._s, self._q))
self._t.daemon = True
self._t.start() #start collecting lines from the stream
def read(self, timeout = None):
try:
result = self._q.get_nowait() if timeout is None else self._q.get(block = timeout is not None,
timeout = timeout)
return result
except TimeoutError:
return None
def is_open(self):
return self._t.is_alive()
def stop(self):
if self._stop_time is None:
self._stop_time = datetime.datetime.now()
class UnexpectedEndOfStream(Exception): pass
def processes_get():
processes_lock.acquire()
try:
return processes.copy()
finally:
processes_lock.release()
def processes_add(process):
processes_lock.acquire()
try:
processes.append(process)
finally:
processes_lock.release()
def processes_remove(process):
if process.link.status == RNS.Link.ACTIVE:
return
processes_lock.acquire()
try:
if next(filter(lambda p: p == process, processes)) is not None:
processes.remove(process)
finally:
processes_lock.release()
#### Link Overrides ####
_link_handle_request_orig = RNS.Link.handle_request
def link_handle_request(self, request_id, unpacked_request):
for process in processes_get():
if process.link.link_id == self.link_id:
RNS.log("Associating packet to link", RNS.LOG_DEBUG)
process.request_id = request_id
self.last_request_id = request_id
_link_handle_request_orig(self, request_id, unpacked_request)
RNS.Link.handle_request = link_handle_request
class ProcessState:
def __init__(self, command, link, remote_identity, term):
self.lock = threading.RLock()
self.link = link
self.remote_identity = remote_identity
self.term = term
self.command = command
self._stderrbuf = bytearray()
self._stdoutbuf = bytearray()
RNS.log("Launching " + self.command) # + " for client " + (RNS.prettyhexrep(self.remote_identity) if self.remote_identity else "unknown"), RNS.LOG_DEBUG)
env = os.environ.copy()
# env["PYTHONUNBUFFERED"] = "1"
# env["PS1"] ="\\u:\\h "
env["TERM"] = self.term
self.mo, so = pty.openpty()
self.me, se = pty.openpty()
self.mi, si = pty.openpty()
self.process = subprocess.Popen(shlex.split(self.command), bufsize=512, stdin=si, stdout=so, stderr=se, preexec_fn=os.setsid, shell=False, env=env)
for fd in [so, se, si]:
os.close(fd)
# tty.setcbreak(self.mo)
self.stdout_reader = NonBlockingStreamReader(self.mo, self._stdout_cb)
self.stderr_reader = NonBlockingStreamReader(self.me, self._stderr_cb)
self.last_update = datetime.datetime.now()
self.request_id = None
self.notify_tried = 0
self.return_code = None
def _fd_callback(self, fdbuf, data):
with self.lock:
fdbuf.extend(data)
def _stdout_cb(self, data):
self._fd_callback(self._stdoutbuf, data)
def _stderr_cb(self, data):
self._fd_callback(self._stderrbuf, data)
def notify_client_data_available(self, chars_available):
if (datetime.datetime.now() - self.last_update).total_seconds() < 1:
return
self.last_update = datetime.datetime.now()
if self.notify_tried > 15:
processes_remove(self)
RNS.log(f"Try count exceeded, terminating connection", RNS.LOG_ERROR)
self.link.teardown()
return
try:
RNS.log(f"Notifying client; try {self.notify_tried} retcode: {self.return_code} chars avail: {chars_available}")
RNS.Packet(self.link, DATA_AVAIL_MSG.encode("utf-8")).send()
self.notify_tried += 1
except Exception as e:
RNS.log("Error notifying client: " + str(e), RNS.LOG_ERROR)
def poll(self, should_notify):
self.return_code, chars_available = self.process.poll(), len(self._stdoutbuf) + len(self._stderrbuf)
if should_notify and self.return_code is not None or chars_available > 0:
self.notify_client_data_available(chars_available)
if self.return_code is not None:
self.stdout_reader.stop()
self.stderr_reader.stop()
return self.return_code, chars_available
def is_finished(self):
with self.lock:
return self.return_code is not None and not self.stdout_reader.is_open() # and not self.stderr_reader.is_open()
def read(self): #TODO: limit take sizes?
with self.lock:
self.notify_tried = 0
self.last_update = datetime.datetime.now()
stdout = self._stdoutbuf
self._stdoutbuf = bytearray()
stderr = self._stderrbuf.copy()
self._stderrbuf = bytearray()
self.return_code = self.process.poll()
if self.return_code is not None and len(stdout) == 0 and len(stderr) == 0:
self.final_checkin = True
return self.process.poll(), stdout, stderr
def write(self, bytes):
os.write(self.mi, bytes)
os.fsync(self.mi)
def terminate(self):
chars_available = 0
with self.lock:
os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)
for fd in [self.mo, self.me, self.mi]:
os.close(fd)
self.process.terminate()
self.process.wait()
if self.process.poll() is not None:
stdout, stderr = self.process.communicate()
self._stdoutbuf += stdout
self._stderrbuf += stderr
return len(self._stdoutbuf) + len(self._stderrbuf)
def prepare_identity(identity_path):
global identity
def _prepare_identity(identity_path):
global _identity
log = _getLogger("_prepare_identity")
if identity_path == None:
identity_path = RNS.Reticulum.identitypath+"/"+APP_NAME
if os.path.isfile(identity_path):
identity = RNS.Identity.from_file(identity_path)
_identity = RNS.Identity.from_file(identity_path)
if identity == None:
RNS.log("No valid saved identity found, creating new...", RNS.LOG_INFO)
identity = RNS.Identity()
identity.to_file(identity_path)
if _identity == None:
log.info("No valid saved identity found, creating new...")
_identity = RNS.Identity()
_identity.to_file(identity_path)
def listen(configdir, command, identitypath = None, service_name ="default", verbosity = 0, quietness = 0, allowed = [], print_identity = False, disable_auth = None, disable_announce=False):
global identity, allow_all, allowed_identity_hashes, reticulum, cmd
cmd = command
async def _listen(configdir, command, identitypath = None, service_name ="default", verbosity = 0, quietness = 0,
allowed = [], print_identity = False, disable_auth = None, disable_announce=False):
global _identity, _allow_all, _allowed_identity_hashes, _reticulum, _cmd, _destination
log = _getLogger("_listen")
_cmd = command
targetloglevel = 3+verbosity-quietness
reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
_reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
prepare_identity(identitypath)
destination = RNS.Destination(identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, service_name)
_prepare_identity(identitypath)
_destination = RNS.Destination(_identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, service_name)
if print_identity:
print("Identity : "+str(identity))
print("Listening on : "+RNS.prettyhexrep(destination.hash))
log.info("Identity : " + str(_identity))
log.info("Listening on : " + RNS.prettyhexrep(_destination.hash))
exit(0)
if disable_auth:
allow_all = True
_allow_all = True
else:
if allowed != None:
for a in allowed:
@ -336,140 +108,306 @@ def listen(configdir, command, identitypath = None, service_name ="default", ver
raise ValueError("Allowed destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format(hex=dest_len, byte=dest_len//2))
try:
destination_hash = bytes.fromhex(a)
allowed_identity_hashes.append(destination_hash)
_allowed_identity_hashes.append(destination_hash)
except Exception as e:
raise ValueError("Invalid destination entered. Check your input.")
except Exception as e:
print(str(e))
log.error(str(e))
exit(1)
if len(allowed_identity_hashes) < 1 and not disable_auth:
print("Warning: No allowed identities configured, rncx will not accept any commands!")
if len(_allowed_identity_hashes) < 1 and not disable_auth:
log.warning("Warning: No allowed identities configured, rnsh will not accept any connections!")
destination.set_link_established_callback(command_link_established)
_destination.set_link_established_callback(_listen_link_established)
if not allow_all:
destination.register_request_handler(
if not _allow_all:
_destination.register_request_handler(
path = service_name,
response_generator = execute_received_command,
response_generator = _listen_request,
allow = RNS.Destination.ALLOW_LIST,
allowed_list = allowed_identity_hashes
allowed_list = _allowed_identity_hashes
)
else:
destination.register_request_handler(
_destination.register_request_handler(
path = service_name,
response_generator = execute_received_command,
response_generator = _listen_request,
allow = RNS.Destination.ALLOW_ALL,
)
RNS.log("rnsh listening for commands on "+RNS.prettyhexrep(destination.hash))
log.info("rnsh listening for commands on " + RNS.prettyhexrep(_destination.hash))
if not disable_announce:
destination.announce()
_destination.announce()
last = datetime.datetime.now()
last = time.monotonic()
while True:
if not disable_announce and (datetime.datetime.now() - last).total_seconds() > 900: # TODO: make parameter
last = datetime.datetime.now()
destination.announce()
time.sleep(0.005)
for proc in processes_get():
try:
while True:
if not disable_announce and time.monotonic() - last > 900: # TODO: make parameter
last = datetime.datetime.now()
_destination.announce()
try:
if proc.link.status == RNS.Link.CLOSED:
RNS.log("Link closed, terminating")
proc.terminate()
proc.poll(should_notify=True)
await asyncio.wait_for(_finished, timeout=1.0)
except TimeoutError:
pass
except KeyboardInterrupt:
log.warning("Shutting down")
for link in list(_destination.links):
try:
if link.process is not None and link.process.process.running:
link.process.process.terminate()
except:
RNS.log("Error polling process for link " + proc.link.link_id, RNS.LOG_ERROR)
pass
await asyncio.sleep(1)
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
for link in links_still_active:
if link.status != RNS.Link.CLOSED:
link.teardown()
if proc.link.status == RNS.Link.CLOSED:
processes_remove(proc)
class ProcessState:
def __init__(self,
cmd: str,
mdu: int,
data_available_callback: callable,
terminated_callback: callable,
term: str | None,
loop: asyncio.AbstractEventLoop = None):
self._log = _getLogger(self.__class__.__name__)
self._mdu = mdu
self._loop = loop if loop is not None else asyncio.get_running_loop()
self._process = process.CallbackSubprocess(argv=shlex.split(cmd),
term=term,
loop=asyncio.get_running_loop(),
stdout_callback=self._stdout_data,
terminated_callback=terminated_callback)
self._data_buffer = bytearray()
self._lock = threading.RLock()
self._data_available_cb = data_available_callback
self._terminated_cb = terminated_callback
self._pending_receipt: RNS.PacketReceipt | None = None
self._process.start()
self._term_state: [int] | None = None
@property
def mdu(self) -> int:
return self._mdu
@mdu.setter
def mdu(self, val: int):
self._mdu = val
def pending_receipt_peek(self) -> RNS.PacketReceipt | None:
return self._pending_receipt
def pending_receipt_take(self) -> RNS.PacketReceipt | None:
with self._lock:
val = self._pending_receipt
self._pending_receipt = None
return val
def pending_receipt_put(self, receipt: RNS.PacketReceipt | None):
with self._lock:
self._pending_receipt = receipt
@property
def process(self) -> process.CallbackSubprocess:
return self._process
@property
def return_code(self) -> int | None:
return self.process.return_code
@property
def lock(self) -> threading.RLock:
return self._lock
def read(self, count: int) -> bytes:
with self.lock:
take = self._data_buffer[:count-1]
self._data_buffer = self._data_buffer[count:]
return take
def _stdout_data(self, data: bytes):
total_available = 0
with self.lock:
self._data_buffer.extend(data)
total_available = len(self._data_buffer)
try:
self._data_available_cb(total_available)
except Exception as e:
self._log.error(f"Error calling ProcessState data_available_callback {e}")
def _update_winsz(self):
self.process.set_winsize(self._term_state[3],
self._term_state[4],
self._term_state[5],
self._term_state[6])
REQUEST_IDX_STDIN = 0
REQUEST_IDX_TERM = 1
REQUEST_IDX_TIOS = 2
REQUEST_IDX_ROWS = 3
REQUEST_IDX_COLS = 4
REQUEST_IDX_HPIX = 5
REQUEST_IDX_VPIX = 6
def process_request(self, data: [any], read_size: int) -> [any]:
stdin = data[ProcessState.REQUEST_IDX_STDIN] # Data passed to stdin
term = data[ProcessState.REQUEST_IDX_TERM] # TERM environment variable
tios = data[ProcessState.REQUEST_IDX_TIOS] # termios attr
rows = data[ProcessState.REQUEST_IDX_ROWS] # window rows
cols = data[ProcessState.REQUEST_IDX_COLS] # window cols
hpix = data[ProcessState.REQUEST_IDX_HPIX] # window horizontal pixels
vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels
term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX]
response = ProcessState.default_response()
def command_link_start_process(link, identity, term) -> ProcessState:
try:
process = ProcessState(cmd, link, identity, term)
processes_add(process)
return process
except Exception as e:
RNS.log("Failed to launch process: " + str(e), RNS.LOG_ERROR)
link.teardown()
def command_link_established(link):
global allow_all
link.set_remote_identified_callback(initiator_identified)
link.set_link_closed_callback(command_link_closed)
RNS.log("Shell link "+str(link)+" established")
if allow_all:
command_link_start_process(link, None)
def command_link_closed(link):
RNS.log("Shell link "+str(link)+" closed")
matches = list(filter(lambda p: p.link == link, processes_get()))
if len(matches) == 0:
return
proc = matches[0]
try:
proc.terminate()
except:
RNS.log("Error closing process for link " + RNS.prettyhexrep(link.link_id), RNS.LOG_ERROR)
finally:
processes_remove(proc)
def initiator_identified(link, identity):
global allow_all, cmd
RNS.log("Initiator of link "+str(link)+" identified as "+RNS.prettyhexrep(identity.hash))
if not allow_all and not identity.hash in allowed_identity_hashes:
RNS.log("Identity "+RNS.prettyhexrep(identity.hash)+" not allowed, tearing down link", RNS.LOG_WARNING)
link.teardown()
def execute_received_command(path, data, request_id, remote_identity, requested_at):
RNS.log("execute_received_command", RNS.LOG_DEBUG)
process = None
for proc in processes_get():
RNS.log("checking a proc", RNS.LOG_DEBUG)
if proc.request_id == request_id:
process = proc
RNS.log("execute_received_command matched request", RNS.LOG_DEBUG)
stdin = data[0] # Data passed to stdin
if process is None:
link = next(filter(lambda l: hasattr(l, "last_request_id") and l.last_request_id == request_id, RNS.Transport.active_links))
if link is not None:
process = command_link_start_process(link, identity, base64.b64decode(stdin).decode("utf-8") if stdin is not None else "")
time.sleep(0.1)
# if remote_identity != None:
# RNS.log("Executing command ["+command+"] for "+RNS.prettyhexrep(remote_identity.hash))
# else:
# RNS.log("Executing command ["+command+"] for unknown requestor")
result = [
False, # 0: Command was executed
None, # 1: Return value
None, # 2: Stdout
None, # 3: Stderr
datetime.datetime.now(), # 4: Timestamp
]
try:
if process is not None:
result[0] = not process.is_finished()
response[ProcessState.RESPONSE_IDX_RUNNING] = not self.process.running
if self.process.running:
if term_state != self._term_state:
self._term_state = term_state
self._update_winsz()
if stdin is not None and len(stdin) > 0:
stdin = base64.b64decode(stdin)
process.write(stdin)
return_code, stdout, stderr = process.read()
result[1] = return_code
result[2] = base64.b64encode(stdout).decode("utf-8") if stdout is not None else None
result[3] = base64.b64encode(stderr).decode("utf-8") if stderr is not None else None
self.process.write(stdin)
response[ProcessState.RESPONSE_IDX_RETCODE] = self.return_code
stdout = self.read(read_size)
with self.lock:
response[ProcessState.RESPONSE_IDX_RDYBYTE] = len(self._data_buffer)
response[ProcessState.RESPONSE_IDX_STDOUT] = \
base64.b64encode(stdout).decode("utf-8") if stdout is not None and len(stdout) > 0 else None
return response
RESPONSE_IDX_RUNNING = 0
RESPONSE_IDX_RETCODE = 1
RESPONSE_IDX_RDYBYTE = 2
RESPONSE_IDX_STDOUT = 3
RESPONSE_IDX_TMSTAMP = 4
@staticmethod
def default_response() -> [any]:
return [
False, # 0: Process running
None, # 1: Return value
0, # 2: Number of outstanding bytes
None, # 3: Stdout/Stderr
time.time(), # 4: Timestamp
]
def _subproc_data_ready(link: RNS.Link, chars_available: int):
global _retry_timer
log = _getLogger("_subproc_data_ready")
process_state: ProcessState = link.process
def send(timeout: bool, id: any, tries: int) -> any:
try:
pr = process_state.pending_receipt_take()
if pr is not None and pr.get_status() != RNS.PacketReceipt.SENT and pr.get_status() != RNS.PacketReceipt.DELIVERED:
if not timeout:
_retry_timer.complete(id)
log.debug(f"Packet {id} completed with status {pr.status} on link {link}")
return link.link_id
if not timeout:
log.info(f"Notifying client try {tries} (retcode: {process_state.return_code} chars avail: {chars_available})")
packet = RNS.Packet(link, DATA_AVAIL_MSG.encode("utf-8"))
packet.send()
pr = packet.receipt
process_state.pending_receipt_put(pr)
return link.link_id
else:
log.error(f"Retry count exceeded, terminating link {link}")
_retry_timer.complete(link.link_id)
link.teardown()
except Exception as e:
log.error("Error notifying client: " + str(e))
return link.link_id
with process_state.lock:
if process_state.pending_receipt_peek() is None:
_retry_timer.begin(try_limit=15,
wait_delay=link.rtt * 3 if link.rtt is not None else 1,
try_callback=functools.partial(send, False),
timeout_callback=functools.partial(send, True),
id=None)
else:
log.debug(f"Notification already pending for link {link}")
def _subproc_terminated(link: RNS.Link, return_code: int):
log = _getLogger("_subproc_terminated")
log.info(f"Subprocess terminated ({return_code} for link {link}")
link.teardown()
def _listen_start_proc(link: RNS.Link, term: str) -> ProcessState | None:
global _cmd
log = _getLogger("_listen_start_proc")
try:
link.process = ProcessState(cmd=_cmd,
term=term,
data_available_callback=functools.partial(_subproc_data_ready, link),
terminated_callback=functools.partial(_subproc_terminated, link))
return link.process
except Exception as e:
result[0] = False
if process is not None:
process.terminate()
process.link.teardown()
log.error("Failed to launch process: " + str(e))
link.teardown()
return None
def _listen_link_established(link):
global _allow_all
log = _getLogger("_listen_link_established")
link.set_remote_identified_callback(_initiator_identified)
link.set_link_closed_callback(_listen_link_closed)
log.info("Link "+str(link)+" established")
def _listen_link_closed(link: RNS.Link):
log = _getLogger("_listen_link_closed")
# async def cleanup():
log.info("Link "+str(link)+" closed")
proc: ProcessState | None = link.proc if hasattr(link, "process") else None
if proc is None:
log.warning(f"No process for link {link}")
try:
proc.process.terminate()
except:
log.error(f"Error closing process for link {link}")
# asyncio.get_running_loop().call_soon(cleanup)
def _initiator_identified(link, identity):
global _allow_all, _cmd
log = _getLogger("_initiator_identified")
log.info("Initiator of link "+str(link)+" identified as "+RNS.prettyhexrep(identity.hash))
if not _allow_all and not identity.hash in _allowed_identity_hashes:
log.warning("Identity "+RNS.prettyhexrep(identity.hash)+" not allowed, tearing down link", RNS.LOG_WARNING)
link.teardown()
def _listen_request(path, data, request_id, link_id, remote_identity, requested_at):
global _destination, _retry_timer
log = _getLogger("_listen_request")
log.debug(f"listen_execute {path} {request_id} {link_id} {remote_identity}, {requested_at}")
_retry_timer.complete(link_id)
link: RNS.Link = next(filter(lambda l: l.link_id == link_id, _destination.links))
if link is None:
log.error(f"invalid request {request_id}, no link found with id {link_id}")
return
process_state: ProcessState | None = None
try:
term = data[1]
process_state = link.process if hasattr(link, "process") else None
if process_state is None:
log.debug(f"process not found for link {link}")
process_state = _listen_start_proc(link, term)
# leave significant overhead for metadata and encoding
result = process_state.process_request(data, link.MDU * 3 // 2)
except Exception as e:
result = ProcessState.default_response()
try:
if process_state is not None:
process_state.process.terminate()
link.teardown()
except Exception as e:
log.error(f"Error terminating process for link {link}")
return result
@ -552,7 +490,7 @@ def client_packet_handler(message, packet):
if message is not None and message.decode("utf-8") == DATA_AVAIL_MSG:
new_data = True
def execute(configdir, identitypath = None, verbosity = 0, quietness = 0, noid = False, destination = None, service_name = "default", stdin = None, timeout = RNS.Transport.PATH_REQUEST_TIMEOUT):
global identity, reticulum, link, listener_destination, remote_exec_grace
global _identity, _reticulum, link, listener_destination, remote_exec_grace
try:
dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH//8)*2
@ -566,12 +504,12 @@ def execute(configdir, identitypath = None, verbosity = 0, quietness = 0, noid =
print(str(e))
return 241
if reticulum == None:
if _reticulum == None:
targetloglevel = 3+verbosity-quietness
reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
_reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
if identity == None:
prepare_identity(identitypath)
if _identity == None:
_prepare_identity(identitypath)
if not RNS.Transport.has_path(destination_hash):
RNS.Transport.request_path(destination_hash)
@ -598,7 +536,7 @@ def execute(configdir, identitypath = None, verbosity = 0, quietness = 0, noid =
return 243
if not noid and not link.did_identify:
link.identify(identity)
link.identify(_identity)
link.did_identify = True
link.set_packet_callback(client_packet_handler)
@ -718,7 +656,7 @@ def main():
if args.listen or args.print_identity:
RNS.log("command " + args.command)
listen(
_listen(
configdir=args.config,
command=args.command,
identitypath=args.identity,

60
rnsh/rnslogging.py Normal file
View File

@ -0,0 +1,60 @@
import logging
from logging import Handler, getLevelName
from types import GenericAlias
import os
import RNS
class RnsHandler(Handler):
"""
A handler class which writes logging records, appropriately formatted,
to the RNS logger.
"""
def __init__(self):
"""
Initialize the handler.
"""
Handler.__init__(self)
@staticmethod
def get_rns_loglevel(loglevel: int) -> int:
if loglevel == logging.CRITICAL:
return RNS.LOG_CRITICAL
if loglevel == logging.ERROR:
return RNS.LOG_ERROR
if loglevel == logging.WARNING:
return RNS.LOG_WARNING
if loglevel == logging.INFO:
return RNS.LOG_INFO
if loglevel == logging.DEBUG:
return RNS.LOG_DEBUG
return RNS.LOG_DEBUG
def emit(self, record):
"""
Emit a record.
"""
try:
msg = self.format(record)
RNS.log(msg, RnsHandler.get_rns_loglevel(record.levelno))
except RecursionError: # See issue 36272
raise
except Exception:
self.handleError(record)
def __repr__(self):
level = getLevelName(self.level)
return '<%s (%s)>' % (self.__class__.__name__, level)
__class_getitem__ = classmethod(GenericAlias)
log_format = '%(name)-40s %(message)s [%(threadName)s]'
logging.basicConfig(
level=logging.INFO,
#format='%(asctime)s.%(msecs)03d %(levelname)-6s %(threadName)-15s %(name)-15s %(message)s',
format=log_format,
datefmt='%Y-%m-%d %H:%M:%S',
handlers=[RnsHandler()])