Separate listener and initiator code into separate modules

This commit is contained in:
Aaron Heise 2023-02-23 06:36:55 -06:00
parent 458a2391df
commit a06d27ec15
No known key found for this signature in database
GPG key ID: 6BA54088C41DE8BF
5 changed files with 676 additions and 517 deletions

View file

@ -52,6 +52,9 @@ import rnsh.args
import pwd
import rnsh.protocol as protocol
import rnsh.helpers as helpers
import rnsh.loop
import rnsh.listener as listener
import rnsh.initiator as initiator
module_logger = __logging.getLogger(__name__)
@ -62,557 +65,99 @@ def _get_logger(name: str):
APP_NAME = "rnsh"
_identity = None
_reticulum = None
_allow_all = False
_allowed_identity_hashes = []
_cmd: [str] | None = None
DATA_AVAIL_MSG = "data available"
_finished: asyncio.Event = None
_retry_timer: retry.RetryThread | None = None
_destination: RNS.Destination | None = None
_loop: asyncio.AbstractEventLoop | None = None
_no_remote_command = True
_remote_cmd_as_args = False
async def _check_finished(timeout: float = 0):
return await process.event_wait(_finished, timeout=timeout)
def _sigint_handler(sig, frame):
global _finished
log = _get_logger("_sigint_handler")
log.debug(signal.Signals(sig).name)
if _finished is not None:
_finished.set()
else:
raise KeyboardInterrupt()
signal.signal(signal.SIGINT, _sigint_handler)
loop: asyncio.AbstractEventLoop | None = None
def _sanitize_service_name(service_name:str) -> str:
return re.sub(r'\W+', '', service_name)
def _prepare_identity(identity_path, service_name: str = None):
global _identity
def prepare_identity(identity_path, service_name: str = None) -> tuple[RNS.Identity]:
log = _get_logger("_prepare_identity")
service_name = _sanitize_service_name(service_name or "")
if identity_path is None:
identity_path = RNS.Reticulum.identitypath + "/" + APP_NAME + \
(f".{service_name}" if service_name and len(service_name) > 0 else "")
identity = None
if os.path.isfile(identity_path):
_identity = RNS.Identity.from_file(identity_path)
identity = RNS.Identity.from_file(identity_path)
if _identity is None:
if identity is None:
log.info("No valid saved identity found, creating new...")
_identity = RNS.Identity()
_identity.to_file(identity_path)
identity = RNS.Identity()
identity.to_file(identity_path)
return identity
def _print_identity(configdir, identitypath, service_name, include_destination: bool):
global _reticulum
_reticulum = RNS.Reticulum(configdir=configdir, loglevel=RNS.LOG_INFO)
def print_identity(configdir, identitypath, service_name, include_destination: bool):
reticulum = RNS.Reticulum(configdir=configdir, loglevel=RNS.LOG_INFO)
if service_name and len(service_name) > 0:
print(f"Using service name \"{service_name}\"")
_prepare_identity(identitypath, service_name)
destination = RNS.Destination(_identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME)
print("Identity : " + str(_identity))
identity = prepare_identity(identitypath, service_name)
destination = RNS.Destination(identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME)
print("Identity : " + str(identity))
if include_destination:
print("Listening on : " + RNS.prettyhexrep(destination.hash))
exit(0)
async def _listen(configdir, command, identitypath=None, service_name=None, verbosity=0, quietness=0, allowed=None,
disable_auth=None, announce_period=900, no_remote_command=True, remote_cmd_as_args=False):
global _identity, _allow_all, _allowed_identity_hashes, _reticulum, _cmd, _destination, _no_remote_command
global _remote_cmd_as_args
log = _get_logger("_listen")
if service_name is None or len(service_name) == 0:
service_name = "default"
log.info(f"Using service name {service_name}")
targetloglevel = RNS.LOG_INFO + verbosity - quietness
_reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
rnslogging.RnsHandler.set_log_level_with_rns_level(targetloglevel)
_prepare_identity(identitypath, service_name)
_destination = RNS.Destination(_identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME)
_cmd = command
if _cmd is None or len(_cmd) == 0:
shell = None
try:
shell = pwd.getpwuid(os.getuid()).pw_shell
except Exception as e:
log.error(f"Error looking up shell: {e}")
log.info(f"Using {shell} for default command.")
_cmd = [shell] if shell else None
else:
log.info(f"Using command {shlex.join(_cmd)}")
_no_remote_command = no_remote_command
session.ListenerSession.allow_remote_command = not no_remote_command
_remote_cmd_as_args = remote_cmd_as_args
if (_cmd is None or len(_cmd) == 0 or _cmd[0] is None or len(_cmd[0]) == 0) \
and (_no_remote_command or _remote_cmd_as_args):
raise Exception(f"Unable to look up shell for {os.getlogin}, cannot proceed with -A or -C and no <program>.")
session.ListenerSession.default_command = _cmd
session.ListenerSession.remote_cmd_as_args = _remote_cmd_as_args
if disable_auth:
_allow_all = True
session.ListenerSession.allow_all = True
else:
if allowed is not None:
for a in allowed:
try:
dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2
if len(a) != dest_len:
raise ValueError(
"Allowed destination length is invalid, must be {hex} hexadecimal " +
"characters ({byte} bytes).".format(
hex=dest_len, byte=dest_len // 2))
try:
destination_hash = bytes.fromhex(a)
_allowed_identity_hashes.append(destination_hash)
session.ListenerSession.allowed_identity_hashes.append(destination_hash)
except Exception:
raise ValueError("Invalid destination entered. Check your input.")
except Exception as e:
log.error(str(e))
exit(1)
if len(_allowed_identity_hashes) < 1 and not disable_auth:
log.warning("Warning: No allowed identities configured, rnsh will not accept any connections!")
def link_established(lnk: RNS.Link):
session.ListenerSession(session.RNSOutlet.get_outlet(lnk), _loop)
_destination.set_link_established_callback(link_established)
if await _check_finished():
return
log.info("rnsh listening for commands on " + RNS.prettyhexrep(_destination.hash))
if announce_period is not None:
_destination.announce()
last_announce = time.time()
sleeper = helpers.SleepRate(0.01)
try:
while not await _check_finished():
if announce_period and 0 < announce_period < time.time() - last_announce:
last_announce = time.time()
_destination.announce()
if len(session.ListenerSession.sessions) > 0:
# no sleep if there's work to do
if not await session.ListenerSession.pump_all():
await sleeper.sleep_async()
else:
await asyncio.sleep(0.25)
finally:
log.warning("Shutting down")
await session.ListenerSession.terminate_all("Shutting down")
await asyncio.sleep(1)
session.ListenerSession.messenger.shutdown()
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
for link in links_still_active:
if link.status not in [RNS.Link.CLOSED]:
link.teardown()
await asyncio.sleep(0.01)
async def _spin_tty(until=None, msg=None, timeout=None):
i = 0
syms = "⢄⢂⢁⡁⡈⡐⡠"
if timeout != None:
timeout = time.time()+timeout
print(msg+" ", end=" ")
while (timeout == None or time.time()<timeout) and not until():
await asyncio.sleep(0.1)
print(("\b\b"+syms[i]+" "), end="")
sys.stdout.flush()
i = (i+1)%len(syms)
print("\r"+" "*len(msg)+" \r", end="")
if timeout != None and time.time() > timeout:
return False
else:
return True
async def _spin_pipe(until: callable = None, msg=None, timeout: float | None = None) -> bool:
if timeout is not None:
timeout += time.time()
while (timeout is None or time.time() < timeout) and not until():
if await _check_finished(0.1):
raise asyncio.CancelledError()
if timeout is not None and time.time() > timeout:
return False
else:
return True
async def _spin(until: callable = None, msg=None, timeout: float | None = None) -> bool:
if os.isatty(1):
return await _spin_tty(until, msg, timeout)
else:
return await _spin_pipe(until, msg, timeout)
_link: RNS.Link | None = None
_remote_exec_grace = 2.0
_new_data: asyncio.Event | None = None
_tr: process.TTYRestorer | None = None
_pq = queue.Queue()
class InitiatorState(enum.IntEnum):
IS_INITIAL = 0
IS_LINKED = 1
IS_WAIT_VERS = 2
IS_RUNNING = 3
IS_TERMINATE = 4
IS_TEARDOWN = 5
def _client_link_closed(link):
log = _get_logger("_client_link_closed")
_finished.set()
def _client_packet_handler(message, packet):
global _new_data
log = _get_logger("_client_packet_handler")
packet.prove()
_pq.put(message)
class RemoteExecutionError(Exception):
def __init__(self, msg):
self.msg = msg
async def _initiate_link(configdir, identitypath=None, verbosity=0, quietness=0, noid=False, destination=None,
timeout=RNS.Transport.PATH_REQUEST_TIMEOUT):
global _identity, _reticulum, _link, _destination, _remote_exec_grace, _tr, _new_data
log = _get_logger("_initiate_link")
dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2
if len(destination) != dest_len:
raise RemoteExecutionError(
"Allowed destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format(
hex=dest_len, byte=dest_len // 2))
try:
destination_hash = bytes.fromhex(destination)
except Exception as e:
raise RemoteExecutionError("Invalid destination entered. Check your input.")
if _reticulum is None:
targetloglevel = RNS.LOG_ERROR + verbosity - quietness
_reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
rnslogging.RnsHandler.set_log_level_with_rns_level(targetloglevel)
if _identity is None:
_prepare_identity(identitypath)
if not RNS.Transport.has_path(destination_hash):
RNS.Transport.request_path(destination_hash)
log.info(f"Requesting path...")
if not await _spin(until=lambda: RNS.Transport.has_path(destination_hash), msg="Requesting path...",
timeout=timeout):
raise RemoteExecutionError("Path not found")
if _destination is None:
listener_identity = RNS.Identity.recall(destination_hash)
_destination = RNS.Destination(
listener_identity,
RNS.Destination.OUT,
RNS.Destination.SINGLE,
APP_NAME
)
if _link is None or _link.status == RNS.Link.PENDING:
log.debug("No link")
_link = RNS.Link(_destination)
_link.did_identify = False
_link.set_link_closed_callback(_client_link_closed)
log.info(f"Establishing link...")
if not await _spin(until=lambda: _link.status == RNS.Link.ACTIVE, msg="Establishing link...",
timeout=timeout):
raise RemoteExecutionError("Could not establish link with " + RNS.prettyhexrep(destination_hash))
log.debug("Have link")
if not noid and not _link.did_identify:
_link.identify(_identity)
_link.did_identify = True
_link.set_packet_callback(_client_packet_handler)
async def _handle_error(errmsg: protocol.Message):
if isinstance(errmsg, protocol.ErrorMessage):
with contextlib.suppress(Exception):
if _link and _link.status == RNS.Link.ACTIVE:
_link.teardown()
await asyncio.sleep(0.1)
raise RemoteExecutionError(f"Remote error: {errmsg.msg}")
async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness: int, noid: bool, destination: str,
timeout: float, command: [str] | None = None):
global _new_data, _finished, _tr, _cmd, _pre_input
log = _get_logger("_initiate")
loop = asyncio.get_running_loop()
_new_data = asyncio.Event()
state = InitiatorState.IS_INITIAL
data_buffer = bytearray(sys.stdin.buffer.read()) if not os.isatty(sys.stdin.fileno()) else bytearray()
await _initiate_link(
configdir=configdir,
identitypath=identitypath,
verbosity=verbosity,
quietness=quietness,
noid=noid,
destination=destination,
timeout=timeout,
)
if not _link or _link.status not in [RNS.Link.ACTIVE, RNS.Link.PENDING]:
_finished.set()
return 255
state = InitiatorState.IS_LINKED
outlet = session.RNSOutlet(_link)
with protocol.Messenger(retry_delay_min=5) as messenger:
# Next step after linking and identifying: send version
# if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5):
# print("Error bringing up link")
# return 253
messenger.send(outlet, protocol.VersionInfoMessage())
try:
vp = _pq.get(timeout=max(outlet.rtt * 20, 5))
vm = messenger.receive(vp)
await _handle_error(vm)
if not isinstance(vm, protocol.VersionInfoMessage):
raise Exception("Invalid message received")
log.debug(f"Server version info: sw {vm.sw_version} prot {vm.protocol_version}")
state = InitiatorState.IS_RUNNING
except queue.Empty:
print("Protocol error")
return 254
winch = False
def sigwinch_handler():
nonlocal winch
# log.debug("WindowChanged")
winch = True
# if _new_data is not None:
# _new_data.set()
stdin_eof = False
def stdin():
nonlocal stdin_eof
try:
data = process.tty_read(sys.stdin.fileno())
log.debug(f"stdin {data}")
if data is not None:
data_buffer.extend(data)
_new_data.set()
except EOFError:
if os.isatty(0):
data_buffer.extend(process.CTRL_D)
stdin_eof = True
process.tty_unset_reader_callbacks(sys.stdin.fileno())
process.tty_add_reader_callback(sys.stdin.fileno(), stdin)
await _check_finished()
tcattr = None
rows, cols, hpix, vpix = (None, None, None, None)
try:
tcattr = termios.tcgetattr(0)
rows, cols, hpix, vpix = process.tty_get_winsize(0)
except:
try:
tcattr = termios.tcgetattr(1)
rows, cols, hpix, vpix = process.tty_get_winsize(1)
except:
try:
tcattr = termios.tcgetattr(2)
rows, cols, hpix, vpix = process.tty_get_winsize(2)
except:
pass
messenger.send(outlet, protocol.ExecuteCommandMesssage(cmdline=command,
pipe_stdin=not os.isatty(0),
pipe_stdout=not os.isatty(1),
pipe_stderr=not os.isatty(2),
tcflags=tcattr,
term=os.environ.get("TERM", None),
rows=rows,
cols=cols,
hpix=hpix,
vpix=vpix))
loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler)
mdu = _link.MDU - 16
sent_eof = False
last_winch = time.time()
sleeper = helpers.SleepRate(0.01)
processed = False
while not await _check_finished() and state in [InitiatorState.IS_RUNNING]:
try:
try:
packet = _pq.get(timeout=sleeper.next_sleep_time() if not processed else 0.0005)
message = messenger.receive(packet)
await _handle_error(message)
processed = True
if isinstance(message, protocol.StreamDataMessage):
if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDOUT:
if message.data and len(message.data) > 0:
_tr.raw()
log.debug(f"stdout: {message.data}")
os.write(1, message.data)
sys.stdout.flush()
if message.eof:
os.close(1)
if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDERR:
if message.data and len(message.data) > 0:
_tr.raw()
log.debug(f"stdout: {message.data}")
os.write(2, message.data)
sys.stderr.flush()
if message.eof:
os.close(2)
elif isinstance(message, protocol.CommandExitedMessage):
log.debug(f"received return code {message.return_code}, exiting")
with exception.permit(SystemExit, KeyboardInterrupt):
_link.teardown()
return message.return_code
elif isinstance(message, protocol.ErrorMessage):
log.error(message.data)
if message.fatal:
_link.teardown()
return 200
except queue.Empty:
processed = False
if messenger.is_outlet_ready(outlet):
stdin = data_buffer[:mdu]
data_buffer = data_buffer[mdu:]
eof = not sent_eof and stdin_eof and len(stdin) == 0
if len(stdin) > 0 or eof:
messenger.send(outlet, protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN,
stdin, eof))
sent_eof = eof
processed = True
# send window change, but rate limited
if winch and time.time() - last_winch > _link.rtt * 25:
last_winch = time.time()
winch = False
with contextlib.suppress(Exception):
r, c, h, v = process.tty_get_winsize(0)
messenger.send(outlet, protocol.WindowSizeMessage(r, c, h, v))
processed = True
except RemoteExecutionError as e:
print(e.msg)
return 255
except Exception as ex:
print(f"Client exception: {ex}")
if _link and _link.status != RNS.Link.CLOSED:
_link.teardown()
return 127
# await process.event_wait_any([_new_data, _finished], timeout=min(max(rtt * 50, 5), 120))
# await sleeper.sleep_async()
log.debug("after main loop")
return 0
def _loop_set_signal(sig, loop):
loop.remove_signal_handler(sig)
loop.add_signal_handler(sig, functools.partial(_sigint_handler, sig, None))
async def _rnsh_cli_main():
global _tr, _finished, _loop
import docopt
log = _get_logger("main")
_loop = asyncio.get_running_loop()
rnslogging.set_main_loop(_loop)
_finished = asyncio.Event()
_loop_set_signal(signal.SIGINT, _loop)
_loop_set_signal(signal.SIGTERM, _loop)
#with contextlib.suppress(KeyboardInterrupt, SystemExit):
import docopt
log = _get_logger("main")
_loop = asyncio.get_running_loop()
rnslogging.set_main_loop(_loop)
args = rnsh.args.Args(sys.argv)
args = rnsh.args.Args(sys.argv)
if args.print_identity:
print_identity(args.config, args.identity, args.service_name, args.listen)
return 0
if args.print_identity:
_print_identity(args.config, args.identity, args.service_name, args.listen)
return 0
if args.listen:
# log.info("command " + args.command)
await listener.listen(configdir=args.config,
command=args.command_line,
identitypath=args.identity,
service_name=args.service_name,
verbosity=args.verbose,
quietness=args.quiet,
allowed=args.allowed,
disable_auth=args.no_auth,
announce_period=args.announce,
no_remote_command=args.no_remote_cmd,
remote_cmd_as_args=args.remote_cmd_as_args)
return 0
if args.listen:
# log.info("command " + args.command)
await _listen(configdir=args.config,
command=args.command_line,
identitypath=args.identity,
service_name=args.service_name,
verbosity=args.verbose,
quietness=args.quiet,
allowed=args.allowed,
disable_auth=args.no_auth,
announce_period=args.announce,
no_remote_command=args.no_remote_cmd,
remote_cmd_as_args=args.remote_cmd_as_args)
return 0
if args.destination is not None:
try:
return_code = await _initiate(
configdir=args.config,
identitypath=args.identity,
verbosity=args.verbose,
quietness=args.quiet,
noid=args.no_id,
destination=args.destination,
timeout=args.timeout,
command=args.command_line
if args.destination is not None:
return_code = await initiator.initiate(configdir=args.config,
identitypath=args.identity,
verbosity=args.verbose,
quietness=args.quiet,
noid=args.no_id,
destination=args.destination,
timeout=args.timeout,
command=args.command_line
)
return return_code if args.mirror else 0
except Exception as ex:
print(f"{ex}")
return 255;
else:
print("")
print(rnsh.args.usage)
print("")
return 1
else:
print("")
print(rnsh.args.usage)
print("")
return 1
def rnsh_cli():
global _tr, _retry_timer, _pre_input
with process.TTYRestorer(sys.stdin.fileno()) as _tr, retry.RetryThread() as _retry_timer:
return_code = 1
try:
return_code = asyncio.run(_rnsh_cli_main())
process.tty_unset_reader_callbacks(sys.stdin.fileno())
except SystemExit:
pass
except KeyboardInterrupt:
pass
except Exception as ex:
print(f"Unhandled exception: {ex}")
process.tty_unset_reader_callbacks(0)
sys.exit(return_code if return_code is not None else 255)