mirror of
https://github.com/markqvist/rnsh.git
synced 2025-01-23 04:31:01 -05:00
Protocol versioning, announce period option, old API workaround
This commit is contained in:
parent
9fe37b57ed
commit
da3b390058
@ -8,9 +8,9 @@ readme = "README.md"
|
|||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.9"
|
python = "^3.9"
|
||||||
rns = {git = "https://github.com/markqvist/Reticulum.git", rev = "3706769"}
|
|
||||||
docopt = "^0.6.2"
|
docopt = "^0.6.2"
|
||||||
psutil = "^5.9.4"
|
psutil = "^5.9.4"
|
||||||
|
rns = "^0.4.8"
|
||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
rnsh = 'rnsh.rnsh:rnsh_cli'
|
rnsh = 'rnsh.rnsh:rnsh_cli'
|
||||||
|
67
rnsh/hacks.py
Normal file
67
rnsh/hacks.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
import RNS
|
||||||
|
import logging
|
||||||
|
module_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SelfPruningAssociation:
|
||||||
|
def __init__(self, ttl: float, loop: asyncio.AbstractEventLoop):
|
||||||
|
self._log = module_logger.getChild(self.__class__.__name__)
|
||||||
|
self._ttl = ttl
|
||||||
|
self._associations = list()
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
self._loop = loop
|
||||||
|
# self._log.debug("__init__")
|
||||||
|
|
||||||
|
def _schedule_prune(self, pair):
|
||||||
|
# self._log.debug(f"schedule prune {pair}")
|
||||||
|
|
||||||
|
def schedule_prune_inner(p):
|
||||||
|
# self._log.debug(f"schedule inner {p}")
|
||||||
|
self._loop.call_later(self._ttl, self._prune, p)
|
||||||
|
|
||||||
|
self._loop.call_soon_threadsafe(schedule_prune_inner, pair)
|
||||||
|
|
||||||
|
def _prune(self, pair: any):
|
||||||
|
# self._log.debug(f"prune {pair}")
|
||||||
|
with self._lock:
|
||||||
|
self._associations.remove(pair)
|
||||||
|
|
||||||
|
def get(self, key: any) -> any:
|
||||||
|
# self._log.debug(f"get {key}")
|
||||||
|
with self._lock:
|
||||||
|
pair = next(filter(lambda x: x[0] == key, self._associations), None)
|
||||||
|
if not pair:
|
||||||
|
return None
|
||||||
|
return pair[1]
|
||||||
|
|
||||||
|
def put(self, key: any, value: any):
|
||||||
|
# self._log.debug(f"put {key},{value}")
|
||||||
|
with self._lock:
|
||||||
|
pair = (key, value)
|
||||||
|
self._associations.append(pair)
|
||||||
|
self._schedule_prune(pair)
|
||||||
|
|
||||||
|
|
||||||
|
_request_to_link: SelfPruningAssociation = None
|
||||||
|
_link_handle_request_orig = RNS.Link.handle_request
|
||||||
|
|
||||||
|
|
||||||
|
def _link_handle_request(self, request_id, unpacked_request):
|
||||||
|
global _request_to_link
|
||||||
|
_request_to_link.put(request_id, self.link_id)
|
||||||
|
_link_handle_request_orig(self, request_id, unpacked_request)
|
||||||
|
|
||||||
|
|
||||||
|
def request_request_id_hack(new_api_reponse_generator, loop):
|
||||||
|
global _request_to_link
|
||||||
|
if _request_to_link is None:
|
||||||
|
RNS.Link.handle_request = _link_handle_request
|
||||||
|
_request_to_link = SelfPruningAssociation(1.0, loop)
|
||||||
|
|
||||||
|
def listen_request(path, data, request_id, remote_identity, requested_at):
|
||||||
|
link_id = _request_to_link.get(request_id)
|
||||||
|
return new_api_reponse_generator(path, data, request_id, link_id, remote_identity, requested_at)
|
||||||
|
|
||||||
|
return listen_request
|
@ -27,6 +27,9 @@ import time
|
|||||||
import rnsh.exception as exception
|
import rnsh.exception as exception
|
||||||
import logging as __logging
|
import logging as __logging
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
from contextlib import AbstractContextManager
|
||||||
|
import types
|
||||||
|
import typing
|
||||||
|
|
||||||
module_logger = __logging.getLogger(__name__)
|
module_logger = __logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -67,7 +70,7 @@ class RetryStatus:
|
|||||||
self.retry_callback(self.tag, self.tries)
|
self.retry_callback(self.tag, self.tries)
|
||||||
|
|
||||||
|
|
||||||
class RetryThread:
|
class RetryThread(AbstractContextManager):
|
||||||
def __init__(self, loop_period: float = 0.25, name: str = "retry thread"):
|
def __init__(self, loop_period: float = 0.25, name: str = "retry thread"):
|
||||||
self._log = module_logger.getChild(self.__class__.__name__)
|
self._log = module_logger.getChild(self.__class__.__name__)
|
||||||
self._loop_period = loop_period
|
self._loop_period = loop_period
|
||||||
@ -176,3 +179,8 @@ class RetryThread:
|
|||||||
status.completed = True
|
status.completed = True
|
||||||
self._log.debug(f"completed {status.tag}")
|
self._log.debug(f"completed {status.tag}")
|
||||||
self._statuses.clear()
|
self._statuses.clear()
|
||||||
|
|
||||||
|
def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
|
||||||
|
__traceback: types.TracebackType) -> bool:
|
||||||
|
self.close()
|
||||||
|
return False
|
||||||
|
240
rnsh/rnsh.py
240
rnsh/rnsh.py
@ -35,12 +35,12 @@ import termios
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from typing import Callable, TypeVar
|
from typing import Callable, TypeVar
|
||||||
|
|
||||||
import RNS
|
import RNS
|
||||||
import rnsh.exception as exception
|
import rnsh.exception as exception
|
||||||
import rnsh.process as process
|
import rnsh.process as process
|
||||||
import rnsh.retry as retry
|
import rnsh.retry as retry
|
||||||
import rnsh.rnslogging as rnslogging
|
import rnsh.rnslogging as rnslogging
|
||||||
|
import rnsh.hacks as hacks
|
||||||
from rnsh.__version import __version__
|
from rnsh.__version import __version__
|
||||||
|
|
||||||
module_logger = __logging.getLogger(__name__)
|
module_logger = __logging.getLogger(__name__)
|
||||||
@ -59,7 +59,7 @@ _allowed_identity_hashes = []
|
|||||||
_cmd: [str] = None
|
_cmd: [str] = None
|
||||||
DATA_AVAIL_MSG = "data available"
|
DATA_AVAIL_MSG = "data available"
|
||||||
_finished: asyncio.Event | None = None
|
_finished: asyncio.Event | None = None
|
||||||
_retry_timer = retry.RetryThread()
|
_retry_timer: retry.RetryThread | None = None
|
||||||
_destination: RNS.Destination | None = None
|
_destination: RNS.Destination | None = None
|
||||||
_loop: asyncio.AbstractEventLoop | None = None
|
_loop: asyncio.AbstractEventLoop | None = None
|
||||||
|
|
||||||
@ -107,8 +107,10 @@ def _print_identity(configdir, identitypath, service_name, include_destination:
|
|||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
# hack_goes_here
|
||||||
|
|
||||||
async def _listen(configdir, command, identitypath=None, service_name="default", verbosity=0, quietness=0,
|
async def _listen(configdir, command, identitypath=None, service_name="default", verbosity=0, quietness=0,
|
||||||
allowed=None, disable_auth=None, disable_announce=False):
|
allowed=None, disable_auth=None, announce_period=900):
|
||||||
global _identity, _allow_all, _allowed_identity_hashes, _reticulum, _cmd, _destination
|
global _identity, _allow_all, _allowed_identity_hashes, _reticulum, _cmd, _destination
|
||||||
log = _get_logger("_listen")
|
log = _get_logger("_listen")
|
||||||
_cmd = command
|
_cmd = command
|
||||||
@ -147,14 +149,14 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
|||||||
if not _allow_all:
|
if not _allow_all:
|
||||||
_destination.register_request_handler(
|
_destination.register_request_handler(
|
||||||
path="data",
|
path="data",
|
||||||
response_generator=_listen_request,
|
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
|
||||||
allow=RNS.Destination.ALLOW_LIST,
|
allow=RNS.Destination.ALLOW_LIST,
|
||||||
allowed_list=_allowed_identity_hashes
|
allowed_list=_allowed_identity_hashes
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
_destination.register_request_handler(
|
_destination.register_request_handler(
|
||||||
path="data",
|
path="data",
|
||||||
response_generator=_listen_request,
|
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
|
||||||
allow=RNS.Destination.ALLOW_ALL,
|
allow=RNS.Destination.ALLOW_ALL,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -162,14 +164,14 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
|||||||
|
|
||||||
log.info("rnsh listening for commands on " + RNS.prettyhexrep(_destination.hash))
|
log.info("rnsh listening for commands on " + RNS.prettyhexrep(_destination.hash))
|
||||||
|
|
||||||
if not disable_announce:
|
if announce_period is not None:
|
||||||
_destination.announce()
|
_destination.announce()
|
||||||
|
|
||||||
last = time.time()
|
last = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
if not disable_announce and time.time() - last > 900: # TODO: make parameter
|
if announce_period and 0 < announce_period < time.time() - last:
|
||||||
last = time.time()
|
last = time.time()
|
||||||
_destination.announce()
|
_destination.announce()
|
||||||
await _check_finished(1.0)
|
await _check_finished(1.0)
|
||||||
@ -177,7 +179,7 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
|||||||
log.warning("Shutting down")
|
log.warning("Shutting down")
|
||||||
for link in list(_destination.links):
|
for link in list(_destination.links):
|
||||||
with exception.permit(SystemExit):
|
with exception.permit(SystemExit):
|
||||||
proc = ProcessState.get_for_tag(link.link_id)
|
proc = Session.get_for_tag(link.link_id)
|
||||||
if proc is not None and proc.process.running:
|
if proc is not None and proc.process.running:
|
||||||
proc.process.terminate()
|
proc.process.terminate()
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
@ -187,17 +189,35 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
|||||||
link.teardown()
|
link.teardown()
|
||||||
|
|
||||||
|
|
||||||
class ProcessState:
|
_PROTOCOL_MAGIC = 0xdeadbeef
|
||||||
_processes: [(any, ProcessState)] = []
|
|
||||||
|
|
||||||
|
def _protocol_make_version(version: int):
|
||||||
|
return (_PROTOCOL_MAGIC << 32) & 0xffffffff00000000 | (0xffffffff & version)
|
||||||
|
|
||||||
|
|
||||||
|
_PROTOCOL_VERSION_0 = _protocol_make_version(0)
|
||||||
|
|
||||||
|
|
||||||
|
def _protocol_split_version(version: int):
|
||||||
|
return (version >> 32) & 0xffffffff, version & 0xffffffff
|
||||||
|
|
||||||
|
|
||||||
|
def _protocol_check_magic(value: int):
|
||||||
|
return _protocol_split_version(value)[0] == _PROTOCOL_MAGIC
|
||||||
|
|
||||||
|
|
||||||
|
class Session:
|
||||||
|
_processes: [(any, Session)] = []
|
||||||
_lock = threading.RLock()
|
_lock = threading.RLock()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_for_tag(cls, tag: any) -> ProcessState | None:
|
def get_for_tag(cls, tag: any) -> Session | None:
|
||||||
with cls._lock:
|
with cls._lock:
|
||||||
return next(map(lambda p: p[1], filter(lambda p: p[0] == tag, cls._processes)), None)
|
return next(map(lambda p: p[1], filter(lambda p: p[0] == tag, cls._processes)), None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def put_for_tag(cls, tag: any, ps: ProcessState):
|
def put_for_tag(cls, tag: any, ps: Session):
|
||||||
with cls._lock:
|
with cls._lock:
|
||||||
cls.clear_tag(tag)
|
cls.clear_tag(tag)
|
||||||
cls._processes.append((tag, ps))
|
cls._processes.append((tag, ps))
|
||||||
@ -234,7 +254,7 @@ class ProcessState:
|
|||||||
self._pending_receipt: RNS.PacketReceipt | None = None
|
self._pending_receipt: RNS.PacketReceipt | None = None
|
||||||
self._process.start()
|
self._process.start()
|
||||||
self._term_state: [int] = None
|
self._term_state: [int] = None
|
||||||
ProcessState.put_for_tag(tag, self)
|
Session.put_for_tag(tag, self)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mdu(self) -> int:
|
def mdu(self) -> int:
|
||||||
@ -302,37 +322,40 @@ class ProcessState:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._log.debug(f"failed to update winsz: {e}")
|
self._log.debug(f"failed to update winsz: {e}")
|
||||||
|
|
||||||
REQUEST_IDX_STDIN = 0
|
REQUEST_IDX_VERS = 0
|
||||||
REQUEST_IDX_TERM = 1
|
REQUEST_IDX_STDIN = 1
|
||||||
REQUEST_IDX_TIOS = 2
|
REQUEST_IDX_TERM = 2
|
||||||
REQUEST_IDX_ROWS = 3
|
REQUEST_IDX_TIOS = 3
|
||||||
REQUEST_IDX_COLS = 4
|
REQUEST_IDX_ROWS = 4
|
||||||
REQUEST_IDX_HPIX = 5
|
REQUEST_IDX_COLS = 5
|
||||||
REQUEST_IDX_VPIX = 6
|
REQUEST_IDX_HPIX = 6
|
||||||
|
REQUEST_IDX_VPIX = 7
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def default_request(stdin_fd: int | None) -> [any]:
|
def default_request(stdin_fd: int | None) -> [any]:
|
||||||
|
global _PROTOCOL_VERSION_0
|
||||||
request: list[any] = [
|
request: list[any] = [
|
||||||
None, # 0 Stdin
|
_PROTOCOL_VERSION_0, # 0 Protocol Version
|
||||||
None, # 1 TERM variable
|
None, # 1 Stdin
|
||||||
None, # 2 termios attributes or something
|
None, # 2 TERM variable
|
||||||
None, # 3 terminal rows
|
None, # 3 termios attributes or something
|
||||||
None, # 4 terminal cols
|
None, # 4 terminal rows
|
||||||
None, # 5 terminal horizontal pixels
|
None, # 5 terminal cols
|
||||||
None, # 6 terminal vertical pixels
|
None, # 6 terminal horizontal pixels
|
||||||
|
None, # 7 terminal vertical pixels
|
||||||
].copy()
|
].copy()
|
||||||
|
|
||||||
if stdin_fd is not None:
|
if stdin_fd is not None:
|
||||||
request[ProcessState.REQUEST_IDX_TERM] = os.environ.get("TERM", None)
|
request[Session.REQUEST_IDX_TERM] = os.environ.get("TERM", None)
|
||||||
request[ProcessState.REQUEST_IDX_TIOS] = termios.tcgetattr(stdin_fd)
|
request[Session.REQUEST_IDX_TIOS] = termios.tcgetattr(stdin_fd)
|
||||||
request[ProcessState.REQUEST_IDX_ROWS], \
|
request[Session.REQUEST_IDX_ROWS], \
|
||||||
request[ProcessState.REQUEST_IDX_COLS], \
|
request[Session.REQUEST_IDX_COLS], \
|
||||||
request[ProcessState.REQUEST_IDX_HPIX], \
|
request[Session.REQUEST_IDX_HPIX], \
|
||||||
request[ProcessState.REQUEST_IDX_VPIX] = process.tty_get_winsize(stdin_fd)
|
request[Session.REQUEST_IDX_VPIX] = process.tty_get_winsize(stdin_fd)
|
||||||
return request
|
return request
|
||||||
|
|
||||||
def process_request(self, data: [any], read_size: int) -> [any]:
|
def process_request(self, data: [any], read_size: int) -> [any]:
|
||||||
stdin = data[ProcessState.REQUEST_IDX_STDIN] # Data passed to stdin
|
stdin = data[Session.REQUEST_IDX_STDIN] # Data passed to stdin
|
||||||
# term = data[ProcessState.REQUEST_IDX_TERM] # TERM environment variable
|
# term = data[ProcessState.REQUEST_IDX_TERM] # TERM environment variable
|
||||||
# tios = data[ProcessState.REQUEST_IDX_TIOS] # termios attr
|
# tios = data[ProcessState.REQUEST_IDX_TIOS] # termios attr
|
||||||
# rows = data[ProcessState.REQUEST_IDX_ROWS] # window rows
|
# rows = data[ProcessState.REQUEST_IDX_ROWS] # window rows
|
||||||
@ -340,10 +363,10 @@ class ProcessState:
|
|||||||
# hpix = data[ProcessState.REQUEST_IDX_HPIX] # window horizontal pixels
|
# hpix = data[ProcessState.REQUEST_IDX_HPIX] # window horizontal pixels
|
||||||
# vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels
|
# vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels
|
||||||
# term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX+1]
|
# term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX+1]
|
||||||
response = ProcessState.default_response()
|
response = Session.default_response()
|
||||||
term_state = data[ProcessState.REQUEST_IDX_TIOS:ProcessState.REQUEST_IDX_VPIX + 1]
|
term_state = data[Session.REQUEST_IDX_TIOS:Session.REQUEST_IDX_VPIX + 1]
|
||||||
|
|
||||||
response[ProcessState.RESPONSE_IDX_RUNNING] = self.process.running
|
response[Session.RESPONSE_IDX_RUNNING] = self.process.running
|
||||||
if self.process.running:
|
if self.process.running:
|
||||||
if term_state != self._term_state:
|
if term_state != self._term_state:
|
||||||
self._term_state = term_state
|
self._term_state = term_state
|
||||||
@ -351,39 +374,42 @@ class ProcessState:
|
|||||||
if stdin is not None and len(stdin) > 0:
|
if stdin is not None and len(stdin) > 0:
|
||||||
stdin = base64.b64decode(stdin)
|
stdin = base64.b64decode(stdin)
|
||||||
self.process.write(stdin)
|
self.process.write(stdin)
|
||||||
response[ProcessState.RESPONSE_IDX_RETCODE] = None if self.process.running else self.return_code
|
response[Session.RESPONSE_IDX_RETCODE] = None if self.process.running else self.return_code
|
||||||
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
stdout = self.read(read_size)
|
stdout = self.read(read_size)
|
||||||
response[ProcessState.RESPONSE_IDX_RDYBYTE] = len(self._data_buffer)
|
response[Session.RESPONSE_IDX_RDYBYTE] = len(self._data_buffer)
|
||||||
|
|
||||||
if stdout is not None and len(stdout) > 0:
|
if stdout is not None and len(stdout) > 0:
|
||||||
response[ProcessState.RESPONSE_IDX_STDOUT] = base64.b64encode(stdout).decode("utf-8")
|
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode(stdout).decode("utf-8")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
RESPONSE_IDX_RUNNING = 0
|
RESPONSE_IDX_VERSION = 0
|
||||||
RESPONSE_IDX_RETCODE = 1
|
RESPONSE_IDX_RUNNING = 1
|
||||||
RESPONSE_IDX_RDYBYTE = 2
|
RESPONSE_IDX_RETCODE = 2
|
||||||
RESPONSE_IDX_STDOUT = 3
|
RESPONSE_IDX_RDYBYTE = 3
|
||||||
RESPONSE_IDX_TMSTAMP = 4
|
RESPONSE_IDX_STDOUT = 4
|
||||||
|
RESPONSE_IDX_TMSTAMP = 5
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def default_response() -> [any]:
|
def default_response() -> [any]:
|
||||||
|
global _PROTOCOL_VERSION_0
|
||||||
response: list[any] = [
|
response: list[any] = [
|
||||||
False, # 0: Process running
|
_PROTOCOL_VERSION_0, # 0: Protocol version
|
||||||
None, # 1: Return value
|
False, # 1: Process running
|
||||||
0, # 2: Number of outstanding bytes
|
None, # 2: Return value
|
||||||
None, # 3: Stdout/Stderr
|
0, # 3: Number of outstanding bytes
|
||||||
None, # 4: Timestamp
|
None, # 4: Stdout/Stderr
|
||||||
|
None, # 5: Timestamp
|
||||||
].copy()
|
].copy()
|
||||||
response[ProcessState.RESPONSE_IDX_TMSTAMP] = time.time()
|
response[Session.RESPONSE_IDX_TMSTAMP] = time.time()
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def _subproc_data_ready(link: RNS.Link, chars_available: int):
|
def _subproc_data_ready(link: RNS.Link, chars_available: int):
|
||||||
global _retry_timer
|
global _retry_timer
|
||||||
log = _get_logger("_subproc_data_ready")
|
log = _get_logger("_subproc_data_ready")
|
||||||
process_state: ProcessState = ProcessState.get_for_tag(link.link_id)
|
session: Session = Session.get_for_tag(link.link_id)
|
||||||
|
|
||||||
def send(timeout: bool, tag: any, tries: int) -> any:
|
def send(timeout: bool, tag: any, tries: int) -> any:
|
||||||
# log.debug("send")
|
# log.debug("send")
|
||||||
@ -392,10 +418,10 @@ def _subproc_data_ready(link: RNS.Link, chars_available: int):
|
|||||||
try:
|
try:
|
||||||
if link.status != RNS.Link.ACTIVE:
|
if link.status != RNS.Link.ACTIVE:
|
||||||
_retry_timer.complete(link.link_id)
|
_retry_timer.complete(link.link_id)
|
||||||
process_state.pending_receipt_take()
|
session.pending_receipt_take()
|
||||||
return
|
return
|
||||||
|
|
||||||
pr = process_state.pending_receipt_take()
|
pr = session.pending_receipt_take()
|
||||||
log.debug(f"send inner pr: {pr}")
|
log.debug(f"send inner pr: {pr}")
|
||||||
if pr is not None and pr.status == RNS.PacketReceipt.DELIVERED:
|
if pr is not None and pr.status == RNS.PacketReceipt.DELIVERED:
|
||||||
if not timeout:
|
if not timeout:
|
||||||
@ -405,12 +431,12 @@ def _subproc_data_ready(link: RNS.Link, chars_available: int):
|
|||||||
else:
|
else:
|
||||||
if not timeout:
|
if not timeout:
|
||||||
log.info(
|
log.info(
|
||||||
f"Notifying client try {tries} (retcode: {process_state.return_code} " +
|
f"Notifying client try {tries} (retcode: {session.return_code} " +
|
||||||
f"chars avail: {chars_available})")
|
f"chars avail: {chars_available})")
|
||||||
packet = RNS.Packet(link, DATA_AVAIL_MSG.encode("utf-8"))
|
packet = RNS.Packet(link, DATA_AVAIL_MSG.encode("utf-8"))
|
||||||
packet.send()
|
packet.send()
|
||||||
pr = packet.receipt
|
pr = packet.receipt
|
||||||
process_state.pending_receipt_put(pr)
|
session.pending_receipt_put(pr)
|
||||||
else:
|
else:
|
||||||
log.error(f"Retry count exceeded, terminating link {link}")
|
log.error(f"Retry count exceeded, terminating link {link}")
|
||||||
_retry_timer.complete(link.link_id)
|
_retry_timer.complete(link.link_id)
|
||||||
@ -421,7 +447,7 @@ def _subproc_data_ready(link: RNS.Link, chars_available: int):
|
|||||||
_loop.call_soon_threadsafe(inner)
|
_loop.call_soon_threadsafe(inner)
|
||||||
return link.link_id
|
return link.link_id
|
||||||
|
|
||||||
with process_state.lock:
|
with session.lock:
|
||||||
if not _retry_timer.has_tag(link.link_id):
|
if not _retry_timer.has_tag(link.link_id):
|
||||||
_retry_timer.begin(try_limit=15,
|
_retry_timer.begin(try_limit=15,
|
||||||
wait_delay=max(link.rtt * 5 if link.rtt is not None else 1, 1),
|
wait_delay=max(link.rtt * 5 if link.rtt is not None else 1, 1),
|
||||||
@ -436,7 +462,7 @@ def _subproc_terminated(link: RNS.Link, return_code: int):
|
|||||||
global _loop
|
global _loop
|
||||||
log = _get_logger("_subproc_terminated")
|
log = _get_logger("_subproc_terminated")
|
||||||
log.info(f"Subprocess returned {return_code} for link {link}")
|
log.info(f"Subprocess returned {return_code} for link {link}")
|
||||||
proc = ProcessState.get_for_tag(link.link_id)
|
proc = Session.get_for_tag(link.link_id)
|
||||||
if proc is None:
|
if proc is None:
|
||||||
log.debug(f"no proc for link {link}")
|
log.debug(f"no proc for link {link}")
|
||||||
return
|
return
|
||||||
@ -449,7 +475,7 @@ def _subproc_terminated(link: RNS.Link, return_code: int):
|
|||||||
try:
|
try:
|
||||||
link.teardown()
|
link.teardown()
|
||||||
finally:
|
finally:
|
||||||
ProcessState.clear_tag(link.link_id)
|
Session.clear_tag(link.link_id)
|
||||||
|
|
||||||
_loop.call_later(300, inner)
|
_loop.call_later(300, inner)
|
||||||
_loop.call_soon(_subproc_data_ready, link, 0)
|
_loop.call_soon(_subproc_data_ready, link, 0)
|
||||||
@ -458,11 +484,11 @@ def _subproc_terminated(link: RNS.Link, return_code: int):
|
|||||||
|
|
||||||
|
|
||||||
def _listen_start_proc(link: RNS.Link, remote_identity: str | None, term: str,
|
def _listen_start_proc(link: RNS.Link, remote_identity: str | None, term: str,
|
||||||
loop: asyncio.AbstractEventLoop) -> ProcessState | None:
|
loop: asyncio.AbstractEventLoop) -> Session | None:
|
||||||
global _cmd
|
global _cmd
|
||||||
log = _get_logger("_listen_start_proc")
|
log = _get_logger("_listen_start_proc")
|
||||||
try:
|
try:
|
||||||
return ProcessState(tag=link.link_id,
|
return Session(tag=link.link_id,
|
||||||
cmd=_cmd,
|
cmd=_cmd,
|
||||||
term=term,
|
term=term,
|
||||||
remote_identity=remote_identity,
|
remote_identity=remote_identity,
|
||||||
@ -488,7 +514,7 @@ def _listen_link_closed(link: RNS.Link):
|
|||||||
log = _get_logger("_listen_link_closed")
|
log = _get_logger("_listen_link_closed")
|
||||||
# async def cleanup():
|
# async def cleanup():
|
||||||
log.info("Link " + str(link) + " closed")
|
log.info("Link " + str(link) + " closed")
|
||||||
proc: ProcessState | None = ProcessState.get_for_tag(link.link_id)
|
proc: Session | None = Session.get_for_tag(link.link_id)
|
||||||
if proc is None:
|
if proc is None:
|
||||||
log.warning(f"No process for link {link}")
|
log.warning(f"No process for link {link}")
|
||||||
else:
|
else:
|
||||||
@ -497,7 +523,7 @@ def _listen_link_closed(link: RNS.Link):
|
|||||||
_retry_timer.complete(link.link_id)
|
_retry_timer.complete(link.link_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error closing process for link {link}: {e}")
|
log.error(f"Error closing process for link {link}: {e}")
|
||||||
ProcessState.clear_tag(link.link_id)
|
Session.clear_tag(link.link_id)
|
||||||
|
|
||||||
|
|
||||||
def _initiator_identified(link, identity):
|
def _initiator_identified(link, identity):
|
||||||
@ -513,34 +539,48 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
|
|||||||
global _destination, _retry_timer, _loop
|
global _destination, _retry_timer, _loop
|
||||||
log = _get_logger("_listen_request")
|
log = _get_logger("_listen_request")
|
||||||
log.debug(f"listen_execute {path} {request_id} {link_id} {remote_identity}, {requested_at}")
|
log.debug(f"listen_execute {path} {request_id} {link_id} {remote_identity}, {requested_at}")
|
||||||
|
if not hasattr(data, "__len__") or len(data) < 1:
|
||||||
|
raise Exception("Request data invalid")
|
||||||
_retry_timer.complete(link_id)
|
_retry_timer.complete(link_id)
|
||||||
link: RNS.Link = next(filter(lambda l: l.link_id == link_id, _destination.links), None)
|
link: RNS.Link = next(filter(lambda l: l.link_id == link_id, _destination.links), None)
|
||||||
if link is None:
|
if link is None:
|
||||||
raise Exception(f"Invalid request {request_id}, no link found with id {link_id}")
|
raise Exception(f"Invalid request {request_id}, no link found with id {link_id}")
|
||||||
process_state: ProcessState | None = None
|
|
||||||
|
remote_version = data[Session.REQUEST_IDX_VERS]
|
||||||
|
if not _protocol_check_magic(remote_version):
|
||||||
|
raise Exception("Request magic incorrect")
|
||||||
|
|
||||||
|
if not remote_version == _PROTOCOL_VERSION_0:
|
||||||
|
response = Session.default_response()
|
||||||
|
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode("Listener<->initiator version mismatch\r\n".encode("utf-8"))
|
||||||
|
response[Session.RESPONSE_IDX_RETCODE] = 255
|
||||||
|
response[Session.RESPONSE_IDX_RDYBYTE] = 0
|
||||||
|
return response
|
||||||
|
|
||||||
|
session: Session | None = None
|
||||||
try:
|
try:
|
||||||
term = data[ProcessState.REQUEST_IDX_TERM]
|
term = data[Session.REQUEST_IDX_TERM]
|
||||||
process_state = ProcessState.get_for_tag(link.link_id)
|
session = Session.get_for_tag(link.link_id)
|
||||||
if process_state is None:
|
if session is None:
|
||||||
log.debug(f"Process not found for link {link}")
|
log.debug(f"Process not found for link {link}")
|
||||||
process_state = _listen_start_proc(link=link,
|
session = _listen_start_proc(link=link,
|
||||||
term=term,
|
term=term,
|
||||||
remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""),
|
remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""),
|
||||||
loop=_loop)
|
loop=_loop)
|
||||||
|
|
||||||
# leave significant headroom for metadata and encoding
|
# leave significant headroom for metadata and encoding
|
||||||
result = process_state.process_request(data, link.MDU * 4 // 3)
|
result = session.process_request(data, link.MDU * 4 // 3)
|
||||||
return result
|
return result
|
||||||
# return ProcessState.default_response()
|
# return ProcessState.default_response()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error procesing request for link {link}: {e}")
|
log.error(f"Error procesing request for link {link}: {e}")
|
||||||
try:
|
try:
|
||||||
if process_state is not None and process_state.process.running:
|
if session is not None and session.process.running:
|
||||||
process_state.process.terminate()
|
session.process.terminate()
|
||||||
except Exception as ee:
|
except Exception as ee:
|
||||||
log.debug(f"Error terminating process for link {link}: {ee}")
|
log.debug(f"Error terminating process for link {link}: {ee}")
|
||||||
|
|
||||||
return ProcessState.default_response()
|
return Session.default_response()
|
||||||
|
|
||||||
|
|
||||||
async def _spin(until: callable = None, timeout: float | None = None) -> bool:
|
async def _spin(until: callable = None, timeout: float | None = None) -> bool:
|
||||||
@ -558,7 +598,7 @@ async def _spin(until: callable = None, timeout: float | None = None) -> bool:
|
|||||||
_link: RNS.Link | None = None
|
_link: RNS.Link | None = None
|
||||||
_remote_exec_grace = 2.0
|
_remote_exec_grace = 2.0
|
||||||
_new_data: asyncio.Event | None = None
|
_new_data: asyncio.Event | None = None
|
||||||
_tr = process.TTYRestorer(sys.stdin.fileno())
|
_tr: process.TTYRestorer | None = None
|
||||||
|
|
||||||
|
|
||||||
def _client_packet_handler(message, packet):
|
def _client_packet_handler(message, packet):
|
||||||
@ -632,8 +672,8 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
|||||||
|
|
||||||
_link.set_packet_callback(_client_packet_handler)
|
_link.set_packet_callback(_client_packet_handler)
|
||||||
|
|
||||||
request = ProcessState.default_request(sys.stdin.fileno())
|
request = Session.default_request(sys.stdin.fileno())
|
||||||
request[ProcessState.REQUEST_IDX_STDIN] = (base64.b64encode(stdin) if stdin is not None else None)
|
request[Session.REQUEST_IDX_STDIN] = (base64.b64encode(stdin) if stdin is not None else None)
|
||||||
|
|
||||||
# TODO: Tune
|
# TODO: Tune
|
||||||
timeout = timeout + _link.rtt * 4 + _remote_exec_grace
|
timeout = timeout + _link.rtt * 4 + _remote_exec_grace
|
||||||
@ -671,18 +711,27 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
|||||||
|
|
||||||
if request_receipt.response is not None:
|
if request_receipt.response is not None:
|
||||||
try:
|
try:
|
||||||
running = request_receipt.response[ProcessState.RESPONSE_IDX_RUNNING] or True
|
version = request_receipt.response[Session.RESPONSE_IDX_VERSION] or 0
|
||||||
return_code = request_receipt.response[ProcessState.RESPONSE_IDX_RETCODE]
|
if not _protocol_check_magic(version):
|
||||||
ready_bytes = request_receipt.response[ProcessState.RESPONSE_IDX_RDYBYTE] or 0
|
raise RemoteExecutionError("Protocol error")
|
||||||
stdout = request_receipt.response[ProcessState.RESPONSE_IDX_STDOUT]
|
elif version != _PROTOCOL_VERSION_0:
|
||||||
timestamp = request_receipt.response[ProcessState.RESPONSE_IDX_TMSTAMP]
|
raise RemoteExecutionError("Protocol version mismatch")
|
||||||
|
|
||||||
|
running = request_receipt.response[Session.RESPONSE_IDX_RUNNING] or True
|
||||||
|
return_code = request_receipt.response[Session.RESPONSE_IDX_RETCODE]
|
||||||
|
ready_bytes = request_receipt.response[Session.RESPONSE_IDX_RDYBYTE] or 0
|
||||||
|
stdout = request_receipt.response[Session.RESPONSE_IDX_STDOUT]
|
||||||
|
if stdout is not None:
|
||||||
|
stdout = base64.b64decode(stdout)
|
||||||
|
timestamp = request_receipt.response[Session.RESPONSE_IDX_TMSTAMP]
|
||||||
# log.debug("data: " + (stdout.decode("utf-8") if stdout is not None else ""))
|
# log.debug("data: " + (stdout.decode("utf-8") if stdout is not None else ""))
|
||||||
|
except RemoteExecutionError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RemoteExecutionError(f"Received invalid response") from e
|
raise RemoteExecutionError(f"Received invalid response") from e
|
||||||
|
|
||||||
_tr.raw()
|
_tr.raw()
|
||||||
if stdout is not None:
|
if stdout is not None:
|
||||||
stdout = base64.b64decode(stdout)
|
|
||||||
# log.debug(f"stdout: {stdout}")
|
# log.debug(f"stdout: {stdout}")
|
||||||
os.write(sys.stdout.fileno(), stdout)
|
os.write(sys.stdout.fileno(), stdout)
|
||||||
|
|
||||||
@ -784,7 +833,7 @@ def _split_array_at(arr: [_T], at: _T) -> ([_T], [_T]):
|
|||||||
return arr, []
|
return arr, []
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def _rnsh_cli_main():
|
||||||
global _tr, _finished, _loop
|
global _tr, _finished, _loop
|
||||||
import docopt
|
import docopt
|
||||||
log = _get_logger("main")
|
log = _get_logger("main")
|
||||||
@ -797,8 +846,8 @@ async def main():
|
|||||||
Usage:
|
Usage:
|
||||||
rnsh [--config <configdir>] [-i <identityfile>] [-s <service_name>] [-l] -p
|
rnsh [--config <configdir>] [-i <identityfile>] [-s <service_name>] [-l] -p
|
||||||
rnsh -l [--config <configfile>] [-i <identityfile>] [-s <service_name>]
|
rnsh -l [--config <configfile>] [-i <identityfile>] [-s <service_name>]
|
||||||
[-v...] [-q...] [-b] (-n | -a <identity_hash> [-a <identity_hash>]...)
|
[-v...] [-q...] [-b <period>] (-n | -a <identity_hash> [-a <identity_hash>] ...)
|
||||||
[--] <program> [<arg>...]
|
[--] <program> [<arg> ...]
|
||||||
rnsh [--config <configfile>] [-i <identityfile>] [-s <service_name>]
|
rnsh [--config <configfile>] [-i <identityfile>] [-s <service_name>]
|
||||||
[-v...] [-q...] [-N] [-m] [-w <timeout>] <destination_hash>
|
[-v...] [-q...] [-N] [-m] [-w <timeout>] <destination_hash>
|
||||||
rnsh -h
|
rnsh -h
|
||||||
@ -810,7 +859,8 @@ Options:
|
|||||||
-s NAME --service NAME Listen on/connect to specific service name if not default
|
-s NAME --service NAME Listen on/connect to specific service name if not default
|
||||||
-p --print-identity Print identity information and exit
|
-p --print-identity Print identity information and exit
|
||||||
-l --listen Listen (server) mode
|
-l --listen Listen (server) mode
|
||||||
-b --no-announce Do not announce service
|
-b --announce PERIOD Announce on startup and every PERIOD seconds
|
||||||
|
Specify 0 for PERIOD to announce on startup only.
|
||||||
-a HASH --allowed HASH Specify identities allowed to connect
|
-a HASH --allowed HASH Specify identities allowed to connect
|
||||||
-n --no-auth Disable authentication
|
-n --no-auth Disable authentication
|
||||||
-N --no-id Disable identify on connect
|
-N --no-id Disable identify on connect
|
||||||
@ -837,7 +887,13 @@ Options:
|
|||||||
args_print_identity = args.get("--print-identity", None) or False
|
args_print_identity = args.get("--print-identity", None) or False
|
||||||
args_verbose = args.get("--verbose", None) or 0
|
args_verbose = args.get("--verbose", None) or 0
|
||||||
args_quiet = args.get("--quiet", None) or 0
|
args_quiet = args.get("--quiet", None) or 0
|
||||||
args_no_announce = args.get("--no-announce", None) or False
|
args_announce = args.get("--announce", None)
|
||||||
|
try:
|
||||||
|
if args_announce:
|
||||||
|
args_announce = int(args_announce)
|
||||||
|
except ValueError:
|
||||||
|
print("Invalid value for --announce")
|
||||||
|
return 1
|
||||||
args_no_auth = args.get("--no-auth", None) or False
|
args_no_auth = args.get("--no-auth", None) or False
|
||||||
args_allowed = args.get("--allowed", None) or []
|
args_allowed = args.get("--allowed", None) or []
|
||||||
args_program = args.get("<program>", None)
|
args_program = args.get("<program>", None)
|
||||||
@ -868,7 +924,7 @@ Options:
|
|||||||
quietness=args_quiet,
|
quietness=args_quiet,
|
||||||
allowed=args_allowed,
|
allowed=args_allowed,
|
||||||
disable_auth=args_no_auth,
|
disable_auth=args_no_auth,
|
||||||
disable_announce=args_no_announce,
|
announce_period=args_announce,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args_destination is not None and args_service_name is not None:
|
if args_destination is not None and args_service_name is not None:
|
||||||
@ -893,14 +949,14 @@ Options:
|
|||||||
|
|
||||||
|
|
||||||
def rnsh_cli():
|
def rnsh_cli():
|
||||||
try:
|
global _tr, _retry_timer
|
||||||
return_code = asyncio.run(main())
|
with process.TTYRestorer(sys.stdin.fileno()) as _tr:
|
||||||
finally:
|
with retry.RetryThread() as _retry_timer:
|
||||||
|
return_code = asyncio.run(_rnsh_cli_main())
|
||||||
|
|
||||||
with exception.permit(SystemExit):
|
with exception.permit(SystemExit):
|
||||||
process.tty_unset_reader_callbacks(sys.stdin.fileno())
|
process.tty_unset_reader_callbacks(sys.stdin.fileno())
|
||||||
|
|
||||||
_tr.restore()
|
|
||||||
_retry_timer.close()
|
|
||||||
sys.exit(return_code or 255)
|
sys.exit(return_code or 255)
|
||||||
|
|
||||||
|
|
||||||
|
44
tests/test_hacks.py
Normal file
44
tests/test_hacks.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
import rnsh.hacks as hacks
|
||||||
|
import pytest
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
logging.getLogger().setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeLink:
|
||||||
|
def __init__(self, link_id):
|
||||||
|
self.link_id = link_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pruning():
|
||||||
|
def listen_request(path, data, request_id, link_id, remote_identity, requested_at):
|
||||||
|
assert path == 1
|
||||||
|
assert data == 2
|
||||||
|
assert request_id == 3
|
||||||
|
assert remote_identity == 4
|
||||||
|
assert requested_at == 5
|
||||||
|
assert link_id == 6
|
||||||
|
return 7
|
||||||
|
|
||||||
|
lhr_called = 0
|
||||||
|
link = FakeLink(6)
|
||||||
|
|
||||||
|
def link_handle_request(self, request_id, unpacked_request):
|
||||||
|
nonlocal lhr_called
|
||||||
|
lhr_called += 1
|
||||||
|
|
||||||
|
old_func = hacks.request_request_id_hack(listen_request, asyncio.get_running_loop())
|
||||||
|
hacks._link_handle_request_orig = link_handle_request
|
||||||
|
hacks._link_handle_request(link, 3, None)
|
||||||
|
link_id = hacks._request_to_link.get(3)
|
||||||
|
assert link_id == link.link_id
|
||||||
|
result = old_func(1, 2, 3, 4, 5)
|
||||||
|
assert result == 7
|
||||||
|
link_id = hacks._request_to_link.get(3)
|
||||||
|
assert link_id == 6
|
||||||
|
await asyncio.sleep(1.5)
|
||||||
|
link_id = hacks._request_to_link.get(3)
|
||||||
|
assert link_id is None
|
15
tests/test_rnsh.py
Normal file
15
tests/test_rnsh.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import logging
|
||||||
|
import rnsh.rnsh
|
||||||
|
logging.getLogger().setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_magic():
|
||||||
|
magic = rnsh.rnsh._PROTOCOL_VERSION_0
|
||||||
|
# magic for version 0 is generated, make sure it comes out as expected
|
||||||
|
assert magic == 0xdeadbeef00000000
|
||||||
|
# verify the checker thinks it's right
|
||||||
|
assert rnsh.rnsh._protocol_check_magic(magic)
|
||||||
|
# scramble the magic
|
||||||
|
magic = magic | 0x00ffff0000000000
|
||||||
|
# make sure it fails now
|
||||||
|
assert not rnsh.rnsh._protocol_check_magic(magic)
|
Loading…
Reference in New Issue
Block a user