diff --git a/internal/statuswaiter/statuswaiter.go b/internal/statuswaiter/statuswaiter.go index d07f94b64..1002c4d57 100644 --- a/internal/statuswaiter/statuswaiter.go +++ b/internal/statuswaiter/statuswaiter.go @@ -4,6 +4,7 @@ import ( "context" "errors" "io" + "strings" "time" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" @@ -55,7 +56,7 @@ func (w *Waiter) WaitFor(ctx context.Context, endpoint string, status ...state.S // Check once before waiting resp, err := w.probe(ctx, endpoint) - if err != nil && grpcstatus.Code(err) != grpccodes.Unavailable { + if err != nil && (grpcstatus.Code(err) != grpccodes.Unavailable || isGRPCHandshakeError(err)) { return err } if resp != nil && containsState(state.State(resp.State), status...) { @@ -67,7 +68,7 @@ func (w *Waiter) WaitFor(ctx context.Context, endpoint string, status ...state.S select { case <-ticker.C: resp, err := w.probe(ctx, endpoint) - if grpcstatus.Code(err) == grpccodes.Unavailable { + if grpcstatus.Code(err) == grpccodes.Unavailable && !isGRPCHandshakeError(err) { // The server isn't reachable yet. continue } @@ -136,3 +137,15 @@ func containsState(s state.State, states ...state.State) bool { } return false } + +func isGRPCHandshakeError(err error) bool { + statusErr, ok := grpcstatus.FromError(err) + if !ok { + return false + } + if statusErr.Code() != grpccodes.Unavailable { + return false + } + // ideally we would check the error type directly, but grpc only provides a string + return strings.HasPrefix(statusErr.Message(), `connection error: desc = "transport: authentication handshake failed`) +} diff --git a/internal/statuswaiter/statuswaiter_test.go b/internal/statuswaiter/statuswaiter_test.go index 44d3b692d..dea4ae5bf 100644 --- a/internal/statuswaiter/statuswaiter_test.go +++ b/internal/statuswaiter/statuswaiter_test.go @@ -3,15 +3,20 @@ package statuswaiter import ( "context" "errors" + "net" "testing" "time" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/internal/atls" + "github.com/edgelesssys/constellation/internal/grpc/atlscredentials" "github.com/edgelesssys/constellation/internal/oid" "github.com/stretchr/testify/assert" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/grpc/test/bufconn" ) func TestInitializeValidators(t *testing.T) { @@ -36,6 +41,7 @@ func TestInitializeValidators(t *testing.T) { func TestWaitForAndWaitForAll(t *testing.T) { var noErr error someErr := errors.New("failed") + handshakeErr := status.Error(codes.Unavailable, `connection error: desc = "transport: authentication handshake failed"`) testCases := map[string]struct { waiter Waiter @@ -90,6 +96,16 @@ func TestWaitForAndWaitForAll(t *testing.T) { waitForState: []state.State{state.IsNode}, wantErr: true, }, + "fail TLS handshake": { + waiter: Waiter{ + initialized: true, + interval: time.Millisecond, + newConn: stubNewConnFunc(handshakeErr), + newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}), + }, + waitForState: []state.State{state.IsNode}, + wantErr: true, + }, } t.Run("WaitFor", func(t *testing.T) { @@ -216,3 +232,76 @@ func TestContainsState(t *testing.T) { }) } } + +func TestIsHandshakeError(t *testing.T) { + testCases := map[string]struct { + err error + wantedResult bool + }{ + "TLS handshake error": { + err: getGRPCHandshakeError(), + wantedResult: true, + }, + "Unavailable error": { + err: status.Error(codes.Unavailable, "connection error"), + wantedResult: false, + }, + "TLS handshake error with wrong code": { + err: status.Error(codes.Aborted, `connection error: desc = "transport: authentication handshake failed`), + wantedResult: false, + }, + "Non gRPC error": { + err: errors.New("error"), + wantedResult: false, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + + res := isGRPCHandshakeError(tc.err) + assert.Equal(tc.wantedResult, res) + }) + } +} + +func getGRPCHandshakeError() error { + serverCreds := atlscredentials.New(atls.NewFakeIssuer(oid.Dummy{}), nil) + api := &fakeAPI{} + server := grpc.NewServer(grpc.Creds(serverCreds)) + pubproto.RegisterAPIServer(server, api) + + listener := bufconn.Listen(1024) + defer server.GracefulStop() + go server.Serve(listener) + + clientCreds := atlscredentials.New(nil, []atls.Validator{failingValidator{oid.Dummy{}}}) + conn, err := grpc.DialContext(context.Background(), "", grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + return listener.Dial() + }), grpc.WithTransportCredentials(clientCreds)) + if err != nil { + panic(err) + } + defer conn.Close() + + client := pubproto.NewAPIClient(conn) + _, err = client.GetState(context.Background(), &pubproto.GetStateRequest{}) + return err +} + +type failingValidator struct { + oid.Getter +} + +func (v failingValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) { + return nil, errors.New("error") +} + +type fakeAPI struct { + pubproto.UnimplementedAPIServer +} + +func (f *fakeAPI) GetState(ctx context.Context, in *pubproto.GetStateRequest) (*pubproto.GetStateResponse, error) { + return &pubproto.GetStateResponse{State: 1}, nil +}