mirror of
https://github.com/markqvist/rnsh.git
synced 2025-01-05 20:30:53 -05:00
Protocol versioning, announce period option, old API workaround
This commit is contained in:
parent
9fe37b57ed
commit
da3b390058
@ -8,9 +8,9 @@ readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.9"
|
||||
rns = {git = "https://github.com/markqvist/Reticulum.git", rev = "3706769"}
|
||||
docopt = "^0.6.2"
|
||||
psutil = "^5.9.4"
|
||||
rns = "^0.4.8"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
rnsh = 'rnsh.rnsh:rnsh_cli'
|
||||
|
67
rnsh/hacks.py
Normal file
67
rnsh/hacks.py
Normal file
@ -0,0 +1,67 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import RNS
|
||||
import logging
|
||||
module_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SelfPruningAssociation:
|
||||
def __init__(self, ttl: float, loop: asyncio.AbstractEventLoop):
|
||||
self._log = module_logger.getChild(self.__class__.__name__)
|
||||
self._ttl = ttl
|
||||
self._associations = list()
|
||||
self._lock = threading.RLock()
|
||||
self._loop = loop
|
||||
# self._log.debug("__init__")
|
||||
|
||||
def _schedule_prune(self, pair):
|
||||
# self._log.debug(f"schedule prune {pair}")
|
||||
|
||||
def schedule_prune_inner(p):
|
||||
# self._log.debug(f"schedule inner {p}")
|
||||
self._loop.call_later(self._ttl, self._prune, p)
|
||||
|
||||
self._loop.call_soon_threadsafe(schedule_prune_inner, pair)
|
||||
|
||||
def _prune(self, pair: any):
|
||||
# self._log.debug(f"prune {pair}")
|
||||
with self._lock:
|
||||
self._associations.remove(pair)
|
||||
|
||||
def get(self, key: any) -> any:
|
||||
# self._log.debug(f"get {key}")
|
||||
with self._lock:
|
||||
pair = next(filter(lambda x: x[0] == key, self._associations), None)
|
||||
if not pair:
|
||||
return None
|
||||
return pair[1]
|
||||
|
||||
def put(self, key: any, value: any):
|
||||
# self._log.debug(f"put {key},{value}")
|
||||
with self._lock:
|
||||
pair = (key, value)
|
||||
self._associations.append(pair)
|
||||
self._schedule_prune(pair)
|
||||
|
||||
|
||||
_request_to_link: SelfPruningAssociation = None
|
||||
_link_handle_request_orig = RNS.Link.handle_request
|
||||
|
||||
|
||||
def _link_handle_request(self, request_id, unpacked_request):
|
||||
global _request_to_link
|
||||
_request_to_link.put(request_id, self.link_id)
|
||||
_link_handle_request_orig(self, request_id, unpacked_request)
|
||||
|
||||
|
||||
def request_request_id_hack(new_api_reponse_generator, loop):
|
||||
global _request_to_link
|
||||
if _request_to_link is None:
|
||||
RNS.Link.handle_request = _link_handle_request
|
||||
_request_to_link = SelfPruningAssociation(1.0, loop)
|
||||
|
||||
def listen_request(path, data, request_id, remote_identity, requested_at):
|
||||
link_id = _request_to_link.get(request_id)
|
||||
return new_api_reponse_generator(path, data, request_id, link_id, remote_identity, requested_at)
|
||||
|
||||
return listen_request
|
@ -27,6 +27,9 @@ import time
|
||||
import rnsh.exception as exception
|
||||
import logging as __logging
|
||||
from typing import Callable
|
||||
from contextlib import AbstractContextManager
|
||||
import types
|
||||
import typing
|
||||
|
||||
module_logger = __logging.getLogger(__name__)
|
||||
|
||||
@ -67,7 +70,7 @@ class RetryStatus:
|
||||
self.retry_callback(self.tag, self.tries)
|
||||
|
||||
|
||||
class RetryThread:
|
||||
class RetryThread(AbstractContextManager):
|
||||
def __init__(self, loop_period: float = 0.25, name: str = "retry thread"):
|
||||
self._log = module_logger.getChild(self.__class__.__name__)
|
||||
self._loop_period = loop_period
|
||||
@ -176,3 +179,8 @@ class RetryThread:
|
||||
status.completed = True
|
||||
self._log.debug(f"completed {status.tag}")
|
||||
self._statuses.clear()
|
||||
|
||||
def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
|
||||
__traceback: types.TracebackType) -> bool:
|
||||
self.close()
|
||||
return False
|
||||
|
258
rnsh/rnsh.py
258
rnsh/rnsh.py
@ -35,12 +35,12 @@ import termios
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
import RNS
|
||||
import rnsh.exception as exception
|
||||
import rnsh.process as process
|
||||
import rnsh.retry as retry
|
||||
import rnsh.rnslogging as rnslogging
|
||||
import rnsh.hacks as hacks
|
||||
from rnsh.__version import __version__
|
||||
|
||||
module_logger = __logging.getLogger(__name__)
|
||||
@ -59,7 +59,7 @@ _allowed_identity_hashes = []
|
||||
_cmd: [str] = None
|
||||
DATA_AVAIL_MSG = "data available"
|
||||
_finished: asyncio.Event | None = None
|
||||
_retry_timer = retry.RetryThread()
|
||||
_retry_timer: retry.RetryThread | None = None
|
||||
_destination: RNS.Destination | None = None
|
||||
_loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
@ -107,8 +107,10 @@ def _print_identity(configdir, identitypath, service_name, include_destination:
|
||||
exit(0)
|
||||
|
||||
|
||||
# hack_goes_here
|
||||
|
||||
async def _listen(configdir, command, identitypath=None, service_name="default", verbosity=0, quietness=0,
|
||||
allowed=None, disable_auth=None, disable_announce=False):
|
||||
allowed=None, disable_auth=None, announce_period=900):
|
||||
global _identity, _allow_all, _allowed_identity_hashes, _reticulum, _cmd, _destination
|
||||
log = _get_logger("_listen")
|
||||
_cmd = command
|
||||
@ -147,14 +149,14 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
||||
if not _allow_all:
|
||||
_destination.register_request_handler(
|
||||
path="data",
|
||||
response_generator=_listen_request,
|
||||
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
|
||||
allow=RNS.Destination.ALLOW_LIST,
|
||||
allowed_list=_allowed_identity_hashes
|
||||
)
|
||||
else:
|
||||
_destination.register_request_handler(
|
||||
path="data",
|
||||
response_generator=_listen_request,
|
||||
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
|
||||
allow=RNS.Destination.ALLOW_ALL,
|
||||
)
|
||||
|
||||
@ -162,14 +164,14 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
||||
|
||||
log.info("rnsh listening for commands on " + RNS.prettyhexrep(_destination.hash))
|
||||
|
||||
if not disable_announce:
|
||||
if announce_period is not None:
|
||||
_destination.announce()
|
||||
|
||||
last = time.time()
|
||||
|
||||
try:
|
||||
while True:
|
||||
if not disable_announce and time.time() - last > 900: # TODO: make parameter
|
||||
if announce_period and 0 < announce_period < time.time() - last:
|
||||
last = time.time()
|
||||
_destination.announce()
|
||||
await _check_finished(1.0)
|
||||
@ -177,7 +179,7 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
||||
log.warning("Shutting down")
|
||||
for link in list(_destination.links):
|
||||
with exception.permit(SystemExit):
|
||||
proc = ProcessState.get_for_tag(link.link_id)
|
||||
proc = Session.get_for_tag(link.link_id)
|
||||
if proc is not None and proc.process.running:
|
||||
proc.process.terminate()
|
||||
await asyncio.sleep(1)
|
||||
@ -187,17 +189,35 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
||||
link.teardown()
|
||||
|
||||
|
||||
class ProcessState:
|
||||
_processes: [(any, ProcessState)] = []
|
||||
_PROTOCOL_MAGIC = 0xdeadbeef
|
||||
|
||||
|
||||
def _protocol_make_version(version: int):
|
||||
return (_PROTOCOL_MAGIC << 32) & 0xffffffff00000000 | (0xffffffff & version)
|
||||
|
||||
|
||||
_PROTOCOL_VERSION_0 = _protocol_make_version(0)
|
||||
|
||||
|
||||
def _protocol_split_version(version: int):
|
||||
return (version >> 32) & 0xffffffff, version & 0xffffffff
|
||||
|
||||
|
||||
def _protocol_check_magic(value: int):
|
||||
return _protocol_split_version(value)[0] == _PROTOCOL_MAGIC
|
||||
|
||||
|
||||
class Session:
|
||||
_processes: [(any, Session)] = []
|
||||
_lock = threading.RLock()
|
||||
|
||||
@classmethod
|
||||
def get_for_tag(cls, tag: any) -> ProcessState | None:
|
||||
def get_for_tag(cls, tag: any) -> Session | None:
|
||||
with cls._lock:
|
||||
return next(map(lambda p: p[1], filter(lambda p: p[0] == tag, cls._processes)), None)
|
||||
|
||||
@classmethod
|
||||
def put_for_tag(cls, tag: any, ps: ProcessState):
|
||||
def put_for_tag(cls, tag: any, ps: Session):
|
||||
with cls._lock:
|
||||
cls.clear_tag(tag)
|
||||
cls._processes.append((tag, ps))
|
||||
@ -234,7 +254,7 @@ class ProcessState:
|
||||
self._pending_receipt: RNS.PacketReceipt | None = None
|
||||
self._process.start()
|
||||
self._term_state: [int] = None
|
||||
ProcessState.put_for_tag(tag, self)
|
||||
Session.put_for_tag(tag, self)
|
||||
|
||||
@property
|
||||
def mdu(self) -> int:
|
||||
@ -302,37 +322,40 @@ class ProcessState:
|
||||
except Exception as e:
|
||||
self._log.debug(f"failed to update winsz: {e}")
|
||||
|
||||
REQUEST_IDX_STDIN = 0
|
||||
REQUEST_IDX_TERM = 1
|
||||
REQUEST_IDX_TIOS = 2
|
||||
REQUEST_IDX_ROWS = 3
|
||||
REQUEST_IDX_COLS = 4
|
||||
REQUEST_IDX_HPIX = 5
|
||||
REQUEST_IDX_VPIX = 6
|
||||
REQUEST_IDX_VERS = 0
|
||||
REQUEST_IDX_STDIN = 1
|
||||
REQUEST_IDX_TERM = 2
|
||||
REQUEST_IDX_TIOS = 3
|
||||
REQUEST_IDX_ROWS = 4
|
||||
REQUEST_IDX_COLS = 5
|
||||
REQUEST_IDX_HPIX = 6
|
||||
REQUEST_IDX_VPIX = 7
|
||||
|
||||
@staticmethod
|
||||
def default_request(stdin_fd: int | None) -> [any]:
|
||||
global _PROTOCOL_VERSION_0
|
||||
request: list[any] = [
|
||||
None, # 0 Stdin
|
||||
None, # 1 TERM variable
|
||||
None, # 2 termios attributes or something
|
||||
None, # 3 terminal rows
|
||||
None, # 4 terminal cols
|
||||
None, # 5 terminal horizontal pixels
|
||||
None, # 6 terminal vertical pixels
|
||||
_PROTOCOL_VERSION_0, # 0 Protocol Version
|
||||
None, # 1 Stdin
|
||||
None, # 2 TERM variable
|
||||
None, # 3 termios attributes or something
|
||||
None, # 4 terminal rows
|
||||
None, # 5 terminal cols
|
||||
None, # 6 terminal horizontal pixels
|
||||
None, # 7 terminal vertical pixels
|
||||
].copy()
|
||||
|
||||
if stdin_fd is not None:
|
||||
request[ProcessState.REQUEST_IDX_TERM] = os.environ.get("TERM", None)
|
||||
request[ProcessState.REQUEST_IDX_TIOS] = termios.tcgetattr(stdin_fd)
|
||||
request[ProcessState.REQUEST_IDX_ROWS], \
|
||||
request[ProcessState.REQUEST_IDX_COLS], \
|
||||
request[ProcessState.REQUEST_IDX_HPIX], \
|
||||
request[ProcessState.REQUEST_IDX_VPIX] = process.tty_get_winsize(stdin_fd)
|
||||
request[Session.REQUEST_IDX_TERM] = os.environ.get("TERM", None)
|
||||
request[Session.REQUEST_IDX_TIOS] = termios.tcgetattr(stdin_fd)
|
||||
request[Session.REQUEST_IDX_ROWS], \
|
||||
request[Session.REQUEST_IDX_COLS], \
|
||||
request[Session.REQUEST_IDX_HPIX], \
|
||||
request[Session.REQUEST_IDX_VPIX] = process.tty_get_winsize(stdin_fd)
|
||||
return request
|
||||
|
||||
def process_request(self, data: [any], read_size: int) -> [any]:
|
||||
stdin = data[ProcessState.REQUEST_IDX_STDIN] # Data passed to stdin
|
||||
stdin = data[Session.REQUEST_IDX_STDIN] # Data passed to stdin
|
||||
# term = data[ProcessState.REQUEST_IDX_TERM] # TERM environment variable
|
||||
# tios = data[ProcessState.REQUEST_IDX_TIOS] # termios attr
|
||||
# rows = data[ProcessState.REQUEST_IDX_ROWS] # window rows
|
||||
@ -340,10 +363,10 @@ class ProcessState:
|
||||
# hpix = data[ProcessState.REQUEST_IDX_HPIX] # window horizontal pixels
|
||||
# vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels
|
||||
# term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX+1]
|
||||
response = ProcessState.default_response()
|
||||
term_state = data[ProcessState.REQUEST_IDX_TIOS:ProcessState.REQUEST_IDX_VPIX + 1]
|
||||
response = Session.default_response()
|
||||
term_state = data[Session.REQUEST_IDX_TIOS:Session.REQUEST_IDX_VPIX + 1]
|
||||
|
||||
response[ProcessState.RESPONSE_IDX_RUNNING] = self.process.running
|
||||
response[Session.RESPONSE_IDX_RUNNING] = self.process.running
|
||||
if self.process.running:
|
||||
if term_state != self._term_state:
|
||||
self._term_state = term_state
|
||||
@ -351,39 +374,42 @@ class ProcessState:
|
||||
if stdin is not None and len(stdin) > 0:
|
||||
stdin = base64.b64decode(stdin)
|
||||
self.process.write(stdin)
|
||||
response[ProcessState.RESPONSE_IDX_RETCODE] = None if self.process.running else self.return_code
|
||||
response[Session.RESPONSE_IDX_RETCODE] = None if self.process.running else self.return_code
|
||||
|
||||
with self.lock:
|
||||
stdout = self.read(read_size)
|
||||
response[ProcessState.RESPONSE_IDX_RDYBYTE] = len(self._data_buffer)
|
||||
response[Session.RESPONSE_IDX_RDYBYTE] = len(self._data_buffer)
|
||||
|
||||
if stdout is not None and len(stdout) > 0:
|
||||
response[ProcessState.RESPONSE_IDX_STDOUT] = base64.b64encode(stdout).decode("utf-8")
|
||||
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode(stdout).decode("utf-8")
|
||||
return response
|
||||
|
||||
RESPONSE_IDX_RUNNING = 0
|
||||
RESPONSE_IDX_RETCODE = 1
|
||||
RESPONSE_IDX_RDYBYTE = 2
|
||||
RESPONSE_IDX_STDOUT = 3
|
||||
RESPONSE_IDX_TMSTAMP = 4
|
||||
RESPONSE_IDX_VERSION = 0
|
||||
RESPONSE_IDX_RUNNING = 1
|
||||
RESPONSE_IDX_RETCODE = 2
|
||||
RESPONSE_IDX_RDYBYTE = 3
|
||||
RESPONSE_IDX_STDOUT = 4
|
||||
RESPONSE_IDX_TMSTAMP = 5
|
||||
|
||||
@staticmethod
|
||||
def default_response() -> [any]:
|
||||
global _PROTOCOL_VERSION_0
|
||||
response: list[any] = [
|
||||
False, # 0: Process running
|
||||
None, # 1: Return value
|
||||
0, # 2: Number of outstanding bytes
|
||||
None, # 3: Stdout/Stderr
|
||||
None, # 4: Timestamp
|
||||
_PROTOCOL_VERSION_0, # 0: Protocol version
|
||||
False, # 1: Process running
|
||||
None, # 2: Return value
|
||||
0, # 3: Number of outstanding bytes
|
||||
None, # 4: Stdout/Stderr
|
||||
None, # 5: Timestamp
|
||||
].copy()
|
||||
response[ProcessState.RESPONSE_IDX_TMSTAMP] = time.time()
|
||||
response[Session.RESPONSE_IDX_TMSTAMP] = time.time()
|
||||
return response
|
||||
|
||||
|
||||
def _subproc_data_ready(link: RNS.Link, chars_available: int):
|
||||
global _retry_timer
|
||||
log = _get_logger("_subproc_data_ready")
|
||||
process_state: ProcessState = ProcessState.get_for_tag(link.link_id)
|
||||
session: Session = Session.get_for_tag(link.link_id)
|
||||
|
||||
def send(timeout: bool, tag: any, tries: int) -> any:
|
||||
# log.debug("send")
|
||||
@ -392,10 +418,10 @@ def _subproc_data_ready(link: RNS.Link, chars_available: int):
|
||||
try:
|
||||
if link.status != RNS.Link.ACTIVE:
|
||||
_retry_timer.complete(link.link_id)
|
||||
process_state.pending_receipt_take()
|
||||
session.pending_receipt_take()
|
||||
return
|
||||
|
||||
pr = process_state.pending_receipt_take()
|
||||
pr = session.pending_receipt_take()
|
||||
log.debug(f"send inner pr: {pr}")
|
||||
if pr is not None and pr.status == RNS.PacketReceipt.DELIVERED:
|
||||
if not timeout:
|
||||
@ -405,12 +431,12 @@ def _subproc_data_ready(link: RNS.Link, chars_available: int):
|
||||
else:
|
||||
if not timeout:
|
||||
log.info(
|
||||
f"Notifying client try {tries} (retcode: {process_state.return_code} " +
|
||||
f"Notifying client try {tries} (retcode: {session.return_code} " +
|
||||
f"chars avail: {chars_available})")
|
||||
packet = RNS.Packet(link, DATA_AVAIL_MSG.encode("utf-8"))
|
||||
packet.send()
|
||||
pr = packet.receipt
|
||||
process_state.pending_receipt_put(pr)
|
||||
session.pending_receipt_put(pr)
|
||||
else:
|
||||
log.error(f"Retry count exceeded, terminating link {link}")
|
||||
_retry_timer.complete(link.link_id)
|
||||
@ -421,7 +447,7 @@ def _subproc_data_ready(link: RNS.Link, chars_available: int):
|
||||
_loop.call_soon_threadsafe(inner)
|
||||
return link.link_id
|
||||
|
||||
with process_state.lock:
|
||||
with session.lock:
|
||||
if not _retry_timer.has_tag(link.link_id):
|
||||
_retry_timer.begin(try_limit=15,
|
||||
wait_delay=max(link.rtt * 5 if link.rtt is not None else 1, 1),
|
||||
@ -436,7 +462,7 @@ def _subproc_terminated(link: RNS.Link, return_code: int):
|
||||
global _loop
|
||||
log = _get_logger("_subproc_terminated")
|
||||
log.info(f"Subprocess returned {return_code} for link {link}")
|
||||
proc = ProcessState.get_for_tag(link.link_id)
|
||||
proc = Session.get_for_tag(link.link_id)
|
||||
if proc is None:
|
||||
log.debug(f"no proc for link {link}")
|
||||
return
|
||||
@ -449,7 +475,7 @@ def _subproc_terminated(link: RNS.Link, return_code: int):
|
||||
try:
|
||||
link.teardown()
|
||||
finally:
|
||||
ProcessState.clear_tag(link.link_id)
|
||||
Session.clear_tag(link.link_id)
|
||||
|
||||
_loop.call_later(300, inner)
|
||||
_loop.call_soon(_subproc_data_ready, link, 0)
|
||||
@ -458,18 +484,18 @@ def _subproc_terminated(link: RNS.Link, return_code: int):
|
||||
|
||||
|
||||
def _listen_start_proc(link: RNS.Link, remote_identity: str | None, term: str,
|
||||
loop: asyncio.AbstractEventLoop) -> ProcessState | None:
|
||||
loop: asyncio.AbstractEventLoop) -> Session | None:
|
||||
global _cmd
|
||||
log = _get_logger("_listen_start_proc")
|
||||
try:
|
||||
return ProcessState(tag=link.link_id,
|
||||
cmd=_cmd,
|
||||
term=term,
|
||||
remote_identity=remote_identity,
|
||||
mdu=link.MDU,
|
||||
loop=loop,
|
||||
data_available_callback=functools.partial(_subproc_data_ready, link),
|
||||
terminated_callback=functools.partial(_subproc_terminated, link))
|
||||
return Session(tag=link.link_id,
|
||||
cmd=_cmd,
|
||||
term=term,
|
||||
remote_identity=remote_identity,
|
||||
mdu=link.MDU,
|
||||
loop=loop,
|
||||
data_available_callback=functools.partial(_subproc_data_ready, link),
|
||||
terminated_callback=functools.partial(_subproc_terminated, link))
|
||||
except Exception as e:
|
||||
log.error("Failed to launch process: " + str(e))
|
||||
_subproc_terminated(link, 255)
|
||||
@ -488,7 +514,7 @@ def _listen_link_closed(link: RNS.Link):
|
||||
log = _get_logger("_listen_link_closed")
|
||||
# async def cleanup():
|
||||
log.info("Link " + str(link) + " closed")
|
||||
proc: ProcessState | None = ProcessState.get_for_tag(link.link_id)
|
||||
proc: Session | None = Session.get_for_tag(link.link_id)
|
||||
if proc is None:
|
||||
log.warning(f"No process for link {link}")
|
||||
else:
|
||||
@ -497,7 +523,7 @@ def _listen_link_closed(link: RNS.Link):
|
||||
_retry_timer.complete(link.link_id)
|
||||
except Exception as e:
|
||||
log.error(f"Error closing process for link {link}: {e}")
|
||||
ProcessState.clear_tag(link.link_id)
|
||||
Session.clear_tag(link.link_id)
|
||||
|
||||
|
||||
def _initiator_identified(link, identity):
|
||||
@ -513,34 +539,48 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
|
||||
global _destination, _retry_timer, _loop
|
||||
log = _get_logger("_listen_request")
|
||||
log.debug(f"listen_execute {path} {request_id} {link_id} {remote_identity}, {requested_at}")
|
||||
if not hasattr(data, "__len__") or len(data) < 1:
|
||||
raise Exception("Request data invalid")
|
||||
_retry_timer.complete(link_id)
|
||||
link: RNS.Link = next(filter(lambda l: l.link_id == link_id, _destination.links), None)
|
||||
if link is None:
|
||||
raise Exception(f"Invalid request {request_id}, no link found with id {link_id}")
|
||||
process_state: ProcessState | None = None
|
||||
|
||||
remote_version = data[Session.REQUEST_IDX_VERS]
|
||||
if not _protocol_check_magic(remote_version):
|
||||
raise Exception("Request magic incorrect")
|
||||
|
||||
if not remote_version == _PROTOCOL_VERSION_0:
|
||||
response = Session.default_response()
|
||||
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode("Listener<->initiator version mismatch\r\n".encode("utf-8"))
|
||||
response[Session.RESPONSE_IDX_RETCODE] = 255
|
||||
response[Session.RESPONSE_IDX_RDYBYTE] = 0
|
||||
return response
|
||||
|
||||
session: Session | None = None
|
||||
try:
|
||||
term = data[ProcessState.REQUEST_IDX_TERM]
|
||||
process_state = ProcessState.get_for_tag(link.link_id)
|
||||
if process_state is None:
|
||||
term = data[Session.REQUEST_IDX_TERM]
|
||||
session = Session.get_for_tag(link.link_id)
|
||||
if session is None:
|
||||
log.debug(f"Process not found for link {link}")
|
||||
process_state = _listen_start_proc(link=link,
|
||||
session = _listen_start_proc(link=link,
|
||||
term=term,
|
||||
remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""),
|
||||
loop=_loop)
|
||||
|
||||
# leave significant headroom for metadata and encoding
|
||||
result = process_state.process_request(data, link.MDU * 4 // 3)
|
||||
result = session.process_request(data, link.MDU * 4 // 3)
|
||||
return result
|
||||
# return ProcessState.default_response()
|
||||
except Exception as e:
|
||||
log.error(f"Error procesing request for link {link}: {e}")
|
||||
try:
|
||||
if process_state is not None and process_state.process.running:
|
||||
process_state.process.terminate()
|
||||
if session is not None and session.process.running:
|
||||
session.process.terminate()
|
||||
except Exception as ee:
|
||||
log.debug(f"Error terminating process for link {link}: {ee}")
|
||||
|
||||
return ProcessState.default_response()
|
||||
return Session.default_response()
|
||||
|
||||
|
||||
async def _spin(until: callable = None, timeout: float | None = None) -> bool:
|
||||
@ -558,7 +598,7 @@ async def _spin(until: callable = None, timeout: float | None = None) -> bool:
|
||||
_link: RNS.Link | None = None
|
||||
_remote_exec_grace = 2.0
|
||||
_new_data: asyncio.Event | None = None
|
||||
_tr = process.TTYRestorer(sys.stdin.fileno())
|
||||
_tr: process.TTYRestorer | None = None
|
||||
|
||||
|
||||
def _client_packet_handler(message, packet):
|
||||
@ -632,8 +672,8 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
||||
|
||||
_link.set_packet_callback(_client_packet_handler)
|
||||
|
||||
request = ProcessState.default_request(sys.stdin.fileno())
|
||||
request[ProcessState.REQUEST_IDX_STDIN] = (base64.b64encode(stdin) if stdin is not None else None)
|
||||
request = Session.default_request(sys.stdin.fileno())
|
||||
request[Session.REQUEST_IDX_STDIN] = (base64.b64encode(stdin) if stdin is not None else None)
|
||||
|
||||
# TODO: Tune
|
||||
timeout = timeout + _link.rtt * 4 + _remote_exec_grace
|
||||
@ -671,18 +711,27 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
||||
|
||||
if request_receipt.response is not None:
|
||||
try:
|
||||
running = request_receipt.response[ProcessState.RESPONSE_IDX_RUNNING] or True
|
||||
return_code = request_receipt.response[ProcessState.RESPONSE_IDX_RETCODE]
|
||||
ready_bytes = request_receipt.response[ProcessState.RESPONSE_IDX_RDYBYTE] or 0
|
||||
stdout = request_receipt.response[ProcessState.RESPONSE_IDX_STDOUT]
|
||||
timestamp = request_receipt.response[ProcessState.RESPONSE_IDX_TMSTAMP]
|
||||
version = request_receipt.response[Session.RESPONSE_IDX_VERSION] or 0
|
||||
if not _protocol_check_magic(version):
|
||||
raise RemoteExecutionError("Protocol error")
|
||||
elif version != _PROTOCOL_VERSION_0:
|
||||
raise RemoteExecutionError("Protocol version mismatch")
|
||||
|
||||
running = request_receipt.response[Session.RESPONSE_IDX_RUNNING] or True
|
||||
return_code = request_receipt.response[Session.RESPONSE_IDX_RETCODE]
|
||||
ready_bytes = request_receipt.response[Session.RESPONSE_IDX_RDYBYTE] or 0
|
||||
stdout = request_receipt.response[Session.RESPONSE_IDX_STDOUT]
|
||||
if stdout is not None:
|
||||
stdout = base64.b64decode(stdout)
|
||||
timestamp = request_receipt.response[Session.RESPONSE_IDX_TMSTAMP]
|
||||
# log.debug("data: " + (stdout.decode("utf-8") if stdout is not None else ""))
|
||||
except RemoteExecutionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RemoteExecutionError(f"Received invalid response") from e
|
||||
|
||||
_tr.raw()
|
||||
if stdout is not None:
|
||||
stdout = base64.b64decode(stdout)
|
||||
# log.debug(f"stdout: {stdout}")
|
||||
os.write(sys.stdout.fileno(), stdout)
|
||||
|
||||
@ -784,7 +833,7 @@ def _split_array_at(arr: [_T], at: _T) -> ([_T], [_T]):
|
||||
return arr, []
|
||||
|
||||
|
||||
async def main():
|
||||
async def _rnsh_cli_main():
|
||||
global _tr, _finished, _loop
|
||||
import docopt
|
||||
log = _get_logger("main")
|
||||
@ -797,8 +846,8 @@ async def main():
|
||||
Usage:
|
||||
rnsh [--config <configdir>] [-i <identityfile>] [-s <service_name>] [-l] -p
|
||||
rnsh -l [--config <configfile>] [-i <identityfile>] [-s <service_name>]
|
||||
[-v...] [-q...] [-b] (-n | -a <identity_hash> [-a <identity_hash>]...)
|
||||
[--] <program> [<arg>...]
|
||||
[-v...] [-q...] [-b <period>] (-n | -a <identity_hash> [-a <identity_hash>] ...)
|
||||
[--] <program> [<arg> ...]
|
||||
rnsh [--config <configfile>] [-i <identityfile>] [-s <service_name>]
|
||||
[-v...] [-q...] [-N] [-m] [-w <timeout>] <destination_hash>
|
||||
rnsh -h
|
||||
@ -810,7 +859,8 @@ Options:
|
||||
-s NAME --service NAME Listen on/connect to specific service name if not default
|
||||
-p --print-identity Print identity information and exit
|
||||
-l --listen Listen (server) mode
|
||||
-b --no-announce Do not announce service
|
||||
-b --announce PERIOD Announce on startup and every PERIOD seconds
|
||||
Specify 0 for PERIOD to announce on startup only.
|
||||
-a HASH --allowed HASH Specify identities allowed to connect
|
||||
-n --no-auth Disable authentication
|
||||
-N --no-id Disable identify on connect
|
||||
@ -837,7 +887,13 @@ Options:
|
||||
args_print_identity = args.get("--print-identity", None) or False
|
||||
args_verbose = args.get("--verbose", None) or 0
|
||||
args_quiet = args.get("--quiet", None) or 0
|
||||
args_no_announce = args.get("--no-announce", None) or False
|
||||
args_announce = args.get("--announce", None)
|
||||
try:
|
||||
if args_announce:
|
||||
args_announce = int(args_announce)
|
||||
except ValueError:
|
||||
print("Invalid value for --announce")
|
||||
return 1
|
||||
args_no_auth = args.get("--no-auth", None) or False
|
||||
args_allowed = args.get("--allowed", None) or []
|
||||
args_program = args.get("<program>", None)
|
||||
@ -868,7 +924,7 @@ Options:
|
||||
quietness=args_quiet,
|
||||
allowed=args_allowed,
|
||||
disable_auth=args_no_auth,
|
||||
disable_announce=args_no_announce,
|
||||
announce_period=args_announce,
|
||||
)
|
||||
|
||||
if args_destination is not None and args_service_name is not None:
|
||||
@ -893,14 +949,14 @@ Options:
|
||||
|
||||
|
||||
def rnsh_cli():
|
||||
try:
|
||||
return_code = asyncio.run(main())
|
||||
finally:
|
||||
with exception.permit(SystemExit):
|
||||
process.tty_unset_reader_callbacks(sys.stdin.fileno())
|
||||
global _tr, _retry_timer
|
||||
with process.TTYRestorer(sys.stdin.fileno()) as _tr:
|
||||
with retry.RetryThread() as _retry_timer:
|
||||
return_code = asyncio.run(_rnsh_cli_main())
|
||||
|
||||
with exception.permit(SystemExit):
|
||||
process.tty_unset_reader_callbacks(sys.stdin.fileno())
|
||||
|
||||
_tr.restore()
|
||||
_retry_timer.close()
|
||||
sys.exit(return_code or 255)
|
||||
|
||||
|
||||
|
44
tests/test_hacks.py
Normal file
44
tests/test_hacks.py
Normal file
@ -0,0 +1,44 @@
|
||||
import asyncio
|
||||
|
||||
import rnsh.hacks as hacks
|
||||
import pytest
|
||||
import time
|
||||
import logging
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
class FakeLink:
|
||||
def __init__(self, link_id):
|
||||
self.link_id = link_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pruning():
|
||||
def listen_request(path, data, request_id, link_id, remote_identity, requested_at):
|
||||
assert path == 1
|
||||
assert data == 2
|
||||
assert request_id == 3
|
||||
assert remote_identity == 4
|
||||
assert requested_at == 5
|
||||
assert link_id == 6
|
||||
return 7
|
||||
|
||||
lhr_called = 0
|
||||
link = FakeLink(6)
|
||||
|
||||
def link_handle_request(self, request_id, unpacked_request):
|
||||
nonlocal lhr_called
|
||||
lhr_called += 1
|
||||
|
||||
old_func = hacks.request_request_id_hack(listen_request, asyncio.get_running_loop())
|
||||
hacks._link_handle_request_orig = link_handle_request
|
||||
hacks._link_handle_request(link, 3, None)
|
||||
link_id = hacks._request_to_link.get(3)
|
||||
assert link_id == link.link_id
|
||||
result = old_func(1, 2, 3, 4, 5)
|
||||
assert result == 7
|
||||
link_id = hacks._request_to_link.get(3)
|
||||
assert link_id == 6
|
||||
await asyncio.sleep(1.5)
|
||||
link_id = hacks._request_to_link.get(3)
|
||||
assert link_id is None
|
15
tests/test_rnsh.py
Normal file
15
tests/test_rnsh.py
Normal file
@ -0,0 +1,15 @@
|
||||
import logging
|
||||
import rnsh.rnsh
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def test_check_magic():
|
||||
magic = rnsh.rnsh._PROTOCOL_VERSION_0
|
||||
# magic for version 0 is generated, make sure it comes out as expected
|
||||
assert magic == 0xdeadbeef00000000
|
||||
# verify the checker thinks it's right
|
||||
assert rnsh.rnsh._protocol_check_magic(magic)
|
||||
# scramble the magic
|
||||
magic = magic | 0x00ffff0000000000
|
||||
# make sure it fails now
|
||||
assert not rnsh.rnsh._protocol_check_magic(magic)
|
Loading…
Reference in New Issue
Block a user