Protocol versioning, announce period option, old API workaround

This commit is contained in:
Aaron Heise 2023-02-11 21:57:13 -06:00
parent 9fe37b57ed
commit da3b390058
6 changed files with 293 additions and 103 deletions

View file

@ -35,12 +35,12 @@ import termios
import threading
import time
from typing import Callable, TypeVar
import RNS
import rnsh.exception as exception
import rnsh.process as process
import rnsh.retry as retry
import rnsh.rnslogging as rnslogging
import rnsh.hacks as hacks
from rnsh.__version import __version__
module_logger = __logging.getLogger(__name__)
@ -59,7 +59,7 @@ _allowed_identity_hashes = []
_cmd: [str] = None
DATA_AVAIL_MSG = "data available"
_finished: asyncio.Event | None = None
_retry_timer = retry.RetryThread()
_retry_timer: retry.RetryThread | None = None
_destination: RNS.Destination | None = None
_loop: asyncio.AbstractEventLoop | None = None
@ -107,8 +107,10 @@ def _print_identity(configdir, identitypath, service_name, include_destination:
exit(0)
# hack_goes_here
async def _listen(configdir, command, identitypath=None, service_name="default", verbosity=0, quietness=0,
allowed=None, disable_auth=None, disable_announce=False):
allowed=None, disable_auth=None, announce_period=900):
global _identity, _allow_all, _allowed_identity_hashes, _reticulum, _cmd, _destination
log = _get_logger("_listen")
_cmd = command
@ -147,14 +149,14 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
if not _allow_all:
_destination.register_request_handler(
path="data",
response_generator=_listen_request,
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
allow=RNS.Destination.ALLOW_LIST,
allowed_list=_allowed_identity_hashes
)
else:
_destination.register_request_handler(
path="data",
response_generator=_listen_request,
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
allow=RNS.Destination.ALLOW_ALL,
)
@ -162,14 +164,14 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
log.info("rnsh listening for commands on " + RNS.prettyhexrep(_destination.hash))
if not disable_announce:
if announce_period is not None:
_destination.announce()
last = time.time()
try:
while True:
if not disable_announce and time.time() - last > 900: # TODO: make parameter
if announce_period and 0 < announce_period < time.time() - last:
last = time.time()
_destination.announce()
await _check_finished(1.0)
@ -177,7 +179,7 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
log.warning("Shutting down")
for link in list(_destination.links):
with exception.permit(SystemExit):
proc = ProcessState.get_for_tag(link.link_id)
proc = Session.get_for_tag(link.link_id)
if proc is not None and proc.process.running:
proc.process.terminate()
await asyncio.sleep(1)
@ -187,17 +189,35 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
link.teardown()
class ProcessState:
_processes: [(any, ProcessState)] = []
_PROTOCOL_MAGIC = 0xdeadbeef
def _protocol_make_version(version: int):
return (_PROTOCOL_MAGIC << 32) & 0xffffffff00000000 | (0xffffffff & version)
_PROTOCOL_VERSION_0 = _protocol_make_version(0)
def _protocol_split_version(version: int):
return (version >> 32) & 0xffffffff, version & 0xffffffff
def _protocol_check_magic(value: int):
return _protocol_split_version(value)[0] == _PROTOCOL_MAGIC
class Session:
_processes: [(any, Session)] = []
_lock = threading.RLock()
@classmethod
def get_for_tag(cls, tag: any) -> ProcessState | None:
def get_for_tag(cls, tag: any) -> Session | None:
with cls._lock:
return next(map(lambda p: p[1], filter(lambda p: p[0] == tag, cls._processes)), None)
@classmethod
def put_for_tag(cls, tag: any, ps: ProcessState):
def put_for_tag(cls, tag: any, ps: Session):
with cls._lock:
cls.clear_tag(tag)
cls._processes.append((tag, ps))
@ -234,7 +254,7 @@ class ProcessState:
self._pending_receipt: RNS.PacketReceipt | None = None
self._process.start()
self._term_state: [int] = None
ProcessState.put_for_tag(tag, self)
Session.put_for_tag(tag, self)
@property
def mdu(self) -> int:
@ -302,37 +322,40 @@ class ProcessState:
except Exception as e:
self._log.debug(f"failed to update winsz: {e}")
REQUEST_IDX_STDIN = 0
REQUEST_IDX_TERM = 1
REQUEST_IDX_TIOS = 2
REQUEST_IDX_ROWS = 3
REQUEST_IDX_COLS = 4
REQUEST_IDX_HPIX = 5
REQUEST_IDX_VPIX = 6
REQUEST_IDX_VERS = 0
REQUEST_IDX_STDIN = 1
REQUEST_IDX_TERM = 2
REQUEST_IDX_TIOS = 3
REQUEST_IDX_ROWS = 4
REQUEST_IDX_COLS = 5
REQUEST_IDX_HPIX = 6
REQUEST_IDX_VPIX = 7
@staticmethod
def default_request(stdin_fd: int | None) -> [any]:
global _PROTOCOL_VERSION_0
request: list[any] = [
None, # 0 Stdin
None, # 1 TERM variable
None, # 2 termios attributes or something
None, # 3 terminal rows
None, # 4 terminal cols
None, # 5 terminal horizontal pixels
None, # 6 terminal vertical pixels
_PROTOCOL_VERSION_0, # 0 Protocol Version
None, # 1 Stdin
None, # 2 TERM variable
None, # 3 termios attributes or something
None, # 4 terminal rows
None, # 5 terminal cols
None, # 6 terminal horizontal pixels
None, # 7 terminal vertical pixels
].copy()
if stdin_fd is not None:
request[ProcessState.REQUEST_IDX_TERM] = os.environ.get("TERM", None)
request[ProcessState.REQUEST_IDX_TIOS] = termios.tcgetattr(stdin_fd)
request[ProcessState.REQUEST_IDX_ROWS], \
request[ProcessState.REQUEST_IDX_COLS], \
request[ProcessState.REQUEST_IDX_HPIX], \
request[ProcessState.REQUEST_IDX_VPIX] = process.tty_get_winsize(stdin_fd)
request[Session.REQUEST_IDX_TERM] = os.environ.get("TERM", None)
request[Session.REQUEST_IDX_TIOS] = termios.tcgetattr(stdin_fd)
request[Session.REQUEST_IDX_ROWS], \
request[Session.REQUEST_IDX_COLS], \
request[Session.REQUEST_IDX_HPIX], \
request[Session.REQUEST_IDX_VPIX] = process.tty_get_winsize(stdin_fd)
return request
def process_request(self, data: [any], read_size: int) -> [any]:
stdin = data[ProcessState.REQUEST_IDX_STDIN] # Data passed to stdin
stdin = data[Session.REQUEST_IDX_STDIN] # Data passed to stdin
# term = data[ProcessState.REQUEST_IDX_TERM] # TERM environment variable
# tios = data[ProcessState.REQUEST_IDX_TIOS] # termios attr
# rows = data[ProcessState.REQUEST_IDX_ROWS] # window rows
@ -340,10 +363,10 @@ class ProcessState:
# hpix = data[ProcessState.REQUEST_IDX_HPIX] # window horizontal pixels
# vpix = data[ProcessState.REQUEST_IDX_VPIX] # window vertical pixels
# term_state = data[ProcessState.REQUEST_IDX_ROWS:ProcessState.REQUEST_IDX_VPIX+1]
response = ProcessState.default_response()
term_state = data[ProcessState.REQUEST_IDX_TIOS:ProcessState.REQUEST_IDX_VPIX + 1]
response = Session.default_response()
term_state = data[Session.REQUEST_IDX_TIOS:Session.REQUEST_IDX_VPIX + 1]
response[ProcessState.RESPONSE_IDX_RUNNING] = self.process.running
response[Session.RESPONSE_IDX_RUNNING] = self.process.running
if self.process.running:
if term_state != self._term_state:
self._term_state = term_state
@ -351,39 +374,42 @@ class ProcessState:
if stdin is not None and len(stdin) > 0:
stdin = base64.b64decode(stdin)
self.process.write(stdin)
response[ProcessState.RESPONSE_IDX_RETCODE] = None if self.process.running else self.return_code
response[Session.RESPONSE_IDX_RETCODE] = None if self.process.running else self.return_code
with self.lock:
stdout = self.read(read_size)
response[ProcessState.RESPONSE_IDX_RDYBYTE] = len(self._data_buffer)
response[Session.RESPONSE_IDX_RDYBYTE] = len(self._data_buffer)
if stdout is not None and len(stdout) > 0:
response[ProcessState.RESPONSE_IDX_STDOUT] = base64.b64encode(stdout).decode("utf-8")
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode(stdout).decode("utf-8")
return response
RESPONSE_IDX_RUNNING = 0
RESPONSE_IDX_RETCODE = 1
RESPONSE_IDX_RDYBYTE = 2
RESPONSE_IDX_STDOUT = 3
RESPONSE_IDX_TMSTAMP = 4
RESPONSE_IDX_VERSION = 0
RESPONSE_IDX_RUNNING = 1
RESPONSE_IDX_RETCODE = 2
RESPONSE_IDX_RDYBYTE = 3
RESPONSE_IDX_STDOUT = 4
RESPONSE_IDX_TMSTAMP = 5
@staticmethod
def default_response() -> [any]:
global _PROTOCOL_VERSION_0
response: list[any] = [
False, # 0: Process running
None, # 1: Return value
0, # 2: Number of outstanding bytes
None, # 3: Stdout/Stderr
None, # 4: Timestamp
_PROTOCOL_VERSION_0, # 0: Protocol version
False, # 1: Process running
None, # 2: Return value
0, # 3: Number of outstanding bytes
None, # 4: Stdout/Stderr
None, # 5: Timestamp
].copy()
response[ProcessState.RESPONSE_IDX_TMSTAMP] = time.time()
response[Session.RESPONSE_IDX_TMSTAMP] = time.time()
return response
def _subproc_data_ready(link: RNS.Link, chars_available: int):
global _retry_timer
log = _get_logger("_subproc_data_ready")
process_state: ProcessState = ProcessState.get_for_tag(link.link_id)
session: Session = Session.get_for_tag(link.link_id)
def send(timeout: bool, tag: any, tries: int) -> any:
# log.debug("send")
@ -392,10 +418,10 @@ def _subproc_data_ready(link: RNS.Link, chars_available: int):
try:
if link.status != RNS.Link.ACTIVE:
_retry_timer.complete(link.link_id)
process_state.pending_receipt_take()
session.pending_receipt_take()
return
pr = process_state.pending_receipt_take()
pr = session.pending_receipt_take()
log.debug(f"send inner pr: {pr}")
if pr is not None and pr.status == RNS.PacketReceipt.DELIVERED:
if not timeout:
@ -405,12 +431,12 @@ def _subproc_data_ready(link: RNS.Link, chars_available: int):
else:
if not timeout:
log.info(
f"Notifying client try {tries} (retcode: {process_state.return_code} " +
f"Notifying client try {tries} (retcode: {session.return_code} " +
f"chars avail: {chars_available})")
packet = RNS.Packet(link, DATA_AVAIL_MSG.encode("utf-8"))
packet.send()
pr = packet.receipt
process_state.pending_receipt_put(pr)
session.pending_receipt_put(pr)
else:
log.error(f"Retry count exceeded, terminating link {link}")
_retry_timer.complete(link.link_id)
@ -421,7 +447,7 @@ def _subproc_data_ready(link: RNS.Link, chars_available: int):
_loop.call_soon_threadsafe(inner)
return link.link_id
with process_state.lock:
with session.lock:
if not _retry_timer.has_tag(link.link_id):
_retry_timer.begin(try_limit=15,
wait_delay=max(link.rtt * 5 if link.rtt is not None else 1, 1),
@ -436,7 +462,7 @@ def _subproc_terminated(link: RNS.Link, return_code: int):
global _loop
log = _get_logger("_subproc_terminated")
log.info(f"Subprocess returned {return_code} for link {link}")
proc = ProcessState.get_for_tag(link.link_id)
proc = Session.get_for_tag(link.link_id)
if proc is None:
log.debug(f"no proc for link {link}")
return
@ -449,7 +475,7 @@ def _subproc_terminated(link: RNS.Link, return_code: int):
try:
link.teardown()
finally:
ProcessState.clear_tag(link.link_id)
Session.clear_tag(link.link_id)
_loop.call_later(300, inner)
_loop.call_soon(_subproc_data_ready, link, 0)
@ -458,18 +484,18 @@ def _subproc_terminated(link: RNS.Link, return_code: int):
def _listen_start_proc(link: RNS.Link, remote_identity: str | None, term: str,
loop: asyncio.AbstractEventLoop) -> ProcessState | None:
loop: asyncio.AbstractEventLoop) -> Session | None:
global _cmd
log = _get_logger("_listen_start_proc")
try:
return ProcessState(tag=link.link_id,
cmd=_cmd,
term=term,
remote_identity=remote_identity,
mdu=link.MDU,
loop=loop,
data_available_callback=functools.partial(_subproc_data_ready, link),
terminated_callback=functools.partial(_subproc_terminated, link))
return Session(tag=link.link_id,
cmd=_cmd,
term=term,
remote_identity=remote_identity,
mdu=link.MDU,
loop=loop,
data_available_callback=functools.partial(_subproc_data_ready, link),
terminated_callback=functools.partial(_subproc_terminated, link))
except Exception as e:
log.error("Failed to launch process: " + str(e))
_subproc_terminated(link, 255)
@ -488,7 +514,7 @@ def _listen_link_closed(link: RNS.Link):
log = _get_logger("_listen_link_closed")
# async def cleanup():
log.info("Link " + str(link) + " closed")
proc: ProcessState | None = ProcessState.get_for_tag(link.link_id)
proc: Session | None = Session.get_for_tag(link.link_id)
if proc is None:
log.warning(f"No process for link {link}")
else:
@ -497,7 +523,7 @@ def _listen_link_closed(link: RNS.Link):
_retry_timer.complete(link.link_id)
except Exception as e:
log.error(f"Error closing process for link {link}: {e}")
ProcessState.clear_tag(link.link_id)
Session.clear_tag(link.link_id)
def _initiator_identified(link, identity):
@ -513,34 +539,48 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
global _destination, _retry_timer, _loop
log = _get_logger("_listen_request")
log.debug(f"listen_execute {path} {request_id} {link_id} {remote_identity}, {requested_at}")
if not hasattr(data, "__len__") or len(data) < 1:
raise Exception("Request data invalid")
_retry_timer.complete(link_id)
link: RNS.Link = next(filter(lambda l: l.link_id == link_id, _destination.links), None)
if link is None:
raise Exception(f"Invalid request {request_id}, no link found with id {link_id}")
process_state: ProcessState | None = None
remote_version = data[Session.REQUEST_IDX_VERS]
if not _protocol_check_magic(remote_version):
raise Exception("Request magic incorrect")
if not remote_version == _PROTOCOL_VERSION_0:
response = Session.default_response()
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode("Listener<->initiator version mismatch\r\n".encode("utf-8"))
response[Session.RESPONSE_IDX_RETCODE] = 255
response[Session.RESPONSE_IDX_RDYBYTE] = 0
return response
session: Session | None = None
try:
term = data[ProcessState.REQUEST_IDX_TERM]
process_state = ProcessState.get_for_tag(link.link_id)
if process_state is None:
term = data[Session.REQUEST_IDX_TERM]
session = Session.get_for_tag(link.link_id)
if session is None:
log.debug(f"Process not found for link {link}")
process_state = _listen_start_proc(link=link,
session = _listen_start_proc(link=link,
term=term,
remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""),
loop=_loop)
# leave significant headroom for metadata and encoding
result = process_state.process_request(data, link.MDU * 4 // 3)
result = session.process_request(data, link.MDU * 4 // 3)
return result
# return ProcessState.default_response()
except Exception as e:
log.error(f"Error procesing request for link {link}: {e}")
try:
if process_state is not None and process_state.process.running:
process_state.process.terminate()
if session is not None and session.process.running:
session.process.terminate()
except Exception as ee:
log.debug(f"Error terminating process for link {link}: {ee}")
return ProcessState.default_response()
return Session.default_response()
async def _spin(until: callable = None, timeout: float | None = None) -> bool:
@ -558,7 +598,7 @@ async def _spin(until: callable = None, timeout: float | None = None) -> bool:
_link: RNS.Link | None = None
_remote_exec_grace = 2.0
_new_data: asyncio.Event | None = None
_tr = process.TTYRestorer(sys.stdin.fileno())
_tr: process.TTYRestorer | None = None
def _client_packet_handler(message, packet):
@ -632,8 +672,8 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
_link.set_packet_callback(_client_packet_handler)
request = ProcessState.default_request(sys.stdin.fileno())
request[ProcessState.REQUEST_IDX_STDIN] = (base64.b64encode(stdin) if stdin is not None else None)
request = Session.default_request(sys.stdin.fileno())
request[Session.REQUEST_IDX_STDIN] = (base64.b64encode(stdin) if stdin is not None else None)
# TODO: Tune
timeout = timeout + _link.rtt * 4 + _remote_exec_grace
@ -671,18 +711,27 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
if request_receipt.response is not None:
try:
running = request_receipt.response[ProcessState.RESPONSE_IDX_RUNNING] or True
return_code = request_receipt.response[ProcessState.RESPONSE_IDX_RETCODE]
ready_bytes = request_receipt.response[ProcessState.RESPONSE_IDX_RDYBYTE] or 0
stdout = request_receipt.response[ProcessState.RESPONSE_IDX_STDOUT]
timestamp = request_receipt.response[ProcessState.RESPONSE_IDX_TMSTAMP]
version = request_receipt.response[Session.RESPONSE_IDX_VERSION] or 0
if not _protocol_check_magic(version):
raise RemoteExecutionError("Protocol error")
elif version != _PROTOCOL_VERSION_0:
raise RemoteExecutionError("Protocol version mismatch")
running = request_receipt.response[Session.RESPONSE_IDX_RUNNING] or True
return_code = request_receipt.response[Session.RESPONSE_IDX_RETCODE]
ready_bytes = request_receipt.response[Session.RESPONSE_IDX_RDYBYTE] or 0
stdout = request_receipt.response[Session.RESPONSE_IDX_STDOUT]
if stdout is not None:
stdout = base64.b64decode(stdout)
timestamp = request_receipt.response[Session.RESPONSE_IDX_TMSTAMP]
# log.debug("data: " + (stdout.decode("utf-8") if stdout is not None else ""))
except RemoteExecutionError:
raise
except Exception as e:
raise RemoteExecutionError(f"Received invalid response") from e
_tr.raw()
if stdout is not None:
stdout = base64.b64decode(stdout)
# log.debug(f"stdout: {stdout}")
os.write(sys.stdout.fileno(), stdout)
@ -784,7 +833,7 @@ def _split_array_at(arr: [_T], at: _T) -> ([_T], [_T]):
return arr, []
async def main():
async def _rnsh_cli_main():
global _tr, _finished, _loop
import docopt
log = _get_logger("main")
@ -797,8 +846,8 @@ async def main():
Usage:
rnsh [--config <configdir>] [-i <identityfile>] [-s <service_name>] [-l] -p
rnsh -l [--config <configfile>] [-i <identityfile>] [-s <service_name>]
[-v...] [-q...] [-b] (-n | -a <identity_hash> [-a <identity_hash>]...)
[--] <program> [<arg>...]
[-v...] [-q...] [-b <period>] (-n | -a <identity_hash> [-a <identity_hash>] ...)
[--] <program> [<arg> ...]
rnsh [--config <configfile>] [-i <identityfile>] [-s <service_name>]
[-v...] [-q...] [-N] [-m] [-w <timeout>] <destination_hash>
rnsh -h
@ -810,7 +859,8 @@ Options:
-s NAME --service NAME Listen on/connect to specific service name if not default
-p --print-identity Print identity information and exit
-l --listen Listen (server) mode
-b --no-announce Do not announce service
-b --announce PERIOD Announce on startup and every PERIOD seconds
Specify 0 for PERIOD to announce on startup only.
-a HASH --allowed HASH Specify identities allowed to connect
-n --no-auth Disable authentication
-N --no-id Disable identify on connect
@ -837,7 +887,13 @@ Options:
args_print_identity = args.get("--print-identity", None) or False
args_verbose = args.get("--verbose", None) or 0
args_quiet = args.get("--quiet", None) or 0
args_no_announce = args.get("--no-announce", None) or False
args_announce = args.get("--announce", None)
try:
if args_announce:
args_announce = int(args_announce)
except ValueError:
print("Invalid value for --announce")
return 1
args_no_auth = args.get("--no-auth", None) or False
args_allowed = args.get("--allowed", None) or []
args_program = args.get("<program>", None)
@ -868,7 +924,7 @@ Options:
quietness=args_quiet,
allowed=args_allowed,
disable_auth=args_no_auth,
disable_announce=args_no_announce,
announce_period=args_announce,
)
if args_destination is not None and args_service_name is not None:
@ -893,14 +949,14 @@ Options:
def rnsh_cli():
try:
return_code = asyncio.run(main())
finally:
with exception.permit(SystemExit):
process.tty_unset_reader_callbacks(sys.stdin.fileno())
global _tr, _retry_timer
with process.TTYRestorer(sys.stdin.fileno()) as _tr:
with retry.RetryThread() as _retry_timer:
return_code = asyncio.run(_rnsh_cli_main())
with exception.permit(SystemExit):
process.tty_unset_reader_callbacks(sys.stdin.fileno())
_tr.restore()
_retry_timer.close()
sys.exit(return_code or 255)