AB#2260 Refactor disk-mapper recovery (#82)

* Refactor disk-mapper recovery

* Adapt constellation recover command to use new disk-mapper recovery API

* Fix Cilium connectivity on rebooting nodes (#89)

* Lower CoreDNS reschedule timeout to 10 seconds (#93)

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2022-09-08 14:45:27 +02:00 committed by GitHub
parent a7b20b2a11
commit 8cb155d5c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
40 changed files with 1600 additions and 1130 deletions

View File

@ -106,6 +106,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Owner ID and Unique ID are merged into a single value: Cluster ID. - Owner ID and Unique ID are merged into a single value: Cluster ID.
- Streamline logging to only use one logging library, instead of multiple. - Streamline logging to only use one logging library, instead of multiple.
- Replace dependency on github.com/willdonnelly/passwd with own implementation. - Replace dependency on github.com/willdonnelly/passwd with own implementation.
- Refactor disk-mapper to allow a more streamlined node recovery
### Removed ### Removed

View File

@ -71,5 +71,5 @@ add_test(NAME unit-hack COMMAND go test -race -count=3 ./... WORKING_DIRECTORY $
add_test(NAME unit-node-operator COMMAND go test -race -count=3 ./... WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/operators/constellation-node-operator) add_test(NAME unit-node-operator COMMAND go test -race -count=3 ./... WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/operators/constellation-node-operator)
add_test(NAME integration-node-operator COMMAND make test WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/operators/constellation-node-operator) add_test(NAME integration-node-operator COMMAND make test WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/operators/constellation-node-operator)
add_test(NAME integration-csi COMMAND bash -c "go test -tags integration -c ./test/ && sudo ./test.test -test.v" WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/csi) add_test(NAME integration-csi COMMAND bash -c "go test -tags integration -c ./test/ && sudo ./test.test -test.v" WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/csi)
add_test(NAME integration-dm COMMAND bash -c "go test -tags integration -c ./test/ && sudo ./test.test -test.v" WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/state/internal) add_test(NAME integration-dm COMMAND bash -c "go test -tags integration -c ./test/ && sudo ./test.test -test.v" WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/disk-mapper/internal)
add_test(NAME integration-license COMMAND bash -c "go test -tags integration" WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/internal/license) add_test(NAME integration-license COMMAND bash -c "go test -tags integration" WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/internal/license)

View File

@ -61,7 +61,7 @@ Core components:
* [image](image): Build files for the Constellation disk image * [image](image): Build files for the Constellation disk image
* [kms](kms): Constellation's key management client and server * [kms](kms): Constellation's key management client and server
* [csi](csi): Package used by CSI plugins to create and mount encrypted block devices * [csi](csi): Package used by CSI plugins to create and mount encrypted block devices
* [state](state): Contains the disk-mapper that maps the encrypted node data disk during boot * [disk-mapper](disk-mapper): Contains the disk-mapper that maps the encrypted node data disk during boot
Development components: Development components:

View File

@ -36,7 +36,7 @@ ARG PROJECT_VERSION
RUN go build -o bootstrapper -tags=gcp,disable_tpm_simulator -buildvcs=false -ldflags "-s -w -buildid='' -X main.version=${PROJECT_VERSION}" ./cmd/bootstrapper/ RUN go build -o bootstrapper -tags=gcp,disable_tpm_simulator -buildvcs=false -ldflags "-s -w -buildid='' -X main.version=${PROJECT_VERSION}" ./cmd/bootstrapper/
FROM build AS build-disk-mapper FROM build AS build-disk-mapper
WORKDIR /constellation/state/ WORKDIR /constellation/disk-mapper/
RUN go build -o disk-mapper -ldflags "-s -w" ./cmd/ RUN go build -o disk-mapper -ldflags "-s -w" ./cmd/
@ -44,4 +44,4 @@ FROM scratch AS bootstrapper
COPY --from=build-bootstrapper /constellation/bootstrapper/bootstrapper / COPY --from=build-bootstrapper /constellation/bootstrapper/bootstrapper /
FROM scratch AS disk-mapper FROM scratch AS disk-mapper
COPY --from=build-disk-mapper /constellation/state/disk-mapper / COPY --from=build-disk-mapper /constellation/disk-mapper/disk-mapper /

View File

@ -50,7 +50,7 @@ func run(issuerWrapper initserver.IssuerWrapper, tpm vtpm.TPMOpenFunc, fileHandl
} }
if nodeBootstrapped { if nodeBootstrapped {
if err := kube.StartKubelet(); err != nil { if err := kube.StartKubelet(log); err != nil {
log.With(zap.Error(err)).Fatalf("Failed to restart kubelet") log.With(zap.Error(err)).Fatalf("Failed to restart kubelet")
} }
return return
@ -88,7 +88,7 @@ func getDiskUUID() (string, error) {
type clusterInitJoiner interface { type clusterInitJoiner interface {
joinclient.ClusterJoiner joinclient.ClusterJoiner
initserver.ClusterInitializer initserver.ClusterInitializer
StartKubelet() error StartKubelet(*logger.Logger) error
} }
type metadataAPI interface { type metadataAPI interface {

View File

@ -33,7 +33,7 @@ func (c *clusterFake) JoinCluster(context.Context, *kubeadm.BootstrapTokenDiscov
} }
// StartKubelet starts the kubelet service. // StartKubelet starts the kubelet service.
func (c *clusterFake) StartKubelet() error { func (c *clusterFake) StartKubelet(*logger.Logger) error {
return nil return nil
} }

View File

@ -13,30 +13,6 @@ import (
"github.com/coreos/go-systemd/v22/dbus" "github.com/coreos/go-systemd/v22/dbus"
) )
func restartSystemdUnit(ctx context.Context, unit string) error {
conn, err := dbus.NewSystemdConnectionContext(ctx)
if err != nil {
return fmt.Errorf("establishing systemd connection: %w", err)
}
restartChan := make(chan string)
if _, err := conn.RestartUnitContext(ctx, unit, "replace", restartChan); err != nil {
return fmt.Errorf("restarting systemd unit %q: %w", unit, err)
}
// Wait for the restart to finish and actually check if it was
// successful or not.
result := <-restartChan
switch result {
case "done":
return nil
default:
return fmt.Errorf("restarting systemd unit %q failed: expected %v but received %v", unit, "done", result)
}
}
func startSystemdUnit(ctx context.Context, unit string) error { func startSystemdUnit(ctx context.Context, unit string) error {
conn, err := dbus.NewSystemdConnectionContext(ctx) conn, err := dbus.NewSystemdConnectionContext(ctx)
if err != nil { if err != nil {

View File

@ -264,12 +264,19 @@ func (k *KubernetesUtil) deployCiliumGCP(ctx context.Context, helmClient *action
return err return err
} }
timeoutS := int64(10)
// allow coredns to run on uninitialized nodes (required by cloud-controller-manager) // allow coredns to run on uninitialized nodes (required by cloud-controller-manager)
tolerations := []corev1.Toleration{ tolerations := []corev1.Toleration{
{ {
Key: "node.cloudprovider.kubernetes.io/uninitialized", Key: "node.cloudprovider.kubernetes.io/uninitialized",
Value: "true", Value: "true",
Effect: "NoSchedule", Effect: corev1.TaintEffectNoSchedule,
},
{
Key: "node.kubernetes.io/unreachable",
Operator: corev1.TolerationOpExists,
Effect: corev1.TaintEffectNoExecute,
TolerationSeconds: &timeoutS,
}, },
} }
if err = kubectl.AddTolerationsToDeployment(ctx, tolerations, "coredns", "kube-system"); err != nil { if err = kubectl.AddTolerationsToDeployment(ctx, tolerations, "coredns", "kube-system"); err != nil {
@ -305,7 +312,7 @@ func (k *KubernetesUtil) deployCiliumGCP(ctx context.Context, helmClient *action
// FixCilium fixes https://github.com/cilium/cilium/issues/19958 but instead of a rollout restart of // FixCilium fixes https://github.com/cilium/cilium/issues/19958 but instead of a rollout restart of
// the cilium daemonset, it only restarts the local cilium pod. // the cilium daemonset, it only restarts the local cilium pod.
func (k *KubernetesUtil) FixCilium(nodeNameK8s string, log *logger.Logger) { func (k *KubernetesUtil) FixCilium(log *logger.Logger) {
// wait for cilium pod to be healthy // wait for cilium pod to be healthy
client := http.Client{} client := http.Client{}
for { for {
@ -487,13 +494,6 @@ func (k *KubernetesUtil) StartKubelet() error {
return startSystemdUnit(ctx, "kubelet.service") return startSystemdUnit(ctx, "kubelet.service")
} }
// RestartKubelet restarts a kubelet.
func (k *KubernetesUtil) RestartKubelet() error {
ctx, cancel := context.WithTimeout(context.TODO(), kubeletStartTimeout)
defer cancel()
return restartSystemdUnit(ctx, "kubelet.service")
}
// createSignedKubeletCert manually creates a Kubernetes CA signed kubelet certificate for the bootstrapper node. // createSignedKubeletCert manually creates a Kubernetes CA signed kubelet certificate for the bootstrapper node.
// This is necessary because this node does not request a certificate from the join service. // This is necessary because this node does not request a certificate from the join service.
func (k *KubernetesUtil) createSignedKubeletCert(nodeName string, ips []net.IP) error { func (k *KubernetesUtil) createSignedKubeletCert(nodeName string, ips []net.IP) error {

View File

@ -33,6 +33,5 @@ type clusterUtil interface {
SetupNodeMaintenanceOperator(kubectl k8sapi.Client, nodeMaintenanceOperatorConfiguration kubernetes.Marshaler) error SetupNodeMaintenanceOperator(kubectl k8sapi.Client, nodeMaintenanceOperatorConfiguration kubernetes.Marshaler) error
SetupNodeOperator(ctx context.Context, kubectl k8sapi.Client, nodeOperatorConfiguration kubernetes.Marshaler) error SetupNodeOperator(ctx context.Context, kubectl k8sapi.Client, nodeOperatorConfiguration kubernetes.Marshaler) error
StartKubelet() error StartKubelet() error
RestartKubelet() error FixCilium(log *logger.Logger)
FixCilium(nodeNameK8s string, log *logger.Logger)
} }

View File

@ -229,7 +229,7 @@ func (k *KubeWrapper) InitCluster(
return nil, fmt.Errorf("failed to setup k8s version ConfigMap: %w", err) return nil, fmt.Errorf("failed to setup k8s version ConfigMap: %w", err)
} }
k.clusterUtil.FixCilium(nodeName, log) k.clusterUtil.FixCilium(log)
return k.GetKubeconfig() return k.GetKubeconfig()
} }
@ -309,7 +309,7 @@ func (k *KubeWrapper) JoinCluster(ctx context.Context, args *kubeadm.BootstrapTo
return fmt.Errorf("joining cluster: %v; %w ", string(joinConfigYAML), err) return fmt.Errorf("joining cluster: %v; %w ", string(joinConfigYAML), err)
} }
k.clusterUtil.FixCilium(nodeName, log) k.clusterUtil.FixCilium(log)
return nil return nil
} }
@ -481,8 +481,13 @@ func k8sCompliantHostname(in string) string {
} }
// StartKubelet starts the kubelet service. // StartKubelet starts the kubelet service.
func (k *KubeWrapper) StartKubelet() error { func (k *KubeWrapper) StartKubelet(log *logger.Logger) error {
return k.clusterUtil.StartKubelet() if err := k.clusterUtil.StartKubelet(); err != nil {
return fmt.Errorf("starting kubelet: %w", err)
}
k.clusterUtil.FixCilium(log)
return nil
} }
// getIPAddr retrieves to default sender IP used for outgoing connection. // getIPAddr retrieves to default sender IP used for outgoing connection.

View File

@ -531,7 +531,6 @@ type stubClusterUtil struct {
setupNodeOperatorErr error setupNodeOperatorErr error
joinClusterErr error joinClusterErr error
startKubeletErr error startKubeletErr error
restartKubeletErr error
initConfigs [][]byte initConfigs [][]byte
joinConfigs [][]byte joinConfigs [][]byte
@ -603,11 +602,7 @@ func (s *stubClusterUtil) StartKubelet() error {
return s.startKubeletErr return s.startKubeletErr
} }
func (s *stubClusterUtil) RestartKubelet() error { func (s *stubClusterUtil) FixCilium(log *logger.Logger) {
return s.restartKubeletErr
}
func (s *stubClusterUtil) FixCilium(nodeName string, log *logger.Logger) {
} }
type stubConfigProvider struct { type stubConfigProvider struct {

View File

@ -7,24 +7,26 @@ SPDX-License-Identifier: AGPL-3.0-only
package cmd package cmd
import ( import (
"errors" "context"
"fmt" "fmt"
"regexp" "io"
"strings"
"github.com/edgelesssys/constellation/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/cli/internal/proto" "github.com/edgelesssys/constellation/cli/internal/proto"
"github.com/edgelesssys/constellation/internal/attestation" "github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/internal/constants" "github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/crypto"
"github.com/edgelesssys/constellation/internal/file" "github.com/edgelesssys/constellation/internal/file"
"github.com/edgelesssys/constellation/internal/state" "github.com/edgelesssys/constellation/internal/state"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
var diskUUIDRegexp = regexp.MustCompile("^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$") type recoveryClient interface {
Connect(endpoint string, validators atls.Validator) error
Recover(ctx context.Context, masterSecret, salt []byte) error
io.Closer
}
// NewRecoverCmd returns a new cobra.Command for the recover command. // NewRecoverCmd returns a new cobra.Command for the recover command.
func NewRecoverCmd() *cobra.Command { func NewRecoverCmd() *cobra.Command {
@ -38,15 +40,13 @@ func NewRecoverCmd() *cobra.Command {
} }
cmd.Flags().StringP("endpoint", "e", "", "endpoint of the instance, passed as HOST[:PORT] (required)") cmd.Flags().StringP("endpoint", "e", "", "endpoint of the instance, passed as HOST[:PORT] (required)")
must(cmd.MarkFlagRequired("endpoint")) must(cmd.MarkFlagRequired("endpoint"))
cmd.Flags().String("disk-uuid", "", "disk UUID of the encrypted state disk (required)")
must(cmd.MarkFlagRequired("disk-uuid"))
cmd.Flags().String("master-secret", constants.MasterSecretFilename, "path to master secret file") cmd.Flags().String("master-secret", constants.MasterSecretFilename, "path to master secret file")
return cmd return cmd
} }
func runRecover(cmd *cobra.Command, _ []string) error { func runRecover(cmd *cobra.Command, _ []string) error {
fileHandler := file.NewHandler(afero.NewOsFs()) fileHandler := file.NewHandler(afero.NewOsFs())
recoveryClient := &proto.KeyClient{} recoveryClient := &proto.RecoverClient{}
defer recoveryClient.Close() defer recoveryClient.Close()
return recover(cmd, fileHandler, recoveryClient) return recover(cmd, fileHandler, recoveryClient)
} }
@ -82,17 +82,7 @@ func recover(cmd *cobra.Command, fileHandler file.Handler, recoveryClient recove
return err return err
} }
diskKey, err := deriveStateDiskKey(masterSecret.Key, masterSecret.Salt, flags.diskUUID) if err := recoveryClient.Recover(cmd.Context(), masterSecret.Key, masterSecret.Salt); err != nil {
if err != nil {
return err
}
measurementSecret, err := attestation.DeriveMeasurementSecret(masterSecret.Key, masterSecret.Salt)
if err != nil {
return err
}
if err := recoveryClient.PushStateDiskKey(cmd.Context(), diskKey, measurementSecret); err != nil {
return err return err
} }
@ -110,15 +100,6 @@ func parseRecoverFlags(cmd *cobra.Command) (recoverFlags, error) {
return recoverFlags{}, fmt.Errorf("validating endpoint argument: %w", err) return recoverFlags{}, fmt.Errorf("validating endpoint argument: %w", err)
} }
diskUUID, err := cmd.Flags().GetString("disk-uuid")
if err != nil {
return recoverFlags{}, fmt.Errorf("parsing disk-uuid argument: %w", err)
}
if match := diskUUIDRegexp.MatchString(diskUUID); !match {
return recoverFlags{}, errors.New("flag '--disk-uuid' isn't a valid LUKS UUID")
}
diskUUID = strings.ToLower(diskUUID)
masterSecretPath, err := cmd.Flags().GetString("master-secret") masterSecretPath, err := cmd.Flags().GetString("master-secret")
if err != nil { if err != nil {
return recoverFlags{}, fmt.Errorf("parsing master-secret path argument: %w", err) return recoverFlags{}, fmt.Errorf("parsing master-secret path argument: %w", err)
@ -131,7 +112,6 @@ func parseRecoverFlags(cmd *cobra.Command) (recoverFlags, error) {
return recoverFlags{ return recoverFlags{
endpoint: endpoint, endpoint: endpoint,
diskUUID: diskUUID,
secretPath: masterSecretPath, secretPath: masterSecretPath,
configPath: configPath, configPath: configPath,
}, nil }, nil
@ -139,12 +119,6 @@ func parseRecoverFlags(cmd *cobra.Command) (recoverFlags, error) {
type recoverFlags struct { type recoverFlags struct {
endpoint string endpoint string
diskUUID string
secretPath string secretPath string
configPath string configPath string
} }
// deriveStateDiskKey derives a state disk key from a master key, a salt, and a disk UUID.
func deriveStateDiskKey(masterKey, salt []byte, diskUUID string) ([]byte, error) {
return crypto.DeriveKey(masterKey, salt, []byte(crypto.HKDFInfoPrefix+diskUUID), crypto.StateDiskKeyLength)
}

View File

@ -8,10 +8,11 @@ package cmd
import ( import (
"bytes" "bytes"
"context"
"errors" "errors"
"strings"
"testing" "testing"
"github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/internal/config" "github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/constants" "github.com/edgelesssys/constellation/internal/constants"
@ -57,7 +58,6 @@ func TestRecover(t *testing.T) {
client *stubRecoveryClient client *stubRecoveryClient
masterSecret testvector.HKDF masterSecret testvector.HKDF
endpointFlag string endpointFlag string
diskUUIDFlag string
masterSecretFlag string masterSecretFlag string
configFlag string configFlag string
stateless bool stateless bool
@ -67,29 +67,13 @@ func TestRecover(t *testing.T) {
existingState: validState, existingState: validState,
client: &stubRecoveryClient{}, client: &stubRecoveryClient{},
endpointFlag: "192.0.2.1", endpointFlag: "192.0.2.1",
diskUUIDFlag: testvector.HKDFZero.Info,
masterSecret: testvector.HKDFZero, masterSecret: testvector.HKDFZero,
}, },
"uppercase disk uuid works": {
existingState: validState,
client: &stubRecoveryClient{},
endpointFlag: "192.0.2.1",
diskUUIDFlag: strings.ToUpper(testvector.HKDF0xFF.Info),
masterSecret: testvector.HKDF0xFF,
},
"lowercase disk uuid results in same key": {
existingState: validState,
client: &stubRecoveryClient{},
endpointFlag: "192.0.2.1",
diskUUIDFlag: strings.ToLower(testvector.HKDF0xFF.Info),
masterSecret: testvector.HKDF0xFF,
},
"missing flags": { "missing flags": {
wantErr: true, wantErr: true,
}, },
"missing config": { "missing config": {
endpointFlag: "192.0.2.1", endpointFlag: "192.0.2.1",
diskUUIDFlag: testvector.HKDFZero.Info,
masterSecret: testvector.HKDFZero, masterSecret: testvector.HKDFZero,
configFlag: "nonexistent-config", configFlag: "nonexistent-config",
wantErr: true, wantErr: true,
@ -97,7 +81,6 @@ func TestRecover(t *testing.T) {
"missing state": { "missing state": {
existingState: validState, existingState: validState,
endpointFlag: "192.0.2.1", endpointFlag: "192.0.2.1",
diskUUIDFlag: testvector.HKDFZero.Info,
masterSecret: testvector.HKDFZero, masterSecret: testvector.HKDFZero,
stateless: true, stateless: true,
wantErr: true, wantErr: true,
@ -105,7 +88,6 @@ func TestRecover(t *testing.T) {
"invalid cloud provider": { "invalid cloud provider": {
existingState: invalidCSPState, existingState: invalidCSPState,
endpointFlag: "192.0.2.1", endpointFlag: "192.0.2.1",
diskUUIDFlag: testvector.HKDFZero.Info,
masterSecret: testvector.HKDFZero, masterSecret: testvector.HKDFZero,
wantErr: true, wantErr: true,
}, },
@ -113,7 +95,6 @@ func TestRecover(t *testing.T) {
existingState: validState, existingState: validState,
client: &stubRecoveryClient{connectErr: errors.New("connect failed")}, client: &stubRecoveryClient{connectErr: errors.New("connect failed")},
endpointFlag: "192.0.2.1", endpointFlag: "192.0.2.1",
diskUUIDFlag: testvector.HKDFZero.Info,
masterSecret: testvector.HKDFZero, masterSecret: testvector.HKDFZero,
wantErr: true, wantErr: true,
}, },
@ -121,7 +102,6 @@ func TestRecover(t *testing.T) {
existingState: validState, existingState: validState,
client: &stubRecoveryClient{pushStateDiskKeyErr: errors.New("pushing key failed")}, client: &stubRecoveryClient{pushStateDiskKeyErr: errors.New("pushing key failed")},
endpointFlag: "192.0.2.1", endpointFlag: "192.0.2.1",
diskUUIDFlag: testvector.HKDFZero.Info,
masterSecret: testvector.HKDFZero, masterSecret: testvector.HKDFZero,
wantErr: true, wantErr: true,
}, },
@ -140,9 +120,6 @@ func TestRecover(t *testing.T) {
if tc.endpointFlag != "" { if tc.endpointFlag != "" {
require.NoError(cmd.Flags().Set("endpoint", tc.endpointFlag)) require.NoError(cmd.Flags().Set("endpoint", tc.endpointFlag))
} }
if tc.diskUUIDFlag != "" {
require.NoError(cmd.Flags().Set("disk-uuid", tc.diskUUIDFlag))
}
if tc.masterSecretFlag != "" { if tc.masterSecretFlag != "" {
require.NoError(cmd.Flags().Set("master-secret", tc.masterSecretFlag)) require.NoError(cmd.Flags().Set("master-secret", tc.masterSecretFlag))
} }
@ -170,7 +147,6 @@ func TestRecover(t *testing.T) {
assert.NoError(err) assert.NoError(err)
assert.Contains(out.String(), "Pushed recovery key.") assert.Contains(out.String(), "Pushed recovery key.")
assert.Equal(tc.masterSecret.Output, tc.client.pushStateDiskKeyKey)
}) })
} }
} }
@ -185,38 +161,24 @@ func TestParseRecoverFlags(t *testing.T) {
wantErr: true, wantErr: true,
}, },
"invalid ip": { "invalid ip": {
args: []string{"-e", "192.0.2.1:2:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012"}, args: []string{"-e", "192.0.2.1:2:2"},
wantErr: true,
},
"invalid disk uuid": {
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "invalid"},
wantErr: true, wantErr: true,
}, },
"minimal args set": { "minimal args set": {
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012"}, args: []string{"-e", "192.0.2.1:2"},
wantFlags: recoverFlags{ wantFlags: recoverFlags{
endpoint: "192.0.2.1:2", endpoint: "192.0.2.1:2",
diskUUID: "12345678-1234-1234-1234-123456789012",
secretPath: "constellation-mastersecret.json", secretPath: "constellation-mastersecret.json",
}, },
}, },
"all args set": { "all args set": {
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012", "--config", "config-path", "--master-secret", "/path/super-secret.json"}, args: []string{"-e", "192.0.2.1:2", "--config", "config-path", "--master-secret", "/path/super-secret.json"},
wantFlags: recoverFlags{ wantFlags: recoverFlags{
endpoint: "192.0.2.1:2", endpoint: "192.0.2.1:2",
diskUUID: "12345678-1234-1234-1234-123456789012",
secretPath: "/path/super-secret.json", secretPath: "/path/super-secret.json",
configPath: "config-path", configPath: "config-path",
}, },
}, },
"uppercase disk-uuid is converted to lowercase": {
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "ABCDEFAB-CDEF-ABCD-ABCD-ABCDEFABCDEF"},
wantFlags: recoverFlags{
endpoint: "192.0.2.1:2",
diskUUID: "abcdefab-cdef-abcd-abcd-abcdefabcdef",
secretPath: "constellation-mastersecret.json",
},
},
} }
for name, tc := range testCases { for name, tc := range testCases {
@ -239,26 +201,26 @@ func TestParseRecoverFlags(t *testing.T) {
} }
} }
func TestDeriveStateDiskKey(t *testing.T) { type stubRecoveryClient struct {
testCases := map[string]struct { conn bool
masterSecret testvector.HKDF connectErr error
}{ closeErr error
"all zero": { pushStateDiskKeyErr error
masterSecret: testvector.HKDFZero,
},
"all 0xff": {
masterSecret: testvector.HKDF0xFF,
},
}
for name, tc := range testCases { pushStateDiskKeyKey []byte
t.Run(name, func(t *testing.T) { }
assert := assert.New(t)
func (c *stubRecoveryClient) Connect(string, atls.Validator) error {
stateDiskKey, err := deriveStateDiskKey(tc.masterSecret.Secret, tc.masterSecret.Salt, tc.masterSecret.Info) c.conn = true
return c.connectErr
assert.NoError(err) }
assert.Equal(tc.masterSecret.Output, stateDiskKey)
}) func (c *stubRecoveryClient) Close() error {
} c.conn = false
return c.closeErr
}
func (c *stubRecoveryClient) Recover(_ context.Context, stateDiskKey, _ []byte) error {
c.pushStateDiskKeyKey = stateDiskKey
return c.pushStateDiskKeyErr
} }

View File

@ -1,20 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"context"
"io"
"github.com/edgelesssys/constellation/internal/atls"
)
type recoveryClient interface {
Connect(endpoint string, validators atls.Validator) error
PushStateDiskKey(ctx context.Context, stateDiskKey, measurementSecret []byte) error
io.Closer
}

View File

@ -1,37 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"context"
"github.com/edgelesssys/constellation/internal/atls"
)
type stubRecoveryClient struct {
conn bool
connectErr error
closeErr error
pushStateDiskKeyErr error
pushStateDiskKeyKey []byte
}
func (c *stubRecoveryClient) Connect(_ string, _ atls.Validator) error {
c.conn = true
return c.connectErr
}
func (c *stubRecoveryClient) Close() error {
c.conn = false
return c.closeErr
}
func (c *stubRecoveryClient) PushStateDiskKey(_ context.Context, stateDiskKey, _ []byte) error {
c.pushStateDiskKeyKey = stateDiskKey
return c.pushStateDiskKeyErr
}

View File

@ -9,17 +9,21 @@ package proto
import ( import (
"context" "context"
"errors" "errors"
"io"
"github.com/edgelesssys/constellation/disk-mapper/recoverproto"
"github.com/edgelesssys/constellation/internal/atls" "github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/attestation"
"github.com/edgelesssys/constellation/internal/crypto"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials" "github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/state/keyproto" "go.uber.org/multierr"
"google.golang.org/grpc" "google.golang.org/grpc"
) )
// KeyClient wraps a KeyAPI client and the connection to it. // RecoverClient wraps a recoverAPI client and the connection to it.
type KeyClient struct { type RecoverClient struct {
conn *grpc.ClientConn conn *grpc.ClientConn
keyapi keyproto.APIClient recoverapi recoverproto.APIClient
} }
// Connect connects the client to a given server, using the handed // Connect connects the client to a given server, using the handed
@ -27,7 +31,7 @@ type KeyClient struct {
// The connection must be closed using Close(). If connect is // The connection must be closed using Close(). If connect is
// called on a client that already has a connection, the old // called on a client that already has a connection, the old
// connection is closed. // connection is closed.
func (c *KeyClient) Connect(endpoint string, validators atls.Validator) error { func (c *RecoverClient) Connect(endpoint string, validators atls.Validator) error {
creds := atlscredentials.New(nil, []atls.Validator{validators}) creds := atlscredentials.New(nil, []atls.Validator{validators})
conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(creds)) conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(creds))
@ -36,14 +40,14 @@ func (c *KeyClient) Connect(endpoint string, validators atls.Validator) error {
} }
_ = c.Close() _ = c.Close()
c.conn = conn c.conn = conn
c.keyapi = keyproto.NewAPIClient(conn) c.recoverapi = recoverproto.NewAPIClient(conn)
return nil return nil
} }
// Close closes the grpc connection of the client. // Close closes the grpc connection of the client.
// Close is idempotent and can be called on non connected clients // Close is idempotent and can be called on non connected clients
// without returning an error. // without returning an error.
func (c *KeyClient) Close() error { func (c *RecoverClient) Close() error {
if c.conn == nil { if c.conn == nil {
return nil return nil
} }
@ -55,17 +59,58 @@ func (c *KeyClient) Close() error {
} }
// PushStateDiskKey pushes the state disk key to a constellation instance in recovery mode. // PushStateDiskKey pushes the state disk key to a constellation instance in recovery mode.
// The state disk key must be derived from the UUID of the state disk and the master key. func (c *RecoverClient) Recover(ctx context.Context, masterSecret, salt []byte) (retErr error) {
func (c *KeyClient) PushStateDiskKey(ctx context.Context, stateDiskKey, measurementSecret []byte) error { if c.recoverapi == nil {
if c.keyapi == nil {
return errors.New("client is not connected") return errors.New("client is not connected")
} }
req := &keyproto.PushStateDiskKeyRequest{ measurementSecret, err := attestation.DeriveMeasurementSecret(masterSecret, salt)
StateDiskKey: stateDiskKey, if err != nil {
MeasurementSecret: measurementSecret, return err
} }
_, err := c.keyapi.PushStateDiskKey(ctx, req) recoverclient, err := c.recoverapi.Recover(ctx)
return err if err != nil {
return err
}
defer func() {
if err := recoverclient.CloseSend(); err != nil {
multierr.AppendInto(&retErr, err)
}
}()
if err := recoverclient.Send(&recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_MeasurementSecret{
MeasurementSecret: measurementSecret,
},
}); err != nil {
return err
}
res, err := recoverclient.Recv()
if err != nil {
return err
}
stateDiskKey, err := deriveStateDiskKey(masterSecret, salt, res.DiskUuid)
if err != nil {
return err
}
if err := recoverclient.Send(&recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_StateDiskKey{
StateDiskKey: stateDiskKey,
},
}); err != nil {
return err
}
if _, err := recoverclient.Recv(); err != nil && !errors.Is(err, io.EOF) {
return err
}
return nil
}
// deriveStateDiskKey derives a state disk key from a master key, a salt, and a disk UUID.
func deriveStateDiskKey(masterKey, salt []byte, diskUUID string) ([]byte, error) {
return crypto.DeriveKey(masterKey, salt, []byte(crypto.HKDFInfoPrefix+diskUUID), crypto.StateDiskKeyLength)
} }

View File

@ -0,0 +1,38 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package proto
import (
"testing"
"github.com/edgelesssys/constellation/internal/crypto/testvector"
"github.com/stretchr/testify/assert"
)
func TestDeriveStateDiskKey(t *testing.T) {
testCases := map[string]struct {
masterSecret testvector.HKDF
}{
"all zero": {
masterSecret: testvector.HKDFZero,
},
"all 0xff": {
masterSecret: testvector.HKDF0xFF,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
stateDiskKey, err := deriveStateDiskKey(tc.masterSecret.Secret, tc.masterSecret.Salt, tc.masterSecret.Info)
assert.NoError(err)
assert.Equal(tc.masterSecret.Output, stateDiskKey)
})
}
}

View File

@ -4,11 +4,11 @@ Files and source code for mounting persistent state disks
## Testing ## Testing
Integration test is available in `state/test/integration_test.go`. Integration test is available in `disk-mapper/test/integration_test.go`.
The integration test requires root privileges since it uses dm-crypt. The integration test requires root privileges since it uses dm-crypt.
Build and run the test: Build and run the test:
```bash ```bash
go test -c -tags=integration ./state/test/ go test -c -tags=integration ./disk-mapper/test/
sudo ./test.test sudo ./test.test
``` ```

View File

@ -11,12 +11,17 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"flag" "flag"
"net"
"net/http" "net/http"
"net/url" "net/url"
"path/filepath" "path/filepath"
"strings" "strings"
"time"
"github.com/edgelesssys/constellation/disk-mapper/internal/mapper"
"github.com/edgelesssys/constellation/disk-mapper/internal/recoveryserver"
"github.com/edgelesssys/constellation/disk-mapper/internal/rejoinclient"
"github.com/edgelesssys/constellation/disk-mapper/internal/setup"
"github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/attestation/azure" "github.com/edgelesssys/constellation/internal/attestation/azure"
"github.com/edgelesssys/constellation/internal/attestation/gcp" "github.com/edgelesssys/constellation/internal/attestation/gcp"
"github.com/edgelesssys/constellation/internal/attestation/qemu" "github.com/edgelesssys/constellation/internal/attestation/qemu"
@ -26,10 +31,8 @@ import (
"github.com/edgelesssys/constellation/internal/cloud/metadata" "github.com/edgelesssys/constellation/internal/cloud/metadata"
qemucloud "github.com/edgelesssys/constellation/internal/cloud/qemu" qemucloud "github.com/edgelesssys/constellation/internal/cloud/qemu"
"github.com/edgelesssys/constellation/internal/constants" "github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/grpc/dialer"
"github.com/edgelesssys/constellation/internal/logger" "github.com/edgelesssys/constellation/internal/logger"
"github.com/edgelesssys/constellation/state/internal/keyservice"
"github.com/edgelesssys/constellation/state/internal/mapper"
"github.com/edgelesssys/constellation/state/internal/setup"
tpmClient "github.com/google/go-tpm-tools/client" tpmClient "github.com/google/go-tpm-tools/client"
"github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2"
"github.com/spf13/afero" "github.com/spf13/afero"
@ -54,8 +57,8 @@ func main() {
// set up metadata API and quote issuer for aTLS connections // set up metadata API and quote issuer for aTLS connections
var err error var err error
var diskPath string var diskPath string
var issuer keyservice.QuoteIssuer var issuer atls.Issuer
var metadata metadata.InstanceLister var metadataAPI setup.MetadataAPI
switch strings.ToLower(*csp) { switch strings.ToLower(*csp) {
case "azure": case "azure":
diskPath, err = filepath.EvalSymlinks(azureStateDiskPath) diskPath, err = filepath.EvalSymlinks(azureStateDiskPath)
@ -63,7 +66,7 @@ func main() {
_ = exportPCRs() _ = exportPCRs()
log.With(zap.Error(err)).Fatalf("Unable to resolve Azure state disk path") log.With(zap.Error(err)).Fatalf("Unable to resolve Azure state disk path")
} }
metadata, err = azurecloud.NewMetadata(context.Background()) metadataAPI, err = azurecloud.NewMetadata(context.Background())
if err != nil { if err != nil {
log.With(zap.Error).Fatalf("Failed to create Azure metadata API") log.With(zap.Error).Fatalf("Failed to create Azure metadata API")
} }
@ -81,12 +84,13 @@ func main() {
if err != nil { if err != nil {
log.With(zap.Error).Fatalf("Failed to create GCP client") log.With(zap.Error).Fatalf("Failed to create GCP client")
} }
metadata = gcpcloud.New(gcpClient) metadataAPI = gcpcloud.New(gcpClient)
case "qemu": case "qemu":
diskPath = qemuStateDiskPath diskPath = qemuStateDiskPath
issuer = qemu.NewIssuer() issuer = qemu.NewIssuer()
metadata = &qemucloud.Metadata{} metadataAPI = &qemucloud.Metadata{}
_ = exportPCRs()
default: default:
log.Fatalf("CSP %s is not supported by Constellation", *csp) log.Fatalf("CSP %s is not supported by Constellation", *csp)
@ -104,7 +108,6 @@ func main() {
*csp, *csp,
diskPath, diskPath,
afero.Afero{Fs: afero.NewOsFs()}, afero.Afero{Fs: afero.NewOsFs()},
keyservice.New(log.Named("keyService"), issuer, metadata, 20*time.Second, 20*time.Second), // try to request a key every 20 seconds
mapper, mapper,
setup.DiskMounter{}, setup.DiskMounter{},
vtpm.OpenVTPM, vtpm.OpenVTPM,
@ -112,7 +115,23 @@ func main() {
// prepare the state disk // prepare the state disk
if mapper.IsLUKSDevice() { if mapper.IsLUKSDevice() {
err = setupManger.PrepareExistingDisk() // set up rejoin client
var self metadata.InstanceMetadata
self, err = metadataAPI.Self(context.Background())
if err != nil {
log.With(zap.Error(err)).Fatalf("Failed to get self metadata")
}
rejoinClient := rejoinclient.New(
dialer.New(issuer, nil, &net.Dialer{}),
self,
metadataAPI,
log.Named("rejoinClient"),
)
// set up recovery server
recoveryServer := recoveryserver.New(issuer, log.Named("recoveryServer"))
err = setupManger.PrepareExistingDisk(setup.NewNodeRecoverer(recoveryServer, rejoinClient))
} else { } else {
err = setupManger.PrepareNewDisk() err = setupManger.PrepareNewDisk()
} }

View File

@ -0,0 +1,134 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package recoveryserver
import (
"context"
"net"
"sync"
"github.com/edgelesssys/constellation/disk-mapper/recoverproto"
"github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/internal/logger"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// RecoveryServer is a gRPC server that can be used by an admin to recover a restarting node.
type RecoveryServer struct {
mux sync.Mutex
diskUUID string
stateDiskKey []byte
measurementSecret []byte
grpcServer server
log *logger.Logger
recoverproto.UnimplementedAPIServer
}
// New returns a new RecoveryServer.
func New(issuer atls.Issuer, log *logger.Logger) *RecoveryServer {
server := &RecoveryServer{
log: log,
}
grpcServer := grpc.NewServer(
grpc.Creds(atlscredentials.New(issuer, nil)),
log.Named("gRPC").GetServerStreamInterceptor(),
)
recoverproto.RegisterAPIServer(grpcServer, server)
server.grpcServer = grpcServer
return server
}
// Serve starts the recovery server.
// It blocks until a recover request call is successful.
// The server will shut down when the call is successful and the keys are returned.
// Additionally, the server can be shutdown by canceling the context.
func (s *RecoveryServer) Serve(ctx context.Context, listener net.Listener, diskUUID string) (diskKey, measurementSecret []byte, err error) {
s.log.Infof("Starting RecoveryServer")
s.diskUUID = diskUUID
recoveryDone := make(chan struct{}, 1)
var serveErr error
var wg sync.WaitGroup
wg.Add(1)
defer wg.Wait()
go func() {
defer wg.Done()
serveErr = s.grpcServer.Serve(listener)
recoveryDone <- struct{}{}
}()
for {
select {
case <-ctx.Done():
s.log.Infof("Context canceled, shutting down server")
s.grpcServer.GracefulStop()
return nil, nil, ctx.Err()
case <-recoveryDone:
if serveErr != nil {
return nil, nil, serveErr
}
return s.stateDiskKey, s.measurementSecret, nil
}
}
}
// Recover is a bidirectional streaming RPC that is used to send recovery keys to a restarting node.
func (s *RecoveryServer) Recover(stream recoverproto.API_RecoverServer) error {
s.mux.Lock()
defer s.mux.Unlock()
s.log.Infof("Received recover call")
msg, err := stream.Recv()
if err != nil {
return status.Error(codes.Internal, "failed to receive message")
}
measurementSecret, ok := msg.GetRequest().(*recoverproto.RecoverMessage_MeasurementSecret)
if !ok {
s.log.Errorf("Received invalid first message: not a measurement secret")
return status.Error(codes.InvalidArgument, "first message is not a measurement secret")
}
if err := stream.Send(&recoverproto.RecoverResponse{DiskUuid: s.diskUUID}); err != nil {
s.log.With(zap.Error(err)).Errorf("Failed to send disk UUID")
return status.Error(codes.Internal, "failed to send response")
}
msg, err = stream.Recv()
if err != nil {
s.log.With(zap.Error(err)).Errorf("Failed to receive disk key")
return status.Error(codes.Internal, "failed to receive message")
}
stateDiskKey, ok := msg.GetRequest().(*recoverproto.RecoverMessage_StateDiskKey)
if !ok {
s.log.Errorf("Received invalid second message: not a state disk key")
return status.Error(codes.InvalidArgument, "second message is not a state disk key")
}
s.stateDiskKey = stateDiskKey.StateDiskKey
s.measurementSecret = measurementSecret.MeasurementSecret
s.log.Infof("Received state disk key and measurement secret, shutting down server")
go s.grpcServer.GracefulStop()
return nil
}
type server interface {
Serve(net.Listener) error
GracefulStop()
}

View File

@ -0,0 +1,194 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package recoveryserver
import (
"context"
"io"
"sync"
"testing"
"time"
"github.com/edgelesssys/constellation/disk-mapper/recoverproto"
"github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/grpc/dialer"
"github.com/edgelesssys/constellation/internal/grpc/testdialer"
"github.com/edgelesssys/constellation/internal/logger"
"github.com/edgelesssys/constellation/internal/oid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestServe(t *testing.T) {
assert := assert.New(t)
log := logger.NewTest(t)
uuid := "uuid"
server := New(atls.NewFakeIssuer(oid.Dummy{}), log)
dialer := testdialer.NewBufconnDialer()
listener := dialer.GetListener("192.0.2.1:1234")
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
// Serve method returns when context is canceled
wg.Add(1)
go func() {
defer wg.Done()
_, _, err := server.Serve(ctx, listener, uuid)
assert.ErrorIs(err, context.Canceled)
}()
time.Sleep(100 * time.Millisecond)
cancel()
wg.Wait()
server = New(atls.NewFakeIssuer(oid.Dummy{}), log)
dialer = testdialer.NewBufconnDialer()
listener = dialer.GetListener("192.0.2.1:1234")
// Serve method returns without error when the server is shut down
wg.Add(1)
go func() {
defer wg.Done()
_, _, err := server.Serve(context.Background(), listener, uuid)
assert.NoError(err)
}()
time.Sleep(100 * time.Millisecond)
server.grpcServer.GracefulStop()
wg.Wait()
// Serve method returns an error when serving is unsuccessful
_, _, err := server.Serve(context.Background(), listener, uuid)
assert.Error(err)
}
func TestRecover(t *testing.T) {
testCases := map[string]struct {
initialMsg message
keyMsg message
wantErr bool
}{
"success": {
initialMsg: message{
recoverMsg: &recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_MeasurementSecret{
MeasurementSecret: []byte("measurementSecret"),
},
},
},
keyMsg: message{
recoverMsg: &recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_StateDiskKey{
StateDiskKey: []byte("diskKey"),
},
},
},
},
"first message is not a measurement secret": {
initialMsg: message{
recoverMsg: &recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_StateDiskKey{
StateDiskKey: []byte("diskKey"),
},
},
wantErr: true,
},
keyMsg: message{
recoverMsg: &recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_StateDiskKey{
StateDiskKey: []byte("diskKey"),
},
},
},
},
"second message is not a state disk key": {
initialMsg: message{
recoverMsg: &recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_MeasurementSecret{
MeasurementSecret: []byte("measurementSecret"),
},
},
},
keyMsg: message{
recoverMsg: &recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_MeasurementSecret{
MeasurementSecret: []byte("measurementSecret"),
},
},
wantErr: true,
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
ctx := context.Background()
serverUUID := "uuid"
server := New(atls.NewFakeIssuer(oid.Dummy{}), logger.NewTest(t))
netDialer := testdialer.NewBufconnDialer()
listener := netDialer.GetListener("192.0.2.1:1234")
var diskKey, measurementSecret []byte
var serveErr error
var wg sync.WaitGroup
defer wg.Wait()
serveCtx, cancel := context.WithCancel(ctx)
defer cancel()
wg.Add(1)
go func() {
defer wg.Done()
diskKey, measurementSecret, serveErr = server.Serve(serveCtx, listener, serverUUID)
}()
conn, err := dialer.New(nil, nil, netDialer).Dial(ctx, "192.0.2.1:1234")
require.NoError(err)
defer conn.Close()
client, err := recoverproto.NewAPIClient(conn).Recover(ctx)
require.NoError(err)
// Send initial message
err = client.Send(tc.initialMsg.recoverMsg)
require.NoError(err)
// Receive uuid
uuid, err := client.Recv()
if tc.initialMsg.wantErr {
assert.Error(err)
return
}
assert.Equal(serverUUID, uuid.DiskUuid)
// Send key message
err = client.Send(tc.keyMsg.recoverMsg)
require.NoError(err)
_, err = client.Recv()
if tc.keyMsg.wantErr {
assert.Error(err)
return
}
assert.ErrorIs(io.EOF, err)
wg.Wait()
assert.NoError(serveErr)
assert.Equal(tc.initialMsg.recoverMsg.GetMeasurementSecret(), measurementSecret)
assert.Equal(tc.keyMsg.recoverMsg.GetStateDiskKey(), diskKey)
})
}
}
type message struct {
recoverMsg *recoverproto.RecoverMessage
wantErr bool
}

View File

@ -0,0 +1,167 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package rejoinclient
import (
"context"
"errors"
"net"
"time"
"github.com/edgelesssys/constellation/internal/cloud/metadata"
"github.com/edgelesssys/constellation/internal/logger"
"github.com/edgelesssys/constellation/internal/role"
"github.com/edgelesssys/constellation/joinservice/joinproto"
"go.uber.org/zap"
"google.golang.org/grpc"
"k8s.io/utils/clock"
)
const (
interval = 30 * time.Second
timeout = 30 * time.Second
)
// RejoinClient is a client for requesting the needed information
// for rejoining a cluster as a restarting worker or control-plane node.
type RejoinClient struct {
diskUUID string
nodeInfo metadata.InstanceMetadata
timeout time.Duration
interval time.Duration
clock clock.WithTicker
dialer grpcDialer
metadataAPI metadataAPI
log *logger.Logger
}
// New returns a new RejoinClient.
func New(dial grpcDialer, nodeInfo metadata.InstanceMetadata,
meta metadataAPI, log *logger.Logger,
) *RejoinClient {
return &RejoinClient{
nodeInfo: nodeInfo,
timeout: timeout,
interval: interval,
clock: clock.RealClock{},
dialer: dial,
metadataAPI: meta,
log: log,
}
}
// Start starts the rejoin client.
// The client will continuously request available control-plane endpoints
// from the metadata API and send rejoin requests to them.
// The function returns after a successful rejoin request has been performed.
func (c *RejoinClient) Start(ctx context.Context, diskUUID string) (diskKey, measurementSecret []byte) {
c.log.Infof("Starting RejoinClient")
c.diskUUID = diskUUID
ticker := c.clock.NewTicker(c.interval)
defer ticker.Stop()
defer c.log.Infof("RejoinClient stopped")
for {
endpoints, err := c.getControlPlaneEndpoints()
if err != nil {
c.log.With(zap.Error(err)).Errorf("Failed to get control-plane endpoints")
} else {
c.log.With(zap.Strings("endpoints", endpoints)).Infof("Received list with JoinService endpoints")
diskKey, measurementSecret, err = c.tryRejoinWithAvailableServices(ctx, endpoints)
if err == nil {
c.log.Infof("Successfully retrieved rejoin ticket")
return diskKey, measurementSecret
}
}
select {
case <-ctx.Done():
return nil, nil
case <-ticker.C():
}
}
}
// tryRejoinWithAvailableServices tries sending rejoin requests to the available endpoints.
func (c *RejoinClient) tryRejoinWithAvailableServices(ctx context.Context, endpoints []string) (diskKey, measurementSecret []byte, err error) {
for _, endpoint := range endpoints {
c.log.With(zap.String("endpoint", endpoint)).Infof("Requesting rejoin ticket")
rejoinTicket, err := c.requestRejoinTicket(endpoint)
if err == nil {
return rejoinTicket.StateDiskKey, rejoinTicket.MeasurementSecret, nil
}
c.log.With(zap.Error(err), zap.String("endpoint", endpoint)).Warnf("Failed to rejoin on endpoint")
// stop requesting additional endpoints if the context is done
select {
case <-ctx.Done():
return nil, nil, ctx.Err()
default:
}
}
c.log.Errorf("Failed to rejoin on all endpoints")
return nil, nil, errors.New("failed to join on all endpoints")
}
// requestRejoinTicket requests a rejoin ticket from the endpoint.
func (c *RejoinClient) requestRejoinTicket(endpoint string) (*joinproto.IssueRejoinTicketResponse, error) {
ctx, cancel := c.timeoutCtx()
defer cancel()
conn, err := c.dialer.Dial(ctx, endpoint)
if err != nil {
return nil, err
}
defer conn.Close()
return joinproto.NewAPIClient(conn).IssueRejoinTicket(ctx, &joinproto.IssueRejoinTicketRequest{DiskUuid: c.diskUUID})
}
// getControlPlaneEndpoints requests the available control-plane endpoints from the metadata API.
// The list is filtered to remove *this* node if it is a restarting control-plane node.
func (c *RejoinClient) getControlPlaneEndpoints() ([]string, error) {
ctx, cancel := c.timeoutCtx()
defer cancel()
endpoints, err := metadata.JoinServiceEndpoints(ctx, c.metadataAPI)
if err != nil {
return nil, err
}
if c.nodeInfo.Role == role.ControlPlane {
return removeSelfFromEndpoints(c.nodeInfo.VPCIP, endpoints), nil
}
return endpoints, nil
}
// removeSelfFromEndpoints removes *this* node from the list of endpoints.
// If an error occurs, the entry is removed from the list of endpoints.
func removeSelfFromEndpoints(self string, endpoints []string) []string {
var result []string
for _, endpoint := range endpoints {
host, _, err := net.SplitHostPort(endpoint)
if err == nil && host != self {
result = append(result, endpoint)
}
}
return result
}
func (c *RejoinClient) timeoutCtx() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), c.timeout)
}
type grpcDialer interface {
Dial(ctx context.Context, target string) (*grpc.ClientConn, error)
}
type metadataAPI interface {
// List retrieves all instances belonging to the current constellation.
List(ctx context.Context) ([]metadata.InstanceMetadata, error)
}

View File

@ -0,0 +1,308 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package rejoinclient
import (
"context"
"errors"
"net"
"strconv"
"sync"
"testing"
"time"
"github.com/edgelesssys/constellation/internal/cloud/metadata"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/internal/grpc/dialer"
"github.com/edgelesssys/constellation/internal/grpc/testdialer"
"github.com/edgelesssys/constellation/internal/logger"
"github.com/edgelesssys/constellation/internal/role"
"github.com/edgelesssys/constellation/joinservice/joinproto"
"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
"google.golang.org/grpc"
testclock "k8s.io/utils/clock/testing"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestStartCancel(t *testing.T) {
netDialer := testdialer.NewBufconnDialer()
dialer := dialer.New(nil, nil, netDialer)
clock := testclock.NewFakeClock(time.Time{})
metaAPI := &stubMetadataAPI{
instances: []metadata.InstanceMetadata{
{
Role: role.ControlPlane,
VPCIP: "192.0.2.1",
},
{
Role: role.ControlPlane,
VPCIP: "192.0.2.1",
},
},
}
client := &RejoinClient{
dialer: dialer,
nodeInfo: metadata.InstanceMetadata{Role: role.Worker},
metadataAPI: metaAPI,
log: logger.NewTest(t),
timeout: time.Second * 30,
interval: time.Second,
clock: clock,
}
serverCreds := atlscredentials.New(nil, nil)
rejoinServer := grpc.NewServer(grpc.Creds(serverCreds))
rejoinServiceAPI := &stubRejoinServiceAPI{err: errors.New("error")}
joinproto.RegisterAPIServer(rejoinServer, rejoinServiceAPI)
port := strconv.Itoa(constants.JoinServiceNodePort)
listener := netDialer.GetListener(net.JoinHostPort("192.0.2.1", port))
go rejoinServer.Serve(listener)
defer rejoinServer.GracefulStop()
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
client.Start(ctx, "uuid")
}()
clock.Step(time.Millisecond)
cancel()
wg.Wait()
assert.Equal(t, client.diskUUID, "uuid")
}
func TestRemoveSelfFromEndpoints(t *testing.T) {
testCases := map[string]struct {
self string
endpoints []string
}{
"self is not in endpoints": {
self: "192.0.2.1",
endpoints: []string{
"192.0.2.2:30090",
"192.0.2.3:30090",
"192.0.2.4:30090",
"192.0.2.5:30090",
"192.0.2.6:30090",
},
},
"self is in endpoints": {
self: "192.0.2.1",
endpoints: []string{
"192.0.2.2:30090",
"192.0.2.3:30090",
"192.0.2.4:30090",
"192.0.2.5:30090",
"192.0.2.6:30090",
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
got := removeSelfFromEndpoints(tc.self, tc.endpoints)
assert.NotContains(got, tc.self)
})
}
}
func TestGetControlPlaneEndpoints(t *testing.T) {
testInstances := []metadata.InstanceMetadata{
{
Role: role.ControlPlane,
VPCIP: "192.0.2.2",
},
{
Role: role.ControlPlane,
VPCIP: "192.0.2.3",
},
{
Role: role.ControlPlane,
VPCIP: "192.0.2.4",
},
{
Role: role.Worker,
VPCIP: "192.0.2.12",
},
{
Role: role.Worker,
VPCIP: "192.0.2.13",
},
{
Role: role.Worker,
VPCIP: "192.0.2.14",
},
}
testCases := map[string]struct {
nodeInfo metadata.InstanceMetadata
meta stubMetadataAPI
wantInstances int
wantErr bool
}{
"worker node": {
nodeInfo: metadata.InstanceMetadata{
Role: role.Worker,
VPCIP: "192.0.2.1",
},
meta: stubMetadataAPI{
instances: testInstances,
},
wantInstances: 3,
},
"control-plane node not in list": {
nodeInfo: metadata.InstanceMetadata{
Role: role.ControlPlane,
VPCIP: "192.0.2.1",
},
meta: stubMetadataAPI{
instances: testInstances,
},
wantInstances: 3,
},
"control-plane node in list": {
nodeInfo: metadata.InstanceMetadata{
Role: role.ControlPlane,
VPCIP: "192.0.2.2",
},
meta: stubMetadataAPI{
instances: testInstances,
},
wantInstances: 2,
},
"metadata error": {
nodeInfo: metadata.InstanceMetadata{
Role: role.ControlPlane,
VPCIP: "192.0.2.1",
},
meta: stubMetadataAPI{
err: errors.New("error"),
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := New(nil, tc.nodeInfo, tc.meta, logger.NewTest(t))
endpoints, err := client.getControlPlaneEndpoints()
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.NotContains(endpoints, tc.nodeInfo.VPCIP)
assert.Len(endpoints, tc.wantInstances)
}
})
}
}
func TestStart(t *testing.T) {
testCases := map[string]struct {
nodeInfo metadata.InstanceMetadata
}{
"worker node": {
nodeInfo: metadata.InstanceMetadata{
Role: role.Worker,
VPCIP: "192.0.2.99",
},
},
"control-plane node": {
nodeInfo: metadata.InstanceMetadata{
Role: role.ControlPlane,
VPCIP: "192.0.2.99",
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
diskKey := []byte("disk-key")
measurementSecret := []byte("measurement-secret")
netDialer := testdialer.NewBufconnDialer()
dialer := dialer.New(nil, nil, netDialer)
serverCreds := atlscredentials.New(nil, nil)
rejoinServer := grpc.NewServer(grpc.Creds(serverCreds))
rejoinServiceAPI := &stubRejoinServiceAPI{
rejoinTicketResponse: &joinproto.IssueRejoinTicketResponse{
StateDiskKey: diskKey,
MeasurementSecret: measurementSecret,
},
}
joinproto.RegisterAPIServer(rejoinServer, rejoinServiceAPI)
port := strconv.Itoa(constants.JoinServiceNodePort)
listener := netDialer.GetListener(net.JoinHostPort("192.0.2.1", port))
go rejoinServer.Serve(listener)
defer rejoinServer.GracefulStop()
meta := stubMetadataAPI{
instances: []metadata.InstanceMetadata{
{
Role: role.ControlPlane,
VPCIP: "192.0.2.1",
},
{
Role: role.ControlPlane,
VPCIP: "192.0.2.2",
},
{
Role: role.Worker,
VPCIP: "192.0.2.13",
},
{
Role: role.Worker,
VPCIP: "192.0.2.14",
},
},
}
client := New(dialer, tc.nodeInfo, meta, logger.NewTest(t))
passphrase, secret := client.Start(context.Background(), "uuid")
assert.Equal(diskKey, passphrase)
assert.Equal(measurementSecret, secret)
})
}
}
type stubMetadataAPI struct {
instances []metadata.InstanceMetadata
err error
}
func (s stubMetadataAPI) List(context.Context) ([]metadata.InstanceMetadata, error) {
return s.instances, s.err
}
type stubRejoinServiceAPI struct {
rejoinTicketResponse *joinproto.IssueRejoinTicketResponse
err error
joinproto.UnimplementedAPIServer
}
func (s *stubRejoinServiceAPI) IssueRejoinTicket(context.Context, *joinproto.IssueRejoinTicketRequest,
) (*joinproto.IssueRejoinTicketResponse, error) {
return s.rejoinTicketResponse, s.err
}

View File

@ -10,6 +10,8 @@ import (
"io/fs" "io/fs"
"os" "os"
"syscall" "syscall"
"github.com/edgelesssys/constellation/internal/cloud/metadata"
) )
// Mounter is an interface for mount and unmount operations. // Mounter is an interface for mount and unmount operations.
@ -27,17 +29,23 @@ type DeviceMapper interface {
UnmapDisk(target string) error UnmapDisk(target string) error
} }
// KeyWaiter is an interface to request and wait for disk decryption keys.
type KeyWaiter interface {
WaitForDecryptionKey(uuid, addr string) (diskKey, measurementSecret []byte, err error)
ResetKey()
}
// ConfigurationGenerator is an interface for generating systemd-cryptsetup@.service unit files. // ConfigurationGenerator is an interface for generating systemd-cryptsetup@.service unit files.
type ConfigurationGenerator interface { type ConfigurationGenerator interface {
Generate(volumeName, encryptedDevice, keyFile, options string) error Generate(volumeName, encryptedDevice, keyFile, options string) error
} }
// MetadataAPI is an interface for accessing cloud metadata.
type MetadataAPI interface {
metadata.InstanceSelfer
metadata.InstanceLister
}
// RecoveryDoer is an interface to perform key recovery operations.
// Calls to Do may be blocking, and if successful return a passphrase and measurementSecret.
type RecoveryDoer interface {
Do(uuid, endpoint string) (passphrase, measurementSecret []byte, err error)
}
// DiskMounter uses the syscall package to mount disks. // DiskMounter uses the syscall package to mount disks.
type DiskMounter struct{} type DiskMounter struct{}

View File

@ -7,14 +7,18 @@ SPDX-License-Identifier: AGPL-3.0-only
package setup package setup
import ( import (
"context"
"crypto/rand" "crypto/rand"
"errors" "errors"
"fmt"
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
"sync"
"syscall" "syscall"
"github.com/edgelesssys/constellation/disk-mapper/internal/systemd"
"github.com/edgelesssys/constellation/internal/attestation" "github.com/edgelesssys/constellation/internal/attestation"
"github.com/edgelesssys/constellation/internal/attestation/vtpm" "github.com/edgelesssys/constellation/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/internal/constants" "github.com/edgelesssys/constellation/internal/constants"
@ -22,9 +26,7 @@ import (
"github.com/edgelesssys/constellation/internal/file" "github.com/edgelesssys/constellation/internal/file"
"github.com/edgelesssys/constellation/internal/logger" "github.com/edgelesssys/constellation/internal/logger"
"github.com/edgelesssys/constellation/internal/nodestate" "github.com/edgelesssys/constellation/internal/nodestate"
"github.com/edgelesssys/constellation/state/internal/systemd"
"github.com/spf13/afero" "github.com/spf13/afero"
"go.uber.org/zap"
) )
const ( const (
@ -38,56 +40,53 @@ const (
// SetupManager handles formatting, mapping, mounting and unmounting of state disks. // SetupManager handles formatting, mapping, mounting and unmounting of state disks.
type SetupManager struct { type SetupManager struct {
log *logger.Logger log *logger.Logger
csp string csp string
diskPath string diskPath string
fs afero.Afero fs afero.Afero
keyWaiter KeyWaiter mapper DeviceMapper
mapper DeviceMapper mounter Mounter
mounter Mounter config ConfigurationGenerator
config ConfigurationGenerator openTPM vtpm.TPMOpenFunc
openTPM vtpm.TPMOpenFunc
} }
// New initializes a SetupManager with the given parameters. // New initializes a SetupManager with the given parameters.
func New(log *logger.Logger, csp string, diskPath string, fs afero.Afero, keyWaiter KeyWaiter, mapper DeviceMapper, mounter Mounter, openTPM vtpm.TPMOpenFunc) *SetupManager { func New(log *logger.Logger, csp string, diskPath string, fs afero.Afero,
mapper DeviceMapper, mounter Mounter, openTPM vtpm.TPMOpenFunc,
) *SetupManager {
return &SetupManager{ return &SetupManager{
log: log, log: log,
csp: csp, csp: csp,
diskPath: diskPath, diskPath: diskPath,
fs: fs, fs: fs,
keyWaiter: keyWaiter, mapper: mapper,
mapper: mapper, mounter: mounter,
mounter: mounter, config: systemd.New(fs),
config: systemd.New(fs), openTPM: openTPM,
openTPM: openTPM,
} }
} }
// PrepareExistingDisk requests and waits for a decryption key to remap the encrypted state disk. // PrepareExistingDisk requests and waits for a decryption key to remap the encrypted state disk.
// Once the disk is mapped, the function taints the node as initialized by updating it's PCRs. // Once the disk is mapped, the function taints the node as initialized by updating it's PCRs.
func (s *SetupManager) PrepareExistingDisk() error { func (s *SetupManager) PrepareExistingDisk(recover RecoveryDoer) error {
s.log.Infof("Preparing existing state disk") s.log.Infof("Preparing existing state disk")
uuid := s.mapper.DiskUUID() uuid := s.mapper.DiskUUID()
endpoint := net.JoinHostPort("0.0.0.0", strconv.Itoa(constants.RecoveryPort)) endpoint := net.JoinHostPort("0.0.0.0", strconv.Itoa(constants.RecoveryPort))
getKey:
passphrase, measurementSecret, err := s.keyWaiter.WaitForDecryptionKey(uuid, endpoint) passphrase, measurementSecret, err := recover.Do(uuid, endpoint)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to perform recovery: %w", err)
} }
if err := s.mapper.MapDisk(stateDiskMappedName, string(passphrase)); err != nil { if err := s.mapper.MapDisk(stateDiskMappedName, string(passphrase)); err != nil {
// retry key fetching if disk mapping fails return err
s.log.With(zap.Error(err)).Errorf("Failed to map state disk, retrying...")
s.keyWaiter.ResetKey()
goto getKey
} }
if err := s.mounter.MkdirAll(stateDiskMountPath, os.ModePerm); err != nil { if err := s.mounter.MkdirAll(stateDiskMountPath, os.ModePerm); err != nil {
return err return err
} }
// we do not care about cleaning up the mount point on error, since any errors returned here should result in a kernel panic in the main function // we do not care about cleaning up the mount point on error, since any errors returned here should cause a boot failure
if err := s.mounter.Mount(filepath.Join("/dev/mapper/", stateDiskMappedName), stateDiskMountPath, "ext4", syscall.MS_RDONLY, ""); err != nil { if err := s.mounter.Mount(filepath.Join("/dev/mapper/", stateDiskMappedName), stateDiskMountPath, "ext4", syscall.MS_RDONLY, ""); err != nil {
return err return err
} }
@ -160,3 +159,68 @@ func (s *SetupManager) saveConfiguration(passphrase []byte) error {
// systemd cryptsetup unit // systemd cryptsetup unit
return s.config.Generate(stateDiskMappedName, s.diskPath, filepath.Join(keyPath, keyFile), cryptsetupOptions) return s.config.Generate(stateDiskMappedName, s.diskPath, filepath.Join(keyPath, keyFile), cryptsetupOptions)
} }
type recoveryServer interface {
Serve(context.Context, net.Listener, string) (key, secret []byte, err error)
}
type rejoinClient interface {
Start(context.Context, string) (key, secret []byte)
}
type nodeRecoverer struct {
recoveryServer recoveryServer
rejoinClient rejoinClient
}
// NewNodeRecoverer initializes a new nodeRecoverer.
func NewNodeRecoverer(recoveryServer recoveryServer, rejoinClient rejoinClient) *nodeRecoverer {
return &nodeRecoverer{
recoveryServer: recoveryServer,
rejoinClient: rejoinClient,
}
}
// Do performs a recovery procedure on the given state disk.
// The method starts a gRPC server to allow manual recovery by a user.
// At the same time it tries to request a decryption key from all available Constellation control-plane nodes.
func (r *nodeRecoverer) Do(uuid, endpoint string) (passphrase, measurementSecret []byte, err error) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
lis, err := net.Listen("tcp", endpoint)
if err != nil {
return nil, nil, err
}
defer lis.Close()
var once sync.Once
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
key, secret, serveErr := r.recoveryServer.Serve(ctx, lis, uuid)
once.Do(func() {
cancel()
passphrase = key
measurementSecret = secret
})
if serveErr != nil && !errors.Is(serveErr, context.Canceled) {
err = serveErr
}
}()
wg.Add(1)
go func() {
defer wg.Done()
key, secret := r.rejoinClient.Start(ctx, uuid)
once.Do(func() {
cancel()
passphrase = key
measurementSecret = secret
})
}()
wg.Wait()
return passphrase, measurementSecret, err
}

View File

@ -7,10 +7,13 @@ SPDX-License-Identifier: AGPL-3.0-only
package setup package setup
import ( import (
"context"
"errors" "errors"
"io" "io"
"io/fs" "io/fs"
"net"
"path/filepath" "path/filepath"
"sync"
"testing" "testing"
"github.com/edgelesssys/constellation/internal/attestation/vtpm" "github.com/edgelesssys/constellation/internal/attestation/vtpm"
@ -30,9 +33,13 @@ func TestMain(m *testing.M) {
func TestPrepareExistingDisk(t *testing.T) { func TestPrepareExistingDisk(t *testing.T) {
someErr := errors.New("error") someErr := errors.New("error")
testRecoveryDoer := &stubRecoveryDoer{
passphrase: []byte("passphrase"),
secret: []byte("secret"),
}
testCases := map[string]struct { testCases := map[string]struct {
keyWaiter *stubKeyWaiter recoveryDoer *stubRecoveryDoer
mapper *stubMapper mapper *stubMapper
mounter *stubMounter mounter *stubMounter
configGenerator *stubConfigurationGenerator configGenerator *stubConfigurationGenerator
@ -41,34 +48,33 @@ func TestPrepareExistingDisk(t *testing.T) {
wantErr bool wantErr bool
}{ }{
"success": { "success": {
keyWaiter: &stubKeyWaiter{}, recoveryDoer: testRecoveryDoer,
mapper: &stubMapper{uuid: "test"}, mapper: &stubMapper{uuid: "test"},
mounter: &stubMounter{}, mounter: &stubMounter{},
configGenerator: &stubConfigurationGenerator{}, configGenerator: &stubConfigurationGenerator{},
openTPM: vtpm.OpenNOPTPM, openTPM: vtpm.OpenNOPTPM,
}, },
"WaitForDecryptionKey fails": { "WaitForDecryptionKey fails": {
keyWaiter: &stubKeyWaiter{waitErr: someErr}, recoveryDoer: &stubRecoveryDoer{recoveryErr: someErr},
mapper: &stubMapper{uuid: "test"}, mapper: &stubMapper{uuid: "test"},
mounter: &stubMounter{}, mounter: &stubMounter{},
configGenerator: &stubConfigurationGenerator{}, configGenerator: &stubConfigurationGenerator{},
openTPM: vtpm.OpenNOPTPM, openTPM: vtpm.OpenNOPTPM,
wantErr: true, wantErr: true,
}, },
"MapDisk fails causes a repeat": { "MapDisk fails": {
keyWaiter: &stubKeyWaiter{}, recoveryDoer: testRecoveryDoer,
mapper: &stubMapper{ mapper: &stubMapper{
uuid: "test", uuid: "test",
mapDiskErr: someErr, mapDiskErr: someErr,
mapDiskRepeatedCalls: 2,
}, },
mounter: &stubMounter{}, mounter: &stubMounter{},
configGenerator: &stubConfigurationGenerator{}, configGenerator: &stubConfigurationGenerator{},
openTPM: vtpm.OpenNOPTPM, openTPM: vtpm.OpenNOPTPM,
wantErr: false, wantErr: true,
}, },
"MkdirAll fails": { "MkdirAll fails": {
keyWaiter: &stubKeyWaiter{}, recoveryDoer: testRecoveryDoer,
mapper: &stubMapper{uuid: "test"}, mapper: &stubMapper{uuid: "test"},
mounter: &stubMounter{mkdirAllErr: someErr}, mounter: &stubMounter{mkdirAllErr: someErr},
configGenerator: &stubConfigurationGenerator{}, configGenerator: &stubConfigurationGenerator{},
@ -76,7 +82,7 @@ func TestPrepareExistingDisk(t *testing.T) {
wantErr: true, wantErr: true,
}, },
"Mount fails": { "Mount fails": {
keyWaiter: &stubKeyWaiter{}, recoveryDoer: testRecoveryDoer,
mapper: &stubMapper{uuid: "test"}, mapper: &stubMapper{uuid: "test"},
mounter: &stubMounter{mountErr: someErr}, mounter: &stubMounter{mountErr: someErr},
configGenerator: &stubConfigurationGenerator{}, configGenerator: &stubConfigurationGenerator{},
@ -84,7 +90,7 @@ func TestPrepareExistingDisk(t *testing.T) {
wantErr: true, wantErr: true,
}, },
"Unmount fails": { "Unmount fails": {
keyWaiter: &stubKeyWaiter{}, recoveryDoer: testRecoveryDoer,
mapper: &stubMapper{uuid: "test"}, mapper: &stubMapper{uuid: "test"},
mounter: &stubMounter{unmountErr: someErr}, mounter: &stubMounter{unmountErr: someErr},
configGenerator: &stubConfigurationGenerator{}, configGenerator: &stubConfigurationGenerator{},
@ -92,7 +98,7 @@ func TestPrepareExistingDisk(t *testing.T) {
wantErr: true, wantErr: true,
}, },
"MarkNodeAsBootstrapped fails": { "MarkNodeAsBootstrapped fails": {
keyWaiter: &stubKeyWaiter{}, recoveryDoer: testRecoveryDoer,
mapper: &stubMapper{uuid: "test"}, mapper: &stubMapper{uuid: "test"},
mounter: &stubMounter{unmountErr: someErr}, mounter: &stubMounter{unmountErr: someErr},
configGenerator: &stubConfigurationGenerator{}, configGenerator: &stubConfigurationGenerator{},
@ -100,7 +106,7 @@ func TestPrepareExistingDisk(t *testing.T) {
wantErr: true, wantErr: true,
}, },
"Generating config fails": { "Generating config fails": {
keyWaiter: &stubKeyWaiter{}, recoveryDoer: testRecoveryDoer,
mapper: &stubMapper{uuid: "test"}, mapper: &stubMapper{uuid: "test"},
mounter: &stubMounter{}, mounter: &stubMounter{},
configGenerator: &stubConfigurationGenerator{generateErr: someErr}, configGenerator: &stubConfigurationGenerator{generateErr: someErr},
@ -108,7 +114,7 @@ func TestPrepareExistingDisk(t *testing.T) {
wantErr: true, wantErr: true,
}, },
"no state file": { "no state file": {
keyWaiter: &stubKeyWaiter{}, recoveryDoer: testRecoveryDoer,
mapper: &stubMapper{uuid: "test"}, mapper: &stubMapper{uuid: "test"},
mounter: &stubMounter{}, mounter: &stubMounter{},
configGenerator: &stubConfigurationGenerator{}, configGenerator: &stubConfigurationGenerator{},
@ -130,23 +136,21 @@ func TestPrepareExistingDisk(t *testing.T) {
} }
setupManager := &SetupManager{ setupManager := &SetupManager{
log: logger.NewTest(t), log: logger.NewTest(t),
csp: "test", csp: "test",
diskPath: "disk-path", diskPath: "disk-path",
fs: fs, fs: fs,
keyWaiter: tc.keyWaiter, mapper: tc.mapper,
mapper: tc.mapper, mounter: tc.mounter,
mounter: tc.mounter, config: tc.configGenerator,
config: tc.configGenerator, openTPM: tc.openTPM,
openTPM: tc.openTPM,
} }
err := setupManager.PrepareExistingDisk() err := setupManager.PrepareExistingDisk(tc.recoveryDoer)
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
} else { } else {
assert.NoError(err) assert.NoError(err)
assert.Equal(tc.mapper.uuid, tc.keyWaiter.receivedUUID)
assert.True(tc.mapper.mapDiskCalled) assert.True(tc.mapper.mapDiskCalled)
assert.True(tc.mounter.mountCalled) assert.True(tc.mounter.mountCalled)
assert.True(tc.mounter.unmountCalled) assert.True(tc.mounter.unmountCalled)
@ -191,9 +195,8 @@ func TestPrepareNewDisk(t *testing.T) {
"MapDisk fails": { "MapDisk fails": {
fs: afero.Afero{Fs: afero.NewMemMapFs()}, fs: afero.Afero{Fs: afero.NewMemMapFs()},
mapper: &stubMapper{ mapper: &stubMapper{
uuid: "test", uuid: "test",
mapDiskErr: someErr, mapDiskErr: someErr,
mapDiskRepeatedCalls: 1,
}, },
configGenerator: &stubConfigurationGenerator{}, configGenerator: &stubConfigurationGenerator{},
wantErr: true, wantErr: true,
@ -267,7 +270,7 @@ func TestReadMeasurementSalt(t *testing.T) {
require.NoError(handler.WriteJSON("test-state.json", state, file.OptMkdirAll)) require.NoError(handler.WriteJSON("test-state.json", state, file.OptMkdirAll))
} }
setupManager := New(logger.NewTest(t), "test", "disk-path", fs, nil, nil, nil, nil) setupManager := New(logger.NewTest(t), "test", "disk-path", fs, nil, nil, nil)
measurementSalt, err := setupManager.readMeasurementSalt("test-state.json") measurementSalt, err := setupManager.readMeasurementSalt("test-state.json")
if tc.wantErr { if tc.wantErr {
@ -280,15 +283,115 @@ func TestReadMeasurementSalt(t *testing.T) {
} }
} }
func TestRecoveryDoer(t *testing.T) {
assert := assert.New(t)
rejoinClientKey := []byte("rejoinClientKey")
rejoinClientSecret := []byte("rejoinClientSecret")
recoveryServerKey := []byte("recoveryServerKey")
recoveryServerSecret := []byte("recoveryServerSecret")
recoveryServerErr := errors.New("error")
recoveryServer := &stubRecoveryServer{
key: recoveryServerKey,
secret: recoveryServerSecret,
sendKeys: make(chan struct{}, 1),
err: recoveryServerErr,
}
rejoinClient := &stubRejoinClient{
key: rejoinClientKey,
secret: rejoinClientSecret,
sendKeys: make(chan struct{}, 1),
}
recoverer := NewNodeRecoverer(recoveryServer, rejoinClient)
var wg sync.WaitGroup
var key, secret []byte
var err error
// error from recovery server
wg.Add(1)
go func() {
defer wg.Done()
key, secret, err = recoverer.Do("", "")
}()
recoveryServer.sendKeys <- struct{}{}
wg.Wait()
assert.ErrorIs(err, recoveryServerErr)
recoveryServer.err = nil
recoveryServer.sendKeys = make(chan struct{}, 1)
// recovery server returns its key and secret
wg.Add(1)
go func() {
defer wg.Done()
key, secret, err = recoverer.Do("", "")
}()
recoveryServer.sendKeys <- struct{}{}
wg.Wait()
assert.NoError(err)
assert.Equal(recoveryServerKey, key)
assert.Equal(recoveryServerSecret, secret)
recoveryServer.sendKeys = make(chan struct{}, 1)
// rejoin client returns its key and secret
wg.Add(1)
go func() {
defer wg.Done()
key, secret, err = recoverer.Do("", "")
}()
rejoinClient.sendKeys <- struct{}{}
wg.Wait()
assert.NoError(err)
assert.Equal(rejoinClientKey, key)
assert.Equal(rejoinClientSecret, secret)
}
type stubRecoveryServer struct {
key []byte
secret []byte
sendKeys chan struct{}
err error
}
func (s *stubRecoveryServer) Serve(ctx context.Context, _ net.Listener, _ string) ([]byte, []byte, error) {
for {
select {
case <-ctx.Done():
return nil, nil, ctx.Err()
case <-s.sendKeys:
return s.key, s.secret, s.err
}
}
}
type stubRejoinClient struct {
key []byte
secret []byte
sendKeys chan struct{}
}
func (s *stubRejoinClient) Start(ctx context.Context, _ string) ([]byte, []byte) {
for {
select {
case <-ctx.Done():
return nil, nil
case <-s.sendKeys:
return s.key, s.secret
}
}
}
type stubMapper struct { type stubMapper struct {
formatDiskCalled bool formatDiskCalled bool
formatDiskErr error formatDiskErr error
mapDiskRepeatedCalls int mapDiskCalled bool
mapDiskCalled bool mapDiskErr error
mapDiskErr error unmapDiskCalled bool
unmapDiskCalled bool unmapDiskErr error
unmapDiskErr error uuid string
uuid string
} }
func (s *stubMapper) DiskUUID() string { func (s *stubMapper) DiskUUID() string {
@ -301,10 +404,6 @@ func (s *stubMapper) FormatDisk(string) error {
} }
func (s *stubMapper) MapDisk(string, string) error { func (s *stubMapper) MapDisk(string, string) error {
if s.mapDiskRepeatedCalls == 0 {
s.mapDiskErr = nil
}
s.mapDiskRepeatedCalls--
s.mapDiskCalled = true s.mapDiskCalled = true
return s.mapDiskErr return s.mapDiskErr
} }
@ -336,25 +435,14 @@ func (s *stubMounter) MkdirAll(path string, perm fs.FileMode) error {
return s.mkdirAllErr return s.mkdirAllErr
} }
type stubKeyWaiter struct { type stubRecoveryDoer struct {
receivedUUID string passphrase []byte
decryptionKey []byte secret []byte
measurementSecret []byte recoveryErr error
waitErr error
waitCalled bool
} }
func (s *stubKeyWaiter) WaitForDecryptionKey(uuid, addr string) ([]byte, []byte, error) { func (s *stubRecoveryDoer) Do(uuid, endpoint string) (passphrase, measurementSecret []byte, err error) {
if s.waitCalled { return s.passphrase, s.secret, s.recoveryErr
return nil, nil, errors.New("wait called before key was reset")
}
s.waitCalled = true
s.receivedUUID = uuid
return s.decryptionKey, s.measurementSecret, s.waitErr
}
func (s *stubKeyWaiter) ResetKey() {
s.waitCalled = false
} }
type stubConfigurationGenerator struct { type stubConfigurationGenerator struct {

View File

@ -13,8 +13,8 @@ import (
"math" "math"
"testing" "testing"
"github.com/edgelesssys/constellation/disk-mapper/internal/mapper"
"github.com/edgelesssys/constellation/internal/logger" "github.com/edgelesssys/constellation/internal/logger"
"github.com/edgelesssys/constellation/state/internal/mapper"
"github.com/martinjungblut/go-cryptsetup" "github.com/martinjungblut/go-cryptsetup"
"go.uber.org/zap/zapcore" "go.uber.org/zap/zapcore"
) )

View File

@ -9,29 +9,18 @@ SPDX-License-Identifier: AGPL-3.0-only
package integration package integration
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"net"
"os" "os"
"os/exec" "os/exec"
"testing" "testing"
"time"
"github.com/edgelesssys/constellation/internal/atls" "github.com/edgelesssys/constellation/disk-mapper/internal/mapper"
"github.com/edgelesssys/constellation/internal/cloud/metadata"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/internal/logger" "github.com/edgelesssys/constellation/internal/logger"
"github.com/edgelesssys/constellation/internal/oid"
"github.com/edgelesssys/constellation/internal/role"
"github.com/edgelesssys/constellation/state/internal/keyservice"
"github.com/edgelesssys/constellation/state/internal/mapper"
"github.com/edgelesssys/constellation/state/keyproto"
"github.com/martinjungblut/go-cryptsetup" "github.com/martinjungblut/go-cryptsetup"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/goleak" "go.uber.org/goleak"
"google.golang.org/grpc"
) )
const ( const (
@ -93,53 +82,7 @@ func TestMapper(t *testing.T) {
assert.Error(mapper.MapDisk(mappedDevice, "invalid-passphrase"), "was able to map disk with incorrect passphrase") assert.Error(mapper.MapDisk(mappedDevice, "invalid-passphrase"), "was able to map disk with incorrect passphrase")
} }
func TestKeyAPI(t *testing.T) { /*
require := require.New(t)
assert := assert.New(t)
testKey := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
testSecret := []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB")
// get a free port on localhost to run the test on
listener, err := net.Listen("tcp", "localhost:0")
require.NoError(err)
apiAddr := listener.Addr().String()
listener.Close()
api := keyservice.New(
logger.NewTest(t),
atls.NewFakeIssuer(oid.Dummy{}),
&fakeMetadataAPI{},
20*time.Second,
time.Second,
)
// send a key to the server
go func() {
// wait 2 seconds before sending the key
time.Sleep(2 * time.Second)
creds := atlscredentials.New(nil, nil)
conn, err := grpc.Dial(apiAddr, grpc.WithTransportCredentials(creds))
require.NoError(err)
defer conn.Close()
client := keyproto.NewAPIClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
_, err = client.PushStateDiskKey(ctx, &keyproto.PushStateDiskKeyRequest{
StateDiskKey: testKey,
MeasurementSecret: testSecret,
})
require.NoError(err)
}()
key, measurementSecret, err := api.WaitForDecryptionKey("12345678-1234-1234-1234-123456789ABC", apiAddr)
assert.NoError(err)
assert.Equal(testKey, key)
assert.Equal(testSecret, measurementSecret)
}
type fakeMetadataAPI struct{} type fakeMetadataAPI struct{}
func (f *fakeMetadataAPI) List(ctx context.Context) ([]metadata.InstanceMetadata, error) { func (f *fakeMetadataAPI) List(ctx context.Context) ([]metadata.InstanceMetadata, error) {
@ -152,3 +95,4 @@ func (f *fakeMetadataAPI) List(ctx context.Context) ([]metadata.InstanceMetadata
}, },
}, nil }, nil
} }
*/

View File

@ -0,0 +1,258 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.0
// protoc v3.20.1
// source: recover.proto
package recoverproto
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type RecoverMessage struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// Types that are assignable to Request:
// *RecoverMessage_StateDiskKey
// *RecoverMessage_MeasurementSecret
Request isRecoverMessage_Request `protobuf_oneof:"request"`
}
func (x *RecoverMessage) Reset() {
*x = RecoverMessage{}
if protoimpl.UnsafeEnabled {
mi := &file_recover_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *RecoverMessage) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*RecoverMessage) ProtoMessage() {}
func (x *RecoverMessage) ProtoReflect() protoreflect.Message {
mi := &file_recover_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use RecoverMessage.ProtoReflect.Descriptor instead.
func (*RecoverMessage) Descriptor() ([]byte, []int) {
return file_recover_proto_rawDescGZIP(), []int{0}
}
func (m *RecoverMessage) GetRequest() isRecoverMessage_Request {
if m != nil {
return m.Request
}
return nil
}
func (x *RecoverMessage) GetStateDiskKey() []byte {
if x, ok := x.GetRequest().(*RecoverMessage_StateDiskKey); ok {
return x.StateDiskKey
}
return nil
}
func (x *RecoverMessage) GetMeasurementSecret() []byte {
if x, ok := x.GetRequest().(*RecoverMessage_MeasurementSecret); ok {
return x.MeasurementSecret
}
return nil
}
type isRecoverMessage_Request interface {
isRecoverMessage_Request()
}
type RecoverMessage_StateDiskKey struct {
StateDiskKey []byte `protobuf:"bytes,1,opt,name=state_disk_key,json=stateDiskKey,proto3,oneof"`
}
type RecoverMessage_MeasurementSecret struct {
MeasurementSecret []byte `protobuf:"bytes,2,opt,name=measurement_secret,json=measurementSecret,proto3,oneof"`
}
func (*RecoverMessage_StateDiskKey) isRecoverMessage_Request() {}
func (*RecoverMessage_MeasurementSecret) isRecoverMessage_Request() {}
type RecoverResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
DiskUuid string `protobuf:"bytes,1,opt,name=disk_uuid,json=diskUuid,proto3" json:"disk_uuid,omitempty"`
}
func (x *RecoverResponse) Reset() {
*x = RecoverResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_recover_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *RecoverResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*RecoverResponse) ProtoMessage() {}
func (x *RecoverResponse) ProtoReflect() protoreflect.Message {
mi := &file_recover_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use RecoverResponse.ProtoReflect.Descriptor instead.
func (*RecoverResponse) Descriptor() ([]byte, []int) {
return file_recover_proto_rawDescGZIP(), []int{1}
}
func (x *RecoverResponse) GetDiskUuid() string {
if x != nil {
return x.DiskUuid
}
return ""
}
var File_recover_proto protoreflect.FileDescriptor
var file_recover_proto_rawDesc = []byte{
0x0a, 0x0d, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12,
0x0c, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x74, 0x0a,
0x0e, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12,
0x26, 0x0a, 0x0e, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5f, 0x64, 0x69, 0x73, 0x6b, 0x5f, 0x6b, 0x65,
0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x48, 0x00, 0x52, 0x0c, 0x73, 0x74, 0x61, 0x74, 0x65,
0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79, 0x12, 0x2f, 0x0a, 0x12, 0x6d, 0x65, 0x61, 0x73, 0x75,
0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20,
0x01, 0x28, 0x0c, 0x48, 0x00, 0x52, 0x11, 0x6d, 0x65, 0x61, 0x73, 0x75, 0x72, 0x65, 0x6d, 0x65,
0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x42, 0x09, 0x0a, 0x07, 0x72, 0x65, 0x71, 0x75,
0x65, 0x73, 0x74, 0x22, 0x2e, 0x0a, 0x0f, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x52, 0x65,
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x64, 0x69, 0x73, 0x6b, 0x5f, 0x75,
0x75, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x64, 0x69, 0x73, 0x6b, 0x55,
0x75, 0x69, 0x64, 0x32, 0x53, 0x0a, 0x03, 0x41, 0x50, 0x49, 0x12, 0x4c, 0x0a, 0x07, 0x52, 0x65,
0x63, 0x6f, 0x76, 0x65, 0x72, 0x12, 0x1c, 0x2e, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x4d, 0x65, 0x73, 0x73,
0x61, 0x67, 0x65, 0x1a, 0x1d, 0x2e, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x3f, 0x5a, 0x3d, 0x67, 0x69, 0x74, 0x68,
0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x65, 0x64, 0x67, 0x65, 0x6c, 0x65, 0x73, 0x73, 0x73,
0x79, 0x73, 0x2f, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x65, 0x6c, 0x6c, 0x61, 0x74, 0x69, 0x6f, 0x6e,
0x2f, 0x64, 0x69, 0x73, 0x6b, 0x2d, 0x6d, 0x61, 0x70, 0x70, 0x65, 0x72, 0x2f, 0x72, 0x65, 0x63,
0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x33,
}
var (
file_recover_proto_rawDescOnce sync.Once
file_recover_proto_rawDescData = file_recover_proto_rawDesc
)
func file_recover_proto_rawDescGZIP() []byte {
file_recover_proto_rawDescOnce.Do(func() {
file_recover_proto_rawDescData = protoimpl.X.CompressGZIP(file_recover_proto_rawDescData)
})
return file_recover_proto_rawDescData
}
var file_recover_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_recover_proto_goTypes = []interface{}{
(*RecoverMessage)(nil), // 0: recoverproto.RecoverMessage
(*RecoverResponse)(nil), // 1: recoverproto.RecoverResponse
}
var file_recover_proto_depIdxs = []int32{
0, // 0: recoverproto.API.Recover:input_type -> recoverproto.RecoverMessage
1, // 1: recoverproto.API.Recover:output_type -> recoverproto.RecoverResponse
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_recover_proto_init() }
func file_recover_proto_init() {
if File_recover_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_recover_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*RecoverMessage); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_recover_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*RecoverResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
file_recover_proto_msgTypes[0].OneofWrappers = []interface{}{
(*RecoverMessage_StateDiskKey)(nil),
(*RecoverMessage_MeasurementSecret)(nil),
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_recover_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_recover_proto_goTypes,
DependencyIndexes: file_recover_proto_depIdxs,
MessageInfos: file_recover_proto_msgTypes,
}.Build()
File_recover_proto = out.File
file_recover_proto_rawDesc = nil
file_recover_proto_goTypes = nil
file_recover_proto_depIdxs = nil
}

View File

@ -0,0 +1,20 @@
syntax = "proto3";
package recoverproto;
option go_package = "github.com/edgelesssys/constellation/disk-mapper/recoverproto";
service API {
rpc Recover(stream RecoverMessage) returns (stream RecoverResponse) {}
}
message RecoverMessage {
oneof request {
bytes state_disk_key = 1;
bytes measurement_secret = 2;
}
}
message RecoverResponse {
string disk_uuid = 1;
}

View File

@ -2,9 +2,9 @@
// versions: // versions:
// - protoc-gen-go-grpc v1.2.0 // - protoc-gen-go-grpc v1.2.0
// - protoc v3.20.1 // - protoc v3.20.1
// source: keyservice.proto // source: recover.proto
package keyproto package recoverproto
import ( import (
context "context" context "context"
@ -22,7 +22,7 @@ const _ = grpc.SupportPackageIsVersion7
// //
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type APIClient interface { type APIClient interface {
PushStateDiskKey(ctx context.Context, in *PushStateDiskKeyRequest, opts ...grpc.CallOption) (*PushStateDiskKeyResponse, error) Recover(ctx context.Context, opts ...grpc.CallOption) (API_RecoverClient, error)
} }
type aPIClient struct { type aPIClient struct {
@ -33,20 +33,42 @@ func NewAPIClient(cc grpc.ClientConnInterface) APIClient {
return &aPIClient{cc} return &aPIClient{cc}
} }
func (c *aPIClient) PushStateDiskKey(ctx context.Context, in *PushStateDiskKeyRequest, opts ...grpc.CallOption) (*PushStateDiskKeyResponse, error) { func (c *aPIClient) Recover(ctx context.Context, opts ...grpc.CallOption) (API_RecoverClient, error) {
out := new(PushStateDiskKeyResponse) stream, err := c.cc.NewStream(ctx, &API_ServiceDesc.Streams[0], "/recoverproto.API/Recover", opts...)
err := c.cc.Invoke(ctx, "/keyproto.API/PushStateDiskKey", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return out, nil x := &aPIRecoverClient{stream}
return x, nil
}
type API_RecoverClient interface {
Send(*RecoverMessage) error
Recv() (*RecoverResponse, error)
grpc.ClientStream
}
type aPIRecoverClient struct {
grpc.ClientStream
}
func (x *aPIRecoverClient) Send(m *RecoverMessage) error {
return x.ClientStream.SendMsg(m)
}
func (x *aPIRecoverClient) Recv() (*RecoverResponse, error) {
m := new(RecoverResponse)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
} }
// APIServer is the server API for API service. // APIServer is the server API for API service.
// All implementations must embed UnimplementedAPIServer // All implementations must embed UnimplementedAPIServer
// for forward compatibility // for forward compatibility
type APIServer interface { type APIServer interface {
PushStateDiskKey(context.Context, *PushStateDiskKeyRequest) (*PushStateDiskKeyResponse, error) Recover(API_RecoverServer) error
mustEmbedUnimplementedAPIServer() mustEmbedUnimplementedAPIServer()
} }
@ -54,8 +76,8 @@ type APIServer interface {
type UnimplementedAPIServer struct { type UnimplementedAPIServer struct {
} }
func (UnimplementedAPIServer) PushStateDiskKey(context.Context, *PushStateDiskKeyRequest) (*PushStateDiskKeyResponse, error) { func (UnimplementedAPIServer) Recover(API_RecoverServer) error {
return nil, status.Errorf(codes.Unimplemented, "method PushStateDiskKey not implemented") return status.Errorf(codes.Unimplemented, "method Recover not implemented")
} }
func (UnimplementedAPIServer) mustEmbedUnimplementedAPIServer() {} func (UnimplementedAPIServer) mustEmbedUnimplementedAPIServer() {}
@ -70,36 +92,46 @@ func RegisterAPIServer(s grpc.ServiceRegistrar, srv APIServer) {
s.RegisterService(&API_ServiceDesc, srv) s.RegisterService(&API_ServiceDesc, srv)
} }
func _API_PushStateDiskKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { func _API_Recover_Handler(srv interface{}, stream grpc.ServerStream) error {
in := new(PushStateDiskKeyRequest) return srv.(APIServer).Recover(&aPIRecoverServer{stream})
if err := dec(in); err != nil { }
type API_RecoverServer interface {
Send(*RecoverResponse) error
Recv() (*RecoverMessage, error)
grpc.ServerStream
}
type aPIRecoverServer struct {
grpc.ServerStream
}
func (x *aPIRecoverServer) Send(m *RecoverResponse) error {
return x.ServerStream.SendMsg(m)
}
func (x *aPIRecoverServer) Recv() (*RecoverMessage, error) {
m := new(RecoverMessage)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err return nil, err
} }
if interceptor == nil { return m, nil
return srv.(APIServer).PushStateDiskKey(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/keyproto.API/PushStateDiskKey",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(APIServer).PushStateDiskKey(ctx, req.(*PushStateDiskKeyRequest))
}
return interceptor(ctx, in, info, handler)
} }
// API_ServiceDesc is the grpc.ServiceDesc for API service. // API_ServiceDesc is the grpc.ServiceDesc for API service.
// It's only intended for direct use with grpc.RegisterService, // It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy) // and not to be introspected or modified (even as a copy)
var API_ServiceDesc = grpc.ServiceDesc{ var API_ServiceDesc = grpc.ServiceDesc{
ServiceName: "keyproto.API", ServiceName: "recoverproto.API",
HandlerType: (*APIServer)(nil), HandlerType: (*APIServer)(nil),
Methods: []grpc.MethodDesc{ Methods: []grpc.MethodDesc{},
Streams: []grpc.StreamDesc{
{ {
MethodName: "PushStateDiskKey", StreamName: "Recover",
Handler: _API_PushStateDiskKey_Handler, Handler: _API_Recover_Handler,
ServerStreams: true,
ClientStreams: true,
}, },
}, },
Streams: []grpc.StreamDesc{}, Metadata: "recover.proto",
Metadata: "keyservice.proto",
} }

View File

@ -24,9 +24,9 @@ RUN go install google.golang.org/protobuf/cmd/protoc-gen-go@v${GEN_GO_VER} && \
# Generate code for every existing proto file # Generate code for every existing proto file
## disk-mapper keyservice api ## disk-mapper recover api
WORKDIR /disk-mapper WORKDIR /disk-mapper
COPY state/keyproto/*.proto /disk-mapper COPY disk-mapper/recoverproto/*.proto /disk-mapper
RUN protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative *.proto RUN protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative *.proto
## debugd service ## debugd service
@ -54,7 +54,7 @@ COPY bootstrapper/initproto/*.proto /init
RUN protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative *.proto RUN protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative *.proto
FROM scratch as export FROM scratch as export
COPY --from=build /disk-mapper/*.go state/keyproto/ COPY --from=build /disk-mapper/*.go disk-mapper/recoverproto/
COPY --from=build /service/*.go debugd/service/ COPY --from=build /service/*.go debugd/service/
COPY --from=build /kms/*.go kms/kmsproto/ COPY --from=build /kms/*.go kms/kmsproto/
COPY --from=build /joinservice/*.go joinservice/joinproto/ COPY --from=build /joinservice/*.go joinservice/joinproto/

View File

@ -1,204 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package keyservice
import (
"context"
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/edgelesssys/constellation/internal/cloud/metadata"
"github.com/edgelesssys/constellation/internal/crypto"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/internal/logger"
"github.com/edgelesssys/constellation/internal/oid"
"github.com/edgelesssys/constellation/joinservice/joinproto"
"github.com/edgelesssys/constellation/state/keyproto"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/status"
"k8s.io/utils/clock"
)
// KeyAPI is the interface called by control-plane or an admin during restart of a node.
type KeyAPI struct {
listenAddr string
log *logger.Logger
mux sync.Mutex
metadata metadata.InstanceLister
issuer QuoteIssuer
key []byte
measurementSecret []byte
keyReceived chan struct{}
clock clock.WithTicker
timeout time.Duration
interval time.Duration
keyproto.UnimplementedAPIServer
}
// New initializes a KeyAPI with the given parameters.
func New(log *logger.Logger, issuer QuoteIssuer, metadata metadata.InstanceLister, timeout time.Duration, interval time.Duration) *KeyAPI {
return &KeyAPI{
log: log,
metadata: metadata,
issuer: issuer,
keyReceived: make(chan struct{}, 1),
clock: clock.RealClock{},
timeout: timeout,
interval: interval,
}
}
// PushStateDiskKey is the rpc to push state disk decryption keys to a restarting node.
func (a *KeyAPI) PushStateDiskKey(ctx context.Context, in *keyproto.PushStateDiskKeyRequest) (*keyproto.PushStateDiskKeyResponse, error) {
a.mux.Lock()
defer a.mux.Unlock()
if len(a.key) != 0 {
return nil, status.Error(codes.FailedPrecondition, "node already received a passphrase")
}
if len(in.StateDiskKey) != crypto.StateDiskKeyLength {
return nil, status.Errorf(codes.InvalidArgument, "received invalid passphrase: expected length: %d, but got: %d", crypto.StateDiskKeyLength, len(in.StateDiskKey))
}
if len(in.MeasurementSecret) != crypto.RNGLengthDefault {
return nil, status.Errorf(codes.InvalidArgument, "received invalid measurement secret: expected length: %d, but got: %d", crypto.RNGLengthDefault, len(in.MeasurementSecret))
}
a.key = in.StateDiskKey
a.measurementSecret = in.MeasurementSecret
a.keyReceived <- struct{}{}
return &keyproto.PushStateDiskKeyResponse{}, nil
}
// WaitForDecryptionKey notifies control-plane nodes to send a decryption key and waits until a key is received.
func (a *KeyAPI) WaitForDecryptionKey(uuid, listenAddr string) (diskKey, measurementSecret []byte, err error) {
if uuid == "" {
return nil, nil, errors.New("received no disk UUID")
}
a.listenAddr = listenAddr
creds := atlscredentials.New(a.issuer, nil)
server := grpc.NewServer(grpc.Creds(creds))
keyproto.RegisterAPIServer(server, a)
listener, err := net.Listen("tcp", listenAddr)
if err != nil {
return nil, nil, err
}
defer listener.Close()
a.log.Infof("Waiting for decryption key. Listening on: %s", listener.Addr().String())
go server.Serve(listener)
defer server.GracefulStop()
a.requestKeyLoop(uuid)
return a.key, a.measurementSecret, nil
}
// ResetKey resets a previously set key.
func (a *KeyAPI) ResetKey() {
a.key = nil
}
// requestKeyLoop continuously requests decryption keys from all available control-plane nodes, until the KeyAPI receives a key.
func (a *KeyAPI) requestKeyLoop(uuid string, opts ...grpc.DialOption) {
// we do not perform attestation, since the restarting node does not need to care about notifying the correct node
// if an incorrect key is pushed by a malicious actor, decrypting the disk will fail, and the node will not start
creds := atlscredentials.New(a.issuer, nil)
ticker := a.clock.NewTicker(a.interval)
defer ticker.Stop()
for {
endpoints, err := a.getJoinServiceEndpoints()
if err != nil {
a.log.With(zap.Error(err)).Errorf("Failed to get JoinService endpoints")
} else {
a.log.Infof("Received list with JoinService endpoints: %v", endpoints)
for _, endpoint := range endpoints {
a.requestKey(endpoint, uuid, creds, opts...)
}
}
select {
case <-a.keyReceived:
// return if a key was received
// a key can be send by
// - a control-plane node, after the request rpc was received
// - by a Constellation admin, at any time this loop is running on a node during boot
return
case <-ticker.C():
}
}
}
func (a *KeyAPI) getJoinServiceEndpoints() ([]string, error) {
ctx, cancel := context.WithTimeout(context.Background(), a.timeout)
defer cancel()
return metadata.JoinServiceEndpoints(ctx, a.metadata)
}
func (a *KeyAPI) requestKey(endpoint, uuid string, credentials credentials.TransportCredentials, opts ...grpc.DialOption) {
opts = append(opts, grpc.WithTransportCredentials(credentials))
a.log.With(zap.String("endpoint", endpoint)).Infof("Requesting rejoin ticket")
rejoinTicket, err := a.requestRejoinTicket(endpoint, uuid, opts...)
if err != nil {
a.log.With(zap.Error(err), zap.String("endpoint", endpoint)).Errorf("Failed to request rejoin ticket")
return
}
a.log.With(zap.String("endpoint", endpoint)).Infof("Pushing key to own server")
if err := a.pushKeyToOwnServer(rejoinTicket.StateDiskKey, rejoinTicket.MeasurementSecret, opts...); err != nil {
a.log.With(zap.Error(err), zap.String("endpoint", a.listenAddr)).Errorf("Failed to push key to own server")
return
}
}
func (a *KeyAPI) requestRejoinTicket(endpoint, uuid string, opts ...grpc.DialOption) (*joinproto.IssueRejoinTicketResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), a.timeout)
defer cancel()
conn, err := grpc.DialContext(ctx, endpoint, opts...)
if err != nil {
return nil, fmt.Errorf("dialing gRPC: %w", err)
}
defer conn.Close()
client := joinproto.NewAPIClient(conn)
req := &joinproto.IssueRejoinTicketRequest{DiskUuid: uuid}
return client.IssueRejoinTicket(ctx, req)
}
func (a *KeyAPI) pushKeyToOwnServer(stateDiskKey, measurementSecret []byte, opts ...grpc.DialOption) error {
ctx, cancel := context.WithTimeout(context.Background(), a.timeout)
defer cancel()
conn, err := grpc.DialContext(ctx, a.listenAddr, opts...)
if err != nil {
return fmt.Errorf("dialing gRPC: %w", err)
}
defer conn.Close()
client := keyproto.NewAPIClient(conn)
req := &keyproto.PushStateDiskKeyRequest{StateDiskKey: stateDiskKey, MeasurementSecret: measurementSecret}
_, err = client.PushStateDiskKey(ctx, req)
return err
}
// QuoteValidator validates quotes.
type QuoteValidator interface {
oid.Getter
// Validate validates a quote and returns the user data on success.
Validate(attDoc []byte, nonce []byte) ([]byte, error)
}
// QuoteIssuer issues quotes.
type QuoteIssuer interface {
oid.Getter
// Issue issues a quote for remote attestation for a given message
Issue(userData []byte, nonce []byte) (quote []byte, err error)
}

View File

@ -1,264 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package keyservice
import (
"context"
"errors"
"net"
"testing"
"time"
"github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/cloud/metadata"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/internal/logger"
"github.com/edgelesssys/constellation/internal/oid"
"github.com/edgelesssys/constellation/internal/role"
"github.com/edgelesssys/constellation/joinservice/joinproto"
"github.com/edgelesssys/constellation/state/keyproto"
"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
"google.golang.org/grpc"
"google.golang.org/grpc/test/bufconn"
testclock "k8s.io/utils/clock/testing"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestRequestKeyLoop(t *testing.T) {
clockstep := struct{}{}
someErr := errors.New("failed")
defaultInstance := metadata.InstanceMetadata{
Name: "test-instance",
ProviderID: "/test/provider",
Role: role.ControlPlane,
VPCIP: "192.0.2.1",
}
testCases := map[string]struct {
answers []any
}{
"success": {
answers: []any{
listAnswer{listResponse: []metadata.InstanceMetadata{defaultInstance}},
issueRejoinTicketAnswer{stateDiskKey: []byte{0x1}, measurementSecret: []byte{0x2}},
pushStateDiskKeyAnswer{},
},
},
"recover metadata list error": {
answers: []any{
listAnswer{err: someErr},
clockstep,
listAnswer{listResponse: []metadata.InstanceMetadata{defaultInstance}},
issueRejoinTicketAnswer{stateDiskKey: []byte{0x1}, measurementSecret: []byte{0x2}},
pushStateDiskKeyAnswer{},
},
},
"recover issue rejoin ticket error": {
answers: []any{
listAnswer{listResponse: []metadata.InstanceMetadata{defaultInstance}},
issueRejoinTicketAnswer{err: someErr},
clockstep,
listAnswer{listResponse: []metadata.InstanceMetadata{defaultInstance}},
issueRejoinTicketAnswer{stateDiskKey: []byte{0x1}, measurementSecret: []byte{0x2}},
pushStateDiskKeyAnswer{},
},
},
"recover push key error": {
answers: []any{
listAnswer{listResponse: []metadata.InstanceMetadata{defaultInstance}},
issueRejoinTicketAnswer{stateDiskKey: []byte{0x1}, measurementSecret: []byte{0x2}},
pushStateDiskKeyAnswer{err: someErr},
clockstep,
listAnswer{listResponse: []metadata.InstanceMetadata{defaultInstance}},
issueRejoinTicketAnswer{stateDiskKey: []byte{0x1}, measurementSecret: []byte{0x2}},
pushStateDiskKeyAnswer{},
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
metadataServer := newStubMetadataServer()
joinServer := newStubJoinAPIServer()
keyServer := newStubKeyAPIServer()
listener := bufconn.Listen(1024)
defer listener.Close()
creds := atlscredentials.New(atls.NewFakeIssuer(oid.Dummy{}), nil)
grpcServer := grpc.NewServer(grpc.Creds(creds))
joinproto.RegisterAPIServer(grpcServer, joinServer)
keyproto.RegisterAPIServer(grpcServer, keyServer)
go grpcServer.Serve(listener)
defer grpcServer.GracefulStop()
clock := testclock.NewFakeClock(time.Now())
keyReceived := make(chan struct{}, 1)
keyWaiter := &KeyAPI{
listenAddr: "192.0.2.1:30090",
log: logger.NewTest(t),
metadata: metadataServer,
keyReceived: keyReceived,
clock: clock,
timeout: 1 * time.Second,
interval: 1 * time.Second,
}
grpcOpts := []grpc.DialOption{
grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) {
return listener.DialContext(ctx)
}),
}
// Start the request loop under tests.
done := make(chan struct{})
go func() {
defer close(done)
keyWaiter.requestKeyLoop("1234", grpcOpts...)
}()
// Play test case answers.
for _, answ := range tc.answers {
switch answ := answ.(type) {
case listAnswer:
metadataServer.listAnswerC <- answ
case issueRejoinTicketAnswer:
joinServer.issueRejoinTicketAnswerC <- answ
case pushStateDiskKeyAnswer:
keyServer.pushStateDiskKeyAnswerC <- answ
default:
clock.Step(time.Second)
}
}
// Stop the request loop.
keyReceived <- struct{}{}
})
}
}
func TestPushStateDiskKey(t *testing.T) {
testCases := map[string]struct {
testAPI *KeyAPI
request *keyproto.PushStateDiskKeyRequest
wantErr bool
}{
"success": {
testAPI: &KeyAPI{keyReceived: make(chan struct{}, 1)},
request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), MeasurementSecret: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB")},
},
"key already set": {
testAPI: &KeyAPI{
keyReceived: make(chan struct{}, 1),
key: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
},
request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"), MeasurementSecret: []byte("CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC")},
wantErr: true,
},
"incorrect size of pushed key": {
testAPI: &KeyAPI{keyReceived: make(chan struct{}, 1)},
request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("AAAAAAAAAAAAAAAA"), MeasurementSecret: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB")},
wantErr: true,
},
"incorrect size of measurement secret": {
testAPI: &KeyAPI{keyReceived: make(chan struct{}, 1)},
request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), MeasurementSecret: []byte("BBBBBBBBBBBBBBBB")},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
tc.testAPI.log = logger.NewTest(t)
_, err := tc.testAPI.PushStateDiskKey(context.Background(), tc.request)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.request.StateDiskKey, tc.testAPI.key)
}
})
}
}
func TestResetKey(t *testing.T) {
api := New(logger.NewTest(t), nil, nil, time.Second, time.Millisecond)
api.key = []byte{0x1, 0x2, 0x3}
api.ResetKey()
assert.Nil(t, api.key)
}
type stubMetadataServer struct {
listAnswerC chan listAnswer
}
func newStubMetadataServer() *stubMetadataServer {
return &stubMetadataServer{
listAnswerC: make(chan listAnswer),
}
}
func (s *stubMetadataServer) List(context.Context) ([]metadata.InstanceMetadata, error) {
answer := <-s.listAnswerC
return answer.listResponse, answer.err
}
type listAnswer struct {
listResponse []metadata.InstanceMetadata
err error
}
type stubJoinAPIServer struct {
issueRejoinTicketAnswerC chan issueRejoinTicketAnswer
joinproto.UnimplementedAPIServer
}
func newStubJoinAPIServer() *stubJoinAPIServer {
return &stubJoinAPIServer{
issueRejoinTicketAnswerC: make(chan issueRejoinTicketAnswer),
}
}
func (s *stubJoinAPIServer) IssueRejoinTicket(context.Context, *joinproto.IssueRejoinTicketRequest) (*joinproto.IssueRejoinTicketResponse, error) {
answer := <-s.issueRejoinTicketAnswerC
resp := &joinproto.IssueRejoinTicketResponse{
StateDiskKey: answer.stateDiskKey,
MeasurementSecret: answer.measurementSecret,
}
return resp, answer.err
}
type issueRejoinTicketAnswer struct {
stateDiskKey []byte
measurementSecret []byte
err error
}
type stubKeyAPIServer struct {
pushStateDiskKeyAnswerC chan pushStateDiskKeyAnswer
keyproto.UnimplementedAPIServer
}
func newStubKeyAPIServer() *stubKeyAPIServer {
return &stubKeyAPIServer{
pushStateDiskKeyAnswerC: make(chan pushStateDiskKeyAnswer),
}
}
func (s *stubKeyAPIServer) PushStateDiskKey(context.Context, *keyproto.PushStateDiskKeyRequest) (*keyproto.PushStateDiskKeyResponse, error) {
answer := <-s.pushStateDiskKeyAnswerC
return &keyproto.PushStateDiskKeyResponse{}, answer.err
}
type pushStateDiskKeyAnswer struct {
err error
}

View File

@ -1,219 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.0
// protoc v3.20.1
// source: keyservice.proto
package keyproto
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type PushStateDiskKeyRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
StateDiskKey []byte `protobuf:"bytes,1,opt,name=state_disk_key,json=stateDiskKey,proto3" json:"state_disk_key,omitempty"`
MeasurementSecret []byte `protobuf:"bytes,2,opt,name=measurement_secret,json=measurementSecret,proto3" json:"measurement_secret,omitempty"`
}
func (x *PushStateDiskKeyRequest) Reset() {
*x = PushStateDiskKeyRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_keyservice_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *PushStateDiskKeyRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*PushStateDiskKeyRequest) ProtoMessage() {}
func (x *PushStateDiskKeyRequest) ProtoReflect() protoreflect.Message {
mi := &file_keyservice_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use PushStateDiskKeyRequest.ProtoReflect.Descriptor instead.
func (*PushStateDiskKeyRequest) Descriptor() ([]byte, []int) {
return file_keyservice_proto_rawDescGZIP(), []int{0}
}
func (x *PushStateDiskKeyRequest) GetStateDiskKey() []byte {
if x != nil {
return x.StateDiskKey
}
return nil
}
func (x *PushStateDiskKeyRequest) GetMeasurementSecret() []byte {
if x != nil {
return x.MeasurementSecret
}
return nil
}
type PushStateDiskKeyResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *PushStateDiskKeyResponse) Reset() {
*x = PushStateDiskKeyResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_keyservice_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *PushStateDiskKeyResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*PushStateDiskKeyResponse) ProtoMessage() {}
func (x *PushStateDiskKeyResponse) ProtoReflect() protoreflect.Message {
mi := &file_keyservice_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use PushStateDiskKeyResponse.ProtoReflect.Descriptor instead.
func (*PushStateDiskKeyResponse) Descriptor() ([]byte, []int) {
return file_keyservice_proto_rawDescGZIP(), []int{1}
}
var File_keyservice_proto protoreflect.FileDescriptor
var file_keyservice_proto_rawDesc = []byte{
0x0a, 0x10, 0x6b, 0x65, 0x79, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x12, 0x08, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x6e, 0x0a, 0x17,
0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79,
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x24, 0x0a, 0x0e, 0x73, 0x74, 0x61, 0x74, 0x65,
0x5f, 0x64, 0x69, 0x73, 0x6b, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52,
0x0c, 0x73, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79, 0x12, 0x2d, 0x0a,
0x12, 0x6d, 0x65, 0x61, 0x73, 0x75, 0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x73, 0x65, 0x63,
0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, 0x6d, 0x65, 0x61, 0x73, 0x75,
0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x22, 0x1a, 0x0a, 0x18,
0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x32, 0x60, 0x0a, 0x03, 0x41, 0x50, 0x49, 0x12,
0x59, 0x0a, 0x10, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b,
0x4b, 0x65, 0x79, 0x12, 0x21, 0x2e, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x50,
0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x22, 0x2e, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x2e, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b,
0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x40, 0x5a, 0x3e, 0x67, 0x69,
0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x65, 0x64, 0x67, 0x65, 0x6c, 0x65, 0x73,
0x73, 0x73, 0x79, 0x73, 0x2f, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x65, 0x6c, 0x6c, 0x61, 0x74, 0x69,
0x6f, 0x6e, 0x2f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x2f, 0x6b, 0x65, 0x79, 0x73, 0x65, 0x72, 0x76,
0x69, 0x63, 0x65, 0x2f, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x33,
}
var (
file_keyservice_proto_rawDescOnce sync.Once
file_keyservice_proto_rawDescData = file_keyservice_proto_rawDesc
)
func file_keyservice_proto_rawDescGZIP() []byte {
file_keyservice_proto_rawDescOnce.Do(func() {
file_keyservice_proto_rawDescData = protoimpl.X.CompressGZIP(file_keyservice_proto_rawDescData)
})
return file_keyservice_proto_rawDescData
}
var file_keyservice_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_keyservice_proto_goTypes = []interface{}{
(*PushStateDiskKeyRequest)(nil), // 0: keyproto.PushStateDiskKeyRequest
(*PushStateDiskKeyResponse)(nil), // 1: keyproto.PushStateDiskKeyResponse
}
var file_keyservice_proto_depIdxs = []int32{
0, // 0: keyproto.API.PushStateDiskKey:input_type -> keyproto.PushStateDiskKeyRequest
1, // 1: keyproto.API.PushStateDiskKey:output_type -> keyproto.PushStateDiskKeyResponse
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_keyservice_proto_init() }
func file_keyservice_proto_init() {
if File_keyservice_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_keyservice_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*PushStateDiskKeyRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_keyservice_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*PushStateDiskKeyResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_keyservice_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_keyservice_proto_goTypes,
DependencyIndexes: file_keyservice_proto_depIdxs,
MessageInfos: file_keyservice_proto_msgTypes,
}.Build()
File_keyservice_proto = out.File
file_keyservice_proto_rawDesc = nil
file_keyservice_proto_goTypes = nil
file_keyservice_proto_depIdxs = nil
}

View File

@ -1,17 +0,0 @@
syntax = "proto3";
package keyproto;
option go_package = "github.com/edgelesssys/constellation/state/keyservice/keyproto";
service API {
rpc PushStateDiskKey(PushStateDiskKeyRequest) returns (PushStateDiskKeyResponse);
}
message PushStateDiskKeyRequest {
bytes state_disk_key = 1;
bytes measurement_secret = 2;
}
message PushStateDiskKeyResponse {
}