mirror of
https://github.com/markqvist/rnsh.git
synced 2024-10-01 01:15:37 -04:00
Add tests for and fix #14
This commit is contained in:
parent
a07ce53bf9
commit
458a2391df
@ -391,7 +391,7 @@ async def _initiate(configdir: str, identitypath: str, verbosity: int, quietness
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if not _link or _link.status != RNS.Link.ACTIVE:
|
||||
if not _link or _link.status not in [RNS.Link.ACTIVE, RNS.Link.PENDING]:
|
||||
_finished.set()
|
||||
return 255
|
||||
|
||||
|
@ -101,8 +101,12 @@ class ListenerSession:
|
||||
self.return_code: int | None = None
|
||||
self.return_code_sent = False
|
||||
self.process: process.CallbackSubprocess | None = None
|
||||
self._set_state(LSState.LSSTATE_WAIT_IDENT)
|
||||
if self.allow_all:
|
||||
self._set_state(LSState.LSSTATE_WAIT_VERS)
|
||||
else:
|
||||
self._set_state(LSState.LSSTATE_WAIT_IDENT)
|
||||
self.sessions.append(self)
|
||||
self.outlet.set_packet_received_callback(self._packet_received)
|
||||
|
||||
def _terminated(self, return_code: int):
|
||||
self.return_code = return_code
|
||||
@ -176,14 +180,13 @@ class ListenerSession:
|
||||
return
|
||||
|
||||
self._log.info(f"initiator_identified {identity} on link {outlet}")
|
||||
if self.state != LSState.LSSTATE_WAIT_IDENT:
|
||||
if self.state not in [LSState.LSSTATE_WAIT_IDENT, LSState.LSSTATE_WAIT_VERS]:
|
||||
self._protocol_error(LSState.LSSTATE_WAIT_IDENT.name)
|
||||
|
||||
if not self.allow_all and identity.hash not in self.allowed_identity_hashes:
|
||||
self.terminate("Identity is not allowed.")
|
||||
|
||||
self.remote_identity = identity
|
||||
self.outlet.set_packet_received_callback(self._packet_received)
|
||||
self._set_state(LSState.LSSTATE_WAIT_VERS)
|
||||
|
||||
@classmethod
|
||||
@ -277,7 +280,8 @@ class ListenerSession:
|
||||
try:
|
||||
self.process = process.CallbackSubprocess(argv=self.cmdline,
|
||||
env={"TERM": self.term or os.environ.get("TERM", None),
|
||||
"RNS_REMOTE_IDENTITY": RNS.prettyhexrep(self.remote_identity.hash) or ""},
|
||||
"RNS_REMOTE_IDENTITY": (RNS.prettyhexrep(self.remote_identity.hash)
|
||||
if self.remote_identity and self.remote_identity.hash else "")},
|
||||
loop=self.loop,
|
||||
stdout_callback=stdout,
|
||||
stderr_callback=stderr,
|
||||
@ -306,6 +310,9 @@ class ListenerSession:
|
||||
self.process.close_stdin()
|
||||
|
||||
def _handle_message(self, message: protocol.Message):
|
||||
if self.state == LSState.LSSTATE_WAIT_IDENT:
|
||||
self._protocol_error("Identification required")
|
||||
return
|
||||
if self.state == LSState.LSSTATE_WAIT_VERS:
|
||||
if not isinstance(message, protocol.VersionInfoMessage):
|
||||
self._protocol_error(self.state.name)
|
||||
|
@ -84,6 +84,7 @@ class SubprocessReader(contextlib.AbstractContextManager):
|
||||
self._log.debug(f"cleanup()")
|
||||
if self.process and self.process.running:
|
||||
self.process.terminate(kill_delay=0.1)
|
||||
time.sleep(0.5)
|
||||
|
||||
def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
|
||||
__traceback: types.TracebackType) -> bool:
|
||||
|
@ -33,11 +33,26 @@ def test_program_initiate_no_args():
|
||||
args = rnsh.args.Args(shlex.split("rnsh one"))
|
||||
assert not args.listen
|
||||
assert args.destination == "one"
|
||||
assert not args.no_id
|
||||
assert args.command_line == []
|
||||
except docopt.DocoptExit:
|
||||
docopt_threw = True
|
||||
assert not docopt_threw
|
||||
|
||||
|
||||
def test_program_initiate_no_auth():
|
||||
docopt_threw = False
|
||||
try:
|
||||
args = rnsh.args.Args(shlex.split("rnsh -N one"))
|
||||
assert not args.listen
|
||||
assert args.destination == "one"
|
||||
assert args.no_id
|
||||
assert args.command_line == []
|
||||
except docopt.DocoptExit:
|
||||
docopt_threw = True
|
||||
assert not docopt_threw
|
||||
|
||||
|
||||
def test_program_initiate_dash_args():
|
||||
docopt_threw = False
|
||||
try:
|
||||
|
@ -123,14 +123,17 @@ async def do_connected_test(listener_args: str, initiator_args: str, test: calla
|
||||
assert len(ih) == 32
|
||||
assert len(dh) == 32
|
||||
assert len(iih) == 32
|
||||
assert "dh" in initiator_args
|
||||
initiator_args = initiator_args.replace("dh", dh)
|
||||
listener_args = listener_args.replace("iih", iih)
|
||||
with tests.helpers.SubprocessReader(name="listener", argv=shlex.split(f"poetry run -- rnsh -l -c \"{td}\" {listener_args}")) as listener, \
|
||||
tests.helpers.SubprocessReader(name="initiator", argv=shlex.split(f"poetry run -- rnsh -c \"{td}\" {dh} {initiator_args}")) as initiator:
|
||||
tests.helpers.SubprocessReader(name="initiator", argv=shlex.split(f"poetry run -- rnsh -c \"{td}\" {initiator_args}")) as initiator:
|
||||
# listener startup
|
||||
listener.start()
|
||||
await asyncio.sleep(0.1)
|
||||
assert listener.process.running
|
||||
# wait for process to start up
|
||||
await asyncio.sleep(5)
|
||||
await asyncio.sleep(2)
|
||||
# read the output
|
||||
text = listener.read().decode("utf-8")
|
||||
assert text.index(dh) is not None
|
||||
@ -166,7 +169,55 @@ async def test_rnsh_get_echo_through():
|
||||
text = initiator.read().decode("utf-8").replace("\r", "").replace("\n", "")
|
||||
assert text[len(text)-len(cwd):] == cwd
|
||||
|
||||
await do_connected_test("-n -C -- /bin/pwd", "", test)
|
||||
await do_connected_test("-n -C -- /bin/pwd", "dh", test)
|
||||
|
||||
|
||||
@pytest.mark.skip_ci
|
||||
@pytest.mark.asyncio
|
||||
async def test_rnsh_no_ident():
|
||||
cwd = os.getcwd()
|
||||
|
||||
async def test(td: str, ih: str, dh: str, iih: str, listener: tests.helpers.SubprocessReader,
|
||||
initiator: tests.helpers.SubprocessReader):
|
||||
start_time = time.time()
|
||||
while initiator.return_code is None and time.time() - start_time < 3:
|
||||
await asyncio.sleep(0.1)
|
||||
text = initiator.read().decode("utf-8").replace("\r", "").replace("\n", "")
|
||||
assert text[len(text)-len(cwd):] == cwd
|
||||
|
||||
await do_connected_test("-n -C -- /bin/pwd", "-N dh", test)
|
||||
|
||||
|
||||
@pytest.mark.skip_ci
|
||||
@pytest.mark.asyncio
|
||||
async def test_rnsh_invalid_ident():
|
||||
cwd = os.getcwd()
|
||||
|
||||
async def test(td: str, ih: str, dh: str, iih: str, listener: tests.helpers.SubprocessReader,
|
||||
initiator: tests.helpers.SubprocessReader):
|
||||
start_time = time.time()
|
||||
while initiator.return_code is None and time.time() - start_time < 3:
|
||||
await asyncio.sleep(0.1)
|
||||
text = initiator.read().decode("utf-8").replace("\r", "").replace("\n", "")
|
||||
assert "not allowed" in text
|
||||
|
||||
await do_connected_test("-a 12345678901234567890123456789012 -C -- /bin/pwd", "dh", test)
|
||||
|
||||
|
||||
@pytest.mark.skip_ci
|
||||
@pytest.mark.asyncio
|
||||
async def test_rnsh_valid_ident():
|
||||
cwd = os.getcwd()
|
||||
|
||||
async def test(td: str, ih: str, dh: str, iih: str, listener: tests.helpers.SubprocessReader,
|
||||
initiator: tests.helpers.SubprocessReader):
|
||||
start_time = time.time()
|
||||
while initiator.return_code is None and time.time() - start_time < 3:
|
||||
await asyncio.sleep(0.1)
|
||||
text = initiator.read().decode("utf-8").replace("\r", "").replace("\n", "")
|
||||
assert text[len(text)-len(cwd):] == cwd
|
||||
|
||||
await do_connected_test("-a iih -C -- /bin/pwd", "dh", test)
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user