From ab7f43e9106ff74303a9add671203e8899f0e1ef Mon Sep 17 00:00:00 2001 From: Aaron Heise <5148966+acehoss@users.noreply.github.com> Date: Sun, 17 Dec 2023 13:09:56 -0600 Subject: [PATCH] Improve pipe management Fixes an issue where terminating a session with ~. prevents another session from connecting for some amount of time. --- pyproject.toml | 2 +- rnsh/initiator.py | 16 ++++++++-------- rnsh/process.py | 30 +++++++++++++++++++----------- rnsh/session.py | 3 ++- tests/test_rnsh.py | 10 +++++----- 5 files changed, 35 insertions(+), 26 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 08ea7a1..6cd4ce9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "rnsh" -version = "0.1.2" +version = "0.1.3" description = "Shell over Reticulum" authors = ["acehoss "] license = "MIT" diff --git a/rnsh/initiator.py b/rnsh/initiator.py index 085a8af..3ec6bba 100644 --- a/rnsh/initiator.py +++ b/rnsh/initiator.py @@ -73,7 +73,7 @@ _loop: asyncio.AbstractEventLoop | None = None async def _check_finished(timeout: float = 0): - return await process.event_wait(_finished, timeout=timeout) + return _finished is not None and await process.event_wait(_finished, timeout=timeout) def _sigint_handler(sig, loop): @@ -120,8 +120,8 @@ async def _spin_pipe(until: callable = None, msg=None, timeout: float | None = N return True -async def _spin(until: callable = None, msg=None, timeout: float | None = None) -> bool: - if os.isatty(1): +async def _spin(until: callable = None, msg=None, timeout: float | None = None, quiet: bool = False) -> bool: + if not quiet and os.isatty(1): return await _spin_tty(until, msg, timeout) else: return await _spin_pipe(until, msg, timeout) @@ -184,7 +184,7 @@ async def _initiate_link(configdir, identitypath=None, verbosity=0, quietness=0, RNS.Transport.request_path(destination_hash) log.info(f"Requesting path...") if not await _spin(until=lambda: RNS.Transport.has_path(destination_hash), msg="Requesting path...", - timeout=timeout): + timeout=timeout, quiet=quietness > 0): raise RemoteExecutionError("Path not found") if _destination is None: @@ -205,7 +205,7 @@ async def _initiate_link(configdir, identitypath=None, verbosity=0, quietness=0, log.info(f"Establishing link...") if not await _spin(until=lambda: _link.status == RNS.Link.ACTIVE, msg="Establishing link...", - timeout=timeout): + timeout=timeout, quiet=quietness > 0): raise RemoteExecutionError("Could not establish link with " + RNS.prettyhexrep(destination_hash)) log.debug("Have link") @@ -240,7 +240,7 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness: quietness=quietness, noid=noid, destination=destination, - timeout=timeout, + timeout=timeout ) if not _link or _link.status not in [RNS.Link.ACTIVE, RNS.Link.PENDING]: @@ -253,7 +253,7 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness: channel.add_message_handler(_client_message_handler) # Next step after linking and identifying: send version - # if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5): + # if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5, quiet=quietness > 0): # print("Error bringing up link") # return 253 @@ -374,7 +374,7 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness: except: pass - await _spin(lambda: channel.is_ready_to_send(), "Waiting for channel...", 1) + await _spin(lambda: channel.is_ready_to_send(), "Waiting for channel...", 1, quietness > 0) channel.send(protocol.ExecuteCommandMesssage(cmdline=command, pipe_stdin=not os.isatty(0), pipe_stdout=not os.isatty(1), diff --git a/rnsh/process.py b/rnsh/process.py index ffbee79..f583cb7 100644 --- a/rnsh/process.py +++ b/rnsh/process.py @@ -420,6 +420,7 @@ def _launch_child(cmd_line: list[str], env: dict[str, str], stdin_is_pipe: bool, class CallbackSubprocess: # time between checks of child process PROCESS_POLL_TIME: float = 0.1 + PROCESS_PIPE_TIME: int = 60 def __init__(self, argv: [str], env: dict, loop: asyncio.AbstractEventLoop, stdout_callback: callable, stderr_callback: callable, terminated_callback: callable, stdin_is_pipe: bool, stdout_is_pipe: bool, @@ -456,20 +457,26 @@ class CallbackSubprocess: self._stderr_is_pipe = stderr_is_pipe def _ensure_pipes_closed(self): - self._log.debug("Ensuring pipes are closed") stdin = self._child_stdin stdout = self._child_stdout stderr = self._child_stderr - fds = list(filter(lambda x: x is not None, list({stdin, stdout, stderr}))) - for fd in fds: - self._log.debug(f"Closing fd {fd}") - with contextlib.suppress(OSError): - os.close(self._child_stdin) - with contextlib.suppress(OSError): - tty_unset_reader_callbacks(fd) - self._child_stdin = None - self._child_stdout = None - self._child_stderr = None + fds = set(filter(lambda x: x is not None, list({stdin, stdout, stderr}))) + self._log.debug(f"Queuing close of pipes for ended process (fds: {fds})") + + def ensure_pipes_closed_inner(): + self._log.debug(f"Ensuring pipes are closed (fds: {fds})") + for fd in fds: + self._log.debug(f"Closing fd {fd}") + with contextlib.suppress(OSError): + tty_unset_reader_callbacks(fd) + with contextlib.suppress(OSError): + os.close(fd) + + self._child_stdin = None + self._child_stdout = None + self._child_stderr = None + + self._loop.call_later(CallbackSubprocess.PROCESS_PIPE_TIME, ensure_pipes_closed_inner) def terminate(self, kill_delay: float = 1.0): """ @@ -597,6 +604,7 @@ class CallbackSubprocess: self._child_stdout, \ self._child_stderr = _launch_child(self._command, env, self._stdin_is_pipe, self._stdout_is_pipe, self._stderr_is_pipe) + self._log.debug("Started pid %d, fds: %d, %d, %d", self.pid, self._child_stdin, self._child_stdout, self._child_stderr) def poll(): # self.log.debug("poll") diff --git a/rnsh/session.py b/rnsh/session.py index 1b4b906..a56d971 100644 --- a/rnsh/session.py +++ b/rnsh/session.py @@ -147,10 +147,11 @@ class ListenerSession: self.send(protocol.ErrorMessage(error, True)) self.state = LSState.LSSTATE_ERROR self._terminate_process() - self._call(self._prune, max(self.outlet.rtt * 3, 5)) + self._call(self._prune, max(self.outlet.rtt * 3, process.CallbackSubprocess.PROCESS_PIPE_TIME+5)) def _prune(self): self.state = LSState.LSSTATE_TEARDOWN + self._log.debug("Pruning session") with contextlib.suppress(ValueError): self.sessions.remove(self) with contextlib.suppress(Exception): diff --git a/tests/test_rnsh.py b/tests/test_rnsh.py index 6132268..6cde405 100644 --- a/tests/test_rnsh.py +++ b/tests/test_rnsh.py @@ -126,13 +126,13 @@ async def do_connected_test(listener_args: str, initiator_args: str, test: calla initiator_args = initiator_args.replace("dh", dh) listener_args = listener_args.replace("iih", iih) with tests.helpers.SubprocessReader(name="listener", argv=shlex.split(f"poetry run -- rnsh -l -c \"{td}\" {listener_args}")) as listener, \ - tests.helpers.SubprocessReader(name="initiator", argv=shlex.split(f"poetry run -- rnsh -c \"{td}\" {initiator_args}")) as initiator: + tests.helpers.SubprocessReader(name="initiator", argv=shlex.split(f"poetry run -- rnsh -q -c \"{td}\" {initiator_args}")) as initiator: # listener startup listener.start() await asyncio.sleep(0.1) assert listener.process.running # wait for process to start up - await asyncio.sleep(2) + await asyncio.sleep(5) # read the output text = listener.read().decode("utf-8") assert text.index(dh) is not None @@ -166,7 +166,7 @@ async def test_rnsh_get_echo_through(): while initiator.return_code is None and time.time() - start_time < 3: await asyncio.sleep(0.1) text = initiator.read().decode("utf-8").replace("\r", "").replace("\n", "") - assert text[len(text)-len(cwd):] == cwd + assert text == cwd await do_connected_test("-n -C -- /bin/pwd", "dh", test) @@ -182,7 +182,7 @@ async def test_rnsh_no_ident(): while initiator.return_code is None and time.time() - start_time < 3: await asyncio.sleep(0.1) text = initiator.read().decode("utf-8").replace("\r", "").replace("\n", "") - assert text[len(text)-len(cwd):] == cwd + assert text == cwd await do_connected_test("-n -C -- /bin/pwd", "-N dh", test) @@ -214,7 +214,7 @@ async def test_rnsh_valid_ident(): while initiator.return_code is None and time.time() - start_time < 3: await asyncio.sleep(0.1) text = initiator.read().decode("utf-8").replace("\r", "").replace("\n", "") - assert text[len(text)-len(cwd):] == cwd + assert (text == cwd) await do_connected_test("-a iih -C -- /bin/pwd", "dh", test)