Fix endless wait if handshake fails

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2022-06-17 16:51:50 +02:00 committed by Daniel Weiße
parent e6b1156849
commit 3b92b52611
2 changed files with 104 additions and 2 deletions

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"io" "io"
"strings"
"time" "time"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
@ -55,7 +56,7 @@ func (w *Waiter) WaitFor(ctx context.Context, endpoint string, status ...state.S
// Check once before waiting // Check once before waiting
resp, err := w.probe(ctx, endpoint) resp, err := w.probe(ctx, endpoint)
if err != nil && grpcstatus.Code(err) != grpccodes.Unavailable { if err != nil && (grpcstatus.Code(err) != grpccodes.Unavailable || isGRPCHandshakeError(err)) {
return err return err
} }
if resp != nil && containsState(state.State(resp.State), status...) { if resp != nil && containsState(state.State(resp.State), status...) {
@ -67,7 +68,7 @@ func (w *Waiter) WaitFor(ctx context.Context, endpoint string, status ...state.S
select { select {
case <-ticker.C: case <-ticker.C:
resp, err := w.probe(ctx, endpoint) resp, err := w.probe(ctx, endpoint)
if grpcstatus.Code(err) == grpccodes.Unavailable { if grpcstatus.Code(err) == grpccodes.Unavailable && !isGRPCHandshakeError(err) {
// The server isn't reachable yet. // The server isn't reachable yet.
continue continue
} }
@ -136,3 +137,15 @@ func containsState(s state.State, states ...state.State) bool {
} }
return false return false
} }
func isGRPCHandshakeError(err error) bool {
statusErr, ok := grpcstatus.FromError(err)
if !ok {
return false
}
if statusErr.Code() != grpccodes.Unavailable {
return false
}
// ideally we would check the error type directly, but grpc only provides a string
return strings.HasPrefix(statusErr.Message(), `connection error: desc = "transport: authentication handshake failed`)
}

View File

@ -3,15 +3,20 @@ package statuswaiter
import ( import (
"context" "context"
"errors" "errors"
"net"
"testing" "testing"
"time" "time"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/internal/atls" "github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/internal/oid" "github.com/edgelesssys/constellation/internal/oid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/grpc/test/bufconn"
) )
func TestInitializeValidators(t *testing.T) { func TestInitializeValidators(t *testing.T) {
@ -36,6 +41,7 @@ func TestInitializeValidators(t *testing.T) {
func TestWaitForAndWaitForAll(t *testing.T) { func TestWaitForAndWaitForAll(t *testing.T) {
var noErr error var noErr error
someErr := errors.New("failed") someErr := errors.New("failed")
handshakeErr := status.Error(codes.Unavailable, `connection error: desc = "transport: authentication handshake failed"`)
testCases := map[string]struct { testCases := map[string]struct {
waiter Waiter waiter Waiter
@ -90,6 +96,16 @@ func TestWaitForAndWaitForAll(t *testing.T) {
waitForState: []state.State{state.IsNode}, waitForState: []state.State{state.IsNode},
wantErr: true, wantErr: true,
}, },
"fail TLS handshake": {
waiter: Waiter{
initialized: true,
interval: time.Millisecond,
newConn: stubNewConnFunc(handshakeErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
},
waitForState: []state.State{state.IsNode},
wantErr: true,
},
} }
t.Run("WaitFor", func(t *testing.T) { t.Run("WaitFor", func(t *testing.T) {
@ -216,3 +232,76 @@ func TestContainsState(t *testing.T) {
}) })
} }
} }
func TestIsHandshakeError(t *testing.T) {
testCases := map[string]struct {
err error
wantedResult bool
}{
"TLS handshake error": {
err: getGRPCHandshakeError(),
wantedResult: true,
},
"Unavailable error": {
err: status.Error(codes.Unavailable, "connection error"),
wantedResult: false,
},
"TLS handshake error with wrong code": {
err: status.Error(codes.Aborted, `connection error: desc = "transport: authentication handshake failed`),
wantedResult: false,
},
"Non gRPC error": {
err: errors.New("error"),
wantedResult: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
res := isGRPCHandshakeError(tc.err)
assert.Equal(tc.wantedResult, res)
})
}
}
func getGRPCHandshakeError() error {
serverCreds := atlscredentials.New(atls.NewFakeIssuer(oid.Dummy{}), nil)
api := &fakeAPI{}
server := grpc.NewServer(grpc.Creds(serverCreds))
pubproto.RegisterAPIServer(server, api)
listener := bufconn.Listen(1024)
defer server.GracefulStop()
go server.Serve(listener)
clientCreds := atlscredentials.New(nil, []atls.Validator{failingValidator{oid.Dummy{}}})
conn, err := grpc.DialContext(context.Background(), "", grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return listener.Dial()
}), grpc.WithTransportCredentials(clientCreds))
if err != nil {
panic(err)
}
defer conn.Close()
client := pubproto.NewAPIClient(conn)
_, err = client.GetState(context.Background(), &pubproto.GetStateRequest{})
return err
}
type failingValidator struct {
oid.Getter
}
func (v failingValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
return nil, errors.New("error")
}
type fakeAPI struct {
pubproto.UnimplementedAPIServer
}
func (f *fakeAPI) GetState(ctx context.Context, in *pubproto.GetStateRequest) (*pubproto.GetStateResponse, error) {
return &pubproto.GetStateResponse{State: 1}, nil
}