mirror of
https://github.com/edgelesssys/constellation.git
synced 2024-10-01 01:36:09 -04:00
AB#2260 Refactor disk-mapper recovery (#82)
* Refactor disk-mapper recovery * Adapt constellation recover command to use new disk-mapper recovery API * Fix Cilium connectivity on rebooting nodes (#89) * Lower CoreDNS reschedule timeout to 10 seconds (#93) Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
parent
a7b20b2a11
commit
8cb155d5c5
@ -106,6 +106,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Owner ID and Unique ID are merged into a single value: Cluster ID.
|
||||
- Streamline logging to only use one logging library, instead of multiple.
|
||||
- Replace dependency on github.com/willdonnelly/passwd with own implementation.
|
||||
- Refactor disk-mapper to allow a more streamlined node recovery
|
||||
|
||||
### Removed
|
||||
|
||||
|
@ -71,5 +71,5 @@ add_test(NAME unit-hack COMMAND go test -race -count=3 ./... WORKING_DIRECTORY $
|
||||
add_test(NAME unit-node-operator COMMAND go test -race -count=3 ./... WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/operators/constellation-node-operator)
|
||||
add_test(NAME integration-node-operator COMMAND make test WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/operators/constellation-node-operator)
|
||||
add_test(NAME integration-csi COMMAND bash -c "go test -tags integration -c ./test/ && sudo ./test.test -test.v" WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/csi)
|
||||
add_test(NAME integration-dm COMMAND bash -c "go test -tags integration -c ./test/ && sudo ./test.test -test.v" WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/state/internal)
|
||||
add_test(NAME integration-dm COMMAND bash -c "go test -tags integration -c ./test/ && sudo ./test.test -test.v" WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/disk-mapper/internal)
|
||||
add_test(NAME integration-license COMMAND bash -c "go test -tags integration" WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/internal/license)
|
||||
|
@ -61,7 +61,7 @@ Core components:
|
||||
* [image](image): Build files for the Constellation disk image
|
||||
* [kms](kms): Constellation's key management client and server
|
||||
* [csi](csi): Package used by CSI plugins to create and mount encrypted block devices
|
||||
* [state](state): Contains the disk-mapper that maps the encrypted node data disk during boot
|
||||
* [disk-mapper](disk-mapper): Contains the disk-mapper that maps the encrypted node data disk during boot
|
||||
|
||||
Development components:
|
||||
|
||||
|
@ -36,7 +36,7 @@ ARG PROJECT_VERSION
|
||||
RUN go build -o bootstrapper -tags=gcp,disable_tpm_simulator -buildvcs=false -ldflags "-s -w -buildid='' -X main.version=${PROJECT_VERSION}" ./cmd/bootstrapper/
|
||||
|
||||
FROM build AS build-disk-mapper
|
||||
WORKDIR /constellation/state/
|
||||
WORKDIR /constellation/disk-mapper/
|
||||
|
||||
RUN go build -o disk-mapper -ldflags "-s -w" ./cmd/
|
||||
|
||||
@ -44,4 +44,4 @@ FROM scratch AS bootstrapper
|
||||
COPY --from=build-bootstrapper /constellation/bootstrapper/bootstrapper /
|
||||
|
||||
FROM scratch AS disk-mapper
|
||||
COPY --from=build-disk-mapper /constellation/state/disk-mapper /
|
||||
COPY --from=build-disk-mapper /constellation/disk-mapper/disk-mapper /
|
||||
|
@ -50,7 +50,7 @@ func run(issuerWrapper initserver.IssuerWrapper, tpm vtpm.TPMOpenFunc, fileHandl
|
||||
}
|
||||
|
||||
if nodeBootstrapped {
|
||||
if err := kube.StartKubelet(); err != nil {
|
||||
if err := kube.StartKubelet(log); err != nil {
|
||||
log.With(zap.Error(err)).Fatalf("Failed to restart kubelet")
|
||||
}
|
||||
return
|
||||
@ -88,7 +88,7 @@ func getDiskUUID() (string, error) {
|
||||
type clusterInitJoiner interface {
|
||||
joinclient.ClusterJoiner
|
||||
initserver.ClusterInitializer
|
||||
StartKubelet() error
|
||||
StartKubelet(*logger.Logger) error
|
||||
}
|
||||
|
||||
type metadataAPI interface {
|
||||
|
@ -33,7 +33,7 @@ func (c *clusterFake) JoinCluster(context.Context, *kubeadm.BootstrapTokenDiscov
|
||||
}
|
||||
|
||||
// StartKubelet starts the kubelet service.
|
||||
func (c *clusterFake) StartKubelet() error {
|
||||
func (c *clusterFake) StartKubelet(*logger.Logger) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -13,30 +13,6 @@ import (
|
||||
"github.com/coreos/go-systemd/v22/dbus"
|
||||
)
|
||||
|
||||
func restartSystemdUnit(ctx context.Context, unit string) error {
|
||||
conn, err := dbus.NewSystemdConnectionContext(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("establishing systemd connection: %w", err)
|
||||
}
|
||||
|
||||
restartChan := make(chan string)
|
||||
if _, err := conn.RestartUnitContext(ctx, unit, "replace", restartChan); err != nil {
|
||||
return fmt.Errorf("restarting systemd unit %q: %w", unit, err)
|
||||
}
|
||||
|
||||
// Wait for the restart to finish and actually check if it was
|
||||
// successful or not.
|
||||
result := <-restartChan
|
||||
|
||||
switch result {
|
||||
case "done":
|
||||
return nil
|
||||
|
||||
default:
|
||||
return fmt.Errorf("restarting systemd unit %q failed: expected %v but received %v", unit, "done", result)
|
||||
}
|
||||
}
|
||||
|
||||
func startSystemdUnit(ctx context.Context, unit string) error {
|
||||
conn, err := dbus.NewSystemdConnectionContext(ctx)
|
||||
if err != nil {
|
||||
|
@ -264,12 +264,19 @@ func (k *KubernetesUtil) deployCiliumGCP(ctx context.Context, helmClient *action
|
||||
return err
|
||||
}
|
||||
|
||||
timeoutS := int64(10)
|
||||
// allow coredns to run on uninitialized nodes (required by cloud-controller-manager)
|
||||
tolerations := []corev1.Toleration{
|
||||
{
|
||||
Key: "node.cloudprovider.kubernetes.io/uninitialized",
|
||||
Value: "true",
|
||||
Effect: "NoSchedule",
|
||||
Effect: corev1.TaintEffectNoSchedule,
|
||||
},
|
||||
{
|
||||
Key: "node.kubernetes.io/unreachable",
|
||||
Operator: corev1.TolerationOpExists,
|
||||
Effect: corev1.TaintEffectNoExecute,
|
||||
TolerationSeconds: &timeoutS,
|
||||
},
|
||||
}
|
||||
if err = kubectl.AddTolerationsToDeployment(ctx, tolerations, "coredns", "kube-system"); err != nil {
|
||||
@ -305,7 +312,7 @@ func (k *KubernetesUtil) deployCiliumGCP(ctx context.Context, helmClient *action
|
||||
|
||||
// FixCilium fixes https://github.com/cilium/cilium/issues/19958 but instead of a rollout restart of
|
||||
// the cilium daemonset, it only restarts the local cilium pod.
|
||||
func (k *KubernetesUtil) FixCilium(nodeNameK8s string, log *logger.Logger) {
|
||||
func (k *KubernetesUtil) FixCilium(log *logger.Logger) {
|
||||
// wait for cilium pod to be healthy
|
||||
client := http.Client{}
|
||||
for {
|
||||
@ -487,13 +494,6 @@ func (k *KubernetesUtil) StartKubelet() error {
|
||||
return startSystemdUnit(ctx, "kubelet.service")
|
||||
}
|
||||
|
||||
// RestartKubelet restarts a kubelet.
|
||||
func (k *KubernetesUtil) RestartKubelet() error {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), kubeletStartTimeout)
|
||||
defer cancel()
|
||||
return restartSystemdUnit(ctx, "kubelet.service")
|
||||
}
|
||||
|
||||
// createSignedKubeletCert manually creates a Kubernetes CA signed kubelet certificate for the bootstrapper node.
|
||||
// This is necessary because this node does not request a certificate from the join service.
|
||||
func (k *KubernetesUtil) createSignedKubeletCert(nodeName string, ips []net.IP) error {
|
||||
|
@ -33,6 +33,5 @@ type clusterUtil interface {
|
||||
SetupNodeMaintenanceOperator(kubectl k8sapi.Client, nodeMaintenanceOperatorConfiguration kubernetes.Marshaler) error
|
||||
SetupNodeOperator(ctx context.Context, kubectl k8sapi.Client, nodeOperatorConfiguration kubernetes.Marshaler) error
|
||||
StartKubelet() error
|
||||
RestartKubelet() error
|
||||
FixCilium(nodeNameK8s string, log *logger.Logger)
|
||||
FixCilium(log *logger.Logger)
|
||||
}
|
||||
|
@ -229,7 +229,7 @@ func (k *KubeWrapper) InitCluster(
|
||||
return nil, fmt.Errorf("failed to setup k8s version ConfigMap: %w", err)
|
||||
}
|
||||
|
||||
k.clusterUtil.FixCilium(nodeName, log)
|
||||
k.clusterUtil.FixCilium(log)
|
||||
|
||||
return k.GetKubeconfig()
|
||||
}
|
||||
@ -309,7 +309,7 @@ func (k *KubeWrapper) JoinCluster(ctx context.Context, args *kubeadm.BootstrapTo
|
||||
return fmt.Errorf("joining cluster: %v; %w ", string(joinConfigYAML), err)
|
||||
}
|
||||
|
||||
k.clusterUtil.FixCilium(nodeName, log)
|
||||
k.clusterUtil.FixCilium(log)
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -481,8 +481,13 @@ func k8sCompliantHostname(in string) string {
|
||||
}
|
||||
|
||||
// StartKubelet starts the kubelet service.
|
||||
func (k *KubeWrapper) StartKubelet() error {
|
||||
return k.clusterUtil.StartKubelet()
|
||||
func (k *KubeWrapper) StartKubelet(log *logger.Logger) error {
|
||||
if err := k.clusterUtil.StartKubelet(); err != nil {
|
||||
return fmt.Errorf("starting kubelet: %w", err)
|
||||
}
|
||||
|
||||
k.clusterUtil.FixCilium(log)
|
||||
return nil
|
||||
}
|
||||
|
||||
// getIPAddr retrieves to default sender IP used for outgoing connection.
|
||||
|
@ -531,7 +531,6 @@ type stubClusterUtil struct {
|
||||
setupNodeOperatorErr error
|
||||
joinClusterErr error
|
||||
startKubeletErr error
|
||||
restartKubeletErr error
|
||||
|
||||
initConfigs [][]byte
|
||||
joinConfigs [][]byte
|
||||
@ -603,11 +602,7 @@ func (s *stubClusterUtil) StartKubelet() error {
|
||||
return s.startKubeletErr
|
||||
}
|
||||
|
||||
func (s *stubClusterUtil) RestartKubelet() error {
|
||||
return s.restartKubeletErr
|
||||
}
|
||||
|
||||
func (s *stubClusterUtil) FixCilium(nodeName string, log *logger.Logger) {
|
||||
func (s *stubClusterUtil) FixCilium(log *logger.Logger) {
|
||||
}
|
||||
|
||||
type stubConfigProvider struct {
|
||||
|
@ -7,24 +7,26 @@ SPDX-License-Identifier: AGPL-3.0-only
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"io"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/internal/cloudcmd"
|
||||
"github.com/edgelesssys/constellation/cli/internal/proto"
|
||||
"github.com/edgelesssys/constellation/internal/attestation"
|
||||
"github.com/edgelesssys/constellation/internal/atls"
|
||||
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
"github.com/edgelesssys/constellation/internal/crypto"
|
||||
"github.com/edgelesssys/constellation/internal/file"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var diskUUIDRegexp = regexp.MustCompile("^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$")
|
||||
type recoveryClient interface {
|
||||
Connect(endpoint string, validators atls.Validator) error
|
||||
Recover(ctx context.Context, masterSecret, salt []byte) error
|
||||
io.Closer
|
||||
}
|
||||
|
||||
// NewRecoverCmd returns a new cobra.Command for the recover command.
|
||||
func NewRecoverCmd() *cobra.Command {
|
||||
@ -38,15 +40,13 @@ func NewRecoverCmd() *cobra.Command {
|
||||
}
|
||||
cmd.Flags().StringP("endpoint", "e", "", "endpoint of the instance, passed as HOST[:PORT] (required)")
|
||||
must(cmd.MarkFlagRequired("endpoint"))
|
||||
cmd.Flags().String("disk-uuid", "", "disk UUID of the encrypted state disk (required)")
|
||||
must(cmd.MarkFlagRequired("disk-uuid"))
|
||||
cmd.Flags().String("master-secret", constants.MasterSecretFilename, "path to master secret file")
|
||||
return cmd
|
||||
}
|
||||
|
||||
func runRecover(cmd *cobra.Command, _ []string) error {
|
||||
fileHandler := file.NewHandler(afero.NewOsFs())
|
||||
recoveryClient := &proto.KeyClient{}
|
||||
recoveryClient := &proto.RecoverClient{}
|
||||
defer recoveryClient.Close()
|
||||
return recover(cmd, fileHandler, recoveryClient)
|
||||
}
|
||||
@ -82,17 +82,7 @@ func recover(cmd *cobra.Command, fileHandler file.Handler, recoveryClient recove
|
||||
return err
|
||||
}
|
||||
|
||||
diskKey, err := deriveStateDiskKey(masterSecret.Key, masterSecret.Salt, flags.diskUUID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
measurementSecret, err := attestation.DeriveMeasurementSecret(masterSecret.Key, masterSecret.Salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := recoveryClient.PushStateDiskKey(cmd.Context(), diskKey, measurementSecret); err != nil {
|
||||
if err := recoveryClient.Recover(cmd.Context(), masterSecret.Key, masterSecret.Salt); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -110,15 +100,6 @@ func parseRecoverFlags(cmd *cobra.Command) (recoverFlags, error) {
|
||||
return recoverFlags{}, fmt.Errorf("validating endpoint argument: %w", err)
|
||||
}
|
||||
|
||||
diskUUID, err := cmd.Flags().GetString("disk-uuid")
|
||||
if err != nil {
|
||||
return recoverFlags{}, fmt.Errorf("parsing disk-uuid argument: %w", err)
|
||||
}
|
||||
if match := diskUUIDRegexp.MatchString(diskUUID); !match {
|
||||
return recoverFlags{}, errors.New("flag '--disk-uuid' isn't a valid LUKS UUID")
|
||||
}
|
||||
diskUUID = strings.ToLower(diskUUID)
|
||||
|
||||
masterSecretPath, err := cmd.Flags().GetString("master-secret")
|
||||
if err != nil {
|
||||
return recoverFlags{}, fmt.Errorf("parsing master-secret path argument: %w", err)
|
||||
@ -131,7 +112,6 @@ func parseRecoverFlags(cmd *cobra.Command) (recoverFlags, error) {
|
||||
|
||||
return recoverFlags{
|
||||
endpoint: endpoint,
|
||||
diskUUID: diskUUID,
|
||||
secretPath: masterSecretPath,
|
||||
configPath: configPath,
|
||||
}, nil
|
||||
@ -139,12 +119,6 @@ func parseRecoverFlags(cmd *cobra.Command) (recoverFlags, error) {
|
||||
|
||||
type recoverFlags struct {
|
||||
endpoint string
|
||||
diskUUID string
|
||||
secretPath string
|
||||
configPath string
|
||||
}
|
||||
|
||||
// deriveStateDiskKey derives a state disk key from a master key, a salt, and a disk UUID.
|
||||
func deriveStateDiskKey(masterKey, salt []byte, diskUUID string) ([]byte, error) {
|
||||
return crypto.DeriveKey(masterKey, salt, []byte(crypto.HKDFInfoPrefix+diskUUID), crypto.StateDiskKeyLength)
|
||||
}
|
||||
|
@ -8,10 +8,11 @@ package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/atls"
|
||||
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
@ -57,7 +58,6 @@ func TestRecover(t *testing.T) {
|
||||
client *stubRecoveryClient
|
||||
masterSecret testvector.HKDF
|
||||
endpointFlag string
|
||||
diskUUIDFlag string
|
||||
masterSecretFlag string
|
||||
configFlag string
|
||||
stateless bool
|
||||
@ -67,29 +67,13 @@ func TestRecover(t *testing.T) {
|
||||
existingState: validState,
|
||||
client: &stubRecoveryClient{},
|
||||
endpointFlag: "192.0.2.1",
|
||||
diskUUIDFlag: testvector.HKDFZero.Info,
|
||||
masterSecret: testvector.HKDFZero,
|
||||
},
|
||||
"uppercase disk uuid works": {
|
||||
existingState: validState,
|
||||
client: &stubRecoveryClient{},
|
||||
endpointFlag: "192.0.2.1",
|
||||
diskUUIDFlag: strings.ToUpper(testvector.HKDF0xFF.Info),
|
||||
masterSecret: testvector.HKDF0xFF,
|
||||
},
|
||||
"lowercase disk uuid results in same key": {
|
||||
existingState: validState,
|
||||
client: &stubRecoveryClient{},
|
||||
endpointFlag: "192.0.2.1",
|
||||
diskUUIDFlag: strings.ToLower(testvector.HKDF0xFF.Info),
|
||||
masterSecret: testvector.HKDF0xFF,
|
||||
},
|
||||
"missing flags": {
|
||||
wantErr: true,
|
||||
},
|
||||
"missing config": {
|
||||
endpointFlag: "192.0.2.1",
|
||||
diskUUIDFlag: testvector.HKDFZero.Info,
|
||||
masterSecret: testvector.HKDFZero,
|
||||
configFlag: "nonexistent-config",
|
||||
wantErr: true,
|
||||
@ -97,7 +81,6 @@ func TestRecover(t *testing.T) {
|
||||
"missing state": {
|
||||
existingState: validState,
|
||||
endpointFlag: "192.0.2.1",
|
||||
diskUUIDFlag: testvector.HKDFZero.Info,
|
||||
masterSecret: testvector.HKDFZero,
|
||||
stateless: true,
|
||||
wantErr: true,
|
||||
@ -105,7 +88,6 @@ func TestRecover(t *testing.T) {
|
||||
"invalid cloud provider": {
|
||||
existingState: invalidCSPState,
|
||||
endpointFlag: "192.0.2.1",
|
||||
diskUUIDFlag: testvector.HKDFZero.Info,
|
||||
masterSecret: testvector.HKDFZero,
|
||||
wantErr: true,
|
||||
},
|
||||
@ -113,7 +95,6 @@ func TestRecover(t *testing.T) {
|
||||
existingState: validState,
|
||||
client: &stubRecoveryClient{connectErr: errors.New("connect failed")},
|
||||
endpointFlag: "192.0.2.1",
|
||||
diskUUIDFlag: testvector.HKDFZero.Info,
|
||||
masterSecret: testvector.HKDFZero,
|
||||
wantErr: true,
|
||||
},
|
||||
@ -121,7 +102,6 @@ func TestRecover(t *testing.T) {
|
||||
existingState: validState,
|
||||
client: &stubRecoveryClient{pushStateDiskKeyErr: errors.New("pushing key failed")},
|
||||
endpointFlag: "192.0.2.1",
|
||||
diskUUIDFlag: testvector.HKDFZero.Info,
|
||||
masterSecret: testvector.HKDFZero,
|
||||
wantErr: true,
|
||||
},
|
||||
@ -140,9 +120,6 @@ func TestRecover(t *testing.T) {
|
||||
if tc.endpointFlag != "" {
|
||||
require.NoError(cmd.Flags().Set("endpoint", tc.endpointFlag))
|
||||
}
|
||||
if tc.diskUUIDFlag != "" {
|
||||
require.NoError(cmd.Flags().Set("disk-uuid", tc.diskUUIDFlag))
|
||||
}
|
||||
if tc.masterSecretFlag != "" {
|
||||
require.NoError(cmd.Flags().Set("master-secret", tc.masterSecretFlag))
|
||||
}
|
||||
@ -170,7 +147,6 @@ func TestRecover(t *testing.T) {
|
||||
|
||||
assert.NoError(err)
|
||||
assert.Contains(out.String(), "Pushed recovery key.")
|
||||
assert.Equal(tc.masterSecret.Output, tc.client.pushStateDiskKeyKey)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -185,38 +161,24 @@ func TestParseRecoverFlags(t *testing.T) {
|
||||
wantErr: true,
|
||||
},
|
||||
"invalid ip": {
|
||||
args: []string{"-e", "192.0.2.1:2:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012"},
|
||||
wantErr: true,
|
||||
},
|
||||
"invalid disk uuid": {
|
||||
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "invalid"},
|
||||
args: []string{"-e", "192.0.2.1:2:2"},
|
||||
wantErr: true,
|
||||
},
|
||||
"minimal args set": {
|
||||
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012"},
|
||||
args: []string{"-e", "192.0.2.1:2"},
|
||||
wantFlags: recoverFlags{
|
||||
endpoint: "192.0.2.1:2",
|
||||
diskUUID: "12345678-1234-1234-1234-123456789012",
|
||||
secretPath: "constellation-mastersecret.json",
|
||||
},
|
||||
},
|
||||
"all args set": {
|
||||
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012", "--config", "config-path", "--master-secret", "/path/super-secret.json"},
|
||||
args: []string{"-e", "192.0.2.1:2", "--config", "config-path", "--master-secret", "/path/super-secret.json"},
|
||||
wantFlags: recoverFlags{
|
||||
endpoint: "192.0.2.1:2",
|
||||
diskUUID: "12345678-1234-1234-1234-123456789012",
|
||||
secretPath: "/path/super-secret.json",
|
||||
configPath: "config-path",
|
||||
},
|
||||
},
|
||||
"uppercase disk-uuid is converted to lowercase": {
|
||||
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "ABCDEFAB-CDEF-ABCD-ABCD-ABCDEFABCDEF"},
|
||||
wantFlags: recoverFlags{
|
||||
endpoint: "192.0.2.1:2",
|
||||
diskUUID: "abcdefab-cdef-abcd-abcd-abcdefabcdef",
|
||||
secretPath: "constellation-mastersecret.json",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
@ -239,26 +201,26 @@ func TestParseRecoverFlags(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveStateDiskKey(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
masterSecret testvector.HKDF
|
||||
}{
|
||||
"all zero": {
|
||||
masterSecret: testvector.HKDFZero,
|
||||
},
|
||||
"all 0xff": {
|
||||
masterSecret: testvector.HKDF0xFF,
|
||||
},
|
||||
}
|
||||
type stubRecoveryClient struct {
|
||||
conn bool
|
||||
connectErr error
|
||||
closeErr error
|
||||
pushStateDiskKeyErr error
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
stateDiskKey, err := deriveStateDiskKey(tc.masterSecret.Secret, tc.masterSecret.Salt, tc.masterSecret.Info)
|
||||
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.masterSecret.Output, stateDiskKey)
|
||||
})
|
||||
}
|
||||
pushStateDiskKeyKey []byte
|
||||
}
|
||||
|
||||
func (c *stubRecoveryClient) Connect(string, atls.Validator) error {
|
||||
c.conn = true
|
||||
return c.connectErr
|
||||
}
|
||||
|
||||
func (c *stubRecoveryClient) Close() error {
|
||||
c.conn = false
|
||||
return c.closeErr
|
||||
}
|
||||
|
||||
func (c *stubRecoveryClient) Recover(_ context.Context, stateDiskKey, _ []byte) error {
|
||||
c.pushStateDiskKeyKey = stateDiskKey
|
||||
return c.pushStateDiskKeyErr
|
||||
}
|
||||
|
@ -1,20 +0,0 @@
|
||||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/atls"
|
||||
)
|
||||
|
||||
type recoveryClient interface {
|
||||
Connect(endpoint string, validators atls.Validator) error
|
||||
PushStateDiskKey(ctx context.Context, stateDiskKey, measurementSecret []byte) error
|
||||
io.Closer
|
||||
}
|
@ -1,37 +0,0 @@
|
||||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/atls"
|
||||
)
|
||||
|
||||
type stubRecoveryClient struct {
|
||||
conn bool
|
||||
connectErr error
|
||||
closeErr error
|
||||
pushStateDiskKeyErr error
|
||||
|
||||
pushStateDiskKeyKey []byte
|
||||
}
|
||||
|
||||
func (c *stubRecoveryClient) Connect(_ string, _ atls.Validator) error {
|
||||
c.conn = true
|
||||
return c.connectErr
|
||||
}
|
||||
|
||||
func (c *stubRecoveryClient) Close() error {
|
||||
c.conn = false
|
||||
return c.closeErr
|
||||
}
|
||||
|
||||
func (c *stubRecoveryClient) PushStateDiskKey(_ context.Context, stateDiskKey, _ []byte) error {
|
||||
c.pushStateDiskKeyKey = stateDiskKey
|
||||
return c.pushStateDiskKeyErr
|
||||
}
|
@ -9,17 +9,21 @@ package proto
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/edgelesssys/constellation/disk-mapper/recoverproto"
|
||||
"github.com/edgelesssys/constellation/internal/atls"
|
||||
"github.com/edgelesssys/constellation/internal/attestation"
|
||||
"github.com/edgelesssys/constellation/internal/crypto"
|
||||
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
|
||||
"github.com/edgelesssys/constellation/state/keyproto"
|
||||
"go.uber.org/multierr"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// KeyClient wraps a KeyAPI client and the connection to it.
|
||||
type KeyClient struct {
|
||||
conn *grpc.ClientConn
|
||||
keyapi keyproto.APIClient
|
||||
// RecoverClient wraps a recoverAPI client and the connection to it.
|
||||
type RecoverClient struct {
|
||||
conn *grpc.ClientConn
|
||||
recoverapi recoverproto.APIClient
|
||||
}
|
||||
|
||||
// Connect connects the client to a given server, using the handed
|
||||
@ -27,7 +31,7 @@ type KeyClient struct {
|
||||
// The connection must be closed using Close(). If connect is
|
||||
// called on a client that already has a connection, the old
|
||||
// connection is closed.
|
||||
func (c *KeyClient) Connect(endpoint string, validators atls.Validator) error {
|
||||
func (c *RecoverClient) Connect(endpoint string, validators atls.Validator) error {
|
||||
creds := atlscredentials.New(nil, []atls.Validator{validators})
|
||||
|
||||
conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(creds))
|
||||
@ -36,14 +40,14 @@ func (c *KeyClient) Connect(endpoint string, validators atls.Validator) error {
|
||||
}
|
||||
_ = c.Close()
|
||||
c.conn = conn
|
||||
c.keyapi = keyproto.NewAPIClient(conn)
|
||||
c.recoverapi = recoverproto.NewAPIClient(conn)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the grpc connection of the client.
|
||||
// Close is idempotent and can be called on non connected clients
|
||||
// without returning an error.
|
||||
func (c *KeyClient) Close() error {
|
||||
func (c *RecoverClient) Close() error {
|
||||
if c.conn == nil {
|
||||
return nil
|
||||
}
|
||||
@ -55,17 +59,58 @@ func (c *KeyClient) Close() error {
|
||||
}
|
||||
|
||||
// PushStateDiskKey pushes the state disk key to a constellation instance in recovery mode.
|
||||
// The state disk key must be derived from the UUID of the state disk and the master key.
|
||||
func (c *KeyClient) PushStateDiskKey(ctx context.Context, stateDiskKey, measurementSecret []byte) error {
|
||||
if c.keyapi == nil {
|
||||
func (c *RecoverClient) Recover(ctx context.Context, masterSecret, salt []byte) (retErr error) {
|
||||
if c.recoverapi == nil {
|
||||
return errors.New("client is not connected")
|
||||
}
|
||||
|
||||
req := &keyproto.PushStateDiskKeyRequest{
|
||||
StateDiskKey: stateDiskKey,
|
||||
MeasurementSecret: measurementSecret,
|
||||
measurementSecret, err := attestation.DeriveMeasurementSecret(masterSecret, salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := c.keyapi.PushStateDiskKey(ctx, req)
|
||||
return err
|
||||
recoverclient, err := c.recoverapi.Recover(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := recoverclient.CloseSend(); err != nil {
|
||||
multierr.AppendInto(&retErr, err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := recoverclient.Send(&recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_MeasurementSecret{
|
||||
MeasurementSecret: measurementSecret,
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
res, err := recoverclient.Recv()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stateDiskKey, err := deriveStateDiskKey(masterSecret, salt, res.DiskUuid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := recoverclient.Send(&recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_StateDiskKey{
|
||||
StateDiskKey: stateDiskKey,
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := recoverclient.Recv(); err != nil && !errors.Is(err, io.EOF) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// deriveStateDiskKey derives a state disk key from a master key, a salt, and a disk UUID.
|
||||
func deriveStateDiskKey(masterKey, salt []byte, diskUUID string) ([]byte, error) {
|
||||
return crypto.DeriveKey(masterKey, salt, []byte(crypto.HKDFInfoPrefix+diskUUID), crypto.StateDiskKeyLength)
|
||||
}
|
||||
|
38
cli/internal/proto/recover_test.go
Normal file
38
cli/internal/proto/recover_test.go
Normal file
@ -0,0 +1,38 @@
|
||||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/crypto/testvector"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDeriveStateDiskKey(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
masterSecret testvector.HKDF
|
||||
}{
|
||||
"all zero": {
|
||||
masterSecret: testvector.HKDFZero,
|
||||
},
|
||||
"all 0xff": {
|
||||
masterSecret: testvector.HKDF0xFF,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
stateDiskKey, err := deriveStateDiskKey(tc.masterSecret.Secret, tc.masterSecret.Salt, tc.masterSecret.Info)
|
||||
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.masterSecret.Output, stateDiskKey)
|
||||
})
|
||||
}
|
||||
}
|
@ -4,11 +4,11 @@ Files and source code for mounting persistent state disks
|
||||
|
||||
## Testing
|
||||
|
||||
Integration test is available in `state/test/integration_test.go`.
|
||||
Integration test is available in `disk-mapper/test/integration_test.go`.
|
||||
The integration test requires root privileges since it uses dm-crypt.
|
||||
Build and run the test:
|
||||
|
||||
```bash
|
||||
go test -c -tags=integration ./state/test/
|
||||
go test -c -tags=integration ./disk-mapper/test/
|
||||
sudo ./test.test
|
||||
```
|
@ -11,12 +11,17 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/disk-mapper/internal/mapper"
|
||||
"github.com/edgelesssys/constellation/disk-mapper/internal/recoveryserver"
|
||||
"github.com/edgelesssys/constellation/disk-mapper/internal/rejoinclient"
|
||||
"github.com/edgelesssys/constellation/disk-mapper/internal/setup"
|
||||
"github.com/edgelesssys/constellation/internal/atls"
|
||||
"github.com/edgelesssys/constellation/internal/attestation/azure"
|
||||
"github.com/edgelesssys/constellation/internal/attestation/gcp"
|
||||
"github.com/edgelesssys/constellation/internal/attestation/qemu"
|
||||
@ -26,10 +31,8 @@ import (
|
||||
"github.com/edgelesssys/constellation/internal/cloud/metadata"
|
||||
qemucloud "github.com/edgelesssys/constellation/internal/cloud/qemu"
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
"github.com/edgelesssys/constellation/internal/grpc/dialer"
|
||||
"github.com/edgelesssys/constellation/internal/logger"
|
||||
"github.com/edgelesssys/constellation/state/internal/keyservice"
|
||||
"github.com/edgelesssys/constellation/state/internal/mapper"
|
||||
"github.com/edgelesssys/constellation/state/internal/setup"
|
||||
tpmClient "github.com/google/go-tpm-tools/client"
|
||||
"github.com/google/go-tpm/tpm2"
|
||||
"github.com/spf13/afero"
|
||||
@ -54,8 +57,8 @@ func main() {
|
||||
// set up metadata API and quote issuer for aTLS connections
|
||||
var err error
|
||||
var diskPath string
|
||||
var issuer keyservice.QuoteIssuer
|
||||
var metadata metadata.InstanceLister
|
||||
var issuer atls.Issuer
|
||||
var metadataAPI setup.MetadataAPI
|
||||
switch strings.ToLower(*csp) {
|
||||
case "azure":
|
||||
diskPath, err = filepath.EvalSymlinks(azureStateDiskPath)
|
||||
@ -63,7 +66,7 @@ func main() {
|
||||
_ = exportPCRs()
|
||||
log.With(zap.Error(err)).Fatalf("Unable to resolve Azure state disk path")
|
||||
}
|
||||
metadata, err = azurecloud.NewMetadata(context.Background())
|
||||
metadataAPI, err = azurecloud.NewMetadata(context.Background())
|
||||
if err != nil {
|
||||
log.With(zap.Error).Fatalf("Failed to create Azure metadata API")
|
||||
}
|
||||
@ -81,12 +84,13 @@ func main() {
|
||||
if err != nil {
|
||||
log.With(zap.Error).Fatalf("Failed to create GCP client")
|
||||
}
|
||||
metadata = gcpcloud.New(gcpClient)
|
||||
metadataAPI = gcpcloud.New(gcpClient)
|
||||
|
||||
case "qemu":
|
||||
diskPath = qemuStateDiskPath
|
||||
issuer = qemu.NewIssuer()
|
||||
metadata = &qemucloud.Metadata{}
|
||||
metadataAPI = &qemucloud.Metadata{}
|
||||
_ = exportPCRs()
|
||||
|
||||
default:
|
||||
log.Fatalf("CSP %s is not supported by Constellation", *csp)
|
||||
@ -104,7 +108,6 @@ func main() {
|
||||
*csp,
|
||||
diskPath,
|
||||
afero.Afero{Fs: afero.NewOsFs()},
|
||||
keyservice.New(log.Named("keyService"), issuer, metadata, 20*time.Second, 20*time.Second), // try to request a key every 20 seconds
|
||||
mapper,
|
||||
setup.DiskMounter{},
|
||||
vtpm.OpenVTPM,
|
||||
@ -112,7 +115,23 @@ func main() {
|
||||
|
||||
// prepare the state disk
|
||||
if mapper.IsLUKSDevice() {
|
||||
err = setupManger.PrepareExistingDisk()
|
||||
// set up rejoin client
|
||||
var self metadata.InstanceMetadata
|
||||
self, err = metadataAPI.Self(context.Background())
|
||||
if err != nil {
|
||||
log.With(zap.Error(err)).Fatalf("Failed to get self metadata")
|
||||
}
|
||||
rejoinClient := rejoinclient.New(
|
||||
dialer.New(issuer, nil, &net.Dialer{}),
|
||||
self,
|
||||
metadataAPI,
|
||||
log.Named("rejoinClient"),
|
||||
)
|
||||
|
||||
// set up recovery server
|
||||
recoveryServer := recoveryserver.New(issuer, log.Named("recoveryServer"))
|
||||
|
||||
err = setupManger.PrepareExistingDisk(setup.NewNodeRecoverer(recoveryServer, rejoinClient))
|
||||
} else {
|
||||
err = setupManger.PrepareNewDisk()
|
||||
}
|
134
disk-mapper/internal/recoveryserver/server.go
Normal file
134
disk-mapper/internal/recoveryserver/server.go
Normal file
@ -0,0 +1,134 @@
|
||||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package recoveryserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/edgelesssys/constellation/disk-mapper/recoverproto"
|
||||
"github.com/edgelesssys/constellation/internal/atls"
|
||||
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
|
||||
"github.com/edgelesssys/constellation/internal/logger"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// RecoveryServer is a gRPC server that can be used by an admin to recover a restarting node.
|
||||
type RecoveryServer struct {
|
||||
mux sync.Mutex
|
||||
|
||||
diskUUID string
|
||||
stateDiskKey []byte
|
||||
measurementSecret []byte
|
||||
grpcServer server
|
||||
|
||||
log *logger.Logger
|
||||
|
||||
recoverproto.UnimplementedAPIServer
|
||||
}
|
||||
|
||||
// New returns a new RecoveryServer.
|
||||
func New(issuer atls.Issuer, log *logger.Logger) *RecoveryServer {
|
||||
server := &RecoveryServer{
|
||||
log: log,
|
||||
}
|
||||
|
||||
grpcServer := grpc.NewServer(
|
||||
grpc.Creds(atlscredentials.New(issuer, nil)),
|
||||
log.Named("gRPC").GetServerStreamInterceptor(),
|
||||
)
|
||||
recoverproto.RegisterAPIServer(grpcServer, server)
|
||||
|
||||
server.grpcServer = grpcServer
|
||||
return server
|
||||
}
|
||||
|
||||
// Serve starts the recovery server.
|
||||
// It blocks until a recover request call is successful.
|
||||
// The server will shut down when the call is successful and the keys are returned.
|
||||
// Additionally, the server can be shutdown by canceling the context.
|
||||
func (s *RecoveryServer) Serve(ctx context.Context, listener net.Listener, diskUUID string) (diskKey, measurementSecret []byte, err error) {
|
||||
s.log.Infof("Starting RecoveryServer")
|
||||
s.diskUUID = diskUUID
|
||||
recoveryDone := make(chan struct{}, 1)
|
||||
var serveErr error
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
defer wg.Wait()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
serveErr = s.grpcServer.Serve(listener)
|
||||
recoveryDone <- struct{}{}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.log.Infof("Context canceled, shutting down server")
|
||||
s.grpcServer.GracefulStop()
|
||||
return nil, nil, ctx.Err()
|
||||
case <-recoveryDone:
|
||||
if serveErr != nil {
|
||||
return nil, nil, serveErr
|
||||
}
|
||||
return s.stateDiskKey, s.measurementSecret, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Recover is a bidirectional streaming RPC that is used to send recovery keys to a restarting node.
|
||||
func (s *RecoveryServer) Recover(stream recoverproto.API_RecoverServer) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
s.log.Infof("Received recover call")
|
||||
|
||||
msg, err := stream.Recv()
|
||||
if err != nil {
|
||||
return status.Error(codes.Internal, "failed to receive message")
|
||||
}
|
||||
|
||||
measurementSecret, ok := msg.GetRequest().(*recoverproto.RecoverMessage_MeasurementSecret)
|
||||
if !ok {
|
||||
s.log.Errorf("Received invalid first message: not a measurement secret")
|
||||
return status.Error(codes.InvalidArgument, "first message is not a measurement secret")
|
||||
}
|
||||
|
||||
if err := stream.Send(&recoverproto.RecoverResponse{DiskUuid: s.diskUUID}); err != nil {
|
||||
s.log.With(zap.Error(err)).Errorf("Failed to send disk UUID")
|
||||
return status.Error(codes.Internal, "failed to send response")
|
||||
}
|
||||
|
||||
msg, err = stream.Recv()
|
||||
if err != nil {
|
||||
s.log.With(zap.Error(err)).Errorf("Failed to receive disk key")
|
||||
return status.Error(codes.Internal, "failed to receive message")
|
||||
}
|
||||
|
||||
stateDiskKey, ok := msg.GetRequest().(*recoverproto.RecoverMessage_StateDiskKey)
|
||||
if !ok {
|
||||
s.log.Errorf("Received invalid second message: not a state disk key")
|
||||
return status.Error(codes.InvalidArgument, "second message is not a state disk key")
|
||||
}
|
||||
|
||||
s.stateDiskKey = stateDiskKey.StateDiskKey
|
||||
s.measurementSecret = measurementSecret.MeasurementSecret
|
||||
s.log.Infof("Received state disk key and measurement secret, shutting down server")
|
||||
|
||||
go s.grpcServer.GracefulStop()
|
||||
return nil
|
||||
}
|
||||
|
||||
type server interface {
|
||||
Serve(net.Listener) error
|
||||
GracefulStop()
|
||||
}
|
194
disk-mapper/internal/recoveryserver/server_test.go
Normal file
194
disk-mapper/internal/recoveryserver/server_test.go
Normal file
@ -0,0 +1,194 @@
|
||||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package recoveryserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/disk-mapper/recoverproto"
|
||||
"github.com/edgelesssys/constellation/internal/atls"
|
||||
"github.com/edgelesssys/constellation/internal/grpc/dialer"
|
||||
"github.com/edgelesssys/constellation/internal/grpc/testdialer"
|
||||
"github.com/edgelesssys/constellation/internal/logger"
|
||||
"github.com/edgelesssys/constellation/internal/oid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
||||
|
||||
func TestServe(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
log := logger.NewTest(t)
|
||||
uuid := "uuid"
|
||||
server := New(atls.NewFakeIssuer(oid.Dummy{}), log)
|
||||
dialer := testdialer.NewBufconnDialer()
|
||||
listener := dialer.GetListener("192.0.2.1:1234")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Serve method returns when context is canceled
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _, err := server.Serve(ctx, listener, uuid)
|
||||
assert.ErrorIs(err, context.Canceled)
|
||||
}()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cancel()
|
||||
wg.Wait()
|
||||
|
||||
server = New(atls.NewFakeIssuer(oid.Dummy{}), log)
|
||||
dialer = testdialer.NewBufconnDialer()
|
||||
listener = dialer.GetListener("192.0.2.1:1234")
|
||||
|
||||
// Serve method returns without error when the server is shut down
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _, err := server.Serve(context.Background(), listener, uuid)
|
||||
assert.NoError(err)
|
||||
}()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
server.grpcServer.GracefulStop()
|
||||
wg.Wait()
|
||||
|
||||
// Serve method returns an error when serving is unsuccessful
|
||||
_, _, err := server.Serve(context.Background(), listener, uuid)
|
||||
assert.Error(err)
|
||||
}
|
||||
|
||||
func TestRecover(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
initialMsg message
|
||||
keyMsg message
|
||||
wantErr bool
|
||||
}{
|
||||
"success": {
|
||||
initialMsg: message{
|
||||
recoverMsg: &recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_MeasurementSecret{
|
||||
MeasurementSecret: []byte("measurementSecret"),
|
||||
},
|
||||
},
|
||||
},
|
||||
keyMsg: message{
|
||||
recoverMsg: &recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_StateDiskKey{
|
||||
StateDiskKey: []byte("diskKey"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"first message is not a measurement secret": {
|
||||
initialMsg: message{
|
||||
recoverMsg: &recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_StateDiskKey{
|
||||
StateDiskKey: []byte("diskKey"),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
keyMsg: message{
|
||||
recoverMsg: &recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_StateDiskKey{
|
||||
StateDiskKey: []byte("diskKey"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"second message is not a state disk key": {
|
||||
initialMsg: message{
|
||||
recoverMsg: &recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_MeasurementSecret{
|
||||
MeasurementSecret: []byte("measurementSecret"),
|
||||
},
|
||||
},
|
||||
},
|
||||
keyMsg: message{
|
||||
recoverMsg: &recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_MeasurementSecret{
|
||||
MeasurementSecret: []byte("measurementSecret"),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
ctx := context.Background()
|
||||
serverUUID := "uuid"
|
||||
server := New(atls.NewFakeIssuer(oid.Dummy{}), logger.NewTest(t))
|
||||
netDialer := testdialer.NewBufconnDialer()
|
||||
listener := netDialer.GetListener("192.0.2.1:1234")
|
||||
|
||||
var diskKey, measurementSecret []byte
|
||||
var serveErr error
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait()
|
||||
|
||||
serveCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
diskKey, measurementSecret, serveErr = server.Serve(serveCtx, listener, serverUUID)
|
||||
}()
|
||||
|
||||
conn, err := dialer.New(nil, nil, netDialer).Dial(ctx, "192.0.2.1:1234")
|
||||
require.NoError(err)
|
||||
defer conn.Close()
|
||||
client, err := recoverproto.NewAPIClient(conn).Recover(ctx)
|
||||
require.NoError(err)
|
||||
|
||||
// Send initial message
|
||||
err = client.Send(tc.initialMsg.recoverMsg)
|
||||
require.NoError(err)
|
||||
|
||||
// Receive uuid
|
||||
uuid, err := client.Recv()
|
||||
if tc.initialMsg.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.Equal(serverUUID, uuid.DiskUuid)
|
||||
|
||||
// Send key message
|
||||
err = client.Send(tc.keyMsg.recoverMsg)
|
||||
require.NoError(err)
|
||||
|
||||
_, err = client.Recv()
|
||||
if tc.keyMsg.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.ErrorIs(io.EOF, err)
|
||||
|
||||
wg.Wait()
|
||||
assert.NoError(serveErr)
|
||||
assert.Equal(tc.initialMsg.recoverMsg.GetMeasurementSecret(), measurementSecret)
|
||||
assert.Equal(tc.keyMsg.recoverMsg.GetStateDiskKey(), diskKey)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type message struct {
|
||||
recoverMsg *recoverproto.RecoverMessage
|
||||
wantErr bool
|
||||
}
|
167
disk-mapper/internal/rejoinclient/client.go
Normal file
167
disk-mapper/internal/rejoinclient/client.go
Normal file
@ -0,0 +1,167 @@
|
||||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package rejoinclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/cloud/metadata"
|
||||
"github.com/edgelesssys/constellation/internal/logger"
|
||||
"github.com/edgelesssys/constellation/internal/role"
|
||||
"github.com/edgelesssys/constellation/joinservice/joinproto"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"k8s.io/utils/clock"
|
||||
)
|
||||
|
||||
const (
|
||||
interval = 30 * time.Second
|
||||
timeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// RejoinClient is a client for requesting the needed information
|
||||
// for rejoining a cluster as a restarting worker or control-plane node.
|
||||
type RejoinClient struct {
|
||||
diskUUID string
|
||||
nodeInfo metadata.InstanceMetadata
|
||||
|
||||
timeout time.Duration
|
||||
interval time.Duration
|
||||
clock clock.WithTicker
|
||||
|
||||
dialer grpcDialer
|
||||
metadataAPI metadataAPI
|
||||
|
||||
log *logger.Logger
|
||||
}
|
||||
|
||||
// New returns a new RejoinClient.
|
||||
func New(dial grpcDialer, nodeInfo metadata.InstanceMetadata,
|
||||
meta metadataAPI, log *logger.Logger,
|
||||
) *RejoinClient {
|
||||
return &RejoinClient{
|
||||
nodeInfo: nodeInfo,
|
||||
timeout: timeout,
|
||||
interval: interval,
|
||||
clock: clock.RealClock{},
|
||||
dialer: dial,
|
||||
metadataAPI: meta,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the rejoin client.
|
||||
// The client will continuously request available control-plane endpoints
|
||||
// from the metadata API and send rejoin requests to them.
|
||||
// The function returns after a successful rejoin request has been performed.
|
||||
func (c *RejoinClient) Start(ctx context.Context, diskUUID string) (diskKey, measurementSecret []byte) {
|
||||
c.log.Infof("Starting RejoinClient")
|
||||
c.diskUUID = diskUUID
|
||||
ticker := c.clock.NewTicker(c.interval)
|
||||
|
||||
defer ticker.Stop()
|
||||
defer c.log.Infof("RejoinClient stopped")
|
||||
|
||||
for {
|
||||
endpoints, err := c.getControlPlaneEndpoints()
|
||||
if err != nil {
|
||||
c.log.With(zap.Error(err)).Errorf("Failed to get control-plane endpoints")
|
||||
} else {
|
||||
c.log.With(zap.Strings("endpoints", endpoints)).Infof("Received list with JoinService endpoints")
|
||||
diskKey, measurementSecret, err = c.tryRejoinWithAvailableServices(ctx, endpoints)
|
||||
if err == nil {
|
||||
c.log.Infof("Successfully retrieved rejoin ticket")
|
||||
return diskKey, measurementSecret
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, nil
|
||||
case <-ticker.C():
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tryRejoinWithAvailableServices tries sending rejoin requests to the available endpoints.
|
||||
func (c *RejoinClient) tryRejoinWithAvailableServices(ctx context.Context, endpoints []string) (diskKey, measurementSecret []byte, err error) {
|
||||
for _, endpoint := range endpoints {
|
||||
c.log.With(zap.String("endpoint", endpoint)).Infof("Requesting rejoin ticket")
|
||||
rejoinTicket, err := c.requestRejoinTicket(endpoint)
|
||||
if err == nil {
|
||||
return rejoinTicket.StateDiskKey, rejoinTicket.MeasurementSecret, nil
|
||||
}
|
||||
c.log.With(zap.Error(err), zap.String("endpoint", endpoint)).Warnf("Failed to rejoin on endpoint")
|
||||
|
||||
// stop requesting additional endpoints if the context is done
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
c.log.Errorf("Failed to rejoin on all endpoints")
|
||||
return nil, nil, errors.New("failed to join on all endpoints")
|
||||
}
|
||||
|
||||
// requestRejoinTicket requests a rejoin ticket from the endpoint.
|
||||
func (c *RejoinClient) requestRejoinTicket(endpoint string) (*joinproto.IssueRejoinTicketResponse, error) {
|
||||
ctx, cancel := c.timeoutCtx()
|
||||
defer cancel()
|
||||
|
||||
conn, err := c.dialer.Dial(ctx, endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
return joinproto.NewAPIClient(conn).IssueRejoinTicket(ctx, &joinproto.IssueRejoinTicketRequest{DiskUuid: c.diskUUID})
|
||||
}
|
||||
|
||||
// getControlPlaneEndpoints requests the available control-plane endpoints from the metadata API.
|
||||
// The list is filtered to remove *this* node if it is a restarting control-plane node.
|
||||
func (c *RejoinClient) getControlPlaneEndpoints() ([]string, error) {
|
||||
ctx, cancel := c.timeoutCtx()
|
||||
defer cancel()
|
||||
endpoints, err := metadata.JoinServiceEndpoints(ctx, c.metadataAPI)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if c.nodeInfo.Role == role.ControlPlane {
|
||||
return removeSelfFromEndpoints(c.nodeInfo.VPCIP, endpoints), nil
|
||||
}
|
||||
return endpoints, nil
|
||||
}
|
||||
|
||||
// removeSelfFromEndpoints removes *this* node from the list of endpoints.
|
||||
// If an error occurs, the entry is removed from the list of endpoints.
|
||||
func removeSelfFromEndpoints(self string, endpoints []string) []string {
|
||||
var result []string
|
||||
for _, endpoint := range endpoints {
|
||||
host, _, err := net.SplitHostPort(endpoint)
|
||||
if err == nil && host != self {
|
||||
result = append(result, endpoint)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *RejoinClient) timeoutCtx() (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), c.timeout)
|
||||
}
|
||||
|
||||
type grpcDialer interface {
|
||||
Dial(ctx context.Context, target string) (*grpc.ClientConn, error)
|
||||
}
|
||||
|
||||
type metadataAPI interface {
|
||||
// List retrieves all instances belonging to the current constellation.
|
||||
List(ctx context.Context) ([]metadata.InstanceMetadata, error)
|
||||
}
|
308
disk-mapper/internal/rejoinclient/client_test.go
Normal file
308
disk-mapper/internal/rejoinclient/client_test.go
Normal file
@ -0,0 +1,308 @@
|
||||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package rejoinclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/cloud/metadata"
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
|
||||
"github.com/edgelesssys/constellation/internal/grpc/dialer"
|
||||
"github.com/edgelesssys/constellation/internal/grpc/testdialer"
|
||||
"github.com/edgelesssys/constellation/internal/logger"
|
||||
"github.com/edgelesssys/constellation/internal/role"
|
||||
"github.com/edgelesssys/constellation/joinservice/joinproto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/goleak"
|
||||
"google.golang.org/grpc"
|
||||
testclock "k8s.io/utils/clock/testing"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
||||
|
||||
func TestStartCancel(t *testing.T) {
|
||||
netDialer := testdialer.NewBufconnDialer()
|
||||
dialer := dialer.New(nil, nil, netDialer)
|
||||
|
||||
clock := testclock.NewFakeClock(time.Time{})
|
||||
|
||||
metaAPI := &stubMetadataAPI{
|
||||
instances: []metadata.InstanceMetadata{
|
||||
{
|
||||
Role: role.ControlPlane,
|
||||
VPCIP: "192.0.2.1",
|
||||
},
|
||||
{
|
||||
Role: role.ControlPlane,
|
||||
VPCIP: "192.0.2.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
client := &RejoinClient{
|
||||
dialer: dialer,
|
||||
nodeInfo: metadata.InstanceMetadata{Role: role.Worker},
|
||||
metadataAPI: metaAPI,
|
||||
log: logger.NewTest(t),
|
||||
timeout: time.Second * 30,
|
||||
interval: time.Second,
|
||||
clock: clock,
|
||||
}
|
||||
|
||||
serverCreds := atlscredentials.New(nil, nil)
|
||||
rejoinServer := grpc.NewServer(grpc.Creds(serverCreds))
|
||||
rejoinServiceAPI := &stubRejoinServiceAPI{err: errors.New("error")}
|
||||
joinproto.RegisterAPIServer(rejoinServer, rejoinServiceAPI)
|
||||
port := strconv.Itoa(constants.JoinServiceNodePort)
|
||||
listener := netDialer.GetListener(net.JoinHostPort("192.0.2.1", port))
|
||||
go rejoinServer.Serve(listener)
|
||||
defer rejoinServer.GracefulStop()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
client.Start(ctx, "uuid")
|
||||
}()
|
||||
|
||||
clock.Step(time.Millisecond)
|
||||
cancel()
|
||||
wg.Wait()
|
||||
assert.Equal(t, client.diskUUID, "uuid")
|
||||
}
|
||||
|
||||
func TestRemoveSelfFromEndpoints(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
self string
|
||||
endpoints []string
|
||||
}{
|
||||
"self is not in endpoints": {
|
||||
self: "192.0.2.1",
|
||||
endpoints: []string{
|
||||
"192.0.2.2:30090",
|
||||
"192.0.2.3:30090",
|
||||
"192.0.2.4:30090",
|
||||
"192.0.2.5:30090",
|
||||
"192.0.2.6:30090",
|
||||
},
|
||||
},
|
||||
"self is in endpoints": {
|
||||
self: "192.0.2.1",
|
||||
endpoints: []string{
|
||||
"192.0.2.2:30090",
|
||||
"192.0.2.3:30090",
|
||||
"192.0.2.4:30090",
|
||||
"192.0.2.5:30090",
|
||||
"192.0.2.6:30090",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
got := removeSelfFromEndpoints(tc.self, tc.endpoints)
|
||||
assert.NotContains(got, tc.self)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetControlPlaneEndpoints(t *testing.T) {
|
||||
testInstances := []metadata.InstanceMetadata{
|
||||
{
|
||||
Role: role.ControlPlane,
|
||||
VPCIP: "192.0.2.2",
|
||||
},
|
||||
{
|
||||
Role: role.ControlPlane,
|
||||
VPCIP: "192.0.2.3",
|
||||
},
|
||||
{
|
||||
Role: role.ControlPlane,
|
||||
VPCIP: "192.0.2.4",
|
||||
},
|
||||
{
|
||||
Role: role.Worker,
|
||||
VPCIP: "192.0.2.12",
|
||||
},
|
||||
{
|
||||
Role: role.Worker,
|
||||
VPCIP: "192.0.2.13",
|
||||
},
|
||||
{
|
||||
Role: role.Worker,
|
||||
VPCIP: "192.0.2.14",
|
||||
},
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
nodeInfo metadata.InstanceMetadata
|
||||
meta stubMetadataAPI
|
||||
wantInstances int
|
||||
wantErr bool
|
||||
}{
|
||||
"worker node": {
|
||||
nodeInfo: metadata.InstanceMetadata{
|
||||
Role: role.Worker,
|
||||
VPCIP: "192.0.2.1",
|
||||
},
|
||||
meta: stubMetadataAPI{
|
||||
instances: testInstances,
|
||||
},
|
||||
wantInstances: 3,
|
||||
},
|
||||
"control-plane node not in list": {
|
||||
nodeInfo: metadata.InstanceMetadata{
|
||||
Role: role.ControlPlane,
|
||||
VPCIP: "192.0.2.1",
|
||||
},
|
||||
meta: stubMetadataAPI{
|
||||
instances: testInstances,
|
||||
},
|
||||
wantInstances: 3,
|
||||
},
|
||||
"control-plane node in list": {
|
||||
nodeInfo: metadata.InstanceMetadata{
|
||||
Role: role.ControlPlane,
|
||||
VPCIP: "192.0.2.2",
|
||||
},
|
||||
meta: stubMetadataAPI{
|
||||
instances: testInstances,
|
||||
},
|
||||
wantInstances: 2,
|
||||
},
|
||||
"metadata error": {
|
||||
nodeInfo: metadata.InstanceMetadata{
|
||||
Role: role.ControlPlane,
|
||||
VPCIP: "192.0.2.1",
|
||||
},
|
||||
meta: stubMetadataAPI{
|
||||
err: errors.New("error"),
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
client := New(nil, tc.nodeInfo, tc.meta, logger.NewTest(t))
|
||||
|
||||
endpoints, err := client.getControlPlaneEndpoints()
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.NotContains(endpoints, tc.nodeInfo.VPCIP)
|
||||
assert.Len(endpoints, tc.wantInstances)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStart(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
nodeInfo metadata.InstanceMetadata
|
||||
}{
|
||||
"worker node": {
|
||||
nodeInfo: metadata.InstanceMetadata{
|
||||
Role: role.Worker,
|
||||
VPCIP: "192.0.2.99",
|
||||
},
|
||||
},
|
||||
"control-plane node": {
|
||||
nodeInfo: metadata.InstanceMetadata{
|
||||
Role: role.ControlPlane,
|
||||
VPCIP: "192.0.2.99",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
diskKey := []byte("disk-key")
|
||||
measurementSecret := []byte("measurement-secret")
|
||||
netDialer := testdialer.NewBufconnDialer()
|
||||
dialer := dialer.New(nil, nil, netDialer)
|
||||
serverCreds := atlscredentials.New(nil, nil)
|
||||
rejoinServer := grpc.NewServer(grpc.Creds(serverCreds))
|
||||
rejoinServiceAPI := &stubRejoinServiceAPI{
|
||||
rejoinTicketResponse: &joinproto.IssueRejoinTicketResponse{
|
||||
StateDiskKey: diskKey,
|
||||
MeasurementSecret: measurementSecret,
|
||||
},
|
||||
}
|
||||
joinproto.RegisterAPIServer(rejoinServer, rejoinServiceAPI)
|
||||
port := strconv.Itoa(constants.JoinServiceNodePort)
|
||||
listener := netDialer.GetListener(net.JoinHostPort("192.0.2.1", port))
|
||||
go rejoinServer.Serve(listener)
|
||||
defer rejoinServer.GracefulStop()
|
||||
|
||||
meta := stubMetadataAPI{
|
||||
instances: []metadata.InstanceMetadata{
|
||||
{
|
||||
Role: role.ControlPlane,
|
||||
VPCIP: "192.0.2.1",
|
||||
},
|
||||
{
|
||||
Role: role.ControlPlane,
|
||||
VPCIP: "192.0.2.2",
|
||||
},
|
||||
{
|
||||
Role: role.Worker,
|
||||
VPCIP: "192.0.2.13",
|
||||
},
|
||||
{
|
||||
Role: role.Worker,
|
||||
VPCIP: "192.0.2.14",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
client := New(dialer, tc.nodeInfo, meta, logger.NewTest(t))
|
||||
|
||||
passphrase, secret := client.Start(context.Background(), "uuid")
|
||||
assert.Equal(diskKey, passphrase)
|
||||
assert.Equal(measurementSecret, secret)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type stubMetadataAPI struct {
|
||||
instances []metadata.InstanceMetadata
|
||||
err error
|
||||
}
|
||||
|
||||
func (s stubMetadataAPI) List(context.Context) ([]metadata.InstanceMetadata, error) {
|
||||
return s.instances, s.err
|
||||
}
|
||||
|
||||
type stubRejoinServiceAPI struct {
|
||||
rejoinTicketResponse *joinproto.IssueRejoinTicketResponse
|
||||
err error
|
||||
joinproto.UnimplementedAPIServer
|
||||
}
|
||||
|
||||
func (s *stubRejoinServiceAPI) IssueRejoinTicket(context.Context, *joinproto.IssueRejoinTicketRequest,
|
||||
) (*joinproto.IssueRejoinTicketResponse, error) {
|
||||
return s.rejoinTicketResponse, s.err
|
||||
}
|
@ -10,6 +10,8 @@ import (
|
||||
"io/fs"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/cloud/metadata"
|
||||
)
|
||||
|
||||
// Mounter is an interface for mount and unmount operations.
|
||||
@ -27,17 +29,23 @@ type DeviceMapper interface {
|
||||
UnmapDisk(target string) error
|
||||
}
|
||||
|
||||
// KeyWaiter is an interface to request and wait for disk decryption keys.
|
||||
type KeyWaiter interface {
|
||||
WaitForDecryptionKey(uuid, addr string) (diskKey, measurementSecret []byte, err error)
|
||||
ResetKey()
|
||||
}
|
||||
|
||||
// ConfigurationGenerator is an interface for generating systemd-cryptsetup@.service unit files.
|
||||
type ConfigurationGenerator interface {
|
||||
Generate(volumeName, encryptedDevice, keyFile, options string) error
|
||||
}
|
||||
|
||||
// MetadataAPI is an interface for accessing cloud metadata.
|
||||
type MetadataAPI interface {
|
||||
metadata.InstanceSelfer
|
||||
metadata.InstanceLister
|
||||
}
|
||||
|
||||
// RecoveryDoer is an interface to perform key recovery operations.
|
||||
// Calls to Do may be blocking, and if successful return a passphrase and measurementSecret.
|
||||
type RecoveryDoer interface {
|
||||
Do(uuid, endpoint string) (passphrase, measurementSecret []byte, err error)
|
||||
}
|
||||
|
||||
// DiskMounter uses the syscall package to mount disks.
|
||||
type DiskMounter struct{}
|
||||
|
@ -7,14 +7,18 @@ SPDX-License-Identifier: AGPL-3.0-only
|
||||
package setup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/edgelesssys/constellation/disk-mapper/internal/systemd"
|
||||
"github.com/edgelesssys/constellation/internal/attestation"
|
||||
"github.com/edgelesssys/constellation/internal/attestation/vtpm"
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
@ -22,9 +26,7 @@ import (
|
||||
"github.com/edgelesssys/constellation/internal/file"
|
||||
"github.com/edgelesssys/constellation/internal/logger"
|
||||
"github.com/edgelesssys/constellation/internal/nodestate"
|
||||
"github.com/edgelesssys/constellation/state/internal/systemd"
|
||||
"github.com/spf13/afero"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -38,56 +40,53 @@ const (
|
||||
|
||||
// SetupManager handles formatting, mapping, mounting and unmounting of state disks.
|
||||
type SetupManager struct {
|
||||
log *logger.Logger
|
||||
csp string
|
||||
diskPath string
|
||||
fs afero.Afero
|
||||
keyWaiter KeyWaiter
|
||||
mapper DeviceMapper
|
||||
mounter Mounter
|
||||
config ConfigurationGenerator
|
||||
openTPM vtpm.TPMOpenFunc
|
||||
log *logger.Logger
|
||||
csp string
|
||||
diskPath string
|
||||
fs afero.Afero
|
||||
mapper DeviceMapper
|
||||
mounter Mounter
|
||||
config ConfigurationGenerator
|
||||
openTPM vtpm.TPMOpenFunc
|
||||
}
|
||||
|
||||
// New initializes a SetupManager with the given parameters.
|
||||
func New(log *logger.Logger, csp string, diskPath string, fs afero.Afero, keyWaiter KeyWaiter, mapper DeviceMapper, mounter Mounter, openTPM vtpm.TPMOpenFunc) *SetupManager {
|
||||
func New(log *logger.Logger, csp string, diskPath string, fs afero.Afero,
|
||||
mapper DeviceMapper, mounter Mounter, openTPM vtpm.TPMOpenFunc,
|
||||
) *SetupManager {
|
||||
return &SetupManager{
|
||||
log: log,
|
||||
csp: csp,
|
||||
diskPath: diskPath,
|
||||
fs: fs,
|
||||
keyWaiter: keyWaiter,
|
||||
mapper: mapper,
|
||||
mounter: mounter,
|
||||
config: systemd.New(fs),
|
||||
openTPM: openTPM,
|
||||
log: log,
|
||||
csp: csp,
|
||||
diskPath: diskPath,
|
||||
fs: fs,
|
||||
mapper: mapper,
|
||||
mounter: mounter,
|
||||
config: systemd.New(fs),
|
||||
openTPM: openTPM,
|
||||
}
|
||||
}
|
||||
|
||||
// PrepareExistingDisk requests and waits for a decryption key to remap the encrypted state disk.
|
||||
// Once the disk is mapped, the function taints the node as initialized by updating it's PCRs.
|
||||
func (s *SetupManager) PrepareExistingDisk() error {
|
||||
func (s *SetupManager) PrepareExistingDisk(recover RecoveryDoer) error {
|
||||
s.log.Infof("Preparing existing state disk")
|
||||
uuid := s.mapper.DiskUUID()
|
||||
|
||||
endpoint := net.JoinHostPort("0.0.0.0", strconv.Itoa(constants.RecoveryPort))
|
||||
getKey:
|
||||
passphrase, measurementSecret, err := s.keyWaiter.WaitForDecryptionKey(uuid, endpoint)
|
||||
|
||||
passphrase, measurementSecret, err := recover.Do(uuid, endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to perform recovery: %w", err)
|
||||
}
|
||||
|
||||
if err := s.mapper.MapDisk(stateDiskMappedName, string(passphrase)); err != nil {
|
||||
// retry key fetching if disk mapping fails
|
||||
s.log.With(zap.Error(err)).Errorf("Failed to map state disk, retrying...")
|
||||
s.keyWaiter.ResetKey()
|
||||
goto getKey
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.mounter.MkdirAll(stateDiskMountPath, os.ModePerm); err != nil {
|
||||
return err
|
||||
}
|
||||
// we do not care about cleaning up the mount point on error, since any errors returned here should result in a kernel panic in the main function
|
||||
// we do not care about cleaning up the mount point on error, since any errors returned here should cause a boot failure
|
||||
if err := s.mounter.Mount(filepath.Join("/dev/mapper/", stateDiskMappedName), stateDiskMountPath, "ext4", syscall.MS_RDONLY, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -160,3 +159,68 @@ func (s *SetupManager) saveConfiguration(passphrase []byte) error {
|
||||
// systemd cryptsetup unit
|
||||
return s.config.Generate(stateDiskMappedName, s.diskPath, filepath.Join(keyPath, keyFile), cryptsetupOptions)
|
||||
}
|
||||
|
||||
type recoveryServer interface {
|
||||
Serve(context.Context, net.Listener, string) (key, secret []byte, err error)
|
||||
}
|
||||
|
||||
type rejoinClient interface {
|
||||
Start(context.Context, string) (key, secret []byte)
|
||||
}
|
||||
|
||||
type nodeRecoverer struct {
|
||||
recoveryServer recoveryServer
|
||||
rejoinClient rejoinClient
|
||||
}
|
||||
|
||||
// NewNodeRecoverer initializes a new nodeRecoverer.
|
||||
func NewNodeRecoverer(recoveryServer recoveryServer, rejoinClient rejoinClient) *nodeRecoverer {
|
||||
return &nodeRecoverer{
|
||||
recoveryServer: recoveryServer,
|
||||
rejoinClient: rejoinClient,
|
||||
}
|
||||
}
|
||||
|
||||
// Do performs a recovery procedure on the given state disk.
|
||||
// The method starts a gRPC server to allow manual recovery by a user.
|
||||
// At the same time it tries to request a decryption key from all available Constellation control-plane nodes.
|
||||
func (r *nodeRecoverer) Do(uuid, endpoint string) (passphrase, measurementSecret []byte, err error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
lis, err := net.Listen("tcp", endpoint)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer lis.Close()
|
||||
|
||||
var once sync.Once
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
key, secret, serveErr := r.recoveryServer.Serve(ctx, lis, uuid)
|
||||
once.Do(func() {
|
||||
cancel()
|
||||
passphrase = key
|
||||
measurementSecret = secret
|
||||
})
|
||||
if serveErr != nil && !errors.Is(serveErr, context.Canceled) {
|
||||
err = serveErr
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
key, secret := r.rejoinClient.Start(ctx, uuid)
|
||||
once.Do(func() {
|
||||
cancel()
|
||||
passphrase = key
|
||||
measurementSecret = secret
|
||||
})
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
return passphrase, measurementSecret, err
|
||||
}
|
@ -7,10 +7,13 @@ SPDX-License-Identifier: AGPL-3.0-only
|
||||
package setup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/attestation/vtpm"
|
||||
@ -30,9 +33,13 @@ func TestMain(m *testing.M) {
|
||||
|
||||
func TestPrepareExistingDisk(t *testing.T) {
|
||||
someErr := errors.New("error")
|
||||
testRecoveryDoer := &stubRecoveryDoer{
|
||||
passphrase: []byte("passphrase"),
|
||||
secret: []byte("secret"),
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
keyWaiter *stubKeyWaiter
|
||||
recoveryDoer *stubRecoveryDoer
|
||||
mapper *stubMapper
|
||||
mounter *stubMounter
|
||||
configGenerator *stubConfigurationGenerator
|
||||
@ -41,34 +48,33 @@ func TestPrepareExistingDisk(t *testing.T) {
|
||||
wantErr bool
|
||||
}{
|
||||
"success": {
|
||||
keyWaiter: &stubKeyWaiter{},
|
||||
recoveryDoer: testRecoveryDoer,
|
||||
mapper: &stubMapper{uuid: "test"},
|
||||
mounter: &stubMounter{},
|
||||
configGenerator: &stubConfigurationGenerator{},
|
||||
openTPM: vtpm.OpenNOPTPM,
|
||||
},
|
||||
"WaitForDecryptionKey fails": {
|
||||
keyWaiter: &stubKeyWaiter{waitErr: someErr},
|
||||
recoveryDoer: &stubRecoveryDoer{recoveryErr: someErr},
|
||||
mapper: &stubMapper{uuid: "test"},
|
||||
mounter: &stubMounter{},
|
||||
configGenerator: &stubConfigurationGenerator{},
|
||||
openTPM: vtpm.OpenNOPTPM,
|
||||
wantErr: true,
|
||||
},
|
||||
"MapDisk fails causes a repeat": {
|
||||
keyWaiter: &stubKeyWaiter{},
|
||||
"MapDisk fails": {
|
||||
recoveryDoer: testRecoveryDoer,
|
||||
mapper: &stubMapper{
|
||||
uuid: "test",
|
||||
mapDiskErr: someErr,
|
||||
mapDiskRepeatedCalls: 2,
|
||||
uuid: "test",
|
||||
mapDiskErr: someErr,
|
||||
},
|
||||
mounter: &stubMounter{},
|
||||
configGenerator: &stubConfigurationGenerator{},
|
||||
openTPM: vtpm.OpenNOPTPM,
|
||||
wantErr: false,
|
||||
wantErr: true,
|
||||
},
|
||||
"MkdirAll fails": {
|
||||
keyWaiter: &stubKeyWaiter{},
|
||||
recoveryDoer: testRecoveryDoer,
|
||||
mapper: &stubMapper{uuid: "test"},
|
||||
mounter: &stubMounter{mkdirAllErr: someErr},
|
||||
configGenerator: &stubConfigurationGenerator{},
|
||||
@ -76,7 +82,7 @@ func TestPrepareExistingDisk(t *testing.T) {
|
||||
wantErr: true,
|
||||
},
|
||||
"Mount fails": {
|
||||
keyWaiter: &stubKeyWaiter{},
|
||||
recoveryDoer: testRecoveryDoer,
|
||||
mapper: &stubMapper{uuid: "test"},
|
||||
mounter: &stubMounter{mountErr: someErr},
|
||||
configGenerator: &stubConfigurationGenerator{},
|
||||
@ -84,7 +90,7 @@ func TestPrepareExistingDisk(t *testing.T) {
|
||||
wantErr: true,
|
||||
},
|
||||
"Unmount fails": {
|
||||
keyWaiter: &stubKeyWaiter{},
|
||||
recoveryDoer: testRecoveryDoer,
|
||||
mapper: &stubMapper{uuid: "test"},
|
||||
mounter: &stubMounter{unmountErr: someErr},
|
||||
configGenerator: &stubConfigurationGenerator{},
|
||||
@ -92,7 +98,7 @@ func TestPrepareExistingDisk(t *testing.T) {
|
||||
wantErr: true,
|
||||
},
|
||||
"MarkNodeAsBootstrapped fails": {
|
||||
keyWaiter: &stubKeyWaiter{},
|
||||
recoveryDoer: testRecoveryDoer,
|
||||
mapper: &stubMapper{uuid: "test"},
|
||||
mounter: &stubMounter{unmountErr: someErr},
|
||||
configGenerator: &stubConfigurationGenerator{},
|
||||
@ -100,7 +106,7 @@ func TestPrepareExistingDisk(t *testing.T) {
|
||||
wantErr: true,
|
||||
},
|
||||
"Generating config fails": {
|
||||
keyWaiter: &stubKeyWaiter{},
|
||||
recoveryDoer: testRecoveryDoer,
|
||||
mapper: &stubMapper{uuid: "test"},
|
||||
mounter: &stubMounter{},
|
||||
configGenerator: &stubConfigurationGenerator{generateErr: someErr},
|
||||
@ -108,7 +114,7 @@ func TestPrepareExistingDisk(t *testing.T) {
|
||||
wantErr: true,
|
||||
},
|
||||
"no state file": {
|
||||
keyWaiter: &stubKeyWaiter{},
|
||||
recoveryDoer: testRecoveryDoer,
|
||||
mapper: &stubMapper{uuid: "test"},
|
||||
mounter: &stubMounter{},
|
||||
configGenerator: &stubConfigurationGenerator{},
|
||||
@ -130,23 +136,21 @@ func TestPrepareExistingDisk(t *testing.T) {
|
||||
}
|
||||
|
||||
setupManager := &SetupManager{
|
||||
log: logger.NewTest(t),
|
||||
csp: "test",
|
||||
diskPath: "disk-path",
|
||||
fs: fs,
|
||||
keyWaiter: tc.keyWaiter,
|
||||
mapper: tc.mapper,
|
||||
mounter: tc.mounter,
|
||||
config: tc.configGenerator,
|
||||
openTPM: tc.openTPM,
|
||||
log: logger.NewTest(t),
|
||||
csp: "test",
|
||||
diskPath: "disk-path",
|
||||
fs: fs,
|
||||
mapper: tc.mapper,
|
||||
mounter: tc.mounter,
|
||||
config: tc.configGenerator,
|
||||
openTPM: tc.openTPM,
|
||||
}
|
||||
|
||||
err := setupManager.PrepareExistingDisk()
|
||||
err := setupManager.PrepareExistingDisk(tc.recoveryDoer)
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.mapper.uuid, tc.keyWaiter.receivedUUID)
|
||||
assert.True(tc.mapper.mapDiskCalled)
|
||||
assert.True(tc.mounter.mountCalled)
|
||||
assert.True(tc.mounter.unmountCalled)
|
||||
@ -191,9 +195,8 @@ func TestPrepareNewDisk(t *testing.T) {
|
||||
"MapDisk fails": {
|
||||
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
||||
mapper: &stubMapper{
|
||||
uuid: "test",
|
||||
mapDiskErr: someErr,
|
||||
mapDiskRepeatedCalls: 1,
|
||||
uuid: "test",
|
||||
mapDiskErr: someErr,
|
||||
},
|
||||
configGenerator: &stubConfigurationGenerator{},
|
||||
wantErr: true,
|
||||
@ -267,7 +270,7 @@ func TestReadMeasurementSalt(t *testing.T) {
|
||||
require.NoError(handler.WriteJSON("test-state.json", state, file.OptMkdirAll))
|
||||
}
|
||||
|
||||
setupManager := New(logger.NewTest(t), "test", "disk-path", fs, nil, nil, nil, nil)
|
||||
setupManager := New(logger.NewTest(t), "test", "disk-path", fs, nil, nil, nil)
|
||||
|
||||
measurementSalt, err := setupManager.readMeasurementSalt("test-state.json")
|
||||
if tc.wantErr {
|
||||
@ -280,15 +283,115 @@ func TestReadMeasurementSalt(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecoveryDoer(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
rejoinClientKey := []byte("rejoinClientKey")
|
||||
rejoinClientSecret := []byte("rejoinClientSecret")
|
||||
recoveryServerKey := []byte("recoveryServerKey")
|
||||
recoveryServerSecret := []byte("recoveryServerSecret")
|
||||
|
||||
recoveryServerErr := errors.New("error")
|
||||
recoveryServer := &stubRecoveryServer{
|
||||
key: recoveryServerKey,
|
||||
secret: recoveryServerSecret,
|
||||
sendKeys: make(chan struct{}, 1),
|
||||
err: recoveryServerErr,
|
||||
}
|
||||
rejoinClient := &stubRejoinClient{
|
||||
key: rejoinClientKey,
|
||||
secret: rejoinClientSecret,
|
||||
sendKeys: make(chan struct{}, 1),
|
||||
}
|
||||
recoverer := NewNodeRecoverer(recoveryServer, rejoinClient)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var key, secret []byte
|
||||
var err error
|
||||
|
||||
// error from recovery server
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
key, secret, err = recoverer.Do("", "")
|
||||
}()
|
||||
recoveryServer.sendKeys <- struct{}{}
|
||||
wg.Wait()
|
||||
assert.ErrorIs(err, recoveryServerErr)
|
||||
|
||||
recoveryServer.err = nil
|
||||
recoveryServer.sendKeys = make(chan struct{}, 1)
|
||||
|
||||
// recovery server returns its key and secret
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
key, secret, err = recoverer.Do("", "")
|
||||
}()
|
||||
recoveryServer.sendKeys <- struct{}{}
|
||||
wg.Wait()
|
||||
assert.NoError(err)
|
||||
assert.Equal(recoveryServerKey, key)
|
||||
assert.Equal(recoveryServerSecret, secret)
|
||||
|
||||
recoveryServer.sendKeys = make(chan struct{}, 1)
|
||||
|
||||
// rejoin client returns its key and secret
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
key, secret, err = recoverer.Do("", "")
|
||||
}()
|
||||
rejoinClient.sendKeys <- struct{}{}
|
||||
wg.Wait()
|
||||
assert.NoError(err)
|
||||
assert.Equal(rejoinClientKey, key)
|
||||
assert.Equal(rejoinClientSecret, secret)
|
||||
}
|
||||
|
||||
type stubRecoveryServer struct {
|
||||
key []byte
|
||||
secret []byte
|
||||
sendKeys chan struct{}
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *stubRecoveryServer) Serve(ctx context.Context, _ net.Listener, _ string) ([]byte, []byte, error) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, nil, ctx.Err()
|
||||
case <-s.sendKeys:
|
||||
return s.key, s.secret, s.err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type stubRejoinClient struct {
|
||||
key []byte
|
||||
secret []byte
|
||||
sendKeys chan struct{}
|
||||
}
|
||||
|
||||
func (s *stubRejoinClient) Start(ctx context.Context, _ string) ([]byte, []byte) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, nil
|
||||
case <-s.sendKeys:
|
||||
return s.key, s.secret
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type stubMapper struct {
|
||||
formatDiskCalled bool
|
||||
formatDiskErr error
|
||||
mapDiskRepeatedCalls int
|
||||
mapDiskCalled bool
|
||||
mapDiskErr error
|
||||
unmapDiskCalled bool
|
||||
unmapDiskErr error
|
||||
uuid string
|
||||
formatDiskCalled bool
|
||||
formatDiskErr error
|
||||
mapDiskCalled bool
|
||||
mapDiskErr error
|
||||
unmapDiskCalled bool
|
||||
unmapDiskErr error
|
||||
uuid string
|
||||
}
|
||||
|
||||
func (s *stubMapper) DiskUUID() string {
|
||||
@ -301,10 +404,6 @@ func (s *stubMapper) FormatDisk(string) error {
|
||||
}
|
||||
|
||||
func (s *stubMapper) MapDisk(string, string) error {
|
||||
if s.mapDiskRepeatedCalls == 0 {
|
||||
s.mapDiskErr = nil
|
||||
}
|
||||
s.mapDiskRepeatedCalls--
|
||||
s.mapDiskCalled = true
|
||||
return s.mapDiskErr
|
||||
}
|
||||
@ -336,25 +435,14 @@ func (s *stubMounter) MkdirAll(path string, perm fs.FileMode) error {
|
||||
return s.mkdirAllErr
|
||||
}
|
||||
|
||||
type stubKeyWaiter struct {
|
||||
receivedUUID string
|
||||
decryptionKey []byte
|
||||
measurementSecret []byte
|
||||
waitErr error
|
||||
waitCalled bool
|
||||
type stubRecoveryDoer struct {
|
||||
passphrase []byte
|
||||
secret []byte
|
||||
recoveryErr error
|
||||
}
|
||||
|
||||
func (s *stubKeyWaiter) WaitForDecryptionKey(uuid, addr string) ([]byte, []byte, error) {
|
||||
if s.waitCalled {
|
||||
return nil, nil, errors.New("wait called before key was reset")
|
||||
}
|
||||
s.waitCalled = true
|
||||
s.receivedUUID = uuid
|
||||
return s.decryptionKey, s.measurementSecret, s.waitErr
|
||||
}
|
||||
|
||||
func (s *stubKeyWaiter) ResetKey() {
|
||||
s.waitCalled = false
|
||||
func (s *stubRecoveryDoer) Do(uuid, endpoint string) (passphrase, measurementSecret []byte, err error) {
|
||||
return s.passphrase, s.secret, s.recoveryErr
|
||||
}
|
||||
|
||||
type stubConfigurationGenerator struct {
|
@ -13,8 +13,8 @@ import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/disk-mapper/internal/mapper"
|
||||
"github.com/edgelesssys/constellation/internal/logger"
|
||||
"github.com/edgelesssys/constellation/state/internal/mapper"
|
||||
"github.com/martinjungblut/go-cryptsetup"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
@ -9,29 +9,18 @@ SPDX-License-Identifier: AGPL-3.0-only
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/atls"
|
||||
"github.com/edgelesssys/constellation/internal/cloud/metadata"
|
||||
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
|
||||
"github.com/edgelesssys/constellation/disk-mapper/internal/mapper"
|
||||
"github.com/edgelesssys/constellation/internal/logger"
|
||||
"github.com/edgelesssys/constellation/internal/oid"
|
||||
"github.com/edgelesssys/constellation/internal/role"
|
||||
"github.com/edgelesssys/constellation/state/internal/keyservice"
|
||||
"github.com/edgelesssys/constellation/state/internal/mapper"
|
||||
"github.com/edgelesssys/constellation/state/keyproto"
|
||||
"github.com/martinjungblut/go-cryptsetup"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -93,53 +82,7 @@ func TestMapper(t *testing.T) {
|
||||
assert.Error(mapper.MapDisk(mappedDevice, "invalid-passphrase"), "was able to map disk with incorrect passphrase")
|
||||
}
|
||||
|
||||
func TestKeyAPI(t *testing.T) {
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
|
||||
testKey := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
|
||||
testSecret := []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB")
|
||||
|
||||
// get a free port on localhost to run the test on
|
||||
listener, err := net.Listen("tcp", "localhost:0")
|
||||
require.NoError(err)
|
||||
apiAddr := listener.Addr().String()
|
||||
listener.Close()
|
||||
|
||||
api := keyservice.New(
|
||||
logger.NewTest(t),
|
||||
atls.NewFakeIssuer(oid.Dummy{}),
|
||||
&fakeMetadataAPI{},
|
||||
20*time.Second,
|
||||
time.Second,
|
||||
)
|
||||
|
||||
// send a key to the server
|
||||
go func() {
|
||||
// wait 2 seconds before sending the key
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
creds := atlscredentials.New(nil, nil)
|
||||
conn, err := grpc.Dial(apiAddr, grpc.WithTransportCredentials(creds))
|
||||
require.NoError(err)
|
||||
defer conn.Close()
|
||||
|
||||
client := keyproto.NewAPIClient(conn)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
||||
defer cancel()
|
||||
_, err = client.PushStateDiskKey(ctx, &keyproto.PushStateDiskKeyRequest{
|
||||
StateDiskKey: testKey,
|
||||
MeasurementSecret: testSecret,
|
||||
})
|
||||
require.NoError(err)
|
||||
}()
|
||||
|
||||
key, measurementSecret, err := api.WaitForDecryptionKey("12345678-1234-1234-1234-123456789ABC", apiAddr)
|
||||
assert.NoError(err)
|
||||
assert.Equal(testKey, key)
|
||||
assert.Equal(testSecret, measurementSecret)
|
||||
}
|
||||
|
||||
/*
|
||||
type fakeMetadataAPI struct{}
|
||||
|
||||
func (f *fakeMetadataAPI) List(ctx context.Context) ([]metadata.InstanceMetadata, error) {
|
||||
@ -152,3 +95,4 @@ func (f *fakeMetadataAPI) List(ctx context.Context) ([]metadata.InstanceMetadata
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
*/
|
258
disk-mapper/recoverproto/recover.pb.go
Normal file
258
disk-mapper/recoverproto/recover.pb.go
Normal file
@ -0,0 +1,258 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.28.0
|
||||
// protoc v3.20.1
|
||||
// source: recover.proto
|
||||
|
||||
package recoverproto
|
||||
|
||||
import (
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
type RecoverMessage struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
// Types that are assignable to Request:
|
||||
// *RecoverMessage_StateDiskKey
|
||||
// *RecoverMessage_MeasurementSecret
|
||||
Request isRecoverMessage_Request `protobuf_oneof:"request"`
|
||||
}
|
||||
|
||||
func (x *RecoverMessage) Reset() {
|
||||
*x = RecoverMessage{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_recover_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *RecoverMessage) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*RecoverMessage) ProtoMessage() {}
|
||||
|
||||
func (x *RecoverMessage) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_recover_proto_msgTypes[0]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use RecoverMessage.ProtoReflect.Descriptor instead.
|
||||
func (*RecoverMessage) Descriptor() ([]byte, []int) {
|
||||
return file_recover_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (m *RecoverMessage) GetRequest() isRecoverMessage_Request {
|
||||
if m != nil {
|
||||
return m.Request
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *RecoverMessage) GetStateDiskKey() []byte {
|
||||
if x, ok := x.GetRequest().(*RecoverMessage_StateDiskKey); ok {
|
||||
return x.StateDiskKey
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *RecoverMessage) GetMeasurementSecret() []byte {
|
||||
if x, ok := x.GetRequest().(*RecoverMessage_MeasurementSecret); ok {
|
||||
return x.MeasurementSecret
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type isRecoverMessage_Request interface {
|
||||
isRecoverMessage_Request()
|
||||
}
|
||||
|
||||
type RecoverMessage_StateDiskKey struct {
|
||||
StateDiskKey []byte `protobuf:"bytes,1,opt,name=state_disk_key,json=stateDiskKey,proto3,oneof"`
|
||||
}
|
||||
|
||||
type RecoverMessage_MeasurementSecret struct {
|
||||
MeasurementSecret []byte `protobuf:"bytes,2,opt,name=measurement_secret,json=measurementSecret,proto3,oneof"`
|
||||
}
|
||||
|
||||
func (*RecoverMessage_StateDiskKey) isRecoverMessage_Request() {}
|
||||
|
||||
func (*RecoverMessage_MeasurementSecret) isRecoverMessage_Request() {}
|
||||
|
||||
type RecoverResponse struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
DiskUuid string `protobuf:"bytes,1,opt,name=disk_uuid,json=diskUuid,proto3" json:"disk_uuid,omitempty"`
|
||||
}
|
||||
|
||||
func (x *RecoverResponse) Reset() {
|
||||
*x = RecoverResponse{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_recover_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *RecoverResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*RecoverResponse) ProtoMessage() {}
|
||||
|
||||
func (x *RecoverResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_recover_proto_msgTypes[1]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use RecoverResponse.ProtoReflect.Descriptor instead.
|
||||
func (*RecoverResponse) Descriptor() ([]byte, []int) {
|
||||
return file_recover_proto_rawDescGZIP(), []int{1}
|
||||
}
|
||||
|
||||
func (x *RecoverResponse) GetDiskUuid() string {
|
||||
if x != nil {
|
||||
return x.DiskUuid
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var File_recover_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_recover_proto_rawDesc = []byte{
|
||||
0x0a, 0x0d, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12,
|
||||
0x0c, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x74, 0x0a,
|
||||
0x0e, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12,
|
||||
0x26, 0x0a, 0x0e, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5f, 0x64, 0x69, 0x73, 0x6b, 0x5f, 0x6b, 0x65,
|
||||
0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x48, 0x00, 0x52, 0x0c, 0x73, 0x74, 0x61, 0x74, 0x65,
|
||||
0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79, 0x12, 0x2f, 0x0a, 0x12, 0x6d, 0x65, 0x61, 0x73, 0x75,
|
||||
0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20,
|
||||
0x01, 0x28, 0x0c, 0x48, 0x00, 0x52, 0x11, 0x6d, 0x65, 0x61, 0x73, 0x75, 0x72, 0x65, 0x6d, 0x65,
|
||||
0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x42, 0x09, 0x0a, 0x07, 0x72, 0x65, 0x71, 0x75,
|
||||
0x65, 0x73, 0x74, 0x22, 0x2e, 0x0a, 0x0f, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x52, 0x65,
|
||||
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x64, 0x69, 0x73, 0x6b, 0x5f, 0x75,
|
||||
0x75, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x64, 0x69, 0x73, 0x6b, 0x55,
|
||||
0x75, 0x69, 0x64, 0x32, 0x53, 0x0a, 0x03, 0x41, 0x50, 0x49, 0x12, 0x4c, 0x0a, 0x07, 0x52, 0x65,
|
||||
0x63, 0x6f, 0x76, 0x65, 0x72, 0x12, 0x1c, 0x2e, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70,
|
||||
0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x4d, 0x65, 0x73, 0x73,
|
||||
0x61, 0x67, 0x65, 0x1a, 0x1d, 0x2e, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f,
|
||||
0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
|
||||
0x73, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x3f, 0x5a, 0x3d, 0x67, 0x69, 0x74, 0x68,
|
||||
0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x65, 0x64, 0x67, 0x65, 0x6c, 0x65, 0x73, 0x73, 0x73,
|
||||
0x79, 0x73, 0x2f, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x65, 0x6c, 0x6c, 0x61, 0x74, 0x69, 0x6f, 0x6e,
|
||||
0x2f, 0x64, 0x69, 0x73, 0x6b, 0x2d, 0x6d, 0x61, 0x70, 0x70, 0x65, 0x72, 0x2f, 0x72, 0x65, 0x63,
|
||||
0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f,
|
||||
0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
file_recover_proto_rawDescOnce sync.Once
|
||||
file_recover_proto_rawDescData = file_recover_proto_rawDesc
|
||||
)
|
||||
|
||||
func file_recover_proto_rawDescGZIP() []byte {
|
||||
file_recover_proto_rawDescOnce.Do(func() {
|
||||
file_recover_proto_rawDescData = protoimpl.X.CompressGZIP(file_recover_proto_rawDescData)
|
||||
})
|
||||
return file_recover_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_recover_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
|
||||
var file_recover_proto_goTypes = []interface{}{
|
||||
(*RecoverMessage)(nil), // 0: recoverproto.RecoverMessage
|
||||
(*RecoverResponse)(nil), // 1: recoverproto.RecoverResponse
|
||||
}
|
||||
var file_recover_proto_depIdxs = []int32{
|
||||
0, // 0: recoverproto.API.Recover:input_type -> recoverproto.RecoverMessage
|
||||
1, // 1: recoverproto.API.Recover:output_type -> recoverproto.RecoverResponse
|
||||
1, // [1:2] is the sub-list for method output_type
|
||||
0, // [0:1] is the sub-list for method input_type
|
||||
0, // [0:0] is the sub-list for extension type_name
|
||||
0, // [0:0] is the sub-list for extension extendee
|
||||
0, // [0:0] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_recover_proto_init() }
|
||||
func file_recover_proto_init() {
|
||||
if File_recover_proto != nil {
|
||||
return
|
||||
}
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_recover_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*RecoverMessage); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_recover_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*RecoverResponse); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
file_recover_proto_msgTypes[0].OneofWrappers = []interface{}{
|
||||
(*RecoverMessage_StateDiskKey)(nil),
|
||||
(*RecoverMessage_MeasurementSecret)(nil),
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_recover_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
NumMessages: 2,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
GoTypes: file_recover_proto_goTypes,
|
||||
DependencyIndexes: file_recover_proto_depIdxs,
|
||||
MessageInfos: file_recover_proto_msgTypes,
|
||||
}.Build()
|
||||
File_recover_proto = out.File
|
||||
file_recover_proto_rawDesc = nil
|
||||
file_recover_proto_goTypes = nil
|
||||
file_recover_proto_depIdxs = nil
|
||||
}
|
20
disk-mapper/recoverproto/recover.proto
Normal file
20
disk-mapper/recoverproto/recover.proto
Normal file
@ -0,0 +1,20 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package recoverproto;
|
||||
|
||||
option go_package = "github.com/edgelesssys/constellation/disk-mapper/recoverproto";
|
||||
|
||||
service API {
|
||||
rpc Recover(stream RecoverMessage) returns (stream RecoverResponse) {}
|
||||
}
|
||||
|
||||
message RecoverMessage {
|
||||
oneof request {
|
||||
bytes state_disk_key = 1;
|
||||
bytes measurement_secret = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message RecoverResponse {
|
||||
string disk_uuid = 1;
|
||||
}
|
@ -2,9 +2,9 @@
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.2.0
|
||||
// - protoc v3.20.1
|
||||
// source: keyservice.proto
|
||||
// source: recover.proto
|
||||
|
||||
package keyproto
|
||||
package recoverproto
|
||||
|
||||
import (
|
||||
context "context"
|
||||
@ -22,7 +22,7 @@ const _ = grpc.SupportPackageIsVersion7
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type APIClient interface {
|
||||
PushStateDiskKey(ctx context.Context, in *PushStateDiskKeyRequest, opts ...grpc.CallOption) (*PushStateDiskKeyResponse, error)
|
||||
Recover(ctx context.Context, opts ...grpc.CallOption) (API_RecoverClient, error)
|
||||
}
|
||||
|
||||
type aPIClient struct {
|
||||
@ -33,20 +33,42 @@ func NewAPIClient(cc grpc.ClientConnInterface) APIClient {
|
||||
return &aPIClient{cc}
|
||||
}
|
||||
|
||||
func (c *aPIClient) PushStateDiskKey(ctx context.Context, in *PushStateDiskKeyRequest, opts ...grpc.CallOption) (*PushStateDiskKeyResponse, error) {
|
||||
out := new(PushStateDiskKeyResponse)
|
||||
err := c.cc.Invoke(ctx, "/keyproto.API/PushStateDiskKey", in, out, opts...)
|
||||
func (c *aPIClient) Recover(ctx context.Context, opts ...grpc.CallOption) (API_RecoverClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &API_ServiceDesc.Streams[0], "/recoverproto.API/Recover", opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
x := &aPIRecoverClient{stream}
|
||||
return x, nil
|
||||
}
|
||||
|
||||
type API_RecoverClient interface {
|
||||
Send(*RecoverMessage) error
|
||||
Recv() (*RecoverResponse, error)
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
type aPIRecoverClient struct {
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
func (x *aPIRecoverClient) Send(m *RecoverMessage) error {
|
||||
return x.ClientStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func (x *aPIRecoverClient) Recv() (*RecoverResponse, error) {
|
||||
m := new(RecoverResponse)
|
||||
if err := x.ClientStream.RecvMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// APIServer is the server API for API service.
|
||||
// All implementations must embed UnimplementedAPIServer
|
||||
// for forward compatibility
|
||||
type APIServer interface {
|
||||
PushStateDiskKey(context.Context, *PushStateDiskKeyRequest) (*PushStateDiskKeyResponse, error)
|
||||
Recover(API_RecoverServer) error
|
||||
mustEmbedUnimplementedAPIServer()
|
||||
}
|
||||
|
||||
@ -54,8 +76,8 @@ type APIServer interface {
|
||||
type UnimplementedAPIServer struct {
|
||||
}
|
||||
|
||||
func (UnimplementedAPIServer) PushStateDiskKey(context.Context, *PushStateDiskKeyRequest) (*PushStateDiskKeyResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method PushStateDiskKey not implemented")
|
||||
func (UnimplementedAPIServer) Recover(API_RecoverServer) error {
|
||||
return status.Errorf(codes.Unimplemented, "method Recover not implemented")
|
||||
}
|
||||
func (UnimplementedAPIServer) mustEmbedUnimplementedAPIServer() {}
|
||||
|
||||
@ -70,36 +92,46 @@ func RegisterAPIServer(s grpc.ServiceRegistrar, srv APIServer) {
|
||||
s.RegisterService(&API_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _API_PushStateDiskKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(PushStateDiskKeyRequest)
|
||||
if err := dec(in); err != nil {
|
||||
func _API_Recover_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
return srv.(APIServer).Recover(&aPIRecoverServer{stream})
|
||||
}
|
||||
|
||||
type API_RecoverServer interface {
|
||||
Send(*RecoverResponse) error
|
||||
Recv() (*RecoverMessage, error)
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
type aPIRecoverServer struct {
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
func (x *aPIRecoverServer) Send(m *RecoverResponse) error {
|
||||
return x.ServerStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func (x *aPIRecoverServer) Recv() (*RecoverMessage, error) {
|
||||
m := new(RecoverMessage)
|
||||
if err := x.ServerStream.RecvMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(APIServer).PushStateDiskKey(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/keyproto.API/PushStateDiskKey",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(APIServer).PushStateDiskKey(ctx, req.(*PushStateDiskKeyRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// API_ServiceDesc is the grpc.ServiceDesc for API service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
var API_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "keyproto.API",
|
||||
ServiceName: "recoverproto.API",
|
||||
HandlerType: (*APIServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
Methods: []grpc.MethodDesc{},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
MethodName: "PushStateDiskKey",
|
||||
Handler: _API_PushStateDiskKey_Handler,
|
||||
StreamName: "Recover",
|
||||
Handler: _API_Recover_Handler,
|
||||
ServerStreams: true,
|
||||
ClientStreams: true,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "keyservice.proto",
|
||||
Metadata: "recover.proto",
|
||||
}
|
@ -24,9 +24,9 @@ RUN go install google.golang.org/protobuf/cmd/protoc-gen-go@v${GEN_GO_VER} && \
|
||||
|
||||
# Generate code for every existing proto file
|
||||
|
||||
## disk-mapper keyservice api
|
||||
## disk-mapper recover api
|
||||
WORKDIR /disk-mapper
|
||||
COPY state/keyproto/*.proto /disk-mapper
|
||||
COPY disk-mapper/recoverproto/*.proto /disk-mapper
|
||||
RUN protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative *.proto
|
||||
|
||||
## debugd service
|
||||
@ -54,7 +54,7 @@ COPY bootstrapper/initproto/*.proto /init
|
||||
RUN protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative *.proto
|
||||
|
||||
FROM scratch as export
|
||||
COPY --from=build /disk-mapper/*.go state/keyproto/
|
||||
COPY --from=build /disk-mapper/*.go disk-mapper/recoverproto/
|
||||
COPY --from=build /service/*.go debugd/service/
|
||||
COPY --from=build /kms/*.go kms/kmsproto/
|
||||
COPY --from=build /joinservice/*.go joinservice/joinproto/
|
||||
|
@ -1,204 +0,0 @@
|
||||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package keyservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/cloud/metadata"
|
||||
"github.com/edgelesssys/constellation/internal/crypto"
|
||||
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
|
||||
"github.com/edgelesssys/constellation/internal/logger"
|
||||
"github.com/edgelesssys/constellation/internal/oid"
|
||||
"github.com/edgelesssys/constellation/joinservice/joinproto"
|
||||
"github.com/edgelesssys/constellation/state/keyproto"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/status"
|
||||
"k8s.io/utils/clock"
|
||||
)
|
||||
|
||||
// KeyAPI is the interface called by control-plane or an admin during restart of a node.
|
||||
type KeyAPI struct {
|
||||
listenAddr string
|
||||
log *logger.Logger
|
||||
mux sync.Mutex
|
||||
metadata metadata.InstanceLister
|
||||
issuer QuoteIssuer
|
||||
key []byte
|
||||
measurementSecret []byte
|
||||
keyReceived chan struct{}
|
||||
|
||||
clock clock.WithTicker
|
||||
timeout time.Duration
|
||||
interval time.Duration
|
||||
|
||||
keyproto.UnimplementedAPIServer
|
||||
}
|
||||
|
||||
// New initializes a KeyAPI with the given parameters.
|
||||
func New(log *logger.Logger, issuer QuoteIssuer, metadata metadata.InstanceLister, timeout time.Duration, interval time.Duration) *KeyAPI {
|
||||
return &KeyAPI{
|
||||
log: log,
|
||||
metadata: metadata,
|
||||
issuer: issuer,
|
||||
keyReceived: make(chan struct{}, 1),
|
||||
clock: clock.RealClock{},
|
||||
timeout: timeout,
|
||||
interval: interval,
|
||||
}
|
||||
}
|
||||
|
||||
// PushStateDiskKey is the rpc to push state disk decryption keys to a restarting node.
|
||||
func (a *KeyAPI) PushStateDiskKey(ctx context.Context, in *keyproto.PushStateDiskKeyRequest) (*keyproto.PushStateDiskKeyResponse, error) {
|
||||
a.mux.Lock()
|
||||
defer a.mux.Unlock()
|
||||
if len(a.key) != 0 {
|
||||
return nil, status.Error(codes.FailedPrecondition, "node already received a passphrase")
|
||||
}
|
||||
if len(in.StateDiskKey) != crypto.StateDiskKeyLength {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "received invalid passphrase: expected length: %d, but got: %d", crypto.StateDiskKeyLength, len(in.StateDiskKey))
|
||||
}
|
||||
if len(in.MeasurementSecret) != crypto.RNGLengthDefault {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "received invalid measurement secret: expected length: %d, but got: %d", crypto.RNGLengthDefault, len(in.MeasurementSecret))
|
||||
}
|
||||
|
||||
a.key = in.StateDiskKey
|
||||
a.measurementSecret = in.MeasurementSecret
|
||||
a.keyReceived <- struct{}{}
|
||||
return &keyproto.PushStateDiskKeyResponse{}, nil
|
||||
}
|
||||
|
||||
// WaitForDecryptionKey notifies control-plane nodes to send a decryption key and waits until a key is received.
|
||||
func (a *KeyAPI) WaitForDecryptionKey(uuid, listenAddr string) (diskKey, measurementSecret []byte, err error) {
|
||||
if uuid == "" {
|
||||
return nil, nil, errors.New("received no disk UUID")
|
||||
}
|
||||
a.listenAddr = listenAddr
|
||||
creds := atlscredentials.New(a.issuer, nil)
|
||||
server := grpc.NewServer(grpc.Creds(creds))
|
||||
keyproto.RegisterAPIServer(server, a)
|
||||
listener, err := net.Listen("tcp", listenAddr)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
a.log.Infof("Waiting for decryption key. Listening on: %s", listener.Addr().String())
|
||||
go server.Serve(listener)
|
||||
defer server.GracefulStop()
|
||||
|
||||
a.requestKeyLoop(uuid)
|
||||
return a.key, a.measurementSecret, nil
|
||||
}
|
||||
|
||||
// ResetKey resets a previously set key.
|
||||
func (a *KeyAPI) ResetKey() {
|
||||
a.key = nil
|
||||
}
|
||||
|
||||
// requestKeyLoop continuously requests decryption keys from all available control-plane nodes, until the KeyAPI receives a key.
|
||||
func (a *KeyAPI) requestKeyLoop(uuid string, opts ...grpc.DialOption) {
|
||||
// we do not perform attestation, since the restarting node does not need to care about notifying the correct node
|
||||
// if an incorrect key is pushed by a malicious actor, decrypting the disk will fail, and the node will not start
|
||||
creds := atlscredentials.New(a.issuer, nil)
|
||||
|
||||
ticker := a.clock.NewTicker(a.interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
endpoints, err := a.getJoinServiceEndpoints()
|
||||
if err != nil {
|
||||
a.log.With(zap.Error(err)).Errorf("Failed to get JoinService endpoints")
|
||||
} else {
|
||||
a.log.Infof("Received list with JoinService endpoints: %v", endpoints)
|
||||
for _, endpoint := range endpoints {
|
||||
a.requestKey(endpoint, uuid, creds, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-a.keyReceived:
|
||||
// return if a key was received
|
||||
// a key can be send by
|
||||
// - a control-plane node, after the request rpc was received
|
||||
// - by a Constellation admin, at any time this loop is running on a node during boot
|
||||
return
|
||||
case <-ticker.C():
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *KeyAPI) getJoinServiceEndpoints() ([]string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), a.timeout)
|
||||
defer cancel()
|
||||
return metadata.JoinServiceEndpoints(ctx, a.metadata)
|
||||
}
|
||||
|
||||
func (a *KeyAPI) requestKey(endpoint, uuid string, credentials credentials.TransportCredentials, opts ...grpc.DialOption) {
|
||||
opts = append(opts, grpc.WithTransportCredentials(credentials))
|
||||
|
||||
a.log.With(zap.String("endpoint", endpoint)).Infof("Requesting rejoin ticket")
|
||||
rejoinTicket, err := a.requestRejoinTicket(endpoint, uuid, opts...)
|
||||
if err != nil {
|
||||
a.log.With(zap.Error(err), zap.String("endpoint", endpoint)).Errorf("Failed to request rejoin ticket")
|
||||
return
|
||||
}
|
||||
|
||||
a.log.With(zap.String("endpoint", endpoint)).Infof("Pushing key to own server")
|
||||
if err := a.pushKeyToOwnServer(rejoinTicket.StateDiskKey, rejoinTicket.MeasurementSecret, opts...); err != nil {
|
||||
a.log.With(zap.Error(err), zap.String("endpoint", a.listenAddr)).Errorf("Failed to push key to own server")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (a *KeyAPI) requestRejoinTicket(endpoint, uuid string, opts ...grpc.DialOption) (*joinproto.IssueRejoinTicketResponse, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), a.timeout)
|
||||
defer cancel()
|
||||
conn, err := grpc.DialContext(ctx, endpoint, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing gRPC: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
client := joinproto.NewAPIClient(conn)
|
||||
req := &joinproto.IssueRejoinTicketRequest{DiskUuid: uuid}
|
||||
return client.IssueRejoinTicket(ctx, req)
|
||||
}
|
||||
|
||||
func (a *KeyAPI) pushKeyToOwnServer(stateDiskKey, measurementSecret []byte, opts ...grpc.DialOption) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), a.timeout)
|
||||
defer cancel()
|
||||
conn, err := grpc.DialContext(ctx, a.listenAddr, opts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing gRPC: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
client := keyproto.NewAPIClient(conn)
|
||||
req := &keyproto.PushStateDiskKeyRequest{StateDiskKey: stateDiskKey, MeasurementSecret: measurementSecret}
|
||||
_, err = client.PushStateDiskKey(ctx, req)
|
||||
return err
|
||||
}
|
||||
|
||||
// QuoteValidator validates quotes.
|
||||
type QuoteValidator interface {
|
||||
oid.Getter
|
||||
// Validate validates a quote and returns the user data on success.
|
||||
Validate(attDoc []byte, nonce []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
// QuoteIssuer issues quotes.
|
||||
type QuoteIssuer interface {
|
||||
oid.Getter
|
||||
// Issue issues a quote for remote attestation for a given message
|
||||
Issue(userData []byte, nonce []byte) (quote []byte, err error)
|
||||
}
|
@ -1,264 +0,0 @@
|
||||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package keyservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/atls"
|
||||
"github.com/edgelesssys/constellation/internal/cloud/metadata"
|
||||
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
|
||||
"github.com/edgelesssys/constellation/internal/logger"
|
||||
"github.com/edgelesssys/constellation/internal/oid"
|
||||
"github.com/edgelesssys/constellation/internal/role"
|
||||
"github.com/edgelesssys/constellation/joinservice/joinproto"
|
||||
"github.com/edgelesssys/constellation/state/keyproto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/goleak"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
testclock "k8s.io/utils/clock/testing"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
||||
|
||||
func TestRequestKeyLoop(t *testing.T) {
|
||||
clockstep := struct{}{}
|
||||
someErr := errors.New("failed")
|
||||
defaultInstance := metadata.InstanceMetadata{
|
||||
Name: "test-instance",
|
||||
ProviderID: "/test/provider",
|
||||
Role: role.ControlPlane,
|
||||
VPCIP: "192.0.2.1",
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
answers []any
|
||||
}{
|
||||
"success": {
|
||||
answers: []any{
|
||||
listAnswer{listResponse: []metadata.InstanceMetadata{defaultInstance}},
|
||||
issueRejoinTicketAnswer{stateDiskKey: []byte{0x1}, measurementSecret: []byte{0x2}},
|
||||
pushStateDiskKeyAnswer{},
|
||||
},
|
||||
},
|
||||
"recover metadata list error": {
|
||||
answers: []any{
|
||||
listAnswer{err: someErr},
|
||||
clockstep,
|
||||
listAnswer{listResponse: []metadata.InstanceMetadata{defaultInstance}},
|
||||
issueRejoinTicketAnswer{stateDiskKey: []byte{0x1}, measurementSecret: []byte{0x2}},
|
||||
pushStateDiskKeyAnswer{},
|
||||
},
|
||||
},
|
||||
"recover issue rejoin ticket error": {
|
||||
answers: []any{
|
||||
listAnswer{listResponse: []metadata.InstanceMetadata{defaultInstance}},
|
||||
issueRejoinTicketAnswer{err: someErr},
|
||||
clockstep,
|
||||
listAnswer{listResponse: []metadata.InstanceMetadata{defaultInstance}},
|
||||
issueRejoinTicketAnswer{stateDiskKey: []byte{0x1}, measurementSecret: []byte{0x2}},
|
||||
pushStateDiskKeyAnswer{},
|
||||
},
|
||||
},
|
||||
"recover push key error": {
|
||||
answers: []any{
|
||||
listAnswer{listResponse: []metadata.InstanceMetadata{defaultInstance}},
|
||||
issueRejoinTicketAnswer{stateDiskKey: []byte{0x1}, measurementSecret: []byte{0x2}},
|
||||
pushStateDiskKeyAnswer{err: someErr},
|
||||
clockstep,
|
||||
listAnswer{listResponse: []metadata.InstanceMetadata{defaultInstance}},
|
||||
issueRejoinTicketAnswer{stateDiskKey: []byte{0x1}, measurementSecret: []byte{0x2}},
|
||||
pushStateDiskKeyAnswer{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
metadataServer := newStubMetadataServer()
|
||||
joinServer := newStubJoinAPIServer()
|
||||
keyServer := newStubKeyAPIServer()
|
||||
|
||||
listener := bufconn.Listen(1024)
|
||||
defer listener.Close()
|
||||
creds := atlscredentials.New(atls.NewFakeIssuer(oid.Dummy{}), nil)
|
||||
grpcServer := grpc.NewServer(grpc.Creds(creds))
|
||||
joinproto.RegisterAPIServer(grpcServer, joinServer)
|
||||
keyproto.RegisterAPIServer(grpcServer, keyServer)
|
||||
go grpcServer.Serve(listener)
|
||||
defer grpcServer.GracefulStop()
|
||||
|
||||
clock := testclock.NewFakeClock(time.Now())
|
||||
keyReceived := make(chan struct{}, 1)
|
||||
keyWaiter := &KeyAPI{
|
||||
listenAddr: "192.0.2.1:30090",
|
||||
log: logger.NewTest(t),
|
||||
metadata: metadataServer,
|
||||
keyReceived: keyReceived,
|
||||
clock: clock,
|
||||
timeout: 1 * time.Second,
|
||||
interval: 1 * time.Second,
|
||||
}
|
||||
grpcOpts := []grpc.DialOption{
|
||||
grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) {
|
||||
return listener.DialContext(ctx)
|
||||
}),
|
||||
}
|
||||
|
||||
// Start the request loop under tests.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
keyWaiter.requestKeyLoop("1234", grpcOpts...)
|
||||
}()
|
||||
|
||||
// Play test case answers.
|
||||
for _, answ := range tc.answers {
|
||||
switch answ := answ.(type) {
|
||||
case listAnswer:
|
||||
metadataServer.listAnswerC <- answ
|
||||
case issueRejoinTicketAnswer:
|
||||
joinServer.issueRejoinTicketAnswerC <- answ
|
||||
case pushStateDiskKeyAnswer:
|
||||
keyServer.pushStateDiskKeyAnswerC <- answ
|
||||
default:
|
||||
clock.Step(time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop the request loop.
|
||||
keyReceived <- struct{}{}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushStateDiskKey(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
testAPI *KeyAPI
|
||||
request *keyproto.PushStateDiskKeyRequest
|
||||
wantErr bool
|
||||
}{
|
||||
"success": {
|
||||
testAPI: &KeyAPI{keyReceived: make(chan struct{}, 1)},
|
||||
request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), MeasurementSecret: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB")},
|
||||
},
|
||||
"key already set": {
|
||||
testAPI: &KeyAPI{
|
||||
keyReceived: make(chan struct{}, 1),
|
||||
key: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
|
||||
},
|
||||
request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"), MeasurementSecret: []byte("CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC")},
|
||||
wantErr: true,
|
||||
},
|
||||
"incorrect size of pushed key": {
|
||||
testAPI: &KeyAPI{keyReceived: make(chan struct{}, 1)},
|
||||
request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("AAAAAAAAAAAAAAAA"), MeasurementSecret: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB")},
|
||||
wantErr: true,
|
||||
},
|
||||
"incorrect size of measurement secret": {
|
||||
testAPI: &KeyAPI{keyReceived: make(chan struct{}, 1)},
|
||||
request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), MeasurementSecret: []byte("BBBBBBBBBBBBBBBB")},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
tc.testAPI.log = logger.NewTest(t)
|
||||
_, err := tc.testAPI.PushStateDiskKey(context.Background(), tc.request)
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.request.StateDiskKey, tc.testAPI.key)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetKey(t *testing.T) {
|
||||
api := New(logger.NewTest(t), nil, nil, time.Second, time.Millisecond)
|
||||
|
||||
api.key = []byte{0x1, 0x2, 0x3}
|
||||
api.ResetKey()
|
||||
assert.Nil(t, api.key)
|
||||
}
|
||||
|
||||
type stubMetadataServer struct {
|
||||
listAnswerC chan listAnswer
|
||||
}
|
||||
|
||||
func newStubMetadataServer() *stubMetadataServer {
|
||||
return &stubMetadataServer{
|
||||
listAnswerC: make(chan listAnswer),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stubMetadataServer) List(context.Context) ([]metadata.InstanceMetadata, error) {
|
||||
answer := <-s.listAnswerC
|
||||
return answer.listResponse, answer.err
|
||||
}
|
||||
|
||||
type listAnswer struct {
|
||||
listResponse []metadata.InstanceMetadata
|
||||
err error
|
||||
}
|
||||
|
||||
type stubJoinAPIServer struct {
|
||||
issueRejoinTicketAnswerC chan issueRejoinTicketAnswer
|
||||
joinproto.UnimplementedAPIServer
|
||||
}
|
||||
|
||||
func newStubJoinAPIServer() *stubJoinAPIServer {
|
||||
return &stubJoinAPIServer{
|
||||
issueRejoinTicketAnswerC: make(chan issueRejoinTicketAnswer),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stubJoinAPIServer) IssueRejoinTicket(context.Context, *joinproto.IssueRejoinTicketRequest) (*joinproto.IssueRejoinTicketResponse, error) {
|
||||
answer := <-s.issueRejoinTicketAnswerC
|
||||
resp := &joinproto.IssueRejoinTicketResponse{
|
||||
StateDiskKey: answer.stateDiskKey,
|
||||
MeasurementSecret: answer.measurementSecret,
|
||||
}
|
||||
return resp, answer.err
|
||||
}
|
||||
|
||||
type issueRejoinTicketAnswer struct {
|
||||
stateDiskKey []byte
|
||||
measurementSecret []byte
|
||||
err error
|
||||
}
|
||||
|
||||
type stubKeyAPIServer struct {
|
||||
pushStateDiskKeyAnswerC chan pushStateDiskKeyAnswer
|
||||
keyproto.UnimplementedAPIServer
|
||||
}
|
||||
|
||||
func newStubKeyAPIServer() *stubKeyAPIServer {
|
||||
return &stubKeyAPIServer{
|
||||
pushStateDiskKeyAnswerC: make(chan pushStateDiskKeyAnswer),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stubKeyAPIServer) PushStateDiskKey(context.Context, *keyproto.PushStateDiskKeyRequest) (*keyproto.PushStateDiskKeyResponse, error) {
|
||||
answer := <-s.pushStateDiskKeyAnswerC
|
||||
return &keyproto.PushStateDiskKeyResponse{}, answer.err
|
||||
}
|
||||
|
||||
type pushStateDiskKeyAnswer struct {
|
||||
err error
|
||||
}
|
@ -1,219 +0,0 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.28.0
|
||||
// protoc v3.20.1
|
||||
// source: keyservice.proto
|
||||
|
||||
package keyproto
|
||||
|
||||
import (
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
type PushStateDiskKeyRequest struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
StateDiskKey []byte `protobuf:"bytes,1,opt,name=state_disk_key,json=stateDiskKey,proto3" json:"state_disk_key,omitempty"`
|
||||
MeasurementSecret []byte `protobuf:"bytes,2,opt,name=measurement_secret,json=measurementSecret,proto3" json:"measurement_secret,omitempty"`
|
||||
}
|
||||
|
||||
func (x *PushStateDiskKeyRequest) Reset() {
|
||||
*x = PushStateDiskKeyRequest{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_keyservice_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *PushStateDiskKeyRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*PushStateDiskKeyRequest) ProtoMessage() {}
|
||||
|
||||
func (x *PushStateDiskKeyRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_keyservice_proto_msgTypes[0]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use PushStateDiskKeyRequest.ProtoReflect.Descriptor instead.
|
||||
func (*PushStateDiskKeyRequest) Descriptor() ([]byte, []int) {
|
||||
return file_keyservice_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *PushStateDiskKeyRequest) GetStateDiskKey() []byte {
|
||||
if x != nil {
|
||||
return x.StateDiskKey
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *PushStateDiskKeyRequest) GetMeasurementSecret() []byte {
|
||||
if x != nil {
|
||||
return x.MeasurementSecret
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type PushStateDiskKeyResponse struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
}
|
||||
|
||||
func (x *PushStateDiskKeyResponse) Reset() {
|
||||
*x = PushStateDiskKeyResponse{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_keyservice_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *PushStateDiskKeyResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*PushStateDiskKeyResponse) ProtoMessage() {}
|
||||
|
||||
func (x *PushStateDiskKeyResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_keyservice_proto_msgTypes[1]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use PushStateDiskKeyResponse.ProtoReflect.Descriptor instead.
|
||||
func (*PushStateDiskKeyResponse) Descriptor() ([]byte, []int) {
|
||||
return file_keyservice_proto_rawDescGZIP(), []int{1}
|
||||
}
|
||||
|
||||
var File_keyservice_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_keyservice_proto_rawDesc = []byte{
|
||||
0x0a, 0x10, 0x6b, 0x65, 0x79, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f,
|
||||
0x74, 0x6f, 0x12, 0x08, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x6e, 0x0a, 0x17,
|
||||
0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79,
|
||||
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x24, 0x0a, 0x0e, 0x73, 0x74, 0x61, 0x74, 0x65,
|
||||
0x5f, 0x64, 0x69, 0x73, 0x6b, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52,
|
||||
0x0c, 0x73, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79, 0x12, 0x2d, 0x0a,
|
||||
0x12, 0x6d, 0x65, 0x61, 0x73, 0x75, 0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x73, 0x65, 0x63,
|
||||
0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, 0x6d, 0x65, 0x61, 0x73, 0x75,
|
||||
0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x22, 0x1a, 0x0a, 0x18,
|
||||
0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79,
|
||||
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x32, 0x60, 0x0a, 0x03, 0x41, 0x50, 0x49, 0x12,
|
||||
0x59, 0x0a, 0x10, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b,
|
||||
0x4b, 0x65, 0x79, 0x12, 0x21, 0x2e, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x50,
|
||||
0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79, 0x52,
|
||||
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x22, 0x2e, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x2e, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b,
|
||||
0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x40, 0x5a, 0x3e, 0x67, 0x69,
|
||||
0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x65, 0x64, 0x67, 0x65, 0x6c, 0x65, 0x73,
|
||||
0x73, 0x73, 0x79, 0x73, 0x2f, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x65, 0x6c, 0x6c, 0x61, 0x74, 0x69,
|
||||
0x6f, 0x6e, 0x2f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x2f, 0x6b, 0x65, 0x79, 0x73, 0x65, 0x72, 0x76,
|
||||
0x69, 0x63, 0x65, 0x2f, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72,
|
||||
0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
file_keyservice_proto_rawDescOnce sync.Once
|
||||
file_keyservice_proto_rawDescData = file_keyservice_proto_rawDesc
|
||||
)
|
||||
|
||||
func file_keyservice_proto_rawDescGZIP() []byte {
|
||||
file_keyservice_proto_rawDescOnce.Do(func() {
|
||||
file_keyservice_proto_rawDescData = protoimpl.X.CompressGZIP(file_keyservice_proto_rawDescData)
|
||||
})
|
||||
return file_keyservice_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_keyservice_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
|
||||
var file_keyservice_proto_goTypes = []interface{}{
|
||||
(*PushStateDiskKeyRequest)(nil), // 0: keyproto.PushStateDiskKeyRequest
|
||||
(*PushStateDiskKeyResponse)(nil), // 1: keyproto.PushStateDiskKeyResponse
|
||||
}
|
||||
var file_keyservice_proto_depIdxs = []int32{
|
||||
0, // 0: keyproto.API.PushStateDiskKey:input_type -> keyproto.PushStateDiskKeyRequest
|
||||
1, // 1: keyproto.API.PushStateDiskKey:output_type -> keyproto.PushStateDiskKeyResponse
|
||||
1, // [1:2] is the sub-list for method output_type
|
||||
0, // [0:1] is the sub-list for method input_type
|
||||
0, // [0:0] is the sub-list for extension type_name
|
||||
0, // [0:0] is the sub-list for extension extendee
|
||||
0, // [0:0] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_keyservice_proto_init() }
|
||||
func file_keyservice_proto_init() {
|
||||
if File_keyservice_proto != nil {
|
||||
return
|
||||
}
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_keyservice_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*PushStateDiskKeyRequest); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_keyservice_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*PushStateDiskKeyResponse); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_keyservice_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
NumMessages: 2,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
GoTypes: file_keyservice_proto_goTypes,
|
||||
DependencyIndexes: file_keyservice_proto_depIdxs,
|
||||
MessageInfos: file_keyservice_proto_msgTypes,
|
||||
}.Build()
|
||||
File_keyservice_proto = out.File
|
||||
file_keyservice_proto_rawDesc = nil
|
||||
file_keyservice_proto_goTypes = nil
|
||||
file_keyservice_proto_depIdxs = nil
|
||||
}
|
@ -1,17 +0,0 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package keyproto;
|
||||
|
||||
option go_package = "github.com/edgelesssys/constellation/state/keyservice/keyproto";
|
||||
|
||||
service API {
|
||||
rpc PushStateDiskKey(PushStateDiskKeyRequest) returns (PushStateDiskKeyResponse);
|
||||
}
|
||||
|
||||
message PushStateDiskKeyRequest {
|
||||
bytes state_disk_key = 1;
|
||||
bytes measurement_secret = 2;
|
||||
}
|
||||
|
||||
message PushStateDiskKeyResponse {
|
||||
}
|
Loading…
Reference in New Issue
Block a user