Improve pipe management

Fixes an issue where terminating a session with ~. prevents another session from connecting for some amount of time.
This commit is contained in:
Aaron Heise 2023-12-17 13:09:56 -06:00
parent 94b8d8ed66
commit ab7f43e910
No known key found for this signature in database
GPG Key ID: 6BA54088C41DE8BF
5 changed files with 35 additions and 26 deletions

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "rnsh"
version = "0.1.2"
version = "0.1.3"
description = "Shell over Reticulum"
authors = ["acehoss <acehoss@acehoss.net>"]
license = "MIT"

View File

@ -73,7 +73,7 @@ _loop: asyncio.AbstractEventLoop | None = None
async def _check_finished(timeout: float = 0):
return await process.event_wait(_finished, timeout=timeout)
return _finished is not None and await process.event_wait(_finished, timeout=timeout)
def _sigint_handler(sig, loop):
@ -120,8 +120,8 @@ async def _spin_pipe(until: callable = None, msg=None, timeout: float | None = N
return True
async def _spin(until: callable = None, msg=None, timeout: float | None = None) -> bool:
if os.isatty(1):
async def _spin(until: callable = None, msg=None, timeout: float | None = None, quiet: bool = False) -> bool:
if not quiet and os.isatty(1):
return await _spin_tty(until, msg, timeout)
else:
return await _spin_pipe(until, msg, timeout)
@ -184,7 +184,7 @@ async def _initiate_link(configdir, identitypath=None, verbosity=0, quietness=0,
RNS.Transport.request_path(destination_hash)
log.info(f"Requesting path...")
if not await _spin(until=lambda: RNS.Transport.has_path(destination_hash), msg="Requesting path...",
timeout=timeout):
timeout=timeout, quiet=quietness > 0):
raise RemoteExecutionError("Path not found")
if _destination is None:
@ -205,7 +205,7 @@ async def _initiate_link(configdir, identitypath=None, verbosity=0, quietness=0,
log.info(f"Establishing link...")
if not await _spin(until=lambda: _link.status == RNS.Link.ACTIVE, msg="Establishing link...",
timeout=timeout):
timeout=timeout, quiet=quietness > 0):
raise RemoteExecutionError("Could not establish link with " + RNS.prettyhexrep(destination_hash))
log.debug("Have link")
@ -240,7 +240,7 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness:
quietness=quietness,
noid=noid,
destination=destination,
timeout=timeout,
timeout=timeout
)
if not _link or _link.status not in [RNS.Link.ACTIVE, RNS.Link.PENDING]:
@ -253,7 +253,7 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness:
channel.add_message_handler(_client_message_handler)
# Next step after linking and identifying: send version
# if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5):
# if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5, quiet=quietness > 0):
# print("Error bringing up link")
# return 253
@ -374,7 +374,7 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness:
except:
pass
await _spin(lambda: channel.is_ready_to_send(), "Waiting for channel...", 1)
await _spin(lambda: channel.is_ready_to_send(), "Waiting for channel...", 1, quietness > 0)
channel.send(protocol.ExecuteCommandMesssage(cmdline=command,
pipe_stdin=not os.isatty(0),
pipe_stdout=not os.isatty(1),

View File

@ -420,6 +420,7 @@ def _launch_child(cmd_line: list[str], env: dict[str, str], stdin_is_pipe: bool,
class CallbackSubprocess:
# time between checks of child process
PROCESS_POLL_TIME: float = 0.1
PROCESS_PIPE_TIME: int = 60
def __init__(self, argv: [str], env: dict, loop: asyncio.AbstractEventLoop, stdout_callback: callable,
stderr_callback: callable, terminated_callback: callable, stdin_is_pipe: bool, stdout_is_pipe: bool,
@ -456,20 +457,26 @@ class CallbackSubprocess:
self._stderr_is_pipe = stderr_is_pipe
def _ensure_pipes_closed(self):
self._log.debug("Ensuring pipes are closed")
stdin = self._child_stdin
stdout = self._child_stdout
stderr = self._child_stderr
fds = list(filter(lambda x: x is not None, list({stdin, stdout, stderr})))
for fd in fds:
self._log.debug(f"Closing fd {fd}")
with contextlib.suppress(OSError):
os.close(self._child_stdin)
with contextlib.suppress(OSError):
tty_unset_reader_callbacks(fd)
self._child_stdin = None
self._child_stdout = None
self._child_stderr = None
fds = set(filter(lambda x: x is not None, list({stdin, stdout, stderr})))
self._log.debug(f"Queuing close of pipes for ended process (fds: {fds})")
def ensure_pipes_closed_inner():
self._log.debug(f"Ensuring pipes are closed (fds: {fds})")
for fd in fds:
self._log.debug(f"Closing fd {fd}")
with contextlib.suppress(OSError):
tty_unset_reader_callbacks(fd)
with contextlib.suppress(OSError):
os.close(fd)
self._child_stdin = None
self._child_stdout = None
self._child_stderr = None
self._loop.call_later(CallbackSubprocess.PROCESS_PIPE_TIME, ensure_pipes_closed_inner)
def terminate(self, kill_delay: float = 1.0):
"""
@ -597,6 +604,7 @@ class CallbackSubprocess:
self._child_stdout, \
self._child_stderr = _launch_child(self._command, env, self._stdin_is_pipe, self._stdout_is_pipe,
self._stderr_is_pipe)
self._log.debug("Started pid %d, fds: %d, %d, %d", self.pid, self._child_stdin, self._child_stdout, self._child_stderr)
def poll():
# self.log.debug("poll")

View File

@ -147,10 +147,11 @@ class ListenerSession:
self.send(protocol.ErrorMessage(error, True))
self.state = LSState.LSSTATE_ERROR
self._terminate_process()
self._call(self._prune, max(self.outlet.rtt * 3, 5))
self._call(self._prune, max(self.outlet.rtt * 3, process.CallbackSubprocess.PROCESS_PIPE_TIME+5))
def _prune(self):
self.state = LSState.LSSTATE_TEARDOWN
self._log.debug("Pruning session")
with contextlib.suppress(ValueError):
self.sessions.remove(self)
with contextlib.suppress(Exception):

View File

@ -126,13 +126,13 @@ async def do_connected_test(listener_args: str, initiator_args: str, test: calla
initiator_args = initiator_args.replace("dh", dh)
listener_args = listener_args.replace("iih", iih)
with tests.helpers.SubprocessReader(name="listener", argv=shlex.split(f"poetry run -- rnsh -l -c \"{td}\" {listener_args}")) as listener, \
tests.helpers.SubprocessReader(name="initiator", argv=shlex.split(f"poetry run -- rnsh -c \"{td}\" {initiator_args}")) as initiator:
tests.helpers.SubprocessReader(name="initiator", argv=shlex.split(f"poetry run -- rnsh -q -c \"{td}\" {initiator_args}")) as initiator:
# listener startup
listener.start()
await asyncio.sleep(0.1)
assert listener.process.running
# wait for process to start up
await asyncio.sleep(2)
await asyncio.sleep(5)
# read the output
text = listener.read().decode("utf-8")
assert text.index(dh) is not None
@ -166,7 +166,7 @@ async def test_rnsh_get_echo_through():
while initiator.return_code is None and time.time() - start_time < 3:
await asyncio.sleep(0.1)
text = initiator.read().decode("utf-8").replace("\r", "").replace("\n", "")
assert text[len(text)-len(cwd):] == cwd
assert text == cwd
await do_connected_test("-n -C -- /bin/pwd", "dh", test)
@ -182,7 +182,7 @@ async def test_rnsh_no_ident():
while initiator.return_code is None and time.time() - start_time < 3:
await asyncio.sleep(0.1)
text = initiator.read().decode("utf-8").replace("\r", "").replace("\n", "")
assert text[len(text)-len(cwd):] == cwd
assert text == cwd
await do_connected_test("-n -C -- /bin/pwd", "-N dh", test)
@ -214,7 +214,7 @@ async def test_rnsh_valid_ident():
while initiator.return_code is None and time.time() - start_time < 3:
await asyncio.sleep(0.1)
text = initiator.read().decode("utf-8").replace("\r", "").replace("\n", "")
assert text[len(text)-len(cwd):] == cwd
assert (text == cwd)
await do_connected_test("-a iih -C -- /bin/pwd", "dh", test)