diff --git a/activation/cmd/main.go b/activation/cmd/main.go index 8e7f604da..47fedbed6 100644 --- a/activation/cmd/main.go +++ b/activation/cmd/main.go @@ -12,6 +12,7 @@ import ( "github.com/edgelesssys/constellation/internal/atls" "github.com/edgelesssys/constellation/internal/constants" "github.com/edgelesssys/constellation/internal/file" + "github.com/edgelesssys/constellation/internal/grpc/atlscredentials" "github.com/spf13/afero" "k8s.io/klog/v2" ) @@ -36,10 +37,7 @@ func main() { klog.Exitf("failed to create validator: %s", err) } - tlsConfig, err := atls.CreateAttestationServerTLSConfig(nil, []atls.Validator{validator}) - if err != nil { - klog.Exitf("unable to create server config: %s", err) - } + creds := atlscredentials.New(nil, []atls.Validator{validator}) kubeadm, err := kubeadm.New() if err != nil { @@ -62,7 +60,7 @@ func main() { } }() - if err := server.Run(tlsConfig, bindPort); err != nil { + if err := server.Run(creds, bindPort); err != nil { klog.Exitf("failed to run server: %s", err) } } diff --git a/activation/server/server.go b/activation/server/server.go index d4b7b52a5..723b10af1 100644 --- a/activation/server/server.go +++ b/activation/server/server.go @@ -2,7 +2,6 @@ package server import ( "context" - "crypto/tls" "fmt" "net" "time" @@ -39,9 +38,9 @@ func New(fileHandler file.Handler, ca certificateAuthority, joinTokenGetter join } // Run starts the gRPC server on the given port, using the provided tlsConfig. -func (s *Server) Run(tlsConfig *tls.Config, port string) error { +func (s *Server) Run(creds credentials.TransportCredentials, port string) error { grpcServer := grpc.NewServer( - grpc.Creds(credentials.NewTLS(tlsConfig)), + grpc.Creds(creds), grpc.UnaryInterceptor(logGRPC), ) diff --git a/cli/internal/proto/client.go b/cli/internal/proto/client.go index 4833c50dd..a101655c5 100644 --- a/cli/internal/proto/client.go +++ b/cli/internal/proto/client.go @@ -8,10 +8,10 @@ import ( "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/internal/atls" + "github.com/edgelesssys/constellation/internal/grpc/atlscredentials" kms "github.com/edgelesssys/constellation/kms/server/setup" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" ) // Client wraps a PubAPI client and the connection to it. @@ -26,12 +26,9 @@ type Client struct { // called on a client that already has a connection, the old // connection is closed. func (c *Client) Connect(endpoint string, validators []atls.Validator) error { - tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, validators) - if err != nil { - return err - } + creds := atlscredentials.New(nil, validators) - conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) + conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(creds)) if err != nil { return err } diff --git a/cli/internal/proto/recover.go b/cli/internal/proto/recover.go index a4dece4bc..362855ce8 100644 --- a/cli/internal/proto/recover.go +++ b/cli/internal/proto/recover.go @@ -5,9 +5,9 @@ import ( "errors" "github.com/edgelesssys/constellation/internal/atls" + "github.com/edgelesssys/constellation/internal/grpc/atlscredentials" "github.com/edgelesssys/constellation/state/keyservice/keyproto" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" ) // KeyClient wraps a KeyAPI client and the connection to it. @@ -22,12 +22,9 @@ type KeyClient struct { // called on a client that already has a connection, the old // connection is closed. func (c *KeyClient) Connect(endpoint string, validators []atls.Validator) error { - tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, validators) - if err != nil { - return err - } + creds := atlscredentials.New(nil, validators) - conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) + conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(creds)) if err != nil { return err } diff --git a/coordinator/cmd/coordinator/main.go b/coordinator/cmd/coordinator/main.go index b525e3a9d..d29685511 100644 --- a/coordinator/cmd/coordinator/main.go +++ b/coordinator/cmd/coordinator/main.go @@ -20,7 +20,6 @@ import ( "github.com/edgelesssys/constellation/coordinator/kubernetes/k8sapi/kubectl" "github.com/edgelesssys/constellation/coordinator/logging" "github.com/edgelesssys/constellation/coordinator/util" - "github.com/edgelesssys/constellation/coordinator/util/grpcutil" "github.com/edgelesssys/constellation/coordinator/wireguard" "github.com/edgelesssys/constellation/internal/attestation/azure" "github.com/edgelesssys/constellation/internal/attestation/gcp" @@ -28,6 +27,7 @@ import ( "github.com/edgelesssys/constellation/internal/attestation/simulator" "github.com/edgelesssys/constellation/internal/attestation/vtpm" "github.com/edgelesssys/constellation/internal/file" + "github.com/edgelesssys/constellation/internal/grpc/dialer" grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap" "github.com/spf13/afero" "go.uber.org/zap" @@ -177,7 +177,7 @@ func main() { fileHandler := file.NewHandler(fs) netDialer := &net.Dialer{} - dialer := grpcutil.NewDialer(validator, netDialer) + dialer := dialer.New(nil, validator, netDialer) run(issuer, wg, openTPM, util.GetIPAddr, dialer, fileHandler, kube, coreMetadata, encryptedDisk, etcdEndpoint, enforceEtcdTls, bindIP, bindPort, zapLoggerCore, cloudLogger, fs) diff --git a/coordinator/cmd/coordinator/run.go b/coordinator/cmd/coordinator/run.go index ce0ff08e9..bd1432cf3 100644 --- a/coordinator/cmd/coordinator/run.go +++ b/coordinator/cmd/coordinator/run.go @@ -12,25 +12,24 @@ import ( "github.com/edgelesssys/constellation/coordinator/pubapi" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/store" - "github.com/edgelesssys/constellation/coordinator/util/grpcutil" "github.com/edgelesssys/constellation/coordinator/vpnapi" "github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto" - "github.com/edgelesssys/constellation/internal/atls" "github.com/edgelesssys/constellation/internal/attestation/vtpm" "github.com/edgelesssys/constellation/internal/deploy/user" "github.com/edgelesssys/constellation/internal/file" + "github.com/edgelesssys/constellation/internal/grpc/atlscredentials" + "github.com/edgelesssys/constellation/internal/grpc/dialer" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap" grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" "github.com/spf13/afero" "go.uber.org/zap" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" ) var version = "0.0.0" -func run(issuer core.QuoteIssuer, vpn core.VPN, tpm vtpm.TPMOpenFunc, getPublicIPAddr func() (string, error), dialer *grpcutil.Dialer, fileHandler file.Handler, +func run(issuer core.QuoteIssuer, vpn core.VPN, tpm vtpm.TPMOpenFunc, getPublicIPAddr func() (string, error), dialer *dialer.Dialer, fileHandler file.Handler, kube core.Cluster, metadata core.ProviderMetadata, disk core.EncryptedDisk, etcdEndpoint string, etcdTLS bool, bindIP, bindPort string, logger *zap.Logger, cloudLogger logging.CloudLogger, fs afero.Fs, ) { @@ -40,10 +39,7 @@ func run(issuer core.QuoteIssuer, vpn core.VPN, tpm vtpm.TPMOpenFunc, getPublicI defer cloudLogger.Close() cloudLogger.Disclose("Coordinator started running...") - tlsConfig, err := atls.CreateAttestationServerTLSConfig(issuer, nil) - if err != nil { - logger.Fatal("failed to create server TLS config", zap.Error(err)) - } + creds := atlscredentials.New(issuer, nil) etcdStoreFactory := store.NewEtcdStoreFactory(etcdEndpoint, etcdTLS, logger) linuxUserManager := user.NewLinuxUserManager(fs) @@ -64,7 +60,7 @@ func run(issuer core.QuoteIssuer, vpn core.VPN, tpm vtpm.TPMOpenFunc, getPublicI zapLoggergRPC := loggerPubAPI.Named("gRPC") grpcServer := grpc.NewServer( - grpc.Creds(credentials.NewTLS(tlsConfig)), + grpc.Creds(creds), grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( grpc_ctxtags.StreamServerInterceptor(), grpc_zap.StreamServerInterceptor(zapLoggergRPC), @@ -117,14 +113,11 @@ func tryJoinClusterOnStartup(getPublicIPAddr func() (string, error), metadata co // We create an client unverified connection, since the node does not need to verify the Coordinator. // ActivateAdditionalNodes triggers the Coordinator to call ActivateAsNode. This rpc lets the Coordinator verify the node. - tlsClientConfig, err := atls.CreateAttestationClientTLSConfig(nil, nil) - if err != nil { - return fmt.Errorf("failed to create client TLS config: %w", err) - } + creds := atlscredentials.New(nil, nil) // try to notify a coordinator to activate this node for _, coordinatorEndpoint := range coordinatorEndpoints { - conn, err := grpc.Dial(coordinatorEndpoint, grpc.WithTransportCredentials(credentials.NewTLS(tlsClientConfig))) + conn, err := grpc.Dial(coordinatorEndpoint, grpc.WithTransportCredentials(creds)) if err != nil { logger.Info("Dial failed:", zap.String("endpoint", coordinatorEndpoint), zap.Error(err)) continue diff --git a/coordinator/coordinator_test.go b/coordinator/coordinator_test.go index ff0f165b9..74817c61a 100644 --- a/coordinator/coordinator_test.go +++ b/coordinator/coordinator_test.go @@ -15,14 +15,15 @@ import ( "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/coordinator/store" - "github.com/edgelesssys/constellation/coordinator/util/grpcutil" - "github.com/edgelesssys/constellation/coordinator/util/testdialer" "github.com/edgelesssys/constellation/coordinator/vpnapi" "github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto" "github.com/edgelesssys/constellation/internal/atls" "github.com/edgelesssys/constellation/internal/attestation/simulator" "github.com/edgelesssys/constellation/internal/deploy/user" "github.com/edgelesssys/constellation/internal/file" + "github.com/edgelesssys/constellation/internal/grpc/atlscredentials" + "github.com/edgelesssys/constellation/internal/grpc/dialer" + "github.com/edgelesssys/constellation/internal/grpc/testdialer" kms "github.com/edgelesssys/constellation/kms/server/setup" "github.com/spf13/afero" "github.com/stretchr/testify/assert" @@ -31,7 +32,6 @@ import ( "go.uber.org/zap" "go.uber.org/zap/zaptest" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" ) func TestMain(m *testing.M) { @@ -221,14 +221,13 @@ func spawnPeer(require *require.Assertions, logger *zap.Logger, netDialer *testd getPublicAddr := func() (string, error) { return "192.0.2.1", nil } - dialer := grpcutil.NewDialer(&core.MockValidator{}, netDialer) + dialer := dialer.New(nil, &core.MockValidator{}, netDialer) vapiServer := &fakeVPNAPIServer{logger: logger.Named("vpnapi"), core: cor, dialer: netDialer} papi := pubapi.New(logger, &logging.NopLogger{}, cor, dialer, vapiServer, getPublicAddr, nil) - tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{}, nil) - require.NoError(err) - server := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) + creds := atlscredentials.New(&core.MockIssuer{}, nil) + server := grpc.NewServer(grpc.Creds(creds)) pubproto.RegisterAPIServer(server, papi) listener := netDialer.GetListener(endpoint) @@ -264,16 +263,13 @@ func activateCoordinator(require *require.Assertions, dialer netDialer, coordina } func dialGRPC(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) { - tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, []atls.Validator{&core.MockValidator{}}) - if err != nil { - return nil, err - } + creds := atlscredentials.New(nil, []atls.Validator{&core.MockValidator{}}) return grpc.DialContext(ctx, target, grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { return dialer.DialContext(ctx, "tcp", addr) }), - grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), + grpc.WithTransportCredentials(creds), ) } diff --git a/coordinator/core/core_test.go b/coordinator/core/core_test.go index b4693b276..048ec6553 100644 --- a/coordinator/core/core_test.go +++ b/coordinator/core/core_test.go @@ -11,12 +11,12 @@ import ( "github.com/edgelesssys/constellation/coordinator/role" "github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/coordinator/store" - "github.com/edgelesssys/constellation/coordinator/util/grpcutil" - "github.com/edgelesssys/constellation/coordinator/util/testdialer" "github.com/edgelesssys/constellation/internal/attestation/simulator" "github.com/edgelesssys/constellation/internal/attestation/vtpm" "github.com/edgelesssys/constellation/internal/deploy/user" "github.com/edgelesssys/constellation/internal/file" + "github.com/edgelesssys/constellation/internal/grpc/dialer" + "github.com/edgelesssys/constellation/internal/grpc/testdialer" kms "github.com/edgelesssys/constellation/kms/server/setup" "github.com/spf13/afero" "github.com/stretchr/testify/assert" @@ -220,7 +220,7 @@ func TestInitialize(t *testing.T) { // prepare store to emulate initialized KMS require.NoError(core.data().PutKMSData(kms.KMSInformation{StorageUri: kms.NoStoreURI, KmsUri: kms.ClusterKMSURI})) require.NoError(core.data().PutMasterSecret([]byte("master-secret"))) - dialer := grpcutil.NewDialer(&MockValidator{}, testdialer.NewBufconnDialer()) + dialer := dialer.New(nil, &MockValidator{}, testdialer.NewBufconnDialer()) nodeActivated, err := core.Initialize(context.Background(), dialer, &stubPubAPI{}) if tc.wantErr { diff --git a/coordinator/core/legacy_test.go b/coordinator/core/legacy_test.go index 3d6f374ff..4af07f146 100644 --- a/coordinator/core/legacy_test.go +++ b/coordinator/core/legacy_test.go @@ -13,20 +13,19 @@ import ( "github.com/edgelesssys/constellation/coordinator/pubapi" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/state" - "github.com/edgelesssys/constellation/coordinator/util/grpcutil" "github.com/edgelesssys/constellation/coordinator/vpnapi" "github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto" - "github.com/edgelesssys/constellation/internal/atls" "github.com/edgelesssys/constellation/internal/attestation/simulator" "github.com/edgelesssys/constellation/internal/deploy/user" "github.com/edgelesssys/constellation/internal/file" + "github.com/edgelesssys/constellation/internal/grpc/atlscredentials" + "github.com/edgelesssys/constellation/internal/grpc/dialer" kms "github.com/edgelesssys/constellation/kms/server/setup" "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" "google.golang.org/grpc/test/bufconn" ) @@ -121,7 +120,7 @@ func newMockCoreWithDialer(bufDialer *bufconnDialer) (*Core, *pubapi.API, error) return nil, nil, err } - dialer := grpcutil.NewDialer(NewMockValidator(), bufDialer) + dialer := dialer.New(nil, NewMockValidator(), bufDialer) vpn := &stubVPN{} kubeFake := &ClusterFake{} metadataFake := &ProviderMetadataFake{} @@ -171,12 +170,9 @@ func (b *bufconnDialer) addListener(endpoint string, listener *bufconn.Listener) } func spawnNode(endpoint string, testNodeCore *pubapi.API, bufDialer *bufconnDialer) (*grpc.Server, error) { - tlsConfig, err := atls.CreateAttestationServerTLSConfig(&MockIssuer{}, nil) - if err != nil { - return nil, err - } + creds := atlscredentials.New(&MockIssuer{}, nil) - grpcServer := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) + grpcServer := grpc.NewServer(grpc.Creds(creds)) pubproto.RegisterAPIServer(grpcServer, testNodeCore) const bufferSize = 8 * 1024 diff --git a/coordinator/core/reinitialize_test.go b/coordinator/core/reinitialize_test.go index e438d8dce..6a187fd5c 100644 --- a/coordinator/core/reinitialize_test.go +++ b/coordinator/core/reinitialize_test.go @@ -9,11 +9,11 @@ import ( "github.com/edgelesssys/constellation/coordinator/peer" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/role" - "github.com/edgelesssys/constellation/coordinator/util/grpcutil" - "github.com/edgelesssys/constellation/coordinator/util/testdialer" - "github.com/edgelesssys/constellation/internal/atls" "github.com/edgelesssys/constellation/internal/deploy/user" "github.com/edgelesssys/constellation/internal/file" + "github.com/edgelesssys/constellation/internal/grpc/atlscredentials" + "github.com/edgelesssys/constellation/internal/grpc/dialer" + "github.com/edgelesssys/constellation/internal/grpc/testdialer" kms "github.com/edgelesssys/constellation/kms/server/setup" "github.com/spf13/afero" "github.com/stretchr/testify/assert" @@ -21,7 +21,6 @@ import ( "go.uber.org/zap" "go.uber.org/zap/zaptest" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" "google.golang.org/protobuf/proto" ) @@ -74,7 +73,7 @@ func TestReinitializeAsNode(t *testing.T) { coordinators := []cloudtypes.Instance{{PrivateIPs: []string{"192.0.2.1"}, Role: role.Coordinator}} netDialer := testdialer.NewBufconnDialer() - dialer := grpcutil.NewDialer(&MockValidator{}, netDialer) + dialer := dialer.New(nil, &MockValidator{}, netDialer) server := newPubAPIServer() api := &pubAPIServerStub{responses: tc.getInitialVPNPeersResponses} pubproto.RegisterAPIServer(server, api) @@ -147,7 +146,7 @@ func TestReinitializeAsCoordinator(t *testing.T) { coordinators := []cloudtypes.Instance{{PrivateIPs: []string{"192.0.2.1"}, Role: role.Coordinator}} netDialer := testdialer.NewBufconnDialer() - dialer := grpcutil.NewDialer(&MockValidator{}, netDialer) + dialer := dialer.New(nil, &MockValidator{}, netDialer) server := newPubAPIServer() api := &pubAPIServerStub{responses: tc.getInitialVPNPeersResponses} pubproto.RegisterAPIServer(server, api) @@ -235,7 +234,7 @@ func TestGetInitialVPNPeers(t *testing.T) { zapLogger, err := zap.NewDevelopment() require.NoError(err) netDialer := testdialer.NewBufconnDialer() - dialer := grpcutil.NewDialer(&MockValidator{}, netDialer) + dialer := dialer.New(nil, &MockValidator{}, netDialer) server := newPubAPIServer() api := &pubAPIServerStub{ responses: []struct { @@ -259,11 +258,9 @@ func TestGetInitialVPNPeers(t *testing.T) { } func newPubAPIServer() *grpc.Server { - tlsConfig, err := atls.CreateAttestationServerTLSConfig(&MockIssuer{}, nil) - if err != nil { - panic(err) - } - return grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) + creds := atlscredentials.New(&MockIssuer{}, nil) + + return grpc.NewServer(grpc.Creds(creds)) } type pubAPIServerStub struct { diff --git a/coordinator/pubapi/coord_test.go b/coordinator/pubapi/coord_test.go index 1c91b3473..cd633ba6a 100644 --- a/coordinator/pubapi/coord_test.go +++ b/coordinator/pubapi/coord_test.go @@ -16,12 +16,12 @@ import ( "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/role" "github.com/edgelesssys/constellation/coordinator/state" - "github.com/edgelesssys/constellation/coordinator/util/grpcutil" - "github.com/edgelesssys/constellation/coordinator/util/testdialer" - "github.com/edgelesssys/constellation/internal/atls" "github.com/edgelesssys/constellation/internal/attestation/vtpm" "github.com/edgelesssys/constellation/internal/deploy/ssh" "github.com/edgelesssys/constellation/internal/deploy/user" + "github.com/edgelesssys/constellation/internal/grpc/atlscredentials" + "github.com/edgelesssys/constellation/internal/grpc/dialer" + "github.com/edgelesssys/constellation/internal/grpc/testdialer" "github.com/edgelesssys/constellation/internal/oid" kms "github.com/edgelesssys/constellation/kms/server/setup" "github.com/edgelesssys/constellation/state/keyservice/keyproto" @@ -30,7 +30,6 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/zap/zaptest" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" grpcpeer "google.golang.org/grpc/peer" ) @@ -150,7 +149,7 @@ func TestActivateAsCoordinator(t *testing.T) { } netDialer := testdialer.NewBufconnDialer() - dialer := grpcutil.NewDialer(fakeValidator{}, netDialer) + dialer := dialer.New(nil, fakeValidator{}, netDialer) getPublicIPAddr := func() (string, error) { return "192.0.2.1", nil @@ -302,7 +301,7 @@ func TestActivateAdditionalNodes(t *testing.T) { core := &fakeCore{state: tc.state} netDialer := testdialer.NewBufconnDialer() - dialer := grpcutil.NewDialer(fakeValidator{}, netDialer) + dialer := dialer.New(nil, fakeValidator{}, netDialer) getPublicIPAddr := func() (string, error) { return "192.0.2.1", nil @@ -432,11 +431,8 @@ func (n *stubPeer) GetPeerVPNPublicKey(ctx context.Context, in *pubproto.GetPeer } func (n *stubPeer) newServer() *grpc.Server { - tlsConfig, err := atls.CreateAttestationServerTLSConfig(fakeIssuer{}, nil) - if err != nil { - panic(err) - } - server := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) + creds := atlscredentials.New(fakeIssuer{}, nil) + server := grpc.NewServer(grpc.Creds(creds)) pubproto.RegisterAPIServer(server, n) return server } @@ -537,9 +533,8 @@ func TestRequestStateDiskKey(t *testing.T) { require.NoError(err) defer listener.Close() - tlsConfig, err := atls.CreateAttestationServerTLSConfig(issuer, nil) - require.NoError(err) - s := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) + creds := atlscredentials.New(issuer, nil) + s := grpc.NewServer(grpc.Creds(creds)) keyproto.RegisterAPIServer(s, stateDiskServer) defer s.GracefulStop() go s.Serve(listener) @@ -559,7 +554,7 @@ func TestRequestStateDiskKey(t *testing.T) { getDataKeyErr: tc.getDataKeyErr, } - api := New(zaptest.NewLogger(t), &logging.NopLogger{}, core, grpcutil.NewDialer(dummyValidator{}, &net.Dialer{}), nil, nil, getPeerFromContext) + api := New(zaptest.NewLogger(t), &logging.NopLogger{}, core, dialer.New(nil, dummyValidator{}, &net.Dialer{}), nil, nil, getPeerFromContext) _, err = api.RequestStateDiskKey(ctx, &pubproto.RequestStateDiskKeyRequest{}) if tc.wantErr { diff --git a/coordinator/pubapi/multicoord_test.go b/coordinator/pubapi/multicoord_test.go index 23c1ec7a0..492891406 100644 --- a/coordinator/pubapi/multicoord_test.go +++ b/coordinator/pubapi/multicoord_test.go @@ -11,8 +11,8 @@ import ( "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/role" "github.com/edgelesssys/constellation/coordinator/state" - "github.com/edgelesssys/constellation/coordinator/util/grpcutil" - "github.com/edgelesssys/constellation/coordinator/util/testdialer" + "github.com/edgelesssys/constellation/internal/grpc/dialer" + "github.com/edgelesssys/constellation/internal/grpc/testdialer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap/zaptest" @@ -95,7 +95,7 @@ func TestActivateAsAdditionalCoordinator(t *testing.T) { clusterID: []byte("clusterID"), } netDialer := testdialer.NewBufconnDialer() - dialer := grpcutil.NewDialer(fakeValidator{}, netDialer) + dialer := dialer.New(nil, fakeValidator{}, netDialer) getPublicIPAddr := func() (string, error) { return "192.0.2.1", nil @@ -166,7 +166,7 @@ func TestTriggerCoordinatorUpdate(t *testing.T) { state: tc.state, peers: tc.peers, } - dialer := grpcutil.NewDialer(fakeValidator{}, nil) + dialer := dialer.New(nil, fakeValidator{}, nil) api := New(logger, &logging.NopLogger{}, core, dialer, nil, nil, nil) @@ -240,7 +240,7 @@ func TestActivateAdditionalCoordinators(t *testing.T) { clusterID: []byte("clusterID"), } netDialer := testdialer.NewBufconnDialer() - dialer := grpcutil.NewDialer(fakeValidator{}, netDialer) + dialer := dialer.New(nil, fakeValidator{}, netDialer) getPublicIPAddr := func() (string, error) { return "192.0.2.1", nil @@ -297,7 +297,7 @@ func TestGetPeerVPNPublicKey(t *testing.T) { vpnPubKey: tc.coordinator.peer.VPNPubKey, getvpnPubKeyErr: tc.getVPNPubKeyErr, } - dialer := grpcutil.NewDialer(fakeValidator{}, testdialer.NewBufconnDialer()) + dialer := dialer.New(nil, fakeValidator{}, testdialer.NewBufconnDialer()) getPublicIPAddr := func() (string, error) { return "192.0.2.1", nil diff --git a/coordinator/pubapi/node_test.go b/coordinator/pubapi/node_test.go index c7499e8e6..e32739cd8 100644 --- a/coordinator/pubapi/node_test.go +++ b/coordinator/pubapi/node_test.go @@ -14,18 +14,18 @@ import ( "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/role" "github.com/edgelesssys/constellation/coordinator/state" - "github.com/edgelesssys/constellation/coordinator/util/grpcutil" - "github.com/edgelesssys/constellation/coordinator/util/testdialer" "github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto" "github.com/edgelesssys/constellation/internal/atls" "github.com/edgelesssys/constellation/internal/deploy/ssh" "github.com/edgelesssys/constellation/internal/deploy/user" + "github.com/edgelesssys/constellation/internal/grpc/atlscredentials" + "github.com/edgelesssys/constellation/internal/grpc/dialer" + "github.com/edgelesssys/constellation/internal/grpc/testdialer" "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap/zaptest" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" kubeadm "k8s.io/kubernetes/cmd/kubeadm/app/apis/kubeadm/v1beta3" ) @@ -152,7 +152,7 @@ func TestActivateAsNode(t *testing.T) { linuxUserManager := user.NewLinuxUserManagerFake(fs) cor := &fakeCore{state: tc.state, vpnPubKey: vpnPubKey, setVPNIPErr: tc.setVPNIPErr, linuxUserManager: linuxUserManager} netDialer := testdialer.NewBufconnDialer() - dialer := grpcutil.NewDialer(fakeValidator{}, netDialer) + dialer := dialer.New(nil, fakeValidator{}, netDialer) api := New(logger, &logging.NopLogger{}, cor, dialer, nil, nil, nil) defer api.Close() @@ -163,9 +163,8 @@ func TestActivateAsNode(t *testing.T) { go vserver.Serve(netDialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort))) defer vserver.GracefulStop() - tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{}, nil) - require.NoError(err) - pubserver := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) + creds := atlscredentials.New(&core.MockIssuer{}, nil) + pubserver := grpc.NewServer(grpc.Creds(creds)) pubproto.RegisterAPIServer(pubserver, api) go pubserver.Serve(netDialer.GetListener(net.JoinHostPort(nodeIP, endpointAVPNPort))) defer pubserver.GracefulStop() @@ -260,7 +259,7 @@ func TestTriggerNodeUpdate(t *testing.T) { logger := zaptest.NewLogger(t) core := &fakeCore{state: tc.state} netDialer := testdialer.NewBufconnDialer() - dialer := grpcutil.NewDialer(fakeValidator{}, netDialer) + dialer := dialer.New(nil, fakeValidator{}, netDialer) api := New(logger, &logging.NopLogger{}, core, dialer, nil, nil, nil) @@ -336,7 +335,7 @@ func TestJoinCluster(t *testing.T) { logger := zaptest.NewLogger(t) core := &fakeCore{state: tc.state, joinClusterErr: tc.joinClusterErr} netDialer := testdialer.NewBufconnDialer() - dialer := grpcutil.NewDialer(fakeValidator{}, netDialer) + dialer := dialer.New(nil, fakeValidator{}, netDialer) api := New(logger, &logging.NopLogger{}, core, dialer, nil, nil, nil) @@ -433,16 +432,13 @@ func activateNode(require *require.Assertions, dialer netDialer, messageSequence } func dialGRPC(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) { - tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, []atls.Validator{&core.MockValidator{}}) - if err != nil { - return nil, err - } + creds := atlscredentials.New(nil, []atls.Validator{&core.MockValidator{}}) return grpc.DialContext(ctx, target, grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { return dialer.DialContext(ctx, "tcp", addr) }), - grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), + grpc.WithTransportCredentials(creds), ) } diff --git a/debugd/debugd/deploy/download_test.go b/debugd/debugd/deploy/download_test.go index c0d08df01..29e04df30 100644 --- a/debugd/debugd/deploy/download_test.go +++ b/debugd/debugd/deploy/download_test.go @@ -9,10 +9,10 @@ import ( "testing" "time" - "github.com/edgelesssys/constellation/coordinator/util/testdialer" "github.com/edgelesssys/constellation/debugd/coordinator" "github.com/edgelesssys/constellation/debugd/debugd" pb "github.com/edgelesssys/constellation/debugd/service" + "github.com/edgelesssys/constellation/internal/grpc/testdialer" "github.com/stretchr/testify/assert" "google.golang.org/grpc" "google.golang.org/protobuf/proto" diff --git a/debugd/debugd/server/server_test.go b/debugd/debugd/server/server_test.go index 4582aca2d..b6ed0c973 100644 --- a/debugd/debugd/server/server_test.go +++ b/debugd/debugd/server/server_test.go @@ -8,11 +8,11 @@ import ( "net" "testing" - "github.com/edgelesssys/constellation/coordinator/util/testdialer" "github.com/edgelesssys/constellation/debugd/coordinator" "github.com/edgelesssys/constellation/debugd/debugd/deploy" pb "github.com/edgelesssys/constellation/debugd/service" "github.com/edgelesssys/constellation/internal/deploy/ssh" + "github.com/edgelesssys/constellation/internal/grpc/testdialer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" diff --git a/internal/grpc/atlscredentials/atlscredentials.go b/internal/grpc/atlscredentials/atlscredentials.go new file mode 100644 index 000000000..280d28b68 --- /dev/null +++ b/internal/grpc/atlscredentials/atlscredentials.go @@ -0,0 +1,53 @@ +package atlscredentials + +import ( + "context" + "errors" + "net" + + "github.com/edgelesssys/constellation/internal/atls" + "google.golang.org/grpc/credentials" +) + +type Credentials struct { + issuer atls.Issuer + validators []atls.Validator +} + +func New(issuer atls.Issuer, validators []atls.Validator) *Credentials { + return &Credentials{ + issuer: issuer, + validators: validators, + } +} + +func (c *Credentials) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + clientCfg, err := atls.CreateAttestationClientTLSConfig(c.issuer, c.validators) + if err != nil { + return nil, nil, err + } + + return credentials.NewTLS(clientCfg).ClientHandshake(ctx, authority, rawConn) +} + +func (c *Credentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + serverCfg, err := atls.CreateAttestationServerTLSConfig(c.issuer, c.validators) + if err != nil { + return nil, nil, err + } + + return credentials.NewTLS(serverCfg).ServerHandshake(rawConn) +} + +func (c *Credentials) Info() credentials.ProtocolInfo { + return credentials.NewTLS(nil).Info() +} + +func (c *Credentials) Clone() credentials.TransportCredentials { + cloned := *c + return &cloned +} + +func (c *Credentials) OverrideServerName(s string) error { + return errors.New("cannot override server name") +} diff --git a/internal/grpc/atlscredentials/atlscredentials_test.go b/internal/grpc/atlscredentials/atlscredentials_test.go new file mode 100644 index 000000000..38824121f --- /dev/null +++ b/internal/grpc/atlscredentials/atlscredentials_test.go @@ -0,0 +1,120 @@ +package atlscredentials + +import ( + "bytes" + "context" + "encoding/asn1" + "encoding/json" + "errors" + "net" + "testing" + + "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" + "github.com/edgelesssys/constellation/internal/atls" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + "google.golang.org/grpc" + "google.golang.org/grpc/test/bufconn" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestATLSCredentials(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + oid := fakeOID{1, 3, 9900, 1} + + // + // Create servers + // + + serverCreds := New(fakeIssuer{fakeOID: oid}, nil) + + const serverCount = 15 + var listeners []*bufconn.Listener + for i := 0; i < serverCount; i++ { + api := &fakeAPI{} + server := grpc.NewServer(grpc.Creds(serverCreds)) + pubproto.RegisterAPIServer(server, api) + + listener := bufconn.Listen(1024) + listeners = append(listeners, listener) + + defer server.GracefulStop() + go server.Serve(listener) + } + + // + // Dial concurrently + // + + clientCreds := New(nil, []atls.Validator{fakeValidator{fakeOID: oid}}) + + errChan := make(chan error, serverCount) + for _, listener := range listeners { + lis := listener + go func() { + var err error + defer func() { errChan <- err }() + conn, err := grpc.DialContext(context.Background(), "", grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + return lis.Dial() + }), grpc.WithTransportCredentials(clientCreds)) + require.NoError(err) + defer conn.Close() + + client := pubproto.NewAPIClient(conn) + _, err = client.GetState(context.Background(), &pubproto.GetStateRequest{}) + }() + } + + for i := 0; i < serverCount; i++ { + assert.NoError(<-errChan) + } +} + +type fakeIssuer struct { + fakeOID +} + +func (fakeIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) { + return json.Marshal(fakeDoc{UserData: userData, Nonce: nonce}) +} + +type fakeValidator struct { + fakeOID + err error +} + +func (v fakeValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) { + var doc fakeDoc + if err := json.Unmarshal(attDoc, &doc); err != nil { + return nil, err + } + if !bytes.Equal(doc.Nonce, nonce) { + return nil, errors.New("invalid nonce") + } + return doc.UserData, v.err +} + +type fakeOID asn1.ObjectIdentifier + +func (o fakeOID) OID() asn1.ObjectIdentifier { + return asn1.ObjectIdentifier(o) +} + +type fakeDoc struct { + UserData []byte + Nonce []byte +} + +type fakeAPI struct { + pubproto.UnimplementedAPIServer +} + +func (f *fakeAPI) GetState(ctx context.Context, in *pubproto.GetStateRequest) (*pubproto.GetStateResponse, error) { + return &pubproto.GetStateResponse{State: 1}, nil +} diff --git a/coordinator/util/grpcutil/dialer.go b/internal/grpc/dialer/dialer.go similarity index 76% rename from coordinator/util/grpcutil/dialer.go rename to internal/grpc/dialer/dialer.go index feb4db625..621733f37 100644 --- a/coordinator/util/grpcutil/dialer.go +++ b/internal/grpc/dialer/dialer.go @@ -1,24 +1,26 @@ -package grpcutil +package dialer import ( "context" "net" "github.com/edgelesssys/constellation/internal/atls" + "github.com/edgelesssys/constellation/internal/grpc/atlscredentials" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" ) // Dialer can open grpc client connections with different levels of ATLS encryption / verification. type Dialer struct { + issuer atls.Issuer validator atls.Validator netDialer NetDialer } -// NewDialer creates a new Dialer. -func NewDialer(validator atls.Validator, netDialer NetDialer) *Dialer { +// New creates a new Dialer. +func New(issuer atls.Issuer, validator atls.Validator, netDialer NetDialer) *Dialer { return &Dialer{ + issuer: issuer, validator: validator, netDialer: netDialer, } @@ -26,14 +28,11 @@ func NewDialer(validator atls.Validator, netDialer NetDialer) *Dialer { // Dial creates a new grpc client connection to the given target using the atls validator. func (d *Dialer) Dial(ctx context.Context, target string) (*grpc.ClientConn, error) { - tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, []atls.Validator{d.validator}) - if err != nil { - return nil, err - } + credentials := atlscredentials.New(d.issuer, []atls.Validator{d.validator}) return grpc.DialContext(ctx, target, d.grpcWithDialer(), - grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), + grpc.WithTransportCredentials(credentials), ) } @@ -48,14 +47,11 @@ func (d *Dialer) DialInsecure(ctx context.Context, target string) (*grpc.ClientC // DialNoVerify creates a new grpc client connection to the given target without verifying the server's attestation. func (d *Dialer) DialNoVerify(ctx context.Context, target string) (*grpc.ClientConn, error) { - tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, nil) - if err != nil { - return nil, err - } + credentials := atlscredentials.New(nil, nil) return grpc.DialContext(ctx, target, d.grpcWithDialer(), - grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), + grpc.WithTransportCredentials(credentials), ) } diff --git a/coordinator/util/grpcutil/dialer_test.go b/internal/grpc/dialer/dialer_test.go similarity index 86% rename from coordinator/util/grpcutil/dialer_test.go rename to internal/grpc/dialer/dialer_test.go index 772ca7768..aa6dfab4a 100644 --- a/coordinator/util/grpcutil/dialer_test.go +++ b/internal/grpc/dialer/dialer_test.go @@ -1,16 +1,15 @@ -package grpcutil +package dialer import ( "context" "testing" "github.com/edgelesssys/constellation/coordinator/core" - "github.com/edgelesssys/constellation/coordinator/util/testdialer" - "github.com/edgelesssys/constellation/internal/atls" + "github.com/edgelesssys/constellation/internal/grpc/atlscredentials" + "github.com/edgelesssys/constellation/internal/grpc/testdialer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" "google.golang.org/grpc/test/grpc_testing" ) @@ -64,7 +63,7 @@ func TestDial(t *testing.T) { require := require.New(t) netDialer := testdialer.NewBufconnDialer() - dialer := NewDialer(&core.MockValidator{}, netDialer) + dialer := New(nil, &core.MockValidator{}, netDialer) server := newServer(tc.tls) api := &testAPI{} grpc_testing.RegisterTestServiceServer(server, api) @@ -88,11 +87,8 @@ func TestDial(t *testing.T) { func newServer(tls bool) *grpc.Server { if tls { - tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{}, nil) - if err != nil { - panic(err) - } - return grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) + creds := atlscredentials.New(&core.MockIssuer{}, nil) + return grpc.NewServer(grpc.Creds(creds)) } return grpc.NewServer() } diff --git a/coordinator/util/testdialer/bufconndialer.go b/internal/grpc/testdialer/bufconndialer.go similarity index 100% rename from coordinator/util/testdialer/bufconndialer.go rename to internal/grpc/testdialer/bufconndialer.go diff --git a/internal/statuswaiter/statuswaiter.go b/internal/statuswaiter/statuswaiter.go index ffb0a2374..d07f94b64 100644 --- a/internal/statuswaiter/statuswaiter.go +++ b/internal/statuswaiter/statuswaiter.go @@ -9,9 +9,9 @@ import ( "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/internal/atls" + "github.com/edgelesssys/constellation/internal/grpc/atlscredentials" "google.golang.org/grpc" grpccodes "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials" grpcstatus "google.golang.org/grpc/status" ) @@ -113,13 +113,10 @@ func (w *Waiter) WaitForAll(ctx context.Context, endpoints []string, status ...s // newAttestedConnGenerator creates a function returning a default attested grpc connection. func newAttestedConnGenerator(validators []atls.Validator) func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) { return func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) { - tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, validators) - if err != nil { - return nil, err - } + creds := atlscredentials.New(nil, validators) return grpc.DialContext( - ctx, target, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), + ctx, target, grpc.WithTransportCredentials(creds), ) } } diff --git a/state/keyservice/keyservice.go b/state/keyservice/keyservice.go index a59d3abc4..6dbcbb7a4 100644 --- a/state/keyservice/keyservice.go +++ b/state/keyservice/keyservice.go @@ -2,7 +2,6 @@ package keyservice import ( "context" - "crypto/tls" "errors" "log" "net" @@ -12,7 +11,7 @@ import ( "github.com/edgelesssys/constellation/coordinator/config" "github.com/edgelesssys/constellation/coordinator/core" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" - "github.com/edgelesssys/constellation/internal/atls" + "github.com/edgelesssys/constellation/internal/grpc/atlscredentials" "github.com/edgelesssys/constellation/state/keyservice/keyproto" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -63,11 +62,8 @@ func (a *KeyAPI) WaitForDecryptionKey(uuid, listenAddr string) ([]byte, error) { return nil, errors.New("received no disk UUID") } - tlsConfig, err := atls.CreateAttestationServerTLSConfig(a.issuer, nil) - if err != nil { - return nil, err - } - server := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) + creds := atlscredentials.New(a.issuer, nil) + server := grpc.NewServer(grpc.Creds(creds)) keyproto.RegisterAPIServer(server, a) listener, err := net.Listen("tcp", listenAddr) if err != nil { @@ -95,11 +91,7 @@ func (a *KeyAPI) ResetKey() { func (a *KeyAPI) requestKeyLoop(uuid string, opts ...grpc.DialOption) error { // we do not perform attestation, since the restarting node does not need to care about notifying the correct Coordinator // if an incorrect key is pushed by a malicious actor, decrypting the disk will fail, and the node will not start - tlsClientConfig, err := atls.CreateAttestationClientTLSConfig(nil, nil) - if err != nil { - return err - } - + creds := atlscredentials.New(nil, nil) // set up for the select statement to immediately request a key, skipping the initial delay caused by using a ticker firstReq := make(chan struct{}, 1) firstReq <- struct{}{} @@ -115,14 +107,14 @@ func (a *KeyAPI) requestKeyLoop(uuid string, opts ...grpc.DialOption) error { case <-a.keyReceived: return nil case <-ticker.C: - a.requestKey(uuid, tlsClientConfig, opts...) + a.requestKey(uuid, creds, opts...) case <-firstReq: - a.requestKey(uuid, tlsClientConfig, opts...) + a.requestKey(uuid, creds, opts...) } } } -func (a *KeyAPI) requestKey(uuid string, tlsClientConfig *tls.Config, opts ...grpc.DialOption) { +func (a *KeyAPI) requestKey(uuid string, credentials credentials.TransportCredentials, opts ...grpc.DialOption) { // list available Coordinators endpoints, _ := core.CoordinatorEndpoints(context.Background(), a.metadata) @@ -131,7 +123,7 @@ func (a *KeyAPI) requestKey(uuid string, tlsClientConfig *tls.Config, opts ...gr // any errors encountered here will be ignored, and the calls retried after a timeout for _, endpoint := range endpoints { ctx, cancel := context.WithTimeout(context.Background(), a.timeout) - conn, err := grpc.DialContext(ctx, endpoint, append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tlsClientConfig)))...) + conn, err := grpc.DialContext(ctx, endpoint, append(opts, grpc.WithTransportCredentials(credentials))...) if err == nil { client := pubproto.NewAPIClient(conn) _, _ = client.RequestStateDiskKey(ctx, &pubproto.RequestStateDiskKeyRequest{DiskUuid: uuid}) diff --git a/state/keyservice/keyservice_test.go b/state/keyservice/keyservice_test.go index ddd206a6e..95336ba05 100644 --- a/state/keyservice/keyservice_test.go +++ b/state/keyservice/keyservice_test.go @@ -11,12 +11,11 @@ import ( "github.com/edgelesssys/constellation/coordinator/core" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/role" - "github.com/edgelesssys/constellation/internal/atls" + "github.com/edgelesssys/constellation/internal/grpc/atlscredentials" "github.com/edgelesssys/constellation/state/keyservice/keyproto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" "google.golang.org/grpc/test/bufconn" ) @@ -76,9 +75,8 @@ func TestRequestKeyLoop(t *testing.T) { listener := bufconn.Listen(1) defer listener.Close() - tlsConfig, err := atls.CreateAttestationServerTLSConfig(core.NewMockIssuer(), nil) - require.NoError(err) - s := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) + creds := atlscredentials.New(core.NewMockIssuer(), nil) + s := grpc.NewServer(grpc.Creds(creds)) pubproto.RegisterAPIServer(s, tc.server) if !tc.dontStartServer { @@ -97,7 +95,7 @@ func TestRequestKeyLoop(t *testing.T) { keyReceived <- struct{}{} }() - err = keyWaiter.requestKeyLoop( + err := keyWaiter.requestKeyLoop( "1234", grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) { return listener.DialContext(ctx) diff --git a/state/test/integration_test.go b/state/test/integration_test.go index ff232ac2f..4876ac487 100644 --- a/state/test/integration_test.go +++ b/state/test/integration_test.go @@ -12,7 +12,7 @@ import ( "time" "github.com/edgelesssys/constellation/coordinator/core" - "github.com/edgelesssys/constellation/internal/atls" + "github.com/edgelesssys/constellation/internal/grpc/atlscredentials" "github.com/edgelesssys/constellation/state/keyservice" "github.com/edgelesssys/constellation/state/keyservice/keyproto" "github.com/edgelesssys/constellation/state/mapper" @@ -20,7 +20,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" ) const ( @@ -91,9 +90,8 @@ func TestKeyAPI(t *testing.T) { // wait 2 seconds before sending the key time.Sleep(2 * time.Second) - clientCfg, err := atls.CreateAttestationClientTLSConfig(nil, nil) - require.NoError(err) - conn, err := grpc.Dial(apiAddr, grpc.WithTransportCredentials(credentials.NewTLS(clientCfg))) + creds := atlscredentials.New(nil, nil) + conn, err := grpc.Dial(apiAddr, grpc.WithTransportCredentials(creds)) require.NoError(err) defer conn.Close() diff --git a/test/coordinator_integration_test.go b/test/coordinator_integration_test.go index 323095784..a6fdea5f5 100644 --- a/test/coordinator_integration_test.go +++ b/test/coordinator_integration_test.go @@ -4,7 +4,6 @@ package integration import ( "context" - "crypto/tls" "errors" "fmt" "io" @@ -27,6 +26,7 @@ import ( "github.com/edgelesssys/constellation/coordinator/store" "github.com/edgelesssys/constellation/coordinator/storewrapper" "github.com/edgelesssys/constellation/internal/atls" + "github.com/edgelesssys/constellation/internal/grpc/atlscredentials" kms "github.com/edgelesssys/constellation/kms/server/setup" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -249,12 +249,9 @@ func TestMain(t *testing.T) { // helper methods func startCoordinator(ctx context.Context, coordinatorAddr string, ips []string) error { - tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, []atls.Validator{&core.MockValidator{}}) - if err != nil { - return err - } + creds := atlscredentials.New(nil, []atls.Validator{&core.MockValidator{}}) - conn, err := grpc.DialContext(ctx, net.JoinHostPort(coordinatorAddr, publicgRPCPort), grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) + conn, err := grpc.DialContext(ctx, net.JoinHostPort(coordinatorAddr, publicgRPCPort), grpc.WithTransportCredentials(creds)) if err != nil { return err } @@ -299,12 +296,9 @@ func createTempDir() error { } func addNewCoordinatorToCoordinator(ctx context.Context, newCoordinatorAddr, oldCoordinatorAddr string) error { - tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, []atls.Validator{&core.MockValidator{}}) - if err != nil { - return err - } + creds := atlscredentials.New(nil, []atls.Validator{&core.MockValidator{}}) - conn, err := grpc.DialContext(ctx, net.JoinHostPort(oldCoordinatorAddr, publicgRPCPort), grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) + conn, err := grpc.DialContext(ctx, net.JoinHostPort(oldCoordinatorAddr, publicgRPCPort), grpc.WithTransportCredentials(creds)) if err != nil { return err } @@ -322,12 +316,9 @@ func addNewCoordinatorToCoordinator(ctx context.Context, newCoordinatorAddr, old } func addNewNodesToCoordinator(ctx context.Context, coordinatorAddr string, ips []string) error { - tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, []atls.Validator{&core.MockValidator{}}) - if err != nil { - return err - } + creds := atlscredentials.New(nil, []atls.Validator{&core.MockValidator{}}) - conn, err := grpc.DialContext(ctx, net.JoinHostPort(coordinatorAddr, publicgRPCPort), grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) + conn, err := grpc.DialContext(ctx, net.JoinHostPort(coordinatorAddr, publicgRPCPort), grpc.WithTransportCredentials(creds)) if err != nil { return err } @@ -533,11 +524,11 @@ func createNewNode(ctx context.Context, cli *client.Client) (*newNodeData, error return &newNodeData{resp, containerData.NetworkSettings.IPAddress}, nil } -func awaitPeerResponse(ctx context.Context, ip string, tlsConfig *tls.Config) error { +func awaitPeerResponse(ctx context.Context, ip string, credentials credentials.TransportCredentials) error { // Block, so the connection gets established/fails immediately ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - conn, err := grpc.DialContext(ctx, net.JoinHostPort(ip, publicgRPCPort), grpc.WithBlock(), grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) + conn, err := grpc.DialContext(ctx, net.JoinHostPort(ip, publicgRPCPort), grpc.WithBlock(), grpc.WithTransportCredentials(credentials)) if err != nil { return err } @@ -545,13 +536,10 @@ func awaitPeerResponse(ctx context.Context, ip string, tlsConfig *tls.Config) er } func blockUntilUp(ctx context.Context, peerIPs []string) error { - tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, []atls.Validator{&core.MockValidator{}}) - if err != nil { - return err - } + creds := atlscredentials.New(nil, []atls.Validator{&core.MockValidator{}}) for _, ip := range peerIPs { // Block, so the connection gets established/fails immediately - if err := awaitPeerResponse(ctx, ip, tlsConfig); err != nil { + if err := awaitPeerResponse(ctx, ip, creds); err != nil { return err } }