AB#1903 Add grpc interface to push decryption keys

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2022-04-11 14:25:19 +02:00 committed by Daniel Weiße
parent 96d7029367
commit 152e3985f7
12 changed files with 1110 additions and 201 deletions

View File

@ -36,7 +36,7 @@ RUN protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_o
## disk-mapper keyservice api ## disk-mapper keyservice api
WORKDIR /disk-mapper WORKDIR /disk-mapper
COPY state/keyservice/proto/*.proto /disk-mapper COPY state/keyservice/keyproto/*.proto /disk-mapper
RUN protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative *.proto RUN protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative *.proto
## debugd service ## debugd service
@ -48,5 +48,5 @@ RUN protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_o
FROM scratch as export FROM scratch as export
COPY --from=build /pubapi/*.go coordinator/pubapi/pubproto/ COPY --from=build /pubapi/*.go coordinator/pubapi/pubproto/
COPY --from=build /vpnapi/*.go coordinator/vpnapi/vpnproto/ COPY --from=build /vpnapi/*.go coordinator/vpnapi/vpnproto/
COPY --from=build /disk-mapper/*.go state/keyservice/proto/ COPY --from=build /disk-mapper/*.go state/keyservice/keyproto/
COPY --from=build /service/*.go debugd/service/ COPY --from=build /service/*.go debugd/service/

View File

@ -1,68 +1,103 @@
package main package main
import ( import (
"crypto/rand" "context"
"flag" "flag"
"fmt"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"time"
"github.com/edgelesssys/constellation/coordinator/config" "github.com/edgelesssys/constellation/coordinator/attestation/azure"
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
azurecloud "github.com/edgelesssys/constellation/coordinator/cloudprovider/azure"
gcpcloud "github.com/edgelesssys/constellation/coordinator/cloudprovider/gcp"
"github.com/edgelesssys/constellation/coordinator/core"
"github.com/edgelesssys/constellation/internal/utils" "github.com/edgelesssys/constellation/internal/utils"
"github.com/edgelesssys/constellation/state/keyservice" "github.com/edgelesssys/constellation/state/keyservice"
"github.com/edgelesssys/constellation/state/mapper" "github.com/edgelesssys/constellation/state/mapper"
"github.com/edgelesssys/constellation/state/setup"
"github.com/spf13/afero"
) )
const ( const (
keyPath = "/run/cryptsetup-keys.d" gcpStateDiskPath = "/dev/disk/by-id/google-state-disk"
keyFile = "state.key" azureStateDiskPath = "/dev/disk/azure/scsi1/lun0"
fallBackPath = "/dev/disk/by-id/state-disk"
) )
var csp = flag.String("csp", "", "Cloud Service Provider the image is running on") var csp = flag.String("csp", "", "Cloud Service Provider the image is running on")
func main() { func main() {
flag.Parse() flag.Parse()
diskPath, err := mapper.GetDiskPath(*csp)
if err != nil { // set up metadata API and quote issuer for aTLS connections
utils.KernelPanic(err) var err error
var diskPathErr error
var diskPath string
var issuer core.QuoteIssuer
var metadata core.ProviderMetadata
switch strings.ToLower(*csp) {
case "azure":
diskPath, diskPathErr = filepath.EvalSymlinks(azureStateDiskPath)
metadata, err = azurecloud.NewMetadata(context.Background())
if err != nil {
utils.KernelPanic(err)
}
issuer = azure.NewIssuer()
case "gcp":
diskPath, diskPathErr = filepath.EvalSymlinks(gcpStateDiskPath)
issuer = gcp.NewIssuer()
gcpClient, err := gcpcloud.NewClient(context.Background())
if err != nil {
utils.KernelPanic(err)
}
metadata = gcpcloud.New(gcpClient)
default:
diskPath, err = filepath.EvalSymlinks(fallBackPath)
if err != nil {
utils.KernelPanic(err)
}
issuer = core.NewMockIssuer()
fmt.Fprintf(os.Stderr, "warning: csp %q is not supported, unable to automatically request decryption keys on reboot\n", *csp)
metadata = &core.ProviderMetadataFake{}
}
if diskPathErr != nil {
fmt.Fprintf(os.Stderr, "warning: no attached disk detected, trying to use boot-disk state partition as fallback")
diskPath, err = filepath.EvalSymlinks(fallBackPath)
if err != nil {
utils.KernelPanic(err)
}
} }
// initialize device mapper
mapper, err := mapper.New(diskPath) mapper, err := mapper.New(diskPath)
if err != nil { if err != nil {
utils.KernelPanic(err) utils.KernelPanic(err)
} }
defer mapper.Close() defer mapper.Close()
setupManger := setup.New(
*csp,
afero.Afero{Fs: afero.NewOsFs()},
keyservice.New(issuer, metadata, 20*time.Second), // try to request a key every 20 seconds
mapper,
setup.DiskMounter{},
vtpm.OpenVTPM,
)
// prepare the state disk
if mapper.IsLUKSDevice() { if mapper.IsLUKSDevice() {
uuid := mapper.DiskUUID() err = setupManger.PrepareExistingDisk()
_, err = keyservice.WaitForDecryptionKey(*csp, uuid)
} else { } else {
err = formatDisk(mapper) err = setupManger.PrepareNewDisk()
} }
if err != nil { if err != nil {
utils.KernelPanic(err) utils.KernelPanic(err)
} }
} }
func formatDisk(mapper *mapper.Mapper) error {
// generate and save temporary passphrase
if err := os.MkdirAll(keyPath, os.ModePerm); err != nil {
utils.KernelPanic(err)
}
passphrase := make([]byte, config.RNGLengthDefault)
if _, err := rand.Read(passphrase); err != nil {
utils.KernelPanic(err)
}
if err := os.WriteFile(filepath.Join(keyPath, keyFile), passphrase, 0o400); err != nil {
utils.KernelPanic(err)
}
if err := mapper.FormatDisk(string(passphrase)); err != nil {
utils.KernelPanic(err)
}
if err := mapper.MapDisk("state", string(passphrase)); err != nil {
utils.KernelPanic(err)
}
return nil
}

View File

@ -0,0 +1,208 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc v3.17.3
// source: keyservice.proto
package keyproto
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type PushStateDiskKeyRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
StateDiskKey []byte `protobuf:"bytes,1,opt,name=state_disk_key,json=stateDiskKey,proto3" json:"state_disk_key,omitempty"`
}
func (x *PushStateDiskKeyRequest) Reset() {
*x = PushStateDiskKeyRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_keyservice_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *PushStateDiskKeyRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*PushStateDiskKeyRequest) ProtoMessage() {}
func (x *PushStateDiskKeyRequest) ProtoReflect() protoreflect.Message {
mi := &file_keyservice_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use PushStateDiskKeyRequest.ProtoReflect.Descriptor instead.
func (*PushStateDiskKeyRequest) Descriptor() ([]byte, []int) {
return file_keyservice_proto_rawDescGZIP(), []int{0}
}
func (x *PushStateDiskKeyRequest) GetStateDiskKey() []byte {
if x != nil {
return x.StateDiskKey
}
return nil
}
type PushStateDiskKeyResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *PushStateDiskKeyResponse) Reset() {
*x = PushStateDiskKeyResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_keyservice_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *PushStateDiskKeyResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*PushStateDiskKeyResponse) ProtoMessage() {}
func (x *PushStateDiskKeyResponse) ProtoReflect() protoreflect.Message {
mi := &file_keyservice_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use PushStateDiskKeyResponse.ProtoReflect.Descriptor instead.
func (*PushStateDiskKeyResponse) Descriptor() ([]byte, []int) {
return file_keyservice_proto_rawDescGZIP(), []int{1}
}
var File_keyservice_proto protoreflect.FileDescriptor
var file_keyservice_proto_rawDesc = []byte{
0x0a, 0x10, 0x6b, 0x65, 0x79, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x12, 0x08, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x3f, 0x0a, 0x17,
0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79,
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x24, 0x0a, 0x0e, 0x73, 0x74, 0x61, 0x74, 0x65,
0x5f, 0x64, 0x69, 0x73, 0x6b, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52,
0x0c, 0x73, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79, 0x22, 0x1a, 0x0a,
0x18, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65,
0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x32, 0x60, 0x0a, 0x03, 0x41, 0x50, 0x49,
0x12, 0x59, 0x0a, 0x10, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73,
0x6b, 0x4b, 0x65, 0x79, 0x12, 0x21, 0x2e, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e,
0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79,
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x22, 0x2e, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x2e, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b,
0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x40, 0x5a, 0x3e, 0x67,
0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x65, 0x64, 0x67, 0x65, 0x6c, 0x65,
0x73, 0x73, 0x73, 0x79, 0x73, 0x2f, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x65, 0x6c, 0x6c, 0x61, 0x74,
0x69, 0x6f, 0x6e, 0x2f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x2f, 0x6b, 0x65, 0x79, 0x73, 0x65, 0x72,
0x76, 0x69, 0x63, 0x65, 0x2f, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_keyservice_proto_rawDescOnce sync.Once
file_keyservice_proto_rawDescData = file_keyservice_proto_rawDesc
)
func file_keyservice_proto_rawDescGZIP() []byte {
file_keyservice_proto_rawDescOnce.Do(func() {
file_keyservice_proto_rawDescData = protoimpl.X.CompressGZIP(file_keyservice_proto_rawDescData)
})
return file_keyservice_proto_rawDescData
}
var file_keyservice_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_keyservice_proto_goTypes = []interface{}{
(*PushStateDiskKeyRequest)(nil), // 0: keyproto.PushStateDiskKeyRequest
(*PushStateDiskKeyResponse)(nil), // 1: keyproto.PushStateDiskKeyResponse
}
var file_keyservice_proto_depIdxs = []int32{
0, // 0: keyproto.API.PushStateDiskKey:input_type -> keyproto.PushStateDiskKeyRequest
1, // 1: keyproto.API.PushStateDiskKey:output_type -> keyproto.PushStateDiskKeyResponse
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_keyservice_proto_init() }
func file_keyservice_proto_init() {
if File_keyservice_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_keyservice_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*PushStateDiskKeyRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_keyservice_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*PushStateDiskKeyResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_keyservice_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_keyservice_proto_goTypes,
DependencyIndexes: file_keyservice_proto_depIdxs,
MessageInfos: file_keyservice_proto_msgTypes,
}.Build()
File_keyservice_proto = out.File
file_keyservice_proto_rawDesc = nil
file_keyservice_proto_goTypes = nil
file_keyservice_proto_depIdxs = nil
}

View File

@ -0,0 +1,16 @@
syntax = "proto3";
package keyproto;
option go_package = "github.com/edgelesssys/constellation/state/keyservice/keyproto";
service API {
rpc PushStateDiskKey(PushStateDiskKeyRequest) returns (PushStateDiskKeyResponse);
}
message PushStateDiskKeyRequest {
bytes state_disk_key = 1;
}
message PushStateDiskKeyResponse {
}

View File

@ -0,0 +1,101 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
package keyproto
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// APIClient is the client API for API service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type APIClient interface {
PushStateDiskKey(ctx context.Context, in *PushStateDiskKeyRequest, opts ...grpc.CallOption) (*PushStateDiskKeyResponse, error)
}
type aPIClient struct {
cc grpc.ClientConnInterface
}
func NewAPIClient(cc grpc.ClientConnInterface) APIClient {
return &aPIClient{cc}
}
func (c *aPIClient) PushStateDiskKey(ctx context.Context, in *PushStateDiskKeyRequest, opts ...grpc.CallOption) (*PushStateDiskKeyResponse, error) {
out := new(PushStateDiskKeyResponse)
err := c.cc.Invoke(ctx, "/keyproto.API/PushStateDiskKey", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// APIServer is the server API for API service.
// All implementations must embed UnimplementedAPIServer
// for forward compatibility
type APIServer interface {
PushStateDiskKey(context.Context, *PushStateDiskKeyRequest) (*PushStateDiskKeyResponse, error)
mustEmbedUnimplementedAPIServer()
}
// UnimplementedAPIServer must be embedded to have forward compatible implementations.
type UnimplementedAPIServer struct {
}
func (UnimplementedAPIServer) PushStateDiskKey(context.Context, *PushStateDiskKeyRequest) (*PushStateDiskKeyResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method PushStateDiskKey not implemented")
}
func (UnimplementedAPIServer) mustEmbedUnimplementedAPIServer() {}
// UnsafeAPIServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to APIServer will
// result in compilation errors.
type UnsafeAPIServer interface {
mustEmbedUnimplementedAPIServer()
}
func RegisterAPIServer(s grpc.ServiceRegistrar, srv APIServer) {
s.RegisterService(&API_ServiceDesc, srv)
}
func _API_PushStateDiskKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(PushStateDiskKeyRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(APIServer).PushStateDiskKey(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/keyproto.API/PushStateDiskKey",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(APIServer).PushStateDiskKey(ctx, req.(*PushStateDiskKeyRequest))
}
return interceptor(ctx, in, info, handler)
}
// API_ServiceDesc is the grpc.ServiceDesc for API service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var API_ServiceDesc = grpc.ServiceDesc{
ServiceName: "keyproto.API",
HandlerType: (*APIServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "PushStateDiskKey",
Handler: _API_PushStateDiskKey_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "keyservice.proto",
}

View File

@ -2,41 +2,97 @@ package keyservice
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt" "log"
"os" "net"
"strings"
"sync" "sync"
"time" "time"
"github.com/edgelesssys/constellation/coordinator/atls" "github.com/edgelesssys/constellation/coordinator/atls"
azurecloud "github.com/edgelesssys/constellation/coordinator/cloudprovider/azure" "github.com/edgelesssys/constellation/coordinator/config"
gcpcloud "github.com/edgelesssys/constellation/coordinator/cloudprovider/gcp"
"github.com/edgelesssys/constellation/coordinator/core" "github.com/edgelesssys/constellation/coordinator/core"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/role" "github.com/edgelesssys/constellation/state/keyservice/keyproto"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/status"
) )
// keyAPI is the interface called by the Coordinator or an admin during restart of a node. // KeyAPI is the interface called by the Coordinator or an admin during restart of a node.
type keyAPI struct { type KeyAPI struct {
metadata core.ProviderMetadata
mux sync.Mutex mux sync.Mutex
metadata core.ProviderMetadata
issuer core.QuoteIssuer
key []byte key []byte
keyReceived chan bool keyReceived chan struct{}
timeout time.Duration timeout time.Duration
keyproto.UnimplementedAPIServer
} }
func (a *keyAPI) waitForDecryptionKey() { // New initializes a KeyAPI with the given parameters.
// go server.Start() func New(issuer core.QuoteIssuer, metadata core.ProviderMetadata, timeout time.Duration) *KeyAPI {
// block until a key is pushed return &KeyAPI{
if <-a.keyReceived { metadata: metadata,
return issuer: issuer,
keyReceived: make(chan struct{}, 1),
timeout: timeout,
} }
} }
func (a *keyAPI) requestKeyFromCoordinator(uuid string, opts ...grpc.DialOption) error { // PushStateDiskKeyRequest is the rpc to push state disk decryption keys to a restarting node.
func (a *KeyAPI) PushStateDiskKey(ctx context.Context, in *keyproto.PushStateDiskKeyRequest) (*keyproto.PushStateDiskKeyResponse, error) {
a.mux.Lock()
defer a.mux.Unlock()
if len(a.key) != 0 {
return nil, status.Error(codes.FailedPrecondition, "node already received a passphrase")
}
if len(in.StateDiskKey) != config.RNGLengthDefault {
return nil, status.Errorf(codes.InvalidArgument, "received invalid passphrase: expected length: %d, but got: %d", config.RNGLengthDefault, len(in.StateDiskKey))
}
a.key = in.StateDiskKey
a.keyReceived <- struct{}{}
return &keyproto.PushStateDiskKeyResponse{}, nil
}
// WaitForDecryptionKey notifies the Coordinator to send a decryption key and waits until a key is received.
func (a *KeyAPI) WaitForDecryptionKey(uuid, listenAddr string) ([]byte, error) {
if uuid == "" {
return nil, errors.New("received no disk UUID")
}
tlsConfig, err := atls.CreateAttestationServerTLSConfig(a.issuer)
if err != nil {
return nil, err
}
server := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
keyproto.RegisterAPIServer(server, a)
listener, err := net.Listen("tcp", listenAddr)
if err != nil {
return nil, err
}
defer listener.Close()
log.Printf("Waiting for decryption key. Listening on: %s", listener.Addr().String())
go server.Serve(listener)
defer server.GracefulStop()
if err := a.requestKeyLoop(uuid); err != nil {
return nil, err
}
return a.key, nil
}
// ResetKey resets a previously set key.
func (a *KeyAPI) ResetKey() {
a.key = nil
}
// requestKeyLoop continuously requests decryption keys from all available Coordinators, until the KeyAPI receives a key.
func (a *KeyAPI) requestKeyLoop(uuid string, opts ...grpc.DialOption) error {
// we do not perform attestation, since the restarting node does not need to care about notifying the correct Coordinator // we do not perform attestation, since the restarting node does not need to care about notifying the correct Coordinator
// if an incorrect key is pushed by a malicious actor, decrypting the disk will fail, and the node will not start // if an incorrect key is pushed by a malicious actor, decrypting the disk will fail, and the node will not start
tlsClientConfig, err := atls.CreateUnverifiedClientTLSConfig() tlsClientConfig, err := atls.CreateUnverifiedClientTLSConfig()
@ -44,96 +100,44 @@ func (a *keyAPI) requestKeyFromCoordinator(uuid string, opts ...grpc.DialOption)
return err return err
} }
// set up for the select statement to immediately request a key, skipping the initial delay caused by using a ticker
firstReq := make(chan struct{}, 1)
firstReq <- struct{}{}
ticker := time.NewTicker(a.timeout)
defer ticker.Stop()
for { for {
select { select {
// return if a key was received by any means // return if a key was received
// a key can be send by // a key can be send by
// - a Coordinator, after the request rpc was received // - a Coordinator, after the request rpc was received
// - by a Constellation admin, at any time this loop is running on a node during boot // - by a Constellation admin, at any time this loop is running on a node during boot
case <-a.keyReceived: case <-a.keyReceived:
return nil return nil
default: case <-ticker.C:
// list available Coordinators a.requestKey(uuid, tlsClientConfig, opts...)
endpoints, _ := core.CoordinatorEndpoints(context.Background(), a.metadata) case <-firstReq:
// notify the all available Coordinators to send a key to the node a.requestKey(uuid, tlsClientConfig, opts...)
// any errors encountered here will be ignored, and the calls retried after a timeout
for _, endpoint := range endpoints {
ctx, cancel := context.WithTimeout(context.Background(), a.timeout)
conn, err := grpc.DialContext(ctx, endpoint, append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tlsClientConfig)))...)
if err == nil {
client := pubproto.NewAPIClient(conn)
_, _ = client.RequestStateDiskKey(ctx, &pubproto.RequestStateDiskKeyRequest{DiskUuid: uuid})
conn.Close()
}
cancel()
}
time.Sleep(a.timeout)
} }
} }
} }
// WaitForDecryptionKey notifies the Coordinator to send a decryption key and waits until a key is received. func (a *KeyAPI) requestKey(uuid string, tlsClientConfig *tls.Config, opts ...grpc.DialOption) {
func WaitForDecryptionKey(csp, uuid string) ([]byte, error) { // list available Coordinators
if uuid == "" { endpoints, _ := core.CoordinatorEndpoints(context.Background(), a.metadata)
return nil, errors.New("received no disk UUID")
}
keyWaiter := &keyAPI{ log.Printf("Sending a key request to available Coordinators: %v", endpoints)
keyReceived: make(chan bool, 1), // notify all available Coordinators to send a key to the node
timeout: 20 * time.Second, // try to request a key every 20 seconds // any errors encountered here will be ignored, and the calls retried after a timeout
} for _, endpoint := range endpoints {
go keyWaiter.waitForDecryptionKey() ctx, cancel := context.WithTimeout(context.Background(), a.timeout)
conn, err := grpc.DialContext(ctx, endpoint, append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tlsClientConfig)))...)
switch strings.ToLower(csp) { if err == nil {
case "azure": client := pubproto.NewAPIClient(conn)
metadata, err := azurecloud.NewMetadata(context.Background()) _, _ = client.RequestStateDiskKey(ctx, &pubproto.RequestStateDiskKeyRequest{DiskUuid: uuid})
if err != nil { conn.Close()
return nil, err
} }
keyWaiter.metadata = metadata
case "gcp": cancel()
gcpClient, err := gcpcloud.NewClient(context.Background())
if err != nil {
return nil, err
}
keyWaiter.metadata = gcpcloud.New(gcpClient)
default:
fmt.Fprintf(os.Stderr, "warning: csp %q is not supported, unable to automatically request decryption keys\n", csp)
keyWaiter.metadata = stubMetadata{}
} }
if err := keyWaiter.requestKeyFromCoordinator(uuid); err != nil {
return nil, err
}
return keyWaiter.key, nil
}
type stubMetadata struct {
listResponse []core.Instance
}
func (s stubMetadata) List(ctx context.Context) ([]core.Instance, error) {
return s.listResponse, nil
}
func (s stubMetadata) Self(ctx context.Context) (core.Instance, error) {
return core.Instance{}, nil
}
func (s stubMetadata) GetInstance(ctx context.Context, providerID string) (core.Instance, error) {
return core.Instance{}, nil
}
func (s stubMetadata) SignalRole(ctx context.Context, role role.Role) error {
return nil
}
func (s stubMetadata) SetVPNIP(ctx context.Context, vpnIP string) error {
return nil
}
func (s stubMetadata) Supported() bool {
return true
} }

View File

@ -2,22 +2,16 @@ package keyservice
import ( import (
"context" "context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors" "errors"
"net" "net"
"testing" "testing"
"time" "time"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/core" "github.com/edgelesssys/constellation/coordinator/core"
"github.com/edgelesssys/constellation/coordinator/oid"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/role" "github.com/edgelesssys/constellation/coordinator/role"
"github.com/edgelesssys/constellation/coordinator/util" "github.com/edgelesssys/constellation/state/keyservice/keyproto"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -25,7 +19,7 @@ import (
"google.golang.org/grpc/test/bufconn" "google.golang.org/grpc/test/bufconn"
) )
func TestRequestLoop(t *testing.T) { func TestRequestKeyLoop(t *testing.T) {
defaultInstance := core.Instance{ defaultInstance := core.Instance{
Name: "test-instance", Name: "test-instance",
ProviderID: "/test/provider", ProviderID: "/test/provider",
@ -77,11 +71,11 @@ func TestRequestLoop(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t) require := require.New(t)
keyReceived := make(chan bool, 1) keyReceived := make(chan struct{}, 1)
listener := bufconn.Listen(1) listener := bufconn.Listen(1)
defer listener.Close() defer listener.Close()
tlsConfig, err := stubTLSConfig() tlsConfig, err := atls.CreateAttestationServerTLSConfig(core.NewMockIssuer())
require.NoError(err) require.NoError(err)
s := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) s := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
pubproto.RegisterAPIServer(s, tc.server) pubproto.RegisterAPIServer(s, tc.server)
@ -90,7 +84,7 @@ func TestRequestLoop(t *testing.T) {
go func() { require.NoError(s.Serve(listener)) }() go func() { require.NoError(s.Serve(listener)) }()
} }
keyWaiter := &keyAPI{ keyWaiter := &KeyAPI{
metadata: stubMetadata{listResponse: tc.listResponse}, metadata: stubMetadata{listResponse: tc.listResponse},
keyReceived: keyReceived, keyReceived: keyReceived,
timeout: 500 * time.Millisecond, timeout: 500 * time.Millisecond,
@ -99,10 +93,10 @@ func TestRequestLoop(t *testing.T) {
// notify the API a key was received after 1 second // notify the API a key was received after 1 second
go func() { go func() {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
keyReceived <- true keyReceived <- struct{}{}
}() }()
err = keyWaiter.requestKeyFromCoordinator( err = keyWaiter.requestKeyLoop(
"1234", "1234",
grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) { grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) {
return listener.DialContext(ctx) return listener.DialContext(ctx)
@ -115,6 +109,54 @@ func TestRequestLoop(t *testing.T) {
} }
} }
func TestPushStateDiskKey(t *testing.T) {
testCases := map[string]struct {
testAPI *KeyAPI
request *keyproto.PushStateDiskKeyRequest
errExpected bool
}{
"success": {
testAPI: &KeyAPI{keyReceived: make(chan struct{}, 1)},
request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")},
},
"key already set": {
testAPI: &KeyAPI{
keyReceived: make(chan struct{}, 1),
key: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
},
request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB")},
errExpected: true,
},
"incorrect size of pushed key": {
testAPI: &KeyAPI{keyReceived: make(chan struct{}, 1)},
request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("AAAAAAAAAAAAAAAA")},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
_, err := tc.testAPI.PushStateDiskKey(context.Background(), tc.request)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.request.StateDiskKey, tc.testAPI.key)
}
})
}
}
func TestResetKey(t *testing.T) {
api := New(nil, nil, time.Second)
api.key = []byte{0x1, 0x2, 0x3}
api.ResetKey()
assert.Nil(t, api.key)
}
type stubAPIServer struct { type stubAPIServer struct {
requestStateDiskKeyResp *pubproto.RequestStateDiskKeyResponse requestStateDiskKeyResp *pubproto.RequestStateDiskKeyResponse
requestStateDiskKeyErr error requestStateDiskKeyErr error
@ -149,30 +191,30 @@ func (s *stubAPIServer) RequestStateDiskKey(ctx context.Context, in *pubproto.Re
return s.requestStateDiskKeyResp, s.requestStateDiskKeyErr return s.requestStateDiskKeyResp, s.requestStateDiskKeyErr
} }
func stubTLSConfig() (*tls.Config, error) { type stubMetadata struct {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) listResponse []core.Instance
if err != nil { }
return nil, err
} func (s stubMetadata) List(ctx context.Context) ([]core.Instance, error) {
getCertificate := func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { return s.listResponse, nil
serialNumber, err := util.GenerateCertificateSerialNumber() }
if err != nil {
return nil, err func (s stubMetadata) Self(ctx context.Context) (core.Instance, error) {
} return core.Instance{}, nil
now := time.Now() }
template := &x509.Certificate{
SerialNumber: serialNumber, func (s stubMetadata) GetInstance(ctx context.Context, providerID string) (core.Instance, error) {
Subject: pkix.Name{CommonName: "Constellation"}, return core.Instance{}, nil
NotBefore: now.Add(-2 * time.Hour), }
NotAfter: now.Add(2 * time.Hour),
ExtraExtensions: []pkix.Extension{{Id: oid.Dummy{}.OID(), Value: []byte{0x1, 0x2, 0x3}}}, func (s stubMetadata) SignalRole(ctx context.Context, role role.Role) error {
} return nil
cert, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) }
if err != nil {
return nil, err func (s stubMetadata) SetVPNIP(ctx context.Context, vpnIP string) error {
} return nil
}
return &tls.Certificate{Certificate: [][]byte{cert}, PrivateKey: priv}, nil
} func (s stubMetadata) Supported() bool {
return &tls.Config{GetCertificate: getCertificate, MinVersion: tls.VersionTLS12}, nil return true
} }

View File

@ -3,17 +3,10 @@ package mapper
import ( import (
"errors" "errors"
"fmt" "fmt"
"path/filepath"
cryptsetup "github.com/martinjungblut/go-cryptsetup" cryptsetup "github.com/martinjungblut/go-cryptsetup"
) )
const (
gcpStateDiskPath = "/dev/disk/by-id/google-state-disk"
azureStateDiskPath = "/dev/disk/azure/scsi1/lun0"
fallBackPath = "/dev/disk/by-id/state-disk"
)
// Mapper handles actions for formating and mapping crypt devices. // Mapper handles actions for formating and mapping crypt devices.
type Mapper struct { type Mapper struct {
device cryptDevice device cryptDevice
@ -89,28 +82,3 @@ func (m *Mapper) MapDisk(target, passphrase string) error {
func (m *Mapper) UnmapDisk(target string) error { func (m *Mapper) UnmapDisk(target string) error {
return m.device.Deactivate(target) return m.device.Deactivate(target)
} }
// GetDiskPath returns the device path of the data disk by cloud provider.
//
// For GCP a symlink to the disk is expected at /dev/disk/by-id/google-state-disk
// For Azure a symlink to the disk is expected at /dev/disk/azure/scsi1/lun0
// If no symlink can be found at the given path, or if no known cloud provider is supplied,
// we instead return the device path of the os-disk stateful partition at /dev/disk/by-partlabel/stateful.
func GetDiskPath(csp string) (string, error) {
var diskPath string
var err error
switch csp {
case "gcp":
diskPath, err = filepath.EvalSymlinks(gcpStateDiskPath)
case "azure":
diskPath, err = filepath.EvalSymlinks(azureStateDiskPath)
default:
diskPath = fallBackPath
}
if err != nil {
return filepath.EvalSymlinks(fallBackPath)
}
return diskPath, nil
}

45
state/setup/interface.go Normal file
View File

@ -0,0 +1,45 @@
package setup
import (
"io/fs"
"os"
"syscall"
)
// Mounter is an interface for mount and unmount operations.
type Mounter interface {
Mount(source string, target string, fstype string, flags uintptr, data string) error
Unmount(target string, flags int) error
MkdirAll(path string, perm fs.FileMode) error
}
// DeviceMapper is an interface for device mapping operations.
type DeviceMapper interface {
DiskUUID() string
FormatDisk(passphrase string) error
MapDisk(target string, passphrase string) error
}
// KeyWaiter is an interface to request and wait for disk decryption keys.
type KeyWaiter interface {
WaitForDecryptionKey(uuid, addr string) ([]byte, error)
ResetKey()
}
// DiskMounter uses the syscall package to mount disks.
type DiskMounter struct{}
// Mount performs a mount syscall.
func (m DiskMounter) Mount(source string, target string, fstype string, flags uintptr, data string) error {
return syscall.Mount(source, target, fstype, flags, data)
}
// Unmount performs an unmount syscall.
func (m DiskMounter) Unmount(target string, flags int) error {
return syscall.Unmount(target, flags)
}
// MkdirAll uses os.MkdirAll to create the directory.
func (m DiskMounter) MkdirAll(path string, perm fs.FileMode) error {
return os.MkdirAll(path, perm)
}

125
state/setup/setup.go Normal file
View File

@ -0,0 +1,125 @@
package setup
import (
"crypto/rand"
"errors"
"log"
"net"
"os"
"path/filepath"
"syscall"
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
"github.com/edgelesssys/constellation/coordinator/config"
"github.com/edgelesssys/constellation/coordinator/nodestate"
"github.com/spf13/afero"
)
const (
RecoveryPort = "9000"
keyPath = "/run/cryptsetup-keys.d"
keyFile = "state.key"
stateDiskMappedName = "state"
stateDiskMountPath = "/var/run/state"
stateInfoPath = stateDiskMountPath + "/constellation/node_state.json"
)
// SetupManager handles formating, mapping, mounting and unmounting of state disks.
type SetupManager struct {
csp string
fs afero.Afero
keyWaiter KeyWaiter
mapper DeviceMapper
mounter Mounter
openTPM vtpm.TPMOpenFunc
}
// New initializes a SetupManager with the given parameters.
func New(csp string, fs afero.Afero, keyWaiter KeyWaiter, mapper DeviceMapper, mounter Mounter, openTPM vtpm.TPMOpenFunc) *SetupManager {
return &SetupManager{
csp: csp,
fs: fs,
keyWaiter: keyWaiter,
mapper: mapper,
mounter: mounter,
openTPM: openTPM,
}
}
// PrepareExistingDisk requests and waits for a decryption key to remap the encrypted state disk.
// Once the disk is mapped, the function taints the node as initialized by updating it's PCRs.
func (s *SetupManager) PrepareExistingDisk() error {
log.Println("Preparing existing state disk")
uuid := s.mapper.DiskUUID()
getKey:
passphrase, err := s.keyWaiter.WaitForDecryptionKey(uuid, net.JoinHostPort("0.0.0.0", RecoveryPort))
if err != nil {
return err
}
if err := s.mapper.MapDisk(stateDiskMappedName, string(passphrase)); err != nil {
// retry key fetching if disk mapping fails
s.keyWaiter.ResetKey()
goto getKey
}
if err := s.mounter.MkdirAll(stateDiskMountPath, os.ModePerm); err != nil {
return err
}
// we do not care about cleaning up the mount point on error, since any errors returned here should result in a kernel panic in the main function
if err := s.mounter.Mount(filepath.Join("/dev/mapper/", stateDiskMappedName), stateDiskMountPath, "ext4", syscall.MS_RDONLY, ""); err != nil {
return err
}
ownerID, clusterID, err := s.readInitSecrets(stateInfoPath)
if err != nil {
return err
}
// taint the node as initialized
if err := vtpm.MarkNodeAsInitialized(s.openTPM, ownerID, clusterID); err != nil {
return err
}
return s.mounter.Unmount(stateDiskMountPath, 0)
}
// PrepareNewDisk prepares an instances state disk by formatting the disk as a LUKS device using a random passphrase.
func (s *SetupManager) PrepareNewDisk() error {
log.Println("Preparing new state disk")
// generate and save temporary passphrase
if err := s.fs.MkdirAll(keyPath, os.ModePerm); err != nil {
return err
}
passphrase := make([]byte, config.RNGLengthDefault)
if _, err := rand.Read(passphrase); err != nil {
return err
}
if err := s.fs.WriteFile(filepath.Join(keyPath, keyFile), passphrase, 0o400); err != nil {
return err
}
if err := s.mapper.FormatDisk(string(passphrase)); err != nil {
return err
}
return s.mapper.MapDisk(stateDiskMappedName, string(passphrase))
}
func (s *SetupManager) readInitSecrets(path string) ([]byte, []byte, error) {
handler := file.NewHandler(s.fs)
var state nodestate.NodeState
if err := handler.ReadJSON(path, &state); err != nil {
return nil, nil, err
}
if len(state.ClusterID) == 0 || len(state.OwnerID) == 0 {
return nil, nil, errors.New("missing state information to retaint node")
}
return state.OwnerID, state.ClusterID, nil
}

317
state/setup/setup_test.go Normal file
View File

@ -0,0 +1,317 @@
package setup
import (
"errors"
"io"
"io/fs"
"path/filepath"
"testing"
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
"github.com/edgelesssys/constellation/coordinator/config"
"github.com/edgelesssys/constellation/coordinator/nodestate"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPrepareExistingDisk(t *testing.T) {
someErr := errors.New("error")
testCases := map[string]struct {
fs afero.Afero
keyWaiter *stubKeyWaiter
mapper *stubMapper
mounter *stubMounter
openTPM vtpm.TPMOpenFunc
missingState bool
errExpected bool
}{
"success": {
fs: afero.Afero{Fs: afero.NewMemMapFs()},
keyWaiter: &stubKeyWaiter{},
mapper: &stubMapper{uuid: "test"},
mounter: &stubMounter{},
openTPM: vtpm.OpenNOPTPM,
},
"WaitForDecryptionKey fails": {
fs: afero.Afero{Fs: afero.NewMemMapFs()},
keyWaiter: &stubKeyWaiter{waitErr: someErr},
mapper: &stubMapper{uuid: "test"},
mounter: &stubMounter{},
openTPM: vtpm.OpenNOPTPM,
errExpected: true,
},
"MapDisk fails causes a repeat": {
fs: afero.Afero{Fs: afero.NewMemMapFs()},
keyWaiter: &stubKeyWaiter{},
mapper: &stubMapper{
uuid: "test",
mapDiskErr: someErr,
mapDiskRepeatedCalls: 2,
},
mounter: &stubMounter{},
openTPM: vtpm.OpenNOPTPM,
errExpected: false,
},
"MkdirAll fails": {
fs: afero.Afero{Fs: afero.NewMemMapFs()},
keyWaiter: &stubKeyWaiter{},
mapper: &stubMapper{uuid: "test"},
mounter: &stubMounter{mkdirAllErr: someErr},
openTPM: vtpm.OpenNOPTPM,
errExpected: true,
},
"Mount fails": {
fs: afero.Afero{Fs: afero.NewMemMapFs()},
keyWaiter: &stubKeyWaiter{},
mapper: &stubMapper{uuid: "test"},
mounter: &stubMounter{mountErr: someErr},
openTPM: vtpm.OpenNOPTPM,
errExpected: true,
},
"Unmount fails": {
fs: afero.Afero{Fs: afero.NewMemMapFs()},
keyWaiter: &stubKeyWaiter{},
mapper: &stubMapper{uuid: "test"},
mounter: &stubMounter{unmountErr: someErr},
openTPM: vtpm.OpenNOPTPM,
errExpected: true,
},
"MarkNodeAsInitialized fails": {
fs: afero.Afero{Fs: afero.NewMemMapFs()},
keyWaiter: &stubKeyWaiter{},
mapper: &stubMapper{uuid: "test"},
mounter: &stubMounter{unmountErr: someErr},
openTPM: failOpener,
errExpected: true,
},
"no state file": {
fs: afero.Afero{Fs: afero.NewMemMapFs()},
keyWaiter: &stubKeyWaiter{},
mapper: &stubMapper{uuid: "test"},
mounter: &stubMounter{},
openTPM: vtpm.OpenNOPTPM,
missingState: true,
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
if !tc.missingState {
handler := file.NewHandler(tc.fs)
require.NoError(t, handler.WriteJSON(stateInfoPath, nodestate.NodeState{OwnerID: []byte("ownerID"), ClusterID: []byte("clusterID")}, file.OptMkdirAll))
}
setupManager := New("test", tc.fs, tc.keyWaiter, tc.mapper, tc.mounter, tc.openTPM)
err := setupManager.PrepareExistingDisk()
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.mapper.uuid, tc.keyWaiter.receivedUUID)
assert.True(tc.mapper.mapDiskCalled)
assert.True(tc.mounter.mountCalled)
assert.True(tc.mounter.unmountCalled)
assert.False(tc.mapper.formatDiskCalled)
}
})
}
}
func failOpener() (io.ReadWriteCloser, error) {
return nil, errors.New("error")
}
func TestPrepareNewDisk(t *testing.T) {
someErr := errors.New("error")
testCases := map[string]struct {
fs afero.Afero
mapper *stubMapper
errExpected bool
}{
"success": {
fs: afero.Afero{Fs: afero.NewMemMapFs()},
mapper: &stubMapper{uuid: "test"},
},
"creating directory fails": {
fs: afero.Afero{Fs: afero.NewReadOnlyFs(afero.NewMemMapFs())},
mapper: &stubMapper{},
errExpected: true,
},
"FormatDisk fails": {
fs: afero.Afero{Fs: afero.NewMemMapFs()},
mapper: &stubMapper{
uuid: "test",
formatDiskErr: someErr,
},
errExpected: true,
},
"MapDisk fails": {
fs: afero.Afero{Fs: afero.NewMemMapFs()},
mapper: &stubMapper{
uuid: "test",
mapDiskErr: someErr,
mapDiskRepeatedCalls: 1,
},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
setupManager := New("test", tc.fs, nil, tc.mapper, nil, nil)
err := setupManager.PrepareNewDisk()
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.True(tc.mapper.formatDiskCalled)
assert.True(tc.mapper.mapDiskCalled)
data, err := tc.fs.ReadFile(filepath.Join(keyPath, keyFile))
require.NoError(t, err)
assert.Len(data, config.RNGLengthDefault)
}
})
}
}
func TestReadInitSecrets(t *testing.T) {
testCases := map[string]struct {
fs afero.Afero
ownerID string
clusterID string
writeFile bool
errExpected bool
}{
"success": {
fs: afero.Afero{Fs: afero.NewMemMapFs()},
ownerID: "ownerID",
clusterID: "clusterID",
writeFile: true,
},
"no state file": {
fs: afero.Afero{Fs: afero.NewMemMapFs()},
errExpected: true,
},
"missing ownerID": {
fs: afero.Afero{Fs: afero.NewMemMapFs()},
clusterID: "clusterID",
writeFile: true,
errExpected: true,
},
"missing clusterID": {
fs: afero.Afero{Fs: afero.NewMemMapFs()},
ownerID: "ownerID",
writeFile: true,
errExpected: true,
},
"no IDs": {
fs: afero.Afero{Fs: afero.NewMemMapFs()},
writeFile: true,
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
if tc.writeFile {
handler := file.NewHandler(tc.fs)
state := nodestate.NodeState{ClusterID: []byte(tc.clusterID), OwnerID: []byte(tc.ownerID)}
require.NoError(handler.WriteJSON("/tmp/test-state.json", state, file.OptMkdirAll))
}
setupManager := New("test", tc.fs, nil, nil, nil, nil)
ownerID, clusterID, err := setupManager.readInitSecrets("/tmp/test-state.json")
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal([]byte(tc.ownerID), ownerID)
assert.Equal([]byte(tc.clusterID), clusterID)
}
})
}
}
type stubMapper struct {
formatDiskCalled bool
formatDiskErr error
mapDiskRepeatedCalls int
mapDiskCalled bool
mapDiskErr error
uuid string
}
func (s *stubMapper) DiskUUID() string {
return s.uuid
}
func (s *stubMapper) FormatDisk(passphrase string) error {
s.formatDiskCalled = true
return s.formatDiskErr
}
func (s *stubMapper) MapDisk(target string, passphrase string) error {
if s.mapDiskRepeatedCalls == 0 {
s.mapDiskErr = nil
}
s.mapDiskRepeatedCalls--
s.mapDiskCalled = true
return s.mapDiskErr
}
type stubMounter struct {
mountCalled bool
mountErr error
unmountCalled bool
unmountErr error
mkdirAllErr error
}
func (s *stubMounter) Mount(source string, target string, fstype string, flags uintptr, data string) error {
s.mountCalled = true
return s.mountErr
}
func (s *stubMounter) Unmount(target string, flags int) error {
s.unmountCalled = true
return s.unmountErr
}
func (s *stubMounter) MkdirAll(path string, perm fs.FileMode) error {
return s.mkdirAllErr
}
type stubKeyWaiter struct {
receivedUUID string
decryptionKey []byte
waitErr error
waitCalled bool
}
func (s *stubKeyWaiter) WaitForDecryptionKey(uuid, addr string) ([]byte, error) {
if s.waitCalled {
return nil, errors.New("wait called before key was reset")
}
s.waitCalled = true
s.receivedUUID = uuid
return s.decryptionKey, s.waitErr
}
func (s *stubKeyWaiter) ResetKey() {
s.waitCalled = false
}

View File

@ -3,15 +3,24 @@
package integration package integration
import ( import (
"context"
"fmt" "fmt"
"net"
"os" "os"
"os/exec" "os/exec"
"testing" "testing"
"time"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/core"
"github.com/edgelesssys/constellation/state/keyservice"
"github.com/edgelesssys/constellation/state/keyservice/keyproto"
"github.com/edgelesssys/constellation/state/mapper" "github.com/edgelesssys/constellation/state/mapper"
"github.com/martinjungblut/go-cryptsetup" "github.com/martinjungblut/go-cryptsetup"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
) )
const ( const (
@ -62,3 +71,42 @@ func TestMapper(t *testing.T) {
// Try to map disk with incorrect passphrase // Try to map disk with incorrect passphrase
assert.Error(mapper.MapDisk(mappedDevice, "invalid-passphrase"), "was able to map disk with incorrect passphrase") assert.Error(mapper.MapDisk(mappedDevice, "invalid-passphrase"), "was able to map disk with incorrect passphrase")
} }
func TestKeyAPI(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
testKey := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
// get a free port on localhost to run the test on
listener, err := net.Listen("tcp", "localhost:0")
require.NoError(err)
apiAddr := listener.Addr().String()
listener.Close()
api := keyservice.New(core.NewMockIssuer(), &core.ProviderMetadataFake{}, 20*time.Second)
// send a key to the server
go func() {
// wait 2 seconds before sending the key
time.Sleep(2 * time.Second)
clientCfg, err := atls.CreateUnverifiedClientTLSConfig()
require.NoError(err)
conn, err := grpc.Dial(apiAddr, grpc.WithTransportCredentials(credentials.NewTLS(clientCfg)))
require.NoError(err)
defer conn.Close()
client := keyproto.NewAPIClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
_, err = client.PushStateDiskKey(ctx, &keyproto.PushStateDiskKeyRequest{
StateDiskKey: testKey,
})
require.NoError(err)
}()
key, err := api.WaitForDecryptionKey("12345678-1234-1234-1234-123456789ABC", apiAddr)
assert.NoError(err)
assert.Equal(testKey, key)
}