diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e1131814..5933f057d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Nodes add themselves to the cluster after `constellation init` is done +- Owner ID and Unique ID merged into a single value: Cluster ID + ### Deprecated ### Removed diff --git a/bootstrapper/cmd/bootstrapper/test.go b/bootstrapper/cmd/bootstrapper/test.go index 04225601f..cd13e4786 100644 --- a/bootstrapper/cmd/bootstrapper/test.go +++ b/bootstrapper/cmd/bootstrapper/test.go @@ -5,7 +5,6 @@ import ( "github.com/edgelesssys/constellation/bootstrapper/internal/kubernetes" "github.com/edgelesssys/constellation/bootstrapper/role" - attestationtypes "github.com/edgelesssys/constellation/internal/attestation/types" "github.com/edgelesssys/constellation/internal/cloud/metadata" "github.com/edgelesssys/constellation/internal/logger" kubeadm "k8s.io/kubernetes/cmd/kubeadm/app/apis/kubeadm/v1beta3" @@ -15,7 +14,7 @@ import ( type clusterFake struct{} // InitCluster fakes bootstrapping a new cluster with the current node being the master, returning the arguments required to join the cluster. -func (c *clusterFake) InitCluster(context.Context, []string, string, string, attestationtypes.ID, kubernetes.KMSConfig, map[string]string, *logger.Logger, +func (c *clusterFake) InitCluster(context.Context, []string, string, string, []byte, kubernetes.KMSConfig, map[string]string, *logger.Logger, ) ([]byte, error) { return []byte{}, nil } diff --git a/bootstrapper/internal/initserver/initserver.go b/bootstrapper/internal/initserver/initserver.go index 51e0d62e9..dfb0718e5 100644 --- a/bootstrapper/internal/initserver/initserver.go +++ b/bootstrapper/internal/initserver/initserver.go @@ -11,10 +11,9 @@ import ( "github.com/edgelesssys/constellation/bootstrapper/internal/kubernetes" "github.com/edgelesssys/constellation/bootstrapper/nodestate" "github.com/edgelesssys/constellation/bootstrapper/role" - "github.com/edgelesssys/constellation/bootstrapper/util" "github.com/edgelesssys/constellation/internal/atls" - attestationtypes "github.com/edgelesssys/constellation/internal/attestation/types" - "github.com/edgelesssys/constellation/internal/constants" + "github.com/edgelesssys/constellation/internal/attestation" + "github.com/edgelesssys/constellation/internal/crypto" "github.com/edgelesssys/constellation/internal/file" "github.com/edgelesssys/constellation/internal/grpc/atlscredentials" "github.com/edgelesssys/constellation/internal/grpc/grpclog" @@ -79,12 +78,13 @@ func (s *Server) Init(ctx context.Context, req *initproto.InitRequest) (*initpro log := s.log.With(zap.String("peer", grpclog.PeerAddrFromContext(ctx))) log.Infof("Init called") - id, err := s.deriveAttestationID(req.MasterSecret) + // generate values for cluster attestation + measurementSalt, clusterID, err := deriveMeasurementValues(req.MasterSecret) if err != nil { - return nil, status.Errorf(codes.Internal, "%s", err) + return nil, status.Errorf(codes.Internal, "deriving measurement values: %s", err) } - nodeLockAcquired, err := s.nodeLock.TryLockOnce(id.Owner, id.Cluster) + nodeLockAcquired, err := s.nodeLock.TryLockOnce(clusterID) if err != nil { return nil, status.Errorf(codes.Internal, "locking node: %s", err) } @@ -103,9 +103,8 @@ func (s *Server) Init(ctx context.Context, req *initproto.InitRequest) (*initpro } state := nodestate.NodeState{ - Role: role.ControlPlane, - OwnerID: id.Owner, - ClusterID: id.Cluster, + Role: role.ControlPlane, + MeasurementSalt: measurementSalt, } if err := state.ToFile(s.fileHandler); err != nil { return nil, status.Errorf(codes.Internal, "persisting node state: %s", err) @@ -115,7 +114,7 @@ func (s *Server) Init(ctx context.Context, req *initproto.InitRequest) (*initpro req.AutoscalingNodeGroups, req.CloudServiceAccountUri, req.KubernetesVersion, - id, + measurementSalt, kubernetes.KMSConfig{ MasterSecret: req.MasterSecret, KMSURI: req.KmsUri, @@ -133,8 +132,7 @@ func (s *Server) Init(ctx context.Context, req *initproto.InitRequest) (*initpro log.Infof("Init succeeded") return &initproto.InitResponse{ Kubeconfig: kubeconfig, - OwnerId: id.Owner, - ClusterId: id.Cluster, + ClusterId: clusterID, }, nil } @@ -156,7 +154,7 @@ func (s *Server) setupDisk(masterSecret []byte) error { uuid = strings.ToLower(uuid) // TODO: Choose a way to salt the key derivation - diskKey, err := util.DeriveKey(masterSecret, []byte("Constellation"), []byte("key"+uuid), 32) + diskKey, err := crypto.DeriveKey(masterSecret, []byte("Constellation"), []byte(crypto.HKDFInfoPrefix+uuid), 32) if err != nil { return err } @@ -164,21 +162,6 @@ func (s *Server) setupDisk(masterSecret []byte) error { return s.disk.UpdatePassphrase(string(diskKey)) } -func (s *Server) deriveAttestationID(masterSecret []byte) (attestationtypes.ID, error) { - clusterID, err := util.GenerateRandomBytes(constants.RNGLengthDefault) - if err != nil { - return attestationtypes.ID{}, err - } - - // TODO: Choose a way to salt the key derivation - ownerID, err := util.DeriveKey(masterSecret, []byte("Constellation"), []byte("id"), constants.RNGLengthDefault) - if err != nil { - return attestationtypes.ID{}, err - } - - return attestationtypes.ID{Owner: ownerID, Cluster: clusterID}, nil -} - func sshProtoKeysToMap(keys []*initproto.SSHUserKey) map[string]string { keyMap := make(map[string]string) for _, key := range keys { @@ -187,6 +170,23 @@ func sshProtoKeysToMap(keys []*initproto.SSHUserKey) map[string]string { return keyMap } +func deriveMeasurementValues(masterSecret []byte) (salt, clusterID []byte, err error) { + salt, err = crypto.GenerateRandomBytes(crypto.RNGLengthDefault) + if err != nil { + return nil, nil, err + } + secret, err := attestation.DeriveMeasurementSecret(masterSecret) + if err != nil { + return nil, nil, err + } + clusterID, err = attestation.DeriveClusterID(salt, secret) + if err != nil { + return nil, nil, err + } + + return salt, clusterID, nil +} + // ClusterInitializer has the ability to initialize a cluster. type ClusterInitializer interface { // InitCluster initializes a new Kubernetes cluster. @@ -195,7 +195,7 @@ type ClusterInitializer interface { autoscalingNodeGroups []string, cloudServiceAccountURI string, k8sVersion string, - id attestationtypes.ID, + measurementSalt []byte, kmsConfig kubernetes.KMSConfig, sshUserKeys map[string]string, log *logger.Logger, @@ -223,7 +223,7 @@ type serveStopper interface { type locker interface { // TryLockOnce tries to lock the node. If the node is already locked, it // returns false. If the node is unlocked, it locks it and returns true. - TryLockOnce(ownerID, clusterID []byte) (bool, error) + TryLockOnce(clusterID []byte) (bool, error) } type cleaner interface { diff --git a/bootstrapper/internal/initserver/initserver_test.go b/bootstrapper/internal/initserver/initserver_test.go index c56952d78..9a667b5fe 100644 --- a/bootstrapper/internal/initserver/initserver_test.go +++ b/bootstrapper/internal/initserver/initserver_test.go @@ -10,7 +10,6 @@ import ( "github.com/edgelesssys/constellation/bootstrapper/initproto" "github.com/edgelesssys/constellation/bootstrapper/internal/kubernetes" - attestationtypes "github.com/edgelesssys/constellation/internal/attestation/types" "github.com/edgelesssys/constellation/internal/file" "github.com/edgelesssys/constellation/internal/logger" "github.com/spf13/afero" @@ -40,7 +39,7 @@ func TestNew(t *testing.T) { func TestInit(t *testing.T) { someErr := errors.New("failed") lockedLock := newFakeLock() - aqcuiredLock, lockErr := lockedLock.TryLockOnce(nil, nil) + aqcuiredLock, lockErr := lockedLock.TryLockOnce(nil) require.True(t, aqcuiredLock) require.Nil(t, lockErr) @@ -144,7 +143,7 @@ func TestInit(t *testing.T) { assert.NoError(err) assert.NotNil(kubeconfig) - assert.False(server.nodeLock.TryLockOnce(nil, nil)) // lock should be locked + assert.False(server.nodeLock.TryLockOnce(nil)) // lock should be locked }) } } @@ -219,7 +218,7 @@ type stubClusterInitializer struct { initClusterErr error } -func (i *stubClusterInitializer) InitCluster(context.Context, []string, string, string, attestationtypes.ID, kubernetes.KMSConfig, map[string]string, *logger.Logger, +func (i *stubClusterInitializer) InitCluster(context.Context, []string, string, string, []byte, kubernetes.KMSConfig, map[string]string, *logger.Logger, ) ([]byte, error) { return i.initClusterKubeconfig, i.initClusterErr } @@ -250,7 +249,7 @@ func newFakeLock() *fakeLock { } } -func (l *fakeLock) TryLockOnce(_, _ []byte) (bool, error) { +func (l *fakeLock) TryLockOnce(_ []byte) (bool, error) { return l.state.TryLock(), nil } diff --git a/bootstrapper/internal/joinclient/client.go b/bootstrapper/internal/joinclient/client.go index 144eac252..c93d59b79 100644 --- a/bootstrapper/internal/joinclient/client.go +++ b/bootstrapper/internal/joinclient/client.go @@ -14,6 +14,7 @@ import ( "github.com/edgelesssys/constellation/bootstrapper/internal/kubelet" "github.com/edgelesssys/constellation/bootstrapper/nodestate" "github.com/edgelesssys/constellation/bootstrapper/role" + "github.com/edgelesssys/constellation/internal/attestation" "github.com/edgelesssys/constellation/internal/cloud/metadata" "github.com/edgelesssys/constellation/internal/constants" "github.com/edgelesssys/constellation/internal/file" @@ -230,7 +231,12 @@ func (c *JoinClient) startNodeAndJoin(ticket *joinproto.IssueJoinTicketResponse, } }() - nodeLockAcquired, err := c.nodeLock.TryLockOnce(ticket.OwnerId, ticket.ClusterId) + clusterID, err := attestation.DeriveClusterID(ticket.MeasurementSalt, ticket.MeasurementSecret) + if err != nil { + return err + } + + nodeLockAcquired, err := c.nodeLock.TryLockOnce(clusterID) if err != nil { c.log.With(zap.Error(err)).Errorf("Acquiring node lock failed") return fmt.Errorf("acquiring node lock: %w", err) @@ -259,9 +265,8 @@ func (c *JoinClient) startNodeAndJoin(ticket *joinproto.IssueJoinTicketResponse, } state := nodestate.NodeState{ - Role: c.role, - OwnerID: ticket.OwnerId, - ClusterID: ticket.ClusterId, + Role: c.role, + MeasurementSalt: ticket.MeasurementSalt, } if err := state.ToFile(c.fileHandler); err != nil { return fmt.Errorf("persisting node state: %w", err) @@ -417,5 +422,5 @@ type cleaner interface { type locker interface { // TryLockOnce tries to lock the node. If the node is already locked, it // returns false. If the node is unlocked, it locks it and returns true. - TryLockOnce(ownerID, clusterID []byte) (bool, error) + TryLockOnce(clusterID []byte) (bool, error) } diff --git a/bootstrapper/internal/joinclient/client_test.go b/bootstrapper/internal/joinclient/client_test.go index 01df71224..71ad8d0d7 100644 --- a/bootstrapper/internal/joinclient/client_test.go +++ b/bootstrapper/internal/joinclient/client_test.go @@ -34,7 +34,7 @@ func TestMain(m *testing.M) { func TestClient(t *testing.T) { someErr := errors.New("failed") lockedLock := newFakeLock() - aqcuiredLock, lockErr := lockedLock.TryLockOnce(nil, nil) + aqcuiredLock, lockErr := lockedLock.TryLockOnce(nil) require.True(t, aqcuiredLock) require.Nil(t, lockErr) workerSelf := metadata.InstanceMetadata{Role: role.Worker, Name: "node-1"} @@ -246,9 +246,9 @@ func TestClient(t *testing.T) { assert.False(tc.clusterJoiner.joinClusterCalled) } if tc.wantLock { - assert.False(client.nodeLock.TryLockOnce(nil, nil)) // lock should be locked + assert.False(client.nodeLock.TryLockOnce(nil)) // lock should be locked } else { - assert.True(client.nodeLock.TryLockOnce(nil, nil)) + assert.True(client.nodeLock.TryLockOnce(nil)) } }) } @@ -430,6 +430,6 @@ func newFakeLock() *fakeLock { } } -func (l *fakeLock) TryLockOnce(_, _ []byte) (bool, error) { +func (l *fakeLock) TryLockOnce(_ []byte) (bool, error) { return l.state.TryLock(), nil } diff --git a/bootstrapper/internal/kubernetes/k8sapi/resources/joinservice.go b/bootstrapper/internal/kubernetes/k8sapi/resources/joinservice.go index 67d78602c..194100856 100644 --- a/bootstrapper/internal/kubernetes/k8sapi/resources/joinservice.go +++ b/bootstrapper/internal/kubernetes/k8sapi/resources/joinservice.go @@ -23,7 +23,7 @@ type joinServiceDaemonset struct { } // NewJoinServiceDaemonset returns a daemonset for the join service. -func NewJoinServiceDaemonset(csp string, measurementsJSON, idJSON string) *joinServiceDaemonset { +func NewJoinServiceDaemonset(csp, measurementsJSON string, measurementSalt []byte) *joinServiceDaemonset { return &joinServiceDaemonset{ ClusterRole: rbac.ClusterRole{ TypeMeta: meta.TypeMeta{ @@ -246,8 +246,10 @@ func NewJoinServiceDaemonset(csp string, measurementsJSON, idJSON string) *joinS Namespace: "kube-system", }, Data: map[string]string{ - "measurements": measurementsJSON, - "id": idJSON, + constants.MeasurementsFilename: measurementsJSON, + }, + BinaryData: map[string][]byte{ + constants.MeasurementSaltFilename: measurementSalt, }, }, } diff --git a/bootstrapper/internal/kubernetes/k8sapi/resources/joinservice_test.go b/bootstrapper/internal/kubernetes/k8sapi/resources/joinservice_test.go index 393f60b8b..26bfc3e0f 100644 --- a/bootstrapper/internal/kubernetes/k8sapi/resources/joinservice_test.go +++ b/bootstrapper/internal/kubernetes/k8sapi/resources/joinservice_test.go @@ -8,7 +8,7 @@ import ( ) func TestNewJoinServiceDaemonset(t *testing.T) { - deployment := NewJoinServiceDaemonset("csp", "measurementsJSON", "idJSON") + deployment := NewJoinServiceDaemonset("csp", "measurementsJSON", []byte{0x0, 0x1, 0x2}) deploymentYAML, err := deployment.Marshal() require.NoError(t, err) diff --git a/bootstrapper/internal/kubernetes/k8sapi/resources/kms.go b/bootstrapper/internal/kubernetes/k8sapi/resources/kms.go index a9f20e477..808fa5e2f 100644 --- a/bootstrapper/internal/kubernetes/k8sapi/resources/kms.go +++ b/bootstrapper/internal/kubernetes/k8sapi/resources/kms.go @@ -15,8 +15,7 @@ import ( type kmsDeployment struct { ServiceAccount k8s.ServiceAccount - ServiceInternal k8s.Service - ServiceExternal k8s.Service + Service k8s.Service ClusterRole rbac.ClusterRole ClusterRoleBinding rbac.ClusterRoleBinding Deployment apps.Deployment @@ -37,7 +36,7 @@ func NewKMSDeployment(csp string, masterSecret []byte) *kmsDeployment { Namespace: "kube-system", }, }, - ServiceInternal: k8s.Service{ + Service: k8s.Service{ TypeMeta: meta.TypeMeta{ APIVersion: "v1", Kind: "Service", @@ -61,31 +60,6 @@ func NewKMSDeployment(csp string, masterSecret []byte) *kmsDeployment { }, }, }, - ServiceExternal: k8s.Service{ - TypeMeta: meta.TypeMeta{ - APIVersion: "v1", - Kind: "Service", - }, - ObjectMeta: meta.ObjectMeta{ - Name: "kms-external", - Namespace: "kube-system", - }, - Spec: k8s.ServiceSpec{ - Type: k8s.ServiceTypeNodePort, - Ports: []k8s.ServicePort{ - { - Name: "atls", - Protocol: k8s.ProtocolTCP, - Port: constants.KMSATLSPort, - TargetPort: intstr.FromInt(constants.KMSATLSPort), - NodePort: constants.KMSNodePort, - }, - }, - Selector: map[string]string{ - "k8s-app": "kms", - }, - }, - }, ClusterRole: rbac.ClusterRole{ TypeMeta: meta.TypeMeta{ APIVersion: "rbac.authorization.k8s.io/v1", @@ -229,9 +203,7 @@ func NewKMSDeployment(csp string, masterSecret []byte) *kmsDeployment { Name: "kms", Image: versions.KmsImage, Args: []string{ - fmt.Sprintf("--atls-port=%d", constants.KMSATLSPort), fmt.Sprintf("--port=%d", constants.KMSPort), - fmt.Sprintf("--cloud-provider=%s", csp), }, VolumeMounts: []k8s.VolumeMount{ { diff --git a/bootstrapper/internal/kubernetes/k8sapi/util.go b/bootstrapper/internal/kubernetes/k8sapi/util.go index 1ed1a7fb8..26f9fe409 100644 --- a/bootstrapper/internal/kubernetes/k8sapi/util.go +++ b/bootstrapper/internal/kubernetes/k8sapi/util.go @@ -19,7 +19,7 @@ import ( "github.com/edgelesssys/constellation/bootstrapper/internal/kubelet" "github.com/edgelesssys/constellation/bootstrapper/internal/kubernetes/k8sapi/resources" - "github.com/edgelesssys/constellation/bootstrapper/util" + "github.com/edgelesssys/constellation/internal/crypto" "github.com/edgelesssys/constellation/internal/file" "github.com/edgelesssys/constellation/internal/logger" "github.com/edgelesssys/constellation/internal/versions" @@ -455,7 +455,7 @@ func (k *KubernetesUtil) createSignedKubeletCert(nodeName string, ips []net.IP) return err } - serialNumber, err := util.GenerateCertificateSerialNumber() + serialNumber, err := crypto.GenerateCertificateSerialNumber() if err != nil { return err } diff --git a/bootstrapper/internal/kubernetes/kubernetes.go b/bootstrapper/internal/kubernetes/kubernetes.go index 251fb1b79..d58875b3a 100644 --- a/bootstrapper/internal/kubernetes/kubernetes.go +++ b/bootstrapper/internal/kubernetes/kubernetes.go @@ -2,7 +2,6 @@ package kubernetes import ( "context" - "encoding/json" "errors" "fmt" "net" @@ -13,7 +12,6 @@ import ( "github.com/edgelesssys/constellation/bootstrapper/internal/kubernetes/k8sapi/resources" "github.com/edgelesssys/constellation/bootstrapper/role" "github.com/edgelesssys/constellation/bootstrapper/util" - attestationtypes "github.com/edgelesssys/constellation/internal/attestation/types" "github.com/edgelesssys/constellation/internal/cloud/metadata" "github.com/edgelesssys/constellation/internal/constants" "github.com/edgelesssys/constellation/internal/logger" @@ -81,7 +79,7 @@ type KMSConfig struct { // InitCluster initializes a new Kubernetes cluster and applies pod network provider. func (k *KubeWrapper) InitCluster( ctx context.Context, autoscalingNodeGroups []string, cloudServiceAccountURI, versionString string, - id attestationtypes.ID, kmsConfig KMSConfig, sshUsers map[string]string, log *logger.Logger, + measurementSalt []byte, kmsConfig KMSConfig, sshUsers map[string]string, log *logger.Logger, ) ([]byte, error) { k8sVersion, err := versions.NewValidK8sVersion(versionString) if err != nil { @@ -194,7 +192,7 @@ func (k *KubeWrapper) InitCluster( return nil, fmt.Errorf("setting up kms: %w", err) } - if err := k.setupJoinService(k.cloudProvider, k.initialMeasurementsJSON, id); err != nil { + if err := k.setupJoinService(k.cloudProvider, k.initialMeasurementsJSON, measurementSalt); err != nil { return nil, fmt.Errorf("setting up join service failed: %w", err) } @@ -226,7 +224,7 @@ func (k *KubeWrapper) InitCluster( } } - // Store the received k8sVersion in a ConfigMap, overwriting exisiting values (there shouldn't be any). + // Store the received k8sVersion in a ConfigMap, overwriting existing values (there shouldn't be any). // Joining nodes determine the kubernetes version they will install based on this ConfigMap. if err := k.setupK8sVersionConfigMap(ctx, k8sVersion); err != nil { return nil, fmt.Errorf("failed to setup k8s version ConfigMap: %v", err) @@ -306,13 +304,8 @@ func (k *KubeWrapper) GetKubeconfig() ([]byte, error) { return k.kubeconfigReader.ReadKubeconfig() } -func (k *KubeWrapper) setupJoinService(csp string, measurementsJSON []byte, id attestationtypes.ID) error { - idJSON, err := json.Marshal(id) - if err != nil { - return err - } - - joinConfiguration := resources.NewJoinServiceDaemonset(csp, string(measurementsJSON), string(idJSON)) +func (k *KubeWrapper) setupJoinService(csp string, measurementsJSON, measurementSalt []byte) error { + joinConfiguration := resources.NewJoinServiceDaemonset(csp, string(measurementsJSON), measurementSalt) return k.clusterUtil.SetupJoinService(k.client, joinConfiguration) } diff --git a/bootstrapper/internal/kubernetes/kubernetes_test.go b/bootstrapper/internal/kubernetes/kubernetes_test.go index af1135091..89c625093 100644 --- a/bootstrapper/internal/kubernetes/kubernetes_test.go +++ b/bootstrapper/internal/kubernetes/kubernetes_test.go @@ -10,7 +10,6 @@ import ( "github.com/edgelesssys/constellation/bootstrapper/internal/kubernetes/k8sapi" "github.com/edgelesssys/constellation/bootstrapper/internal/kubernetes/k8sapi/resources" "github.com/edgelesssys/constellation/bootstrapper/role" - attestationtypes "github.com/edgelesssys/constellation/internal/attestation/types" "github.com/edgelesssys/constellation/internal/cloud/metadata" "github.com/edgelesssys/constellation/internal/logger" "github.com/edgelesssys/constellation/internal/versions" @@ -297,7 +296,7 @@ func TestInitCluster(t *testing.T) { kubeconfigReader: tc.kubeconfigReader, getIPAddr: func() (string, error) { return privateIP, nil }, } - _, err := kube.InitCluster(context.Background(), autoscalingNodeGroups, serviceAccountURI, string(tc.k8sVersion), attestationtypes.ID{}, KMSConfig{MasterSecret: masterSecret}, nil, logger.NewTest(t)) + _, err := kube.InitCluster(context.Background(), autoscalingNodeGroups, serviceAccountURI, string(tc.k8sVersion), nil, KMSConfig{MasterSecret: masterSecret}, nil, logger.NewTest(t)) if tc.wantErr { assert.Error(err) diff --git a/bootstrapper/internal/nodelock/nodelock.go b/bootstrapper/internal/nodelock/nodelock.go index 9d0306474..0f2a6d59f 100644 --- a/bootstrapper/internal/nodelock/nodelock.go +++ b/bootstrapper/internal/nodelock/nodelock.go @@ -28,9 +28,10 @@ func New(tpm vtpm.TPMOpenFunc) *Lock { // TryLockOnce tries to lock the node. If the node is already locked, it // returns false. If the node is unlocked, it locks it and returns true. -func (l *Lock) TryLockOnce(ownerID, clusterID []byte) (bool, error) { +func (l *Lock) TryLockOnce(clusterID []byte) (bool, error) { if !l.mux.TryLock() { return false, nil } - return true, vtpm.MarkNodeAsBootstrapped(l.tpm, ownerID, clusterID) + + return true, vtpm.MarkNodeAsBootstrapped(l.tpm, clusterID) } diff --git a/bootstrapper/nodestate/nodestate.go b/bootstrapper/nodestate/nodestate.go index c9679c355..1cde08b9d 100644 --- a/bootstrapper/nodestate/nodestate.go +++ b/bootstrapper/nodestate/nodestate.go @@ -12,9 +12,8 @@ const nodeStatePath = "/run/state/constellation/node_state.json" // NodeState is the state of a constellation node that is required to recover from a reboot. // Can be persisted to disk and reloaded later. type NodeState struct { - Role role.Role - OwnerID []byte - ClusterID []byte + Role role.Role + MeasurementSalt []byte } // FromFile reads a NodeState from disk. diff --git a/bootstrapper/nodestate/nodestate_test.go b/bootstrapper/nodestate/nodestate_test.go index cddc55878..6703e69ad 100644 --- a/bootstrapper/nodestate/nodestate_test.go +++ b/bootstrapper/nodestate/nodestate_test.go @@ -23,11 +23,10 @@ func TestFromFile(t *testing.T) { wantErr bool }{ "nodestate exists": { - fileContents: `{ "Role": "ControlPlane", "OwnerID": "T3duZXJJRA==", "ClusterID": "Q2x1c3RlcklE" }`, + fileContents: `{ "Role": "ControlPlane", "MeasurementSalt": "U2FsdA==" }`, wantState: &NodeState{ - Role: role.ControlPlane, - OwnerID: []byte("OwnerID"), - ClusterID: []byte("ClusterID"), + Role: role.ControlPlane, + MeasurementSalt: []byte("Salt"), }, }, "nodestate file does not exist": { @@ -66,14 +65,12 @@ func TestToFile(t *testing.T) { }{ "writing works": { state: &NodeState{ - Role: role.ControlPlane, - OwnerID: []byte("OwnerID"), - ClusterID: []byte("ClusterID"), + Role: role.ControlPlane, + MeasurementSalt: []byte("Salt"), }, wantFile: `{ "Role": "ControlPlane", - "OwnerID": "T3duZXJJRA==", - "ClusterID": "Q2x1c3RlcklE" + "MeasurementSalt": "U2FsdA==" }`, }, "file exists already": { diff --git a/bootstrapper/util/util.go b/bootstrapper/util/util.go index 37b8ccb08..ea6e3e25d 100644 --- a/bootstrapper/util/util.go +++ b/bootstrapper/util/util.go @@ -1,42 +1,9 @@ package util import ( - "crypto/rand" - "crypto/sha256" - "io" - "math/big" "net" - - "golang.org/x/crypto/hkdf" ) -// DeriveKey derives a key from a secret. -// -// TODO: decide on a secure key derivation function. -func DeriveKey(secret, salt, info []byte, length uint) ([]byte, error) { - hkdf := hkdf.New(sha256.New, secret, salt, info) - key := make([]byte, length) - if _, err := io.ReadFull(hkdf, key); err != nil { - return nil, err - } - return key, nil -} - -// GenerateCertificateSerialNumber generates a random serial number for an X.509 certificate. -func GenerateCertificateSerialNumber() (*big.Int, error) { - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - return rand.Int(rand.Reader, serialNumberLimit) -} - -// GenerateRandomBytes reads length bytes from getrandom(2) if available, /dev/urandom otherwise. -func GenerateRandomBytes(length int) ([]byte, error) { - nonce := make([]byte, length) - if _, err := rand.Read(nonce); err != nil { - return nil, err - } - return nonce, nil -} - func GetIPAddr() (string, error) { conn, err := net.Dial("udp", "8.8.8.8:80") if err != nil { diff --git a/cli/internal/cmd/init.go b/cli/internal/cmd/init.go index d77b4e0e5..eb9715952 100644 --- a/cli/internal/cmd/init.go +++ b/cli/internal/cmd/init.go @@ -13,7 +13,6 @@ import ( "time" "github.com/edgelesssys/constellation/bootstrapper/initproto" - "github.com/edgelesssys/constellation/bootstrapper/util" "github.com/edgelesssys/constellation/cli/internal/azure" "github.com/edgelesssys/constellation/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/cli/internal/gcp" @@ -21,6 +20,7 @@ import ( "github.com/edgelesssys/constellation/internal/cloud/cloudtypes" "github.com/edgelesssys/constellation/internal/config" "github.com/edgelesssys/constellation/internal/constants" + "github.com/edgelesssys/constellation/internal/crypto" "github.com/edgelesssys/constellation/internal/deploy/ssh" "github.com/edgelesssys/constellation/internal/file" "github.com/edgelesssys/constellation/internal/grpc/dialer" @@ -178,7 +178,7 @@ func writeOutput(resp *initproto.InitResponse, ip string, wr io.Writer, fileHand clusterID := base64.StdEncoding.EncodeToString(resp.ClusterId) tw := tabwriter.NewWriter(wr, 0, 0, 2, ' ', 0) - writeRow(tw, "Constellation cluster's owner identifier", ownerID) + // writeRow(tw, "Constellation cluster's owner identifier", ownerID) writeRow(tw, "Constellation cluster's unique identifier", clusterID) writeRow(tw, "Kubernetes configuration", constants.AdminConfFilename) tw.Flush() @@ -252,14 +252,14 @@ func readOrGenerateMasterSecret(writer io.Writer, fileHandler file.Handler, file if err != nil { return nil, err } - if len(decoded) < constants.MasterSecretLengthMin { + if len(decoded) < crypto.MasterSecretLengthMin { return nil, errors.New("provided master secret is smaller than the required minimum of 16 Bytes") } return decoded, nil } // No file given, generate a new secret, and save it to disk - masterSecret, err := util.GenerateRandomBytes(constants.MasterSecretLengthDefault) + masterSecret, err := crypto.GenerateRandomBytes(crypto.MasterSecretLengthDefault) if err != nil { return nil, err } diff --git a/cli/internal/cmd/init_test.go b/cli/internal/cmd/init_test.go index 6863fd557..c0eaf7e48 100644 --- a/cli/internal/cmd/init_test.go +++ b/cli/internal/cmd/init_test.go @@ -163,7 +163,7 @@ func TestInitialize(t *testing.T) { return } require.NoError(err) - assert.Contains(out.String(), base64.StdEncoding.EncodeToString([]byte("ownerID"))) + // assert.Contains(out.String(), base64.StdEncoding.EncodeToString([]byte("ownerID"))) assert.Contains(out.String(), base64.StdEncoding.EncodeToString([]byte("clusterID"))) if tc.setAutoscaleFlag { assert.Len(tc.initServerAPI.activateAutoscalingNodeGroups, 1) @@ -198,7 +198,7 @@ func TestWriteOutput(t *testing.T) { err := writeOutput(resp, "ip", &out, fileHandler) assert.NoError(err) - assert.Contains(out.String(), ownerID) + // assert.Contains(out.String(), ownerID) assert.Contains(out.String(), clusterID) assert.Contains(out.String(), constants.AdminConfFilename) diff --git a/cli/internal/cmd/recover.go b/cli/internal/cmd/recover.go index ed0e531de..f6d144b95 100644 --- a/cli/internal/cmd/recover.go +++ b/cli/internal/cmd/recover.go @@ -7,11 +7,12 @@ import ( "regexp" "strings" - "github.com/edgelesssys/constellation/bootstrapper/util" "github.com/edgelesssys/constellation/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/cli/internal/proto" + "github.com/edgelesssys/constellation/internal/attestation" "github.com/edgelesssys/constellation/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/internal/constants" + "github.com/edgelesssys/constellation/internal/crypto" "github.com/edgelesssys/constellation/internal/file" "github.com/edgelesssys/constellation/internal/state" "github.com/spf13/afero" @@ -78,7 +79,12 @@ func recover(cmd *cobra.Command, fileHandler file.Handler, recoveryClient recove return err } - if err := recoveryClient.PushStateDiskKey(cmd.Context(), diskKey); err != nil { + measurementSecret, err := attestation.DeriveMeasurementSecret(flags.masterSecret) + if err != nil { + return err + } + + if err := recoveryClient.PushStateDiskKey(cmd.Context(), diskKey, measurementSecret); err != nil { return err } @@ -150,5 +156,5 @@ func readMasterSecret(fileHandler file.Handler, filename string) ([]byte, error) // deriveStateDiskKey derives a state disk key from a master secret and a disk UUID. func deriveStateDiskKey(masterKey []byte, diskUUID string) ([]byte, error) { - return util.DeriveKey(masterKey, []byte("Constellation"), []byte("key"+diskUUID), constants.StateDiskKeyLength) + return crypto.DeriveKey(masterKey, []byte("Constellation"), []byte(crypto.HKDFInfoPrefix+diskUUID), crypto.StateDiskKeyLength) } diff --git a/cli/internal/cmd/recover_test.go b/cli/internal/cmd/recover_test.go index 29df499a1..f6c1ea9df 100644 --- a/cli/internal/cmd/recover_test.go +++ b/cli/internal/cmd/recover_test.go @@ -64,7 +64,7 @@ func TestRecover(t *testing.T) { client: &stubRecoveryClient{}, endpointFlag: "192.0.2.1", diskUUIDFlag: "00000000-0000-0000-0000-000000000000", - wantKey: []byte{0x2e, 0x4d, 0x40, 0x3a, 0x90, 0x96, 0x6e, 0xd, 0x42, 0x3, 0x98, 0xd, 0xce, 0xc5, 0x73, 0x26, 0xf4, 0x87, 0xcf, 0x85, 0x73, 0xe1, 0xb7, 0xd6, 0xb2, 0x82, 0x4c, 0xd9, 0xbc, 0xa5, 0x7c, 0x32}, + wantKey: []byte{0x4d, 0x34, 0x19, 0x1a, 0xf9, 0x23, 0xb9, 0x61, 0x55, 0x9b, 0xb2, 0x6, 0x15, 0x1b, 0x5f, 0xe, 0x21, 0xc2, 0xe5, 0x18, 0x1c, 0xfa, 0x32, 0x79, 0xa4, 0x6b, 0x84, 0x86, 0x7e, 0xd7, 0xf6, 0x76}, }, "uppercase disk uuid works": { setupFs: func(require *require.Assertions) afero.Fs { @@ -76,7 +76,7 @@ func TestRecover(t *testing.T) { client: &stubRecoveryClient{}, endpointFlag: "192.0.2.1", diskUUIDFlag: "ABCDEFAB-CDEF-ABCD-ABCD-ABCDEFABCDEF", - wantKey: []byte{0xa9, 0x4, 0x3a, 0x74, 0x53, 0xeb, 0x23, 0xb2, 0xbc, 0x88, 0xce, 0xa7, 0x4e, 0xa9, 0xda, 0x9f, 0x11, 0x85, 0xc4, 0x2f, 0x1f, 0x25, 0x10, 0xc9, 0xec, 0xfe, 0xa, 0x6c, 0xa2, 0x6f, 0x53, 0x34}, + wantKey: []byte{0x7e, 0xc0, 0xa8, 0x84, 0xc4, 0x7, 0xda, 0x1, 0xed, 0xa9, 0xc8, 0x87, 0x77, 0xad, 0x86, 0x7c, 0x7d, 0x40, 0xa7, 0x28, 0x3d, 0xbd, 0x92, 0xea, 0xa1, 0x84, 0x67, 0x78, 0x58, 0x76, 0x13, 0x70}, }, "lowercase disk uuid results in same key": { setupFs: func(require *require.Assertions) afero.Fs { @@ -88,7 +88,7 @@ func TestRecover(t *testing.T) { client: &stubRecoveryClient{}, endpointFlag: "192.0.2.1", diskUUIDFlag: "abcdefab-cdef-abcd-abcd-abcdefabcdef", - wantKey: []byte{0xa9, 0x4, 0x3a, 0x74, 0x53, 0xeb, 0x23, 0xb2, 0xbc, 0x88, 0xce, 0xa7, 0x4e, 0xa9, 0xda, 0x9f, 0x11, 0x85, 0xc4, 0x2f, 0x1f, 0x25, 0x10, 0xc9, 0xec, 0xfe, 0xa, 0x6c, 0xa2, 0x6f, 0x53, 0x34}, + wantKey: []byte{0x7e, 0xc0, 0xa8, 0x84, 0xc4, 0x7, 0xda, 0x1, 0xed, 0xa9, 0xc8, 0x87, 0x77, 0xad, 0x86, 0x7c, 0x7d, 0x40, 0xa7, 0x28, 0x3d, 0xbd, 0x92, 0xea, 0xa1, 0x84, 0x67, 0x78, 0x58, 0x76, 0x13, 0x70}, }, "missing flags": { setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() }, @@ -324,8 +324,8 @@ func TestDeriveStateDiskKey(t *testing.T) { }, diskUUID: "00000000-0000-0000-0000-000000000000", wantStateDiskKey: []byte{ - 0xa8, 0xb0, 0x86, 0x83, 0x6f, 0x0b, 0x26, 0x04, 0x86, 0x22, 0x27, 0xcc, 0xa1, 0x1c, 0xaf, 0x6c, - 0x30, 0x4d, 0x90, 0x89, 0x82, 0x68, 0x53, 0x7f, 0x4f, 0x46, 0x7a, 0x65, 0xa2, 0x5d, 0x5e, 0x43, + 0xc6, 0xe0, 0xae, 0xfc, 0xbe, 0x7b, 0x7e, 0x87, 0x7a, 0xdd, 0xb2, 0x87, 0xe0, 0xcd, 0x4c, 0xe4, + 0xde, 0xee, 0xb3, 0x57, 0xaa, 0x6c, 0xc9, 0x44, 0x90, 0xc4, 0x07, 0x72, 0x01, 0x7d, 0xd6, 0xb1, }, }, "all 0xff": { @@ -335,8 +335,8 @@ func TestDeriveStateDiskKey(t *testing.T) { }, diskUUID: "ffffffff-ffff-ffff-ffff-ffffffffffff", wantStateDiskKey: []byte{ - 0x24, 0x18, 0x84, 0x7f, 0xca, 0x86, 0x55, 0xb5, 0x45, 0xa6, 0xb3, 0xc4, 0x45, 0xbb, 0x08, 0x10, - 0x16, 0xb3, 0xde, 0x30, 0x30, 0x74, 0x0b, 0xd4, 0x1e, 0x22, 0x55, 0x45, 0x51, 0x91, 0xfb, 0xa9, + 0x00, 0x74, 0x4c, 0xb0, 0x92, 0x9d, 0x20, 0x08, 0xfa, 0x72, 0xac, 0xd2, 0xb6, 0xe4, 0xc6, 0x6f, + 0xa3, 0x53, 0x16, 0xb1, 0x9e, 0x77, 0x42, 0xe8, 0xd3, 0x66, 0xe8, 0x22, 0x33, 0xfc, 0x63, 0x4d, }, }, } diff --git a/cli/internal/cmd/recoveryclient.go b/cli/internal/cmd/recoveryclient.go index 9301fb7c5..097c4eda7 100644 --- a/cli/internal/cmd/recoveryclient.go +++ b/cli/internal/cmd/recoveryclient.go @@ -9,6 +9,6 @@ import ( type recoveryClient interface { Connect(endpoint string, validators []atls.Validator) error - PushStateDiskKey(ctx context.Context, stateDiskKey []byte) error + PushStateDiskKey(ctx context.Context, stateDiskKey, measurementSecret []byte) error io.Closer } diff --git a/cli/internal/cmd/recoveryclient_test.go b/cli/internal/cmd/recoveryclient_test.go index a6d04c02c..8d6c46ec6 100644 --- a/cli/internal/cmd/recoveryclient_test.go +++ b/cli/internal/cmd/recoveryclient_test.go @@ -25,7 +25,7 @@ func (c *stubRecoveryClient) Close() error { return c.closeErr } -func (c *stubRecoveryClient) PushStateDiskKey(_ context.Context, stateDiskKey []byte) error { +func (c *stubRecoveryClient) PushStateDiskKey(_ context.Context, stateDiskKey, _ []byte) error { c.pushStateDiskKeyKey = stateDiskKey return c.pushStateDiskKeyErr } diff --git a/cli/internal/cmd/verify.go b/cli/internal/cmd/verify.go index 6fd3b658a..cf01a1d93 100644 --- a/cli/internal/cmd/verify.go +++ b/cli/internal/cmd/verify.go @@ -8,11 +8,11 @@ import ( "io/fs" "net" - "github.com/edgelesssys/constellation/bootstrapper/util" "github.com/edgelesssys/constellation/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/internal/atls" "github.com/edgelesssys/constellation/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/internal/constants" + "github.com/edgelesssys/constellation/internal/crypto" "github.com/edgelesssys/constellation/internal/file" "github.com/edgelesssys/constellation/internal/grpc/dialer" "github.com/edgelesssys/constellation/verify/verifyproto" @@ -37,7 +37,7 @@ If arguments aren't specified, values are read from ` + "`" + constants.ClusterI RunE: runVerify, } cmd.Flags().String("owner-id", "", "verify using the owner identity derived from the master secret") - cmd.Flags().String("unique-id", "", "verify using the unique cluster identity") + cmd.Flags().String("cluster-id", "", "verify using the unique cluster identity") cmd.Flags().StringP("node-endpoint", "e", "", "endpoint of the node to verify, passed as HOST[:PORT]") return cmd } @@ -74,11 +74,11 @@ func verify( cmd.Print(validators.Warnings()) } - nonce, err := util.GenerateRandomBytes(32) + nonce, err := crypto.GenerateRandomBytes(32) if err != nil { return err } - userData, err := util.GenerateRandomBytes(32) + userData, err := crypto.GenerateRandomBytes(32) if err != nil { return err } @@ -108,9 +108,9 @@ func parseVerifyFlags(cmd *cobra.Command, fileHandler file.Handler) (verifyFlags if err != nil { return verifyFlags{}, fmt.Errorf("parsing owner-id argument: %w", err) } - clusterID, err := cmd.Flags().GetString("unique-id") + clusterID, err := cmd.Flags().GetString("cluster-id") if err != nil { - return verifyFlags{}, fmt.Errorf("parsing unique-id argument: %w", err) + return verifyFlags{}, fmt.Errorf("parsing cluster-id argument: %w", err) } endpoint, err := cmd.Flags().GetString("node-endpoint") if err != nil { @@ -127,7 +127,7 @@ func parseVerifyFlags(cmd *cobra.Command, fileHandler file.Handler) (verifyFlags endpoint = details.Endpoint } if emptyIDs { - cmd.Printf("Using IDs from %q. Specify --owner-id and/or --unique-id to override this.\n", constants.ClusterIDsFileName) + cmd.Printf("Using IDs from %q. Specify --owner-id and/or --cluster-id to override this.\n", constants.ClusterIDsFileName) ownerID = details.OwnerID clusterID = details.ClusterID } @@ -138,7 +138,7 @@ func parseVerifyFlags(cmd *cobra.Command, fileHandler file.Handler) (verifyFlags // Validate if ownerID == "" && clusterID == "" { - return verifyFlags{}, errors.New("neither owner-id nor unique-id provided to verify the cluster") + return verifyFlags{}, errors.New("neither owner-id nor cluster-id provided to verify the cluster") } endpoint, err = validateEndpoint(endpoint, constants.BootstrapperPort) if err != nil { diff --git a/cli/internal/proto/recover.go b/cli/internal/proto/recover.go index 362855ce8..92795d8da 100644 --- a/cli/internal/proto/recover.go +++ b/cli/internal/proto/recover.go @@ -50,13 +50,14 @@ func (c *KeyClient) Close() error { // PushStateDiskKey pushes the state disk key to a constellation instance in recovery mode. // The state disk key must be derived from the UUID of the state disk and the master key. -func (c *KeyClient) PushStateDiskKey(ctx context.Context, stateDiskKey []byte) error { +func (c *KeyClient) PushStateDiskKey(ctx context.Context, stateDiskKey, measurementSecret []byte) error { if c.keyapi == nil { return errors.New("client is not connected") } req := &keyproto.PushStateDiskKeyRequest{ - StateDiskKey: stateDiskKey, + StateDiskKey: stateDiskKey, + MeasurementSecret: measurementSecret, } _, err := c.keyapi.PushStateDiskKey(ctx, req) diff --git a/hack/pcr-reader/main.go b/hack/pcr-reader/main.go index c95e9066e..4bcbe9fb2 100644 --- a/hack/pcr-reader/main.go +++ b/hack/pcr-reader/main.go @@ -14,9 +14,9 @@ import ( "strconv" "time" - "github.com/edgelesssys/constellation/bootstrapper/util" "github.com/edgelesssys/constellation/internal/attestation/vtpm" "github.com/edgelesssys/constellation/internal/constants" + "github.com/edgelesssys/constellation/internal/crypto" "github.com/edgelesssys/constellation/verify/verifyproto" "github.com/spf13/afero" "google.golang.org/grpc" @@ -86,7 +86,7 @@ func getAttestation(ctx context.Context, addr string) ([]byte, error) { } defer conn.Close() - nonce, err := util.GenerateRandomBytes(32) + nonce, err := crypto.GenerateRandomBytes(32) if err != nil { return nil, err } diff --git a/internal/atls/atls.go b/internal/atls/atls.go index 56f4d8fce..0dff97a6c 100644 --- a/internal/atls/atls.go +++ b/internal/atls/atls.go @@ -17,8 +17,7 @@ import ( "math/big" "time" - "github.com/edgelesssys/constellation/bootstrapper/util" - "github.com/edgelesssys/constellation/internal/constants" + "github.com/edgelesssys/constellation/internal/crypto" "github.com/edgelesssys/constellation/internal/oid" ) @@ -45,7 +44,7 @@ func CreateAttestationServerTLSConfig(issuer Issuer, validators []Validator) (*t // If no validators are set, the server's attestation document will not be verified. // If issuer is nil, the client will be unable to perform mutual aTLS. func CreateAttestationClientTLSConfig(issuer Issuer, validators []Validator) (*tls.Config, error) { - clientNonce, err := util.GenerateRandomBytes(constants.RNGLengthDefault) + clientNonce, err := crypto.GenerateRandomBytes(crypto.RNGLengthDefault) if err != nil { return nil, err } @@ -87,7 +86,7 @@ func getATLSConfigForClientFunc(issuer Issuer, validators []Validator) (func(*tl // this function will be called once for every client return func(chi *tls.ClientHelloInfo) (*tls.Config, error) { // generate nonce for this connection - serverNonce, err := util.GenerateRandomBytes(constants.RNGLengthDefault) + serverNonce, err := crypto.GenerateRandomBytes(crypto.RNGLengthDefault) if err != nil { return nil, err } @@ -122,7 +121,7 @@ func getATLSConfigForClientFunc(issuer Issuer, validators []Validator) (func(*tl // getCertificate creates a client or server certificate for aTLS connections. // The certificate uses certificate extensions to embed an attestation document generated using nonce. func getCertificate(issuer Issuer, priv, pub any, nonce []byte) (*tls.Certificate, error) { - serialNumber, err := util.GenerateCertificateSerialNumber() + serialNumber, err := crypto.GenerateCertificateSerialNumber() if err != nil { return nil, err } diff --git a/internal/attestation/attestation.go b/internal/attestation/attestation.go new file mode 100644 index 000000000..e3d43d8ed --- /dev/null +++ b/internal/attestation/attestation.go @@ -0,0 +1,24 @@ +package attestation + +import ( + "github.com/edgelesssys/constellation/internal/crypto" +) + +const ( + // clusterIDContext is the value to use for info when deriving the cluster ID. + clusterIDContext = "clusterID" + // MeasurementSecretContext is the value to use for info + // when deriving the measurement secret from the master secret. + MeasurementSecretContext = "measurementSecret" +) + +// DeriveClusterID derives the cluster ID from a salt and secret value. +func DeriveClusterID(salt, secret []byte) ([]byte, error) { + return crypto.DeriveKey(secret, salt, []byte(crypto.HKDFInfoPrefix+clusterIDContext), crypto.DerivedKeyLengthDefault) +} + +// DeriveMeasurementSecret derives the secret value needed to derive ClusterID. +func DeriveMeasurementSecret(masterSecret []byte) ([]byte, error) { + // TODO: replace hard coded salt + return crypto.DeriveKey(masterSecret, []byte("Constellation"), []byte(crypto.HKDFInfoPrefix+MeasurementSecretContext), crypto.DerivedKeyLengthDefault) +} diff --git a/internal/attestation/types/types.go b/internal/attestation/types/types.go deleted file mode 100644 index 8d874f168..000000000 --- a/internal/attestation/types/types.go +++ /dev/null @@ -1,7 +0,0 @@ -package attestationtypes - -// ID holds the identifiers of a node. -type ID struct { - Cluster []byte `json:"cluster"` - Owner []byte `json:"owner"` -} diff --git a/internal/attestation/vtpm/initialize.go b/internal/attestation/vtpm/initialize.go index 2050237c5..b15278db4 100644 --- a/internal/attestation/vtpm/initialize.go +++ b/internal/attestation/vtpm/initialize.go @@ -2,7 +2,6 @@ package vtpm import ( "errors" - "fmt" "github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpmutil" @@ -18,22 +17,18 @@ const ( ) // MarkNodeAsBootstrapped marks a node as initialized by extending PCRs. -func MarkNodeAsBootstrapped(openTPM TPMOpenFunc, ownerID, clusterID []byte) error { +func MarkNodeAsBootstrapped(openTPM TPMOpenFunc, clusterID []byte) error { tpm, err := openTPM() if err != nil { return err } defer tpm.Close() - // ownerID is used to identify the Constellation as belonging to a specific master key - if err := tpm2.PCREvent(tpm, PCRIndexOwnerID, ownerID); err != nil { - return err - } // clusterID is used to uniquely identify this running instance of Constellation return tpm2.PCREvent(tpm, PCRIndexClusterID, clusterID) } -// IsNodeBootstrapped checks if a node is already bootestrapped by reading PCRs. +// IsNodeBootstrapped checks if a node is already bootstrapped by reading PCRs. func IsNodeBootstrapped(openTPM TPMOpenFunc) (bool, error) { tpm, err := openTPM() if err != nil { @@ -41,6 +36,20 @@ func IsNodeBootstrapped(openTPM TPMOpenFunc) (bool, error) { } defer tpm.Close() + idxClusterID := int(PCRIndexClusterID) + pcrs, err := tpm2.ReadPCRs(tpm, tpm2.PCRSelection{ + Hash: tpm2.AlgSHA256, + PCRs: []int{idxClusterID}, + }) + if err != nil { + return false, err + } + if len(pcrs[idxClusterID]) == 0 { + return false, errors.New("cluster ID PCR does not exist") + } + return pcrInitialized(pcrs[idxClusterID]), nil + + /* Old code that will be reenabled in the future idxOwner := int(PCRIndexOwnerID) idxCluster := int(PCRIndexClusterID) selection := tpm2.PCRSelection{ @@ -75,6 +84,7 @@ func IsNodeBootstrapped(openTPM TPMOpenFunc) (bool, error) { clusterState = "initialized" } return false, fmt.Errorf("PCRs %v and %v are not consistent: PCR[%v]=%v (%v), PCR[%v]=%v (%v)", idxOwner, idxCluster, idxOwner, pcrs[idxOwner], ownerState, idxCluster, pcrs[idxCluster], clusterState) + */ } // pcrInitialized checks if a PCR value is set to a non-zero value. diff --git a/internal/attestation/vtpm/initialize_test.go b/internal/attestation/vtpm/initialize_test.go index de289661a..68d21d51f 100644 --- a/internal/attestation/vtpm/initialize_test.go +++ b/internal/attestation/vtpm/initialize_test.go @@ -33,13 +33,12 @@ func TestMarkNodeAsInitialized(t *testing.T) { assert.NoError(MarkNodeAsBootstrapped(func() (io.ReadWriteCloser, error) { return &simTPMNOPCloser{tpm}, nil - }, []byte{0x0, 0x1, 0x2, 0x3}, []byte{0x4, 0x5, 0x6, 0x7})) + }, []byte{0x0, 0x1, 0x2, 0x3})) pcrsInitialized, err := client.ReadAllPCRs(tpm) require.NoError(err) for i := range pcrs { - assert.NotEqual(pcrs[i].Pcrs[uint32(PCRIndexOwnerID)], pcrsInitialized[i].Pcrs[uint32(PCRIndexOwnerID)]) assert.NotEqual(pcrs[i].Pcrs[uint32(PCRIndexClusterID)], pcrsInitialized[i].Pcrs[uint32(PCRIndexClusterID)]) } } @@ -47,30 +46,20 @@ func TestMarkNodeAsInitialized(t *testing.T) { func TestFailOpener(t *testing.T) { assert := assert.New(t) - assert.Error(MarkNodeAsBootstrapped(func() (io.ReadWriteCloser, error) { return nil, errors.New("failed") }, []byte{0x0, 0x1, 0x2, 0x3}, []byte{0x0, 0x1, 0x2, 0x3})) + assert.Error(MarkNodeAsBootstrapped(func() (io.ReadWriteCloser, error) { return nil, errors.New("failed") }, []byte{0x0, 0x1, 0x2, 0x3})) } func TestIsNodeInitialized(t *testing.T) { testCases := map[string]struct { - pcrValueOwnerID []byte pcrValueClusterID []byte wantInitialized bool wantErr bool }{ "uninitialized PCRs results in uninitialized node": {}, "initializing PCRs result in initialized node": { - pcrValueOwnerID: []byte{0x0, 0x1, 0x2, 0x3}, pcrValueClusterID: []byte{0x4, 0x5, 0x6, 0x7}, wantInitialized: true, }, - "initializing ownerID alone fails": { - pcrValueOwnerID: []byte{0x0, 0x1, 0x2, 0x3}, - wantErr: true, - }, - "initializing clusterID alone fails": { - pcrValueClusterID: []byte{0x4, 0x5, 0x6, 0x7}, - wantErr: true, - }, } for name, tc := range testCases { @@ -80,9 +69,6 @@ func TestIsNodeInitialized(t *testing.T) { tpm, err := simulator.OpenSimulatedTPM() require.NoError(err) defer tpm.Close() - if tc.pcrValueOwnerID != nil { - require.NoError(tpm2.PCREvent(tpm, PCRIndexOwnerID, tc.pcrValueOwnerID)) - } if tc.pcrValueClusterID != nil { require.NoError(tpm2.PCREvent(tpm, PCRIndexClusterID, tc.pcrValueClusterID)) } diff --git a/internal/attestation/vtpm/vtpm_test.go b/internal/attestation/vtpm/vtpm_test.go index 3652709a3..869c2c49b 100644 --- a/internal/attestation/vtpm/vtpm_test.go +++ b/internal/attestation/vtpm/vtpm_test.go @@ -14,5 +14,5 @@ func TestMain(m *testing.M) { func TestNOPTPM(t *testing.T) { assert := assert.New(t) - assert.NoError(MarkNodeAsBootstrapped(OpenNOPTPM, []byte{0x0, 0x1, 0x2, 0x3}, []byte{0x4, 0x5, 0x6, 0x7})) + assert.NoError(MarkNodeAsBootstrapped(OpenNOPTPM, []byte{0x0, 0x1, 0x2, 0x3})) } diff --git a/internal/cloud/metadata/metadata.go b/internal/cloud/metadata/metadata.go index a0dc1ec57..095359294 100644 --- a/internal/cloud/metadata/metadata.go +++ b/internal/cloud/metadata/metadata.go @@ -50,20 +50,20 @@ func InitServerEndpoints(ctx context.Context, lister InstanceLister) ([]string, return initServerEndpoints, nil } -// KMSEndpoints returns the list of endpoints for the KMS service, which are running on the control plane nodes. -func KMSEndpoints(ctx context.Context, lister InstanceLister) ([]string, error) { +// JoinServiceEndpoints returns the list of endpoints for the join service, which are running on the control plane nodes. +func JoinServiceEndpoints(ctx context.Context, lister InstanceLister) ([]string, error) { instances, err := lister.List(ctx) if err != nil { return nil, fmt.Errorf("retrieving instances list from cloud provider: %w", err) } - kmsEndpoints := []string{} + joinEndpoints := []string{} for _, instance := range instances { if instance.Role == role.ControlPlane { for _, ip := range instance.PrivateIPs { - kmsEndpoints = append(kmsEndpoints, net.JoinHostPort(ip, strconv.Itoa(constants.KMSNodePort))) + joinEndpoints = append(joinEndpoints, net.JoinHostPort(ip, strconv.Itoa(constants.JoinServiceNodePort))) } } } - return kmsEndpoints, nil + return joinEndpoints, nil } diff --git a/internal/constants/constants.go b/internal/constants/constants.go index dcfe4e37e..6667d9486 100644 --- a/internal/constants/constants.go +++ b/internal/constants/constants.go @@ -31,11 +31,7 @@ const ( VerifyServiceNodePortHTTP = 30080 VerifyServiceNodePortGRPC = 30081 // KMSPort is the port the KMS server listens on. - KMSPort = 9000 - // KMSATLSPort is the port the KMS aTLS server listens on. - KMSATLSPort = 9001 - // KMSNodePort is the aTLS port exposed as a NodePort. - KMSNodePort = 30091 + KMSPort = 9000 BootstrapperPort = 9000 EnclaveSSHPort = 2222 SSHPort = 22 @@ -67,25 +63,13 @@ const ( ServiceBasePath = "/var/config" // MeasurementsFilename is the filename of CC measurements. MeasurementsFilename = "measurements" - // IDFilename is the filename of Constellation's IDs. - IDFilename = "id" + // MeasurementSaltFilename is the filename of the salt used in creation of the clusterID. + MeasurementSaltFilename = "measurementSalt" + // MeasurementSecretFilename is the filename of the secret used in creation of the clusterID. + MeasurementSecretFilename = "measurementSecret" // K8sVersion is the filename of the mapped "k8s-version" configMap file. K8sVersion = "k8s-version" - // - // Cryptographic constants. - // - - StateDiskKeyLength = 32 - // DerivedKeyLengthDefault is the default length in bytes for KMS derived keys. - DerivedKeyLengthDefault = 32 - // MasterSecretLengthDefault is the default length in bytes for CLI generated master secrets. - MasterSecretLengthDefault = 32 - // MasterSecretLengthMin is the minimal length in bytes for user provided master secrets. - MasterSecretLengthMin = 16 - // RNGLengthDefault is the number of bytes used for generating nonces. - RNGLengthDefault = 32 - // // CLI. // diff --git a/internal/crypto/crypto.go b/internal/crypto/crypto.go new file mode 100644 index 000000000..f2c3699f7 --- /dev/null +++ b/internal/crypto/crypto.go @@ -0,0 +1,52 @@ +// Package crypto provides functions to for cryptography and random numbers. +package crypto + +import ( + "crypto/rand" + "crypto/sha256" + "io" + "math/big" + + "golang.org/x/crypto/hkdf" +) + +const ( + StateDiskKeyLength = 32 + // DerivedKeyLengthDefault is the default length in bytes for KMS derived keys. + DerivedKeyLengthDefault = 32 + // MasterSecretLengthDefault is the default length in bytes for CLI generated master secrets. + MasterSecretLengthDefault = 32 + // MasterSecretLengthMin is the minimal length in bytes for user provided master secrets. + MasterSecretLengthMin = 16 + // RNGLengthDefault is the number of bytes used for generating nonces. + RNGLengthDefault = 32 + // HKDFInfoPrefix is the prefix used for the info parameter in HKDF. + HKDFInfoPrefix = "key-" +) + +// DeriveKey derives a key from a secret. +// +// TODO: decide on a secure key derivation function. +func DeriveKey(secret, salt, info []byte, length uint) ([]byte, error) { + hkdf := hkdf.New(sha256.New, secret, salt, info) + key := make([]byte, length) + if _, err := io.ReadFull(hkdf, key); err != nil { + return nil, err + } + return key, nil +} + +// GenerateCertificateSerialNumber generates a random serial number for an X.509 certificate. +func GenerateCertificateSerialNumber() (*big.Int, error) { + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + return rand.Int(rand.Reader, serialNumberLimit) +} + +// GenerateRandomBytes reads length bytes from getrandom(2) if available, /dev/urandom otherwise. +func GenerateRandomBytes(length int) ([]byte, error) { + nonce := make([]byte, length) + if _, err := rand.Read(nonce); err != nil { + return nil, err + } + return nonce, nil +} diff --git a/bootstrapper/util/util_test.go b/internal/crypto/crypto_test.go similarity index 98% rename from bootstrapper/util/util_test.go rename to internal/crypto/crypto_test.go index 1e56c1274..423d4963c 100644 --- a/bootstrapper/util/util_test.go +++ b/internal/crypto/crypto_test.go @@ -1,4 +1,4 @@ -package util +package crypto import ( "testing" diff --git a/internal/versions/versions.go b/internal/versions/versions.go index 8148d52d8..f0c0a88a4 100644 --- a/internal/versions/versions.go +++ b/internal/versions/versions.go @@ -26,9 +26,9 @@ func IsSupportedK8sVersion(version string) bool { const ( // Constellation images. // These images are built in a way that they support all versions currently listed in VersionConfigs. - JoinImage = "ghcr.io/edgelesssys/constellation/join-service:v1.3.2-0.20220722130504-526de20a" + JoinImage = "ghcr.io/edgelesssys/constellation/join-service:v1.3.2-0.20220719121753-1a6deb94" AccessManagerImage = "ghcr.io/edgelesssys/constellation/access-manager:v1.3.2-0.20220714151638-d295be31" - KmsImage = "ghcr.io/edgelesssys/constellation/kmsserver:v1.3.2-0.20220714151638-d295be31" + KmsImage = "ghcr.io/edgelesssys/constellation/kmsserver:v1.3.2-0.20220722135959-3e250b12" VerificationImage = "ghcr.io/edgelesssys/constellation/verification-service:v1.3.2-0.20220714151638-d295be31" GcpGuestImage = "ghcr.io/edgelesssys/gcp-guest-agent:latest" diff --git a/joinservice/cmd/main.go b/joinservice/cmd/main.go index 049b0ea13..15235eacc 100644 --- a/joinservice/cmd/main.go +++ b/joinservice/cmd/main.go @@ -63,12 +63,18 @@ func main() { } kms := kms.New(log.Named("kms"), *kmsEndpoint) + measurementSalt, err := handler.Read(filepath.Join(constants.ServiceBasePath, constants.MeasurementSaltFilename)) + if err != nil { + log.With(zap.Error(err)).Fatalf("Failed to read measurement salt") + } + server := server.New( - log.Named("server"), + measurementSalt, handler, kubernetesca.New(log.Named("certificateAuthority"), handler), kubeadm, kms, + log.Named("server"), ) watcher, err := watcher.New(log.Named("fileWatcher"), validator) diff --git a/joinservice/internal/kms/kms.go b/joinservice/internal/kms/kms.go index 501d67f9f..c6610e0e1 100644 --- a/joinservice/internal/kms/kms.go +++ b/joinservice/internal/kms/kms.go @@ -28,8 +28,8 @@ func New(log *logger.Logger, endpoint string) Client { } // GetDataKey returns a data encryption key for the given UUID. -func (c Client) GetDataKey(ctx context.Context, uuid string, length int) ([]byte, error) { - log := c.log.With(zap.String("diskUUID", uuid), zap.String("endpoint", c.endpoint)) +func (c Client) GetDataKey(ctx context.Context, keyID string, length int) ([]byte, error) { + log := c.log.With(zap.String("keyID", keyID), zap.String("endpoint", c.endpoint)) // TODO: update credentials if we enable aTLS on the KMS // For now this is fine since traffic is only routed through the Constellation cluster log.Infof("Connecting to KMS at %s", c.endpoint) @@ -43,7 +43,7 @@ func (c Client) GetDataKey(ctx context.Context, uuid string, length int) ([]byte res, err := c.grpc.GetDataKey( ctx, &kmsproto.GetDataKeyRequest{ - DataKeyId: uuid, + DataKeyId: keyID, Length: uint32(length), }, conn, diff --git a/joinservice/internal/kubernetesca/kubernetesca.go b/joinservice/internal/kubernetesca/kubernetesca.go index 471709623..b93a59fc4 100644 --- a/joinservice/internal/kubernetesca/kubernetesca.go +++ b/joinservice/internal/kubernetesca/kubernetesca.go @@ -9,7 +9,7 @@ import ( "strings" "time" - "github.com/edgelesssys/constellation/bootstrapper/util" + "github.com/edgelesssys/constellation/internal/crypto" "github.com/edgelesssys/constellation/internal/file" "github.com/edgelesssys/constellation/internal/logger" kubeconstants "k8s.io/kubernetes/cmd/kubeadm/app/constants" @@ -87,7 +87,7 @@ func (c KubernetesCA) GetCertificate(csr []byte) (cert []byte, err error) { return nil, fmt.Errorf("certificate request must have common name prefix %q but is %q", kubeconstants.NodesUserPrefix, certRequest.Subject.CommonName) } - serialNumber, err := util.GenerateCertificateSerialNumber() + serialNumber, err := crypto.GenerateCertificateSerialNumber() if err != nil { return nil, err } diff --git a/joinservice/internal/server/server.go b/joinservice/internal/server/server.go index b2df93cdc..97862cc04 100644 --- a/joinservice/internal/server/server.go +++ b/joinservice/internal/server/server.go @@ -7,8 +7,9 @@ import ( "path/filepath" "time" - attestationtypes "github.com/edgelesssys/constellation/internal/attestation/types" + "github.com/edgelesssys/constellation/internal/attestation" "github.com/edgelesssys/constellation/internal/constants" + "github.com/edgelesssys/constellation/internal/crypto" "github.com/edgelesssys/constellation/internal/file" "github.com/edgelesssys/constellation/internal/grpc/grpclog" "github.com/edgelesssys/constellation/internal/logger" @@ -25,6 +26,8 @@ import ( // Server implements the core logic of Constellation's node join service. type Server struct { + measurementSalt []byte + log *logger.Logger file file.Handler joinTokenGetter joinTokenGetter @@ -34,8 +37,12 @@ type Server struct { } // New initializes a new Server. -func New(log *logger.Logger, fileHandler file.Handler, ca certificateAuthority, joinTokenGetter joinTokenGetter, dataKeyGetter dataKeyGetter) *Server { +func New( + measurementSalt []byte, fileHandler file.Handler, ca certificateAuthority, + joinTokenGetter joinTokenGetter, dataKeyGetter dataKeyGetter, log *logger.Logger, +) *Server { return &Server{ + measurementSalt: measurementSalt, log: log, file: fileHandler, joinTokenGetter: joinTokenGetter, @@ -66,21 +73,22 @@ func (s *Server) Run(creds credentials.TransportCredentials, port string) error // A node will receive: // - stateful disk encryption key. // - Kubernetes join token. -// - cluster and owner ID to taint the node as initialized. +// - measurement salt and secret, to mark the node as initialized. // In addition, control plane nodes receive: // - a decryption key for CA certificates uploaded to the Kubernetes cluster. -func (s *Server) IssueJoinTicket(ctx context.Context, req *joinproto.IssueJoinTicketRequest) (resp *joinproto.IssueJoinTicketResponse, retErr error) { - s.log.Infof("IssueJoinTicket called") +func (s *Server) IssueJoinTicket(ctx context.Context, req *joinproto.IssueJoinTicketRequest) (*joinproto.IssueJoinTicketResponse, error) { log := s.log.With(zap.String("peerAddress", grpclog.PeerAddrFromContext(ctx))) - log.Infof("Loading IDs") - var id attestationtypes.ID - if err := s.file.ReadJSON(filepath.Join(constants.ServiceBasePath, constants.IDFilename), &id); err != nil { - log.With(zap.Error(err)).Errorf("Unable to load IDs") - return nil, status.Errorf(codes.Internal, "unable to load IDs: %s", err) + log.Infof("IssueJoinTicket called") + + log.Infof("Requesting measurement secret") + measurementSecret, err := s.dataKeyGetter.GetDataKey(ctx, attestation.MeasurementSecretContext, crypto.DerivedKeyLengthDefault) + if err != nil { + log.With(zap.Error(err)).Errorf("Unable to get measurement secret") + return nil, status.Errorf(codes.Internal, "unable to get measurement secret: %s", err) } log.Infof("Requesting disk encryption key") - stateDiskKey, err := s.dataKeyGetter.GetDataKey(ctx, req.DiskUuid, constants.StateDiskKeyLength) + stateDiskKey, err := s.dataKeyGetter.GetDataKey(ctx, req.DiskUuid, crypto.StateDiskKeyLength) if err != nil { log.With(zap.Error(err)).Errorf("Unable to get key for stateful disk") return nil, status.Errorf(codes.Internal, "unable to get key for stateful disk: %s", err) @@ -94,7 +102,7 @@ func (s *Server) IssueJoinTicket(ctx context.Context, req *joinproto.IssueJoinTi } log.Infof("Querying K8sVersion ConfigMap") - k8sVersion, err := s.getK8sVersion(ctx) + k8sVersion, err := s.getK8sVersion() if err != nil { return nil, status.Errorf(codes.Internal, "unable to get k8s version: %s", err) } @@ -122,11 +130,11 @@ func (s *Server) IssueJoinTicket(ctx context.Context, req *joinproto.IssueJoinTi } } - s.log.Infof("IssueJoinTicket successful") + log.Infof("IssueJoinTicket successful") return &joinproto.IssueJoinTicketResponse{ StateDiskKey: stateDiskKey, - ClusterId: id.Cluster, - OwnerId: id.Owner, + MeasurementSalt: s.measurementSalt, + MeasurementSecret: measurementSecret, ApiServerEndpoint: kubeArgs.APIServerEndpoint, Token: kubeArgs.Token, DiscoveryTokenCaCertHash: kubeArgs.CACertHashes[0], @@ -136,8 +144,32 @@ func (s *Server) IssueJoinTicket(ctx context.Context, req *joinproto.IssueJoinTi }, nil } +func (s *Server) IssueRejoinTicket(ctx context.Context, req *joinproto.IssueRejoinTicketRequest) (*joinproto.IssueRejoinTicketResponse, error) { + log := s.log.With(zap.String("peerAddress", grpclog.PeerAddrFromContext(ctx))) + log.Infof("IssueRejoinTicket called") + + log.Infof("Requesting measurement secret") + measurementSecret, err := s.dataKeyGetter.GetDataKey(ctx, attestation.MeasurementSecretContext, crypto.DerivedKeyLengthDefault) + if err != nil { + log.With(zap.Error(err)).Errorf("Unable to get measurement secret") + return nil, status.Errorf(codes.Internal, "unable to get measurement secret: %s", err) + } + + log.Infof("Requesting disk encryption key") + stateDiskKey, err := s.dataKeyGetter.GetDataKey(ctx, req.DiskUuid, crypto.StateDiskKeyLength) + if err != nil { + log.With(zap.Error(err)).Errorf("Unable to get key for stateful disk") + return nil, status.Errorf(codes.Internal, "unable to get key for stateful disk: %s", err) + } + + return &joinproto.IssueRejoinTicketResponse{ + StateDiskKey: stateDiskKey, + MeasurementSecret: measurementSecret, + }, nil +} + // getK8sVersion reads the k8s version from a VolumeMount that is backed by the k8s-version ConfigMap. -func (s *Server) getK8sVersion(_ context.Context) (string, error) { +func (s *Server) getK8sVersion() (string, error) { fileContent, err := s.file.Read(filepath.Join(constants.ServiceBasePath, "k8s-version")) if err != nil { return "", fmt.Errorf("could not read k8s version file: %v", err) diff --git a/joinservice/internal/server/server_test.go b/joinservice/internal/server/server_test.go index 078c2650c..45d0e4b64 100644 --- a/joinservice/internal/server/server_test.go +++ b/joinservice/internal/server/server_test.go @@ -2,13 +2,12 @@ package server import ( "context" - "encoding/json" "errors" "path/filepath" "testing" "time" - attestationtypes "github.com/edgelesssys/constellation/internal/attestation/types" + "github.com/edgelesssys/constellation/internal/attestation" "github.com/edgelesssys/constellation/internal/constants" "github.com/edgelesssys/constellation/internal/file" "github.com/edgelesssys/constellation/internal/logger" @@ -29,10 +28,9 @@ func TestIssueJoinTicket(t *testing.T) { someErr := errors.New("error") testKey := []byte{0x1, 0x2, 0x3} testCert := []byte{0x4, 0x5, 0x6} - testID := attestationtypes.ID{ - Owner: []byte{0x4, 0x5, 0x6}, - Cluster: []byte{0x7, 0x8, 0x9}, - } + measurementSecret := []byte{0x7, 0x8, 0x9} + uuid := "uuid" + testJoinToken := &kubeadmv1.BootstrapTokenDiscovery{ APIServerEndpoint: "192.0.2.1", CACertHashes: []string{"hash"}, @@ -45,47 +43,38 @@ func TestIssueJoinTicket(t *testing.T) { kubeadm stubTokenGetter kms stubKeyGetter ca stubCA - id []byte wantErr bool }{ "worker node": { kubeadm: stubTokenGetter{token: testJoinToken}, - kms: stubKeyGetter{dataKey: testKey}, - ca: stubCA{cert: testCert}, - id: mustMarshalID(testID), + kms: stubKeyGetter{dataKeys: map[string][]byte{ + uuid: testKey, + attestation.MeasurementSecretContext: measurementSecret, + }}, + ca: stubCA{cert: testCert}, }, "GetDataKey fails": { kubeadm: stubTokenGetter{token: testJoinToken}, - kms: stubKeyGetter{getDataKeyErr: someErr}, - ca: stubCA{cert: testCert}, - id: mustMarshalID(testID), - wantErr: true, - }, - "loading IDs fails": { - kubeadm: stubTokenGetter{token: testJoinToken}, - kms: stubKeyGetter{dataKey: testKey}, - ca: stubCA{cert: testCert}, - id: []byte{0x1, 0x2, 0x3}, - wantErr: true, - }, - "no ID file": { - kubeadm: stubTokenGetter{token: testJoinToken}, - kms: stubKeyGetter{dataKey: testKey}, + kms: stubKeyGetter{dataKeys: make(map[string][]byte), getDataKeyErr: someErr}, ca: stubCA{cert: testCert}, wantErr: true, }, "GetJoinToken fails": { kubeadm: stubTokenGetter{getJoinTokenErr: someErr}, - kms: stubKeyGetter{dataKey: testKey}, + kms: stubKeyGetter{dataKeys: map[string][]byte{ + uuid: testKey, + attestation.MeasurementSecretContext: measurementSecret, + }}, ca: stubCA{cert: testCert}, - id: mustMarshalID(testID), wantErr: true, }, "GetCertificate fails": { kubeadm: stubTokenGetter{token: testJoinToken}, - kms: stubKeyGetter{dataKey: testKey}, + kms: stubKeyGetter{dataKeys: map[string][]byte{ + uuid: testKey, + attestation.MeasurementSecretContext: measurementSecret, + }}, ca: stubCA{getCertErr: someErr}, - id: mustMarshalID(testID), wantErr: true, }, "control plane": { @@ -94,17 +83,21 @@ func TestIssueJoinTicket(t *testing.T) { token: testJoinToken, files: map[string][]byte{"test": {0x1, 0x2, 0x3}}, }, - kms: stubKeyGetter{dataKey: testKey}, - ca: stubCA{cert: testCert}, - id: mustMarshalID(testID), + kms: stubKeyGetter{dataKeys: map[string][]byte{ + uuid: testKey, + attestation.MeasurementSecretContext: measurementSecret, + }}, + ca: stubCA{cert: testCert}, }, "GetControlPlaneCertificateKey fails": { isControlPlane: true, kubeadm: stubTokenGetter{token: testJoinToken, certificateKeyErr: someErr}, - kms: stubKeyGetter{dataKey: testKey}, - ca: stubCA{cert: testCert}, - id: mustMarshalID(testID), - wantErr: true, + kms: stubKeyGetter{dataKeys: map[string][]byte{ + uuid: testKey, + attestation.MeasurementSecretContext: measurementSecret, + }}, + ca: stubCA{cert: testCert}, + wantErr: true, }, } @@ -113,20 +106,18 @@ func TestIssueJoinTicket(t *testing.T) { assert := assert.New(t) require := require.New(t) - file := file.NewHandler(afero.NewMemMapFs()) - if len(tc.id) > 0 { - require.NoError(file.Write(filepath.Join(constants.ServiceBasePath, constants.IDFilename), tc.id, 0o644)) - } - + handler := file.NewHandler(afero.NewMemMapFs()) // IssueJoinTicket tries to read the k8s-version ConfigMap from a mounted file. - require.NoError(file.Write(filepath.Join(constants.ServiceBasePath, constants.K8sVersion), []byte(testK8sVersion), 0o644)) + require.NoError(handler.Write(filepath.Join(constants.ServiceBasePath, constants.K8sVersion), []byte(testK8sVersion), file.OptNone)) + salt := []byte{0xA, 0xB, 0xC} api := New( - logger.NewTest(t), - file, + salt, + handler, tc.ca, tc.kubeadm, tc.kms, + logger.NewTest(t), ) req := &joinproto.IssueJoinTicketRequest{ @@ -139,13 +130,10 @@ func TestIssueJoinTicket(t *testing.T) { return } - var expectedIDs attestationtypes.ID - require.NoError(json.Unmarshal(tc.id, &expectedIDs)) - require.NoError(err) - assert.Equal(tc.kms.dataKey, resp.StateDiskKey) - assert.Equal(expectedIDs.Cluster, resp.ClusterId) - assert.Equal(expectedIDs.Owner, resp.OwnerId) + assert.Equal(tc.kms.dataKeys[uuid], resp.StateDiskKey) + assert.Equal(salt, resp.MeasurementSalt) + assert.Equal(tc.kms.dataKeys[attestation.MeasurementSecretContext], resp.MeasurementSecret) assert.Equal(tc.kubeadm.token.APIServerEndpoint, resp.ApiServerEndpoint) assert.Equal(tc.kubeadm.token.CACertHashes[0], resp.DiscoveryTokenCaCertHash) assert.Equal(tc.kubeadm.token.Token, resp.Token) @@ -158,12 +146,58 @@ func TestIssueJoinTicket(t *testing.T) { } } -func mustMarshalID(id attestationtypes.ID) []byte { - b, err := json.Marshal(id) - if err != nil { - panic(err) +func TestIssueRejoinTicker(t *testing.T) { + uuid := "uuid" + + testCases := map[string]struct { + keyGetter stubKeyGetter + wantErr bool + }{ + "success": { + keyGetter: stubKeyGetter{ + dataKeys: map[string][]byte{ + uuid: {0x1, 0x2, 0x3}, + attestation.MeasurementSecretContext: {0x4, 0x5, 0x6}, + }, + }, + }, + "failure": { + keyGetter: stubKeyGetter{ + dataKeys: make(map[string][]byte), + getDataKeyErr: errors.New("error"), + }, + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + api := New( + nil, + file.Handler{}, + stubCA{}, + stubTokenGetter{}, + tc.keyGetter, + logger.NewTest(t), + ) + + req := &joinproto.IssueRejoinTicketRequest{ + DiskUuid: uuid, + } + resp, err := api.IssueRejoinTicket(context.Background(), req) + if tc.wantErr { + assert.Error(err) + return + } + + require.NoError(err) + assert.Equal(tc.keyGetter.dataKeys[attestation.MeasurementSecretContext], resp.MeasurementSecret) + assert.Equal(tc.keyGetter.dataKeys[uuid], resp.StateDiskKey) + }) } - return b } type stubTokenGetter struct { @@ -182,12 +216,12 @@ func (f stubTokenGetter) GetControlPlaneCertificatesAndKeys() (map[string][]byte } type stubKeyGetter struct { - dataKey []byte + dataKeys map[string][]byte getDataKeyErr error } -func (f stubKeyGetter) GetDataKey(context.Context, string, int) ([]byte, error) { - return f.dataKey, f.getDataKeyErr +func (f stubKeyGetter) GetDataKey(_ context.Context, name string, _ int) ([]byte, error) { + return f.dataKeys[name], f.getDataKeyErr } type stubCA struct { diff --git a/joinservice/joinproto/join.pb.go b/joinservice/joinproto/join.pb.go index 3782fd06b..2f3b5b4e1 100644 --- a/joinservice/joinproto/join.pb.go +++ b/joinservice/joinproto/join.pb.go @@ -89,8 +89,8 @@ type IssueJoinTicketResponse struct { unknownFields protoimpl.UnknownFields StateDiskKey []byte `protobuf:"bytes,1,opt,name=state_disk_key,json=stateDiskKey,proto3" json:"state_disk_key,omitempty"` - OwnerId []byte `protobuf:"bytes,2,opt,name=owner_id,json=ownerId,proto3" json:"owner_id,omitempty"` - ClusterId []byte `protobuf:"bytes,3,opt,name=cluster_id,json=clusterId,proto3" json:"cluster_id,omitempty"` + MeasurementSalt []byte `protobuf:"bytes,2,opt,name=measurement_salt,json=measurementSalt,proto3" json:"measurement_salt,omitempty"` + MeasurementSecret []byte `protobuf:"bytes,3,opt,name=measurement_secret,json=measurementSecret,proto3" json:"measurement_secret,omitempty"` KubeletCert []byte `protobuf:"bytes,4,opt,name=kubelet_cert,json=kubeletCert,proto3" json:"kubelet_cert,omitempty"` ApiServerEndpoint string `protobuf:"bytes,5,opt,name=api_server_endpoint,json=apiServerEndpoint,proto3" json:"api_server_endpoint,omitempty"` Token string `protobuf:"bytes,6,opt,name=token,proto3" json:"token,omitempty"` @@ -138,16 +138,16 @@ func (x *IssueJoinTicketResponse) GetStateDiskKey() []byte { return nil } -func (x *IssueJoinTicketResponse) GetOwnerId() []byte { +func (x *IssueJoinTicketResponse) GetMeasurementSalt() []byte { if x != nil { - return x.OwnerId + return x.MeasurementSalt } return nil } -func (x *IssueJoinTicketResponse) GetClusterId() []byte { +func (x *IssueJoinTicketResponse) GetMeasurementSecret() []byte { if x != nil { - return x.ClusterId + return x.MeasurementSecret } return nil } @@ -249,6 +249,108 @@ func (x *ControlPlaneCertOrKey) GetData() []byte { return nil } +type IssueRejoinTicketRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + DiskUuid string `protobuf:"bytes,1,opt,name=disk_uuid,json=diskUuid,proto3" json:"disk_uuid,omitempty"` +} + +func (x *IssueRejoinTicketRequest) Reset() { + *x = IssueRejoinTicketRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_join_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *IssueRejoinTicketRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*IssueRejoinTicketRequest) ProtoMessage() {} + +func (x *IssueRejoinTicketRequest) ProtoReflect() protoreflect.Message { + mi := &file_join_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use IssueRejoinTicketRequest.ProtoReflect.Descriptor instead. +func (*IssueRejoinTicketRequest) Descriptor() ([]byte, []int) { + return file_join_proto_rawDescGZIP(), []int{3} +} + +func (x *IssueRejoinTicketRequest) GetDiskUuid() string { + if x != nil { + return x.DiskUuid + } + return "" +} + +type IssueRejoinTicketResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + StateDiskKey []byte `protobuf:"bytes,1,opt,name=state_disk_key,json=stateDiskKey,proto3" json:"state_disk_key,omitempty"` + MeasurementSecret []byte `protobuf:"bytes,2,opt,name=measurement_secret,json=measurementSecret,proto3" json:"measurement_secret,omitempty"` +} + +func (x *IssueRejoinTicketResponse) Reset() { + *x = IssueRejoinTicketResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_join_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *IssueRejoinTicketResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*IssueRejoinTicketResponse) ProtoMessage() {} + +func (x *IssueRejoinTicketResponse) ProtoReflect() protoreflect.Message { + mi := &file_join_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use IssueRejoinTicketResponse.ProtoReflect.Descriptor instead. +func (*IssueRejoinTicketResponse) Descriptor() ([]byte, []int) { + return file_join_proto_rawDescGZIP(), []int{4} +} + +func (x *IssueRejoinTicketResponse) GetStateDiskKey() []byte { + if x != nil { + return x.StateDiskKey + } + return nil +} + +func (x *IssueRejoinTicketResponse) GetMeasurementSecret() []byte { + if x != nil { + return x.MeasurementSecret + } + return nil +} + var File_join_proto protoreflect.FileDescriptor var file_join_proto_rawDesc = []byte{ @@ -262,15 +364,17 @@ var file_join_proto_rawDesc = []byte{ 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x28, 0x0a, 0x10, 0x69, 0x73, 0x5f, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x5f, 0x70, 0x6c, 0x61, 0x6e, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, 0x69, 0x73, 0x43, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, - 0x50, 0x6c, 0x61, 0x6e, 0x65, 0x22, 0xa2, 0x03, 0x0a, 0x17, 0x49, 0x73, 0x73, 0x75, 0x65, 0x4a, + 0x50, 0x6c, 0x61, 0x6e, 0x65, 0x22, 0xc2, 0x03, 0x0a, 0x17, 0x49, 0x73, 0x73, 0x75, 0x65, 0x4a, 0x6f, 0x69, 0x6e, 0x54, 0x69, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x24, 0x0a, 0x0e, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5f, 0x64, 0x69, 0x73, 0x6b, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0c, 0x73, 0x74, 0x61, 0x74, 0x65, - 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79, 0x12, 0x19, 0x0a, 0x08, 0x6f, 0x77, 0x6e, 0x65, 0x72, - 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x6f, 0x77, 0x6e, 0x65, 0x72, - 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, 0x72, 0x5f, 0x69, 0x64, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, 0x72, 0x49, - 0x64, 0x12, 0x21, 0x0a, 0x0c, 0x6b, 0x75, 0x62, 0x65, 0x6c, 0x65, 0x74, 0x5f, 0x63, 0x65, 0x72, + 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79, 0x12, 0x29, 0x0a, 0x10, 0x6d, 0x65, 0x61, 0x73, 0x75, + 0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x73, 0x61, 0x6c, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0c, 0x52, 0x0f, 0x6d, 0x65, 0x61, 0x73, 0x75, 0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x61, + 0x6c, 0x74, 0x12, 0x2d, 0x0a, 0x12, 0x6d, 0x65, 0x61, 0x73, 0x75, 0x72, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x5f, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, + 0x6d, 0x65, 0x61, 0x73, 0x75, 0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, + 0x74, 0x12, 0x21, 0x0a, 0x0c, 0x6b, 0x75, 0x62, 0x65, 0x6c, 0x65, 0x74, 0x5f, 0x63, 0x65, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0b, 0x6b, 0x75, 0x62, 0x65, 0x6c, 0x65, 0x74, 0x43, 0x65, 0x72, 0x74, 0x12, 0x2e, 0x0a, 0x13, 0x61, 0x70, 0x69, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, @@ -292,17 +396,33 @@ var file_join_proto_rawDesc = []byte{ 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x5f, 0x70, 0x6c, 0x61, 0x6e, 0x65, 0x5f, 0x63, 0x65, 0x72, 0x74, 0x5f, 0x6f, 0x72, 0x5f, 0x6b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x64, - 0x61, 0x74, 0x61, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x32, - 0x55, 0x0a, 0x03, 0x41, 0x50, 0x49, 0x12, 0x4e, 0x0a, 0x0f, 0x49, 0x73, 0x73, 0x75, 0x65, 0x4a, - 0x6f, 0x69, 0x6e, 0x54, 0x69, 0x63, 0x6b, 0x65, 0x74, 0x12, 0x1c, 0x2e, 0x6a, 0x6f, 0x69, 0x6e, - 0x2e, 0x49, 0x73, 0x73, 0x75, 0x65, 0x4a, 0x6f, 0x69, 0x6e, 0x54, 0x69, 0x63, 0x6b, 0x65, 0x74, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1d, 0x2e, 0x6a, 0x6f, 0x69, 0x6e, 0x2e, 0x49, - 0x73, 0x73, 0x75, 0x65, 0x4a, 0x6f, 0x69, 0x6e, 0x54, 0x69, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x3c, 0x5a, 0x3a, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, - 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x65, 0x64, 0x67, 0x65, 0x6c, 0x65, 0x73, 0x73, 0x73, 0x79, 0x73, - 0x2f, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x65, 0x6c, 0x6c, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x6a, - 0x6f, 0x69, 0x6e, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2f, 0x6a, 0x6f, 0x69, 0x6e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x61, 0x74, 0x61, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x22, + 0x37, 0x0a, 0x18, 0x49, 0x73, 0x73, 0x75, 0x65, 0x52, 0x65, 0x6a, 0x6f, 0x69, 0x6e, 0x54, 0x69, + 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x64, + 0x69, 0x73, 0x6b, 0x5f, 0x75, 0x75, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, + 0x64, 0x69, 0x73, 0x6b, 0x55, 0x75, 0x69, 0x64, 0x22, 0x70, 0x0a, 0x19, 0x49, 0x73, 0x73, 0x75, + 0x65, 0x52, 0x65, 0x6a, 0x6f, 0x69, 0x6e, 0x54, 0x69, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x24, 0x0a, 0x0e, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5f, 0x64, + 0x69, 0x73, 0x6b, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0c, 0x73, + 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79, 0x12, 0x2d, 0x0a, 0x12, 0x6d, + 0x65, 0x61, 0x73, 0x75, 0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x73, 0x65, 0x63, 0x72, 0x65, + 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, 0x6d, 0x65, 0x61, 0x73, 0x75, 0x72, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x32, 0xab, 0x01, 0x0a, 0x03, 0x41, + 0x50, 0x49, 0x12, 0x4e, 0x0a, 0x0f, 0x49, 0x73, 0x73, 0x75, 0x65, 0x4a, 0x6f, 0x69, 0x6e, 0x54, + 0x69, 0x63, 0x6b, 0x65, 0x74, 0x12, 0x1c, 0x2e, 0x6a, 0x6f, 0x69, 0x6e, 0x2e, 0x49, 0x73, 0x73, + 0x75, 0x65, 0x4a, 0x6f, 0x69, 0x6e, 0x54, 0x69, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x1d, 0x2e, 0x6a, 0x6f, 0x69, 0x6e, 0x2e, 0x49, 0x73, 0x73, 0x75, 0x65, + 0x4a, 0x6f, 0x69, 0x6e, 0x54, 0x69, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x54, 0x0a, 0x11, 0x49, 0x73, 0x73, 0x75, 0x65, 0x52, 0x65, 0x6a, 0x6f, 0x69, + 0x6e, 0x54, 0x69, 0x63, 0x6b, 0x65, 0x74, 0x12, 0x1e, 0x2e, 0x6a, 0x6f, 0x69, 0x6e, 0x2e, 0x49, + 0x73, 0x73, 0x75, 0x65, 0x52, 0x65, 0x6a, 0x6f, 0x69, 0x6e, 0x54, 0x69, 0x63, 0x6b, 0x65, 0x74, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1f, 0x2e, 0x6a, 0x6f, 0x69, 0x6e, 0x2e, 0x49, + 0x73, 0x73, 0x75, 0x65, 0x52, 0x65, 0x6a, 0x6f, 0x69, 0x6e, 0x54, 0x69, 0x63, 0x6b, 0x65, 0x74, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x3c, 0x5a, 0x3a, 0x67, 0x69, 0x74, 0x68, + 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x65, 0x64, 0x67, 0x65, 0x6c, 0x65, 0x73, 0x73, 0x73, + 0x79, 0x73, 0x2f, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x65, 0x6c, 0x6c, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x2f, 0x6a, 0x6f, 0x69, 0x6e, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2f, 0x6a, 0x6f, 0x69, + 0x6e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -317,18 +437,22 @@ func file_join_proto_rawDescGZIP() []byte { return file_join_proto_rawDescData } -var file_join_proto_msgTypes = make([]protoimpl.MessageInfo, 3) +var file_join_proto_msgTypes = make([]protoimpl.MessageInfo, 5) var file_join_proto_goTypes = []interface{}{ - (*IssueJoinTicketRequest)(nil), // 0: join.IssueJoinTicketRequest - (*IssueJoinTicketResponse)(nil), // 1: join.IssueJoinTicketResponse - (*ControlPlaneCertOrKey)(nil), // 2: join.control_plane_cert_or_key + (*IssueJoinTicketRequest)(nil), // 0: join.IssueJoinTicketRequest + (*IssueJoinTicketResponse)(nil), // 1: join.IssueJoinTicketResponse + (*ControlPlaneCertOrKey)(nil), // 2: join.control_plane_cert_or_key + (*IssueRejoinTicketRequest)(nil), // 3: join.IssueRejoinTicketRequest + (*IssueRejoinTicketResponse)(nil), // 4: join.IssueRejoinTicketResponse } var file_join_proto_depIdxs = []int32{ 2, // 0: join.IssueJoinTicketResponse.control_plane_files:type_name -> join.control_plane_cert_or_key 0, // 1: join.API.IssueJoinTicket:input_type -> join.IssueJoinTicketRequest - 1, // 2: join.API.IssueJoinTicket:output_type -> join.IssueJoinTicketResponse - 2, // [2:3] is the sub-list for method output_type - 1, // [1:2] is the sub-list for method input_type + 3, // 2: join.API.IssueRejoinTicket:input_type -> join.IssueRejoinTicketRequest + 1, // 3: join.API.IssueJoinTicket:output_type -> join.IssueJoinTicketResponse + 4, // 4: join.API.IssueRejoinTicket:output_type -> join.IssueRejoinTicketResponse + 3, // [3:5] is the sub-list for method output_type + 1, // [1:3] is the sub-list for method input_type 1, // [1:1] is the sub-list for extension type_name 1, // [1:1] is the sub-list for extension extendee 0, // [0:1] is the sub-list for field type_name @@ -376,6 +500,30 @@ func file_join_proto_init() { return nil } } + file_join_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*IssueRejoinTicketRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_join_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*IssueRejoinTicketResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ @@ -383,7 +531,7 @@ func file_join_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_join_proto_rawDesc, NumEnums: 0, - NumMessages: 3, + NumMessages: 5, NumExtensions: 0, NumServices: 1, }, diff --git a/joinservice/joinproto/join.proto b/joinservice/joinproto/join.proto index 8c97f4c12..589b1322a 100644 --- a/joinservice/joinproto/join.proto +++ b/joinservice/joinproto/join.proto @@ -6,6 +6,7 @@ option go_package = "github.com/edgelesssys/constellation/joinservice/joinproto" service API { rpc IssueJoinTicket(IssueJoinTicketRequest) returns (IssueJoinTicketResponse); + rpc IssueRejoinTicket(IssueRejoinTicketRequest) returns (IssueRejoinTicketResponse); } @@ -17,8 +18,8 @@ message IssueJoinTicketRequest { message IssueJoinTicketResponse { bytes state_disk_key = 1; - bytes owner_id = 2; - bytes cluster_id = 3; + bytes measurement_salt = 2; + bytes measurement_secret = 3; bytes kubelet_cert = 4; string api_server_endpoint = 5; string token = 6; @@ -31,3 +32,12 @@ message control_plane_cert_or_key { string name = 1; bytes data = 2; } + +message IssueRejoinTicketRequest { + string disk_uuid = 1; +} + +message IssueRejoinTicketResponse { + bytes state_disk_key = 1; + bytes measurement_secret = 2; +} diff --git a/joinservice/joinproto/join_grpc.pb.go b/joinservice/joinproto/join_grpc.pb.go index 1463f7415..9d8d597c5 100644 --- a/joinservice/joinproto/join_grpc.pb.go +++ b/joinservice/joinproto/join_grpc.pb.go @@ -23,6 +23,7 @@ const _ = grpc.SupportPackageIsVersion7 // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. type APIClient interface { IssueJoinTicket(ctx context.Context, in *IssueJoinTicketRequest, opts ...grpc.CallOption) (*IssueJoinTicketResponse, error) + IssueRejoinTicket(ctx context.Context, in *IssueRejoinTicketRequest, opts ...grpc.CallOption) (*IssueRejoinTicketResponse, error) } type aPIClient struct { @@ -42,11 +43,21 @@ func (c *aPIClient) IssueJoinTicket(ctx context.Context, in *IssueJoinTicketRequ return out, nil } +func (c *aPIClient) IssueRejoinTicket(ctx context.Context, in *IssueRejoinTicketRequest, opts ...grpc.CallOption) (*IssueRejoinTicketResponse, error) { + out := new(IssueRejoinTicketResponse) + err := c.cc.Invoke(ctx, "/join.API/IssueRejoinTicket", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // APIServer is the server API for API service. // All implementations must embed UnimplementedAPIServer // for forward compatibility type APIServer interface { IssueJoinTicket(context.Context, *IssueJoinTicketRequest) (*IssueJoinTicketResponse, error) + IssueRejoinTicket(context.Context, *IssueRejoinTicketRequest) (*IssueRejoinTicketResponse, error) mustEmbedUnimplementedAPIServer() } @@ -57,6 +68,9 @@ type UnimplementedAPIServer struct { func (UnimplementedAPIServer) IssueJoinTicket(context.Context, *IssueJoinTicketRequest) (*IssueJoinTicketResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method IssueJoinTicket not implemented") } +func (UnimplementedAPIServer) IssueRejoinTicket(context.Context, *IssueRejoinTicketRequest) (*IssueRejoinTicketResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method IssueRejoinTicket not implemented") +} func (UnimplementedAPIServer) mustEmbedUnimplementedAPIServer() {} // UnsafeAPIServer may be embedded to opt out of forward compatibility for this service. @@ -88,6 +102,24 @@ func _API_IssueJoinTicket_Handler(srv interface{}, ctx context.Context, dec func return interceptor(ctx, in, info, handler) } +func _API_IssueRejoinTicket_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(IssueRejoinTicketRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(APIServer).IssueRejoinTicket(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/join.API/IssueRejoinTicket", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(APIServer).IssueRejoinTicket(ctx, req.(*IssueRejoinTicketRequest)) + } + return interceptor(ctx, in, info, handler) +} + // API_ServiceDesc is the grpc.ServiceDesc for API service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -99,6 +131,10 @@ var API_ServiceDesc = grpc.ServiceDesc{ MethodName: "IssueJoinTicket", Handler: _API_IssueJoinTicket_Handler, }, + { + MethodName: "IssueRejoinTicket", + Handler: _API_IssueRejoinTicket_Handler, + }, }, Streams: []grpc.StreamDesc{}, Metadata: "join.proto", diff --git a/kms/cmd/main.go b/kms/cmd/main.go index ca5665910..015ece26b 100644 --- a/kms/cmd/main.go +++ b/kms/cmd/main.go @@ -5,17 +5,14 @@ import ( "errors" "flag" "fmt" - "net" "path/filepath" "strconv" "time" - "github.com/edgelesssys/constellation/internal/atls" "github.com/edgelesssys/constellation/internal/constants" + "github.com/edgelesssys/constellation/internal/crypto" "github.com/edgelesssys/constellation/internal/file" - "github.com/edgelesssys/constellation/internal/grpc/atlscredentials" "github.com/edgelesssys/constellation/internal/logger" - "github.com/edgelesssys/constellation/internal/watcher" "github.com/edgelesssys/constellation/kms/internal/server" "github.com/edgelesssys/constellation/kms/setup" "github.com/spf13/afero" @@ -24,24 +21,15 @@ import ( func main() { port := flag.String("port", strconv.Itoa(constants.KMSPort), "Port gRPC server listens on") - portATLS := flag.String("atls-port", strconv.Itoa(constants.KMSNodePort), "Port aTLS server listens on") - provider := flag.String("cloud-provider", "", "cloud service provider this binary is running on") masterSecretPath := flag.String("master-secret", filepath.Join(constants.ServiceBasePath, constants.MasterSecretFilename), "Path to the Constellation master secret") verbosity := flag.Int("v", 0, logger.CmdLineVerbosityDescription) flag.Parse() log := logger.New(logger.JSONLog, logger.VerbosityFromInt(*verbosity)) - log.With(zap.String("version", constants.VersionInfo), zap.String("cloudProvider", *provider)). + log.With(zap.String("version", constants.VersionInfo)). Infof("Constellation Key Management Service") - validator, err := watcher.NewValidator(log.Named("validator"), *provider, file.NewHandler(afero.NewOsFs())) - if err != nil { - flag.Usage() - log.With(zap.Error(err)).Fatalf("Failed to create validator") - } - creds := atlscredentials.New(nil, []atls.Validator{validator}) - // set up Key Management Service masterKey, err := readMainSecret(*masterSecretPath) if err != nil { @@ -58,32 +46,7 @@ func main() { log.With(zap.Error(err)).Fatalf("Failed to create KMS KEK from MasterKey") } - // set up listeners - atlsListener, err := net.Listen("tcp", net.JoinHostPort("", *portATLS)) - if err != nil { - log.With(zap.Error(err)).Fatalf("Failed to listen on port %s", *portATLS) - } - plainListener, err := net.Listen("tcp", net.JoinHostPort("", *port)) - if err != nil { - log.With(zap.Error(err)).Fatalf("Failed to listen on port %s", *port) - } - - // start the measurements file watcher - watcher, err := watcher.New(log.Named("fileWatcher"), validator) - if err != nil { - log.With(zap.Error(err)).Fatalf("Failed to create watcher for measurements updates") - } - defer watcher.Close() - - go func() { - log.Infof("starting file watcher for measurements file %s", filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename)) - if err := watcher.Watch(filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename)); err != nil { - log.With(zap.Error(err)).Fatalf("Failed to watch measurements file") - } - }() - - // start the server - if err := server.New(log.Named("server"), conKMS).Run(atlsListener, plainListener, creds); err != nil { + if err := server.New(log.Named("kms"), conKMS).Run(*port); err != nil { log.With(zap.Error(err)).Fatalf("Failed to run KMS server") } } @@ -100,8 +63,8 @@ func readMainSecret(fileName string) ([]byte, error) { if err != nil { return nil, err } - if len(secretBytes) < constants.MasterSecretLengthMin { - return nil, fmt.Errorf("provided master secret is smaller than the required minimum of %d bytes", constants.MasterSecretLengthMin) + if len(secretBytes) < crypto.MasterSecretLengthMin { + return nil, fmt.Errorf("provided master secret is smaller than the required minimum of %d bytes", crypto.MasterSecretLengthMin) } return secretBytes, nil diff --git a/kms/internal/server/server.go b/kms/internal/server/server.go index 22f5cabd5..4ba92e0b2 100644 --- a/kms/internal/server/server.go +++ b/kms/internal/server/server.go @@ -3,10 +3,10 @@ package server import ( "context" + "fmt" "net" - "sync" - "github.com/edgelesssys/constellation/internal/grpc/atlscredentials" + "github.com/edgelesssys/constellation/internal/crypto" "github.com/edgelesssys/constellation/internal/grpc/grpclog" "github.com/edgelesssys/constellation/internal/logger" "github.com/edgelesssys/constellation/kms/kms" @@ -35,52 +35,21 @@ func New(log *logger.Logger, conKMS kms.CloudKMS) *Server { } } -// Run starts both the plain gRPC server and the aTLS gRPC server. -// If one of the servers fails, the other server will be closed and the error will be returned. -func (s *Server) Run(atlsListener, plainListener net.Listener, credentials *atlscredentials.Credentials) error { - var err error - var once sync.Once - var wg sync.WaitGroup - - atlsServer := grpc.NewServer( - grpc.Creds(credentials), - s.log.Named("gRPC.aTLS").GetServerUnaryInterceptor(), - ) - kmsproto.RegisterAPIServer(atlsServer, s) - - plainServer := grpc.NewServer(s.log.Named("gRPC.cluster").GetServerUnaryInterceptor()) - kmsproto.RegisterAPIServer(plainServer, s) +// Run starts the gRPC server. +func (s *Server) Run(port string) error { + // set up listener + listener, err := net.Listen("tcp", net.JoinHostPort("", port)) + if err != nil { + return fmt.Errorf("failed to listen on port %s: %v", port, err) + } + server := grpc.NewServer(s.log.Named("gRPC").GetServerUnaryInterceptor()) + kmsproto.RegisterAPIServer(server, s) s.log.Named("gRPC").WithIncreasedLevel(zapcore.WarnLevel).ReplaceGRPCLogger() - // start the plain gRPC server - wg.Add(1) - go func() { - defer wg.Done() - defer atlsServer.GracefulStop() - - s.log.Infof("Starting Constellation key management service on %s", plainListener.Addr().String()) - plainErr := plainServer.Serve(plainListener) - if plainErr != nil { - once.Do(func() { err = plainErr }) - } - }() - - // start the aTLS server - wg.Add(1) - go func() { - defer wg.Done() - defer plainServer.GracefulStop() - - s.log.Infof("Starting Constellation aTLS key management service on %s", atlsListener.Addr().String()) - atlsErr := atlsServer.Serve(atlsListener) - if atlsErr != nil { - once.Do(func() { err = atlsErr }) - } - }() - - wg.Wait() - return err + // start the server + s.log.Infof("Starting Constellation key management service on %s", listener.Addr().String()) + return server.Serve(listener) } // GetDataKey returns a data key. @@ -99,7 +68,7 @@ func (s *Server) GetDataKey(ctx context.Context, in *kmsproto.GetDataKeyRequest) return nil, status.Error(codes.InvalidArgument, "no data key ID specified") } - key, err := s.conKMS.GetDEK(ctx, "Constellation", "key-"+in.DataKeyId, int(in.Length)) + key, err := s.conKMS.GetDEK(ctx, "Constellation", crypto.HKDFInfoPrefix+in.DataKeyId, int(in.Length)) if err != nil { log.With(zap.Error(err)).Errorf("Failed to get data key") return nil, status.Errorf(codes.Internal, "%v", err) diff --git a/kms/internal/server/server_test.go b/kms/internal/server/server_test.go index 8fe35740e..b233f89df 100644 --- a/kms/internal/server/server_test.go +++ b/kms/internal/server/server_test.go @@ -3,12 +3,8 @@ package server import ( "context" "errors" - "net" - "sync" "testing" - "github.com/edgelesssys/constellation/internal/grpc/atlscredentials" - "github.com/edgelesssys/constellation/internal/grpc/testdialer" "github.com/edgelesssys/constellation/internal/logger" "github.com/edgelesssys/constellation/kms/kmsproto" "github.com/stretchr/testify/assert" @@ -20,48 +16,6 @@ func TestMain(m *testing.M) { goleak.VerifyTestMain(m) } -func TestRun(t *testing.T) { - assert := assert.New(t) - closeErr := errors.New("closed") - - var err error - var wg sync.WaitGroup - server := New(logger.NewTest(t), &stubKMS{}) - - creds := atlscredentials.New(nil, nil) - - atlsListener, plainListener := setUpTestListeners() - wg.Add(1) - go func() { - defer wg.Done() - err = server.Run(atlsListener, plainListener, creds) - }() - assert.NoError(plainListener.Close()) - wg.Wait() - assert.Equal(closeErr, err) - - atlsListener, plainListener = setUpTestListeners() - wg.Add(1) - go func() { - defer wg.Done() - err = server.Run(atlsListener, plainListener, creds) - }() - assert.NoError(atlsListener.Close()) - wg.Wait() - assert.Equal(closeErr, err) - - atlsListener, plainListener = setUpTestListeners() - wg.Add(1) - go func() { - defer wg.Done() - err = server.Run(atlsListener, plainListener, creds) - }() - go assert.NoError(atlsListener.Close()) - go assert.NoError(plainListener.Close()) - wg.Wait() - assert.Equal(closeErr, err) -} - func TestGetDataKey(t *testing.T) { assert := assert.New(t) require := require.New(t) @@ -92,12 +46,6 @@ func TestGetDataKey(t *testing.T) { assert.Nil(res) } -func setUpTestListeners() (net.Listener, net.Listener) { - atlsListener := testdialer.NewBufconnDialer().GetListener(net.JoinHostPort("192.0.2.1", "9001")) - plainListener := testdialer.NewBufconnDialer().GetListener(net.JoinHostPort("192.0.2.1", "9000")) - return atlsListener, plainListener -} - type stubKMS struct { masterKey []byte derivedKey []byte diff --git a/kms/kms/cluster/cluster.go b/kms/kms/cluster/cluster.go index 45c71800b..1f80ddf21 100644 --- a/kms/kms/cluster/cluster.go +++ b/kms/kms/cluster/cluster.go @@ -4,7 +4,7 @@ import ( "context" "errors" - "github.com/edgelesssys/constellation/bootstrapper/util" + "github.com/edgelesssys/constellation/internal/crypto" ) // ClusterKMS implements the kms.CloudKMS interface for in cluster key management. @@ -24,5 +24,5 @@ func (c *ClusterKMS) GetDEK(ctx context.Context, kekID string, dekID string, dek return nil, errors.New("master key not set for Constellation KMS") } // TODO: Choose a way to salt key derivation - return util.DeriveKey(c.masterKey, []byte("Constellation"), []byte("key"+dekID), uint(dekSize)) + return crypto.DeriveKey(c.masterKey, []byte("Constellation"), []byte(dekID), uint(dekSize)) } diff --git a/mount/cryptmapper/cryptmapper.go b/mount/cryptmapper/cryptmapper.go index e52a0af8a..8d1816294 100644 --- a/mount/cryptmapper/cryptmapper.go +++ b/mount/cryptmapper/cryptmapper.go @@ -11,7 +11,7 @@ import ( "sync" "time" - "github.com/edgelesssys/constellation/internal/constants" + "github.com/edgelesssys/constellation/internal/crypto" cryptsetup "github.com/martinjungblut/go-cryptsetup" "k8s.io/klog/v2" mount "k8s.io/mount-utils" @@ -179,7 +179,7 @@ func (c *CryptMapper) OpenCryptDevice(ctx context.Context, source, volumeID stri // ResizeCryptDevice resizes the underlying crypt device and returns the mapped device path. func (c *CryptMapper) ResizeCryptDevice(ctx context.Context, volumeID string) (string, error) { - dek, err := c.kms.GetDEK(ctx, volumeID, constants.StateDiskKeyLength) + dek, err := c.kms.GetDEK(ctx, volumeID, crypto.StateDiskKeyLength) if err != nil { return "", err } @@ -283,12 +283,12 @@ func openCryptDevice(ctx context.Context, device DeviceMapper, source, volumeID uuid := device.GetUUID() klog.V(4).Infof("Fetching data encryption key for volume %q", volumeID) - passphrase, err = getKey(ctx, uuid, constants.StateDiskKeyLength) + passphrase, err = getKey(ctx, uuid, crypto.StateDiskKeyLength) if err != nil { return "", err } - if len(passphrase) != constants.StateDiskKeyLength { - return "", fmt.Errorf("expected key length to be [%d] but got [%d]", constants.StateDiskKeyLength, len(passphrase)) + if len(passphrase) != crypto.StateDiskKeyLength { + return "", fmt.Errorf("expected key length to be [%d] but got [%d]", crypto.StateDiskKeyLength, len(passphrase)) } // Add a new keyslot using the internal volume key @@ -304,12 +304,12 @@ func openCryptDevice(ctx context.Context, device DeviceMapper, source, volumeID } else { uuid := device.GetUUID() klog.V(4).Infof("Fetching data encryption key for volume %q", volumeID) - passphrase, err = getKey(ctx, uuid, constants.StateDiskKeyLength) + passphrase, err = getKey(ctx, uuid, crypto.StateDiskKeyLength) if err != nil { return "", err } - if len(passphrase) != constants.StateDiskKeyLength { - return "", fmt.Errorf("expected key length to be [%d] but got [%d]", constants.StateDiskKeyLength, len(passphrase)) + if len(passphrase) != crypto.StateDiskKeyLength { + return "", fmt.Errorf("expected key length to be [%d] but got [%d]", crypto.StateDiskKeyLength, len(passphrase)) } } diff --git a/state/keyservice/keyproto/keyservice.pb.go b/state/keyservice/keyproto/keyservice.pb.go index aba0123cb..32c367c9a 100644 --- a/state/keyservice/keyproto/keyservice.pb.go +++ b/state/keyservice/keyproto/keyservice.pb.go @@ -25,7 +25,8 @@ type PushStateDiskKeyRequest struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - StateDiskKey []byte `protobuf:"bytes,1,opt,name=state_disk_key,json=stateDiskKey,proto3" json:"state_disk_key,omitempty"` + StateDiskKey []byte `protobuf:"bytes,1,opt,name=state_disk_key,json=stateDiskKey,proto3" json:"state_disk_key,omitempty"` + MeasurementSecret []byte `protobuf:"bytes,2,opt,name=measurement_secret,json=measurementSecret,proto3" json:"measurement_secret,omitempty"` } func (x *PushStateDiskKeyRequest) Reset() { @@ -67,6 +68,13 @@ func (x *PushStateDiskKeyRequest) GetStateDiskKey() []byte { return nil } +func (x *PushStateDiskKeyRequest) GetMeasurementSecret() []byte { + if x != nil { + return x.MeasurementSecret + } + return nil +} + type PushStateDiskKeyResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -109,24 +117,27 @@ var File_keyservice_proto protoreflect.FileDescriptor var file_keyservice_proto_rawDesc = []byte{ 0x0a, 0x10, 0x6b, 0x65, 0x79, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x12, 0x08, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x3f, 0x0a, 0x17, + 0x74, 0x6f, 0x12, 0x08, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x6e, 0x0a, 0x17, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x24, 0x0a, 0x0e, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5f, 0x64, 0x69, 0x73, 0x6b, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, - 0x0c, 0x73, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79, 0x22, 0x1a, 0x0a, - 0x18, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, - 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x32, 0x60, 0x0a, 0x03, 0x41, 0x50, 0x49, - 0x12, 0x59, 0x0a, 0x10, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, - 0x6b, 0x4b, 0x65, 0x79, 0x12, 0x21, 0x2e, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, + 0x0c, 0x73, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79, 0x12, 0x2d, 0x0a, + 0x12, 0x6d, 0x65, 0x61, 0x73, 0x75, 0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x73, 0x65, 0x63, + 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, 0x6d, 0x65, 0x61, 0x73, 0x75, + 0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x22, 0x1a, 0x0a, 0x18, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x22, 0x2e, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x2e, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, - 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x40, 0x5a, 0x3e, 0x67, - 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x65, 0x64, 0x67, 0x65, 0x6c, 0x65, - 0x73, 0x73, 0x73, 0x79, 0x73, 0x2f, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x65, 0x6c, 0x6c, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x2f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x2f, 0x6b, 0x65, 0x79, 0x73, 0x65, 0x72, - 0x76, 0x69, 0x63, 0x65, 0x2f, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x32, 0x60, 0x0a, 0x03, 0x41, 0x50, 0x49, 0x12, + 0x59, 0x0a, 0x10, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, + 0x4b, 0x65, 0x79, 0x12, 0x21, 0x2e, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x50, + 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x22, 0x2e, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x2e, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, + 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x40, 0x5a, 0x3e, 0x67, 0x69, + 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x65, 0x64, 0x67, 0x65, 0x6c, 0x65, 0x73, + 0x73, 0x73, 0x79, 0x73, 0x2f, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x65, 0x6c, 0x6c, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x2f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x2f, 0x6b, 0x65, 0x79, 0x73, 0x65, 0x72, 0x76, + 0x69, 0x63, 0x65, 0x2f, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/state/keyservice/keyproto/keyservice.proto b/state/keyservice/keyproto/keyservice.proto index 45befba71..d6b718ae4 100644 --- a/state/keyservice/keyproto/keyservice.proto +++ b/state/keyservice/keyproto/keyservice.proto @@ -10,6 +10,7 @@ service API { message PushStateDiskKeyRequest { bytes state_disk_key = 1; + bytes measurement_secret = 2; } message PushStateDiskKeyResponse { diff --git a/state/keyservice/keyservice.go b/state/keyservice/keyservice.go index 3ddbcc008..a46c6307b 100644 --- a/state/keyservice/keyservice.go +++ b/state/keyservice/keyservice.go @@ -8,11 +8,11 @@ import ( "time" "github.com/edgelesssys/constellation/internal/cloud/metadata" - "github.com/edgelesssys/constellation/internal/constants" + "github.com/edgelesssys/constellation/internal/crypto" "github.com/edgelesssys/constellation/internal/grpc/atlscredentials" "github.com/edgelesssys/constellation/internal/logger" "github.com/edgelesssys/constellation/internal/oid" - "github.com/edgelesssys/constellation/kms/kmsproto" + "github.com/edgelesssys/constellation/joinservice/joinproto" "github.com/edgelesssys/constellation/state/keyservice/keyproto" "go.uber.org/zap" "google.golang.org/grpc" @@ -23,14 +23,15 @@ import ( // KeyAPI is the interface called by control-plane or an admin during restart of a node. type KeyAPI struct { - listenAddr string - log *logger.Logger - mux sync.Mutex - metadata metadata.InstanceLister - issuer QuoteIssuer - key []byte - keyReceived chan struct{} - timeout time.Duration + listenAddr string + log *logger.Logger + mux sync.Mutex + metadata metadata.InstanceLister + issuer QuoteIssuer + key []byte + measurementSecret []byte + keyReceived chan struct{} + timeout time.Duration keyproto.UnimplementedAPIServer } @@ -52,19 +53,23 @@ func (a *KeyAPI) PushStateDiskKey(ctx context.Context, in *keyproto.PushStateDis if len(a.key) != 0 { return nil, status.Error(codes.FailedPrecondition, "node already received a passphrase") } - if len(in.StateDiskKey) != constants.RNGLengthDefault { - return nil, status.Errorf(codes.InvalidArgument, "received invalid passphrase: expected length: %d, but got: %d", constants.RNGLengthDefault, len(in.StateDiskKey)) + if len(in.StateDiskKey) != crypto.StateDiskKeyLength { + return nil, status.Errorf(codes.InvalidArgument, "received invalid passphrase: expected length: %d, but got: %d", crypto.StateDiskKeyLength, len(in.StateDiskKey)) + } + if len(in.MeasurementSecret) != crypto.RNGLengthDefault { + return nil, status.Errorf(codes.InvalidArgument, "received invalid measurement secret: expected length: %d, but got: %d", crypto.RNGLengthDefault, len(in.MeasurementSecret)) } a.key = in.StateDiskKey + a.measurementSecret = in.MeasurementSecret a.keyReceived <- struct{}{} return &keyproto.PushStateDiskKeyResponse{}, nil } // WaitForDecryptionKey notifies control-plane nodes to send a decryption key and waits until a key is received. -func (a *KeyAPI) WaitForDecryptionKey(uuid, listenAddr string) ([]byte, error) { +func (a *KeyAPI) WaitForDecryptionKey(uuid, listenAddr string) (diskKey, measurementSecret []byte, err error) { if uuid == "" { - return nil, errors.New("received no disk UUID") + return nil, nil, errors.New("received no disk UUID") } a.listenAddr = listenAddr creds := atlscredentials.New(a.issuer, nil) @@ -72,7 +77,7 @@ func (a *KeyAPI) WaitForDecryptionKey(uuid, listenAddr string) ([]byte, error) { keyproto.RegisterAPIServer(server, a) listener, err := net.Listen("tcp", listenAddr) if err != nil { - return nil, err + return nil, nil, err } defer listener.Close() @@ -81,7 +86,7 @@ func (a *KeyAPI) WaitForDecryptionKey(uuid, listenAddr string) ([]byte, error) { defer server.GracefulStop() a.requestKeyLoop(uuid) - return a.key, nil + return a.key, a.measurementSecret, nil } // ResetKey resets a previously set key. @@ -118,7 +123,7 @@ func (a *KeyAPI) requestKeyLoop(uuid string, opts ...grpc.DialOption) { func (a *KeyAPI) requestKey(uuid string, credentials credentials.TransportCredentials, opts ...grpc.DialOption) { // list available control-plane nodes - endpoints, _ := metadata.KMSEndpoints(context.Background(), a.metadata) + endpoints, _ := metadata.JoinServiceEndpoints(context.Background(), a.metadata) a.log.With(zap.Strings("endpoints", endpoints)).Infof("Sending a key request to available control-plane nodes") // notify all available control-plane nodes to send a key to the node @@ -126,24 +131,31 @@ func (a *KeyAPI) requestKey(uuid string, credentials credentials.TransportCreden for _, endpoint := range endpoints { ctx, cancel := context.WithTimeout(context.Background(), a.timeout) defer cancel() + + // request rejoin ticket from JoinService conn, err := grpc.DialContext(ctx, endpoint, append(opts, grpc.WithTransportCredentials(credentials))...) if err != nil { continue } defer conn.Close() - client := kmsproto.NewAPIClient(conn) - response, err := client.GetDataKey(ctx, &kmsproto.GetDataKeyRequest{DataKeyId: uuid, Length: constants.StateDiskKeyLength}) + client := joinproto.NewAPIClient(conn) + response, err := client.IssueRejoinTicket(ctx, &joinproto.IssueRejoinTicketRequest{DiskUuid: uuid}) if err != nil { a.log.With(zap.Error(err), zap.String("endpoint", endpoint)).Warnf("Failed to request key") continue } + + // push key to own gRPC server pushKeyConn, err := grpc.DialContext(ctx, a.listenAddr, append(opts, grpc.WithTransportCredentials(credentials))...) if err != nil { continue } defer pushKeyConn.Close() pushKeyClient := keyproto.NewAPIClient(pushKeyConn) - if _, err := pushKeyClient.PushStateDiskKey(ctx, &keyproto.PushStateDiskKeyRequest{StateDiskKey: response.DataKey}); err != nil { + if _, err := pushKeyClient.PushStateDiskKey( + ctx, + &keyproto.PushStateDiskKeyRequest{StateDiskKey: response.StateDiskKey, MeasurementSecret: response.MeasurementSecret}, + ); err != nil { a.log.With(zap.Error(err), zap.String("endpoint", a.listenAddr)).Errorf("Failed to push key") continue } diff --git a/state/keyservice/keyservice_test.go b/state/keyservice/keyservice_test.go index 210521d5a..208cc9246 100644 --- a/state/keyservice/keyservice_test.go +++ b/state/keyservice/keyservice_test.go @@ -125,19 +125,24 @@ func TestPushStateDiskKey(t *testing.T) { }{ "success": { testAPI: &KeyAPI{keyReceived: make(chan struct{}, 1)}, - request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}, + request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), MeasurementSecret: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB")}, }, "key already set": { testAPI: &KeyAPI{ keyReceived: make(chan struct{}, 1), key: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), }, - request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB")}, + request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"), MeasurementSecret: []byte("CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC")}, wantErr: true, }, "incorrect size of pushed key": { testAPI: &KeyAPI{keyReceived: make(chan struct{}, 1)}, - request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("AAAAAAAAAAAAAAAA")}, + request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("AAAAAAAAAAAAAAAA"), MeasurementSecret: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB")}, + wantErr: true, + }, + "incorrect size of measurement secret": { + testAPI: &KeyAPI{keyReceived: make(chan struct{}, 1)}, + request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), MeasurementSecret: []byte("BBBBBBBBBBBBBBBB")}, wantErr: true, }, } diff --git a/state/setup/interface.go b/state/setup/interface.go index d88838dd0..ec820753d 100644 --- a/state/setup/interface.go +++ b/state/setup/interface.go @@ -22,7 +22,7 @@ type DeviceMapper interface { // KeyWaiter is an interface to request and wait for disk decryption keys. type KeyWaiter interface { - WaitForDecryptionKey(uuid, addr string) ([]byte, error) + WaitForDecryptionKey(uuid, addr string) (diskKey, measurementSecret []byte, err error) ResetKey() } diff --git a/state/setup/setup.go b/state/setup/setup.go index d4c4696e8..058e69e3d 100644 --- a/state/setup/setup.go +++ b/state/setup/setup.go @@ -9,8 +9,9 @@ import ( "syscall" "github.com/edgelesssys/constellation/bootstrapper/nodestate" + "github.com/edgelesssys/constellation/internal/attestation" "github.com/edgelesssys/constellation/internal/attestation/vtpm" - "github.com/edgelesssys/constellation/internal/constants" + "github.com/edgelesssys/constellation/internal/crypto" "github.com/edgelesssys/constellation/internal/file" "github.com/edgelesssys/constellation/internal/logger" "github.com/spf13/afero" @@ -57,7 +58,7 @@ func (s *SetupManager) PrepareExistingDisk() error { uuid := s.mapper.DiskUUID() getKey: - passphrase, err := s.keyWaiter.WaitForDecryptionKey(uuid, net.JoinHostPort("0.0.0.0", RecoveryPort)) + passphrase, measurementSecret, err := s.keyWaiter.WaitForDecryptionKey(uuid, net.JoinHostPort("0.0.0.0", RecoveryPort)) if err != nil { return err } @@ -77,13 +78,17 @@ getKey: return err } - ownerID, clusterID, err := s.readInitSecrets(stateInfoPath) + measurementSalt, err := s.readMeasurementSalt(stateInfoPath) + if err != nil { + return err + } + clusterID, err := attestation.DeriveClusterID(measurementSalt, measurementSecret) if err != nil { return err } // taint the node as initialized - if err := vtpm.MarkNodeAsBootstrapped(s.openTPM, ownerID, clusterID); err != nil { + if err := vtpm.MarkNodeAsBootstrapped(s.openTPM, clusterID); err != nil { return err } @@ -99,7 +104,7 @@ func (s *SetupManager) PrepareNewDisk() error { return err } - passphrase := make([]byte, constants.RNGLengthDefault) + passphrase := make([]byte, crypto.RNGLengthDefault) if _, err := rand.Read(passphrase); err != nil { return err } @@ -114,16 +119,16 @@ func (s *SetupManager) PrepareNewDisk() error { return s.mapper.MapDisk(stateDiskMappedName, string(passphrase)) } -func (s *SetupManager) readInitSecrets(path string) ([]byte, []byte, error) { +func (s *SetupManager) readMeasurementSalt(path string) ([]byte, error) { handler := file.NewHandler(s.fs) var state nodestate.NodeState if err := handler.ReadJSON(path, &state); err != nil { - return nil, nil, err + return nil, err } - if len(state.ClusterID) == 0 || len(state.OwnerID) == 0 { - return nil, nil, errors.New("missing state information to retaint node") + if len(state.MeasurementSalt) != crypto.RNGLengthDefault { + return nil, errors.New("missing state information to retaint node") } - return state.OwnerID, state.ClusterID, nil + return state.MeasurementSalt, nil } diff --git a/state/setup/setup_test.go b/state/setup/setup_test.go index 89b6fabc9..a2cf0f658 100644 --- a/state/setup/setup_test.go +++ b/state/setup/setup_test.go @@ -9,7 +9,7 @@ import ( "github.com/edgelesssys/constellation/bootstrapper/nodestate" "github.com/edgelesssys/constellation/internal/attestation/vtpm" - "github.com/edgelesssys/constellation/internal/constants" + "github.com/edgelesssys/constellation/internal/crypto" "github.com/edgelesssys/constellation/internal/file" "github.com/edgelesssys/constellation/internal/logger" "github.com/spf13/afero" @@ -108,9 +108,10 @@ func TestPrepareExistingDisk(t *testing.T) { t.Run(name, func(t *testing.T) { assert := assert.New(t) + salt := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") if !tc.missingState { handler := file.NewHandler(tc.fs) - require.NoError(t, handler.WriteJSON(stateInfoPath, nodestate.NodeState{OwnerID: []byte("ownerID"), ClusterID: []byte("clusterID")}, file.OptMkdirAll)) + require.NoError(t, handler.WriteJSON(stateInfoPath, nodestate.NodeState{MeasurementSalt: salt}, file.OptMkdirAll)) } setupManager := New( @@ -193,43 +194,30 @@ func TestPrepareNewDisk(t *testing.T) { data, err := tc.fs.ReadFile(filepath.Join(keyPath, keyFile)) require.NoError(t, err) - assert.Len(data, constants.RNGLengthDefault) + assert.Len(data, crypto.RNGLengthDefault) } }) } } -func TestReadInitSecrets(t *testing.T) { +func TestReadMeasurementSalt(t *testing.T) { + salt := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") testCases := map[string]struct { fs afero.Afero - ownerID string - clusterID string + salt []byte writeFile bool wantErr bool }{ "success": { fs: afero.Afero{Fs: afero.NewMemMapFs()}, - ownerID: "ownerID", - clusterID: "clusterID", + salt: salt, writeFile: true, }, "no state file": { fs: afero.Afero{Fs: afero.NewMemMapFs()}, wantErr: true, }, - "missing ownerID": { - fs: afero.Afero{Fs: afero.NewMemMapFs()}, - clusterID: "clusterID", - writeFile: true, - wantErr: true, - }, - "missing clusterID": { - fs: afero.Afero{Fs: afero.NewMemMapFs()}, - ownerID: "ownerID", - writeFile: true, - wantErr: true, - }, - "no IDs": { + "missing salt": { fs: afero.Afero{Fs: afero.NewMemMapFs()}, writeFile: true, wantErr: true, @@ -243,19 +231,18 @@ func TestReadInitSecrets(t *testing.T) { if tc.writeFile { handler := file.NewHandler(tc.fs) - state := nodestate.NodeState{ClusterID: []byte(tc.clusterID), OwnerID: []byte(tc.ownerID)} - require.NoError(handler.WriteJSON("/tmp/test-state.json", state, file.OptMkdirAll)) + state := nodestate.NodeState{MeasurementSalt: tc.salt} + require.NoError(handler.WriteJSON("test-state.json", state, file.OptMkdirAll)) } setupManager := New(logger.NewTest(t), "test", tc.fs, nil, nil, nil, nil) - ownerID, clusterID, err := setupManager.readInitSecrets("/tmp/test-state.json") + measurementSalt, err := setupManager.readMeasurementSalt("test-state.json") if tc.wantErr { assert.Error(err) } else { assert.NoError(err) - assert.Equal([]byte(tc.ownerID), ownerID) - assert.Equal([]byte(tc.clusterID), clusterID) + assert.Equal(tc.salt, measurementSalt) } }) } @@ -311,19 +298,20 @@ func (s *stubMounter) MkdirAll(path string, perm fs.FileMode) error { } type stubKeyWaiter struct { - receivedUUID string - decryptionKey []byte - waitErr error - waitCalled bool + receivedUUID string + decryptionKey []byte + measurementSecret []byte + waitErr error + waitCalled bool } -func (s *stubKeyWaiter) WaitForDecryptionKey(uuid, addr string) ([]byte, error) { +func (s *stubKeyWaiter) WaitForDecryptionKey(uuid, addr string) ([]byte, []byte, error) { if s.waitCalled { - return nil, errors.New("wait called before key was reset") + return nil, nil, errors.New("wait called before key was reset") } s.waitCalled = true s.receivedUUID = uuid - return s.decryptionKey, s.waitErr + return s.decryptionKey, s.measurementSecret, s.waitErr } func (s *stubKeyWaiter) ResetKey() { diff --git a/state/test/integration_test.go b/state/test/integration_test.go index 770901344..43b8754d0 100644 --- a/state/test/integration_test.go +++ b/state/test/integration_test.go @@ -86,6 +86,7 @@ func TestKeyAPI(t *testing.T) { assert := assert.New(t) testKey := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") + testSecret := []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBB") // get a free port on localhost to run the test on listener, err := net.Listen("tcp", "localhost:0") @@ -114,14 +115,16 @@ func TestKeyAPI(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) defer cancel() _, err = client.PushStateDiskKey(ctx, &keyproto.PushStateDiskKeyRequest{ - StateDiskKey: testKey, + StateDiskKey: testKey, + MeasurementSecret: testSecret, }) require.NoError(err) }() - key, err := api.WaitForDecryptionKey("12345678-1234-1234-1234-123456789ABC", apiAddr) + key, measurementSecret, err := api.WaitForDecryptionKey("12345678-1234-1234-1234-123456789ABC", apiAddr) assert.NoError(err) assert.Equal(testKey, key) + assert.Equal(testSecret, measurementSecret) } type fakeMetadataAPI struct{}