AB#2260 Refactor disk-mapper recovery (#82)

* Refactor disk-mapper recovery

* Adapt constellation recover command to use new disk-mapper recovery API

* Fix Cilium connectivity on rebooting nodes (#89)

* Lower CoreDNS reschedule timeout to 10 seconds (#93)

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2022-09-08 14:45:27 +02:00 committed by GitHub
parent a7b20b2a11
commit 8cb155d5c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
40 changed files with 1600 additions and 1130 deletions

View File

@ -106,6 +106,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Owner ID and Unique ID are merged into a single value: Cluster ID.
- Streamline logging to only use one logging library, instead of multiple.
- Replace dependency on github.com/willdonnelly/passwd with own implementation.
- Refactor disk-mapper to allow a more streamlined node recovery
### Removed

View File

@ -71,5 +71,5 @@ add_test(NAME unit-hack COMMAND go test -race -count=3 ./... WORKING_DIRECTORY $
add_test(NAME unit-node-operator COMMAND go test -race -count=3 ./... WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/operators/constellation-node-operator)
add_test(NAME integration-node-operator COMMAND make test WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/operators/constellation-node-operator)
add_test(NAME integration-csi COMMAND bash -c "go test -tags integration -c ./test/ && sudo ./test.test -test.v" WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/csi)
add_test(NAME integration-dm COMMAND bash -c "go test -tags integration -c ./test/ && sudo ./test.test -test.v" WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/state/internal)
add_test(NAME integration-dm COMMAND bash -c "go test -tags integration -c ./test/ && sudo ./test.test -test.v" WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/disk-mapper/internal)
add_test(NAME integration-license COMMAND bash -c "go test -tags integration" WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/internal/license)

View File

@ -61,7 +61,7 @@ Core components:
* [image](image): Build files for the Constellation disk image
* [kms](kms): Constellation's key management client and server
* [csi](csi): Package used by CSI plugins to create and mount encrypted block devices
* [state](state): Contains the disk-mapper that maps the encrypted node data disk during boot
* [disk-mapper](disk-mapper): Contains the disk-mapper that maps the encrypted node data disk during boot
Development components:

View File

@ -36,7 +36,7 @@ ARG PROJECT_VERSION
RUN go build -o bootstrapper -tags=gcp,disable_tpm_simulator -buildvcs=false -ldflags "-s -w -buildid='' -X main.version=${PROJECT_VERSION}" ./cmd/bootstrapper/
FROM build AS build-disk-mapper
WORKDIR /constellation/state/
WORKDIR /constellation/disk-mapper/
RUN go build -o disk-mapper -ldflags "-s -w" ./cmd/
@ -44,4 +44,4 @@ FROM scratch AS bootstrapper
COPY --from=build-bootstrapper /constellation/bootstrapper/bootstrapper /
FROM scratch AS disk-mapper
COPY --from=build-disk-mapper /constellation/state/disk-mapper /
COPY --from=build-disk-mapper /constellation/disk-mapper/disk-mapper /

View File

@ -50,7 +50,7 @@ func run(issuerWrapper initserver.IssuerWrapper, tpm vtpm.TPMOpenFunc, fileHandl
}
if nodeBootstrapped {
if err := kube.StartKubelet(); err != nil {
if err := kube.StartKubelet(log); err != nil {
log.With(zap.Error(err)).Fatalf("Failed to restart kubelet")
}
return
@ -88,7 +88,7 @@ func getDiskUUID() (string, error) {
type clusterInitJoiner interface {
joinclient.ClusterJoiner
initserver.ClusterInitializer
StartKubelet() error
StartKubelet(*logger.Logger) error
}
type metadataAPI interface {

View File

@ -33,7 +33,7 @@ func (c *clusterFake) JoinCluster(context.Context, *kubeadm.BootstrapTokenDiscov
}
// StartKubelet starts the kubelet service.
func (c *clusterFake) StartKubelet() error {
func (c *clusterFake) StartKubelet(*logger.Logger) error {
return nil
}

View File

@ -13,30 +13,6 @@ import (
"github.com/coreos/go-systemd/v22/dbus"
)
func restartSystemdUnit(ctx context.Context, unit string) error {
conn, err := dbus.NewSystemdConnectionContext(ctx)
if err != nil {
return fmt.Errorf("establishing systemd connection: %w", err)
}
restartChan := make(chan string)
if _, err := conn.RestartUnitContext(ctx, unit, "replace", restartChan); err != nil {
return fmt.Errorf("restarting systemd unit %q: %w", unit, err)
}
// Wait for the restart to finish and actually check if it was
// successful or not.
result := <-restartChan
switch result {
case "done":
return nil
default:
return fmt.Errorf("restarting systemd unit %q failed: expected %v but received %v", unit, "done", result)
}
}
func startSystemdUnit(ctx context.Context, unit string) error {
conn, err := dbus.NewSystemdConnectionContext(ctx)
if err != nil {

View File

@ -264,12 +264,19 @@ func (k *KubernetesUtil) deployCiliumGCP(ctx context.Context, helmClient *action
return err
}
timeoutS := int64(10)
// allow coredns to run on uninitialized nodes (required by cloud-controller-manager)
tolerations := []corev1.Toleration{
{
Key: "node.cloudprovider.kubernetes.io/uninitialized",
Value: "true",
Effect: "NoSchedule",
Effect: corev1.TaintEffectNoSchedule,
},
{
Key: "node.kubernetes.io/unreachable",
Operator: corev1.TolerationOpExists,
Effect: corev1.TaintEffectNoExecute,
TolerationSeconds: &timeoutS,
},
}
if err = kubectl.AddTolerationsToDeployment(ctx, tolerations, "coredns", "kube-system"); err != nil {
@ -305,7 +312,7 @@ func (k *KubernetesUtil) deployCiliumGCP(ctx context.Context, helmClient *action
// FixCilium fixes https://github.com/cilium/cilium/issues/19958 but instead of a rollout restart of
// the cilium daemonset, it only restarts the local cilium pod.
func (k *KubernetesUtil) FixCilium(nodeNameK8s string, log *logger.Logger) {
func (k *KubernetesUtil) FixCilium(log *logger.Logger) {
// wait for cilium pod to be healthy
client := http.Client{}
for {
@ -487,13 +494,6 @@ func (k *KubernetesUtil) StartKubelet() error {
return startSystemdUnit(ctx, "kubelet.service")
}
// RestartKubelet restarts a kubelet.
func (k *KubernetesUtil) RestartKubelet() error {
ctx, cancel := context.WithTimeout(context.TODO(), kubeletStartTimeout)
defer cancel()
return restartSystemdUnit(ctx, "kubelet.service")
}
// createSignedKubeletCert manually creates a Kubernetes CA signed kubelet certificate for the bootstrapper node.
// This is necessary because this node does not request a certificate from the join service.
func (k *KubernetesUtil) createSignedKubeletCert(nodeName string, ips []net.IP) error {

View File

@ -33,6 +33,5 @@ type clusterUtil interface {
SetupNodeMaintenanceOperator(kubectl k8sapi.Client, nodeMaintenanceOperatorConfiguration kubernetes.Marshaler) error
SetupNodeOperator(ctx context.Context, kubectl k8sapi.Client, nodeOperatorConfiguration kubernetes.Marshaler) error
StartKubelet() error
RestartKubelet() error
FixCilium(nodeNameK8s string, log *logger.Logger)
FixCilium(log *logger.Logger)
}

View File

@ -229,7 +229,7 @@ func (k *KubeWrapper) InitCluster(
return nil, fmt.Errorf("failed to setup k8s version ConfigMap: %w", err)
}
k.clusterUtil.FixCilium(nodeName, log)
k.clusterUtil.FixCilium(log)
return k.GetKubeconfig()
}
@ -309,7 +309,7 @@ func (k *KubeWrapper) JoinCluster(ctx context.Context, args *kubeadm.BootstrapTo
return fmt.Errorf("joining cluster: %v; %w ", string(joinConfigYAML), err)
}
k.clusterUtil.FixCilium(nodeName, log)
k.clusterUtil.FixCilium(log)
return nil
}
@ -481,8 +481,13 @@ func k8sCompliantHostname(in string) string {
}
// StartKubelet starts the kubelet service.
func (k *KubeWrapper) StartKubelet() error {
return k.clusterUtil.StartKubelet()
func (k *KubeWrapper) StartKubelet(log *logger.Logger) error {
if err := k.clusterUtil.StartKubelet(); err != nil {
return fmt.Errorf("starting kubelet: %w", err)
}
k.clusterUtil.FixCilium(log)
return nil
}
// getIPAddr retrieves to default sender IP used for outgoing connection.

View File

@ -531,7 +531,6 @@ type stubClusterUtil struct {
setupNodeOperatorErr error
joinClusterErr error
startKubeletErr error
restartKubeletErr error
initConfigs [][]byte
joinConfigs [][]byte
@ -603,11 +602,7 @@ func (s *stubClusterUtil) StartKubelet() error {
return s.startKubeletErr
}
func (s *stubClusterUtil) RestartKubelet() error {
return s.restartKubeletErr
}
func (s *stubClusterUtil) FixCilium(nodeName string, log *logger.Logger) {
func (s *stubClusterUtil) FixCilium(log *logger.Logger) {
}
type stubConfigProvider struct {

View File

@ -7,24 +7,26 @@ SPDX-License-Identifier: AGPL-3.0-only
package cmd
import (
"errors"
"context"
"fmt"
"regexp"
"strings"
"io"
"github.com/edgelesssys/constellation/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/cli/internal/proto"
"github.com/edgelesssys/constellation/internal/attestation"
"github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/crypto"
"github.com/edgelesssys/constellation/internal/file"
"github.com/edgelesssys/constellation/internal/state"
"github.com/spf13/afero"
"github.com/spf13/cobra"
)
var diskUUIDRegexp = regexp.MustCompile("^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$")
type recoveryClient interface {
Connect(endpoint string, validators atls.Validator) error
Recover(ctx context.Context, masterSecret, salt []byte) error
io.Closer
}
// NewRecoverCmd returns a new cobra.Command for the recover command.
func NewRecoverCmd() *cobra.Command {
@ -38,15 +40,13 @@ func NewRecoverCmd() *cobra.Command {
}
cmd.Flags().StringP("endpoint", "e", "", "endpoint of the instance, passed as HOST[:PORT] (required)")
must(cmd.MarkFlagRequired("endpoint"))
cmd.Flags().String("disk-uuid", "", "disk UUID of the encrypted state disk (required)")
must(cmd.MarkFlagRequired("disk-uuid"))
cmd.Flags().String("master-secret", constants.MasterSecretFilename, "path to master secret file")
return cmd
}
func runRecover(cmd *cobra.Command, _ []string) error {
fileHandler := file.NewHandler(afero.NewOsFs())
recoveryClient := &proto.KeyClient{}
recoveryClient := &proto.RecoverClient{}
defer recoveryClient.Close()
return recover(cmd, fileHandler, recoveryClient)
}
@ -82,17 +82,7 @@ func recover(cmd *cobra.Command, fileHandler file.Handler, recoveryClient recove
return err
}
diskKey, err := deriveStateDiskKey(masterSecret.Key, masterSecret.Salt, flags.diskUUID)
if err != nil {
return err
}
measurementSecret, err := attestation.DeriveMeasurementSecret(masterSecret.Key, masterSecret.Salt)
if err != nil {
return err
}
if err := recoveryClient.PushStateDiskKey(cmd.Context(), diskKey, measurementSecret); err != nil {
if err := recoveryClient.Recover(cmd.Context(), masterSecret.Key, masterSecret.Salt); err != nil {
return err
}
@ -110,15 +100,6 @@ func parseRecoverFlags(cmd *cobra.Command) (recoverFlags, error) {
return recoverFlags{}, fmt.Errorf("validating endpoint argument: %w", err)
}
diskUUID, err := cmd.Flags().GetString("disk-uuid")
if err != nil {
return recoverFlags{}, fmt.Errorf("parsing disk-uuid argument: %w", err)
}
if match := diskUUIDRegexp.MatchString(diskUUID); !match {
return recoverFlags{}, errors.New("flag '--disk-uuid' isn't a valid LUKS UUID")
}
diskUUID = strings.ToLower(diskUUID)
masterSecretPath, err := cmd.Flags().GetString("master-secret")
if err != nil {
return recoverFlags{}, fmt.Errorf("parsing master-secret path argument: %w", err)
@ -131,7 +112,6 @@ func parseRecoverFlags(cmd *cobra.Command) (recoverFlags, error) {
return recoverFlags{
endpoint: endpoint,
diskUUID: diskUUID,
secretPath: masterSecretPath,
configPath: configPath,
}, nil
@ -139,12 +119,6 @@ func parseRecoverFlags(cmd *cobra.Command) (recoverFlags, error) {
type recoverFlags struct {
endpoint string
diskUUID string
secretPath string
configPath string
}
// deriveStateDiskKey derives a state disk key from a master key, a salt, and a disk UUID.
func deriveStateDiskKey(masterKey, salt []byte, diskUUID string) ([]byte, error) {
return crypto.DeriveKey(masterKey, salt, []byte(crypto.HKDFInfoPrefix+diskUUID), crypto.StateDiskKeyLength)
}

View File

@ -8,10 +8,11 @@ package cmd
import (
"bytes"
"context"
"errors"
"strings"
"testing"
"github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/constants"
@ -57,7 +58,6 @@ func TestRecover(t *testing.T) {
client *stubRecoveryClient
masterSecret testvector.HKDF
endpointFlag string
diskUUIDFlag string
masterSecretFlag string
configFlag string
stateless bool
@ -67,29 +67,13 @@ func TestRecover(t *testing.T) {
existingState: validState,
client: &stubRecoveryClient{},
endpointFlag: "192.0.2.1",
diskUUIDFlag: testvector.HKDFZero.Info,
masterSecret: testvector.HKDFZero,
},
"uppercase disk uuid works": {
existingState: validState,
client: &stubRecoveryClient{},
endpointFlag: "192.0.2.1",
diskUUIDFlag: strings.ToUpper(testvector.HKDF0xFF.Info),
masterSecret: testvector.HKDF0xFF,
},
"lowercase disk uuid results in same key": {
existingState: validState,
client: &stubRecoveryClient{},
endpointFlag: "192.0.2.1",
diskUUIDFlag: strings.ToLower(testvector.HKDF0xFF.Info),
masterSecret: testvector.HKDF0xFF,
},
"missing flags": {
wantErr: true,
},
"missing config": {
endpointFlag: "192.0.2.1",
diskUUIDFlag: testvector.HKDFZero.Info,
masterSecret: testvector.HKDFZero,
configFlag: "nonexistent-config",
wantErr: true,
@ -97,7 +81,6 @@ func TestRecover(t *testing.T) {
"missing state": {
existingState: validState,
endpointFlag: "192.0.2.1",
diskUUIDFlag: testvector.HKDFZero.Info,
masterSecret: testvector.HKDFZero,
stateless: true,
wantErr: true,
@ -105,7 +88,6 @@ func TestRecover(t *testing.T) {
"invalid cloud provider": {
existingState: invalidCSPState,
endpointFlag: "192.0.2.1",
diskUUIDFlag: testvector.HKDFZero.Info,
masterSecret: testvector.HKDFZero,
wantErr: true,
},
@ -113,7 +95,6 @@ func TestRecover(t *testing.T) {
existingState: validState,
client: &stubRecoveryClient{connectErr: errors.New("connect failed")},
endpointFlag: "192.0.2.1",
diskUUIDFlag: testvector.HKDFZero.Info,
masterSecret: testvector.HKDFZero,
wantErr: true,
},
@ -121,7 +102,6 @@ func TestRecover(t *testing.T) {
existingState: validState,
client: &stubRecoveryClient{pushStateDiskKeyErr: errors.New("pushing key failed")},
endpointFlag: "192.0.2.1",
diskUUIDFlag: testvector.HKDFZero.Info,
masterSecret: testvector.HKDFZero,
wantErr: true,
},
@ -140,9 +120,6 @@ func TestRecover(t *testing.T) {
if tc.endpointFlag != "" {
require.NoError(cmd.Flags().Set("endpoint", tc.endpointFlag))
}
if tc.diskUUIDFlag != "" {
require.NoError(cmd.Flags().Set("disk-uuid", tc.diskUUIDFlag))
}
if tc.masterSecretFlag != "" {
require.NoError(cmd.Flags().Set("master-secret", tc.masterSecretFlag))
}
@ -170,7 +147,6 @@ func TestRecover(t *testing.T) {
assert.NoError(err)
assert.Contains(out.String(), "Pushed recovery key.")
assert.Equal(tc.masterSecret.Output, tc.client.pushStateDiskKeyKey)
})
}
}
@ -185,38 +161,24 @@ func TestParseRecoverFlags(t *testing.T) {
wantErr: true,
},
"invalid ip": {
args: []string{"-e", "192.0.2.1:2:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012"},
wantErr: true,
},
"invalid disk uuid": {
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "invalid"},
args: []string{"-e", "192.0.2.1:2:2"},
wantErr: true,
},
"minimal args set": {
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012"},
args: []string{"-e", "192.0.2.1:2"},
wantFlags: recoverFlags{
endpoint: "192.0.2.1:2",
diskUUID: "12345678-1234-1234-1234-123456789012",
secretPath: "constellation-mastersecret.json",
},
},
"all args set": {
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012", "--config", "config-path", "--master-secret", "/path/super-secret.json"},
args: []string{"-e", "192.0.2.1:2", "--config", "config-path", "--master-secret", "/path/super-secret.json"},
wantFlags: recoverFlags{
endpoint: "192.0.2.1:2",
diskUUID: "12345678-1234-1234-1234-123456789012",
secretPath: "/path/super-secret.json",
configPath: "config-path",
},
},
"uppercase disk-uuid is converted to lowercase": {
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "ABCDEFAB-CDEF-ABCD-ABCD-ABCDEFABCDEF"},
wantFlags: recoverFlags{
endpoint: "192.0.2.1:2",
diskUUID: "abcdefab-cdef-abcd-abcd-abcdefabcdef",
secretPath: "constellation-mastersecret.json",
},
},
}
for name, tc := range testCases {
@ -239,26 +201,26 @@ func TestParseRecoverFlags(t *testing.T) {
}
}
func TestDeriveStateDiskKey(t *testing.T) {
testCases := map[string]struct {
masterSecret testvector.HKDF
}{
"all zero": {
masterSecret: testvector.HKDFZero,
},
"all 0xff": {
masterSecret: testvector.HKDF0xFF,
},
}
type stubRecoveryClient struct {
conn bool
connectErr error
closeErr error
pushStateDiskKeyErr error
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
stateDiskKey, err := deriveStateDiskKey(tc.masterSecret.Secret, tc.masterSecret.Salt, tc.masterSecret.Info)
assert.NoError(err)
assert.Equal(tc.masterSecret.Output, stateDiskKey)
})
}
pushStateDiskKeyKey []byte
}
func (c *stubRecoveryClient) Connect(string, atls.Validator) error {
c.conn = true
return c.connectErr
}
func (c *stubRecoveryClient) Close() error {
c.conn = false
return c.closeErr
}
func (c *stubRecoveryClient) Recover(_ context.Context, stateDiskKey, _ []byte) error {
c.pushStateDiskKeyKey = stateDiskKey
return c.pushStateDiskKeyErr
}

View File

@ -1,20 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"context"
"io"
"github.com/edgelesssys/constellation/internal/atls"
)
type recoveryClient interface {
Connect(endpoint string, validators atls.Validator) error
PushStateDiskKey(ctx context.Context, stateDiskKey, measurementSecret []byte) error
io.Closer
}

View File

@ -1,37 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"context"
"github.com/edgelesssys/constellation/internal/atls"
)
type stubRecoveryClient struct {
conn bool
connectErr error
closeErr error
pushStateDiskKeyErr error
pushStateDiskKeyKey []byte
}
func (c *stubRecoveryClient) Connect(_ string, _ atls.Validator) error {
c.conn = true
return c.connectErr
}
func (c *stubRecoveryClient) Close() error {
c.conn = false
return c.closeErr
}
func (c *stubRecoveryClient) PushStateDiskKey(_ context.Context, stateDiskKey, _ []byte) error {
c.pushStateDiskKeyKey = stateDiskKey
return c.pushStateDiskKeyErr
}

View File

@ -9,17 +9,21 @@ package proto
import (
"context"
"errors"
"io"
"github.com/edgelesssys/constellation/disk-mapper/recoverproto"
"github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/attestation"
"github.com/edgelesssys/constellation/internal/crypto"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/state/keyproto"
"go.uber.org/multierr"
"google.golang.org/grpc"
)
// KeyClient wraps a KeyAPI client and the connection to it.
type KeyClient struct {
conn *grpc.ClientConn
keyapi keyproto.APIClient
// RecoverClient wraps a recoverAPI client and the connection to it.
type RecoverClient struct {
conn *grpc.ClientConn
recoverapi recoverproto.APIClient
}
// Connect connects the client to a given server, using the handed
@ -27,7 +31,7 @@ type KeyClient struct {
// The connection must be closed using Close(). If connect is
// called on a client that already has a connection, the old
// connection is closed.
func (c *KeyClient) Connect(endpoint string, validators atls.Validator) error {
func (c *RecoverClient) Connect(endpoint string, validators atls.Validator) error {
creds := atlscredentials.New(nil, []atls.Validator{validators})
conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(creds))
@ -36,14 +40,14 @@ func (c *KeyClient) Connect(endpoint string, validators atls.Validator) error {
}
_ = c.Close()
c.conn = conn
c.keyapi = keyproto.NewAPIClient(conn)
c.recoverapi = recoverproto.NewAPIClient(conn)
return nil
}
// Close closes the grpc connection of the client.
// Close is idempotent and can be called on non connected clients
// without returning an error.
func (c *KeyClient) Close() error {
func (c *RecoverClient) Close() error {
if c.conn == nil {
return nil
}
@ -55,17 +59,58 @@ func (c *KeyClient) Close() error {
}
// PushStateDiskKey pushes the state disk key to a constellation instance in recovery mode.
// The state disk key must be derived from the UUID of the state disk and the master key.
func (c *KeyClient) PushStateDiskKey(ctx context.Context, stateDiskKey, measurementSecret []byte) error {
if c.keyapi == nil {
func (c *RecoverClient) Recover(ctx context.Context, masterSecret, salt []byte) (retErr error) {
if c.recoverapi == nil {
return errors.New("client is not connected")
}
req := &keyproto.PushStateDiskKeyRequest{
StateDiskKey: stateDiskKey,
MeasurementSecret: measurementSecret,
measurementSecret, err := attestation.DeriveMeasurementSecret(masterSecret, salt)
if err != nil {
return err
}
_, err := c.keyapi.PushStateDiskKey(ctx, req)
return err
recoverclient, err := c.recoverapi.Recover(ctx)
if err != nil {
return err
}
defer func() {
if err := recoverclient.CloseSend(); err != nil {
multierr.AppendInto(&retErr, err)
}
}()
if err := recoverclient.Send(&recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_MeasurementSecret{
MeasurementSecret: measurementSecret,
},
}); err != nil {
return err
}
res, err := recoverclient.Recv()
if err != nil {
return err
}
stateDiskKey, err := deriveStateDiskKey(masterSecret, salt, res.DiskUuid)
if err != nil {
return err
}
if err := recoverclient.Send(&recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_StateDiskKey{
StateDiskKey: stateDiskKey,
},
}); err != nil {
return err
}
if _, err := recoverclient.Recv(); err != nil && !errors.Is(err, io.EOF) {
return err
}
return nil
}
// deriveStateDiskKey derives a state disk key from a master key, a salt, and a disk UUID.
func deriveStateDiskKey(masterKey, salt []byte, diskUUID string) ([]byte, error) {
return crypto.DeriveKey(masterKey, salt, []byte(crypto.HKDFInfoPrefix+diskUUID), crypto.StateDiskKeyLength)
}

View File

@ -0,0 +1,38 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package proto
import (
"testing"
"github.com/edgelesssys/constellation/internal/crypto/testvector"
"github.com/stretchr/testify/assert"
)
func TestDeriveStateDiskKey(t *testing.T) {
testCases := map[string]struct {
masterSecret testvector.HKDF
}{
"all zero": {
masterSecret: testvector.HKDFZero,
},
"all 0xff": {
masterSecret: testvector.HKDF0xFF,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
stateDiskKey, err := deriveStateDiskKey(tc.masterSecret.Secret, tc.masterSecret.Salt, tc.masterSecret.Info)
assert.NoError(err)
assert.Equal(tc.masterSecret.Output, stateDiskKey)
})
}
}

View File

@ -4,11 +4,11 @@ Files and source code for mounting persistent state disks
## Testing
Integration test is available in `state/test/integration_test.go`.
Integration test is available in `disk-mapper/test/integration_test.go`.
The integration test requires root privileges since it uses dm-crypt.
Build and run the test:
```bash
go test -c -tags=integration ./state/test/
go test -c -tags=integration ./disk-mapper/test/
sudo ./test.test
```

View File

@ -11,12 +11,17 @@ import (
"context"
"encoding/json"
"flag"
"net"
"net/http"
"net/url"
"path/filepath"
"strings"
"time"
"github.com/edgelesssys/constellation/disk-mapper/internal/mapper"
"github.com/edgelesssys/constellation/disk-mapper/internal/recoveryserver"
"github.com/edgelesssys/constellation/disk-mapper/internal/rejoinclient"
"github.com/edgelesssys/constellation/disk-mapper/internal/setup"
"github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/attestation/azure"
"github.com/edgelesssys/constellation/internal/attestation/gcp"
"github.com/edgelesssys/constellation/internal/attestation/qemu"
@ -26,10 +31,8 @@ import (
"github.com/edgelesssys/constellation/internal/cloud/metadata"
qemucloud "github.com/edgelesssys/constellation/internal/cloud/qemu"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/grpc/dialer"
"github.com/edgelesssys/constellation/internal/logger"
"github.com/edgelesssys/constellation/state/internal/keyservice"
"github.com/edgelesssys/constellation/state/internal/mapper"
"github.com/edgelesssys/constellation/state/internal/setup"
tpmClient "github.com/google/go-tpm-tools/client"
"github.com/google/go-tpm/tpm2"
"github.com/spf13/afero"
@ -54,8 +57,8 @@ func main() {
// set up metadata API and quote issuer for aTLS connections
var err error
var diskPath string
var issuer keyservice.QuoteIssuer
var metadata metadata.InstanceLister
var issuer atls.Issuer
var metadataAPI setup.MetadataAPI
switch strings.ToLower(*csp) {
case "azure":
diskPath, err = filepath.EvalSymlinks(azureStateDiskPath)
@ -63,7 +66,7 @@ func main() {
_ = exportPCRs()
log.With(zap.Error(err)).Fatalf("Unable to resolve Azure state disk path")
}
metadata, err = azurecloud.NewMetadata(context.Background())
metadataAPI, err = azurecloud.NewMetadata(context.Background())
if err != nil {
log.With(zap.Error).Fatalf("Failed to create Azure metadata API")
}
@ -81,12 +84,13 @@ func main() {
if err != nil {
log.With(zap.Error).Fatalf("Failed to create GCP client")
}
metadata = gcpcloud.New(gcpClient)
metadataAPI = gcpcloud.New(gcpClient)
case "qemu":
diskPath = qemuStateDiskPath
issuer = qemu.NewIssuer()
metadata = &qemucloud.Metadata{}
metadataAPI = &qemucloud.Metadata{}
_ = exportPCRs()
default:
log.Fatalf("CSP %s is not supported by Constellation", *csp)
@ -104,7 +108,6 @@ func main() {
*csp,
diskPath,
afero.Afero{Fs: afero.NewOsFs()},
keyservice.New(log.Named("keyService"), issuer, metadata, 20*time.Second, 20*time.Second), // try to request a key every 20 seconds
mapper,
setup.DiskMounter{},
vtpm.OpenVTPM,
@ -112,7 +115,23 @@ func main() {
// prepare the state disk
if mapper.IsLUKSDevice() {
err = setupManger.PrepareExistingDisk()
// set up rejoin client
var self metadata.InstanceMetadata
self, err = metadataAPI.Self(context.Background())
if err != nil {
log.With(zap.Error(err)).Fatalf("Failed to get self metadata")
}
rejoinClient := rejoinclient.New(
dialer.New(issuer, nil, &net.Dialer{}),
self,
metadataAPI,
log.Named("rejoinClient"),
)
// set up recovery server
recoveryServer := recoveryserver.New(issuer, log.Named("recoveryServer"))
err = setupManger.PrepareExistingDisk(setup.NewNodeRecoverer(recoveryServer, rejoinClient))
} else {
err = setupManger.PrepareNewDisk()
}

View File

@ -0,0 +1,134 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package recoveryserver
import (
"context"
"net"
"sync"
"github.com/edgelesssys/constellation/disk-mapper/recoverproto"
"github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/internal/logger"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// RecoveryServer is a gRPC server that can be used by an admin to recover a restarting node.
type RecoveryServer struct {
mux sync.Mutex
diskUUID string
stateDiskKey []byte
measurementSecret []byte
grpcServer server
log *logger.Logger
recoverproto.UnimplementedAPIServer
}
// New returns a new RecoveryServer.
func New(issuer atls.Issuer, log *logger.Logger) *RecoveryServer {
server := &RecoveryServer{
log: log,
}
grpcServer := grpc.NewServer(
grpc.Creds(atlscredentials.New(issuer, nil)),
log.Named("gRPC").GetServerStreamInterceptor(),
)
recoverproto.RegisterAPIServer(grpcServer, server)
server.grpcServer = grpcServer
return server
}
// Serve starts the recovery server.
// It blocks until a recover request call is successful.
// The server will shut down when the call is successful and the keys are returned.
// Additionally, the server can be shutdown by canceling the context.
func (s *RecoveryServer) Serve(ctx context.Context, listener net.Listener, diskUUID string) (diskKey, measurementSecret []byte, err error) {
s.log.Infof("Starting RecoveryServer")
s.diskUUID = diskUUID
recoveryDone := make(chan struct{}, 1)
var serveErr error
var wg sync.WaitGroup
wg.Add(1)
defer wg.Wait()
go func() {
defer wg.Done()
serveErr = s.grpcServer.Serve(listener)
recoveryDone <- struct{}{}
}()
for {
select {
case <-ctx.Done():
s.log.Infof("Context canceled, shutting down server")
s.grpcServer.GracefulStop()
return nil, nil, ctx.Err()
case <-recoveryDone:
if serveErr != nil {
return nil, nil, serveErr
}
return s.stateDiskKey, s.measurementSecret, nil
}
}
}
// Recover is a bidirectional streaming RPC that is used to send recovery keys to a restarting node.
func (s *RecoveryServer) Recover(stream recoverproto.API_RecoverServer) error {
s.mux.Lock()
defer s.mux.Unlock()
s.log.Infof("Received recover call")
msg, err := stream.Recv()
if err != nil {
return status.Error(codes.Internal, "failed to receive message")
}
measurementSecret, ok := msg.GetRequest().(*recoverproto.RecoverMessage_MeasurementSecret)
if !ok {
s.log.Errorf("Received invalid first message: not a measurement secret")
return status.Error(codes.InvalidArgument, "first message is not a measurement secret")
}
if err := stream.Send(&recoverproto.RecoverResponse{DiskUuid: s.diskUUID}); err != nil {
s.log.With(zap.Error(err)).Errorf("Failed to send disk UUID")
return status.Error(codes.Internal, "failed to send response")
}
msg, err = stream.Recv()
if err != nil {
s.log.With(zap.Error(err)).Errorf("Failed to receive disk key")
return status.Error(codes.Internal, "failed to receive message")
}
stateDiskKey, ok := msg.GetRequest().(*recoverproto.RecoverMessage_StateDiskKey)
if !ok {
s.log.Errorf("Received invalid second message: not a state disk key")
return status.Error(codes.InvalidArgument, "second message is not a state disk key")
}
s.stateDiskKey = stateDiskKey.StateDiskKey
s.measurementSecret = measurementSecret.MeasurementSecret
s.log.Infof("Received state disk key and measurement secret, shutting down server")
go s.grpcServer.GracefulStop()
return nil
}
type server interface {
Serve(net.Listener) error
GracefulStop()
}

View File

@ -0,0 +1,194 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package recoveryserver
import (
"context"
"io"
"sync"
"testing"
"time"
"github.com/edgelesssys/constellation/disk-mapper/recoverproto"
"github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/grpc/dialer"
"github.com/edgelesssys/constellation/internal/grpc/testdialer"
"github.com/edgelesssys/constellation/internal/logger"
"github.com/edgelesssys/constellation/internal/oid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestServe(t *testing.T) {
assert := assert.New(t)
log := logger.NewTest(t)
uuid := "uuid"
server := New(atls.NewFakeIssuer(oid.Dummy{}), log)
dialer := testdialer.NewBufconnDialer()
listener := dialer.GetListener("192.0.2.1:1234")
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
// Serve method returns when context is canceled
wg.Add(1)
go func() {
defer wg.Done()
_, _, err := server.Serve(ctx, listener, uuid)
assert.ErrorIs(err, context.Canceled)
}()
time.Sleep(100 * time.Millisecond)
cancel()
wg.Wait()
server = New(atls.NewFakeIssuer(oid.Dummy{}), log)
dialer = testdialer.NewBufconnDialer()
listener = dialer.GetListener("192.0.2.1:1234")
// Serve method returns without error when the server is shut down
wg.Add(1)
go func() {
defer wg.Done()
_, _, err := server.Serve(context.Background(), listener, uuid)
assert.NoError(err)
}()
time.Sleep(100 * time.Millisecond)
server.grpcServer.GracefulStop()
wg.Wait()
// Serve method returns an error when serving is unsuccessful
_, _, err := server.Serve(context.Background(), listener, uuid)
assert.Error(err)
}
func TestRecover(t *testing.T) {
testCases := map[string]struct {
initialMsg message
keyMsg message
wantErr bool
}{
"success": {
initialMsg: message{
recoverMsg: &recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_MeasurementSecret{
MeasurementSecret: []byte("measurementSecret"),
},
},
},
keyMsg: message{
recoverMsg: &recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_StateDiskKey{
StateDiskKey: []byte("diskKey"),
},
},
},
},
"first message is not a measurement secret": {
initialMsg: message{
recoverMsg: &recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_StateDiskKey{
StateDiskKey: []byte("diskKey"),
},
},
wantErr: true,
},
keyMsg: message{
recoverMsg: &recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_StateDiskKey{
StateDiskKey: []byte("diskKey"),
},
},
},
},
"second message is not a state disk key": {
initialMsg: message{
recoverMsg: &recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_MeasurementSecret{
MeasurementSecret: []byte("measurementSecret"),
},
},
},
keyMsg: message{
recoverMsg: &recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_MeasurementSecret{
MeasurementSecret: []byte("measurementSecret"),
},
},
wantErr: true,
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
ctx := context.Background()
serverUUID := "uuid"
server := New(atls.NewFakeIssuer(oid.Dummy{}), logger.NewTest(t))
netDialer := testdialer.NewBufconnDialer()
listener := netDialer.GetListener("192.0.2.1:1234")
var diskKey, measurementSecret []byte
var serveErr error
var wg sync.WaitGroup
defer wg.Wait()
serveCtx, cancel := context.WithCancel(ctx)
defer cancel()
wg.Add(1)
go func() {
defer wg.Done()
diskKey, measurementSecret, serveErr = server.Serve(serveCtx, listener, serverUUID)
}()
conn, err := dialer.New(nil, nil, netDialer).Dial(ctx, "192.0.2.1:1234")
require.NoError(err)
defer conn.Close()
client, err := recoverproto.NewAPIClient(conn).Recover(ctx)
require.NoError(err)
// Send initial message
err = client.Send(tc.initialMsg.recoverMsg)
require.NoError(err)
// Receive uuid
uuid, err := client.Recv()
if tc.initialMsg.wantErr {
assert.Error(err)
return
}
assert.Equal(serverUUID, uuid.DiskUuid)
// Send key message
err = client.Send(tc.keyMsg.recoverMsg)
require.NoError(err)
_, err = client.Recv()
if tc.keyMsg.wantErr {
assert.Error(err)
return
}
assert.ErrorIs(io.EOF, err)
wg.Wait()
assert.NoError(serveErr)
assert.Equal(tc.initialMsg.recoverMsg.GetMeasurementSecret(), measurementSecret)
assert.Equal(tc.keyMsg.recoverMsg.GetStateDiskKey(), diskKey)
})
}
}
type message struct {
recoverMsg *recoverproto.RecoverMessage
wantErr bool
}

View File

@ -0,0 +1,167 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package rejoinclient
import (
"context"
"errors"
"net"
"time"
"github.com/edgelesssys/constellation/internal/cloud/metadata"
"github.com/edgelesssys/constellation/internal/logger"
"github.com/edgelesssys/constellation/internal/role"
"github.com/edgelesssys/constellation/joinservice/joinproto"
"go.uber.org/zap"
"google.golang.org/grpc"
"k8s.io/utils/clock"
)
const (
interval = 30 * time.Second
timeout = 30 * time.Second
)
// RejoinClient is a client for requesting the needed information
// for rejoining a cluster as a restarting worker or control-plane node.
type RejoinClient struct {
diskUUID string
nodeInfo metadata.InstanceMetadata
timeout time.Duration
interval time.Duration
clock clock.WithTicker
dialer grpcDialer
metadataAPI metadataAPI
log *logger.Logger
}
// New returns a new RejoinClient.
func New(dial grpcDialer, nodeInfo metadata.InstanceMetadata,
meta metadataAPI, log *logger.Logger,
) *RejoinClient {
return &RejoinClient{
nodeInfo: nodeInfo,
timeout: timeout,
interval: interval,
clock: clock.RealClock{},
dialer: dial,
metadataAPI: meta,
log: log,
}
}
// Start starts the rejoin client.
// The client will continuously request available control-plane endpoints
// from the metadata API and send rejoin requests to them.
// The function returns after a successful rejoin request has been performed.
func (c *RejoinClient) Start(ctx context.Context, diskUUID string) (diskKey, measurementSecret []byte) {
c.log.Infof("Starting RejoinClient")
c.diskUUID = diskUUID
ticker := c.clock.NewTicker(c.interval)
defer ticker.Stop()
defer c.log.Infof("RejoinClient stopped")
for {
endpoints, err := c.getControlPlaneEndpoints()
if err != nil {
c.log.With(zap.Error(err)).Errorf("Failed to get control-plane endpoints")
} else {
c.log.With(zap.Strings("endpoints", endpoints)).Infof("Received list with JoinService endpoints")
diskKey, measurementSecret, err = c.tryRejoinWithAvailableServices(ctx, endpoints)
if err == nil {
c.log.Infof("Successfully retrieved rejoin ticket")
return diskKey, measurementSecret
}
}
select {
case <-ctx.Done():
return nil, nil
case <-ticker.C():
}
}
}
// tryRejoinWithAvailableServices tries sending rejoin requests to the available endpoints.
func (c *RejoinClient) tryRejoinWithAvailableServices(ctx context.Context, endpoints []string) (diskKey, measurementSecret []byte, err error) {
for _, endpoint := range endpoints {
c.log.With(zap.String("endpoint", endpoint)).Infof("Requesting rejoin ticket")
rejoinTicket, err := c.requestRejoinTicket(endpoint)
if err == nil {
return rejoinTicket.StateDiskKey, rejoinTicket.MeasurementSecret, nil
}
c.log.With(zap.Error(err), zap.String("endpoint", endpoint)).Warnf("Failed to rejoin on endpoint")
// stop requesting additional endpoints if the context is done
select {
case <-ctx.Done():
return nil, nil, ctx.Err()
default:
}
}
c.log.Errorf("Failed to rejoin on all endpoints")
return nil, nil, errors.New("failed to join on all endpoints")
}
// requestRejoinTicket requests a rejoin ticket from the endpoint.
func (c *RejoinClient) requestRejoinTicket(endpoint string) (*joinproto.IssueRejoinTicketResponse, error) {
ctx, cancel := c.timeoutCtx()
defer cancel()
conn, err := c.dialer.Dial(ctx, endpoint)
if err != nil {
return nil, err
}
defer conn.Close()
return joinproto.NewAPIClient(conn).IssueRejoinTicket(ctx, &joinproto.IssueRejoinTicketRequest{DiskUuid: c.diskUUID})
}
// getControlPlaneEndpoints requests the available control-plane endpoints from the metadata API.
// The list is filtered to remove *this* node if it is a restarting control-plane node.
func (c *RejoinClient) getControlPlaneEndpoints() ([]string, error) {
ctx, cancel := c.timeoutCtx()
defer cancel()
endpoints, err := metadata.JoinServiceEndpoints(ctx, c.metadataAPI)
if err != nil {
return nil, err
}
if c.nodeInfo.Role == role.ControlPlane {
return removeSelfFromEndpoints(c.nodeInfo.VPCIP, endpoints), nil
}
return endpoints, nil
}
// removeSelfFromEndpoints removes *this* node from the list of endpoints.
// If an error occurs, the entry is removed from the list of endpoints.
func removeSelfFromEndpoints(self string, endpoints []string) []string {
var result []string
for _, endpoint := range endpoints {
host, _, err := net.SplitHostPort(endpoint)
if err == nil && host != self {
result = append(result, endpoint)
}
}
return result
}
func (c *RejoinClient) timeoutCtx() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), c.timeout)
}
type grpcDialer interface {
Dial(ctx context.Context, target string) (*grpc.ClientConn, error)
}
type metadataAPI interface {
// List retrieves all instances belonging to the current constellation.
List(ctx context.Context) ([]metadata.InstanceMetadata, error)
}

View File

@ -0,0 +1,308 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package rejoinclient
import (
"context"
"errors"
"net"
"strconv"
"sync"
"testing"
"time"
"github.com/edgelesssys/constellation/internal/cloud/metadata"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/internal/grpc/dialer"
"github.com/edgelesssys/constellation/internal/grpc/testdialer"
"github.com/edgelesssys/constellation/internal/logger"
"github.com/edgelesssys/constellation/internal/role"
"github.com/edgelesssys/constellation/joinservice/joinproto"
"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
"google.golang.org/grpc"
testclock "k8s.io/utils/clock/testing"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestStartCancel(t *testing.T) {
netDialer := testdialer.NewBufconnDialer()
dialer := dialer.New(nil, nil, netDialer)
clock := testclock.NewFakeClock(time.Time{})
metaAPI := &stubMetadataAPI{
instances: []metadata.InstanceMetadata{
{
Role: role.ControlPlane,
VPCIP: "192.0.2.1",
},
{
Role: role.ControlPlane,
VPCIP: "192.0.2.1",
},
},
}
client := &RejoinClient{
dialer: dialer,
nodeInfo: metadata.InstanceMetadata{Role: role.Worker},
metadataAPI: metaAPI,
log: logger.NewTest(t),
timeout: time.Second * 30,
interval: time.Second,
clock: clock,
}
serverCreds := atlscredentials.New(nil, nil)
rejoinServer := grpc.NewServer(grpc.Creds(serverCreds))
rejoinServiceAPI := &stubRejoinServiceAPI{err: errors.New("error")}
joinproto.RegisterAPIServer(rejoinServer, rejoinServiceAPI)
port := strconv.Itoa(constants.JoinServiceNodePort)
listener := netDialer.GetListener(net.JoinHostPort("192.0.2.1", port))
go rejoinServer.Serve(listener)
defer rejoinServer.GracefulStop()
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
client.Start(ctx, "uuid")
}()
clock.Step(time.Millisecond)
cancel()
wg.Wait()
assert.Equal(t, client.diskUUID, "uuid")
}
func TestRemoveSelfFromEndpoints(t *testing.T) {
testCases := map[string]struct {
self string
endpoints []string
}{
"self is not in endpoints": {
self: "192.0.2.1",
endpoints: []string{
"192.0.2.2:30090",
"192.0.2.3:30090",
"192.0.2.4:30090",
"192.0.2.5:30090",
"192.0.2.6:30090",
},
},
"self is in endpoints": {
self: "192.0.2.1",
endpoints: []string{
"192.0.2.2:30090",
"192.0.2.3:30090",
"192.0.2.4:30090",
"192.0.2.5:30090",
"192.0.2.6:30090",
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
got := removeSelfFromEndpoints(tc.self, tc.endpoints)
assert.NotContains(got, tc.self)
})
}
}
func TestGetControlPlaneEndpoints(t *testing.T) {
testInstances := []metadata.InstanceMetadata{
{
Role: role.ControlPlane,
VPCIP: "192.0.2.2",
},
{
Role: role.ControlPlane,
VPCIP: "192.0.2.3",
},
{
Role: role.ControlPlane,
VPCIP: "192.0.2.4",
},
{
Role: role.Worker,
VPCIP: "192.0.2.12",
},
{
Role: role.Worker,
VPCIP: "192.0.2.13",
},
{
Role: role.Worker,
VPCIP: "192.0.2.14",
},
}
testCases := map[string]struct {
nodeInfo metadata.InstanceMetadata
meta stubMetadataAPI
wantInstances int
wantErr bool
}{
"worker node": {
nodeInfo: metadata.InstanceMetadata{
Role: role.Worker,
VPCIP: "192.0.2.1",
},
meta: stubMetadataAPI{
instances: testInstances,
},
wantInstances: 3,
},
"control-plane node not in list": {
nodeInfo: metadata.InstanceMetadata{
Role: role.ControlPlane,
VPCIP: "192.0.2.1",
},
meta: stubMetadataAPI{
instances: testInstances,
},
wantInstances: 3,
},
"control-plane node in list": {
nodeInfo: metadata.InstanceMetadata{
Role: role.ControlPlane,
VPCIP: "192.0.2.2",
},
meta: stubMetadataAPI{
instances: testInstances,
},
wantInstances: 2,
},
"metadata error": {
nodeInfo: metadata.InstanceMetadata{
Role: role.ControlPlane,
VPCIP: "192.0.2.1",
},
meta: stubMetadataAPI{
err: errors.New("error"),
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := New(nil, tc.nodeInfo, tc.meta, logger.NewTest(t))
endpoints, err := client.getControlPlaneEndpoints()
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.NotContains(endpoints, tc.nodeInfo.VPCIP)
assert.Len(endpoints, tc.wantInstances)
}
})
}
}
func TestStart(t *testing.T) {
testCases := map[string]struct {
nodeInfo metadata.InstanceMetadata
}{
"worker node": {
nodeInfo: metadata.InstanceMetadata{
Role: role.Worker,
VPCIP: "192.0.2.99",
},
},
"control-plane node": {
nodeInfo: metadata.InstanceMetadata{
Role: role.ControlPlane,
VPCIP: "192.0.2.99",
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
diskKey := []byte("disk-key")
measurementSecret := []byte("measurement-secret")
netDialer := testdialer.NewBufconnDialer()
dialer := dialer.New(nil, nil, netDialer)
serverCreds := atlscredentials.New(nil, nil)
rejoinServer := grpc.NewServer(grpc.Creds(serverCreds))
rejoinServiceAPI := &stubRejoinServiceAPI{
rejoinTicketResponse: &joinproto.IssueRejoinTicketResponse{
StateDiskKey: diskKey,
MeasurementSecret: measurementSecret,
},
}
joinproto.RegisterAPIServer(rejoinServer, rejoinServiceAPI)
port := strconv.Itoa(constants.JoinServiceNodePort)
listener := netDialer.GetListener(net.JoinHostPort("192.0.2.1", port))
go rejoinServer.Serve(listener)
defer rejoinServer.GracefulStop()
meta := stubMetadataAPI{
instances: []metadata.InstanceMetadata{
{
Role: role.ControlPlane,
VPCIP: "192.0.2.1",
},
{
Role: role.ControlPlane,
VPCIP: "192.0.2.2",
},
{
Role: role.Worker,
VPCIP: "192.0.2.13",
},
{
Role: role.Worker,
VPCIP: "192.0.2.14",
},
},
}
client := New(dialer, tc.nodeInfo, meta, logger.NewTest(t))
passphrase, secret := client.Start(context.Background(), "uuid")
assert.Equal(diskKey, passphrase)
assert.Equal(measurementSecret, secret)
})
}
}
type stubMetadataAPI struct {
instances []metadata.InstanceMetadata
err error
}
func (s stubMetadataAPI) List(context.Context) ([]metadata.InstanceMetadata, error) {
return s.instances, s.err
}
type stubRejoinServiceAPI struct {
rejoinTicketResponse *joinproto.IssueRejoinTicketResponse
err error
joinproto.UnimplementedAPIServer
}
func (s *stubRejoinServiceAPI) IssueRejoinTicket(context.Context, *joinproto.IssueRejoinTicketRequest,
) (*joinproto.IssueRejoinTicketResponse, error) {
return s.rejoinTicketResponse, s.err
}

View File

@ -10,6 +10,8 @@ import (
"io/fs"
"os"
"syscall"
"github.com/edgelesssys/constellation/internal/cloud/metadata"
)
// Mounter is an interface for mount and unmount operations.
@ -27,17 +29,23 @@ type DeviceMapper interface {
UnmapDisk(target string) error
}
// KeyWaiter is an interface to request and wait for disk decryption keys.
type KeyWaiter interface {
WaitForDecryptionKey(uuid, addr string) (diskKey, measurementSecret []byte, err error)
ResetKey()
}
// ConfigurationGenerator is an interface for generating systemd-cryptsetup@.service unit files.
type ConfigurationGenerator interface {
Generate(volumeName, encryptedDevice, keyFile, options string) error
}
// MetadataAPI is an interface for accessing cloud metadata.
type MetadataAPI interface {
metadata.InstanceSelfer
metadata.InstanceLister
}
// RecoveryDoer is an interface to perform key recovery operations.
// Calls to Do may be blocking, and if successful return a passphrase and measurementSecret.
type RecoveryDoer interface {
Do(uuid, endpoint string) (passphrase, measurementSecret []byte, err error)
}
// DiskMounter uses the syscall package to mount disks.
type DiskMounter struct{}

View File

@ -7,14 +7,18 @@ SPDX-License-Identifier: AGPL-3.0-only
package setup
import (
"context"
"crypto/rand"
"errors"
"fmt"
"net"
"os"
"path/filepath"
"strconv"
"sync"
"syscall"
"github.com/edgelesssys/constellation/disk-mapper/internal/systemd"
"github.com/edgelesssys/constellation/internal/attestation"
"github.com/edgelesssys/constellation/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/internal/constants"
@ -22,9 +26,7 @@ import (
"github.com/edgelesssys/constellation/internal/file"
"github.com/edgelesssys/constellation/internal/logger"
"github.com/edgelesssys/constellation/internal/nodestate"
"github.com/edgelesssys/constellation/state/internal/systemd"
"github.com/spf13/afero"
"go.uber.org/zap"
)
const (
@ -38,56 +40,53 @@ const (
// SetupManager handles formatting, mapping, mounting and unmounting of state disks.
type SetupManager struct {
log *logger.Logger
csp string
diskPath string
fs afero.Afero
keyWaiter KeyWaiter
mapper DeviceMapper
mounter Mounter
config ConfigurationGenerator
openTPM vtpm.TPMOpenFunc
log *logger.Logger
csp string
diskPath string
fs afero.Afero
mapper DeviceMapper
mounter Mounter
config ConfigurationGenerator
openTPM vtpm.TPMOpenFunc
}
// New initializes a SetupManager with the given parameters.
func New(log *logger.Logger, csp string, diskPath string, fs afero.Afero, keyWaiter KeyWaiter, mapper DeviceMapper, mounter Mounter, openTPM vtpm.TPMOpenFunc) *SetupManager {
func New(log *logger.Logger, csp string, diskPath string, fs afero.Afero,
mapper DeviceMapper, mounter Mounter, openTPM vtpm.TPMOpenFunc,
) *SetupManager {
return &SetupManager{
log: log,
csp: csp,
diskPath: diskPath,
fs: fs,
keyWaiter: keyWaiter,
mapper: mapper,
mounter: mounter,
config: systemd.New(fs),
openTPM: openTPM,
log: log,
csp: csp,
diskPath: diskPath,
fs: fs,
mapper: mapper,
mounter: mounter,
config: systemd.New(fs),
openTPM: openTPM,
}
}
// PrepareExistingDisk requests and waits for a decryption key to remap the encrypted state disk.
// Once the disk is mapped, the function taints the node as initialized by updating it's PCRs.
func (s *SetupManager) PrepareExistingDisk() error {
func (s *SetupManager) PrepareExistingDisk(recover RecoveryDoer) error {
s.log.Infof("Preparing existing state disk")
uuid := s.mapper.DiskUUID()
endpoint := net.JoinHostPort("0.0.0.0", strconv.Itoa(constants.RecoveryPort))
getKey:
passphrase, measurementSecret, err := s.keyWaiter.WaitForDecryptionKey(uuid, endpoint)
passphrase, measurementSecret, err := recover.Do(uuid, endpoint)
if err != nil {
return err
return fmt.Errorf("failed to perform recovery: %w", err)
}
if err := s.mapper.MapDisk(stateDiskMappedName, string(passphrase)); err != nil {
// retry key fetching if disk mapping fails
s.log.With(zap.Error(err)).Errorf("Failed to map state disk, retrying...")
s.keyWaiter.ResetKey()
goto getKey
return err
}
if err := s.mounter.MkdirAll(stateDiskMountPath, os.ModePerm); err != nil {
return err
}
// we do not care about cleaning up the mount point on error, since any errors returned here should result in a kernel panic in the main function
// we do not care about cleaning up the mount point on error, since any errors returned here should cause a boot failure
if err := s.mounter.Mount(filepath.Join("/dev/mapper/", stateDiskMappedName), stateDiskMountPath, "ext4", syscall.MS_RDONLY, ""); err != nil {
return err
}
@ -160,3 +159,68 @@ func (s *SetupManager) saveConfiguration(passphrase []byte) error {
// systemd cryptsetup unit
return s.config.Generate(stateDiskMappedName, s.diskPath, filepath.Join(keyPath, keyFile), cryptsetupOptions)
}
type recoveryServer interface {
Serve(context.Context, net.Listener, string) (key, secret []byte, err error)
}
type rejoinClient interface {
Start(context.Context, string) (key, secret []byte)
}
type nodeRecoverer struct {
recoveryServer recoveryServer
rejoinClient rejoinClient
}
// NewNodeRecoverer initializes a new nodeRecoverer.
func NewNodeRecoverer(recoveryServer recoveryServer, rejoinClient rejoinClient) *nodeRecoverer {
return &nodeRecoverer{
recoveryServer: recoveryServer,
rejoinClient: rejoinClient,
}
}
// Do performs a recovery procedure on the given state disk.
// The method starts a gRPC server to allow manual recovery by a user.
// At the same time it tries to request a decryption key from all available Constellation control-plane nodes.
func (r *nodeRecoverer) Do(uuid, endpoint string) (passphrase, measurementSecret []byte, err error) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
lis, err := net.Listen("tcp", endpoint)
if err != nil {
return nil, nil, err
}
defer lis.Close()
var once sync.Once
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
key, secret, serveErr := r.recoveryServer.Serve(ctx, lis, uuid)
once.Do(func() {
cancel()
passphrase = key
measurementSecret = secret
})
if serveErr != nil && !errors.Is(serveErr, context.Canceled) {
err = serveErr
}
}()
wg.Add(1)
go func() {
defer wg.Done()
key, secret := r.rejoinClient.Start(ctx, uuid)
once.Do(func() {
cancel()
passphrase = key
measurementSecret = secret
})
}()
wg.Wait()
return passphrase, measurementSecret, err
}

View File

@ -7,10 +7,13 @@ SPDX-License-Identifier: AGPL-3.0-only
package setup
import (
"context"
"errors"
"io"
"io/fs"
"net"
"path/filepath"
"sync"
"testing"
"github.com/edgelesssys/constellation/internal/attestation/vtpm"
@ -30,9 +33,13 @@ func TestMain(m *testing.M) {
func TestPrepareExistingDisk(t *testing.T) {
someErr := errors.New("error")
testRecoveryDoer := &stubRecoveryDoer{
passphrase: []byte("passphrase"),
secret: []byte("secret"),
}
testCases := map[string]struct {
keyWaiter *stubKeyWaiter
recoveryDoer *stubRecoveryDoer
mapper *stubMapper
mounter *stubMounter
configGenerator *stubConfigurationGenerator
@ -41,34 +48,33 @@ func TestPrepareExistingDisk(t *testing.T) {
wantErr bool
}{
"success": {
keyWaiter: &stubKeyWaiter{},
recoveryDoer: testRecoveryDoer,
mapper: &stubMapper{uuid: "test"},
mounter: &stubMounter{},
configGenerator: &stubConfigurationGenerator{},
openTPM: vtpm.OpenNOPTPM,
},
"WaitForDecryptionKey fails": {
keyWaiter: &stubKeyWaiter{waitErr: someErr},
recoveryDoer: &stubRecoveryDoer{recoveryErr: someErr},
mapper: &stubMapper{uuid: "test"},
mounter: &stubMounter{},
configGenerator: &stubConfigurationGenerator{},
openTPM: vtpm.OpenNOPTPM,
wantErr: true,
},
"MapDisk fails causes a repeat": {
keyWaiter: &stubKeyWaiter{},
"MapDisk fails": {
recoveryDoer: testRecoveryDoer,
mapper: &stubMapper{
uuid: "test",
mapDiskErr: someErr,
mapDiskRepeatedCalls: 2,
uuid: "test",
mapDiskErr: someErr,
},
mounter: &stubMounter{},
configGenerator: &stubConfigurationGenerator{},
openTPM: vtpm.OpenNOPTPM,
wantErr: false,
wantErr: true,
},
"MkdirAll fails": {
keyWaiter: &stubKeyWaiter{},
recoveryDoer: testRecoveryDoer,
mapper: &stubMapper{uuid: "test"},
mounter: &stubMounter{mkdirAllErr: someErr},
configGenerator: &stubConfigurationGenerator{},
@ -76,7 +82,7 @@ func TestPrepareExistingDisk(t *testing.T) {
wantErr: true,
},
"Mount fails": {
keyWaiter: &stubKeyWaiter{},
recoveryDoer: testRecoveryDoer,
mapper: &stubMapper{uuid: "test"},
mounter: &stubMounter{mountErr: someErr},
configGenerator: &stubConfigurationGenerator{},
@ -84,7 +90,7 @@ func TestPrepareExistingDisk(t *testing.T) {
wantErr: true,
},
"Unmount fails": {
keyWaiter: &stubKeyWaiter{},
recoveryDoer: testRecoveryDoer,
mapper: &stubMapper{uuid: "test"},
mounter: &stubMounter{unmountErr: someErr},
configGenerator: &stubConfigurationGenerator{},
@ -92,7 +98,7 @@ func TestPrepareExistingDisk(t *testing.T) {
wantErr: true,
},
"MarkNodeAsBootstrapped fails": {
keyWaiter: &stubKeyWaiter{},
recoveryDoer: testRecoveryDoer,
mapper: &stubMapper{uuid: "test"},
mounter: &stubMounter{unmountErr: someErr},
configGenerator: &stubConfigurationGenerator{},
@ -100,7 +106,7 @@ func TestPrepareExistingDisk(t *testing.T) {
wantErr: true,
},
"Generating config fails": {
keyWaiter: &stubKeyWaiter{},
recoveryDoer: testRecoveryDoer,
mapper: &stubMapper{uuid: "test"},
mounter: &stubMounter{},
configGenerator: &stubConfigurationGenerator{generateErr: someErr},
@ -108,7 +114,7 @@ func TestPrepareExistingDisk(t *testing.T) {
wantErr: true,
},
"no state file": {
keyWaiter: &stubKeyWaiter{},
recoveryDoer: testRecoveryDoer,
mapper: &stubMapper{uuid: "test"},
mounter: &stubMounter{},
configGenerator: &stubConfigurationGenerator{},
@ -130,23 +136,21 @@ func TestPrepareExistingDisk(t *testing.T) {
}
setupManager := &SetupManager{
log: logger.NewTest(t),
csp: "test",
diskPath: "disk-path",
fs: fs,
keyWaiter: tc.keyWaiter,
mapper: tc.mapper,
mounter: tc.mounter,
config: tc.configGenerator,
openTPM: tc.openTPM,
log: logger.NewTest(t),
csp: "test",
diskPath: "disk-path",
fs: fs,
mapper: tc.mapper,
mounter: tc.mounter,
config: tc.configGenerator,
openTPM: tc.openTPM,
}
err := setupManager.PrepareExistingDisk()
err := setupManager.PrepareExistingDisk(tc.recoveryDoer)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.mapper.uuid, tc.keyWaiter.receivedUUID)
assert.True(tc.mapper.mapDiskCalled)
assert.True(tc.mounter.mountCalled)
assert.True(tc.mounter.unmountCalled)
@ -191,9 +195,8 @@ func TestPrepareNewDisk(t *testing.T) {
"MapDisk fails": {
fs: afero.Afero{Fs: afero.NewMemMapFs()},
mapper: &stubMapper{
uuid: "test",
mapDiskErr: someErr,
mapDiskRepeatedCalls: 1,
uuid: "test",
mapDiskErr: someErr,
},
configGenerator: &stubConfigurationGenerator{},
wantErr: true,
@ -267,7 +270,7 @@ func TestReadMeasurementSalt(t *testing.T) {
require.NoError(handler.WriteJSON("test-state.json", state, file.OptMkdirAll))
}
setupManager := New(logger.NewTest(t), "test", "disk-path", fs, nil, nil, nil, nil)
setupManager := New(logger.NewTest(t), "test", "disk-path", fs, nil, nil, nil)
measurementSalt, err := setupManager.readMeasurementSalt("test-state.json")
if tc.wantErr {
@ -280,15 +283,115 @@ func TestReadMeasurementSalt(t *testing.T) {
}
}
func TestRecoveryDoer(t *testing.T) {
assert := assert.New(t)
rejoinClientKey := []byte("rejoinClientKey")
rejoinClientSecret := []byte("rejoinClientSecret")
recoveryServerKey := []byte("recoveryServerKey")
recoveryServerSecret := []byte("recoveryServerSecret")
recoveryServerErr := errors.New("error")
recoveryServer := &stubRecoveryServer{
key: recoveryServerKey,
secret: recoveryServerSecret,
sendKeys: make(chan struct{}, 1),
err: recoveryServerErr,
}
rejoinClient := &stubRejoinClient{
key: rejoinClientKey,
secret: rejoinClientSecret,
sendKeys: make(chan struct{}, 1),
}
recoverer := NewNodeRecoverer(recoveryServer, rejoinClient)
var wg sync.WaitGroup
var key, secret []byte
var err error
// error from recovery server
wg.Add(1)
go func() {
defer wg.Done()
key, secret, err = recoverer.Do("", "")
}()
recoveryServer.sendKeys <- struct{}{}
wg.Wait()
assert.ErrorIs(err, recoveryServerErr)
recoveryServer.err = nil
recoveryServer.sendKeys = make(chan struct{}, 1)
// recovery server returns its key and secret
wg.Add(1)
go func() {
defer wg.Done()
key, secret, err = recoverer.Do("", "")
}()
recoveryServer.sendKeys <- struct{}{}
wg.Wait()
assert.NoError(err)
assert.Equal(recoveryServerKey, key)
assert.Equal(recoveryServerSecret, secret)
recoveryServer.sendKeys = make(chan struct{}, 1)
// rejoin client returns its key and secret
wg.Add(1)
go func() {
defer wg.Done()
key, secret, err = recoverer.Do("", "")
}()
rejoinClient.sendKeys <- struct{}{}
wg.Wait()
assert.NoError(err)
assert.Equal(rejoinClientKey, key)
assert.Equal(rejoinClientSecret, secret)
}
type stubRecoveryServer struct {
key []byte
secret []byte
sendKeys chan struct{}
err error
}
func (s *stubRecoveryServer) Serve(ctx context.Context, _ net.Listener, _ string) ([]byte, []byte, error) {
for {
select {
case <-ctx.Done():
return nil, nil, ctx.Err()
case <-s.sendKeys:
return s.key, s.secret, s.err
}
}
}
type stubRejoinClient struct {
key []byte
secret []byte
sendKeys chan struct{}
}
func (s *stubRejoinClient) Start(ctx context.Context, _ string) ([]byte, []byte) {
for {
select {
case <-ctx.Done():
return nil, nil
case <-s.sendKeys:
return s.key, s.secret
}
}
}
type stubMapper struct {
formatDiskCalled bool
formatDiskErr error
mapDiskRepeatedCalls int
mapDiskCalled bool
mapDiskErr error
unmapDiskCalled bool
unmapDiskErr error
uuid string
formatDiskCalled bool
formatDiskErr error
mapDiskCalled bool
mapDiskErr error
unmapDiskCalled bool
unmapDiskErr error
uuid string
}
func (s *stubMapper) DiskUUID() string {
@ -301,10 +404,6 @@ func (s *stubMapper) FormatDisk(string) error {
}
func (s *stubMapper) MapDisk(string, string) error {
if s.mapDiskRepeatedCalls == 0 {
s.mapDiskErr = nil
}
s.mapDiskRepeatedCalls--
s.mapDiskCalled = true
return s.mapDiskErr
}
@ -336,25 +435,14 @@ func (s *stubMounter) MkdirAll(path string, perm fs.FileMode) error {
return s.mkdirAllErr
}
type stubKeyWaiter struct {
receivedUUID string
decryptionKey []byte
measurementSecret []byte
waitErr error
waitCalled bool
type stubRecoveryDoer struct {
passphrase []byte
secret []byte
recoveryErr error
}
func (s *stubKeyWaiter) WaitForDecryptionKey(uuid, addr string) ([]byte, []byte, error) {
if s.waitCalled {
return nil, nil, errors.New("wait called before key was reset")
}
s.waitCalled = true
s.receivedUUID = uuid
return s.decryptionKey, s.measurementSecret, s.waitErr
}
func (s *stubKeyWaiter) ResetKey() {
s.waitCalled = false
func (s *stubRecoveryDoer) Do(uuid, endpoint string) (passphrase, measurementSecret []byte, err error) {
return s.passphrase, s.secret, s.recoveryErr
}
type stubConfigurationGenerator struct {

View File

@ -13,8 +13,8 @@ import (
"math"
"testing"
"github.com/edgelesssys/constellation/disk-mapper/internal/mapper"
"github.com/edgelesssys/constellation/internal/logger"
"github.com/edgelesssys/constellation/state/internal/mapper"
"github.com/martinjungblut/go-cryptsetup"
"go.uber.org/zap/zapcore"
)

View File

@ -9,29 +9,18 @@ SPDX-License-Identifier: AGPL-3.0-only
package integration
import (
"context"
"flag"
"fmt"
"net"
"os"
"os/exec"
"testing"
"time"
"github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/cloud/metadata"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/disk-mapper/internal/mapper"
"github.com/edgelesssys/constellation/internal/logger"
"github.com/edgelesssys/constellation/internal/oid"
"github.com/edgelesssys/constellation/internal/role"
"github.com/edgelesssys/constellation/state/internal/keyservice"
"github.com/edgelesssys/constellation/state/internal/mapper"
"github.com/edgelesssys/constellation/state/keyproto"
"github.com/martinjungblut/go-cryptsetup"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"google.golang.org/grpc"
)
const (
@ -93,53 +82,7 @@ func TestMapper(t *testing.T) {
assert.Error(mapper.MapDisk(mappedDevice, "invalid-passphrase"), "was able to map disk with incorrect passphrase")
}
func TestKeyAPI(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
testKey := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
testSecret := []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB")
// get a free port on localhost to run the test on
listener, err := net.Listen("tcp", "localhost:0")
require.NoError(err)
apiAddr := listener.Addr().String()
listener.Close()
api := keyservice.New(
logger.NewTest(t),
atls.NewFakeIssuer(oid.Dummy{}),
&fakeMetadataAPI{},
20*time.Second,
time.Second,
)
// send a key to the server
go func() {
// wait 2 seconds before sending the key
time.Sleep(2 * time.Second)
creds := atlscredentials.New(nil, nil)
conn, err := grpc.Dial(apiAddr, grpc.WithTransportCredentials(creds))
require.NoError(err)
defer conn.Close()
client := keyproto.NewAPIClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
_, err = client.PushStateDiskKey(ctx, &keyproto.PushStateDiskKeyRequest{
StateDiskKey: testKey,
MeasurementSecret: testSecret,
})
require.NoError(err)
}()
key, measurementSecret, err := api.WaitForDecryptionKey("12345678-1234-1234-1234-123456789ABC", apiAddr)
assert.NoError(err)
assert.Equal(testKey, key)
assert.Equal(testSecret, measurementSecret)
}
/*
type fakeMetadataAPI struct{}
func (f *fakeMetadataAPI) List(ctx context.Context) ([]metadata.InstanceMetadata, error) {
@ -152,3 +95,4 @@ func (f *fakeMetadataAPI) List(ctx context.Context) ([]metadata.InstanceMetadata
},
}, nil
}
*/

View File

@ -0,0 +1,258 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.0
// protoc v3.20.1
// source: recover.proto
package recoverproto
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type RecoverMessage struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// Types that are assignable to Request:
// *RecoverMessage_StateDiskKey
// *RecoverMessage_MeasurementSecret
Request isRecoverMessage_Request `protobuf_oneof:"request"`
}
func (x *RecoverMessage) Reset() {
*x = RecoverMessage{}
if protoimpl.UnsafeEnabled {
mi := &file_recover_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *RecoverMessage) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*RecoverMessage) ProtoMessage() {}
func (x *RecoverMessage) ProtoReflect() protoreflect.Message {
mi := &file_recover_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use RecoverMessage.ProtoReflect.Descriptor instead.
func (*RecoverMessage) Descriptor() ([]byte, []int) {
return file_recover_proto_rawDescGZIP(), []int{0}
}
func (m *RecoverMessage) GetRequest() isRecoverMessage_Request {
if m != nil {
return m.Request
}
return nil
}
func (x *RecoverMessage) GetStateDiskKey() []byte {
if x, ok := x.GetRequest().(*RecoverMessage_StateDiskKey); ok {
return x.StateDiskKey
}
return nil
}
func (x *RecoverMessage) GetMeasurementSecret() []byte {
if x, ok := x.GetRequest().(*RecoverMessage_MeasurementSecret); ok {
return x.MeasurementSecret
}
return nil
}
type isRecoverMessage_Request interface {
isRecoverMessage_Request()
}
type RecoverMessage_StateDiskKey struct {
StateDiskKey []byte `protobuf:"bytes,1,opt,name=state_disk_key,json=stateDiskKey,proto3,oneof"`
}
type RecoverMessage_MeasurementSecret struct {
MeasurementSecret []byte `protobuf:"bytes,2,opt,name=measurement_secret,json=measurementSecret,proto3,oneof"`
}
func (*RecoverMessage_StateDiskKey) isRecoverMessage_Request() {}
func (*RecoverMessage_MeasurementSecret) isRecoverMessage_Request() {}
type RecoverResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
DiskUuid string `protobuf:"bytes,1,opt,name=disk_uuid,json=diskUuid,proto3" json:"disk_uuid,omitempty"`
}
func (x *RecoverResponse) Reset() {
*x = RecoverResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_recover_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *RecoverResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*RecoverResponse) ProtoMessage() {}
func (x *RecoverResponse) ProtoReflect() protoreflect.Message {
mi := &file_recover_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use RecoverResponse.ProtoReflect.Descriptor instead.
func (*RecoverResponse) Descriptor() ([]byte, []int) {
return file_recover_proto_rawDescGZIP(), []int{1}
}
func (x *RecoverResponse) GetDiskUuid() string {
if x != nil {
return x.DiskUuid
}
return ""
}
var File_recover_proto protoreflect.FileDescriptor
var file_recover_proto_rawDesc = []byte{
0x0a, 0x0d, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12,
0x0c, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x74, 0x0a,
0x0e, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12,
0x26, 0x0a, 0x0e, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5f, 0x64, 0x69, 0x73, 0x6b, 0x5f, 0x6b, 0x65,
0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x48, 0x00, 0x52, 0x0c, 0x73, 0x74, 0x61, 0x74, 0x65,
0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79, 0x12, 0x2f, 0x0a, 0x12, 0x6d, 0x65, 0x61, 0x73, 0x75,
0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20,
0x01, 0x28, 0x0c, 0x48, 0x00, 0x52, 0x11, 0x6d, 0x65, 0x61, 0x73, 0x75, 0x72, 0x65, 0x6d, 0x65,
0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x42, 0x09, 0x0a, 0x07, 0x72, 0x65, 0x71, 0x75,
0x65, 0x73, 0x74, 0x22, 0x2e, 0x0a, 0x0f, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x52, 0x65,
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x64, 0x69, 0x73, 0x6b, 0x5f, 0x75,
0x75, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x64, 0x69, 0x73, 0x6b, 0x55,
0x75, 0x69, 0x64, 0x32, 0x53, 0x0a, 0x03, 0x41, 0x50, 0x49, 0x12, 0x4c, 0x0a, 0x07, 0x52, 0x65,
0x63, 0x6f, 0x76, 0x65, 0x72, 0x12, 0x1c, 0x2e, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x4d, 0x65, 0x73, 0x73,
0x61, 0x67, 0x65, 0x1a, 0x1d, 0x2e, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x3f, 0x5a, 0x3d, 0x67, 0x69, 0x74, 0x68,
0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x65, 0x64, 0x67, 0x65, 0x6c, 0x65, 0x73, 0x73, 0x73,
0x79, 0x73, 0x2f, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x65, 0x6c, 0x6c, 0x61, 0x74, 0x69, 0x6f, 0x6e,
0x2f, 0x64, 0x69, 0x73, 0x6b, 0x2d, 0x6d, 0x61, 0x70, 0x70, 0x65, 0x72, 0x2f, 0x72, 0x65, 0x63,
0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x33,
}
var (
file_recover_proto_rawDescOnce sync.Once
file_recover_proto_rawDescData = file_recover_proto_rawDesc
)
func file_recover_proto_rawDescGZIP() []byte {
file_recover_proto_rawDescOnce.Do(func() {
file_recover_proto_rawDescData = protoimpl.X.CompressGZIP(file_recover_proto_rawDescData)
})
return file_recover_proto_rawDescData
}
var file_recover_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_recover_proto_goTypes = []interface{}{
(*RecoverMessage)(nil), // 0: recoverproto.RecoverMessage
(*RecoverResponse)(nil), // 1: recoverproto.RecoverResponse
}
var file_recover_proto_depIdxs = []int32{
0, // 0: recoverproto.API.Recover:input_type -> recoverproto.RecoverMessage
1, // 1: recoverproto.API.Recover:output_type -> recoverproto.RecoverResponse
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_recover_proto_init() }
func file_recover_proto_init() {
if File_recover_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_recover_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*RecoverMessage); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_recover_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*RecoverResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
file_recover_proto_msgTypes[0].OneofWrappers = []interface{}{
(*RecoverMessage_StateDiskKey)(nil),
(*RecoverMessage_MeasurementSecret)(nil),
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_recover_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_recover_proto_goTypes,
DependencyIndexes: file_recover_proto_depIdxs,
MessageInfos: file_recover_proto_msgTypes,
}.Build()
File_recover_proto = out.File
file_recover_proto_rawDesc = nil
file_recover_proto_goTypes = nil
file_recover_proto_depIdxs = nil
}

View File

@ -0,0 +1,20 @@
syntax = "proto3";
package recoverproto;
option go_package = "github.com/edgelesssys/constellation/disk-mapper/recoverproto";
service API {
rpc Recover(stream RecoverMessage) returns (stream RecoverResponse) {}
}
message RecoverMessage {
oneof request {
bytes state_disk_key = 1;
bytes measurement_secret = 2;
}
}
message RecoverResponse {
string disk_uuid = 1;
}

View File

@ -2,9 +2,9 @@
// versions:
// - protoc-gen-go-grpc v1.2.0
// - protoc v3.20.1
// source: keyservice.proto
// source: recover.proto
package keyproto
package recoverproto
import (
context "context"
@ -22,7 +22,7 @@ const _ = grpc.SupportPackageIsVersion7
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type APIClient interface {
PushStateDiskKey(ctx context.Context, in *PushStateDiskKeyRequest, opts ...grpc.CallOption) (*PushStateDiskKeyResponse, error)
Recover(ctx context.Context, opts ...grpc.CallOption) (API_RecoverClient, error)
}
type aPIClient struct {
@ -33,20 +33,42 @@ func NewAPIClient(cc grpc.ClientConnInterface) APIClient {
return &aPIClient{cc}
}
func (c *aPIClient) PushStateDiskKey(ctx context.Context, in *PushStateDiskKeyRequest, opts ...grpc.CallOption) (*PushStateDiskKeyResponse, error) {
out := new(PushStateDiskKeyResponse)
err := c.cc.Invoke(ctx, "/keyproto.API/PushStateDiskKey", in, out, opts...)
func (c *aPIClient) Recover(ctx context.Context, opts ...grpc.CallOption) (API_RecoverClient, error) {
stream, err := c.cc.NewStream(ctx, &API_ServiceDesc.Streams[0], "/recoverproto.API/Recover", opts...)
if err != nil {
return nil, err
}
return out, nil
x := &aPIRecoverClient{stream}
return x, nil
}
type API_RecoverClient interface {
Send(*RecoverMessage) error
Recv() (*RecoverResponse, error)
grpc.ClientStream
}
type aPIRecoverClient struct {
grpc.ClientStream
}
func (x *aPIRecoverClient) Send(m *RecoverMessage) error {
return x.ClientStream.SendMsg(m)
}
func (x *aPIRecoverClient) Recv() (*RecoverResponse, error) {
m := new(RecoverResponse)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// APIServer is the server API for API service.
// All implementations must embed UnimplementedAPIServer
// for forward compatibility
type APIServer interface {
PushStateDiskKey(context.Context, *PushStateDiskKeyRequest) (*PushStateDiskKeyResponse, error)
Recover(API_RecoverServer) error
mustEmbedUnimplementedAPIServer()
}
@ -54,8 +76,8 @@ type APIServer interface {
type UnimplementedAPIServer struct {
}
func (UnimplementedAPIServer) PushStateDiskKey(context.Context, *PushStateDiskKeyRequest) (*PushStateDiskKeyResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method PushStateDiskKey not implemented")
func (UnimplementedAPIServer) Recover(API_RecoverServer) error {
return status.Errorf(codes.Unimplemented, "method Recover not implemented")
}
func (UnimplementedAPIServer) mustEmbedUnimplementedAPIServer() {}
@ -70,36 +92,46 @@ func RegisterAPIServer(s grpc.ServiceRegistrar, srv APIServer) {
s.RegisterService(&API_ServiceDesc, srv)
}
func _API_PushStateDiskKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(PushStateDiskKeyRequest)
if err := dec(in); err != nil {
func _API_Recover_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(APIServer).Recover(&aPIRecoverServer{stream})
}
type API_RecoverServer interface {
Send(*RecoverResponse) error
Recv() (*RecoverMessage, error)
grpc.ServerStream
}
type aPIRecoverServer struct {
grpc.ServerStream
}
func (x *aPIRecoverServer) Send(m *RecoverResponse) error {
return x.ServerStream.SendMsg(m)
}
func (x *aPIRecoverServer) Recv() (*RecoverMessage, error) {
m := new(RecoverMessage)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(APIServer).PushStateDiskKey(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/keyproto.API/PushStateDiskKey",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(APIServer).PushStateDiskKey(ctx, req.(*PushStateDiskKeyRequest))
}
return interceptor(ctx, in, info, handler)
return m, nil
}
// API_ServiceDesc is the grpc.ServiceDesc for API service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var API_ServiceDesc = grpc.ServiceDesc{
ServiceName: "keyproto.API",
ServiceName: "recoverproto.API",
HandlerType: (*APIServer)(nil),
Methods: []grpc.MethodDesc{
Methods: []grpc.MethodDesc{},
Streams: []grpc.StreamDesc{
{
MethodName: "PushStateDiskKey",
Handler: _API_PushStateDiskKey_Handler,
StreamName: "Recover",
Handler: _API_Recover_Handler,
ServerStreams: true,
ClientStreams: true,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "keyservice.proto",
Metadata: "recover.proto",
}

View File

@ -24,9 +24,9 @@ RUN go install google.golang.org/protobuf/cmd/protoc-gen-go@v${GEN_GO_VER} && \
# Generate code for every existing proto file
## disk-mapper keyservice api
## disk-mapper recover api
WORKDIR /disk-mapper
COPY state/keyproto/*.proto /disk-mapper
COPY disk-mapper/recoverproto/*.proto /disk-mapper
RUN protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative *.proto
## debugd service
@ -54,7 +54,7 @@ COPY bootstrapper/initproto/*.proto /init
RUN protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative *.proto
FROM scratch as export
COPY --from=build /disk-mapper/*.go state/keyproto/
COPY --from=build /disk-mapper/*.go disk-mapper/recoverproto/
COPY --from=build /service/*.go debugd/service/
COPY --from=build /kms/*.go kms/kmsproto/
COPY --from=build /joinservice/*.go joinservice/joinproto/

View File

@ -1,204 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package keyservice
import (
"context"
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/edgelesssys/constellation/internal/cloud/metadata"
"github.com/edgelesssys/constellation/internal/crypto"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/internal/logger"
"github.com/edgelesssys/constellation/internal/oid"
"github.com/edgelesssys/constellation/joinservice/joinproto"
"github.com/edgelesssys/constellation/state/keyproto"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/status"
"k8s.io/utils/clock"
)
// KeyAPI is the interface called by control-plane or an admin during restart of a node.
type KeyAPI struct {
listenAddr string
log *logger.Logger
mux sync.Mutex
metadata metadata.InstanceLister
issuer QuoteIssuer
key []byte
measurementSecret []byte
keyReceived chan struct{}
clock clock.WithTicker
timeout time.Duration
interval time.Duration
keyproto.UnimplementedAPIServer
}
// New initializes a KeyAPI with the given parameters.
func New(log *logger.Logger, issuer QuoteIssuer, metadata metadata.InstanceLister, timeout time.Duration, interval time.Duration) *KeyAPI {
return &KeyAPI{
log: log,
metadata: metadata,
issuer: issuer,
keyReceived: make(chan struct{}, 1),
clock: clock.RealClock{},
timeout: timeout,
interval: interval,
}
}
// PushStateDiskKey is the rpc to push state disk decryption keys to a restarting node.
func (a *KeyAPI) PushStateDiskKey(ctx context.Context, in *keyproto.PushStateDiskKeyRequest) (*keyproto.PushStateDiskKeyResponse, error) {
a.mux.Lock()
defer a.mux.Unlock()
if len(a.key) != 0 {
return nil, status.Error(codes.FailedPrecondition, "node already received a passphrase")
}
if len(in.StateDiskKey) != crypto.StateDiskKeyLength {
return nil, status.Errorf(codes.InvalidArgument, "received invalid passphrase: expected length: %d, but got: %d", crypto.StateDiskKeyLength, len(in.StateDiskKey))
}
if len(in.MeasurementSecret) != crypto.RNGLengthDefault {
return nil, status.Errorf(codes.InvalidArgument, "received invalid measurement secret: expected length: %d, but got: %d", crypto.RNGLengthDefault, len(in.MeasurementSecret))
}
a.key = in.StateDiskKey
a.measurementSecret = in.MeasurementSecret
a.keyReceived <- struct{}{}
return &keyproto.PushStateDiskKeyResponse{}, nil
}
// WaitForDecryptionKey notifies control-plane nodes to send a decryption key and waits until a key is received.
func (a *KeyAPI) WaitForDecryptionKey(uuid, listenAddr string) (diskKey, measurementSecret []byte, err error) {
if uuid == "" {
return nil, nil, errors.New("received no disk UUID")
}
a.listenAddr = listenAddr
creds := atlscredentials.New(a.issuer, nil)
server := grpc.NewServer(grpc.Creds(creds))
keyproto.RegisterAPIServer(server, a)
listener, err := net.Listen("tcp", listenAddr)
if err != nil {
return nil, nil, err
}
defer listener.Close()
a.log.Infof("Waiting for decryption key. Listening on: %s", listener.Addr().String())
go server.Serve(listener)
defer server.GracefulStop()
a.requestKeyLoop(uuid)
return a.key, a.measurementSecret, nil
}
// ResetKey resets a previously set key.
func (a *KeyAPI) ResetKey() {
a.key = nil
}
// requestKeyLoop continuously requests decryption keys from all available control-plane nodes, until the KeyAPI receives a key.
func (a *KeyAPI) requestKeyLoop(uuid string, opts ...grpc.DialOption) {
// we do not perform attestation, since the restarting node does not need to care about notifying the correct node
// if an incorrect key is pushed by a malicious actor, decrypting the disk will fail, and the node will not start
creds := atlscredentials.New(a.issuer, nil)
ticker := a.clock.NewTicker(a.interval)
defer ticker.Stop()
for {
endpoints, err := a.getJoinServiceEndpoints()
if err != nil {
a.log.With(zap.Error(err)).Errorf("Failed to get JoinService endpoints")
} else {
a.log.Infof("Received list with JoinService endpoints: %v", endpoints)
for _, endpoint := range endpoints {
a.requestKey(endpoint, uuid, creds, opts...)
}
}
select {
case <-a.keyReceived:
// return if a key was received
// a key can be send by
// - a control-plane node, after the request rpc was received
// - by a Constellation admin, at any time this loop is running on a node during boot
return
case <-ticker.C():
}
}
}
func (a *KeyAPI) getJoinServiceEndpoints() ([]string, error) {
ctx, cancel := context.WithTimeout(context.Background(), a.timeout)
defer cancel()
return metadata.JoinServiceEndpoints(ctx, a.metadata)
}
func (a *KeyAPI) requestKey(endpoint, uuid string, credentials credentials.TransportCredentials, opts ...grpc.DialOption) {
opts = append(opts, grpc.WithTransportCredentials(credentials))
a.log.With(zap.String("endpoint", endpoint)).Infof("Requesting rejoin ticket")
rejoinTicket, err := a.requestRejoinTicket(endpoint, uuid, opts...)
if err != nil {
a.log.With(zap.Error(err), zap.String("endpoint", endpoint)).Errorf("Failed to request rejoin ticket")
return
}
a.log.With(zap.String("endpoint", endpoint)).Infof("Pushing key to own server")
if err := a.pushKeyToOwnServer(rejoinTicket.StateDiskKey, rejoinTicket.MeasurementSecret, opts...); err != nil {
a.log.With(zap.Error(err), zap.String("endpoint", a.listenAddr)).Errorf("Failed to push key to own server")
return
}
}
func (a *KeyAPI) requestRejoinTicket(endpoint, uuid string, opts ...grpc.DialOption) (*joinproto.IssueRejoinTicketResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), a.timeout)
defer cancel()
conn, err := grpc.DialContext(ctx, endpoint, opts...)
if err != nil {
return nil, fmt.Errorf("dialing gRPC: %w", err)
}
defer conn.Close()
client := joinproto.NewAPIClient(conn)
req := &joinproto.IssueRejoinTicketRequest{DiskUuid: uuid}
return client.IssueRejoinTicket(ctx, req)
}
func (a *KeyAPI) pushKeyToOwnServer(stateDiskKey, measurementSecret []byte, opts ...grpc.DialOption) error {
ctx, cancel := context.WithTimeout(context.Background(), a.timeout)
defer cancel()
conn, err := grpc.DialContext(ctx, a.listenAddr, opts...)
if err != nil {
return fmt.Errorf("dialing gRPC: %w", err)
}
defer conn.Close()
client := keyproto.NewAPIClient(conn)
req := &keyproto.PushStateDiskKeyRequest{StateDiskKey: stateDiskKey, MeasurementSecret: measurementSecret}
_, err = client.PushStateDiskKey(ctx, req)
return err
}
// QuoteValidator validates quotes.
type QuoteValidator interface {
oid.Getter
// Validate validates a quote and returns the user data on success.
Validate(attDoc []byte, nonce []byte) ([]byte, error)
}
// QuoteIssuer issues quotes.
type QuoteIssuer interface {
oid.Getter
// Issue issues a quote for remote attestation for a given message
Issue(userData []byte, nonce []byte) (quote []byte, err error)
}

View File

@ -1,264 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package keyservice
import (
"context"
"errors"
"net"
"testing"
"time"
"github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/cloud/metadata"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/internal/logger"
"github.com/edgelesssys/constellation/internal/oid"
"github.com/edgelesssys/constellation/internal/role"
"github.com/edgelesssys/constellation/joinservice/joinproto"
"github.com/edgelesssys/constellation/state/keyproto"
"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
"google.golang.org/grpc"
"google.golang.org/grpc/test/bufconn"
testclock "k8s.io/utils/clock/testing"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestRequestKeyLoop(t *testing.T) {
clockstep := struct{}{}
someErr := errors.New("failed")
defaultInstance := metadata.InstanceMetadata{
Name: "test-instance",
ProviderID: "/test/provider",
Role: role.ControlPlane,
VPCIP: "192.0.2.1",
}
testCases := map[string]struct {
answers []any
}{
"success": {
answers: []any{
listAnswer{listResponse: []metadata.InstanceMetadata{defaultInstance}},
issueRejoinTicketAnswer{stateDiskKey: []byte{0x1}, measurementSecret: []byte{0x2}},
pushStateDiskKeyAnswer{},
},
},
"recover metadata list error": {
answers: []any{
listAnswer{err: someErr},
clockstep,
listAnswer{listResponse: []metadata.InstanceMetadata{defaultInstance}},
issueRejoinTicketAnswer{stateDiskKey: []byte{0x1}, measurementSecret: []byte{0x2}},
pushStateDiskKeyAnswer{},
},
},
"recover issue rejoin ticket error": {
answers: []any{
listAnswer{listResponse: []metadata.InstanceMetadata{defaultInstance}},
issueRejoinTicketAnswer{err: someErr},
clockstep,
listAnswer{listResponse: []metadata.InstanceMetadata{defaultInstance}},
issueRejoinTicketAnswer{stateDiskKey: []byte{0x1}, measurementSecret: []byte{0x2}},
pushStateDiskKeyAnswer{},
},
},
"recover push key error": {
answers: []any{
listAnswer{listResponse: []metadata.InstanceMetadata{defaultInstance}},
issueRejoinTicketAnswer{stateDiskKey: []byte{0x1}, measurementSecret: []byte{0x2}},
pushStateDiskKeyAnswer{err: someErr},
clockstep,
listAnswer{listResponse: []metadata.InstanceMetadata{defaultInstance}},
issueRejoinTicketAnswer{stateDiskKey: []byte{0x1}, measurementSecret: []byte{0x2}},
pushStateDiskKeyAnswer{},
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
metadataServer := newStubMetadataServer()
joinServer := newStubJoinAPIServer()
keyServer := newStubKeyAPIServer()
listener := bufconn.Listen(1024)
defer listener.Close()
creds := atlscredentials.New(atls.NewFakeIssuer(oid.Dummy{}), nil)
grpcServer := grpc.NewServer(grpc.Creds(creds))
joinproto.RegisterAPIServer(grpcServer, joinServer)
keyproto.RegisterAPIServer(grpcServer, keyServer)
go grpcServer.Serve(listener)
defer grpcServer.GracefulStop()
clock := testclock.NewFakeClock(time.Now())
keyReceived := make(chan struct{}, 1)
keyWaiter := &KeyAPI{
listenAddr: "192.0.2.1:30090",
log: logger.NewTest(t),
metadata: metadataServer,
keyReceived: keyReceived,
clock: clock,
timeout: 1 * time.Second,
interval: 1 * time.Second,
}
grpcOpts := []grpc.DialOption{
grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) {
return listener.DialContext(ctx)
}),
}
// Start the request loop under tests.
done := make(chan struct{})
go func() {
defer close(done)
keyWaiter.requestKeyLoop("1234", grpcOpts...)
}()
// Play test case answers.
for _, answ := range tc.answers {
switch answ := answ.(type) {
case listAnswer:
metadataServer.listAnswerC <- answ
case issueRejoinTicketAnswer:
joinServer.issueRejoinTicketAnswerC <- answ
case pushStateDiskKeyAnswer:
keyServer.pushStateDiskKeyAnswerC <- answ
default:
clock.Step(time.Second)
}
}
// Stop the request loop.
keyReceived <- struct{}{}
})
}
}
func TestPushStateDiskKey(t *testing.T) {
testCases := map[string]struct {
testAPI *KeyAPI
request *keyproto.PushStateDiskKeyRequest
wantErr bool
}{
"success": {
testAPI: &KeyAPI{keyReceived: make(chan struct{}, 1)},
request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), MeasurementSecret: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB")},
},
"key already set": {
testAPI: &KeyAPI{
keyReceived: make(chan struct{}, 1),
key: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
},
request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"), MeasurementSecret: []byte("CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC")},
wantErr: true,
},
"incorrect size of pushed key": {
testAPI: &KeyAPI{keyReceived: make(chan struct{}, 1)},
request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("AAAAAAAAAAAAAAAA"), MeasurementSecret: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB")},
wantErr: true,
},
"incorrect size of measurement secret": {
testAPI: &KeyAPI{keyReceived: make(chan struct{}, 1)},
request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), MeasurementSecret: []byte("BBBBBBBBBBBBBBBB")},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
tc.testAPI.log = logger.NewTest(t)
_, err := tc.testAPI.PushStateDiskKey(context.Background(), tc.request)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.request.StateDiskKey, tc.testAPI.key)
}
})
}
}
func TestResetKey(t *testing.T) {
api := New(logger.NewTest(t), nil, nil, time.Second, time.Millisecond)
api.key = []byte{0x1, 0x2, 0x3}
api.ResetKey()
assert.Nil(t, api.key)
}
type stubMetadataServer struct {
listAnswerC chan listAnswer
}
func newStubMetadataServer() *stubMetadataServer {
return &stubMetadataServer{
listAnswerC: make(chan listAnswer),
}
}
func (s *stubMetadataServer) List(context.Context) ([]metadata.InstanceMetadata, error) {
answer := <-s.listAnswerC
return answer.listResponse, answer.err
}
type listAnswer struct {
listResponse []metadata.InstanceMetadata
err error
}
type stubJoinAPIServer struct {
issueRejoinTicketAnswerC chan issueRejoinTicketAnswer
joinproto.UnimplementedAPIServer
}
func newStubJoinAPIServer() *stubJoinAPIServer {
return &stubJoinAPIServer{
issueRejoinTicketAnswerC: make(chan issueRejoinTicketAnswer),
}
}
func (s *stubJoinAPIServer) IssueRejoinTicket(context.Context, *joinproto.IssueRejoinTicketRequest) (*joinproto.IssueRejoinTicketResponse, error) {
answer := <-s.issueRejoinTicketAnswerC
resp := &joinproto.IssueRejoinTicketResponse{
StateDiskKey: answer.stateDiskKey,
MeasurementSecret: answer.measurementSecret,
}
return resp, answer.err
}
type issueRejoinTicketAnswer struct {
stateDiskKey []byte
measurementSecret []byte
err error
}
type stubKeyAPIServer struct {
pushStateDiskKeyAnswerC chan pushStateDiskKeyAnswer
keyproto.UnimplementedAPIServer
}
func newStubKeyAPIServer() *stubKeyAPIServer {
return &stubKeyAPIServer{
pushStateDiskKeyAnswerC: make(chan pushStateDiskKeyAnswer),
}
}
func (s *stubKeyAPIServer) PushStateDiskKey(context.Context, *keyproto.PushStateDiskKeyRequest) (*keyproto.PushStateDiskKeyResponse, error) {
answer := <-s.pushStateDiskKeyAnswerC
return &keyproto.PushStateDiskKeyResponse{}, answer.err
}
type pushStateDiskKeyAnswer struct {
err error
}

View File

@ -1,219 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.0
// protoc v3.20.1
// source: keyservice.proto
package keyproto
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type PushStateDiskKeyRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
StateDiskKey []byte `protobuf:"bytes,1,opt,name=state_disk_key,json=stateDiskKey,proto3" json:"state_disk_key,omitempty"`
MeasurementSecret []byte `protobuf:"bytes,2,opt,name=measurement_secret,json=measurementSecret,proto3" json:"measurement_secret,omitempty"`
}
func (x *PushStateDiskKeyRequest) Reset() {
*x = PushStateDiskKeyRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_keyservice_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *PushStateDiskKeyRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*PushStateDiskKeyRequest) ProtoMessage() {}
func (x *PushStateDiskKeyRequest) ProtoReflect() protoreflect.Message {
mi := &file_keyservice_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use PushStateDiskKeyRequest.ProtoReflect.Descriptor instead.
func (*PushStateDiskKeyRequest) Descriptor() ([]byte, []int) {
return file_keyservice_proto_rawDescGZIP(), []int{0}
}
func (x *PushStateDiskKeyRequest) GetStateDiskKey() []byte {
if x != nil {
return x.StateDiskKey
}
return nil
}
func (x *PushStateDiskKeyRequest) GetMeasurementSecret() []byte {
if x != nil {
return x.MeasurementSecret
}
return nil
}
type PushStateDiskKeyResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *PushStateDiskKeyResponse) Reset() {
*x = PushStateDiskKeyResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_keyservice_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *PushStateDiskKeyResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*PushStateDiskKeyResponse) ProtoMessage() {}
func (x *PushStateDiskKeyResponse) ProtoReflect() protoreflect.Message {
mi := &file_keyservice_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use PushStateDiskKeyResponse.ProtoReflect.Descriptor instead.
func (*PushStateDiskKeyResponse) Descriptor() ([]byte, []int) {
return file_keyservice_proto_rawDescGZIP(), []int{1}
}
var File_keyservice_proto protoreflect.FileDescriptor
var file_keyservice_proto_rawDesc = []byte{
0x0a, 0x10, 0x6b, 0x65, 0x79, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x12, 0x08, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x6e, 0x0a, 0x17,
0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79,
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x24, 0x0a, 0x0e, 0x73, 0x74, 0x61, 0x74, 0x65,
0x5f, 0x64, 0x69, 0x73, 0x6b, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52,
0x0c, 0x73, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79, 0x12, 0x2d, 0x0a,
0x12, 0x6d, 0x65, 0x61, 0x73, 0x75, 0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x73, 0x65, 0x63,
0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, 0x6d, 0x65, 0x61, 0x73, 0x75,
0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x22, 0x1a, 0x0a, 0x18,
0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x32, 0x60, 0x0a, 0x03, 0x41, 0x50, 0x49, 0x12,
0x59, 0x0a, 0x10, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b,
0x4b, 0x65, 0x79, 0x12, 0x21, 0x2e, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x50,
0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x22, 0x2e, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x2e, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b,
0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x40, 0x5a, 0x3e, 0x67, 0x69,
0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x65, 0x64, 0x67, 0x65, 0x6c, 0x65, 0x73,
0x73, 0x73, 0x79, 0x73, 0x2f, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x65, 0x6c, 0x6c, 0x61, 0x74, 0x69,
0x6f, 0x6e, 0x2f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x2f, 0x6b, 0x65, 0x79, 0x73, 0x65, 0x72, 0x76,
0x69, 0x63, 0x65, 0x2f, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x33,
}
var (
file_keyservice_proto_rawDescOnce sync.Once
file_keyservice_proto_rawDescData = file_keyservice_proto_rawDesc
)
func file_keyservice_proto_rawDescGZIP() []byte {
file_keyservice_proto_rawDescOnce.Do(func() {
file_keyservice_proto_rawDescData = protoimpl.X.CompressGZIP(file_keyservice_proto_rawDescData)
})
return file_keyservice_proto_rawDescData
}
var file_keyservice_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_keyservice_proto_goTypes = []interface{}{
(*PushStateDiskKeyRequest)(nil), // 0: keyproto.PushStateDiskKeyRequest
(*PushStateDiskKeyResponse)(nil), // 1: keyproto.PushStateDiskKeyResponse
}
var file_keyservice_proto_depIdxs = []int32{
0, // 0: keyproto.API.PushStateDiskKey:input_type -> keyproto.PushStateDiskKeyRequest
1, // 1: keyproto.API.PushStateDiskKey:output_type -> keyproto.PushStateDiskKeyResponse
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_keyservice_proto_init() }
func file_keyservice_proto_init() {
if File_keyservice_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_keyservice_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*PushStateDiskKeyRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_keyservice_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*PushStateDiskKeyResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_keyservice_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_keyservice_proto_goTypes,
DependencyIndexes: file_keyservice_proto_depIdxs,
MessageInfos: file_keyservice_proto_msgTypes,
}.Build()
File_keyservice_proto = out.File
file_keyservice_proto_rawDesc = nil
file_keyservice_proto_goTypes = nil
file_keyservice_proto_depIdxs = nil
}

View File

@ -1,17 +0,0 @@
syntax = "proto3";
package keyproto;
option go_package = "github.com/edgelesssys/constellation/state/keyservice/keyproto";
service API {
rpc PushStateDiskKey(PushStateDiskKeyRequest) returns (PushStateDiskKeyResponse);
}
message PushStateDiskKeyRequest {
bytes state_disk_key = 1;
bytes measurement_secret = 2;
}
message PushStateDiskKeyResponse {
}