Protocol versioning, announce period option, old API workaround

This commit is contained in:
Aaron Heise 2023-02-11 21:57:13 -06:00
parent 9fe37b57ed
commit da3b390058
6 changed files with 293 additions and 103 deletions

View File

@ -8,9 +8,9 @@ readme = "README.md"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.9" python = "^3.9"
rns = {git = "https://github.com/markqvist/Reticulum.git", rev = "3706769"}
docopt = "^0.6.2" docopt = "^0.6.2"
psutil = "^5.9.4" psutil = "^5.9.4"
rns = "^0.4.8"
[tool.poetry.scripts] [tool.poetry.scripts]
rnsh = 'rnsh.rnsh:rnsh_cli' rnsh = 'rnsh.rnsh:rnsh_cli'

67
rnsh/hacks.py Normal file
View File

@ -0,0 +1,67 @@
import asyncio
import threading
import RNS
import logging
module_logger = logging.getLogger(__name__)
class SelfPruningAssociation:
def __init__(self, ttl: float, loop: asyncio.AbstractEventLoop):
self._log = module_logger.getChild(self.__class__.__name__)
self._ttl = ttl
self._associations = list()
self._lock = threading.RLock()
self._loop = loop
# self._log.debug("__init__")
def _schedule_prune(self, pair):
# self._log.debug(f"schedule prune {pair}")
def schedule_prune_inner(p):
# self._log.debug(f"schedule inner {p}")
self._loop.call_later(self._ttl, self._prune, p)
self._loop.call_soon_threadsafe(schedule_prune_inner, pair)
def _prune(self, pair: any):
# self._log.debug(f"prune {pair}")
with self._lock:
self._associations.remove(pair)
def get(self, key: any) -> any:
# self._log.debug(f"get {key}")
with self._lock:
pair = next(filter(lambda x: x[0] == key, self._associations), None)
if not pair:
return None
return pair[1]
def put(self, key: any, value: any):
# self._log.debug(f"put {key},{value}")
with self._lock:
pair = (key, value)
self._associations.append(pair)
self._schedule_prune(pair)
_request_to_link: SelfPruningAssociation = None
_link_handle_request_orig = RNS.Link.handle_request
def _link_handle_request(self, request_id, unpacked_request):
global _request_to_link
_request_to_link.put(request_id, self.link_id)
_link_handle_request_orig(self, request_id, unpacked_request)
def request_request_id_hack(new_api_reponse_generator, loop):
global _request_to_link
if _request_to_link is None:
RNS.Link.handle_request = _link_handle_request
_request_to_link = SelfPruningAssociation(1.0, loop)
def listen_request(path, data, request_id, remote_identity, requested_at):
link_id = _request_to_link.get(request_id)
return new_api_reponse_generator(path, data, request_id, link_id, remote_identity, requested_at)
return listen_request

View File

@ -27,6 +27,9 @@ import time
import rnsh.exception as exception import rnsh.exception as exception
import logging as __logging import logging as __logging
from typing import Callable from typing import Callable
from contextlib import AbstractContextManager
import types
import typing
module_logger = __logging.getLogger(__name__) module_logger = __logging.getLogger(__name__)
@ -67,7 +70,7 @@ class RetryStatus:
self.retry_callback(self.tag, self.tries) self.retry_callback(self.tag, self.tries)
class RetryThread: class RetryThread(AbstractContextManager):
def __init__(self, loop_period: float = 0.25, name: str = "retry thread"): def __init__(self, loop_period: float = 0.25, name: str = "retry thread"):
self._log = module_logger.getChild(self.__class__.__name__) self._log = module_logger.getChild(self.__class__.__name__)
self._loop_period = loop_period self._loop_period = loop_period
@ -176,3 +179,8 @@ class RetryThread:
status.completed = True status.completed = True
self._log.debug(f"completed {status.tag}") self._log.debug(f"completed {status.tag}")
self._statuses.clear() self._statuses.clear()
def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
__traceback: types.TracebackType) -> bool:
self.close()
return False

View File

@ -35,12 +35,12 @@ import termios
import threading import threading
import time import time
from typing import Callable, TypeVar from typing import Callable, TypeVar
import RNS import RNS
import rnsh.exception as exception import rnsh.exception as exception
import rnsh.process as process import rnsh.process as process
import rnsh.retry as retry import rnsh.retry as retry
import rnsh.rnslogging as rnslogging import rnsh.rnslogging as rnslogging
import rnsh.hacks as hacks
from rnsh.__version import __version__ from rnsh.__version import __version__
module_logger = __logging.getLogger(__name__) module_logger = __logging.getLogger(__name__)
@ -59,7 +59,7 @@ _allowed_identity_hashes = []
_cmd: [str] = None _cmd: [str] = None
DATA_AVAIL_MSG = "data available" DATA_AVAIL_MSG = "data available"
_finished: asyncio.Event | None = None _finished: asyncio.Event | None = None
_retry_timer = retry.RetryThread() _retry_timer: retry.RetryThread | None = None
_destination: RNS.Destination | None = None _destination: RNS.Destination | None = None
_loop: asyncio.AbstractEventLoop | None = None _loop: asyncio.AbstractEventLoop | None = None
@ -107,8 +107,10 @@ def _print_identity(configdir, identitypath, service_name, include_destination:
exit(0) exit(0)
# hack_goes_here
async def _listen(configdir, command, identitypath=None, service_name="default", verbosity=0, quietness=0, async def _listen(configdir, command, identitypath=None, service_name="default", verbosity=0, quietness=0,
allowed=None, disable_auth=None, disable_announce=False): allowed=None, disable_auth=None, announce_period=900):
global _identity, _allow_all, _allowed_identity_hashes, _reticulum, _cmd, _destination global _identity, _allow_all, _allowed_identity_hashes, _reticulum, _cmd, _destination
log = _get_logger("_listen") log = _get_logger("_listen")
_cmd = command _cmd = command
@ -147,14 +149,14 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
if not _allow_all: if not _allow_all:
_destination.register_request_handler( _destination.register_request_handler(
path="data", path="data",
response_generator=_listen_request, response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
allow=RNS.Destination.ALLOW_LIST, allow=RNS.Destination.ALLOW_LIST,
allowed_list=_allowed_identity_hashes allowed_list=_allowed_identity_hashes
) )
else: else:
_destination.register_request_handler( _destination.register_request_handler(
path="data", path="data",
response_generator=_listen_request, response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
allow=RNS.Destination.ALLOW_ALL, allow=RNS.Destination.ALLOW_ALL,
) )
@ -162,14 +164,14 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
log.info("rnsh listening for commands on " + RNS.prettyhexrep(_destination.hash)) log.info("rnsh listening for commands on " + RNS.prettyhexrep(_destination.hash))
if not disable_announce: if announce_period is not None:
_destination.announce() _destination.announce()
last = time.time() last = time.time()
try: try:
while True: while True:
if not disable_announce and time.time() - last > 900: # TODO: make parameter if announce_period and 0 < announce_period < time.time() - last:
last = time.time() last = time.time()
_destination.announce() _destination.announce()
await _check_finished(1.0) await _check_finished(1.0)
@ -177,7 +179,7 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
log.warning("Shutting down") log.warning("Shutting down")
for link in list(_destination.links): for link in list(_destination.links):
with exception.permit(SystemExit): with exception.permit(SystemExit):
proc = ProcessState.get_for_tag(link.link_id) proc = Session.get_for_tag(link.link_id)
if proc is not None and proc.process.running: if proc is not None and proc.process.running:
proc.process.terminate() proc.process.terminate()
await asyncio.sleep(1) await asyncio.sleep(1)
@ -187,17 +189,35 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
link.teardown() link.teardown()
class ProcessState: _PROTOCOL_MAGIC = 0xdeadbeef
_processes: [(any, ProcessState)] = []
def _protocol_make_version(version: int):
return (_PROTOCOL_MAGIC << 32) & 0xffffffff00000000 | (0xffffffff & version)
_PROTOCOL_VERSION_0 = _protocol_make_version(0)
def _protocol_split_version(version: int):
return (version >> 32) & 0xffffffff, version & 0xffffffff
def _protocol_check_magic(value: int):
return _protocol_split_version(value)[0] == _PROTOCOL_MAGIC
class Session:
_processes: [(any, Session)] = []
_lock = threading.RLock() _lock = threading.RLock()
@classmethod @classmethod
def get_for_tag(cls, tag: any) -> ProcessState | None: def get_for_tag(cls, tag: any) -> Session | None:
with cls._lock: with cls._lock:
return next(map(lambda p: p[1], filter(lambda p: p[0] == tag, cls._processes)), None) return next(map(lambda p: p[1], filter(lambda p: p[0] == tag, cls._processes)), None)
@classmethod @classmethod
def put_for_tag(cls, tag: any, ps: ProcessState): def put_for_tag(cls, tag: any, ps: Session):
with cls._lock: with cls._lock:
cls.clear_tag(tag) cls.clear_tag(tag)
cls._processes.append((tag, ps)) cls._processes.append((tag, ps))
@ -234,7 +254,7 @@ class ProcessState:
self._pending_receipt: RNS.PacketReceipt | None = None self._pending_receipt: RNS.PacketReceipt | None = None
self._process.start() self._process.start()
self._term_state: [int] = None self._term_state: [int] = None
ProcessState.put_for_tag(tag, self) Session.put_for_tag(tag, self)
@property @property
def mdu(self) -> int: def mdu(self) -> int:
@ -302,37 +322,40 @@ class ProcessState:
except Exception as e: except Exception as e:
self._log.debug(f"failed to update winsz: {e}") self._log.debug(f"failed to update winsz: {e}")
REQUEST_IDX_STDIN = 0 REQUEST_IDX_VERS = 0
REQUEST_IDX_TERM = 1 REQUEST_IDX_STDIN = 1
REQUEST_IDX_TIOS = 2 REQUEST_IDX_TERM = 2
REQUEST_IDX_ROWS = 3 REQUEST_IDX_TIOS = 3
REQUEST_IDX_COLS = 4 REQUEST_IDX_ROWS = 4
REQUEST_IDX_HPIX = 5 REQUEST_IDX_COLS = 5
REQUEST_IDX_VPIX = 6 REQUEST_IDX_HPIX = 6
REQUEST_IDX_VPIX = 7
@staticmethod @staticmethod
def default_request(stdin_fd: int | None) -> [any]: def default_request(stdin_fd: int | None) -> [any]:
global _PROTOCOL_VERSION_0
request: list[any] = [ request: list[any] = [
None, # 0 Stdin _PROTOCOL_VERSION_0, # 0 Protocol Version
None, # 1 TERM variable None, # 1 Stdin
None, # 2 termios attributes or something None, # 2 TERM variable
None, # 3 terminal rows None, # 3 termios attributes or something
None, # 4 terminal cols None, # 4 terminal rows
None, # 5 terminal horizontal pixels None, # 5 terminal cols
None, # 6 terminal vertical pixels None, # 6 terminal horizontal pixels
None, # 7 terminal vertical pixels
].copy() ].copy()
if stdin_fd is not None: if stdin_fd is not None:
request[ProcessState.REQUEST_IDX_TERM] = os.environ.get("TERM", None) request[Session.REQUEST_IDX_TERM] = os.environ.get("TERM", None)
request[ProcessState.REQUEST_IDX_TIOS] = termios.tcgetattr(stdin_fd) request[Session.REQUEST_IDX_TIOS] = termios.tcgetattr(stdin_fd)
request[ProcessState.REQUEST_IDX_ROWS], \ request[Session.REQUEST_IDX_ROWS], \
request[ProcessState.REQUEST_IDX_COLS], \ request[Session.REQUEST_IDX_COLS], \
request[ProcessState.REQUEST_IDX_HPIX], \ request[Session.REQUEST_IDX_HPIX], \
request[ProcessState.REQUEST_IDX_VPIX] = process.tty_get_winsize(stdin_fd) request[Session.REQUEST_IDX_VPIX] = process.tty_get_winsize(stdin_fd)
return request return request
def process_request(self, data: [any], read_size: int) -> [any]: def process_request(self, data: [any], read_size: int) -> [any]:
stdin = data[ProcessState.REQUEST_IDX_STDIN] # Data passed to stdin stdin = data[Session.REQUEST_IDX_STDIN] # Data passed to stdin
# term = data[ProcessState.REQUEST_IDX_TERM] # TERM environment variable # term = data[ProcessState.REQUEST_IDX_TERM] # TERM environment variable
# tios = data[ProcessState.REQUEST_IDX_TIOS] # termios attr # tios = data[ProcessState.REQUEST_IDX_TIOS] # termios attr
# rows = data[ProcessState.REQUEST_IDX_ROWS] # window rows # rows = data[ProcessState.REQUEST_IDX_ROWS] # window rows
@ -340,10 +363,10 @@ class ProcessState:
# hpix = data[ProcessState.REQUEST_IDX_HPIX] # window horizontal pixels # hpix = data[ProcessState.REQUEST_IDX_HPIX] # window horizontal pixels
# vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels # vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels
# term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX+1] # term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX+1]
response = ProcessState.default_response() response = Session.default_response()
term_state = data[ProcessState.REQUEST_IDX_TIOS:ProcessState.REQUEST_IDX_VPIX + 1] term_state = data[Session.REQUEST_IDX_TIOS:Session.REQUEST_IDX_VPIX + 1]
response[ProcessState.RESPONSE_IDX_RUNNING] = self.process.running response[Session.RESPONSE_IDX_RUNNING] = self.process.running
if self.process.running: if self.process.running:
if term_state != self._term_state: if term_state != self._term_state:
self._term_state = term_state self._term_state = term_state
@ -351,39 +374,42 @@ class ProcessState:
if stdin is not None and len(stdin) > 0: if stdin is not None and len(stdin) > 0:
stdin = base64.b64decode(stdin) stdin = base64.b64decode(stdin)
self.process.write(stdin) self.process.write(stdin)
response[ProcessState.RESPONSE_IDX_RETCODE] = None if self.process.running else self.return_code response[Session.RESPONSE_IDX_RETCODE] = None if self.process.running else self.return_code
with self.lock: with self.lock:
stdout = self.read(read_size) stdout = self.read(read_size)
response[ProcessState.RESPONSE_IDX_RDYBYTE] = len(self._data_buffer) response[Session.RESPONSE_IDX_RDYBYTE] = len(self._data_buffer)
if stdout is not None and len(stdout) > 0: if stdout is not None and len(stdout) > 0:
response[ProcessState.RESPONSE_IDX_STDOUT] = base64.b64encode(stdout).decode("utf-8") response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode(stdout).decode("utf-8")
return response return response
RESPONSE_IDX_RUNNING = 0 RESPONSE_IDX_VERSION = 0
RESPONSE_IDX_RETCODE = 1 RESPONSE_IDX_RUNNING = 1
RESPONSE_IDX_RDYBYTE = 2 RESPONSE_IDX_RETCODE = 2
RESPONSE_IDX_STDOUT = 3 RESPONSE_IDX_RDYBYTE = 3
RESPONSE_IDX_TMSTAMP = 4 RESPONSE_IDX_STDOUT = 4
RESPONSE_IDX_TMSTAMP = 5
@staticmethod @staticmethod
def default_response() -> [any]: def default_response() -> [any]:
global _PROTOCOL_VERSION_0
response: list[any] = [ response: list[any] = [
False, # 0: Process running _PROTOCOL_VERSION_0, # 0: Protocol version
None, # 1: Return value False, # 1: Process running
0, # 2: Number of outstanding bytes None, # 2: Return value
None, # 3: Stdout/Stderr 0, # 3: Number of outstanding bytes
None, # 4: Timestamp None, # 4: Stdout/Stderr
None, # 5: Timestamp
].copy() ].copy()
response[ProcessState.RESPONSE_IDX_TMSTAMP] = time.time() response[Session.RESPONSE_IDX_TMSTAMP] = time.time()
return response return response
def _subproc_data_ready(link: RNS.Link, chars_available: int): def _subproc_data_ready(link: RNS.Link, chars_available: int):
global _retry_timer global _retry_timer
log = _get_logger("_subproc_data_ready") log = _get_logger("_subproc_data_ready")
process_state: ProcessState = ProcessState.get_for_tag(link.link_id) session: Session = Session.get_for_tag(link.link_id)
def send(timeout: bool, tag: any, tries: int) -> any: def send(timeout: bool, tag: any, tries: int) -> any:
# log.debug("send") # log.debug("send")
@ -392,10 +418,10 @@ def _subproc_data_ready(link: RNS.Link, chars_available: int):
try: try:
if link.status != RNS.Link.ACTIVE: if link.status != RNS.Link.ACTIVE:
_retry_timer.complete(link.link_id) _retry_timer.complete(link.link_id)
process_state.pending_receipt_take() session.pending_receipt_take()
return return
pr = process_state.pending_receipt_take() pr = session.pending_receipt_take()
log.debug(f"send inner pr: {pr}") log.debug(f"send inner pr: {pr}")
if pr is not None and pr.status == RNS.PacketReceipt.DELIVERED: if pr is not None and pr.status == RNS.PacketReceipt.DELIVERED:
if not timeout: if not timeout:
@ -405,12 +431,12 @@ def _subproc_data_ready(link: RNS.Link, chars_available: int):
else: else:
if not timeout: if not timeout:
log.info( log.info(
f"Notifying client try {tries} (retcode: {process_state.return_code} " + f"Notifying client try {tries} (retcode: {session.return_code} " +
f"chars avail: {chars_available})") f"chars avail: {chars_available})")
packet = RNS.Packet(link, DATA_AVAIL_MSG.encode("utf-8")) packet = RNS.Packet(link, DATA_AVAIL_MSG.encode("utf-8"))
packet.send() packet.send()
pr = packet.receipt pr = packet.receipt
process_state.pending_receipt_put(pr) session.pending_receipt_put(pr)
else: else:
log.error(f"Retry count exceeded, terminating link {link}") log.error(f"Retry count exceeded, terminating link {link}")
_retry_timer.complete(link.link_id) _retry_timer.complete(link.link_id)
@ -421,7 +447,7 @@ def _subproc_data_ready(link: RNS.Link, chars_available: int):
_loop.call_soon_threadsafe(inner) _loop.call_soon_threadsafe(inner)
return link.link_id return link.link_id
with process_state.lock: with session.lock:
if not _retry_timer.has_tag(link.link_id): if not _retry_timer.has_tag(link.link_id):
_retry_timer.begin(try_limit=15, _retry_timer.begin(try_limit=15,
wait_delay=max(link.rtt * 5 if link.rtt is not None else 1, 1), wait_delay=max(link.rtt * 5 if link.rtt is not None else 1, 1),
@ -436,7 +462,7 @@ def _subproc_terminated(link: RNS.Link, return_code: int):
global _loop global _loop
log = _get_logger("_subproc_terminated") log = _get_logger("_subproc_terminated")
log.info(f"Subprocess returned {return_code} for link {link}") log.info(f"Subprocess returned {return_code} for link {link}")
proc = ProcessState.get_for_tag(link.link_id) proc = Session.get_for_tag(link.link_id)
if proc is None: if proc is None:
log.debug(f"no proc for link {link}") log.debug(f"no proc for link {link}")
return return
@ -449,7 +475,7 @@ def _subproc_terminated(link: RNS.Link, return_code: int):
try: try:
link.teardown() link.teardown()
finally: finally:
ProcessState.clear_tag(link.link_id) Session.clear_tag(link.link_id)
_loop.call_later(300, inner) _loop.call_later(300, inner)
_loop.call_soon(_subproc_data_ready, link, 0) _loop.call_soon(_subproc_data_ready, link, 0)
@ -458,11 +484,11 @@ def _subproc_terminated(link: RNS.Link, return_code: int):
def _listen_start_proc(link: RNS.Link, remote_identity: str | None, term: str, def _listen_start_proc(link: RNS.Link, remote_identity: str | None, term: str,
loop: asyncio.AbstractEventLoop) -> ProcessState | None: loop: asyncio.AbstractEventLoop) -> Session | None:
global _cmd global _cmd
log = _get_logger("_listen_start_proc") log = _get_logger("_listen_start_proc")
try: try:
return ProcessState(tag=link.link_id, return Session(tag=link.link_id,
cmd=_cmd, cmd=_cmd,
term=term, term=term,
remote_identity=remote_identity, remote_identity=remote_identity,
@ -488,7 +514,7 @@ def _listen_link_closed(link: RNS.Link):
log = _get_logger("_listen_link_closed") log = _get_logger("_listen_link_closed")
# async def cleanup(): # async def cleanup():
log.info("Link " + str(link) + " closed") log.info("Link " + str(link) + " closed")
proc: ProcessState | None = ProcessState.get_for_tag(link.link_id) proc: Session | None = Session.get_for_tag(link.link_id)
if proc is None: if proc is None:
log.warning(f"No process for link {link}") log.warning(f"No process for link {link}")
else: else:
@ -497,7 +523,7 @@ def _listen_link_closed(link: RNS.Link):
_retry_timer.complete(link.link_id) _retry_timer.complete(link.link_id)
except Exception as e: except Exception as e:
log.error(f"Error closing process for link {link}: {e}") log.error(f"Error closing process for link {link}: {e}")
ProcessState.clear_tag(link.link_id) Session.clear_tag(link.link_id)
def _initiator_identified(link, identity): def _initiator_identified(link, identity):
@ -513,34 +539,48 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
global _destination, _retry_timer, _loop global _destination, _retry_timer, _loop
log = _get_logger("_listen_request") log = _get_logger("_listen_request")
log.debug(f"listen_execute {path} {request_id} {link_id} {remote_identity}, {requested_at}") log.debug(f"listen_execute {path} {request_id} {link_id} {remote_identity}, {requested_at}")
if not hasattr(data, "__len__") or len(data) < 1:
raise Exception("Request data invalid")
_retry_timer.complete(link_id) _retry_timer.complete(link_id)
link: RNS.Link = next(filter(lambda l: l.link_id == link_id, _destination.links), None) link: RNS.Link = next(filter(lambda l: l.link_id == link_id, _destination.links), None)
if link is None: if link is None:
raise Exception(f"Invalid request {request_id}, no link found with id {link_id}") raise Exception(f"Invalid request {request_id}, no link found with id {link_id}")
process_state: ProcessState | None = None
remote_version = data[Session.REQUEST_IDX_VERS]
if not _protocol_check_magic(remote_version):
raise Exception("Request magic incorrect")
if not remote_version == _PROTOCOL_VERSION_0:
response = Session.default_response()
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode("Listener<->initiator version mismatch\r\n".encode("utf-8"))
response[Session.RESPONSE_IDX_RETCODE] = 255
response[Session.RESPONSE_IDX_RDYBYTE] = 0
return response
session: Session | None = None
try: try:
term = data[ProcessState.REQUEST_IDX_TERM] term = data[Session.REQUEST_IDX_TERM]
process_state = ProcessState.get_for_tag(link.link_id) session = Session.get_for_tag(link.link_id)
if process_state is None: if session is None:
log.debug(f"Process not found for link {link}") log.debug(f"Process not found for link {link}")
process_state = _listen_start_proc(link=link, session = _listen_start_proc(link=link,
term=term, term=term,
remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""), remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""),
loop=_loop) loop=_loop)
# leave significant headroom for metadata and encoding # leave significant headroom for metadata and encoding
result = process_state.process_request(data, link.MDU * 4 // 3) result = session.process_request(data, link.MDU * 4 // 3)
return result return result
# return ProcessState.default_response() # return ProcessState.default_response()
except Exception as e: except Exception as e:
log.error(f"Error procesing request for link {link}: {e}") log.error(f"Error procesing request for link {link}: {e}")
try: try:
if process_state is not None and process_state.process.running: if session is not None and session.process.running:
process_state.process.terminate() session.process.terminate()
except Exception as ee: except Exception as ee:
log.debug(f"Error terminating process for link {link}: {ee}") log.debug(f"Error terminating process for link {link}: {ee}")
return ProcessState.default_response() return Session.default_response()
async def _spin(until: callable = None, timeout: float | None = None) -> bool: async def _spin(until: callable = None, timeout: float | None = None) -> bool:
@ -558,7 +598,7 @@ async def _spin(until: callable = None, timeout: float | None = None) -> bool:
_link: RNS.Link | None = None _link: RNS.Link | None = None
_remote_exec_grace = 2.0 _remote_exec_grace = 2.0
_new_data: asyncio.Event | None = None _new_data: asyncio.Event | None = None
_tr = process.TTYRestorer(sys.stdin.fileno()) _tr: process.TTYRestorer | None = None
def _client_packet_handler(message, packet): def _client_packet_handler(message, packet):
@ -632,8 +672,8 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
_link.set_packet_callback(_client_packet_handler) _link.set_packet_callback(_client_packet_handler)
request = ProcessState.default_request(sys.stdin.fileno()) request = Session.default_request(sys.stdin.fileno())
request[ProcessState.REQUEST_IDX_STDIN] = (base64.b64encode(stdin) if stdin is not None else None) request[Session.REQUEST_IDX_STDIN] = (base64.b64encode(stdin) if stdin is not None else None)
# TODO: Tune # TODO: Tune
timeout = timeout + _link.rtt * 4 + _remote_exec_grace timeout = timeout + _link.rtt * 4 + _remote_exec_grace
@ -671,18 +711,27 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
if request_receipt.response is not None: if request_receipt.response is not None:
try: try:
running = request_receipt.response[ProcessState.RESPONSE_IDX_RUNNING] or True version = request_receipt.response[Session.RESPONSE_IDX_VERSION] or 0
return_code = request_receipt.response[ProcessState.RESPONSE_IDX_RETCODE] if not _protocol_check_magic(version):
ready_bytes = request_receipt.response[ProcessState.RESPONSE_IDX_RDYBYTE] or 0 raise RemoteExecutionError("Protocol error")
stdout = request_receipt.response[ProcessState.RESPONSE_IDX_STDOUT] elif version != _PROTOCOL_VERSION_0:
timestamp = request_receipt.response[ProcessState.RESPONSE_IDX_TMSTAMP] raise RemoteExecutionError("Protocol version mismatch")
running = request_receipt.response[Session.RESPONSE_IDX_RUNNING] or True
return_code = request_receipt.response[Session.RESPONSE_IDX_RETCODE]
ready_bytes = request_receipt.response[Session.RESPONSE_IDX_RDYBYTE] or 0
stdout = request_receipt.response[Session.RESPONSE_IDX_STDOUT]
if stdout is not None:
stdout = base64.b64decode(stdout)
timestamp = request_receipt.response[Session.RESPONSE_IDX_TMSTAMP]
# log.debug("data: " + (stdout.decode("utf-8") if stdout is not None else "")) # log.debug("data: " + (stdout.decode("utf-8") if stdout is not None else ""))
except RemoteExecutionError:
raise
except Exception as e: except Exception as e:
raise RemoteExecutionError(f"Received invalid response") from e raise RemoteExecutionError(f"Received invalid response") from e
_tr.raw() _tr.raw()
if stdout is not None: if stdout is not None:
stdout = base64.b64decode(stdout)
# log.debug(f"stdout: {stdout}") # log.debug(f"stdout: {stdout}")
os.write(sys.stdout.fileno(), stdout) os.write(sys.stdout.fileno(), stdout)
@ -784,7 +833,7 @@ def _split_array_at(arr: [_T], at: _T) -> ([_T], [_T]):
return arr, [] return arr, []
async def main(): async def _rnsh_cli_main():
global _tr, _finished, _loop global _tr, _finished, _loop
import docopt import docopt
log = _get_logger("main") log = _get_logger("main")
@ -797,8 +846,8 @@ async def main():
Usage: Usage:
rnsh [--config <configdir>] [-i <identityfile>] [-s <service_name>] [-l] -p rnsh [--config <configdir>] [-i <identityfile>] [-s <service_name>] [-l] -p
rnsh -l [--config <configfile>] [-i <identityfile>] [-s <service_name>] rnsh -l [--config <configfile>] [-i <identityfile>] [-s <service_name>]
[-v...] [-q...] [-b] (-n | -a <identity_hash> [-a <identity_hash>]...) [-v...] [-q...] [-b <period>] (-n | -a <identity_hash> [-a <identity_hash>] ...)
[--] <program> [<arg>...] [--] <program> [<arg> ...]
rnsh [--config <configfile>] [-i <identityfile>] [-s <service_name>] rnsh [--config <configfile>] [-i <identityfile>] [-s <service_name>]
[-v...] [-q...] [-N] [-m] [-w <timeout>] <destination_hash> [-v...] [-q...] [-N] [-m] [-w <timeout>] <destination_hash>
rnsh -h rnsh -h
@ -810,7 +859,8 @@ Options:
-s NAME --service NAME Listen on/connect to specific service name if not default -s NAME --service NAME Listen on/connect to specific service name if not default
-p --print-identity Print identity information and exit -p --print-identity Print identity information and exit
-l --listen Listen (server) mode -l --listen Listen (server) mode
-b --no-announce Do not announce service -b --announce PERIOD Announce on startup and every PERIOD seconds
Specify 0 for PERIOD to announce on startup only.
-a HASH --allowed HASH Specify identities allowed to connect -a HASH --allowed HASH Specify identities allowed to connect
-n --no-auth Disable authentication -n --no-auth Disable authentication
-N --no-id Disable identify on connect -N --no-id Disable identify on connect
@ -837,7 +887,13 @@ Options:
args_print_identity = args.get("--print-identity", None) or False args_print_identity = args.get("--print-identity", None) or False
args_verbose = args.get("--verbose", None) or 0 args_verbose = args.get("--verbose", None) or 0
args_quiet = args.get("--quiet", None) or 0 args_quiet = args.get("--quiet", None) or 0
args_no_announce = args.get("--no-announce", None) or False args_announce = args.get("--announce", None)
try:
if args_announce:
args_announce = int(args_announce)
except ValueError:
print("Invalid value for --announce")
return 1
args_no_auth = args.get("--no-auth", None) or False args_no_auth = args.get("--no-auth", None) or False
args_allowed = args.get("--allowed", None) or [] args_allowed = args.get("--allowed", None) or []
args_program = args.get("<program>", None) args_program = args.get("<program>", None)
@ -868,7 +924,7 @@ Options:
quietness=args_quiet, quietness=args_quiet,
allowed=args_allowed, allowed=args_allowed,
disable_auth=args_no_auth, disable_auth=args_no_auth,
disable_announce=args_no_announce, announce_period=args_announce,
) )
if args_destination is not None and args_service_name is not None: if args_destination is not None and args_service_name is not None:
@ -893,14 +949,14 @@ Options:
def rnsh_cli(): def rnsh_cli():
try: global _tr, _retry_timer
return_code = asyncio.run(main()) with process.TTYRestorer(sys.stdin.fileno()) as _tr:
finally: with retry.RetryThread() as _retry_timer:
return_code = asyncio.run(_rnsh_cli_main())
with exception.permit(SystemExit): with exception.permit(SystemExit):
process.tty_unset_reader_callbacks(sys.stdin.fileno()) process.tty_unset_reader_callbacks(sys.stdin.fileno())
_tr.restore()
_retry_timer.close()
sys.exit(return_code or 255) sys.exit(return_code or 255)

44
tests/test_hacks.py Normal file
View File

@ -0,0 +1,44 @@
import asyncio
import rnsh.hacks as hacks
import pytest
import time
import logging
logging.getLogger().setLevel(logging.DEBUG)
class FakeLink:
def __init__(self, link_id):
self.link_id = link_id
@pytest.mark.asyncio
async def test_pruning():
def listen_request(path, data, request_id, link_id, remote_identity, requested_at):
assert path == 1
assert data == 2
assert request_id == 3
assert remote_identity == 4
assert requested_at == 5
assert link_id == 6
return 7
lhr_called = 0
link = FakeLink(6)
def link_handle_request(self, request_id, unpacked_request):
nonlocal lhr_called
lhr_called += 1
old_func = hacks.request_request_id_hack(listen_request, asyncio.get_running_loop())
hacks._link_handle_request_orig = link_handle_request
hacks._link_handle_request(link, 3, None)
link_id = hacks._request_to_link.get(3)
assert link_id == link.link_id
result = old_func(1, 2, 3, 4, 5)
assert result == 7
link_id = hacks._request_to_link.get(3)
assert link_id == 6
await asyncio.sleep(1.5)
link_id = hacks._request_to_link.get(3)
assert link_id is None

15
tests/test_rnsh.py Normal file
View File

@ -0,0 +1,15 @@
import logging
import rnsh.rnsh
logging.getLogger().setLevel(logging.DEBUG)
def test_check_magic():
magic = rnsh.rnsh._PROTOCOL_VERSION_0
# magic for version 0 is generated, make sure it comes out as expected
assert magic == 0xdeadbeef00000000
# verify the checker thinks it's right
assert rnsh.rnsh._protocol_check_magic(magic)
# scramble the magic
magic = magic | 0x00ffff0000000000
# make sure it fails now
assert not rnsh.rnsh._protocol_check_magic(magic)