/* Copyright (c) Edgeless Systems GmbH SPDX-License-Identifier: AGPL-3.0-only */ package atlscredentials import ( "bytes" "context" "encoding/asn1" "encoding/json" "errors" "net" "testing" "github.com/edgelesssys/constellation/v2/bootstrapper/initproto" "github.com/edgelesssys/constellation/v2/internal/atls" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" "google.golang.org/grpc" "google.golang.org/grpc/test/bufconn" ) func TestMain(m *testing.M) { goleak.VerifyTestMain(m, goleak.IgnoreAnyFunction("github.com/bazelbuild/rules_go/go/tools/bzltestutil.RegisterTimeoutHandler.func1")) } func TestATLSCredentials(t *testing.T) { assert := assert.New(t) require := require.New(t) oid := fakeOID{1, 3, 9900, 1} // // Create servers // serverCreds := New(fakeIssuer{fakeOID: oid}, nil) const serverCount = 15 var listeners []*bufconn.Listener for i := 0; i < serverCount; i++ { api := &fakeAPI{} server := grpc.NewServer(grpc.Creds(serverCreds)) initproto.RegisterAPIServer(server, api) listener := bufconn.Listen(1024) listeners = append(listeners, listener) defer server.GracefulStop() go server.Serve(listener) } // // Dial concurrently // clientCreds := New(nil, []atls.Validator{fakeValidator{fakeOID: oid}}) errChan := make(chan error, serverCount) for _, listener := range listeners { lis := listener go func() { var err error defer func() { errChan <- err }() conn, err := grpc.DialContext(context.Background(), "", grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { return lis.Dial() }), grpc.WithTransportCredentials(clientCreds)) require.NoError(err) defer conn.Close() client := initproto.NewAPIClient(conn) _, err = client.Init(context.Background(), &initproto.InitRequest{}) }() } for i := 0; i < serverCount; i++ { assert.NoError(<-errChan) } } type fakeIssuer struct { fakeOID } func (fakeIssuer) Issue(_ context.Context, userData []byte, nonce []byte) ([]byte, error) { return json.Marshal(fakeDoc{UserData: userData, Nonce: nonce}) } type fakeValidator struct { fakeOID err error } func (v fakeValidator) Validate(_ context.Context, attDoc []byte, nonce []byte) ([]byte, error) { var doc fakeDoc if err := json.Unmarshal(attDoc, &doc); err != nil { return nil, err } if !bytes.Equal(doc.Nonce, nonce) { return nil, errors.New("invalid nonce") } return doc.UserData, v.err } type fakeOID asn1.ObjectIdentifier func (o fakeOID) OID() asn1.ObjectIdentifier { return asn1.ObjectIdentifier(o) } type fakeDoc struct { UserData []byte Nonce []byte } type fakeAPI struct { initproto.UnimplementedAPIServer } func (f *fakeAPI) Init(_ *initproto.InitRequest, stream initproto.API_InitServer) error { _ = stream.Send(&initproto.InitResponse{}) return nil }