Dynamic grpc client credentials (#204)

* Add an aTLS wrapper for grpc credentials

* Move grpc dialers to internal and use aTLS grpc credentials

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2022-06-13 11:40:27 +02:00 committed by GitHub
parent 6e9428a234
commit 1e19e64fbc
25 changed files with 291 additions and 189 deletions

View File

@ -12,6 +12,7 @@ import (
"github.com/edgelesssys/constellation/internal/atls" "github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/constants" "github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/file" "github.com/edgelesssys/constellation/internal/file"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/spf13/afero" "github.com/spf13/afero"
"k8s.io/klog/v2" "k8s.io/klog/v2"
) )
@ -36,10 +37,7 @@ func main() {
klog.Exitf("failed to create validator: %s", err) klog.Exitf("failed to create validator: %s", err)
} }
tlsConfig, err := atls.CreateAttestationServerTLSConfig(nil, []atls.Validator{validator}) creds := atlscredentials.New(nil, []atls.Validator{validator})
if err != nil {
klog.Exitf("unable to create server config: %s", err)
}
kubeadm, err := kubeadm.New() kubeadm, err := kubeadm.New()
if err != nil { if err != nil {
@ -62,7 +60,7 @@ func main() {
} }
}() }()
if err := server.Run(tlsConfig, bindPort); err != nil { if err := server.Run(creds, bindPort); err != nil {
klog.Exitf("failed to run server: %s", err) klog.Exitf("failed to run server: %s", err)
} }
} }

View File

@ -2,7 +2,6 @@ package server
import ( import (
"context" "context"
"crypto/tls"
"fmt" "fmt"
"net" "net"
"time" "time"
@ -39,9 +38,9 @@ func New(fileHandler file.Handler, ca certificateAuthority, joinTokenGetter join
} }
// Run starts the gRPC server on the given port, using the provided tlsConfig. // Run starts the gRPC server on the given port, using the provided tlsConfig.
func (s *Server) Run(tlsConfig *tls.Config, port string) error { func (s *Server) Run(creds credentials.TransportCredentials, port string) error {
grpcServer := grpc.NewServer( grpcServer := grpc.NewServer(
grpc.Creds(credentials.NewTLS(tlsConfig)), grpc.Creds(creds),
grpc.UnaryInterceptor(logGRPC), grpc.UnaryInterceptor(logGRPC),
) )

View File

@ -8,10 +8,10 @@ import (
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/internal/atls" "github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
kms "github.com/edgelesssys/constellation/kms/server/setup" kms "github.com/edgelesssys/constellation/kms/server/setup"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials"
) )
// Client wraps a PubAPI client and the connection to it. // Client wraps a PubAPI client and the connection to it.
@ -26,12 +26,9 @@ type Client struct {
// called on a client that already has a connection, the old // called on a client that already has a connection, the old
// connection is closed. // connection is closed.
func (c *Client) Connect(endpoint string, validators []atls.Validator) error { func (c *Client) Connect(endpoint string, validators []atls.Validator) error {
tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, validators) creds := atlscredentials.New(nil, validators)
if err != nil {
return err
}
conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(creds))
if err != nil { if err != nil {
return err return err
} }

View File

@ -5,9 +5,9 @@ import (
"errors" "errors"
"github.com/edgelesssys/constellation/internal/atls" "github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/state/keyservice/keyproto" "github.com/edgelesssys/constellation/state/keyservice/keyproto"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials"
) )
// KeyClient wraps a KeyAPI client and the connection to it. // KeyClient wraps a KeyAPI client and the connection to it.
@ -22,12 +22,9 @@ type KeyClient struct {
// called on a client that already has a connection, the old // called on a client that already has a connection, the old
// connection is closed. // connection is closed.
func (c *KeyClient) Connect(endpoint string, validators []atls.Validator) error { func (c *KeyClient) Connect(endpoint string, validators []atls.Validator) error {
tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, validators) creds := atlscredentials.New(nil, validators)
if err != nil {
return err
}
conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(creds))
if err != nil { if err != nil {
return err return err
} }

View File

@ -20,7 +20,6 @@ import (
"github.com/edgelesssys/constellation/coordinator/kubernetes/k8sapi/kubectl" "github.com/edgelesssys/constellation/coordinator/kubernetes/k8sapi/kubectl"
"github.com/edgelesssys/constellation/coordinator/logging" "github.com/edgelesssys/constellation/coordinator/logging"
"github.com/edgelesssys/constellation/coordinator/util" "github.com/edgelesssys/constellation/coordinator/util"
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
"github.com/edgelesssys/constellation/coordinator/wireguard" "github.com/edgelesssys/constellation/coordinator/wireguard"
"github.com/edgelesssys/constellation/internal/attestation/azure" "github.com/edgelesssys/constellation/internal/attestation/azure"
"github.com/edgelesssys/constellation/internal/attestation/gcp" "github.com/edgelesssys/constellation/internal/attestation/gcp"
@ -28,6 +27,7 @@ import (
"github.com/edgelesssys/constellation/internal/attestation/simulator" "github.com/edgelesssys/constellation/internal/attestation/simulator"
"github.com/edgelesssys/constellation/internal/attestation/vtpm" "github.com/edgelesssys/constellation/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/internal/file" "github.com/edgelesssys/constellation/internal/file"
"github.com/edgelesssys/constellation/internal/grpc/dialer"
grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap" grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap"
"github.com/spf13/afero" "github.com/spf13/afero"
"go.uber.org/zap" "go.uber.org/zap"
@ -177,7 +177,7 @@ func main() {
fileHandler := file.NewHandler(fs) fileHandler := file.NewHandler(fs)
netDialer := &net.Dialer{} netDialer := &net.Dialer{}
dialer := grpcutil.NewDialer(validator, netDialer) dialer := dialer.New(nil, validator, netDialer)
run(issuer, wg, openTPM, util.GetIPAddr, dialer, fileHandler, kube, run(issuer, wg, openTPM, util.GetIPAddr, dialer, fileHandler, kube,
coreMetadata, encryptedDisk, etcdEndpoint, enforceEtcdTls, bindIP, coreMetadata, encryptedDisk, etcdEndpoint, enforceEtcdTls, bindIP,
bindPort, zapLoggerCore, cloudLogger, fs) bindPort, zapLoggerCore, cloudLogger, fs)

View File

@ -12,25 +12,24 @@ import (
"github.com/edgelesssys/constellation/coordinator/pubapi" "github.com/edgelesssys/constellation/coordinator/pubapi"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/store" "github.com/edgelesssys/constellation/coordinator/store"
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
"github.com/edgelesssys/constellation/coordinator/vpnapi" "github.com/edgelesssys/constellation/coordinator/vpnapi"
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto" "github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
"github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/attestation/vtpm" "github.com/edgelesssys/constellation/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/internal/deploy/user" "github.com/edgelesssys/constellation/internal/deploy/user"
"github.com/edgelesssys/constellation/internal/file" "github.com/edgelesssys/constellation/internal/file"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/internal/grpc/dialer"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap" grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap"
grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
"github.com/spf13/afero" "github.com/spf13/afero"
"go.uber.org/zap" "go.uber.org/zap"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials"
) )
var version = "0.0.0" var version = "0.0.0"
func run(issuer core.QuoteIssuer, vpn core.VPN, tpm vtpm.TPMOpenFunc, getPublicIPAddr func() (string, error), dialer *grpcutil.Dialer, fileHandler file.Handler, func run(issuer core.QuoteIssuer, vpn core.VPN, tpm vtpm.TPMOpenFunc, getPublicIPAddr func() (string, error), dialer *dialer.Dialer, fileHandler file.Handler,
kube core.Cluster, metadata core.ProviderMetadata, disk core.EncryptedDisk, etcdEndpoint string, etcdTLS bool, bindIP, bindPort string, logger *zap.Logger, kube core.Cluster, metadata core.ProviderMetadata, disk core.EncryptedDisk, etcdEndpoint string, etcdTLS bool, bindIP, bindPort string, logger *zap.Logger,
cloudLogger logging.CloudLogger, fs afero.Fs, cloudLogger logging.CloudLogger, fs afero.Fs,
) { ) {
@ -40,10 +39,7 @@ func run(issuer core.QuoteIssuer, vpn core.VPN, tpm vtpm.TPMOpenFunc, getPublicI
defer cloudLogger.Close() defer cloudLogger.Close()
cloudLogger.Disclose("Coordinator started running...") cloudLogger.Disclose("Coordinator started running...")
tlsConfig, err := atls.CreateAttestationServerTLSConfig(issuer, nil) creds := atlscredentials.New(issuer, nil)
if err != nil {
logger.Fatal("failed to create server TLS config", zap.Error(err))
}
etcdStoreFactory := store.NewEtcdStoreFactory(etcdEndpoint, etcdTLS, logger) etcdStoreFactory := store.NewEtcdStoreFactory(etcdEndpoint, etcdTLS, logger)
linuxUserManager := user.NewLinuxUserManager(fs) linuxUserManager := user.NewLinuxUserManager(fs)
@ -64,7 +60,7 @@ func run(issuer core.QuoteIssuer, vpn core.VPN, tpm vtpm.TPMOpenFunc, getPublicI
zapLoggergRPC := loggerPubAPI.Named("gRPC") zapLoggergRPC := loggerPubAPI.Named("gRPC")
grpcServer := grpc.NewServer( grpcServer := grpc.NewServer(
grpc.Creds(credentials.NewTLS(tlsConfig)), grpc.Creds(creds),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
grpc_ctxtags.StreamServerInterceptor(), grpc_ctxtags.StreamServerInterceptor(),
grpc_zap.StreamServerInterceptor(zapLoggergRPC), grpc_zap.StreamServerInterceptor(zapLoggergRPC),
@ -117,14 +113,11 @@ func tryJoinClusterOnStartup(getPublicIPAddr func() (string, error), metadata co
// We create an client unverified connection, since the node does not need to verify the Coordinator. // We create an client unverified connection, since the node does not need to verify the Coordinator.
// ActivateAdditionalNodes triggers the Coordinator to call ActivateAsNode. This rpc lets the Coordinator verify the node. // ActivateAdditionalNodes triggers the Coordinator to call ActivateAsNode. This rpc lets the Coordinator verify the node.
tlsClientConfig, err := atls.CreateAttestationClientTLSConfig(nil, nil) creds := atlscredentials.New(nil, nil)
if err != nil {
return fmt.Errorf("failed to create client TLS config: %w", err)
}
// try to notify a coordinator to activate this node // try to notify a coordinator to activate this node
for _, coordinatorEndpoint := range coordinatorEndpoints { for _, coordinatorEndpoint := range coordinatorEndpoints {
conn, err := grpc.Dial(coordinatorEndpoint, grpc.WithTransportCredentials(credentials.NewTLS(tlsClientConfig))) conn, err := grpc.Dial(coordinatorEndpoint, grpc.WithTransportCredentials(creds))
if err != nil { if err != nil {
logger.Info("Dial failed:", zap.String("endpoint", coordinatorEndpoint), zap.Error(err)) logger.Info("Dial failed:", zap.String("endpoint", coordinatorEndpoint), zap.Error(err))
continue continue

View File

@ -15,14 +15,15 @@ import (
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/coordinator/store" "github.com/edgelesssys/constellation/coordinator/store"
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
"github.com/edgelesssys/constellation/coordinator/vpnapi" "github.com/edgelesssys/constellation/coordinator/vpnapi"
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto" "github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
"github.com/edgelesssys/constellation/internal/atls" "github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/attestation/simulator" "github.com/edgelesssys/constellation/internal/attestation/simulator"
"github.com/edgelesssys/constellation/internal/deploy/user" "github.com/edgelesssys/constellation/internal/deploy/user"
"github.com/edgelesssys/constellation/internal/file" "github.com/edgelesssys/constellation/internal/file"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/internal/grpc/dialer"
"github.com/edgelesssys/constellation/internal/grpc/testdialer"
kms "github.com/edgelesssys/constellation/kms/server/setup" kms "github.com/edgelesssys/constellation/kms/server/setup"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -31,7 +32,6 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
"go.uber.org/zap/zaptest" "go.uber.org/zap/zaptest"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@ -221,14 +221,13 @@ func spawnPeer(require *require.Assertions, logger *zap.Logger, netDialer *testd
getPublicAddr := func() (string, error) { getPublicAddr := func() (string, error) {
return "192.0.2.1", nil return "192.0.2.1", nil
} }
dialer := grpcutil.NewDialer(&core.MockValidator{}, netDialer) dialer := dialer.New(nil, &core.MockValidator{}, netDialer)
vapiServer := &fakeVPNAPIServer{logger: logger.Named("vpnapi"), core: cor, dialer: netDialer} vapiServer := &fakeVPNAPIServer{logger: logger.Named("vpnapi"), core: cor, dialer: netDialer}
papi := pubapi.New(logger, &logging.NopLogger{}, cor, dialer, vapiServer, getPublicAddr, nil) papi := pubapi.New(logger, &logging.NopLogger{}, cor, dialer, vapiServer, getPublicAddr, nil)
tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{}, nil) creds := atlscredentials.New(&core.MockIssuer{}, nil)
require.NoError(err) server := grpc.NewServer(grpc.Creds(creds))
server := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
pubproto.RegisterAPIServer(server, papi) pubproto.RegisterAPIServer(server, papi)
listener := netDialer.GetListener(endpoint) listener := netDialer.GetListener(endpoint)
@ -264,16 +263,13 @@ func activateCoordinator(require *require.Assertions, dialer netDialer, coordina
} }
func dialGRPC(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) { func dialGRPC(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) {
tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, []atls.Validator{&core.MockValidator{}}) creds := atlscredentials.New(nil, []atls.Validator{&core.MockValidator{}})
if err != nil {
return nil, err
}
return grpc.DialContext(ctx, target, return grpc.DialContext(ctx, target,
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, "tcp", addr) return dialer.DialContext(ctx, "tcp", addr)
}), }),
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), grpc.WithTransportCredentials(creds),
) )
} }

View File

@ -11,12 +11,12 @@ import (
"github.com/edgelesssys/constellation/coordinator/role" "github.com/edgelesssys/constellation/coordinator/role"
"github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/coordinator/store" "github.com/edgelesssys/constellation/coordinator/store"
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
"github.com/edgelesssys/constellation/internal/attestation/simulator" "github.com/edgelesssys/constellation/internal/attestation/simulator"
"github.com/edgelesssys/constellation/internal/attestation/vtpm" "github.com/edgelesssys/constellation/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/internal/deploy/user" "github.com/edgelesssys/constellation/internal/deploy/user"
"github.com/edgelesssys/constellation/internal/file" "github.com/edgelesssys/constellation/internal/file"
"github.com/edgelesssys/constellation/internal/grpc/dialer"
"github.com/edgelesssys/constellation/internal/grpc/testdialer"
kms "github.com/edgelesssys/constellation/kms/server/setup" kms "github.com/edgelesssys/constellation/kms/server/setup"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -220,7 +220,7 @@ func TestInitialize(t *testing.T) {
// prepare store to emulate initialized KMS // prepare store to emulate initialized KMS
require.NoError(core.data().PutKMSData(kms.KMSInformation{StorageUri: kms.NoStoreURI, KmsUri: kms.ClusterKMSURI})) require.NoError(core.data().PutKMSData(kms.KMSInformation{StorageUri: kms.NoStoreURI, KmsUri: kms.ClusterKMSURI}))
require.NoError(core.data().PutMasterSecret([]byte("master-secret"))) require.NoError(core.data().PutMasterSecret([]byte("master-secret")))
dialer := grpcutil.NewDialer(&MockValidator{}, testdialer.NewBufconnDialer()) dialer := dialer.New(nil, &MockValidator{}, testdialer.NewBufconnDialer())
nodeActivated, err := core.Initialize(context.Background(), dialer, &stubPubAPI{}) nodeActivated, err := core.Initialize(context.Background(), dialer, &stubPubAPI{})
if tc.wantErr { if tc.wantErr {

View File

@ -13,20 +13,19 @@ import (
"github.com/edgelesssys/constellation/coordinator/pubapi" "github.com/edgelesssys/constellation/coordinator/pubapi"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
"github.com/edgelesssys/constellation/coordinator/vpnapi" "github.com/edgelesssys/constellation/coordinator/vpnapi"
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto" "github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
"github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/attestation/simulator" "github.com/edgelesssys/constellation/internal/attestation/simulator"
"github.com/edgelesssys/constellation/internal/deploy/user" "github.com/edgelesssys/constellation/internal/deploy/user"
"github.com/edgelesssys/constellation/internal/file" "github.com/edgelesssys/constellation/internal/file"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/internal/grpc/dialer"
kms "github.com/edgelesssys/constellation/kms/server/setup" kms "github.com/edgelesssys/constellation/kms/server/setup"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap" "go.uber.org/zap"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/test/bufconn" "google.golang.org/grpc/test/bufconn"
) )
@ -121,7 +120,7 @@ func newMockCoreWithDialer(bufDialer *bufconnDialer) (*Core, *pubapi.API, error)
return nil, nil, err return nil, nil, err
} }
dialer := grpcutil.NewDialer(NewMockValidator(), bufDialer) dialer := dialer.New(nil, NewMockValidator(), bufDialer)
vpn := &stubVPN{} vpn := &stubVPN{}
kubeFake := &ClusterFake{} kubeFake := &ClusterFake{}
metadataFake := &ProviderMetadataFake{} metadataFake := &ProviderMetadataFake{}
@ -171,12 +170,9 @@ func (b *bufconnDialer) addListener(endpoint string, listener *bufconn.Listener)
} }
func spawnNode(endpoint string, testNodeCore *pubapi.API, bufDialer *bufconnDialer) (*grpc.Server, error) { func spawnNode(endpoint string, testNodeCore *pubapi.API, bufDialer *bufconnDialer) (*grpc.Server, error) {
tlsConfig, err := atls.CreateAttestationServerTLSConfig(&MockIssuer{}, nil) creds := atlscredentials.New(&MockIssuer{}, nil)
if err != nil {
return nil, err
}
grpcServer := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) grpcServer := grpc.NewServer(grpc.Creds(creds))
pubproto.RegisterAPIServer(grpcServer, testNodeCore) pubproto.RegisterAPIServer(grpcServer, testNodeCore)
const bufferSize = 8 * 1024 const bufferSize = 8 * 1024

View File

@ -9,11 +9,11 @@ import (
"github.com/edgelesssys/constellation/coordinator/peer" "github.com/edgelesssys/constellation/coordinator/peer"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/role" "github.com/edgelesssys/constellation/coordinator/role"
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
"github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/deploy/user" "github.com/edgelesssys/constellation/internal/deploy/user"
"github.com/edgelesssys/constellation/internal/file" "github.com/edgelesssys/constellation/internal/file"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/internal/grpc/dialer"
"github.com/edgelesssys/constellation/internal/grpc/testdialer"
kms "github.com/edgelesssys/constellation/kms/server/setup" kms "github.com/edgelesssys/constellation/kms/server/setup"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -21,7 +21,6 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
"go.uber.org/zap/zaptest" "go.uber.org/zap/zaptest"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
) )
@ -74,7 +73,7 @@ func TestReinitializeAsNode(t *testing.T) {
coordinators := []cloudtypes.Instance{{PrivateIPs: []string{"192.0.2.1"}, Role: role.Coordinator}} coordinators := []cloudtypes.Instance{{PrivateIPs: []string{"192.0.2.1"}, Role: role.Coordinator}}
netDialer := testdialer.NewBufconnDialer() netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(&MockValidator{}, netDialer) dialer := dialer.New(nil, &MockValidator{}, netDialer)
server := newPubAPIServer() server := newPubAPIServer()
api := &pubAPIServerStub{responses: tc.getInitialVPNPeersResponses} api := &pubAPIServerStub{responses: tc.getInitialVPNPeersResponses}
pubproto.RegisterAPIServer(server, api) pubproto.RegisterAPIServer(server, api)
@ -147,7 +146,7 @@ func TestReinitializeAsCoordinator(t *testing.T) {
coordinators := []cloudtypes.Instance{{PrivateIPs: []string{"192.0.2.1"}, Role: role.Coordinator}} coordinators := []cloudtypes.Instance{{PrivateIPs: []string{"192.0.2.1"}, Role: role.Coordinator}}
netDialer := testdialer.NewBufconnDialer() netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(&MockValidator{}, netDialer) dialer := dialer.New(nil, &MockValidator{}, netDialer)
server := newPubAPIServer() server := newPubAPIServer()
api := &pubAPIServerStub{responses: tc.getInitialVPNPeersResponses} api := &pubAPIServerStub{responses: tc.getInitialVPNPeersResponses}
pubproto.RegisterAPIServer(server, api) pubproto.RegisterAPIServer(server, api)
@ -235,7 +234,7 @@ func TestGetInitialVPNPeers(t *testing.T) {
zapLogger, err := zap.NewDevelopment() zapLogger, err := zap.NewDevelopment()
require.NoError(err) require.NoError(err)
netDialer := testdialer.NewBufconnDialer() netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(&MockValidator{}, netDialer) dialer := dialer.New(nil, &MockValidator{}, netDialer)
server := newPubAPIServer() server := newPubAPIServer()
api := &pubAPIServerStub{ api := &pubAPIServerStub{
responses: []struct { responses: []struct {
@ -259,11 +258,9 @@ func TestGetInitialVPNPeers(t *testing.T) {
} }
func newPubAPIServer() *grpc.Server { func newPubAPIServer() *grpc.Server {
tlsConfig, err := atls.CreateAttestationServerTLSConfig(&MockIssuer{}, nil) creds := atlscredentials.New(&MockIssuer{}, nil)
if err != nil {
panic(err) return grpc.NewServer(grpc.Creds(creds))
}
return grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
} }
type pubAPIServerStub struct { type pubAPIServerStub struct {

View File

@ -16,12 +16,12 @@ import (
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/role" "github.com/edgelesssys/constellation/coordinator/role"
"github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
"github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/attestation/vtpm" "github.com/edgelesssys/constellation/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/internal/deploy/ssh" "github.com/edgelesssys/constellation/internal/deploy/ssh"
"github.com/edgelesssys/constellation/internal/deploy/user" "github.com/edgelesssys/constellation/internal/deploy/user"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/internal/grpc/dialer"
"github.com/edgelesssys/constellation/internal/grpc/testdialer"
"github.com/edgelesssys/constellation/internal/oid" "github.com/edgelesssys/constellation/internal/oid"
kms "github.com/edgelesssys/constellation/kms/server/setup" kms "github.com/edgelesssys/constellation/kms/server/setup"
"github.com/edgelesssys/constellation/state/keyservice/keyproto" "github.com/edgelesssys/constellation/state/keyservice/keyproto"
@ -30,7 +30,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest" "go.uber.org/zap/zaptest"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials"
grpcpeer "google.golang.org/grpc/peer" grpcpeer "google.golang.org/grpc/peer"
) )
@ -150,7 +149,7 @@ func TestActivateAsCoordinator(t *testing.T) {
} }
netDialer := testdialer.NewBufconnDialer() netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer) dialer := dialer.New(nil, fakeValidator{}, netDialer)
getPublicIPAddr := func() (string, error) { getPublicIPAddr := func() (string, error) {
return "192.0.2.1", nil return "192.0.2.1", nil
@ -302,7 +301,7 @@ func TestActivateAdditionalNodes(t *testing.T) {
core := &fakeCore{state: tc.state} core := &fakeCore{state: tc.state}
netDialer := testdialer.NewBufconnDialer() netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer) dialer := dialer.New(nil, fakeValidator{}, netDialer)
getPublicIPAddr := func() (string, error) { getPublicIPAddr := func() (string, error) {
return "192.0.2.1", nil return "192.0.2.1", nil
@ -432,11 +431,8 @@ func (n *stubPeer) GetPeerVPNPublicKey(ctx context.Context, in *pubproto.GetPeer
} }
func (n *stubPeer) newServer() *grpc.Server { func (n *stubPeer) newServer() *grpc.Server {
tlsConfig, err := atls.CreateAttestationServerTLSConfig(fakeIssuer{}, nil) creds := atlscredentials.New(fakeIssuer{}, nil)
if err != nil { server := grpc.NewServer(grpc.Creds(creds))
panic(err)
}
server := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
pubproto.RegisterAPIServer(server, n) pubproto.RegisterAPIServer(server, n)
return server return server
} }
@ -537,9 +533,8 @@ func TestRequestStateDiskKey(t *testing.T) {
require.NoError(err) require.NoError(err)
defer listener.Close() defer listener.Close()
tlsConfig, err := atls.CreateAttestationServerTLSConfig(issuer, nil) creds := atlscredentials.New(issuer, nil)
require.NoError(err) s := grpc.NewServer(grpc.Creds(creds))
s := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
keyproto.RegisterAPIServer(s, stateDiskServer) keyproto.RegisterAPIServer(s, stateDiskServer)
defer s.GracefulStop() defer s.GracefulStop()
go s.Serve(listener) go s.Serve(listener)
@ -559,7 +554,7 @@ func TestRequestStateDiskKey(t *testing.T) {
getDataKeyErr: tc.getDataKeyErr, getDataKeyErr: tc.getDataKeyErr,
} }
api := New(zaptest.NewLogger(t), &logging.NopLogger{}, core, grpcutil.NewDialer(dummyValidator{}, &net.Dialer{}), nil, nil, getPeerFromContext) api := New(zaptest.NewLogger(t), &logging.NopLogger{}, core, dialer.New(nil, dummyValidator{}, &net.Dialer{}), nil, nil, getPeerFromContext)
_, err = api.RequestStateDiskKey(ctx, &pubproto.RequestStateDiskKeyRequest{}) _, err = api.RequestStateDiskKey(ctx, &pubproto.RequestStateDiskKeyRequest{})
if tc.wantErr { if tc.wantErr {

View File

@ -11,8 +11,8 @@ import (
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/role" "github.com/edgelesssys/constellation/coordinator/role"
"github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/coordinator/util/grpcutil" "github.com/edgelesssys/constellation/internal/grpc/dialer"
"github.com/edgelesssys/constellation/coordinator/util/testdialer" "github.com/edgelesssys/constellation/internal/grpc/testdialer"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest" "go.uber.org/zap/zaptest"
@ -95,7 +95,7 @@ func TestActivateAsAdditionalCoordinator(t *testing.T) {
clusterID: []byte("clusterID"), clusterID: []byte("clusterID"),
} }
netDialer := testdialer.NewBufconnDialer() netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer) dialer := dialer.New(nil, fakeValidator{}, netDialer)
getPublicIPAddr := func() (string, error) { getPublicIPAddr := func() (string, error) {
return "192.0.2.1", nil return "192.0.2.1", nil
@ -166,7 +166,7 @@ func TestTriggerCoordinatorUpdate(t *testing.T) {
state: tc.state, state: tc.state,
peers: tc.peers, peers: tc.peers,
} }
dialer := grpcutil.NewDialer(fakeValidator{}, nil) dialer := dialer.New(nil, fakeValidator{}, nil)
api := New(logger, &logging.NopLogger{}, core, dialer, nil, nil, nil) api := New(logger, &logging.NopLogger{}, core, dialer, nil, nil, nil)
@ -240,7 +240,7 @@ func TestActivateAdditionalCoordinators(t *testing.T) {
clusterID: []byte("clusterID"), clusterID: []byte("clusterID"),
} }
netDialer := testdialer.NewBufconnDialer() netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer) dialer := dialer.New(nil, fakeValidator{}, netDialer)
getPublicIPAddr := func() (string, error) { getPublicIPAddr := func() (string, error) {
return "192.0.2.1", nil return "192.0.2.1", nil
@ -297,7 +297,7 @@ func TestGetPeerVPNPublicKey(t *testing.T) {
vpnPubKey: tc.coordinator.peer.VPNPubKey, vpnPubKey: tc.coordinator.peer.VPNPubKey,
getvpnPubKeyErr: tc.getVPNPubKeyErr, getvpnPubKeyErr: tc.getVPNPubKeyErr,
} }
dialer := grpcutil.NewDialer(fakeValidator{}, testdialer.NewBufconnDialer()) dialer := dialer.New(nil, fakeValidator{}, testdialer.NewBufconnDialer())
getPublicIPAddr := func() (string, error) { getPublicIPAddr := func() (string, error) {
return "192.0.2.1", nil return "192.0.2.1", nil

View File

@ -14,18 +14,18 @@ import (
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/role" "github.com/edgelesssys/constellation/coordinator/role"
"github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto" "github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
"github.com/edgelesssys/constellation/internal/atls" "github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/deploy/ssh" "github.com/edgelesssys/constellation/internal/deploy/ssh"
"github.com/edgelesssys/constellation/internal/deploy/user" "github.com/edgelesssys/constellation/internal/deploy/user"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/internal/grpc/dialer"
"github.com/edgelesssys/constellation/internal/grpc/testdialer"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest" "go.uber.org/zap/zaptest"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials"
kubeadm "k8s.io/kubernetes/cmd/kubeadm/app/apis/kubeadm/v1beta3" kubeadm "k8s.io/kubernetes/cmd/kubeadm/app/apis/kubeadm/v1beta3"
) )
@ -152,7 +152,7 @@ func TestActivateAsNode(t *testing.T) {
linuxUserManager := user.NewLinuxUserManagerFake(fs) linuxUserManager := user.NewLinuxUserManagerFake(fs)
cor := &fakeCore{state: tc.state, vpnPubKey: vpnPubKey, setVPNIPErr: tc.setVPNIPErr, linuxUserManager: linuxUserManager} cor := &fakeCore{state: tc.state, vpnPubKey: vpnPubKey, setVPNIPErr: tc.setVPNIPErr, linuxUserManager: linuxUserManager}
netDialer := testdialer.NewBufconnDialer() netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer) dialer := dialer.New(nil, fakeValidator{}, netDialer)
api := New(logger, &logging.NopLogger{}, cor, dialer, nil, nil, nil) api := New(logger, &logging.NopLogger{}, cor, dialer, nil, nil, nil)
defer api.Close() defer api.Close()
@ -163,9 +163,8 @@ func TestActivateAsNode(t *testing.T) {
go vserver.Serve(netDialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort))) go vserver.Serve(netDialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort)))
defer vserver.GracefulStop() defer vserver.GracefulStop()
tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{}, nil) creds := atlscredentials.New(&core.MockIssuer{}, nil)
require.NoError(err) pubserver := grpc.NewServer(grpc.Creds(creds))
pubserver := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
pubproto.RegisterAPIServer(pubserver, api) pubproto.RegisterAPIServer(pubserver, api)
go pubserver.Serve(netDialer.GetListener(net.JoinHostPort(nodeIP, endpointAVPNPort))) go pubserver.Serve(netDialer.GetListener(net.JoinHostPort(nodeIP, endpointAVPNPort)))
defer pubserver.GracefulStop() defer pubserver.GracefulStop()
@ -260,7 +259,7 @@ func TestTriggerNodeUpdate(t *testing.T) {
logger := zaptest.NewLogger(t) logger := zaptest.NewLogger(t)
core := &fakeCore{state: tc.state} core := &fakeCore{state: tc.state}
netDialer := testdialer.NewBufconnDialer() netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer) dialer := dialer.New(nil, fakeValidator{}, netDialer)
api := New(logger, &logging.NopLogger{}, core, dialer, nil, nil, nil) api := New(logger, &logging.NopLogger{}, core, dialer, nil, nil, nil)
@ -336,7 +335,7 @@ func TestJoinCluster(t *testing.T) {
logger := zaptest.NewLogger(t) logger := zaptest.NewLogger(t)
core := &fakeCore{state: tc.state, joinClusterErr: tc.joinClusterErr} core := &fakeCore{state: tc.state, joinClusterErr: tc.joinClusterErr}
netDialer := testdialer.NewBufconnDialer() netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer) dialer := dialer.New(nil, fakeValidator{}, netDialer)
api := New(logger, &logging.NopLogger{}, core, dialer, nil, nil, nil) api := New(logger, &logging.NopLogger{}, core, dialer, nil, nil, nil)
@ -433,16 +432,13 @@ func activateNode(require *require.Assertions, dialer netDialer, messageSequence
} }
func dialGRPC(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) { func dialGRPC(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) {
tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, []atls.Validator{&core.MockValidator{}}) creds := atlscredentials.New(nil, []atls.Validator{&core.MockValidator{}})
if err != nil {
return nil, err
}
return grpc.DialContext(ctx, target, return grpc.DialContext(ctx, target,
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, "tcp", addr) return dialer.DialContext(ctx, "tcp", addr)
}), }),
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), grpc.WithTransportCredentials(creds),
) )
} }

View File

@ -9,10 +9,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
"github.com/edgelesssys/constellation/debugd/coordinator" "github.com/edgelesssys/constellation/debugd/coordinator"
"github.com/edgelesssys/constellation/debugd/debugd" "github.com/edgelesssys/constellation/debugd/debugd"
pb "github.com/edgelesssys/constellation/debugd/service" pb "github.com/edgelesssys/constellation/debugd/service"
"github.com/edgelesssys/constellation/internal/grpc/testdialer"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"

View File

@ -8,11 +8,11 @@ import (
"net" "net"
"testing" "testing"
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
"github.com/edgelesssys/constellation/debugd/coordinator" "github.com/edgelesssys/constellation/debugd/coordinator"
"github.com/edgelesssys/constellation/debugd/debugd/deploy" "github.com/edgelesssys/constellation/debugd/debugd/deploy"
pb "github.com/edgelesssys/constellation/debugd/service" pb "github.com/edgelesssys/constellation/debugd/service"
"github.com/edgelesssys/constellation/internal/deploy/ssh" "github.com/edgelesssys/constellation/internal/deploy/ssh"
"github.com/edgelesssys/constellation/internal/grpc/testdialer"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/grpc" "google.golang.org/grpc"

View File

@ -0,0 +1,53 @@
package atlscredentials
import (
"context"
"errors"
"net"
"github.com/edgelesssys/constellation/internal/atls"
"google.golang.org/grpc/credentials"
)
type Credentials struct {
issuer atls.Issuer
validators []atls.Validator
}
func New(issuer atls.Issuer, validators []atls.Validator) *Credentials {
return &Credentials{
issuer: issuer,
validators: validators,
}
}
func (c *Credentials) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
clientCfg, err := atls.CreateAttestationClientTLSConfig(c.issuer, c.validators)
if err != nil {
return nil, nil, err
}
return credentials.NewTLS(clientCfg).ClientHandshake(ctx, authority, rawConn)
}
func (c *Credentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
serverCfg, err := atls.CreateAttestationServerTLSConfig(c.issuer, c.validators)
if err != nil {
return nil, nil, err
}
return credentials.NewTLS(serverCfg).ServerHandshake(rawConn)
}
func (c *Credentials) Info() credentials.ProtocolInfo {
return credentials.NewTLS(nil).Info()
}
func (c *Credentials) Clone() credentials.TransportCredentials {
cloned := *c
return &cloned
}
func (c *Credentials) OverrideServerName(s string) error {
return errors.New("cannot override server name")
}

View File

@ -0,0 +1,120 @@
package atlscredentials
import (
"bytes"
"context"
"encoding/asn1"
"encoding/json"
"errors"
"net"
"testing"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/internal/atls"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"google.golang.org/grpc"
"google.golang.org/grpc/test/bufconn"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestATLSCredentials(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
oid := fakeOID{1, 3, 9900, 1}
//
// Create servers
//
serverCreds := New(fakeIssuer{fakeOID: oid}, nil)
const serverCount = 15
var listeners []*bufconn.Listener
for i := 0; i < serverCount; i++ {
api := &fakeAPI{}
server := grpc.NewServer(grpc.Creds(serverCreds))
pubproto.RegisterAPIServer(server, api)
listener := bufconn.Listen(1024)
listeners = append(listeners, listener)
defer server.GracefulStop()
go server.Serve(listener)
}
//
// Dial concurrently
//
clientCreds := New(nil, []atls.Validator{fakeValidator{fakeOID: oid}})
errChan := make(chan error, serverCount)
for _, listener := range listeners {
lis := listener
go func() {
var err error
defer func() { errChan <- err }()
conn, err := grpc.DialContext(context.Background(), "", grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return lis.Dial()
}), grpc.WithTransportCredentials(clientCreds))
require.NoError(err)
defer conn.Close()
client := pubproto.NewAPIClient(conn)
_, err = client.GetState(context.Background(), &pubproto.GetStateRequest{})
}()
}
for i := 0; i < serverCount; i++ {
assert.NoError(<-errChan)
}
}
type fakeIssuer struct {
fakeOID
}
func (fakeIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) {
return json.Marshal(fakeDoc{UserData: userData, Nonce: nonce})
}
type fakeValidator struct {
fakeOID
err error
}
func (v fakeValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
var doc fakeDoc
if err := json.Unmarshal(attDoc, &doc); err != nil {
return nil, err
}
if !bytes.Equal(doc.Nonce, nonce) {
return nil, errors.New("invalid nonce")
}
return doc.UserData, v.err
}
type fakeOID asn1.ObjectIdentifier
func (o fakeOID) OID() asn1.ObjectIdentifier {
return asn1.ObjectIdentifier(o)
}
type fakeDoc struct {
UserData []byte
Nonce []byte
}
type fakeAPI struct {
pubproto.UnimplementedAPIServer
}
func (f *fakeAPI) GetState(ctx context.Context, in *pubproto.GetStateRequest) (*pubproto.GetStateResponse, error) {
return &pubproto.GetStateResponse{State: 1}, nil
}

View File

@ -1,24 +1,26 @@
package grpcutil package dialer
import ( import (
"context" "context"
"net" "net"
"github.com/edgelesssys/constellation/internal/atls" "github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
) )
// Dialer can open grpc client connections with different levels of ATLS encryption / verification. // Dialer can open grpc client connections with different levels of ATLS encryption / verification.
type Dialer struct { type Dialer struct {
issuer atls.Issuer
validator atls.Validator validator atls.Validator
netDialer NetDialer netDialer NetDialer
} }
// NewDialer creates a new Dialer. // New creates a new Dialer.
func NewDialer(validator atls.Validator, netDialer NetDialer) *Dialer { func New(issuer atls.Issuer, validator atls.Validator, netDialer NetDialer) *Dialer {
return &Dialer{ return &Dialer{
issuer: issuer,
validator: validator, validator: validator,
netDialer: netDialer, netDialer: netDialer,
} }
@ -26,14 +28,11 @@ func NewDialer(validator atls.Validator, netDialer NetDialer) *Dialer {
// Dial creates a new grpc client connection to the given target using the atls validator. // Dial creates a new grpc client connection to the given target using the atls validator.
func (d *Dialer) Dial(ctx context.Context, target string) (*grpc.ClientConn, error) { func (d *Dialer) Dial(ctx context.Context, target string) (*grpc.ClientConn, error) {
tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, []atls.Validator{d.validator}) credentials := atlscredentials.New(d.issuer, []atls.Validator{d.validator})
if err != nil {
return nil, err
}
return grpc.DialContext(ctx, target, return grpc.DialContext(ctx, target,
d.grpcWithDialer(), d.grpcWithDialer(),
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), grpc.WithTransportCredentials(credentials),
) )
} }
@ -48,14 +47,11 @@ func (d *Dialer) DialInsecure(ctx context.Context, target string) (*grpc.ClientC
// DialNoVerify creates a new grpc client connection to the given target without verifying the server's attestation. // DialNoVerify creates a new grpc client connection to the given target without verifying the server's attestation.
func (d *Dialer) DialNoVerify(ctx context.Context, target string) (*grpc.ClientConn, error) { func (d *Dialer) DialNoVerify(ctx context.Context, target string) (*grpc.ClientConn, error) {
tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, nil) credentials := atlscredentials.New(nil, nil)
if err != nil {
return nil, err
}
return grpc.DialContext(ctx, target, return grpc.DialContext(ctx, target,
d.grpcWithDialer(), d.grpcWithDialer(),
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), grpc.WithTransportCredentials(credentials),
) )
} }

View File

@ -1,16 +1,15 @@
package grpcutil package dialer
import ( import (
"context" "context"
"testing" "testing"
"github.com/edgelesssys/constellation/coordinator/core" "github.com/edgelesssys/constellation/coordinator/core"
"github.com/edgelesssys/constellation/coordinator/util/testdialer" "github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/internal/atls" "github.com/edgelesssys/constellation/internal/grpc/testdialer"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/test/grpc_testing" "google.golang.org/grpc/test/grpc_testing"
) )
@ -64,7 +63,7 @@ func TestDial(t *testing.T) {
require := require.New(t) require := require.New(t)
netDialer := testdialer.NewBufconnDialer() netDialer := testdialer.NewBufconnDialer()
dialer := NewDialer(&core.MockValidator{}, netDialer) dialer := New(nil, &core.MockValidator{}, netDialer)
server := newServer(tc.tls) server := newServer(tc.tls)
api := &testAPI{} api := &testAPI{}
grpc_testing.RegisterTestServiceServer(server, api) grpc_testing.RegisterTestServiceServer(server, api)
@ -88,11 +87,8 @@ func TestDial(t *testing.T) {
func newServer(tls bool) *grpc.Server { func newServer(tls bool) *grpc.Server {
if tls { if tls {
tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{}, nil) creds := atlscredentials.New(&core.MockIssuer{}, nil)
if err != nil { return grpc.NewServer(grpc.Creds(creds))
panic(err)
}
return grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
} }
return grpc.NewServer() return grpc.NewServer()
} }

View File

@ -9,9 +9,9 @@ import (
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/internal/atls" "github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"google.golang.org/grpc" "google.golang.org/grpc"
grpccodes "google.golang.org/grpc/codes" grpccodes "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
grpcstatus "google.golang.org/grpc/status" grpcstatus "google.golang.org/grpc/status"
) )
@ -113,13 +113,10 @@ func (w *Waiter) WaitForAll(ctx context.Context, endpoints []string, status ...s
// newAttestedConnGenerator creates a function returning a default attested grpc connection. // newAttestedConnGenerator creates a function returning a default attested grpc connection.
func newAttestedConnGenerator(validators []atls.Validator) func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) { func newAttestedConnGenerator(validators []atls.Validator) func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) {
return func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) { return func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) {
tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, validators) creds := atlscredentials.New(nil, validators)
if err != nil {
return nil, err
}
return grpc.DialContext( return grpc.DialContext(
ctx, target, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), ctx, target, grpc.WithTransportCredentials(creds),
) )
} }
} }

View File

@ -2,7 +2,6 @@ package keyservice
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"log" "log"
"net" "net"
@ -12,7 +11,7 @@ import (
"github.com/edgelesssys/constellation/coordinator/config" "github.com/edgelesssys/constellation/coordinator/config"
"github.com/edgelesssys/constellation/coordinator/core" "github.com/edgelesssys/constellation/coordinator/core"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/internal/atls" "github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/state/keyservice/keyproto" "github.com/edgelesssys/constellation/state/keyservice/keyproto"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@ -63,11 +62,8 @@ func (a *KeyAPI) WaitForDecryptionKey(uuid, listenAddr string) ([]byte, error) {
return nil, errors.New("received no disk UUID") return nil, errors.New("received no disk UUID")
} }
tlsConfig, err := atls.CreateAttestationServerTLSConfig(a.issuer, nil) creds := atlscredentials.New(a.issuer, nil)
if err != nil { server := grpc.NewServer(grpc.Creds(creds))
return nil, err
}
server := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
keyproto.RegisterAPIServer(server, a) keyproto.RegisterAPIServer(server, a)
listener, err := net.Listen("tcp", listenAddr) listener, err := net.Listen("tcp", listenAddr)
if err != nil { if err != nil {
@ -95,11 +91,7 @@ func (a *KeyAPI) ResetKey() {
func (a *KeyAPI) requestKeyLoop(uuid string, opts ...grpc.DialOption) error { func (a *KeyAPI) requestKeyLoop(uuid string, opts ...grpc.DialOption) error {
// we do not perform attestation, since the restarting node does not need to care about notifying the correct Coordinator // we do not perform attestation, since the restarting node does not need to care about notifying the correct Coordinator
// if an incorrect key is pushed by a malicious actor, decrypting the disk will fail, and the node will not start // if an incorrect key is pushed by a malicious actor, decrypting the disk will fail, and the node will not start
tlsClientConfig, err := atls.CreateAttestationClientTLSConfig(nil, nil) creds := atlscredentials.New(nil, nil)
if err != nil {
return err
}
// set up for the select statement to immediately request a key, skipping the initial delay caused by using a ticker // set up for the select statement to immediately request a key, skipping the initial delay caused by using a ticker
firstReq := make(chan struct{}, 1) firstReq := make(chan struct{}, 1)
firstReq <- struct{}{} firstReq <- struct{}{}
@ -115,14 +107,14 @@ func (a *KeyAPI) requestKeyLoop(uuid string, opts ...grpc.DialOption) error {
case <-a.keyReceived: case <-a.keyReceived:
return nil return nil
case <-ticker.C: case <-ticker.C:
a.requestKey(uuid, tlsClientConfig, opts...) a.requestKey(uuid, creds, opts...)
case <-firstReq: case <-firstReq:
a.requestKey(uuid, tlsClientConfig, opts...) a.requestKey(uuid, creds, opts...)
} }
} }
} }
func (a *KeyAPI) requestKey(uuid string, tlsClientConfig *tls.Config, opts ...grpc.DialOption) { func (a *KeyAPI) requestKey(uuid string, credentials credentials.TransportCredentials, opts ...grpc.DialOption) {
// list available Coordinators // list available Coordinators
endpoints, _ := core.CoordinatorEndpoints(context.Background(), a.metadata) endpoints, _ := core.CoordinatorEndpoints(context.Background(), a.metadata)
@ -131,7 +123,7 @@ func (a *KeyAPI) requestKey(uuid string, tlsClientConfig *tls.Config, opts ...gr
// any errors encountered here will be ignored, and the calls retried after a timeout // any errors encountered here will be ignored, and the calls retried after a timeout
for _, endpoint := range endpoints { for _, endpoint := range endpoints {
ctx, cancel := context.WithTimeout(context.Background(), a.timeout) ctx, cancel := context.WithTimeout(context.Background(), a.timeout)
conn, err := grpc.DialContext(ctx, endpoint, append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tlsClientConfig)))...) conn, err := grpc.DialContext(ctx, endpoint, append(opts, grpc.WithTransportCredentials(credentials))...)
if err == nil { if err == nil {
client := pubproto.NewAPIClient(conn) client := pubproto.NewAPIClient(conn)
_, _ = client.RequestStateDiskKey(ctx, &pubproto.RequestStateDiskKeyRequest{DiskUuid: uuid}) _, _ = client.RequestStateDiskKey(ctx, &pubproto.RequestStateDiskKeyRequest{DiskUuid: uuid})

View File

@ -11,12 +11,11 @@ import (
"github.com/edgelesssys/constellation/coordinator/core" "github.com/edgelesssys/constellation/coordinator/core"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/role" "github.com/edgelesssys/constellation/coordinator/role"
"github.com/edgelesssys/constellation/internal/atls" "github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/state/keyservice/keyproto" "github.com/edgelesssys/constellation/state/keyservice/keyproto"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/test/bufconn" "google.golang.org/grpc/test/bufconn"
) )
@ -76,9 +75,8 @@ func TestRequestKeyLoop(t *testing.T) {
listener := bufconn.Listen(1) listener := bufconn.Listen(1)
defer listener.Close() defer listener.Close()
tlsConfig, err := atls.CreateAttestationServerTLSConfig(core.NewMockIssuer(), nil) creds := atlscredentials.New(core.NewMockIssuer(), nil)
require.NoError(err) s := grpc.NewServer(grpc.Creds(creds))
s := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
pubproto.RegisterAPIServer(s, tc.server) pubproto.RegisterAPIServer(s, tc.server)
if !tc.dontStartServer { if !tc.dontStartServer {
@ -97,7 +95,7 @@ func TestRequestKeyLoop(t *testing.T) {
keyReceived <- struct{}{} keyReceived <- struct{}{}
}() }()
err = keyWaiter.requestKeyLoop( err := keyWaiter.requestKeyLoop(
"1234", "1234",
grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) { grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) {
return listener.DialContext(ctx) return listener.DialContext(ctx)

View File

@ -12,7 +12,7 @@ import (
"time" "time"
"github.com/edgelesssys/constellation/coordinator/core" "github.com/edgelesssys/constellation/coordinator/core"
"github.com/edgelesssys/constellation/internal/atls" "github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/state/keyservice" "github.com/edgelesssys/constellation/state/keyservice"
"github.com/edgelesssys/constellation/state/keyservice/keyproto" "github.com/edgelesssys/constellation/state/keyservice/keyproto"
"github.com/edgelesssys/constellation/state/mapper" "github.com/edgelesssys/constellation/state/mapper"
@ -20,7 +20,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials"
) )
const ( const (
@ -91,9 +90,8 @@ func TestKeyAPI(t *testing.T) {
// wait 2 seconds before sending the key // wait 2 seconds before sending the key
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
clientCfg, err := atls.CreateAttestationClientTLSConfig(nil, nil) creds := atlscredentials.New(nil, nil)
require.NoError(err) conn, err := grpc.Dial(apiAddr, grpc.WithTransportCredentials(creds))
conn, err := grpc.Dial(apiAddr, grpc.WithTransportCredentials(credentials.NewTLS(clientCfg)))
require.NoError(err) require.NoError(err)
defer conn.Close() defer conn.Close()

View File

@ -4,7 +4,6 @@ package integration
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -27,6 +26,7 @@ import (
"github.com/edgelesssys/constellation/coordinator/store" "github.com/edgelesssys/constellation/coordinator/store"
"github.com/edgelesssys/constellation/coordinator/storewrapper" "github.com/edgelesssys/constellation/coordinator/storewrapper"
"github.com/edgelesssys/constellation/internal/atls" "github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
kms "github.com/edgelesssys/constellation/kms/server/setup" kms "github.com/edgelesssys/constellation/kms/server/setup"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -249,12 +249,9 @@ func TestMain(t *testing.T) {
// helper methods // helper methods
func startCoordinator(ctx context.Context, coordinatorAddr string, ips []string) error { func startCoordinator(ctx context.Context, coordinatorAddr string, ips []string) error {
tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, []atls.Validator{&core.MockValidator{}}) creds := atlscredentials.New(nil, []atls.Validator{&core.MockValidator{}})
if err != nil {
return err
}
conn, err := grpc.DialContext(ctx, net.JoinHostPort(coordinatorAddr, publicgRPCPort), grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) conn, err := grpc.DialContext(ctx, net.JoinHostPort(coordinatorAddr, publicgRPCPort), grpc.WithTransportCredentials(creds))
if err != nil { if err != nil {
return err return err
} }
@ -299,12 +296,9 @@ func createTempDir() error {
} }
func addNewCoordinatorToCoordinator(ctx context.Context, newCoordinatorAddr, oldCoordinatorAddr string) error { func addNewCoordinatorToCoordinator(ctx context.Context, newCoordinatorAddr, oldCoordinatorAddr string) error {
tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, []atls.Validator{&core.MockValidator{}}) creds := atlscredentials.New(nil, []atls.Validator{&core.MockValidator{}})
if err != nil {
return err
}
conn, err := grpc.DialContext(ctx, net.JoinHostPort(oldCoordinatorAddr, publicgRPCPort), grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) conn, err := grpc.DialContext(ctx, net.JoinHostPort(oldCoordinatorAddr, publicgRPCPort), grpc.WithTransportCredentials(creds))
if err != nil { if err != nil {
return err return err
} }
@ -322,12 +316,9 @@ func addNewCoordinatorToCoordinator(ctx context.Context, newCoordinatorAddr, old
} }
func addNewNodesToCoordinator(ctx context.Context, coordinatorAddr string, ips []string) error { func addNewNodesToCoordinator(ctx context.Context, coordinatorAddr string, ips []string) error {
tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, []atls.Validator{&core.MockValidator{}}) creds := atlscredentials.New(nil, []atls.Validator{&core.MockValidator{}})
if err != nil {
return err
}
conn, err := grpc.DialContext(ctx, net.JoinHostPort(coordinatorAddr, publicgRPCPort), grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) conn, err := grpc.DialContext(ctx, net.JoinHostPort(coordinatorAddr, publicgRPCPort), grpc.WithTransportCredentials(creds))
if err != nil { if err != nil {
return err return err
} }
@ -533,11 +524,11 @@ func createNewNode(ctx context.Context, cli *client.Client) (*newNodeData, error
return &newNodeData{resp, containerData.NetworkSettings.IPAddress}, nil return &newNodeData{resp, containerData.NetworkSettings.IPAddress}, nil
} }
func awaitPeerResponse(ctx context.Context, ip string, tlsConfig *tls.Config) error { func awaitPeerResponse(ctx context.Context, ip string, credentials credentials.TransportCredentials) error {
// Block, so the connection gets established/fails immediately // Block, so the connection gets established/fails immediately
ctx, cancel := context.WithTimeout(ctx, 10*time.Second) ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel() defer cancel()
conn, err := grpc.DialContext(ctx, net.JoinHostPort(ip, publicgRPCPort), grpc.WithBlock(), grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) conn, err := grpc.DialContext(ctx, net.JoinHostPort(ip, publicgRPCPort), grpc.WithBlock(), grpc.WithTransportCredentials(credentials))
if err != nil { if err != nil {
return err return err
} }
@ -545,13 +536,10 @@ func awaitPeerResponse(ctx context.Context, ip string, tlsConfig *tls.Config) er
} }
func blockUntilUp(ctx context.Context, peerIPs []string) error { func blockUntilUp(ctx context.Context, peerIPs []string) error {
tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, []atls.Validator{&core.MockValidator{}}) creds := atlscredentials.New(nil, []atls.Validator{&core.MockValidator{}})
if err != nil {
return err
}
for _, ip := range peerIPs { for _, ip := range peerIPs {
// Block, so the connection gets established/fails immediately // Block, so the connection gets established/fails immediately
if err := awaitPeerResponse(ctx, ip, tlsConfig); err != nil { if err := awaitPeerResponse(ctx, ip, creds); err != nil {
return err return err
} }
} }