mirror of
https://github.com/markqvist/rnsh.git
synced 2025-01-07 05:07:57 -05:00
Performance and stability improvements.
This commit is contained in:
parent
7d711cde6d
commit
5e755acad4
15
README.md
15
README.md
@ -66,10 +66,10 @@ rnsh a5f72aefc2cb3cdba648f73f77c4e887
|
|||||||
Usage:
|
Usage:
|
||||||
rnsh [--config <configdir>] [-i <identityfile>] [-s <service_name>] [-l] -p
|
rnsh [--config <configdir>] [-i <identityfile>] [-s <service_name>] [-l] -p
|
||||||
rnsh -l [--config <configfile>] [-i <identityfile>] [-s <service_name>]
|
rnsh -l [--config <configfile>] [-i <identityfile>] [-s <service_name>]
|
||||||
[-v...] [-q...] [-b] (-n | -a <identity_hash> [-a <identity_hash>]...)
|
[-v... | -q...] [-b <period>] (-n | -a <identity_hash> [-a <identity_hash>] ...)
|
||||||
[--] <program> [<arg>...]
|
[--] <program> [<arg> ...]
|
||||||
rnsh [--config <configfile>] [-i <identityfile>] [-s <service_name>]
|
rnsh [--config <configfile>] [-i <identityfile>] [-s <service_name>]
|
||||||
[-v...] [-q...] [-N] [-m] [-w <timeout>] <destination_hash>
|
[-v... | -q...] [-N] [-m] [-w <timeout>] <destination_hash>
|
||||||
rnsh -h
|
rnsh -h
|
||||||
rnsh --version
|
rnsh --version
|
||||||
|
|
||||||
@ -79,13 +79,20 @@ Options:
|
|||||||
-s NAME --service NAME Listen on/connect to specific service name if not default
|
-s NAME --service NAME Listen on/connect to specific service name if not default
|
||||||
-p --print-identity Print identity information and exit
|
-p --print-identity Print identity information and exit
|
||||||
-l --listen Listen (server) mode
|
-l --listen Listen (server) mode
|
||||||
-b --no-announce Do not announce service
|
-b --announce PERIOD Announce on startup and every PERIOD seconds
|
||||||
|
Specify 0 for PERIOD to announce on startup only.
|
||||||
-a HASH --allowed HASH Specify identities allowed to connect
|
-a HASH --allowed HASH Specify identities allowed to connect
|
||||||
-n --no-auth Disable authentication
|
-n --no-auth Disable authentication
|
||||||
-N --no-id Disable identify on connect
|
-N --no-id Disable identify on connect
|
||||||
-m --mirror Client returns with code of remote process
|
-m --mirror Client returns with code of remote process
|
||||||
-w TIME --timeout TIME Specify client connect and request timeout in seconds
|
-w TIME --timeout TIME Specify client connect and request timeout in seconds
|
||||||
-v --verbose Increase verbosity
|
-v --verbose Increase verbosity
|
||||||
|
DEFAULT LEVEL
|
||||||
|
CRITICAL
|
||||||
|
Initiator -> ERROR
|
||||||
|
WARNING
|
||||||
|
Listener -> INFO
|
||||||
|
DEBUG
|
||||||
-q --quiet Increase quietness
|
-q --quiet Increase quietness
|
||||||
--version Show version
|
--version Show version
|
||||||
-h --help Show this help
|
-h --help Show this help
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "rnsh"
|
name = "rnsh"
|
||||||
version = "0.0.2"
|
version = "0.0.3"
|
||||||
description = "Shell over Reticulum"
|
description = "Shell over Reticulum"
|
||||||
authors = ["acehoss <acehoss@acehoss.net>"]
|
authors = ["acehoss <acehoss@acehoss.net>"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
129
rnsh/process.py
129
rnsh/process.py
@ -22,6 +22,7 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import copy
|
||||||
import errno
|
import errno
|
||||||
import fcntl
|
import fcntl
|
||||||
import functools
|
import functools
|
||||||
@ -77,24 +78,27 @@ def tty_read(fd: int) -> bytes:
|
|||||||
if fd_is_closed(fd):
|
if fd_is_closed(fd):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
run = True
|
try:
|
||||||
result = bytearray()
|
run = True
|
||||||
while run and not fd_is_closed(fd):
|
result = bytearray()
|
||||||
ready, _, _ = select.select([fd], [], [], 0)
|
while run and not fd_is_closed(fd):
|
||||||
if len(ready) == 0:
|
ready, _, _ = select.select([fd], [], [], 0)
|
||||||
break
|
if len(ready) == 0:
|
||||||
for f in ready:
|
break
|
||||||
try:
|
for f in ready:
|
||||||
data = os.read(f, 512)
|
try:
|
||||||
except OSError as e:
|
data = os.read(f, 4096)
|
||||||
if e.errno != errno.EIO and e.errno != errno.EWOULDBLOCK:
|
except OSError as e:
|
||||||
raise
|
if e.errno != errno.EIO and e.errno != errno.EWOULDBLOCK:
|
||||||
else:
|
raise
|
||||||
if not data: # EOF
|
else:
|
||||||
run = False
|
if not data: # EOF
|
||||||
if data is not None and len(data) > 0:
|
run = False
|
||||||
result.extend(data)
|
if data is not None and len(data) > 0:
|
||||||
return result
|
result.extend(data)
|
||||||
|
return result
|
||||||
|
except Exception as ex:
|
||||||
|
module_logger.error("tty_read error: {ex}")
|
||||||
|
|
||||||
|
|
||||||
def fd_is_closed(fd: int) -> bool:
|
def fd_is_closed(fd: int) -> bool:
|
||||||
@ -169,7 +173,7 @@ class TTYRestorer(contextlib.AbstractContextManager):
|
|||||||
ATTR_IDX_LFLAG = 4
|
ATTR_IDX_LFLAG = 4
|
||||||
ATTR_IDX_CC = 5
|
ATTR_IDX_CC = 5
|
||||||
|
|
||||||
def __init__(self, fd: int):
|
def __init__(self, fd: int, suppress_logs=False):
|
||||||
"""
|
"""
|
||||||
Saves termios attributes for a tty for later restoration.
|
Saves termios attributes for a tty for later restoration.
|
||||||
|
|
||||||
@ -183,29 +187,37 @@ class TTYRestorer(contextlib.AbstractContextManager):
|
|||||||
|
|
||||||
:param fd: file descriptor of tty
|
:param fd: file descriptor of tty
|
||||||
"""
|
"""
|
||||||
|
self._log = module_logger.getChild(self.__class__.__name__)
|
||||||
self._fd = fd
|
self._fd = fd
|
||||||
self._tattr = None
|
self._tattr = None
|
||||||
with contextlib.suppress(termios.error):
|
self._suppress_logs = suppress_logs
|
||||||
termios.tcgetattr(self._fd)
|
self._tattr = self.current_attr()
|
||||||
|
if not self._tattr and not self._suppress_logs:
|
||||||
|
self._log.warning(f"Could not get attrs for fd {fd}")
|
||||||
|
|
||||||
def raw(self):
|
def raw(self):
|
||||||
"""
|
"""
|
||||||
Set raw mode on tty
|
Set raw mode on tty
|
||||||
"""
|
"""
|
||||||
if not self._fd:
|
if self._fd is None:
|
||||||
return
|
return
|
||||||
|
with contextlib.suppress(termios.error):
|
||||||
|
tty.setraw(self._fd, termios.TCSANOW)
|
||||||
|
|
||||||
tty.setraw(self._fd, termios.TCSADRAIN)
|
def original_attr(self) -> [any]:
|
||||||
|
return copy.deepcopy(self._tattr)
|
||||||
|
|
||||||
def current_attr(self) -> [any]:
|
def current_attr(self) -> [any]:
|
||||||
"""
|
"""
|
||||||
Get the current termios attributes for the wrapped fd.
|
Get the current termios attributes for the wrapped fd.
|
||||||
:return: attribute array
|
:return: attribute array
|
||||||
"""
|
"""
|
||||||
if not self._fd:
|
if self._fd is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return termios.tcgetattr(self._fd)
|
with contextlib.suppress(termios.error):
|
||||||
|
return copy.deepcopy(termios.tcgetattr(self._fd))
|
||||||
|
return None
|
||||||
|
|
||||||
def set_attr(self, attr: [any], when: int = termios.TCSANOW):
|
def set_attr(self, attr: [any], when: int = termios.TCSANOW):
|
||||||
"""
|
"""
|
||||||
@ -213,7 +225,7 @@ class TTYRestorer(contextlib.AbstractContextManager):
|
|||||||
:param attr: attribute list to set
|
:param attr: attribute list to set
|
||||||
:param when: when attributes should be applied (termios.TCSANOW, termios.TCSADRAIN, termios.TCSAFLUSH)
|
:param when: when attributes should be applied (termios.TCSANOW, termios.TCSADRAIN, termios.TCSAFLUSH)
|
||||||
"""
|
"""
|
||||||
if not attr or not self._fd:
|
if not attr or self._fd is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
with contextlib.suppress(termios.error):
|
with contextlib.suppress(termios.error):
|
||||||
@ -228,7 +240,64 @@ class TTYRestorer(contextlib.AbstractContextManager):
|
|||||||
def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
|
def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
|
||||||
__traceback: types.TracebackType) -> bool:
|
__traceback: types.TracebackType) -> bool:
|
||||||
self.restore()
|
self.restore()
|
||||||
return __exc_type is not None and issubclass(__exc_type, termios.error)
|
return False #__exc_type is not None and issubclass(__exc_type, termios.error)
|
||||||
|
|
||||||
|
|
||||||
|
def _task_from_event(evt: asyncio.Event, loop: asyncio.AbstractEventLoop = None):
|
||||||
|
if not loop:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
#TODO: this is hacky
|
||||||
|
async def wait():
|
||||||
|
while not evt.is_set():
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
return True
|
||||||
|
|
||||||
|
return loop.create_task(wait())
|
||||||
|
|
||||||
|
|
||||||
|
class AggregateException(Exception):
|
||||||
|
def __init__(self, inner_exceptions: [Exception]):
|
||||||
|
super().__init__()
|
||||||
|
self.inner_exceptions = inner_exceptions
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "Multiple exceptions encountered: \n\n" + "\n\n".join(map(lambda e: str(e), self.inner_exceptions))
|
||||||
|
|
||||||
|
async def event_wait_any(evts: [asyncio.Event], timeout: float = None) -> (any, any):
|
||||||
|
tasks = list(map(lambda evt: (evt, _task_from_event(evt)), evts))
|
||||||
|
# try:
|
||||||
|
finished, unfinished = await asyncio.wait(map(lambda t: t[1], tasks),
|
||||||
|
timeout=timeout,
|
||||||
|
return_when=asyncio.FIRST_COMPLETED)
|
||||||
|
|
||||||
|
if len(unfinished) > 0:
|
||||||
|
for task in unfinished:
|
||||||
|
task.cancel()
|
||||||
|
await asyncio.wait(unfinished)
|
||||||
|
|
||||||
|
# exceptions = []
|
||||||
|
#
|
||||||
|
# for f in finished:
|
||||||
|
# ex = f.exception()
|
||||||
|
# if ex and not isinstance(ex, asyncio.CancelledError) and not isinstance(ex, TimeoutError):
|
||||||
|
# exceptions.append(ex)
|
||||||
|
#
|
||||||
|
# if len(exceptions) > 0:
|
||||||
|
# raise AggregateException(exceptions)
|
||||||
|
|
||||||
|
return next(map(lambda t: next(map(lambda tt: tt[0], tasks)), finished), None)
|
||||||
|
# finally:
|
||||||
|
# unfinished = []
|
||||||
|
# for task in map(lambda t: t[1], tasks):
|
||||||
|
# if task.done():
|
||||||
|
# if not task.cancelled():
|
||||||
|
# task.exception()
|
||||||
|
# else:
|
||||||
|
# task.cancel()
|
||||||
|
# unfinished.append(task)
|
||||||
|
# if len(unfinished) > 0:
|
||||||
|
# await asyncio.wait(unfinished)
|
||||||
|
|
||||||
|
|
||||||
async def event_wait(evt: asyncio.Event, timeout: float) -> bool:
|
async def event_wait(evt: asyncio.Event, timeout: float) -> bool:
|
||||||
@ -238,9 +307,7 @@ async def event_wait(evt: asyncio.Event, timeout: float) -> bool:
|
|||||||
:param timeout: maximum number of seconds to wait.
|
:param timeout: maximum number of seconds to wait.
|
||||||
:return: True if event was set, False if timeout expired
|
:return: True if event was set, False if timeout expired
|
||||||
"""
|
"""
|
||||||
# suppress TimeoutError because we'll return False in case of timeout
|
await event_wait_any([evt], timeout=timeout)
|
||||||
with contextlib.suppress(asyncio.TimeoutError):
|
|
||||||
await asyncio.wait_for(evt.wait(), timeout)
|
|
||||||
return evt.is_set()
|
return evt.is_set()
|
||||||
|
|
||||||
|
|
||||||
@ -408,6 +475,8 @@ class CallbackSubprocess:
|
|||||||
# self.log.debug("poll")
|
# self.log.debug("poll")
|
||||||
try:
|
try:
|
||||||
pid, self._return_code = os.waitpid(self._pid, os.WNOHANG)
|
pid, self._return_code = os.waitpid(self._pid, os.WNOHANG)
|
||||||
|
if self._return_code is not None:
|
||||||
|
self._return_code = self._return_code & 0xff
|
||||||
if self._return_code is not None and not process_exists(self._pid):
|
if self._return_code is not None and not process_exists(self._pid):
|
||||||
self._log.debug(f"polled return code {self._return_code}")
|
self._log.debug(f"polled return code {self._return_code}")
|
||||||
self._terminated_cb(self._return_code)
|
self._terminated_cb(self._return_code)
|
||||||
|
173
rnsh/rnsh.py
173
rnsh/rnsh.py
@ -34,6 +34,7 @@ import sys
|
|||||||
import termios
|
import termios
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import tty
|
||||||
from typing import Callable, TypeVar
|
from typing import Callable, TypeVar
|
||||||
import RNS
|
import RNS
|
||||||
import rnsh.exception as exception
|
import rnsh.exception as exception
|
||||||
@ -58,20 +59,20 @@ _allow_all = False
|
|||||||
_allowed_identity_hashes = []
|
_allowed_identity_hashes = []
|
||||||
_cmd: [str] = None
|
_cmd: [str] = None
|
||||||
DATA_AVAIL_MSG = "data available"
|
DATA_AVAIL_MSG = "data available"
|
||||||
_finished: asyncio.Event | None = None
|
_finished: asyncio.Event = None
|
||||||
_retry_timer: retry.RetryThread | None = None
|
_retry_timer: retry.RetryThread | None = None
|
||||||
_destination: RNS.Destination | None = None
|
_destination: RNS.Destination | None = None
|
||||||
_loop: asyncio.AbstractEventLoop | None = None
|
_loop: asyncio.AbstractEventLoop | None = None
|
||||||
|
|
||||||
|
|
||||||
async def _check_finished(timeout: float = 0):
|
async def _check_finished(timeout: float = 0):
|
||||||
await process.event_wait(_finished, timeout=timeout)
|
return await process.event_wait(_finished, timeout=timeout)
|
||||||
|
|
||||||
|
|
||||||
def _sigint_handler(sig, frame):
|
def _sigint_handler(sig, frame):
|
||||||
global _finished
|
global _finished
|
||||||
log = _get_logger("_sigint_handler")
|
log = _get_logger("_sigint_handler")
|
||||||
log.debug("SIGINT")
|
log.debug(signal.Signals(sig).name)
|
||||||
if _finished is not None:
|
if _finished is not None:
|
||||||
_finished.set()
|
_finished.set()
|
||||||
else:
|
else:
|
||||||
@ -107,8 +108,6 @@ def _print_identity(configdir, identitypath, service_name, include_destination:
|
|||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
|
|
||||||
# hack_goes_here
|
|
||||||
|
|
||||||
async def _listen(configdir, command, identitypath=None, service_name="default", verbosity=0, quietness=0,
|
async def _listen(configdir, command, identitypath=None, service_name="default", verbosity=0, quietness=0,
|
||||||
allowed=None, disable_auth=None, announce_period=900):
|
allowed=None, disable_auth=None, announce_period=900):
|
||||||
global _identity, _allow_all, _allowed_identity_hashes, _reticulum, _cmd, _destination
|
global _identity, _allow_all, _allowed_identity_hashes, _reticulum, _cmd, _destination
|
||||||
@ -116,8 +115,8 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
|||||||
_cmd = command
|
_cmd = command
|
||||||
|
|
||||||
targetloglevel = RNS.LOG_INFO + verbosity - quietness
|
targetloglevel = RNS.LOG_INFO + verbosity - quietness
|
||||||
rnslogging.RnsHandler.set_global_log_level(__logging.INFO)
|
|
||||||
_reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
|
_reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
|
||||||
|
rnslogging.RnsHandler.set_log_level_with_rns_level(targetloglevel)
|
||||||
_prepare_identity(identitypath)
|
_prepare_identity(identitypath)
|
||||||
_destination = RNS.Destination(_identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, service_name)
|
_destination = RNS.Destination(_identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, service_name)
|
||||||
|
|
||||||
@ -151,6 +150,7 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
|||||||
_destination.register_request_handler(
|
_destination.register_request_handler(
|
||||||
path="data",
|
path="data",
|
||||||
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
|
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
|
||||||
|
# response_generator=_listen_request,
|
||||||
allow=RNS.Destination.ALLOW_LIST,
|
allow=RNS.Destination.ALLOW_LIST,
|
||||||
allowed_list=_allowed_identity_hashes
|
allowed_list=_allowed_identity_hashes
|
||||||
)
|
)
|
||||||
@ -158,10 +158,12 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
|||||||
_destination.register_request_handler(
|
_destination.register_request_handler(
|
||||||
path="data",
|
path="data",
|
||||||
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
|
response_generator=hacks.request_request_id_hack(_listen_request, asyncio.get_running_loop()),
|
||||||
|
# response_generator=_listen_request,
|
||||||
allow=RNS.Destination.ALLOW_ALL,
|
allow=RNS.Destination.ALLOW_ALL,
|
||||||
)
|
)
|
||||||
|
|
||||||
await _check_finished()
|
if await _check_finished():
|
||||||
|
return
|
||||||
|
|
||||||
log.info("rnsh listening for commands on " + RNS.prettyhexrep(_destination.hash))
|
log.info("rnsh listening for commands on " + RNS.prettyhexrep(_destination.hash))
|
||||||
|
|
||||||
@ -171,23 +173,23 @@ async def _listen(configdir, command, identitypath=None, service_name="default",
|
|||||||
last = time.time()
|
last = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while not await _check_finished(1.0):
|
||||||
if announce_period and 0 < announce_period < time.time() - last:
|
if announce_period and 0 < announce_period < time.time() - last:
|
||||||
last = time.time()
|
last = time.time()
|
||||||
_destination.announce()
|
_destination.announce()
|
||||||
await _check_finished(1.0)
|
finally:
|
||||||
except KeyboardInterrupt:
|
|
||||||
log.warning("Shutting down")
|
log.warning("Shutting down")
|
||||||
for link in list(_destination.links):
|
for link in list(_destination.links):
|
||||||
with exception.permit(SystemExit):
|
with exception.permit(SystemExit, KeyboardInterrupt):
|
||||||
proc = Session.get_for_tag(link.link_id)
|
proc = Session.get_for_tag(link.link_id)
|
||||||
if proc is not None and proc.process.running:
|
if proc is not None and proc.process.running:
|
||||||
proc.process.terminate()
|
proc.process.terminate()
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(0)
|
||||||
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
|
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
|
||||||
for link in links_still_active:
|
for link in links_still_active:
|
||||||
if link.status != RNS.Link.CLOSED:
|
if link.status != RNS.Link.CLOSED:
|
||||||
link.teardown()
|
link.teardown()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
|
||||||
_PROTOCOL_MAGIC = 0xdeadbeef
|
_PROTOCOL_MAGIC = 0xdeadbeef
|
||||||
@ -334,21 +336,22 @@ class Session:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def default_request(stdin_fd: int | None) -> [any]:
|
def default_request(stdin_fd: int | None) -> [any]:
|
||||||
|
global _tr
|
||||||
global _PROTOCOL_VERSION_0
|
global _PROTOCOL_VERSION_0
|
||||||
request: list[any] = [
|
request: list[any] = [
|
||||||
_PROTOCOL_VERSION_0, # 0 Protocol Version
|
_PROTOCOL_VERSION_0, # 0 Protocol Version
|
||||||
None, # 1 Stdin
|
None, # 1 Stdin
|
||||||
None, # 2 TERM variable
|
None, # 2 TERM variable
|
||||||
None, # 3 termios attributes or something
|
None, # 3 termios attributes or something
|
||||||
None, # 4 terminal rows
|
None, # 4 terminal rows
|
||||||
None, # 5 terminal cols
|
None, # 5 terminal cols
|
||||||
None, # 6 terminal horizontal pixels
|
None, # 6 terminal horizontal pixels
|
||||||
None, # 7 terminal vertical pixels
|
None, # 7 terminal vertical pixels
|
||||||
].copy()
|
].copy()
|
||||||
|
|
||||||
if stdin_fd is not None:
|
if stdin_fd is not None:
|
||||||
request[Session.REQUEST_IDX_TERM] = os.environ.get("TERM", None)
|
request[Session.REQUEST_IDX_TERM] = os.environ.get("TERM", None)
|
||||||
request[Session.REQUEST_IDX_TIOS] = termios.tcgetattr(stdin_fd)
|
request[Session.REQUEST_IDX_TIOS] = _tr.original_attr() if _tr else termios.tcgetattr(stdin_fd)
|
||||||
request[Session.REQUEST_IDX_ROWS], \
|
request[Session.REQUEST_IDX_ROWS], \
|
||||||
request[Session.REQUEST_IDX_COLS], \
|
request[Session.REQUEST_IDX_COLS], \
|
||||||
request[Session.REQUEST_IDX_HPIX], \
|
request[Session.REQUEST_IDX_HPIX], \
|
||||||
@ -372,6 +375,7 @@ class Session:
|
|||||||
if term_state != self._term_state:
|
if term_state != self._term_state:
|
||||||
self._term_state = term_state
|
self._term_state = term_state
|
||||||
self._update_winsz()
|
self._update_winsz()
|
||||||
|
# self.process.tcsetattr(termios.TCSANOW, self._term_state[0])
|
||||||
if stdin is not None and len(stdin) > 0:
|
if stdin is not None and len(stdin) > 0:
|
||||||
stdin = base64.b64decode(stdin)
|
stdin = base64.b64decode(stdin)
|
||||||
self.process.write(stdin)
|
self.process.write(stdin)
|
||||||
@ -397,11 +401,11 @@ class Session:
|
|||||||
global _PROTOCOL_VERSION_0
|
global _PROTOCOL_VERSION_0
|
||||||
response: list[any] = [
|
response: list[any] = [
|
||||||
_PROTOCOL_VERSION_0, # 0: Protocol version
|
_PROTOCOL_VERSION_0, # 0: Protocol version
|
||||||
False, # 1: Process running
|
False, # 1: Process running
|
||||||
None, # 2: Return value
|
None, # 2: Return value
|
||||||
0, # 3: Number of outstanding bytes
|
0, # 3: Number of outstanding bytes
|
||||||
None, # 4: Stdout/Stderr
|
None, # 4: Stdout/Stderr
|
||||||
None, # 5: Timestamp
|
None, # 5: Timestamp
|
||||||
].copy()
|
].copy()
|
||||||
response[Session.RESPONSE_IDX_TMSTAMP] = time.time()
|
response[Session.RESPONSE_IDX_TMSTAMP] = time.time()
|
||||||
return response
|
return response
|
||||||
@ -539,7 +543,8 @@ def _initiator_identified(link, identity):
|
|||||||
def _listen_request(path, data, request_id, link_id, remote_identity, requested_at):
|
def _listen_request(path, data, request_id, link_id, remote_identity, requested_at):
|
||||||
global _destination, _retry_timer, _loop
|
global _destination, _retry_timer, _loop
|
||||||
log = _get_logger("_listen_request")
|
log = _get_logger("_listen_request")
|
||||||
log.debug(f"listen_execute {path} {request_id} {link_id} {remote_identity}, {requested_at}")
|
log.debug(
|
||||||
|
f"listen_execute {path} {RNS.prettyhexrep(request_id)} {RNS.prettyhexrep(link_id)} {remote_identity}, {requested_at}")
|
||||||
if not hasattr(data, "__len__") or len(data) < 1:
|
if not hasattr(data, "__len__") or len(data) < 1:
|
||||||
raise Exception("Request data invalid")
|
raise Exception("Request data invalid")
|
||||||
_retry_timer.complete(link_id)
|
_retry_timer.complete(link_id)
|
||||||
@ -553,7 +558,8 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
|
|||||||
|
|
||||||
if not remote_version == _PROTOCOL_VERSION_0:
|
if not remote_version == _PROTOCOL_VERSION_0:
|
||||||
response = Session.default_response()
|
response = Session.default_response()
|
||||||
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode("Listener<->initiator version mismatch\r\n".encode("utf-8"))
|
response[Session.RESPONSE_IDX_STDOUT] = base64.b64encode(
|
||||||
|
"Listener<->initiator version mismatch\r\n".encode("utf-8"))
|
||||||
response[Session.RESPONSE_IDX_RETCODE] = 255
|
response[Session.RESPONSE_IDX_RETCODE] = 255
|
||||||
response[Session.RESPONSE_IDX_RDYBYTE] = 0
|
response[Session.RESPONSE_IDX_RDYBYTE] = 0
|
||||||
return response
|
return response
|
||||||
@ -565,12 +571,12 @@ def _listen_request(path, data, request_id, link_id, remote_identity, requested_
|
|||||||
if session is None:
|
if session is None:
|
||||||
log.debug(f"Process not found for link {link}")
|
log.debug(f"Process not found for link {link}")
|
||||||
session = _listen_start_proc(link=link,
|
session = _listen_start_proc(link=link,
|
||||||
term=term,
|
term=term,
|
||||||
remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""),
|
remote_identity=RNS.hexrep(remote_identity.hash).replace(":", ""),
|
||||||
loop=_loop)
|
loop=_loop)
|
||||||
|
|
||||||
# leave significant headroom for metadata and encoding
|
# leave significant headroom for metadata and encoding
|
||||||
result = session.process_request(data, link.MDU * 4 // 3)
|
result = session.process_request(data, link.MDU * 1 // 2)
|
||||||
return result
|
return result
|
||||||
# return ProcessState.default_response()
|
# return ProcessState.default_response()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -589,7 +595,8 @@ async def _spin(until: callable = None, timeout: float | None = None) -> bool:
|
|||||||
timeout += time.time()
|
timeout += time.time()
|
||||||
|
|
||||||
while (timeout is None or time.time() < timeout) and not until():
|
while (timeout is None or time.time() < timeout) and not until():
|
||||||
await _check_finished(0.01)
|
if await _check_finished(0.001):
|
||||||
|
raise asyncio.CancelledError()
|
||||||
if timeout is not None and time.time() > timeout:
|
if timeout is not None and time.time() > timeout:
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
@ -638,8 +645,8 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
|||||||
|
|
||||||
if _reticulum is None:
|
if _reticulum is None:
|
||||||
targetloglevel = RNS.LOG_ERROR + verbosity - quietness
|
targetloglevel = RNS.LOG_ERROR + verbosity - quietness
|
||||||
rnslogging.RnsHandler.set_global_log_level(__logging.ERROR)
|
|
||||||
_reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
|
_reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel)
|
||||||
|
rnslogging.RnsHandler.set_log_level_with_rns_level(targetloglevel)
|
||||||
|
|
||||||
if _identity is None:
|
if _identity is None:
|
||||||
_prepare_identity(identitypath)
|
_prepare_identity(identitypath)
|
||||||
@ -661,6 +668,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
|||||||
)
|
)
|
||||||
|
|
||||||
if _link is None or _link.status == RNS.Link.PENDING:
|
if _link is None or _link.status == RNS.Link.PENDING:
|
||||||
|
log.debug("No link")
|
||||||
_link = RNS.Link(_destination)
|
_link = RNS.Link(_destination)
|
||||||
_link.did_identify = False
|
_link.did_identify = False
|
||||||
|
|
||||||
@ -668,6 +676,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
|||||||
if not await _spin(until=lambda: _link.status == RNS.Link.ACTIVE, timeout=timeout):
|
if not await _spin(until=lambda: _link.status == RNS.Link.ACTIVE, timeout=timeout):
|
||||||
raise RemoteExecutionError("Could not establish link with " + RNS.prettyhexrep(destination_hash))
|
raise RemoteExecutionError("Could not establish link with " + RNS.prettyhexrep(destination_hash))
|
||||||
|
|
||||||
|
log.debug("Have link")
|
||||||
if not noid and not _link.did_identify:
|
if not noid and not _link.did_identify:
|
||||||
_link.identify(_identity)
|
_link.identify(_identity)
|
||||||
_link.did_identify = True
|
_link.did_identify = True
|
||||||
@ -680,6 +689,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
|||||||
# TODO: Tune
|
# TODO: Tune
|
||||||
timeout = timeout + _link.rtt * 4 + _remote_exec_grace
|
timeout = timeout + _link.rtt * 4 + _remote_exec_grace
|
||||||
|
|
||||||
|
log.debug("Sending request")
|
||||||
request_receipt = _link.request(
|
request_receipt = _link.request(
|
||||||
path="data",
|
path="data",
|
||||||
data=request,
|
data=request,
|
||||||
@ -687,6 +697,7 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
|||||||
)
|
)
|
||||||
timeout += 0.5
|
timeout += 0.5
|
||||||
|
|
||||||
|
log.debug("Waiting for delivery")
|
||||||
await _spin(
|
await _spin(
|
||||||
until=lambda: _link.status == RNS.Link.CLOSED or (
|
until=lambda: _link.status == RNS.Link.CLOSED or (
|
||||||
request_receipt.status != RNS.RequestReceipt.FAILED and
|
request_receipt.status != RNS.RequestReceipt.FAILED and
|
||||||
@ -732,15 +743,14 @@ async def _execute(configdir, identitypath=None, verbosity=0, quietness=0, noid=
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RemoteExecutionError(f"Received invalid response") from e
|
raise RemoteExecutionError(f"Received invalid response") from e
|
||||||
|
|
||||||
_tr.raw()
|
|
||||||
if stdout is not None:
|
if stdout is not None:
|
||||||
# log.debug(f"stdout: {stdout}")
|
_tr.raw()
|
||||||
|
log.debug(f"stdout: {stdout}")
|
||||||
os.write(sys.stdout.fileno(), stdout)
|
os.write(sys.stdout.fileno(), stdout)
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
sys.stdout.flush()
|
got_bytes = len(stdout) if stdout is not None else 0
|
||||||
sys.stderr.flush()
|
log.debug(f"{got_bytes} chars received, {ready_bytes} bytes ready on server, return code {return_code}")
|
||||||
|
|
||||||
log.debug(f"{ready_bytes} bytes ready on server, return code {return_code}")
|
|
||||||
|
|
||||||
if ready_bytes > 0:
|
if ready_bytes > 0:
|
||||||
_new_data.set()
|
_new_data.set()
|
||||||
@ -761,11 +771,6 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
|
|||||||
|
|
||||||
data_buffer = bytearray()
|
data_buffer = bytearray()
|
||||||
|
|
||||||
def sigint_handler():
|
|
||||||
log.debug("KeyboardInterrupt")
|
|
||||||
data_buffer.extend("\x03".encode("utf-8"))
|
|
||||||
_new_data.set()
|
|
||||||
|
|
||||||
def sigwinch_handler():
|
def sigwinch_handler():
|
||||||
# log.debug("WindowChanged")
|
# log.debug("WindowChanged")
|
||||||
if _new_data is not None:
|
if _new_data is not None:
|
||||||
@ -773,7 +778,7 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
|
|||||||
|
|
||||||
def stdin():
|
def stdin():
|
||||||
data = process.tty_read(sys.stdin.fileno())
|
data = process.tty_read(sys.stdin.fileno())
|
||||||
# log.debug(f"stdin {data}")
|
log.debug(f"stdin {data}")
|
||||||
if data is not None:
|
if data is not None:
|
||||||
data_buffer.extend(data)
|
data_buffer.extend(data)
|
||||||
_new_data.set()
|
_new_data.set()
|
||||||
@ -782,10 +787,12 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
|
|||||||
|
|
||||||
await _check_finished()
|
await _check_finished()
|
||||||
loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler)
|
loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler)
|
||||||
|
|
||||||
# leave a lot of overhead
|
# leave a lot of overhead
|
||||||
mdu = 64
|
mdu = 64
|
||||||
|
rtt = 5
|
||||||
first_loop = True
|
first_loop = True
|
||||||
while True:
|
while not await _check_finished():
|
||||||
try:
|
try:
|
||||||
log.debug("top of client loop")
|
log.debug("top of client loop")
|
||||||
stdin = data_buffer[:mdu]
|
stdin = data_buffer[:mdu]
|
||||||
@ -806,22 +813,27 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
|
|||||||
|
|
||||||
if first_loop:
|
if first_loop:
|
||||||
first_loop = False
|
first_loop = False
|
||||||
mdu = _link.MDU * 4 // 3
|
mdu = _link.MDU // 2
|
||||||
loop.remove_signal_handler(signal.SIGINT)
|
|
||||||
loop.add_signal_handler(signal.SIGINT, sigint_handler)
|
|
||||||
_new_data.set()
|
_new_data.set()
|
||||||
|
|
||||||
|
if _link:
|
||||||
|
rtt = _link.rtt
|
||||||
|
|
||||||
if return_code is not None:
|
if return_code is not None:
|
||||||
log.debug(f"received return code {return_code}, exiting")
|
log.debug(f"received return code {return_code}, exiting")
|
||||||
with exception.permit(SystemExit):
|
with exception.permit(SystemExit, KeyboardInterrupt):
|
||||||
_link.teardown()
|
_link.teardown()
|
||||||
|
|
||||||
return return_code
|
return return_code
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
if _link and _link.status != RNS.Link.CLOSED:
|
||||||
|
_link.teardown()
|
||||||
|
return 0
|
||||||
except RemoteExecutionError as e:
|
except RemoteExecutionError as e:
|
||||||
print(e.msg)
|
print(e.msg)
|
||||||
return 255
|
return 255
|
||||||
|
|
||||||
await process.event_wait(_new_data, 5)
|
await process.event_wait_any([_new_data, _finished], timeout=min(max(rtt * 50, 5), 120))
|
||||||
|
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
@ -835,6 +847,11 @@ def _split_array_at(arr: [_T], at: _T) -> ([_T], [_T]):
|
|||||||
return arr, []
|
return arr, []
|
||||||
|
|
||||||
|
|
||||||
|
def _loop_set_signal(sig, loop):
|
||||||
|
loop.remove_signal_handler(sig)
|
||||||
|
loop.add_signal_handler(sig, functools.partial(_sigint_handler, sig, None))
|
||||||
|
|
||||||
|
|
||||||
async def _rnsh_cli_main():
|
async def _rnsh_cli_main():
|
||||||
global _tr, _finished, _loop
|
global _tr, _finished, _loop
|
||||||
import docopt
|
import docopt
|
||||||
@ -842,16 +859,16 @@ async def _rnsh_cli_main():
|
|||||||
_loop = asyncio.get_running_loop()
|
_loop = asyncio.get_running_loop()
|
||||||
rnslogging.set_main_loop(_loop)
|
rnslogging.set_main_loop(_loop)
|
||||||
_finished = asyncio.Event()
|
_finished = asyncio.Event()
|
||||||
_loop.remove_signal_handler(signal.SIGINT)
|
_loop_set_signal(signal.SIGINT, _loop)
|
||||||
_loop.add_signal_handler(signal.SIGINT, functools.partial(_sigint_handler, signal.SIGINT, None))
|
_loop_set_signal(signal.SIGTERM, _loop)
|
||||||
usage = '''
|
usage = '''
|
||||||
Usage:
|
Usage:
|
||||||
rnsh [--config <configdir>] [-i <identityfile>] [-s <service_name>] [-l] -p
|
rnsh [--config <configdir>] [-i <identityfile>] [-s <service_name>] [-l] -p
|
||||||
rnsh -l [--config <configfile>] [-i <identityfile>] [-s <service_name>]
|
rnsh -l [--config <configfile>] [-i <identityfile>] [-s <service_name>]
|
||||||
[-v...] [-q...] [-b <period>] (-n | -a <identity_hash> [-a <identity_hash>] ...)
|
[-v... | -q...] [-b <period>] (-n | -a <identity_hash> [-a <identity_hash>] ...)
|
||||||
[--] <program> [<arg> ...]
|
[--] <program> [<arg> ...]
|
||||||
rnsh [--config <configfile>] [-i <identityfile>] [-s <service_name>]
|
rnsh [--config <configfile>] [-i <identityfile>] [-s <service_name>]
|
||||||
[-v...] [-q...] [-N] [-m] [-w <timeout>] <destination_hash>
|
[-v... | -q...] [-N] [-m] [-w <timeout>] <destination_hash>
|
||||||
rnsh -h
|
rnsh -h
|
||||||
rnsh --version
|
rnsh --version
|
||||||
|
|
||||||
@ -869,6 +886,12 @@ Options:
|
|||||||
-m --mirror Client returns with code of remote process
|
-m --mirror Client returns with code of remote process
|
||||||
-w TIME --timeout TIME Specify client connect and request timeout in seconds
|
-w TIME --timeout TIME Specify client connect and request timeout in seconds
|
||||||
-v --verbose Increase verbosity
|
-v --verbose Increase verbosity
|
||||||
|
DEFAULT LEVEL
|
||||||
|
CRITICAL
|
||||||
|
Initiator -> ERROR
|
||||||
|
WARNING
|
||||||
|
Listener -> INFO
|
||||||
|
DEBUG
|
||||||
-q --quiet Increase quietness
|
-q --quiet Increase quietness
|
||||||
--version Show version
|
--version Show version
|
||||||
-h --help Show this help
|
-h --help Show this help
|
||||||
@ -930,36 +953,40 @@ Options:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if args_destination is not None and args_service_name is not None:
|
if args_destination is not None and args_service_name is not None:
|
||||||
try:
|
return_code = await _initiate(
|
||||||
return_code = await _initiate(
|
configdir=args_config,
|
||||||
configdir=args_config,
|
identitypath=args_identity,
|
||||||
identitypath=args_identity,
|
verbosity=args_verbose,
|
||||||
verbosity=args_verbose,
|
quietness=args_quiet,
|
||||||
quietness=args_quiet,
|
noid=args_no_id,
|
||||||
noid=args_no_id,
|
destination=args_destination,
|
||||||
destination=args_destination,
|
service_name=args_service_name,
|
||||||
service_name=args_service_name,
|
timeout=args_timeout,
|
||||||
timeout=args_timeout,
|
)
|
||||||
)
|
return return_code if args_mirror else 0
|
||||||
return return_code if args_mirror else 0
|
|
||||||
finally:
|
|
||||||
_tr.restore()
|
|
||||||
else:
|
else:
|
||||||
print("")
|
print("")
|
||||||
print(args)
|
print(args)
|
||||||
print("")
|
print("")
|
||||||
|
|
||||||
|
|
||||||
|
def _noop():
|
||||||
|
pass
|
||||||
|
|
||||||
|
# RNS.exit = _noop
|
||||||
|
|
||||||
|
|
||||||
def rnsh_cli():
|
def rnsh_cli():
|
||||||
global _tr, _retry_timer
|
global _tr, _retry_timer
|
||||||
with process.TTYRestorer(sys.stdin.fileno()) as _tr:
|
with process.TTYRestorer(sys.stdin.fileno()) as _tr, retry.RetryThread() as _retry_timer:
|
||||||
with retry.RetryThread() as _retry_timer:
|
return_code = asyncio.run(_rnsh_cli_main())
|
||||||
return_code = asyncio.run(_rnsh_cli_main())
|
|
||||||
|
|
||||||
with exception.permit(SystemExit):
|
with exception.permit(SystemExit):
|
||||||
process.tty_unset_reader_callbacks(sys.stdin.fileno())
|
process.tty_unset_reader_callbacks(sys.stdin.fileno())
|
||||||
|
|
||||||
sys.exit(return_code or 255)
|
# RNS.Reticulum.exit_handler()
|
||||||
|
# time.sleep(0.5)
|
||||||
|
sys.exit(return_code if return_code is not None else 255)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -60,10 +60,29 @@ class RnsHandler(Handler):
|
|||||||
return RNS.LOG_DEBUG
|
return RNS.LOG_DEBUG
|
||||||
return RNS.LOG_DEBUG
|
return RNS.LOG_DEBUG
|
||||||
|
|
||||||
|
def get_logging_loglevel(rnsloglevel: int) -> int:
|
||||||
|
if rnsloglevel == RNS.LOG_CRITICAL:
|
||||||
|
return logging.CRITICAL
|
||||||
|
if rnsloglevel == RNS.LOG_ERROR:
|
||||||
|
return logging.ERROR
|
||||||
|
if rnsloglevel == RNS.LOG_WARNING:
|
||||||
|
return logging.WARNING
|
||||||
|
if rnsloglevel == RNS.LOG_NOTICE:
|
||||||
|
return logging.INFO
|
||||||
|
if rnsloglevel == RNS.LOG_INFO:
|
||||||
|
return logging.INFO
|
||||||
|
if rnsloglevel >= RNS.LOG_VERBOSE:
|
||||||
|
return RNS.LOG_DEBUG
|
||||||
|
return RNS.LOG_DEBUG
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_global_log_level(cls, log_level: int):
|
def set_log_level_with_rns_level(cls, rns_log_level: int):
|
||||||
logging.getLogger().setLevel(log_level)
|
logging.getLogger().setLevel(RnsHandler.get_logging_loglevel(rns_log_level))
|
||||||
RNS.loglevel = cls.get_rns_loglevel(log_level)
|
RNS.loglevel = rns_log_level
|
||||||
|
|
||||||
|
def set_log_level_with_logging_level(cls, logging_log_level: int):
|
||||||
|
logging.getLogger().setLevel(logging_log_level)
|
||||||
|
RNS.loglevel = cls.get_rns_loglevel(logging_log_level)
|
||||||
|
|
||||||
def emit(self, record):
|
def emit(self, record):
|
||||||
"""
|
"""
|
||||||
@ -107,13 +126,16 @@ _rns_log_orig = RNS.log
|
|||||||
|
|
||||||
|
|
||||||
def _rns_log(msg, level=3, _override_destination=False):
|
def _rns_log(msg, level=3, _override_destination=False):
|
||||||
|
if RNS.loglevel < level:
|
||||||
|
return
|
||||||
|
|
||||||
if not RNS.compact_log_fmt:
|
if not RNS.compact_log_fmt:
|
||||||
msg = (" " * (7 - len(RNS.loglevelname(level)))) + msg
|
msg = (" " * (7 - len(RNS.loglevelname(level)))) + msg
|
||||||
|
|
||||||
def _rns_log_inner():
|
def _rns_log_inner():
|
||||||
nonlocal msg, level, _override_destination
|
nonlocal msg, level, _override_destination
|
||||||
try:
|
try:
|
||||||
with process.TTYRestorer(sys.stdin.fileno()) as tr:
|
with process.TTYRestorer(sys.stdin.fileno(), suppress_logs=True) as tr:
|
||||||
attr = tr.current_attr()
|
attr = tr.current_attr()
|
||||||
if attr:
|
if attr:
|
||||||
attr[process.TTYRestorer.ATTR_IDX_OFLAG] = attr[process.TTYRestorer.ATTR_IDX_OFLAG] | \
|
attr[process.TTYRestorer.ATTR_IDX_OFLAG] = attr[process.TTYRestorer.ATTR_IDX_OFLAG] | \
|
||||||
@ -123,12 +145,13 @@ def _rns_log(msg, level=3, _override_destination=False):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
_rns_log_orig(msg, level, _override_destination)
|
_rns_log_orig(msg, level, _override_destination)
|
||||||
|
|
||||||
|
# TODO: figure out if forcing this to the main thread actually helps.
|
||||||
try:
|
try:
|
||||||
if _loop and _loop.is_running():
|
if _loop and _loop.is_running():
|
||||||
_loop.call_soon_threadsafe(_rns_log_inner)
|
_loop.call_soon_threadsafe(_rns_log_inner)
|
||||||
else:
|
else:
|
||||||
_rns_log_inner()
|
_rns_log_inner()
|
||||||
except RuntimeError:
|
except:
|
||||||
_rns_log_inner()
|
_rns_log_inner()
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@ import os
|
|||||||
import threading
|
import threading
|
||||||
import types
|
import types
|
||||||
import typing
|
import typing
|
||||||
|
import multiprocessing.pool
|
||||||
logging.getLogger().setLevel(logging.DEBUG)
|
logging.getLogger().setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
@ -26,6 +27,7 @@ class State(contextlib.AbstractContextManager):
|
|||||||
loop=self.loop,
|
loop=self.loop,
|
||||||
stdout_callback=self._stdout_cb,
|
stdout_callback=self._stdout_cb,
|
||||||
terminated_callback=self._terminated_cb)
|
terminated_callback=self._terminated_cb)
|
||||||
|
|
||||||
def _stdout_cb(self, data):
|
def _stdout_cb(self, data):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._stdout.extend(data)
|
self._stdout.extend(data)
|
||||||
@ -74,3 +76,65 @@ async def test_echo():
|
|||||||
decoded = data.decode("utf-8")
|
decoded = data.decode("utf-8")
|
||||||
assert decoded == message.replace("\n", "\r\n") * 2
|
assert decoded == message.replace("\n", "\r\n") * 2
|
||||||
assert not state.process.running
|
assert not state.process.running
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_ci
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_echo_live():
|
||||||
|
"""
|
||||||
|
Check for immediate echo
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
with State(argv=["/bin/cat"],
|
||||||
|
loop=loop) as state:
|
||||||
|
state.start()
|
||||||
|
assert state.process is not None
|
||||||
|
assert state.process.running
|
||||||
|
message = "t"
|
||||||
|
state.process.write(message.encode("utf-8"))
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
data = state.read()
|
||||||
|
state.process.write(rnsh.process.CTRL_C)
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
assert len(data) > 0
|
||||||
|
decoded = data.decode("utf-8")
|
||||||
|
assert decoded == message
|
||||||
|
assert not state.process.running
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_event_wait_any():
|
||||||
|
delay = 0.1
|
||||||
|
with multiprocessing.pool.ThreadPool() as pool:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
evt1 = asyncio.Event()
|
||||||
|
evt2 = asyncio.Event()
|
||||||
|
|
||||||
|
def assert_between(min, max, val):
|
||||||
|
assert min <= val <= max
|
||||||
|
|
||||||
|
# test 1: both timeout
|
||||||
|
ts = time.time()
|
||||||
|
finished = await rnsh.process.event_wait_any([evt1, evt2], timeout=delay*2)
|
||||||
|
assert_between(delay*2, delay*2.1, time.time() - ts)
|
||||||
|
assert finished is None
|
||||||
|
assert not evt1.is_set()
|
||||||
|
assert not evt2.is_set()
|
||||||
|
|
||||||
|
#test 2: evt1 set, evt2 not set
|
||||||
|
hits = 0
|
||||||
|
|
||||||
|
def test2_bg():
|
||||||
|
nonlocal hits
|
||||||
|
hits += 1
|
||||||
|
time.sleep(delay)
|
||||||
|
evt1.set()
|
||||||
|
|
||||||
|
ts = time.time()
|
||||||
|
pool.apply_async(test2_bg)
|
||||||
|
finished = await rnsh.process.event_wait_any([evt1, evt2], timeout=delay * 2)
|
||||||
|
assert_between(delay * 0.5, delay * 1.5, time.time() - ts)
|
||||||
|
assert hits == 1
|
||||||
|
assert evt1.is_set()
|
||||||
|
assert not evt2.is_set()
|
||||||
|
assert finished == evt1
|
||||||
|
Loading…
Reference in New Issue
Block a user