Protocol versioning, announce period option, old API workaround

This commit is contained in:
Aaron Heise 2023-02-11 21:57:13 -06:00
parent 9fe37b57ed
commit da3b390058
6 changed files with 293 additions and 103 deletions

View File

@ -8,9 +8,9 @@ readme = "README.md"
[tool.poetry.dependencies]
python = "^3.9"
rns = {git = "https://github.com/markqvist/Reticulum.git", rev = "3706769"}
docopt = "^0.6.2"
psutil = "^5.9.4"
rns = "^0.4.8"
[tool.poetry.scripts]
rnsh = 'rnsh.rnsh:rnsh_cli'

67
rnsh/hacks.py Normal file
View File

@ -0,0 +1,67 @@
import asyncio
import threading
import RNS
import logging
module_logger = logging.getLogger(__name__)
class SelfPruningAssociation:
def __init__(self, ttl: float, loop: asyncio.AbstractEventLoop):
self._log = module_logger.getChild(self.__class__.__name__)
self._ttl = ttl
self._associations = list()
self._lock = threading.RLock()
self._loop = loop
# self._log.debug("__init__")
def _schedule_prune(self, pair):
# self._log.debug(f"schedule prune {pair}")
def schedule_prune_inner(p):
# self._log.debug(f"schedule inner {p}")
self._loop.call_later(self._ttl, self._prune, p)
self._loop.call_soon_threadsafe(schedule_prune_inner, pair)
def _prune(self, pair: any):
# self._log.debug(f"prune {pair}")
with self._lock:
self._associations.remove(pair)
def get(self, key: any) -> any:
# self._log.debug(f"get {key}")
with self._lock:
pair = next(filter(lambda x: x[0] == key, self._associations), None)
if not pair:
return None
return pair[1]
def put(self, key: any, value: any):
# self._log.debug(f"put {key},{value}")
with self._lock:
pair = (key, value)
self._associations.append(pair)
self._schedule_prune(pair)
_request_to_link: SelfPruningAssociation = None
_link_handle_request_orig = RNS.Link.handle_request
def _link_handle_request(self, request_id, unpacked_request):
global _request_to_link
_request_to_link.put(request_id, self.link_id)
_link_handle_request_orig(self, request_id, unpacked_request)
def request_request_id_hack(new_api_reponse_generator, loop):
global _request_to_link
if _request_to_link is None:
RNS.Link.handle_request = _link_handle_request
_request_to_link = SelfPruningAssociation(1.0, loop)
def listen_request(path, data, request_id, remote_identity, requested_at):
link_id = _request_to_link.get(request_id)
return new_api_reponse_generator(path, data, request_id, link_id, remote_identity, requested_at)
return listen_request

View File

@ -27,6 +27,9 @@ import time
import rnsh.exception as exception
import logging as __logging
from typing import Callable
from contextlib import AbstractContextManager
import types
import typing
module_logger = __logging.getLogger(__name__)
@ -67,7 +70,7 @@ class RetryStatus:
self.retry_callback(self.tag, self.tries)
class RetryThread:
class RetryThread(AbstractContextManager):
def __init__(self, loop_period: float = 0.25, name: str = "retry thread"):
self._log = module_logger.getChild(self.__class__.__name__)
self._loop_period = loop_period
@ -176,3 +179,8 @@ class RetryThread:
status.completed = True
self._log.debug(f"completed {status.tag}")
self._statuses.clear()
def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
__traceback: types.TracebackType) -> bool:
self.close()
return False

View File

@ -35,12 +35,12 @@ import termios
import threading
import time
from typing import Callable, TypeVar
import RNS
import rnsh.exception as exception
import rnsh.process as process
import rnsh.retry as retry
import rnsh.rnslogging as rnslogging
import rnsh.hacks as hacks
from rnsh.__version import __version__
module_logger = __logging.getLogger(__name__)
@ -59,7 +59,7 @@ _allowed_identity_hashes = []
_cmd: [str] = None
DATA_AVAIL_MSG = "data available"
_finished: asyncio.Event | None = None
_retry_timer = retry.RetryThread()
_retry_timer: retry.RetryThread | None = None
_destination: RNS.Destination | None = None
_loop: asyncio.AbstractEventLoop | None = None
@ -107,8 +107,10 @@ def _print_identity(configdir, identitypath, service_name, include_destination:
exit(0)
# hack_goes_here
async def _listen(configdir, command, identitypath=None, service_name="default", verbosity=0, quietness=0,
allowed=None, disable_auth=None, disable_announce=False):
allowed=None, disable_auth=None, announce_period=900):
global _identity, _allow_all, _allowed_identity_hashes, _reticulum, _cmd, _destination
log = _get_logger("_listen")
_cmd = command
@ -147,14 +149,14 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
if not _allow_all:
_destination.register_request_handler(
path="data",
response_generator=_listen_request,
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
allow=RNS.Destination.ALLOW_LIST,
allowed_list=_allowed_identity_hashes
)
else:
_destination.register_request_handler(
path="data",
response_generator=_listen_request,
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
allow=RNS.Destination.ALLOW_ALL,
)
@ -162,14 +164,14 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
log.info("rnsh listening for commands on " + RNS.prettyhexrep(_destination.hash))
if not disable_announce:
if announce_period is not None:
_destination.announce()
last = time.time()
try:
while True:
if not disable_announce and time.time() - last > 900: # TODO: make parameter
if announce_period and 0 < announce_period < time.time() - last:
last = time.time()
_destination.announce()
await _check_finished(1.0)
@ -177,7 +179,7 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
log.warning("Shutting down")
for link in list(_destination.links):
with exception.permit(SystemExit):
proc = ProcessState.get_for_tag(link.link_id)
proc = Session.get_for_tag(link.link_id)
if proc is not None and proc.process.running:
proc.process.terminate()
await asyncio.sleep(1)
@ -187,17 +189,35 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
link.teardown()
class ProcessState:
_processes: [(any, ProcessState)] = []
_PROTOCOL_MAGIC = 0xdeadbeef
def _protocol_make_version(version: int):
return (_PROTOCOL_MAGIC << 32) & 0xffffffff00000000 | (0xffffffff & version)
_PROTOCOL_VERSION_0 = _protocol_make_version(0)
def _protocol_split_version(version: int):
return (version >> 32) & 0xffffffff, version & 0xffffffff
def _protocol_check_magic(value: int):
return _protocol_split_version(value)[0] == _PROTOCOL_MAGIC
class Session:
_processes: [(any, Session)] = []
_lock = threading.RLock()
@classmethod
def get_for_tag(cls, tag: any) -> ProcessState | None:
def get_for_tag(cls, tag: any) -> Session | None:
with cls._lock:
return next(map(lambda p: p[1], filter(lambda p: p[0] == tag, cls._processes)), None)
@classmethod
def put_for_tag(cls, tag: any, ps: ProcessState):
def put_for_tag(cls, tag: any, ps: Session):
with cls._lock:
cls.clear_tag(tag)
cls._processes.append((tag, ps))
@ -234,7 +254,7 @@ class ProcessState:
self._pending_receipt: RNS.PacketReceipt | None = None
self._process.start()
self._term_state: [int] = None
ProcessState.put_for_tag(tag, self)
Session.put_for_tag(tag, self)
@property
def mdu(self) -> int:
@ -302,37 +322,40 @@ class ProcessState:
except Exception as e:
self._log.debug(f"failed to update winsz: {e}")
REQUEST_IDX_STDIN = 0
REQUEST_IDX_TERM = 1
REQUEST_IDX_TIOS = 2
REQUEST_IDX_ROWS = 3
REQUEST_IDX_COLS = 4
REQUEST_IDX_HPIX = 5
REQUEST_IDX_VPIX = 6
REQUEST_IDX_VERS = 0
REQUEST_IDX_STDIN = 1
REQUEST_IDX_TERM = 2
REQUEST_IDX_TIOS = 3
REQUEST_IDX_ROWS = 4
REQUEST_IDX_COLS = 5
REQUEST_IDX_HPIX = 6
REQUEST_IDX_VPIX = 7
@staticmethod
def default_request(stdin_fd: int | None) -> [any]:
global _PROTOCOL_VERSION_0
request: list[any] = [
None, # 0 Stdin
None, # 1 TERM variable
None, # 2 termios attributes or something
None, # 3 terminal rows
None, # 4 terminal cols
None, # 5 terminal horizontal pixels
None, # 6 terminal vertical pixels
_PROTOCOL_VERSION_0, # 0 Protocol Version
None, # 1 Stdin
None, # 2 TERM variable
None, # 3 termios attributes or something
None, # 4 terminal rows
None, # 5 terminal cols
None, # 6 terminal horizontal pixels
None, # 7 terminal vertical pixels
].copy()
if stdin_fd is not None:
request[ProcessState.REQUEST_IDX_TERM] = os.environ.get("TERM", None)
request[ProcessState.REQUEST_IDX_TIOS] = termios.tcgetattr(stdin_fd)
request[ProcessState.REQUEST_IDX_ROWS], \
request[ProcessState.REQUEST_IDX_COLS], \
request[ProcessState.REQUEST_IDX_HPIX], \
request[ProcessState.REQUEST_IDX_VPIX] = process.tty_get_winsize(stdin_fd)
request[Session.REQUEST_IDX_TERM] = os.environ.get("TERM", None)
request[Session.REQUEST_IDX_TIOS] = termios.tcgetattr(stdin_fd)
request[Session.REQUEST_IDX_ROWS], \
request[Session.REQUEST_IDX_COLS], \
request[Session.REQUEST_IDX_HPIX], \
request[Session.REQUEST_IDX_VPIX] = process.tty_get_winsize(stdin_fd)
return request
def process_request(self, data: [any], read_size: int) -> [any]:
stdin = data[ProcessState.REQUEST_IDX_STDIN] # Data passed to stdin
stdin = data[Session.REQUEST_IDX_STDIN] # Data passed to stdin
# term = data[ProcessState.REQUEST_IDX_TERM] # TERM environment variable
# tios = data[ProcessState.REQUEST_IDX_TIOS] # termios attr
# rows = data[ProcessState.REQUEST_IDX_ROWS] # window rows
@ -340,10 +363,10 @@ class ProcessState:
# hpix = data[ProcessState.REQUEST_IDX_HPIX] # window horizontal pixels
# vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels
# term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX+1]
response = ProcessState.default_response()
term_state = data[ProcessState.REQUEST_IDX_TIOS:ProcessState.REQUEST_IDX_VPIX + 1]
response = Session.default_response()
term_state = data[Session.REQUEST_IDX_TIOS:Session.REQUEST_IDX_VPIX + 1]
response[ProcessState.RESPONSE_IDX_RUNNING] = self.process.running
response[Session.RESPONSE_IDX_RUNNING] = self.process.running
if self.process.running:
if term_state != self._term_state:
self._term_state = term_state
@ -351,39 +374,42 @@ class ProcessState:
if stdin is not None and len(stdin) > 0:
stdin = base64.b64decode(stdin)
self.process.write(stdin)
response[ProcessState.RESPONSE_IDX_RETCODE] = None if self.process.running else self.return_code
response[Session.RESPONSE_IDX_RETCODE] = None if self.process.running else self.return_code
with self.lock:
stdout = self.read(read_size)
response[ProcessState.RESPONSE_IDX_RDYBYTE] = len(self._data_buffer)
response[Session.RESPONSE_IDX_RDYBYTE] = len(self._data_buffer)
if stdout is not None and len(stdout) > 0:
response[ProcessState.RESPONSE_IDX_STDOUT] = base64.b64encode(stdout).decode("utf-8")
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode(stdout).decode("utf-8")
return response
RESPONSE_IDX_RUNNING = 0
RESPONSE_IDX_RETCODE = 1
RESPONSE_IDX_RDYBYTE = 2
RESPONSE_IDX_STDOUT = 3
RESPONSE_IDX_TMSTAMP = 4
RESPONSE_IDX_VERSION = 0
RESPONSE_IDX_RUNNING = 1
RESPONSE_IDX_RETCODE = 2
RESPONSE_IDX_RDYBYTE = 3
RESPONSE_IDX_STDOUT = 4
RESPONSE_IDX_TMSTAMP = 5
@staticmethod
def default_response() -> [any]:
global _PROTOCOL_VERSION_0
response: list[any] = [
False, # 0: Process running
None, # 1: Return value
0, # 2: Number of outstanding bytes
None, # 3: Stdout/Stderr
None, # 4: Timestamp
_PROTOCOL_VERSION_0, # 0: Protocol version
False, # 1: Process running
None, # 2: Return value
0, # 3: Number of outstanding bytes
None, # 4: Stdout/Stderr
None, # 5: Timestamp
].copy()
response[ProcessState.RESPONSE_IDX_TMSTAMP] = time.time()
response[Session.RESPONSE_IDX_TMSTAMP] = time.time()
return response
def _subproc_data_ready(link: RNS.Link, chars_available: int):
global _retry_timer
log = _get_logger("_subproc_data_ready")
process_state: ProcessState = ProcessState.get_for_tag(link.link_id)
session: Session = Session.get_for_tag(link.link_id)
def send(timeout: bool, tag: any, tries: int) -> any:
# log.debug("send")
@ -392,10 +418,10 @@ def _subproc_data_ready(link: RNS.Link, chars_available: int):
try:
if link.status != RNS.Link.ACTIVE:
_retry_timer.complete(link.link_id)
process_state.pending_receipt_take()
session.pending_receipt_take()
return
pr = process_state.pending_receipt_take()
pr = session.pending_receipt_take()
log.debug(f"send inner pr: {pr}")
if pr is not None and pr.status == RNS.PacketReceipt.DELIVERED:
if not timeout:
@ -405,12 +431,12 @@ def _subproc_data_ready(link: RNS.Link, chars_available: int):
else:
if not timeout:
log.info(
f"Notifying client try {tries} (retcode: {process_state.return_code} " +
f"Notifying client try {tries} (retcode: {session.return_code} " +
f"chars avail: {chars_available})")
packet = RNS.Packet(link, DATA_AVAIL_MSG.encode("utf-8"))
packet.send()
pr = packet.receipt
process_state.pending_receipt_put(pr)
session.pending_receipt_put(pr)
else:
log.error(f"Retry count exceeded, terminating link {link}")
_retry_timer.complete(link.link_id)
@ -421,7 +447,7 @@ def _subproc_data_ready(link: RNS.Link, chars_available: int):
_loop.call_soon_threadsafe(inner)
return link.link_id
with process_state.lock:
with session.lock:
if not _retry_timer.has_tag(link.link_id):
_retry_timer.begin(try_limit=15,
wait_delay=max(link.rtt * 5 if link.rtt is not None else 1, 1),
@ -436,7 +462,7 @@ def _subproc_terminated(link: RNS.Link, return_code: int):
global _loop
log = _get_logger("_subproc_terminated")
log.info(f"Subprocess returned {return_code} for link {link}")
proc = ProcessState.get_for_tag(link.link_id)
proc = Session.get_for_tag(link.link_id)
if proc is None:
log.debug(f"no proc for link {link}")
return
@ -449,7 +475,7 @@ def _subproc_terminated(link: RNS.Link, return_code: int):
try:
link.teardown()
finally:
ProcessState.clear_tag(link.link_id)
Session.clear_tag(link.link_id)
_loop.call_later(300, inner)
_loop.call_soon(_subproc_data_ready, link, 0)
@ -458,18 +484,18 @@ def _subproc_terminated(link: RNS.Link, return_code: int):
def _listen_start_proc(link: RNS.Link, remote_identity: str | None, term: str,
loop: asyncio.AbstractEventLoop) -> ProcessState | None:
loop: asyncio.AbstractEventLoop) -> Session | None:
global _cmd
log = _get_logger("_listen_start_proc")
try:
return ProcessState(tag=link.link_id,
cmd=_cmd,
term=term,
remote_identity=remote_identity,
mdu=link.MDU,
loop=loop,
data_available_callback=functools.partial(_subproc_data_ready, link),
terminated_callback=functools.partial(_subproc_terminated, link))
return Session(tag=link.link_id,
cmd=_cmd,
term=term,
remote_identity=remote_identity,
mdu=link.MDU,
loop=loop,
data_available_callback=functools.partial(_subproc_data_ready, link),
terminated_callback=functools.partial(_subproc_terminated, link))
except Exception as e:
log.error("Failed to launch process: " + str(e))
_subproc_terminated(link, 255)
@ -488,7 +514,7 @@ def _listen_link_closed(link: RNS.Link):
log = _get_logger("_listen_link_closed")
# async def cleanup():
log.info("Link " + str(link) + " closed")
proc: ProcessState | None = ProcessState.get_for_tag(link.link_id)
proc: Session | None = Session.get_for_tag(link.link_id)
if proc is None:
log.warning(f"No process for link {link}")
else:
@ -497,7 +523,7 @@ def _listen_link_closed(link: RNS.Link):
_retry_timer.complete(link.link_id)
except Exception as e:
log.error(f"Error closing process for link {link}: {e}")
ProcessState.clear_tag(link.link_id)
Session.clear_tag(link.link_id)
def _initiator_identified(link, identity):
@ -513,34 +539,48 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
global _destination, _retry_timer, _loop
log = _get_logger("_listen_request")
log.debug(f"listen_execute {path} {request_id} {link_id} {remote_identity}, {requested_at}")
if not hasattr(data, "__len__") or len(data) < 1:
raise Exception("Request data invalid")
_retry_timer.complete(link_id)
link: RNS.Link = next(filter(lambda l: l.link_id == link_id, _destination.links), None)
if link is None:
raise Exception(f"Invalid request {request_id}, no link found with id {link_id}")
process_state: ProcessState | None = None
remote_version = data[Session.REQUEST_IDX_VERS]
if not _protocol_check_magic(remote_version):
raise Exception("Request magic incorrect")
if not remote_version == _PROTOCOL_VERSION_0:
response = Session.default_response()
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode("Listener<->initiator version mismatch\r\n".encode("utf-8"))
response[Session.RESPONSE_IDX_RETCODE] = 255
response[Session.RESPONSE_IDX_RDYBYTE] = 0
return response
session: Session | None = None
try:
term = data[ProcessState.REQUEST_IDX_TERM]
process_state = ProcessState.get_for_tag(link.link_id)
if process_state is None:
term = data[Session.REQUEST_IDX_TERM]
session = Session.get_for_tag(link.link_id)
if session is None:
log.debug(f"Process not found for link {link}")
process_state = _listen_start_proc(link=link,
session = _listen_start_proc(link=link,
term=term,
remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""),
loop=_loop)
# leave significant headroom for metadata and encoding
result = process_state.process_request(data, link.MDU * 4 // 3)
result = session.process_request(data, link.MDU * 4 // 3)
return result
# return ProcessState.default_response()
except Exception as e:
log.error(f"Error procesing request for link {link}: {e}")
try:
if process_state is not None and process_state.process.running:
process_state.process.terminate()
if session is not None and session.process.running:
session.process.terminate()
except Exception as ee:
log.debug(f"Error terminating process for link {link}: {ee}")
return ProcessState.default_response()
return Session.default_response()
async def _spin(until: callable = None, timeout: float | None = None) -> bool:
@ -558,7 +598,7 @@ async def _spin(until: callable = None, timeout: float | None = None) -> bool:
_link: RNS.Link | None = None
_remote_exec_grace = 2.0
_new_data: asyncio.Event | None = None
_tr = process.TTYRestorer(sys.stdin.fileno())
_tr: process.TTYRestorer | None = None
def _client_packet_handler(message, packet):
@ -632,8 +672,8 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
_link.set_packet_callback(_client_packet_handler)
request = ProcessState.default_request(sys.stdin.fileno())
request[ProcessState.REQUEST_IDX_STDIN] = (base64.b64encode(stdin) if stdin is not None else None)
request = Session.default_request(sys.stdin.fileno())
request[Session.REQUEST_IDX_STDIN] = (base64.b64encode(stdin) if stdin is not None else None)
# TODO: Tune
timeout = timeout + _link.rtt * 4 + _remote_exec_grace
@ -671,18 +711,27 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
if request_receipt.response is not None:
try:
running = request_receipt.response[ProcessState.RESPONSE_IDX_RUNNING] or True
return_code = request_receipt.response[ProcessState.RESPONSE_IDX_RETCODE]
ready_bytes = request_receipt.response[ProcessState.RESPONSE_IDX_RDYBYTE] or 0
stdout = request_receipt.response[ProcessState.RESPONSE_IDX_STDOUT]
timestamp = request_receipt.response[ProcessState.RESPONSE_IDX_TMSTAMP]
version = request_receipt.response[Session.RESPONSE_IDX_VERSION] or 0
if not _protocol_check_magic(version):
raise RemoteExecutionError("Protocol error")
elif version != _PROTOCOL_VERSION_0:
raise RemoteExecutionError("Protocol version mismatch")
running = request_receipt.response[Session.RESPONSE_IDX_RUNNING] or True
return_code = request_receipt.response[Session.RESPONSE_IDX_RETCODE]
ready_bytes = request_receipt.response[Session.RESPONSE_IDX_RDYBYTE] or 0
stdout = request_receipt.response[Session.RESPONSE_IDX_STDOUT]
if stdout is not None:
stdout = base64.b64decode(stdout)
timestamp = request_receipt.response[Session.RESPONSE_IDX_TMSTAMP]
# log.debug("data: " + (stdout.decode("utf-8") if stdout is not None else ""))
except RemoteExecutionError:
raise
except Exception as e:
raise RemoteExecutionError(f"Received invalid response") from e
_tr.raw()
if stdout is not None:
stdout = base64.b64decode(stdout)
# log.debug(f"stdout: {stdout}")
os.write(sys.stdout.fileno(), stdout)
@ -784,7 +833,7 @@ def _split_array_at(arr: [_T], at: _T) -> ([_T], [_T]):
return arr, []
async def main():
async def _rnsh_cli_main():
global _tr, _finished, _loop
import docopt
log = _get_logger("main")
@ -797,8 +846,8 @@ async def main():
Usage:
rnsh [--config <configdir>] [-i <identityfile>] [-s <service_name>] [-l] -p
rnsh -l [--config <configfile>] [-i <identityfile>] [-s <service_name>]
[-v...] [-q...] [-b] (-n | -a <identity_hash> [-a <identity_hash>]...)
[--] <program> [<arg>...]
[-v...] [-q...] [-b <period>] (-n | -a <identity_hash> [-a <identity_hash>] ...)
[--] <program> [<arg> ...]
rnsh [--config <configfile>] [-i <identityfile>] [-s <service_name>]
[-v...] [-q...] [-N] [-m] [-w <timeout>] <destination_hash>
rnsh -h
@ -810,7 +859,8 @@ Options:
-s NAME --service NAME Listen on/connect to specific service name if not default
-p --print-identity Print identity information and exit
-l --listen Listen (server) mode
-b --no-announce Do not announce service
-b --announce PERIOD Announce on startup and every PERIOD seconds
Specify 0 for PERIOD to announce on startup only.
-a HASH --allowed HASH Specify identities allowed to connect
-n --no-auth Disable authentication
-N --no-id Disable identify on connect
@ -837,7 +887,13 @@ Options:
args_print_identity = args.get("--print-identity", None) or False
args_verbose = args.get("--verbose", None) or 0
args_quiet = args.get("--quiet", None) or 0
args_no_announce = args.get("--no-announce", None) or False
args_announce = args.get("--announce", None)
try:
if args_announce:
args_announce = int(args_announce)
except ValueError:
print("Invalid value for --announce")
return 1
args_no_auth = args.get("--no-auth", None) or False
args_allowed = args.get("--allowed", None) or []
args_program = args.get("<program>", None)
@ -868,7 +924,7 @@ Options:
quietness=args_quiet,
allowed=args_allowed,
disable_auth=args_no_auth,
disable_announce=args_no_announce,
announce_period=args_announce,
)
if args_destination is not None and args_service_name is not None:
@ -893,14 +949,14 @@ Options:
def rnsh_cli():
try:
return_code = asyncio.run(main())
finally:
with exception.permit(SystemExit):
process.tty_unset_reader_callbacks(sys.stdin.fileno())
global _tr, _retry_timer
with process.TTYRestorer(sys.stdin.fileno()) as _tr:
with retry.RetryThread() as _retry_timer:
return_code = asyncio.run(_rnsh_cli_main())
with exception.permit(SystemExit):
process.tty_unset_reader_callbacks(sys.stdin.fileno())
_tr.restore()
_retry_timer.close()
sys.exit(return_code or 255)

44
tests/test_hacks.py Normal file
View File

@ -0,0 +1,44 @@
import asyncio
import rnsh.hacks as hacks
import pytest
import time
import logging
logging.getLogger().setLevel(logging.DEBUG)
class FakeLink:
def __init__(self, link_id):
self.link_id = link_id
@pytest.mark.asyncio
async def test_pruning():
def listen_request(path, data, request_id, link_id, remote_identity, requested_at):
assert path == 1
assert data == 2
assert request_id == 3
assert remote_identity == 4
assert requested_at == 5
assert link_id == 6
return 7
lhr_called = 0
link = FakeLink(6)
def link_handle_request(self, request_id, unpacked_request):
nonlocal lhr_called
lhr_called += 1
old_func = hacks.request_request_id_hack(listen_request, asyncio.get_running_loop())
hacks._link_handle_request_orig = link_handle_request
hacks._link_handle_request(link, 3, None)
link_id = hacks._request_to_link.get(3)
assert link_id == link.link_id
result = old_func(1, 2, 3, 4, 5)
assert result == 7
link_id = hacks._request_to_link.get(3)
assert link_id == 6
await asyncio.sleep(1.5)
link_id = hacks._request_to_link.get(3)
assert link_id is None

15
tests/test_rnsh.py Normal file
View File

@ -0,0 +1,15 @@
import logging
import rnsh.rnsh
logging.getLogger().setLevel(logging.DEBUG)
def test_check_magic():
magic = rnsh.rnsh._PROTOCOL_VERSION_0
# magic for version 0 is generated, make sure it comes out as expected
assert magic == 0xdeadbeef00000000
# verify the checker thinks it's right
assert rnsh.rnsh._protocol_check_magic(magic)
# scramble the magic
magic = magic | 0x00ffff0000000000
# make sure it fails now
assert not rnsh.rnsh._protocol_check_magic(magic)