mirror of
https://github.com/markqvist/rnsh.git
synced 2025-06-20 20:14:15 -04:00
Add tests for and fix #14
This commit is contained in:
parent
a07ce53bf9
commit
458a2391df
5 changed files with 82 additions and 8 deletions
|
@ -101,8 +101,12 @@ class ListenerSession:
|
|||
self.return_code: int | None = None
|
||||
self.return_code_sent = False
|
||||
self.process: process.CallbackSubprocess | None = None
|
||||
self._set_state(LSState.LSSTATE_WAIT_IDENT)
|
||||
if self.allow_all:
|
||||
self._set_state(LSState.LSSTATE_WAIT_VERS)
|
||||
else:
|
||||
self._set_state(LSState.LSSTATE_WAIT_IDENT)
|
||||
self.sessions.append(self)
|
||||
self.outlet.set_packet_received_callback(self._packet_received)
|
||||
|
||||
def _terminated(self, return_code: int):
|
||||
self.return_code = return_code
|
||||
|
@ -176,14 +180,13 @@ class ListenerSession:
|
|||
return
|
||||
|
||||
self._log.info(f"initiator_identified {identity} on link {outlet}")
|
||||
if self.state != LSState.LSSTATE_WAIT_IDENT:
|
||||
if self.state not in [LSState.LSSTATE_WAIT_IDENT, LSState.LSSTATE_WAIT_VERS]:
|
||||
self._protocol_error(LSState.LSSTATE_WAIT_IDENT.name)
|
||||
|
||||
if not self.allow_all and identity.hash not in self.allowed_identity_hashes:
|
||||
self.terminate("Identity is not allowed.")
|
||||
|
||||
self.remote_identity = identity
|
||||
self.outlet.set_packet_received_callback(self._packet_received)
|
||||
self._set_state(LSState.LSSTATE_WAIT_VERS)
|
||||
|
||||
@classmethod
|
||||
|
@ -277,7 +280,8 @@ class ListenerSession:
|
|||
try:
|
||||
self.process = process.CallbackSubprocess(argv=self.cmdline,
|
||||
env={"TERM": self.term or os.environ.get("TERM", None),
|
||||
"RNS_REMOTE_IDENTITY": RNS.prettyhexrep(self.remote_identity.hash) or ""},
|
||||
"RNS_REMOTE_IDENTITY": (RNS.prettyhexrep(self.remote_identity.hash)
|
||||
if self.remote_identity and self.remote_identity.hash else "")},
|
||||
loop=self.loop,
|
||||
stdout_callback=stdout,
|
||||
stderr_callback=stderr,
|
||||
|
@ -306,6 +310,9 @@ class ListenerSession:
|
|||
self.process.close_stdin()
|
||||
|
||||
def _handle_message(self, message: protocol.Message):
|
||||
if self.state == LSState.LSSTATE_WAIT_IDENT:
|
||||
self._protocol_error("Identification required")
|
||||
return
|
||||
if self.state == LSState.LSSTATE_WAIT_VERS:
|
||||
if not isinstance(message, protocol.VersionInfoMessage):
|
||||
self._protocol_error(self.state.name)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue