mirror of
https://github.com/markqvist/rnsh.git
synced 2024-10-01 01:15:37 -04:00
Improve pipe management
Fixes an issue where terminating a session with ~. prevents another session from connecting for some amount of time.
This commit is contained in:
parent
94b8d8ed66
commit
ab7f43e910
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "rnsh"
|
||||
version = "0.1.2"
|
||||
version = "0.1.3"
|
||||
description = "Shell over Reticulum"
|
||||
authors = ["acehoss <acehoss@acehoss.net>"]
|
||||
license = "MIT"
|
||||
|
@ -73,7 +73,7 @@ _loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
|
||||
async def _check_finished(timeout: float = 0):
|
||||
return await process.event_wait(_finished, timeout=timeout)
|
||||
return _finished is not None and await process.event_wait(_finished, timeout=timeout)
|
||||
|
||||
|
||||
def _sigint_handler(sig, loop):
|
||||
@ -120,8 +120,8 @@ async def _spin_pipe(until: callable = None, msg=None, timeout: float | None = N
|
||||
return True
|
||||
|
||||
|
||||
async def _spin(until: callable = None, msg=None, timeout: float | None = None) -> bool:
|
||||
if os.isatty(1):
|
||||
async def _spin(until: callable = None, msg=None, timeout: float | None = None, quiet: bool = False) -> bool:
|
||||
if not quiet and os.isatty(1):
|
||||
return await _spin_tty(until, msg, timeout)
|
||||
else:
|
||||
return await _spin_pipe(until, msg, timeout)
|
||||
@ -184,7 +184,7 @@ async def _initiate_link(configdir, identitypath=None, verbosity=0, quietness=0,
|
||||
RNS.Transport.request_path(destination_hash)
|
||||
log.info(f"Requesting path...")
|
||||
if not await _spin(until=lambda: RNS.Transport.has_path(destination_hash), msg="Requesting path...",
|
||||
timeout=timeout):
|
||||
timeout=timeout, quiet=quietness > 0):
|
||||
raise RemoteExecutionError("Path not found")
|
||||
|
||||
if _destination is None:
|
||||
@ -205,7 +205,7 @@ async def _initiate_link(configdir, identitypath=None, verbosity=0, quietness=0,
|
||||
|
||||
log.info(f"Establishing link...")
|
||||
if not await _spin(until=lambda: _link.status == RNS.Link.ACTIVE, msg="Establishing link...",
|
||||
timeout=timeout):
|
||||
timeout=timeout, quiet=quietness > 0):
|
||||
raise RemoteExecutionError("Could not establish link with " + RNS.prettyhexrep(destination_hash))
|
||||
|
||||
log.debug("Have link")
|
||||
@ -240,7 +240,7 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness:
|
||||
quietness=quietness,
|
||||
noid=noid,
|
||||
destination=destination,
|
||||
timeout=timeout,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if not _link or _link.status not in [RNS.Link.ACTIVE, RNS.Link.PENDING]:
|
||||
@ -253,7 +253,7 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness:
|
||||
channel.add_message_handler(_client_message_handler)
|
||||
|
||||
# Next step after linking and identifying: send version
|
||||
# if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5):
|
||||
# if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5, quiet=quietness > 0):
|
||||
# print("Error bringing up link")
|
||||
# return 253
|
||||
|
||||
@ -374,7 +374,7 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness:
|
||||
except:
|
||||
pass
|
||||
|
||||
await _spin(lambda: channel.is_ready_to_send(), "Waiting for channel...", 1)
|
||||
await _spin(lambda: channel.is_ready_to_send(), "Waiting for channel...", 1, quietness > 0)
|
||||
channel.send(protocol.ExecuteCommandMesssage(cmdline=command,
|
||||
pipe_stdin=not os.isatty(0),
|
||||
pipe_stdout=not os.isatty(1),
|
||||
|
@ -420,6 +420,7 @@ def _launch_child(cmd_line: list[str], env: dict[str, str], stdin_is_pipe: bool,
|
||||
class CallbackSubprocess:
|
||||
# time between checks of child process
|
||||
PROCESS_POLL_TIME: float = 0.1
|
||||
PROCESS_PIPE_TIME: int = 60
|
||||
|
||||
def __init__(self, argv: [str], env: dict, loop: asyncio.AbstractEventLoop, stdout_callback: callable,
|
||||
stderr_callback: callable, terminated_callback: callable, stdin_is_pipe: bool, stdout_is_pipe: bool,
|
||||
@ -456,20 +457,26 @@ class CallbackSubprocess:
|
||||
self._stderr_is_pipe = stderr_is_pipe
|
||||
|
||||
def _ensure_pipes_closed(self):
|
||||
self._log.debug("Ensuring pipes are closed")
|
||||
stdin = self._child_stdin
|
||||
stdout = self._child_stdout
|
||||
stderr = self._child_stderr
|
||||
fds = list(filter(lambda x: x is not None, list({stdin, stdout, stderr})))
|
||||
for fd in fds:
|
||||
self._log.debug(f"Closing fd {fd}")
|
||||
with contextlib.suppress(OSError):
|
||||
os.close(self._child_stdin)
|
||||
with contextlib.suppress(OSError):
|
||||
tty_unset_reader_callbacks(fd)
|
||||
self._child_stdin = None
|
||||
self._child_stdout = None
|
||||
self._child_stderr = None
|
||||
fds = set(filter(lambda x: x is not None, list({stdin, stdout, stderr})))
|
||||
self._log.debug(f"Queuing close of pipes for ended process (fds: {fds})")
|
||||
|
||||
def ensure_pipes_closed_inner():
|
||||
self._log.debug(f"Ensuring pipes are closed (fds: {fds})")
|
||||
for fd in fds:
|
||||
self._log.debug(f"Closing fd {fd}")
|
||||
with contextlib.suppress(OSError):
|
||||
tty_unset_reader_callbacks(fd)
|
||||
with contextlib.suppress(OSError):
|
||||
os.close(fd)
|
||||
|
||||
self._child_stdin = None
|
||||
self._child_stdout = None
|
||||
self._child_stderr = None
|
||||
|
||||
self._loop.call_later(CallbackSubprocess.PROCESS_PIPE_TIME, ensure_pipes_closed_inner)
|
||||
|
||||
def terminate(self, kill_delay: float = 1.0):
|
||||
"""
|
||||
@ -597,6 +604,7 @@ class CallbackSubprocess:
|
||||
self._child_stdout, \
|
||||
self._child_stderr = _launch_child(self._command, env, self._stdin_is_pipe, self._stdout_is_pipe,
|
||||
self._stderr_is_pipe)
|
||||
self._log.debug("Started pid %d, fds: %d, %d, %d", self.pid, self._child_stdin, self._child_stdout, self._child_stderr)
|
||||
|
||||
def poll():
|
||||
# self.log.debug("poll")
|
||||
|
@ -147,10 +147,11 @@ class ListenerSession:
|
||||
self.send(protocol.ErrorMessage(error, True))
|
||||
self.state = LSState.LSSTATE_ERROR
|
||||
self._terminate_process()
|
||||
self._call(self._prune, max(self.outlet.rtt * 3, 5))
|
||||
self._call(self._prune, max(self.outlet.rtt * 3, process.CallbackSubprocess.PROCESS_PIPE_TIME+5))
|
||||
|
||||
def _prune(self):
|
||||
self.state = LSState.LSSTATE_TEARDOWN
|
||||
self._log.debug("Pruning session")
|
||||
with contextlib.suppress(ValueError):
|
||||
self.sessions.remove(self)
|
||||
with contextlib.suppress(Exception):
|
||||
|
@ -126,13 +126,13 @@ async def do_connected_test(listener_args: str, initiator_args: str, test: calla
|
||||
initiator_args = initiator_args.replace("dh", dh)
|
||||
listener_args = listener_args.replace("iih", iih)
|
||||
with tests.helpers.SubprocessReader(name="listener", argv=shlex.split(f"poetry run -- rnsh -l -c \"{td}\" {listener_args}")) as listener, \
|
||||
tests.helpers.SubprocessReader(name="initiator", argv=shlex.split(f"poetry run -- rnsh -c \"{td}\" {initiator_args}")) as initiator:
|
||||
tests.helpers.SubprocessReader(name="initiator", argv=shlex.split(f"poetry run -- rnsh -q -c \"{td}\" {initiator_args}")) as initiator:
|
||||
# listener startup
|
||||
listener.start()
|
||||
await asyncio.sleep(0.1)
|
||||
assert listener.process.running
|
||||
# wait for process to start up
|
||||
await asyncio.sleep(2)
|
||||
await asyncio.sleep(5)
|
||||
# read the output
|
||||
text = listener.read().decode("utf-8")
|
||||
assert text.index(dh) is not None
|
||||
@ -166,7 +166,7 @@ async def test_rnsh_get_echo_through():
|
||||
while initiator.return_code is None and time.time() - start_time < 3:
|
||||
await asyncio.sleep(0.1)
|
||||
text = initiator.read().decode("utf-8").replace("\r", "").replace("\n", "")
|
||||
assert text[len(text)-len(cwd):] == cwd
|
||||
assert text == cwd
|
||||
|
||||
await do_connected_test("-n -C -- /bin/pwd", "dh", test)
|
||||
|
||||
@ -182,7 +182,7 @@ async def test_rnsh_no_ident():
|
||||
while initiator.return_code is None and time.time() - start_time < 3:
|
||||
await asyncio.sleep(0.1)
|
||||
text = initiator.read().decode("utf-8").replace("\r", "").replace("\n", "")
|
||||
assert text[len(text)-len(cwd):] == cwd
|
||||
assert text == cwd
|
||||
|
||||
await do_connected_test("-n -C -- /bin/pwd", "-N dh", test)
|
||||
|
||||
@ -214,7 +214,7 @@ async def test_rnsh_valid_ident():
|
||||
while initiator.return_code is None and time.time() - start_time < 3:
|
||||
await asyncio.sleep(0.1)
|
||||
text = initiator.read().decode("utf-8").replace("\r", "").replace("\n", "")
|
||||
assert text[len(text)-len(cwd):] == cwd
|
||||
assert (text == cwd)
|
||||
|
||||
await do_connected_test("-a iih -C -- /bin/pwd", "dh", test)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user