mirror of
https://github.com/markqvist/rnsh.git
synced 2025-01-22 04:01:00 -05:00
Locking in a major increment for integrating the CallbackSubprocess module to rnsh
This commit is contained in:
parent
9186a64962
commit
fcc73ba31a
@ -9,13 +9,10 @@ import pty
|
|||||||
import os
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import logging as __logging
|
|
||||||
|
|
||||||
import fcntl
|
import fcntl
|
||||||
import select
|
import select
|
||||||
import termios
|
import termios
|
||||||
|
import logging as __logging
|
||||||
module_logger = __logging.getLogger(__name__)
|
module_logger = __logging.getLogger(__name__)
|
||||||
|
|
||||||
def tty_add_reader_callback(fd: int, callback: callable, loop: asyncio.AbstractEventLoop | None = None):
|
def tty_add_reader_callback(fd: int, callback: callable, loop: asyncio.AbstractEventLoop | None = None):
|
||||||
@ -165,8 +162,8 @@ class CallbackSubprocess:
|
|||||||
assert stdout_callback is not None, "stdout_callback should not be None"
|
assert stdout_callback is not None, "stdout_callback should not be None"
|
||||||
assert terminated_callback is not None, "terminated_callback should not be None"
|
assert terminated_callback is not None, "terminated_callback should not be None"
|
||||||
|
|
||||||
self.log = module_logger.getChild(self.__class__.__name__)
|
self._log = module_logger.getChild(self.__class__.__name__)
|
||||||
self.log.debug(f"__init__({argv},{term},...")
|
self._log.debug(f"__init__({argv},{term},...")
|
||||||
self._command = argv
|
self._command = argv
|
||||||
self._term = term
|
self._term = term
|
||||||
self._loop = loop
|
self._loop = loop
|
||||||
@ -181,7 +178,7 @@ class CallbackSubprocess:
|
|||||||
Terminate child process if running
|
Terminate child process if running
|
||||||
:param kill_delay: if after kill_delay seconds the child process has not exited, escalate to SIGHUP and SIGKILL
|
:param kill_delay: if after kill_delay seconds the child process has not exited, escalate to SIGHUP and SIGKILL
|
||||||
"""
|
"""
|
||||||
self.log.debug("terminate()")
|
self._log.debug("terminate()")
|
||||||
if not self.running:
|
if not self.running:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -191,7 +188,7 @@ class CallbackSubprocess:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def kill():
|
def kill():
|
||||||
self.log.debug("kill()")
|
self._log.debug("kill()")
|
||||||
try:
|
try:
|
||||||
os.kill(self._pid, signal.SIGHUP)
|
os.kill(self._pid, signal.SIGHUP)
|
||||||
os.kill(self._pid, signal.SIGKILL)
|
os.kill(self._pid, signal.SIGKILL)
|
||||||
@ -201,9 +198,9 @@ class CallbackSubprocess:
|
|||||||
self._loop.call_later(kill_delay, kill)
|
self._loop.call_later(kill_delay, kill)
|
||||||
|
|
||||||
def wait():
|
def wait():
|
||||||
self.log.debug("wait()")
|
self._log.debug("wait()")
|
||||||
os.waitpid(self._pid, 0)
|
os.waitpid(self._pid, 0)
|
||||||
self.log.debug("wait() finish")
|
self._log.debug("wait() finish")
|
||||||
|
|
||||||
threading.Thread(target=wait).start()
|
threading.Thread(target=wait).start()
|
||||||
|
|
||||||
@ -226,7 +223,7 @@ class CallbackSubprocess:
|
|||||||
Write bytes to the stdin of the child process.
|
Write bytes to the stdin of the child process.
|
||||||
:param data: bytes to write
|
:param data: bytes to write
|
||||||
"""
|
"""
|
||||||
self.log.debug(f"write({data})")
|
self._log.debug(f"write({data})")
|
||||||
os.write(self._child_fd, data)
|
os.write(self._child_fd, data)
|
||||||
|
|
||||||
def set_winsize(self, r: int, c: int, h: int, v: int):
|
def set_winsize(self, r: int, c: int, h: int, v: int):
|
||||||
@ -238,7 +235,7 @@ class CallbackSubprocess:
|
|||||||
:param v: vertical pixels visible
|
:param v: vertical pixels visible
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
self.log.debug(f"set_winsize({r},{c},{h},{v}")
|
self._log.debug(f"set_winsize({r},{c},{h},{v}")
|
||||||
tty_set_winsize(self._child_fd, r, c, h, v)
|
tty_set_winsize(self._child_fd, r, c, h, v)
|
||||||
|
|
||||||
def copy_winsize(self, fromfd:int):
|
def copy_winsize(self, fromfd:int):
|
||||||
@ -268,7 +265,7 @@ class CallbackSubprocess:
|
|||||||
"""
|
"""
|
||||||
Start the child process.
|
Start the child process.
|
||||||
"""
|
"""
|
||||||
self.log.debug("start()")
|
self._log.debug("start()")
|
||||||
parentenv = os.environ.copy()
|
parentenv = os.environ.copy()
|
||||||
env = {"HOME": parentenv["HOME"],
|
env = {"HOME": parentenv["HOME"],
|
||||||
"TERM": self._term if self._term is not None else parentenv.get("TERM", "xterm"),
|
"TERM": self._term if self._term is not None else parentenv.get("TERM", "xterm"),
|
||||||
@ -290,11 +287,11 @@ class CallbackSubprocess:
|
|||||||
try:
|
try:
|
||||||
pid, self._return_code = os.waitpid(self._pid, os.WNOHANG)
|
pid, self._return_code = os.waitpid(self._pid, os.WNOHANG)
|
||||||
if self._return_code is not None and not process_exists(self._pid):
|
if self._return_code is not None and not process_exists(self._pid):
|
||||||
self.log.debug(f"polled return code {self._return_code}")
|
self._log.debug(f"polled return code {self._return_code}")
|
||||||
self._terminated_cb(self._return_code)
|
self._terminated_cb(self._return_code)
|
||||||
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
|
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log.debug(f"Error in process poll: {e}")
|
self._log.debug(f"Error in process poll: {e}")
|
||||||
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
|
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
|
||||||
|
|
||||||
def reader(fd: int, callback: callable):
|
def reader(fd: int, callback: callable):
|
||||||
|
116
rnsh/retry.py
Normal file
116
rnsh/retry.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import logging as __logging
|
||||||
|
module_logger = __logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class RetryStatus:
|
||||||
|
def __init__(self, id: any, try_limit: int, wait_delay: float, retry_callback: callable[any, int], timeout_callback: callable[any], tries: int = 1):
|
||||||
|
self.id = id
|
||||||
|
self.try_limit = try_limit
|
||||||
|
self.tries = tries
|
||||||
|
self.wait_delay = wait_delay
|
||||||
|
self.retry_callback = retry_callback
|
||||||
|
self.timeout_callback = timeout_callback
|
||||||
|
self.try_time = time.monotonic()
|
||||||
|
self.completed = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ready(self):
|
||||||
|
return self.try_time + self.wait_delay < time.monotonic() and not self.completed
|
||||||
|
|
||||||
|
@property
|
||||||
|
def timed_out(self):
|
||||||
|
return self.ready and self.tries >= self.try_limit
|
||||||
|
|
||||||
|
def timeout(self):
|
||||||
|
self.completed = True
|
||||||
|
self.timeout_callback(self.id)
|
||||||
|
|
||||||
|
def retry(self):
|
||||||
|
self.tries += 1
|
||||||
|
self.retry_callback(self.id, self.tries)
|
||||||
|
|
||||||
|
class RetryThread:
|
||||||
|
def __init__(self, loop_period: float = 0.25):
|
||||||
|
self._log = module_logger.getChild(self.__class__.__name__)
|
||||||
|
self._loop_period = loop_period
|
||||||
|
self._statuses: list[RetryStatus] = []
|
||||||
|
self._id_counter = 0
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
self._thread = threading.Thread(target=self._thread_run())
|
||||||
|
self._run = True
|
||||||
|
self._finished: asyncio.Future | None = None
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
def close(self, loop: asyncio.AbstractEventLoop | None = None) -> asyncio.Future | None:
|
||||||
|
self._log.debug("stopping timer thread")
|
||||||
|
if loop is None:
|
||||||
|
self._run = False
|
||||||
|
self._thread.join()
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
self._finished = loop.create_future()
|
||||||
|
return self._finished
|
||||||
|
def _thread_run(self):
|
||||||
|
last_run = time.monotonic()
|
||||||
|
while self._run and self._finished is None:
|
||||||
|
time.sleep(last_run + self._loop_period - time.monotonic())
|
||||||
|
last_run = time.monotonic()
|
||||||
|
ready: list[RetryStatus] = []
|
||||||
|
prune: list[RetryStatus] = []
|
||||||
|
with self._lock:
|
||||||
|
ready.extend(filter(lambda s: s.ready, self._statuses))
|
||||||
|
for retry in ready:
|
||||||
|
try:
|
||||||
|
if not retry.completed:
|
||||||
|
if retry.timed_out:
|
||||||
|
self._log.debug(f"timed out {retry.id} after {retry.try_limit} tries")
|
||||||
|
retry.timeout()
|
||||||
|
prune.append(retry)
|
||||||
|
else:
|
||||||
|
self._log.debug(f"retrying {retry.id}, try {retry.tries + 1}/{retry.try_limit}")
|
||||||
|
retry.retry()
|
||||||
|
except Exception as e:
|
||||||
|
self._log.error(f"error processing retry id {retry.id}: {e}")
|
||||||
|
prune.append(retry)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
for retry in prune:
|
||||||
|
self._log.debug(f"pruned retry {retry.id}, retry count {retry.tries}/{retry.try_limit}")
|
||||||
|
self._statuses.remove(retry)
|
||||||
|
if self._finished is not None:
|
||||||
|
self._finished.set_result(None)
|
||||||
|
|
||||||
|
def _get_id(self):
|
||||||
|
self._id_counter += 1
|
||||||
|
return self._id_counter
|
||||||
|
|
||||||
|
def begin(self, try_limit: int, wait_delay: float, try_callback: callable[[any | None, int], any], timeout_callback: callable[any, int], id: int | None = None) -> any:
|
||||||
|
self._log.debug(f"running first try")
|
||||||
|
id = try_callback(id, 1)
|
||||||
|
self._log.debug(f"first try success, got id {id}")
|
||||||
|
with self._lock:
|
||||||
|
if id is None:
|
||||||
|
id = self._get_id()
|
||||||
|
self._statuses.append(RetryStatus(id=id,
|
||||||
|
tries=1,
|
||||||
|
try_limit=try_limit,
|
||||||
|
wait_delay=wait_delay,
|
||||||
|
retry_callback=try_callback,
|
||||||
|
timeout_callback=timeout_callback))
|
||||||
|
self._log.debug(f"added retry timer for {id}")
|
||||||
|
def complete(self, id: any):
|
||||||
|
assert id is not None
|
||||||
|
with self._lock:
|
||||||
|
status = next(filter(lambda l: l.id == id, self._statuses))
|
||||||
|
assert status is not None
|
||||||
|
status.completed = True
|
||||||
|
self._statuses.remove(status)
|
||||||
|
self._log.debug(f"completed {id}")
|
||||||
|
def complete_all(self):
|
||||||
|
with self._lock:
|
||||||
|
for status in self._statuses:
|
||||||
|
status.completed = True
|
||||||
|
self._log.debug(f"completed {status.id}")
|
||||||
|
self._statuses.clear()
|
710
rnsh/rnsh.py
710
rnsh/rnsh.py
@ -1,6 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import pty
|
import functools
|
||||||
import threading
|
|
||||||
|
|
||||||
# MIT License
|
# MIT License
|
||||||
#
|
#
|
||||||
@ -24,309 +23,82 @@ import threading
|
|||||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
|
|
||||||
|
import rnslogging
|
||||||
import RNS
|
import RNS
|
||||||
import subprocess
|
|
||||||
import argparse
|
import argparse
|
||||||
import shlex
|
import shlex
|
||||||
import time
|
import time
|
||||||
import sys
|
import sys
|
||||||
import tty
|
|
||||||
import os
|
import os
|
||||||
import datetime
|
import datetime
|
||||||
import select
|
|
||||||
import base64
|
import base64
|
||||||
import fcntl
|
|
||||||
import termios
|
|
||||||
import queue
|
|
||||||
import signal
|
|
||||||
import errno
|
|
||||||
import RNS.vendor.umsgpack as umsgpack
|
import RNS.vendor.umsgpack as umsgpack
|
||||||
|
import process
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
import signal
|
||||||
|
import retry
|
||||||
|
|
||||||
|
import logging as __logging
|
||||||
|
module_logger = __logging.getLogger(__name__)
|
||||||
|
def _getLogger(name: str):
|
||||||
|
global module_logger
|
||||||
|
return module_logger.getChild(name)
|
||||||
|
|
||||||
from RNS._version import __version__
|
from RNS._version import __version__
|
||||||
|
|
||||||
APP_NAME = "rnsh"
|
APP_NAME = "rnsh"
|
||||||
identity = None
|
_identity = None
|
||||||
reticulum = None
|
_reticulum = None
|
||||||
allow_all = False
|
_allow_all = False
|
||||||
allowed_identity_hashes = []
|
_allowed_identity_hashes = []
|
||||||
cmd = None
|
_cmd: str | None = None
|
||||||
processes = []
|
|
||||||
processes_lock = threading.Lock()
|
|
||||||
DATA_AVAIL_MSG = "data available"
|
DATA_AVAIL_MSG = "data available"
|
||||||
|
_finished: asyncio.Future | None = None
|
||||||
|
_retry_timer = retry.RetryThread()
|
||||||
|
_destination: RNS.Destination | None = None
|
||||||
|
|
||||||
|
def _handle_sigint_with_async_int():
|
||||||
|
if _finished is not None:
|
||||||
|
_finished.set_exception(KeyboardInterrupt())
|
||||||
|
else:
|
||||||
|
raise KeyboardInterrupt()
|
||||||
|
|
||||||
def fd_set_non_blocking(fd):
|
signal.signal(signal.SIGINT, _handle_sigint_with_async_int)
|
||||||
old_flags = fcntl.fcntl(fd, fcntl.F_GETFL)
|
|
||||||
if fd.isatty():
|
|
||||||
tty.setraw(fd)
|
|
||||||
fcntl.fcntl(fd, fcntl.F_SETFL, old_flags | os.O_NONBLOCK)
|
|
||||||
def fd_non_blocking_read(fd):
|
|
||||||
# from https://stackoverflow.com/questions/26263636/how-to-check-potentially-empty-stdin-without-waiting-for-input
|
|
||||||
# TODO: Windows is probably different
|
|
||||||
|
|
||||||
#return fd.read()
|
def _prepare_identity(identity_path):
|
||||||
try:
|
global _identity
|
||||||
old_settings = None
|
log = _getLogger("_prepare_identity")
|
||||||
# try:
|
|
||||||
# old_settings = termios.tcgetattr(fd)
|
|
||||||
# except:
|
|
||||||
# pass
|
|
||||||
old_flags = fcntl.fcntl(fd, fcntl.F_GETFL)
|
|
||||||
try:
|
|
||||||
# try:
|
|
||||||
# tty.setraw(fd)
|
|
||||||
# except:
|
|
||||||
# pass
|
|
||||||
fcntl.fcntl(fd, fcntl.F_SETFL, old_flags | os.O_NONBLOCK)
|
|
||||||
return os.read(fd.fileno(), 1024)
|
|
||||||
except OSError as ose:
|
|
||||||
if ose.errno != 35:
|
|
||||||
raise ose
|
|
||||||
except Exception as e:
|
|
||||||
RNS.log(f"Raw read error {e}")
|
|
||||||
finally:
|
|
||||||
fcntl.fcntl(fd, fcntl.F_SETFL, old_flags)
|
|
||||||
# if old_settings is not None:
|
|
||||||
# termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
class NonBlockingStreamReader:
|
|
||||||
|
|
||||||
def __init__(self, stream, callback = None):
|
|
||||||
'''
|
|
||||||
stream: the stream to read from.
|
|
||||||
Usually a process' stdout or stderr.
|
|
||||||
'''
|
|
||||||
|
|
||||||
self._s = stream
|
|
||||||
self._q = queue.Queue()
|
|
||||||
self._callback = callback
|
|
||||||
self._stop_time = None
|
|
||||||
|
|
||||||
def _populateQueue(stream, queue):
|
|
||||||
'''
|
|
||||||
Collect lines from 'stream' and put them in 'quque'.
|
|
||||||
'''
|
|
||||||
# fd_set_non_blocking(stream)
|
|
||||||
run = True
|
|
||||||
while run and not (self._stop_time is not None and (datetime.datetime.now() - self._stop_time).total_seconds() > 0.05):
|
|
||||||
# stream.flush()
|
|
||||||
# line = stream.read(1) #fd_non_blocking_read(stream)
|
|
||||||
timeout = 0.01
|
|
||||||
ready, _, _ = select.select([stream], [], [], timeout)
|
|
||||||
for fd in ready:
|
|
||||||
try:
|
|
||||||
data = os.read(fd, 512)
|
|
||||||
except OSError as e:
|
|
||||||
if e.errno != errno.EIO:
|
|
||||||
raise
|
|
||||||
# EIO means EOF on some systems
|
|
||||||
run = False
|
|
||||||
else:
|
|
||||||
if not data: # EOF
|
|
||||||
run = False
|
|
||||||
if data is not None and len(data) > 0:
|
|
||||||
if self._callback is not None:
|
|
||||||
self._callback(data)
|
|
||||||
else:
|
|
||||||
queue.put(data)
|
|
||||||
RNS.log("NonBlockingStreamReader exiting", RNS.LOG_DEBUG)
|
|
||||||
os.close(stream)
|
|
||||||
|
|
||||||
self._t = threading.Thread(target = _populateQueue,
|
|
||||||
args = (self._s, self._q))
|
|
||||||
self._t.daemon = True
|
|
||||||
self._t.start() #start collecting lines from the stream
|
|
||||||
|
|
||||||
def read(self, timeout = None):
|
|
||||||
try:
|
|
||||||
result = self._q.get_nowait() if timeout is None else self._q.get(block = timeout is not None,
|
|
||||||
timeout = timeout)
|
|
||||||
return result
|
|
||||||
except TimeoutError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def is_open(self):
|
|
||||||
return self._t.is_alive()
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
if self._stop_time is None:
|
|
||||||
self._stop_time = datetime.datetime.now()
|
|
||||||
|
|
||||||
class UnexpectedEndOfStream(Exception): pass
|
|
||||||
|
|
||||||
def processes_get():
|
|
||||||
processes_lock.acquire()
|
|
||||||
try:
|
|
||||||
return processes.copy()
|
|
||||||
finally:
|
|
||||||
processes_lock.release()
|
|
||||||
|
|
||||||
def processes_add(process):
|
|
||||||
processes_lock.acquire()
|
|
||||||
try:
|
|
||||||
processes.append(process)
|
|
||||||
finally:
|
|
||||||
processes_lock.release()
|
|
||||||
|
|
||||||
def processes_remove(process):
|
|
||||||
if process.link.status == RNS.Link.ACTIVE:
|
|
||||||
return
|
|
||||||
|
|
||||||
processes_lock.acquire()
|
|
||||||
try:
|
|
||||||
if next(filter(lambda p: p == process, processes)) is not None:
|
|
||||||
processes.remove(process)
|
|
||||||
finally:
|
|
||||||
processes_lock.release()
|
|
||||||
|
|
||||||
#### Link Overrides ####
|
|
||||||
_link_handle_request_orig = RNS.Link.handle_request
|
|
||||||
def link_handle_request(self, request_id, unpacked_request):
|
|
||||||
for process in processes_get():
|
|
||||||
if process.link.link_id == self.link_id:
|
|
||||||
RNS.log("Associating packet to link", RNS.LOG_DEBUG)
|
|
||||||
process.request_id = request_id
|
|
||||||
self.last_request_id = request_id
|
|
||||||
_link_handle_request_orig(self, request_id, unpacked_request)
|
|
||||||
|
|
||||||
RNS.Link.handle_request = link_handle_request
|
|
||||||
|
|
||||||
class ProcessState:
|
|
||||||
def __init__(self, command, link, remote_identity, term):
|
|
||||||
self.lock = threading.RLock()
|
|
||||||
self.link = link
|
|
||||||
self.remote_identity = remote_identity
|
|
||||||
self.term = term
|
|
||||||
self.command = command
|
|
||||||
self._stderrbuf = bytearray()
|
|
||||||
self._stdoutbuf = bytearray()
|
|
||||||
RNS.log("Launching " + self.command) # + " for client " + (RNS.prettyhexrep(self.remote_identity) if self.remote_identity else "unknown"), RNS.LOG_DEBUG)
|
|
||||||
env = os.environ.copy()
|
|
||||||
# env["PYTHONUNBUFFERED"] = "1"
|
|
||||||
# env["PS1"] ="\\u:\\h "
|
|
||||||
env["TERM"] = self.term
|
|
||||||
self.mo, so = pty.openpty()
|
|
||||||
self.me, se = pty.openpty()
|
|
||||||
self.mi, si = pty.openpty()
|
|
||||||
self.process = subprocess.Popen(shlex.split(self.command), bufsize=512, stdin=si, stdout=so, stderr=se, preexec_fn=os.setsid, shell=False, env=env)
|
|
||||||
for fd in [so, se, si]:
|
|
||||||
os.close(fd)
|
|
||||||
# tty.setcbreak(self.mo)
|
|
||||||
self.stdout_reader = NonBlockingStreamReader(self.mo, self._stdout_cb)
|
|
||||||
self.stderr_reader = NonBlockingStreamReader(self.me, self._stderr_cb)
|
|
||||||
self.last_update = datetime.datetime.now()
|
|
||||||
self.request_id = None
|
|
||||||
self.notify_tried = 0
|
|
||||||
self.return_code = None
|
|
||||||
|
|
||||||
def _fd_callback(self, fdbuf, data):
|
|
||||||
with self.lock:
|
|
||||||
fdbuf.extend(data)
|
|
||||||
|
|
||||||
def _stdout_cb(self, data):
|
|
||||||
self._fd_callback(self._stdoutbuf, data)
|
|
||||||
|
|
||||||
def _stderr_cb(self, data):
|
|
||||||
self._fd_callback(self._stderrbuf, data)
|
|
||||||
|
|
||||||
def notify_client_data_available(self, chars_available):
|
|
||||||
if (datetime.datetime.now() - self.last_update).total_seconds() < 1:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.last_update = datetime.datetime.now()
|
|
||||||
if self.notify_tried > 15:
|
|
||||||
processes_remove(self)
|
|
||||||
RNS.log(f"Try count exceeded, terminating connection", RNS.LOG_ERROR)
|
|
||||||
self.link.teardown()
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
RNS.log(f"Notifying client; try {self.notify_tried} retcode: {self.return_code} chars avail: {chars_available}")
|
|
||||||
RNS.Packet(self.link, DATA_AVAIL_MSG.encode("utf-8")).send()
|
|
||||||
self.notify_tried += 1
|
|
||||||
except Exception as e:
|
|
||||||
RNS.log("Error notifying client: " + str(e), RNS.LOG_ERROR)
|
|
||||||
|
|
||||||
def poll(self, should_notify):
|
|
||||||
self.return_code, chars_available = self.process.poll(), len(self._stdoutbuf) + len(self._stderrbuf)
|
|
||||||
|
|
||||||
if should_notify and self.return_code is not None or chars_available > 0:
|
|
||||||
self.notify_client_data_available(chars_available)
|
|
||||||
|
|
||||||
if self.return_code is not None:
|
|
||||||
self.stdout_reader.stop()
|
|
||||||
self.stderr_reader.stop()
|
|
||||||
|
|
||||||
return self.return_code, chars_available
|
|
||||||
|
|
||||||
def is_finished(self):
|
|
||||||
with self.lock:
|
|
||||||
return self.return_code is not None and not self.stdout_reader.is_open() # and not self.stderr_reader.is_open()
|
|
||||||
|
|
||||||
def read(self): #TODO: limit take sizes?
|
|
||||||
with self.lock:
|
|
||||||
self.notify_tried = 0
|
|
||||||
self.last_update = datetime.datetime.now()
|
|
||||||
stdout = self._stdoutbuf
|
|
||||||
self._stdoutbuf = bytearray()
|
|
||||||
stderr = self._stderrbuf.copy()
|
|
||||||
self._stderrbuf = bytearray()
|
|
||||||
self.return_code = self.process.poll()
|
|
||||||
if self.return_code is not None and len(stdout) == 0 and len(stderr) == 0:
|
|
||||||
self.final_checkin = True
|
|
||||||
return self.process.poll(), stdout, stderr
|
|
||||||
|
|
||||||
def write(self, bytes):
|
|
||||||
os.write(self.mi, bytes)
|
|
||||||
os.fsync(self.mi)
|
|
||||||
|
|
||||||
def terminate(self):
|
|
||||||
chars_available = 0
|
|
||||||
with self.lock:
|
|
||||||
os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)
|
|
||||||
for fd in [self.mo, self.me, self.mi]:
|
|
||||||
os.close(fd)
|
|
||||||
self.process.terminate()
|
|
||||||
self.process.wait()
|
|
||||||
if self.process.poll() is not None:
|
|
||||||
stdout, stderr = self.process.communicate()
|
|
||||||
self._stdoutbuf += stdout
|
|
||||||
self._stderrbuf += stderr
|
|
||||||
return len(self._stdoutbuf) + len(self._stderrbuf)
|
|
||||||
|
|
||||||
def prepare_identity(identity_path):
|
|
||||||
global identity
|
|
||||||
if identity_path == None:
|
if identity_path == None:
|
||||||
identity_path = RNS.Reticulum.identitypath+"/"+APP_NAME
|
identity_path = RNS.Reticulum.identitypath+"/"+APP_NAME
|
||||||
|
|
||||||
if os.path.isfile(identity_path):
|
if os.path.isfile(identity_path):
|
||||||
identity = RNS.Identity.from_file(identity_path)
|
_identity = RNS.Identity.from_file(identity_path)
|
||||||
|
|
||||||
if identity == None:
|
if _identity == None:
|
||||||
RNS.log("No valid saved identity found, creating new...", RNS.LOG_INFO)
|
log.info("No valid saved identity found, creating new...")
|
||||||
identity = RNS.Identity()
|
_identity = RNS.Identity()
|
||||||
identity.to_file(identity_path)
|
_identity.to_file(identity_path)
|
||||||
|
|
||||||
def listen(configdir, command, identitypath = None, service_name ="default", verbosity = 0, quietness = 0, allowed = [], print_identity = False, disable_auth = None, disable_announce=False):
|
async def _listen(configdir, command, identitypath = None, service_name ="default", verbosity = 0, quietness = 0,
|
||||||
global identity, allow_all, allowed_identity_hashes, reticulum, cmd
|
allowed = [], print_identity = False, disable_auth = None, disable_announce=False):
|
||||||
|
global _identity, _allow_all, _allowed_identity_hashes, _reticulum, _cmd, _destination
|
||||||
cmd = command
|
log = _getLogger("_listen")
|
||||||
|
_cmd = command
|
||||||
|
|
||||||
targetloglevel = 3+verbosity-quietness
|
targetloglevel = 3+verbosity-quietness
|
||||||
reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
|
_reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
|
||||||
|
|
||||||
prepare_identity(identitypath)
|
_prepare_identity(identitypath)
|
||||||
destination = RNS.Destination(identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, service_name)
|
_destination = RNS.Destination(_identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, service_name)
|
||||||
|
|
||||||
if print_identity:
|
if print_identity:
|
||||||
print("Identity : "+str(identity))
|
log.info("Identity : " + str(_identity))
|
||||||
print("Listening on : "+RNS.prettyhexrep(destination.hash))
|
log.info("Listening on : " + RNS.prettyhexrep(_destination.hash))
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
if disable_auth:
|
if disable_auth:
|
||||||
allow_all = True
|
_allow_all = True
|
||||||
else:
|
else:
|
||||||
if allowed != None:
|
if allowed != None:
|
||||||
for a in allowed:
|
for a in allowed:
|
||||||
@ -336,140 +108,306 @@ def listen(configdir, command, identitypath = None, service_name ="default", ver
|
|||||||
raise ValueError("Allowed destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format(hex=dest_len, byte=dest_len//2))
|
raise ValueError("Allowed destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format(hex=dest_len, byte=dest_len//2))
|
||||||
try:
|
try:
|
||||||
destination_hash = bytes.fromhex(a)
|
destination_hash = bytes.fromhex(a)
|
||||||
allowed_identity_hashes.append(destination_hash)
|
_allowed_identity_hashes.append(destination_hash)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError("Invalid destination entered. Check your input.")
|
raise ValueError("Invalid destination entered. Check your input.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(str(e))
|
log.error(str(e))
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
if len(allowed_identity_hashes) < 1 and not disable_auth:
|
if len(_allowed_identity_hashes) < 1 and not disable_auth:
|
||||||
print("Warning: No allowed identities configured, rncx will not accept any commands!")
|
log.warning("Warning: No allowed identities configured, rnsh will not accept any connections!")
|
||||||
|
|
||||||
destination.set_link_established_callback(command_link_established)
|
_destination.set_link_established_callback(_listen_link_established)
|
||||||
|
|
||||||
if not allow_all:
|
if not _allow_all:
|
||||||
destination.register_request_handler(
|
_destination.register_request_handler(
|
||||||
path = service_name,
|
path = service_name,
|
||||||
response_generator = execute_received_command,
|
response_generator = _listen_request,
|
||||||
allow = RNS.Destination.ALLOW_LIST,
|
allow = RNS.Destination.ALLOW_LIST,
|
||||||
allowed_list = allowed_identity_hashes
|
allowed_list = _allowed_identity_hashes
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
destination.register_request_handler(
|
_destination.register_request_handler(
|
||||||
path = service_name,
|
path = service_name,
|
||||||
response_generator = execute_received_command,
|
response_generator = _listen_request,
|
||||||
allow = RNS.Destination.ALLOW_ALL,
|
allow = RNS.Destination.ALLOW_ALL,
|
||||||
)
|
)
|
||||||
|
|
||||||
RNS.log("rnsh listening for commands on "+RNS.prettyhexrep(destination.hash))
|
log.info("rnsh listening for commands on " + RNS.prettyhexrep(_destination.hash))
|
||||||
|
|
||||||
if not disable_announce:
|
if not disable_announce:
|
||||||
destination.announce()
|
_destination.announce()
|
||||||
|
|
||||||
last = datetime.datetime.now()
|
last = time.monotonic()
|
||||||
|
|
||||||
while True:
|
|
||||||
if not disable_announce and (datetime.datetime.now() - last).total_seconds() > 900: # TODO: make parameter
|
|
||||||
last = datetime.datetime.now()
|
|
||||||
destination.announce()
|
|
||||||
|
|
||||||
time.sleep(0.005)
|
try:
|
||||||
for proc in processes_get():
|
while True:
|
||||||
|
if not disable_announce and time.monotonic() - last > 900: # TODO: make parameter
|
||||||
|
last = datetime.datetime.now()
|
||||||
|
_destination.announce()
|
||||||
try:
|
try:
|
||||||
if proc.link.status == RNS.Link.CLOSED:
|
await asyncio.wait_for(_finished, timeout=1.0)
|
||||||
RNS.log("Link closed, terminating")
|
except TimeoutError:
|
||||||
proc.terminate()
|
pass
|
||||||
proc.poll(should_notify=True)
|
except KeyboardInterrupt:
|
||||||
|
log.warning("Shutting down")
|
||||||
|
for link in list(_destination.links):
|
||||||
|
try:
|
||||||
|
if link.process is not None and link.process.process.running:
|
||||||
|
link.process.process.terminate()
|
||||||
except:
|
except:
|
||||||
RNS.log("Error polling process for link " + proc.link.link_id, RNS.LOG_ERROR)
|
pass
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
|
||||||
|
for link in links_still_active:
|
||||||
|
if link.status != RNS.Link.CLOSED:
|
||||||
|
link.teardown()
|
||||||
|
|
||||||
if proc.link.status == RNS.Link.CLOSED:
|
class ProcessState:
|
||||||
processes_remove(proc)
|
def __init__(self,
|
||||||
|
cmd: str,
|
||||||
|
mdu: int,
|
||||||
|
data_available_callback: callable,
|
||||||
|
terminated_callback: callable,
|
||||||
|
term: str | None,
|
||||||
|
loop: asyncio.AbstractEventLoop = None):
|
||||||
|
|
||||||
|
self._log = _getLogger(self.__class__.__name__)
|
||||||
|
self._mdu = mdu
|
||||||
|
self._loop = loop if loop is not None else asyncio.get_running_loop()
|
||||||
|
self._process = process.CallbackSubprocess(argv=shlex.split(cmd),
|
||||||
|
term=term,
|
||||||
|
loop=asyncio.get_running_loop(),
|
||||||
|
stdout_callback=self._stdout_data,
|
||||||
|
terminated_callback=terminated_callback)
|
||||||
|
self._data_buffer = bytearray()
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
self._data_available_cb = data_available_callback
|
||||||
|
self._terminated_cb = terminated_callback
|
||||||
|
self._pending_receipt: RNS.PacketReceipt | None = None
|
||||||
|
self._process.start()
|
||||||
|
self._term_state: [int] | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mdu(self) -> int:
|
||||||
|
return self._mdu
|
||||||
|
|
||||||
|
@mdu.setter
|
||||||
|
def mdu(self, val: int):
|
||||||
|
self._mdu = val
|
||||||
|
|
||||||
|
def pending_receipt_peek(self) -> RNS.PacketReceipt | None:
|
||||||
|
return self._pending_receipt
|
||||||
|
|
||||||
|
def pending_receipt_take(self) -> RNS.PacketReceipt | None:
|
||||||
|
with self._lock:
|
||||||
|
val = self._pending_receipt
|
||||||
|
self._pending_receipt = None
|
||||||
|
return val
|
||||||
|
|
||||||
|
def pending_receipt_put(self, receipt: RNS.PacketReceipt | None):
|
||||||
|
with self._lock:
|
||||||
|
self._pending_receipt = receipt
|
||||||
|
|
||||||
|
@property
|
||||||
|
def process(self) -> process.CallbackSubprocess:
|
||||||
|
return self._process
|
||||||
|
|
||||||
|
@property
|
||||||
|
def return_code(self) -> int | None:
|
||||||
|
return self.process.return_code
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lock(self) -> threading.RLock:
|
||||||
|
return self._lock
|
||||||
|
|
||||||
|
def read(self, count: int) -> bytes:
|
||||||
|
with self.lock:
|
||||||
|
take = self._data_buffer[:count-1]
|
||||||
|
self._data_buffer = self._data_buffer[count:]
|
||||||
|
return take
|
||||||
|
|
||||||
|
def _stdout_data(self, data: bytes):
|
||||||
|
total_available = 0
|
||||||
|
with self.lock:
|
||||||
|
self._data_buffer.extend(data)
|
||||||
|
total_available = len(self._data_buffer)
|
||||||
|
try:
|
||||||
|
self._data_available_cb(total_available)
|
||||||
|
except Exception as e:
|
||||||
|
self._log.error(f"Error calling ProcessState data_available_callback {e}")
|
||||||
|
|
||||||
|
def _update_winsz(self):
|
||||||
|
self.process.set_winsize(self._term_state[3],
|
||||||
|
self._term_state[4],
|
||||||
|
self._term_state[5],
|
||||||
|
self._term_state[6])
|
||||||
|
|
||||||
|
|
||||||
|
REQUEST_IDX_STDIN = 0
|
||||||
|
REQUEST_IDX_TERM = 1
|
||||||
|
REQUEST_IDX_TIOS = 2
|
||||||
|
REQUEST_IDX_ROWS = 3
|
||||||
|
REQUEST_IDX_COLS = 4
|
||||||
|
REQUEST_IDX_HPIX = 5
|
||||||
|
REQUEST_IDX_VPIX = 6
|
||||||
|
def process_request(self, data: [any], read_size: int) -> [any]:
|
||||||
|
stdin = data[ProcessState.REQUEST_IDX_STDIN] # Data passed to stdin
|
||||||
|
term = data[ProcessState.REQUEST_IDX_TERM] # TERM environment variable
|
||||||
|
tios = data[ProcessState.REQUEST_IDX_TIOS] # termios attr
|
||||||
|
rows = data[ProcessState.REQUEST_IDX_ROWS] # window rows
|
||||||
|
cols = data[ProcessState.REQUEST_IDX_COLS] # window cols
|
||||||
|
hpix = data[ProcessState.REQUEST_IDX_HPIX] # window horizontal pixels
|
||||||
|
vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels
|
||||||
|
term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX]
|
||||||
|
response = ProcessState.default_response()
|
||||||
|
|
||||||
def command_link_start_process(link, identity, term) -> ProcessState:
|
response[ProcessState.RESPONSE_IDX_RUNNING] = not self.process.running
|
||||||
try:
|
if self.process.running:
|
||||||
process = ProcessState(cmd, link, identity, term)
|
if term_state != self._term_state:
|
||||||
processes_add(process)
|
self._term_state = term_state
|
||||||
return process
|
self._update_winsz()
|
||||||
except Exception as e:
|
|
||||||
RNS.log("Failed to launch process: " + str(e), RNS.LOG_ERROR)
|
|
||||||
link.teardown()
|
|
||||||
def command_link_established(link):
|
|
||||||
global allow_all
|
|
||||||
link.set_remote_identified_callback(initiator_identified)
|
|
||||||
link.set_link_closed_callback(command_link_closed)
|
|
||||||
RNS.log("Shell link "+str(link)+" established")
|
|
||||||
if allow_all:
|
|
||||||
command_link_start_process(link, None)
|
|
||||||
|
|
||||||
def command_link_closed(link):
|
|
||||||
RNS.log("Shell link "+str(link)+" closed")
|
|
||||||
matches = list(filter(lambda p: p.link == link, processes_get()))
|
|
||||||
if len(matches) == 0:
|
|
||||||
return
|
|
||||||
proc = matches[0]
|
|
||||||
try:
|
|
||||||
proc.terminate()
|
|
||||||
except:
|
|
||||||
RNS.log("Error closing process for link " + RNS.prettyhexrep(link.link_id), RNS.LOG_ERROR)
|
|
||||||
finally:
|
|
||||||
processes_remove(proc)
|
|
||||||
|
|
||||||
def initiator_identified(link, identity):
|
|
||||||
global allow_all, cmd
|
|
||||||
RNS.log("Initiator of link "+str(link)+" identified as "+RNS.prettyhexrep(identity.hash))
|
|
||||||
if not allow_all and not identity.hash in allowed_identity_hashes:
|
|
||||||
RNS.log("Identity "+RNS.prettyhexrep(identity.hash)+" not allowed, tearing down link", RNS.LOG_WARNING)
|
|
||||||
link.teardown()
|
|
||||||
|
|
||||||
def execute_received_command(path, data, request_id, remote_identity, requested_at):
|
|
||||||
RNS.log("execute_received_command", RNS.LOG_DEBUG)
|
|
||||||
process = None
|
|
||||||
for proc in processes_get():
|
|
||||||
RNS.log("checking a proc", RNS.LOG_DEBUG)
|
|
||||||
if proc.request_id == request_id:
|
|
||||||
process = proc
|
|
||||||
RNS.log("execute_received_command matched request", RNS.LOG_DEBUG)
|
|
||||||
|
|
||||||
stdin = data[0] # Data passed to stdin
|
|
||||||
if process is None:
|
|
||||||
link = next(filter(lambda l: hasattr(l, "last_request_id") and l.last_request_id == request_id, RNS.Transport.active_links))
|
|
||||||
if link is not None:
|
|
||||||
process = command_link_start_process(link, identity, base64.b64decode(stdin).decode("utf-8") if stdin is not None else "")
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
# if remote_identity != None:
|
|
||||||
# RNS.log("Executing command ["+command+"] for "+RNS.prettyhexrep(remote_identity.hash))
|
|
||||||
# else:
|
|
||||||
# RNS.log("Executing command ["+command+"] for unknown requestor")
|
|
||||||
|
|
||||||
result = [
|
|
||||||
False, # 0: Command was executed
|
|
||||||
None, # 1: Return value
|
|
||||||
None, # 2: Stdout
|
|
||||||
None, # 3: Stderr
|
|
||||||
datetime.datetime.now(), # 4: Timestamp
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
|
||||||
if process is not None:
|
|
||||||
result[0] = not process.is_finished()
|
|
||||||
if stdin is not None and len(stdin) > 0:
|
if stdin is not None and len(stdin) > 0:
|
||||||
stdin = base64.b64decode(stdin)
|
stdin = base64.b64decode(stdin)
|
||||||
process.write(stdin)
|
self.process.write(stdin)
|
||||||
return_code, stdout, stderr = process.read()
|
response[ProcessState.RESPONSE_IDX_RETCODE] = self.return_code
|
||||||
result[1] = return_code
|
stdout = self.read(read_size)
|
||||||
result[2] = base64.b64encode(stdout).decode("utf-8") if stdout is not None else None
|
with self.lock:
|
||||||
result[3] = base64.b64encode(stderr).decode("utf-8") if stderr is not None else None
|
response[ProcessState.RESPONSE_IDX_RDYBYTE] = len(self._data_buffer)
|
||||||
|
response[ProcessState.RESPONSE_IDX_STDOUT] = \
|
||||||
|
base64.b64encode(stdout).decode("utf-8") if stdout is not None and len(stdout) > 0 else None
|
||||||
|
return response
|
||||||
|
|
||||||
|
RESPONSE_IDX_RUNNING = 0
|
||||||
|
RESPONSE_IDX_RETCODE = 1
|
||||||
|
RESPONSE_IDX_RDYBYTE = 2
|
||||||
|
RESPONSE_IDX_STDOUT = 3
|
||||||
|
RESPONSE_IDX_TMSTAMP = 4
|
||||||
|
@staticmethod
|
||||||
|
def default_response() -> [any]:
|
||||||
|
return [
|
||||||
|
False, # 0: Process running
|
||||||
|
None, # 1: Return value
|
||||||
|
0, # 2: Number of outstanding bytes
|
||||||
|
None, # 3: Stdout/Stderr
|
||||||
|
time.time(), # 4: Timestamp
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _subproc_data_ready(link: RNS.Link, chars_available: int):
|
||||||
|
global _retry_timer
|
||||||
|
log = _getLogger("_subproc_data_ready")
|
||||||
|
process_state: ProcessState = link.process
|
||||||
|
|
||||||
|
def send(timeout: bool, id: any, tries: int) -> any:
|
||||||
|
try:
|
||||||
|
pr = process_state.pending_receipt_take()
|
||||||
|
if pr is not None and pr.get_status() != RNS.PacketReceipt.SENT and pr.get_status() != RNS.PacketReceipt.DELIVERED:
|
||||||
|
if not timeout:
|
||||||
|
_retry_timer.complete(id)
|
||||||
|
log.debug(f"Packet {id} completed with status {pr.status} on link {link}")
|
||||||
|
return link.link_id
|
||||||
|
|
||||||
|
if not timeout:
|
||||||
|
log.info(f"Notifying client try {tries} (retcode: {process_state.return_code} chars avail: {chars_available})")
|
||||||
|
packet = RNS.Packet(link, DATA_AVAIL_MSG.encode("utf-8"))
|
||||||
|
packet.send()
|
||||||
|
pr = packet.receipt
|
||||||
|
process_state.pending_receipt_put(pr)
|
||||||
|
return link.link_id
|
||||||
|
else:
|
||||||
|
log.error(f"Retry count exceeded, terminating link {link}")
|
||||||
|
_retry_timer.complete(link.link_id)
|
||||||
|
link.teardown()
|
||||||
|
except Exception as e:
|
||||||
|
log.error("Error notifying client: " + str(e))
|
||||||
|
return link.link_id
|
||||||
|
|
||||||
|
with process_state.lock:
|
||||||
|
if process_state.pending_receipt_peek() is None:
|
||||||
|
_retry_timer.begin(try_limit=15,
|
||||||
|
wait_delay=link.rtt * 3 if link.rtt is not None else 1,
|
||||||
|
try_callback=functools.partial(send, False),
|
||||||
|
timeout_callback=functools.partial(send, True),
|
||||||
|
id=None)
|
||||||
|
else:
|
||||||
|
log.debug(f"Notification already pending for link {link}")
|
||||||
|
|
||||||
|
def _subproc_terminated(link: RNS.Link, return_code: int):
|
||||||
|
log = _getLogger("_subproc_terminated")
|
||||||
|
log.info(f"Subprocess terminated ({return_code} for link {link}")
|
||||||
|
link.teardown()
|
||||||
|
|
||||||
|
def _listen_start_proc(link: RNS.Link, term: str) -> ProcessState | None:
|
||||||
|
global _cmd
|
||||||
|
log = _getLogger("_listen_start_proc")
|
||||||
|
try:
|
||||||
|
link.process = ProcessState(cmd=_cmd,
|
||||||
|
term=term,
|
||||||
|
data_available_callback=functools.partial(_subproc_data_ready, link),
|
||||||
|
terminated_callback=functools.partial(_subproc_terminated, link))
|
||||||
|
return link.process
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
result[0] = False
|
log.error("Failed to launch process: " + str(e))
|
||||||
if process is not None:
|
link.teardown()
|
||||||
process.terminate()
|
return None
|
||||||
process.link.teardown()
|
def _listen_link_established(link):
|
||||||
|
global _allow_all
|
||||||
|
log = _getLogger("_listen_link_established")
|
||||||
|
link.set_remote_identified_callback(_initiator_identified)
|
||||||
|
link.set_link_closed_callback(_listen_link_closed)
|
||||||
|
log.info("Link "+str(link)+" established")
|
||||||
|
|
||||||
|
def _listen_link_closed(link: RNS.Link):
|
||||||
|
log = _getLogger("_listen_link_closed")
|
||||||
|
# async def cleanup():
|
||||||
|
log.info("Link "+str(link)+" closed")
|
||||||
|
proc: ProcessState | None = link.proc if hasattr(link, "process") else None
|
||||||
|
if proc is None:
|
||||||
|
log.warning(f"No process for link {link}")
|
||||||
|
try:
|
||||||
|
proc.process.terminate()
|
||||||
|
except:
|
||||||
|
log.error(f"Error closing process for link {link}")
|
||||||
|
# asyncio.get_running_loop().call_soon(cleanup)
|
||||||
|
|
||||||
|
def _initiator_identified(link, identity):
|
||||||
|
global _allow_all, _cmd
|
||||||
|
log = _getLogger("_initiator_identified")
|
||||||
|
log.info("Initiator of link "+str(link)+" identified as "+RNS.prettyhexrep(identity.hash))
|
||||||
|
if not _allow_all and not identity.hash in _allowed_identity_hashes:
|
||||||
|
log.warning("Identity "+RNS.prettyhexrep(identity.hash)+" not allowed, tearing down link", RNS.LOG_WARNING)
|
||||||
|
link.teardown()
|
||||||
|
|
||||||
|
def _listen_request(path, data, request_id, link_id, remote_identity, requested_at):
|
||||||
|
global _destination, _retry_timer
|
||||||
|
log = _getLogger("_listen_request")
|
||||||
|
log.debug(f"listen_execute {path} {request_id} {link_id} {remote_identity}, {requested_at}")
|
||||||
|
_retry_timer.complete(link_id)
|
||||||
|
link: RNS.Link = next(filter(lambda l: l.link_id == link_id, _destination.links))
|
||||||
|
if link is None:
|
||||||
|
log.error(f"invalid request {request_id}, no link found with id {link_id}")
|
||||||
|
return
|
||||||
|
process_state: ProcessState | None = None
|
||||||
|
try:
|
||||||
|
term = data[1]
|
||||||
|
process_state = link.process if hasattr(link, "process") else None
|
||||||
|
if process_state is None:
|
||||||
|
log.debug(f"process not found for link {link}")
|
||||||
|
process_state = _listen_start_proc(link, term)
|
||||||
|
|
||||||
|
|
||||||
|
# leave significant overhead for metadata and encoding
|
||||||
|
result = process_state.process_request(data, link.MDU * 3 // 2)
|
||||||
|
except Exception as e:
|
||||||
|
result = ProcessState.default_response()
|
||||||
|
try:
|
||||||
|
if process_state is not None:
|
||||||
|
process_state.process.terminate()
|
||||||
|
link.teardown()
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error terminating process for link {link}")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -552,7 +490,7 @@ def client_packet_handler(message, packet):
|
|||||||
if message is not None and message.decode("utf-8") == DATA_AVAIL_MSG:
|
if message is not None and message.decode("utf-8") == DATA_AVAIL_MSG:
|
||||||
new_data = True
|
new_data = True
|
||||||
def execute(configdir, identitypath = None, verbosity = 0, quietness = 0, noid = False, destination = None, service_name = "default", stdin = None, timeout = RNS.Transport.PATH_REQUEST_TIMEOUT):
|
def execute(configdir, identitypath = None, verbosity = 0, quietness = 0, noid = False, destination = None, service_name = "default", stdin = None, timeout = RNS.Transport.PATH_REQUEST_TIMEOUT):
|
||||||
global identity, reticulum, link, listener_destination, remote_exec_grace
|
global _identity, _reticulum, link, listener_destination, remote_exec_grace
|
||||||
|
|
||||||
try:
|
try:
|
||||||
dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH//8)*2
|
dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH//8)*2
|
||||||
@ -566,12 +504,12 @@ def execute(configdir, identitypath = None, verbosity = 0, quietness = 0, noid =
|
|||||||
print(str(e))
|
print(str(e))
|
||||||
return 241
|
return 241
|
||||||
|
|
||||||
if reticulum == None:
|
if _reticulum == None:
|
||||||
targetloglevel = 3+verbosity-quietness
|
targetloglevel = 3+verbosity-quietness
|
||||||
reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
|
_reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
|
||||||
|
|
||||||
if identity == None:
|
if _identity == None:
|
||||||
prepare_identity(identitypath)
|
_prepare_identity(identitypath)
|
||||||
|
|
||||||
if not RNS.Transport.has_path(destination_hash):
|
if not RNS.Transport.has_path(destination_hash):
|
||||||
RNS.Transport.request_path(destination_hash)
|
RNS.Transport.request_path(destination_hash)
|
||||||
@ -598,7 +536,7 @@ def execute(configdir, identitypath = None, verbosity = 0, quietness = 0, noid =
|
|||||||
return 243
|
return 243
|
||||||
|
|
||||||
if not noid and not link.did_identify:
|
if not noid and not link.did_identify:
|
||||||
link.identify(identity)
|
link.identify(_identity)
|
||||||
link.did_identify = True
|
link.did_identify = True
|
||||||
|
|
||||||
link.set_packet_callback(client_packet_handler)
|
link.set_packet_callback(client_packet_handler)
|
||||||
@ -718,7 +656,7 @@ def main():
|
|||||||
|
|
||||||
if args.listen or args.print_identity:
|
if args.listen or args.print_identity:
|
||||||
RNS.log("command " + args.command)
|
RNS.log("command " + args.command)
|
||||||
listen(
|
_listen(
|
||||||
configdir=args.config,
|
configdir=args.config,
|
||||||
command=args.command,
|
command=args.command,
|
||||||
identitypath=args.identity,
|
identitypath=args.identity,
|
||||||
|
60
rnsh/rnslogging.py
Normal file
60
rnsh/rnslogging.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import logging
|
||||||
|
from logging import Handler, getLevelName
|
||||||
|
from types import GenericAlias
|
||||||
|
import os
|
||||||
|
|
||||||
|
import RNS
|
||||||
|
|
||||||
|
class RnsHandler(Handler):
|
||||||
|
"""
|
||||||
|
A handler class which writes logging records, appropriately formatted,
|
||||||
|
to the RNS logger.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Initialize the handler.
|
||||||
|
"""
|
||||||
|
Handler.__init__(self)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_rns_loglevel(loglevel: int) -> int:
|
||||||
|
if loglevel == logging.CRITICAL:
|
||||||
|
return RNS.LOG_CRITICAL
|
||||||
|
if loglevel == logging.ERROR:
|
||||||
|
return RNS.LOG_ERROR
|
||||||
|
if loglevel == logging.WARNING:
|
||||||
|
return RNS.LOG_WARNING
|
||||||
|
if loglevel == logging.INFO:
|
||||||
|
return RNS.LOG_INFO
|
||||||
|
if loglevel == logging.DEBUG:
|
||||||
|
return RNS.LOG_DEBUG
|
||||||
|
return RNS.LOG_DEBUG
|
||||||
|
|
||||||
|
def emit(self, record):
|
||||||
|
"""
|
||||||
|
Emit a record.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
msg = self.format(record)
|
||||||
|
RNS.log(msg, RnsHandler.get_rns_loglevel(record.levelno))
|
||||||
|
except RecursionError: # See issue 36272
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
self.handleError(record)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
level = getLevelName(self.level)
|
||||||
|
return '<%s (%s)>' % (self.__class__.__name__, level)
|
||||||
|
|
||||||
|
__class_getitem__ = classmethod(GenericAlias)
|
||||||
|
|
||||||
|
|
||||||
|
log_format = '%(name)-40s %(message)s [%(threadName)s]'
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
#format='%(asctime)s.%(msecs)03d %(levelname)-6s %(threadName)-15s %(name)-15s %(message)s',
|
||||||
|
format=log_format,
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S',
|
||||||
|
handlers=[RnsHandler()])
|
Loading…
Reference in New Issue
Block a user