mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-01-11 23:49:30 -05:00
AB#1903 Add grpc interface to push decryption keys
Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
parent
96d7029367
commit
152e3985f7
@ -36,7 +36,7 @@ RUN protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_o
|
||||
|
||||
## disk-mapper keyservice api
|
||||
WORKDIR /disk-mapper
|
||||
COPY state/keyservice/proto/*.proto /disk-mapper
|
||||
COPY state/keyservice/keyproto/*.proto /disk-mapper
|
||||
RUN protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative *.proto
|
||||
|
||||
## debugd service
|
||||
@ -48,5 +48,5 @@ RUN protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_o
|
||||
FROM scratch as export
|
||||
COPY --from=build /pubapi/*.go coordinator/pubapi/pubproto/
|
||||
COPY --from=build /vpnapi/*.go coordinator/vpnapi/vpnproto/
|
||||
COPY --from=build /disk-mapper/*.go state/keyservice/proto/
|
||||
COPY --from=build /disk-mapper/*.go state/keyservice/keyproto/
|
||||
COPY --from=build /service/*.go debugd/service/
|
||||
|
@ -1,68 +1,103 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/config"
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/azure"
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
|
||||
azurecloud "github.com/edgelesssys/constellation/coordinator/cloudprovider/azure"
|
||||
gcpcloud "github.com/edgelesssys/constellation/coordinator/cloudprovider/gcp"
|
||||
"github.com/edgelesssys/constellation/coordinator/core"
|
||||
"github.com/edgelesssys/constellation/internal/utils"
|
||||
"github.com/edgelesssys/constellation/state/keyservice"
|
||||
"github.com/edgelesssys/constellation/state/mapper"
|
||||
"github.com/edgelesssys/constellation/state/setup"
|
||||
"github.com/spf13/afero"
|
||||
)
|
||||
|
||||
const (
|
||||
keyPath = "/run/cryptsetup-keys.d"
|
||||
keyFile = "state.key"
|
||||
gcpStateDiskPath = "/dev/disk/by-id/google-state-disk"
|
||||
azureStateDiskPath = "/dev/disk/azure/scsi1/lun0"
|
||||
fallBackPath = "/dev/disk/by-id/state-disk"
|
||||
)
|
||||
|
||||
var csp = flag.String("csp", "", "Cloud Service Provider the image is running on")
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
diskPath, err := mapper.GetDiskPath(*csp)
|
||||
|
||||
// set up metadata API and quote issuer for aTLS connections
|
||||
var err error
|
||||
var diskPathErr error
|
||||
var diskPath string
|
||||
var issuer core.QuoteIssuer
|
||||
var metadata core.ProviderMetadata
|
||||
switch strings.ToLower(*csp) {
|
||||
case "azure":
|
||||
diskPath, diskPathErr = filepath.EvalSymlinks(azureStateDiskPath)
|
||||
metadata, err = azurecloud.NewMetadata(context.Background())
|
||||
if err != nil {
|
||||
utils.KernelPanic(err)
|
||||
}
|
||||
issuer = azure.NewIssuer()
|
||||
|
||||
case "gcp":
|
||||
diskPath, diskPathErr = filepath.EvalSymlinks(gcpStateDiskPath)
|
||||
issuer = gcp.NewIssuer()
|
||||
gcpClient, err := gcpcloud.NewClient(context.Background())
|
||||
if err != nil {
|
||||
utils.KernelPanic(err)
|
||||
}
|
||||
metadata = gcpcloud.New(gcpClient)
|
||||
|
||||
default:
|
||||
diskPath, err = filepath.EvalSymlinks(fallBackPath)
|
||||
if err != nil {
|
||||
utils.KernelPanic(err)
|
||||
}
|
||||
issuer = core.NewMockIssuer()
|
||||
fmt.Fprintf(os.Stderr, "warning: csp %q is not supported, unable to automatically request decryption keys on reboot\n", *csp)
|
||||
metadata = &core.ProviderMetadataFake{}
|
||||
}
|
||||
if diskPathErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "warning: no attached disk detected, trying to use boot-disk state partition as fallback")
|
||||
diskPath, err = filepath.EvalSymlinks(fallBackPath)
|
||||
if err != nil {
|
||||
utils.KernelPanic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// initialize device mapper
|
||||
mapper, err := mapper.New(diskPath)
|
||||
if err != nil {
|
||||
utils.KernelPanic(err)
|
||||
}
|
||||
defer mapper.Close()
|
||||
|
||||
setupManger := setup.New(
|
||||
*csp,
|
||||
afero.Afero{Fs: afero.NewOsFs()},
|
||||
keyservice.New(issuer, metadata, 20*time.Second), // try to request a key every 20 seconds
|
||||
mapper,
|
||||
setup.DiskMounter{},
|
||||
vtpm.OpenVTPM,
|
||||
)
|
||||
|
||||
// prepare the state disk
|
||||
if mapper.IsLUKSDevice() {
|
||||
uuid := mapper.DiskUUID()
|
||||
_, err = keyservice.WaitForDecryptionKey(*csp, uuid)
|
||||
err = setupManger.PrepareExistingDisk()
|
||||
} else {
|
||||
err = formatDisk(mapper)
|
||||
err = setupManger.PrepareNewDisk()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
utils.KernelPanic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func formatDisk(mapper *mapper.Mapper) error {
|
||||
// generate and save temporary passphrase
|
||||
if err := os.MkdirAll(keyPath, os.ModePerm); err != nil {
|
||||
utils.KernelPanic(err)
|
||||
}
|
||||
passphrase := make([]byte, config.RNGLengthDefault)
|
||||
if _, err := rand.Read(passphrase); err != nil {
|
||||
utils.KernelPanic(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(keyPath, keyFile), passphrase, 0o400); err != nil {
|
||||
utils.KernelPanic(err)
|
||||
}
|
||||
|
||||
if err := mapper.FormatDisk(string(passphrase)); err != nil {
|
||||
utils.KernelPanic(err)
|
||||
}
|
||||
if err := mapper.MapDisk("state", string(passphrase)); err != nil {
|
||||
utils.KernelPanic(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
208
state/keyservice/keyproto/keyservice.pb.go
Normal file
208
state/keyservice/keyproto/keyservice.pb.go
Normal file
@ -0,0 +1,208 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.27.1
|
||||
// protoc v3.17.3
|
||||
// source: keyservice.proto
|
||||
|
||||
package keyproto
|
||||
|
||||
import (
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
type PushStateDiskKeyRequest struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
StateDiskKey []byte `protobuf:"bytes,1,opt,name=state_disk_key,json=stateDiskKey,proto3" json:"state_disk_key,omitempty"`
|
||||
}
|
||||
|
||||
func (x *PushStateDiskKeyRequest) Reset() {
|
||||
*x = PushStateDiskKeyRequest{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_keyservice_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *PushStateDiskKeyRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*PushStateDiskKeyRequest) ProtoMessage() {}
|
||||
|
||||
func (x *PushStateDiskKeyRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_keyservice_proto_msgTypes[0]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use PushStateDiskKeyRequest.ProtoReflect.Descriptor instead.
|
||||
func (*PushStateDiskKeyRequest) Descriptor() ([]byte, []int) {
|
||||
return file_keyservice_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *PushStateDiskKeyRequest) GetStateDiskKey() []byte {
|
||||
if x != nil {
|
||||
return x.StateDiskKey
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type PushStateDiskKeyResponse struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
}
|
||||
|
||||
func (x *PushStateDiskKeyResponse) Reset() {
|
||||
*x = PushStateDiskKeyResponse{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_keyservice_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *PushStateDiskKeyResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*PushStateDiskKeyResponse) ProtoMessage() {}
|
||||
|
||||
func (x *PushStateDiskKeyResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_keyservice_proto_msgTypes[1]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use PushStateDiskKeyResponse.ProtoReflect.Descriptor instead.
|
||||
func (*PushStateDiskKeyResponse) Descriptor() ([]byte, []int) {
|
||||
return file_keyservice_proto_rawDescGZIP(), []int{1}
|
||||
}
|
||||
|
||||
var File_keyservice_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_keyservice_proto_rawDesc = []byte{
|
||||
0x0a, 0x10, 0x6b, 0x65, 0x79, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f,
|
||||
0x74, 0x6f, 0x12, 0x08, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x3f, 0x0a, 0x17,
|
||||
0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79,
|
||||
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x24, 0x0a, 0x0e, 0x73, 0x74, 0x61, 0x74, 0x65,
|
||||
0x5f, 0x64, 0x69, 0x73, 0x6b, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52,
|
||||
0x0c, 0x73, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79, 0x22, 0x1a, 0x0a,
|
||||
0x18, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65,
|
||||
0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x32, 0x60, 0x0a, 0x03, 0x41, 0x50, 0x49,
|
||||
0x12, 0x59, 0x0a, 0x10, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73,
|
||||
0x6b, 0x4b, 0x65, 0x79, 0x12, 0x21, 0x2e, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e,
|
||||
0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79,
|
||||
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x22, 0x2e, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f,
|
||||
0x74, 0x6f, 0x2e, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b,
|
||||
0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x40, 0x5a, 0x3e, 0x67,
|
||||
0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x65, 0x64, 0x67, 0x65, 0x6c, 0x65,
|
||||
0x73, 0x73, 0x73, 0x79, 0x73, 0x2f, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x65, 0x6c, 0x6c, 0x61, 0x74,
|
||||
0x69, 0x6f, 0x6e, 0x2f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x2f, 0x6b, 0x65, 0x79, 0x73, 0x65, 0x72,
|
||||
0x76, 0x69, 0x63, 0x65, 0x2f, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70,
|
||||
0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
file_keyservice_proto_rawDescOnce sync.Once
|
||||
file_keyservice_proto_rawDescData = file_keyservice_proto_rawDesc
|
||||
)
|
||||
|
||||
func file_keyservice_proto_rawDescGZIP() []byte {
|
||||
file_keyservice_proto_rawDescOnce.Do(func() {
|
||||
file_keyservice_proto_rawDescData = protoimpl.X.CompressGZIP(file_keyservice_proto_rawDescData)
|
||||
})
|
||||
return file_keyservice_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_keyservice_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
|
||||
var file_keyservice_proto_goTypes = []interface{}{
|
||||
(*PushStateDiskKeyRequest)(nil), // 0: keyproto.PushStateDiskKeyRequest
|
||||
(*PushStateDiskKeyResponse)(nil), // 1: keyproto.PushStateDiskKeyResponse
|
||||
}
|
||||
var file_keyservice_proto_depIdxs = []int32{
|
||||
0, // 0: keyproto.API.PushStateDiskKey:input_type -> keyproto.PushStateDiskKeyRequest
|
||||
1, // 1: keyproto.API.PushStateDiskKey:output_type -> keyproto.PushStateDiskKeyResponse
|
||||
1, // [1:2] is the sub-list for method output_type
|
||||
0, // [0:1] is the sub-list for method input_type
|
||||
0, // [0:0] is the sub-list for extension type_name
|
||||
0, // [0:0] is the sub-list for extension extendee
|
||||
0, // [0:0] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_keyservice_proto_init() }
|
||||
func file_keyservice_proto_init() {
|
||||
if File_keyservice_proto != nil {
|
||||
return
|
||||
}
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_keyservice_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*PushStateDiskKeyRequest); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_keyservice_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*PushStateDiskKeyResponse); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_keyservice_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
NumMessages: 2,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
GoTypes: file_keyservice_proto_goTypes,
|
||||
DependencyIndexes: file_keyservice_proto_depIdxs,
|
||||
MessageInfos: file_keyservice_proto_msgTypes,
|
||||
}.Build()
|
||||
File_keyservice_proto = out.File
|
||||
file_keyservice_proto_rawDesc = nil
|
||||
file_keyservice_proto_goTypes = nil
|
||||
file_keyservice_proto_depIdxs = nil
|
||||
}
|
16
state/keyservice/keyproto/keyservice.proto
Normal file
16
state/keyservice/keyproto/keyservice.proto
Normal file
@ -0,0 +1,16 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package keyproto;
|
||||
|
||||
option go_package = "github.com/edgelesssys/constellation/state/keyservice/keyproto";
|
||||
|
||||
service API {
|
||||
rpc PushStateDiskKey(PushStateDiskKeyRequest) returns (PushStateDiskKeyResponse);
|
||||
}
|
||||
|
||||
message PushStateDiskKeyRequest {
|
||||
bytes state_disk_key = 1;
|
||||
}
|
||||
|
||||
message PushStateDiskKeyResponse {
|
||||
}
|
101
state/keyservice/keyproto/keyservice_grpc.pb.go
Normal file
101
state/keyservice/keyproto/keyservice_grpc.pb.go
Normal file
@ -0,0 +1,101 @@
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
|
||||
package keyproto
|
||||
|
||||
import (
|
||||
context "context"
|
||||
grpc "google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
// Requires gRPC-Go v1.32.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion7
|
||||
|
||||
// APIClient is the client API for API service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type APIClient interface {
|
||||
PushStateDiskKey(ctx context.Context, in *PushStateDiskKeyRequest, opts ...grpc.CallOption) (*PushStateDiskKeyResponse, error)
|
||||
}
|
||||
|
||||
type aPIClient struct {
|
||||
cc grpc.ClientConnInterface
|
||||
}
|
||||
|
||||
func NewAPIClient(cc grpc.ClientConnInterface) APIClient {
|
||||
return &aPIClient{cc}
|
||||
}
|
||||
|
||||
func (c *aPIClient) PushStateDiskKey(ctx context.Context, in *PushStateDiskKeyRequest, opts ...grpc.CallOption) (*PushStateDiskKeyResponse, error) {
|
||||
out := new(PushStateDiskKeyResponse)
|
||||
err := c.cc.Invoke(ctx, "/keyproto.API/PushStateDiskKey", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// APIServer is the server API for API service.
|
||||
// All implementations must embed UnimplementedAPIServer
|
||||
// for forward compatibility
|
||||
type APIServer interface {
|
||||
PushStateDiskKey(context.Context, *PushStateDiskKeyRequest) (*PushStateDiskKeyResponse, error)
|
||||
mustEmbedUnimplementedAPIServer()
|
||||
}
|
||||
|
||||
// UnimplementedAPIServer must be embedded to have forward compatible implementations.
|
||||
type UnimplementedAPIServer struct {
|
||||
}
|
||||
|
||||
func (UnimplementedAPIServer) PushStateDiskKey(context.Context, *PushStateDiskKeyRequest) (*PushStateDiskKeyResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method PushStateDiskKey not implemented")
|
||||
}
|
||||
func (UnimplementedAPIServer) mustEmbedUnimplementedAPIServer() {}
|
||||
|
||||
// UnsafeAPIServer may be embedded to opt out of forward compatibility for this service.
|
||||
// Use of this interface is not recommended, as added methods to APIServer will
|
||||
// result in compilation errors.
|
||||
type UnsafeAPIServer interface {
|
||||
mustEmbedUnimplementedAPIServer()
|
||||
}
|
||||
|
||||
func RegisterAPIServer(s grpc.ServiceRegistrar, srv APIServer) {
|
||||
s.RegisterService(&API_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _API_PushStateDiskKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(PushStateDiskKeyRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(APIServer).PushStateDiskKey(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/keyproto.API/PushStateDiskKey",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(APIServer).PushStateDiskKey(ctx, req.(*PushStateDiskKeyRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// API_ServiceDesc is the grpc.ServiceDesc for API service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
var API_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "keyproto.API",
|
||||
HandlerType: (*APIServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "PushStateDiskKey",
|
||||
Handler: _API_PushStateDiskKey_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "keyservice.proto",
|
||||
}
|
@ -2,41 +2,97 @@ package keyservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
azurecloud "github.com/edgelesssys/constellation/coordinator/cloudprovider/azure"
|
||||
gcpcloud "github.com/edgelesssys/constellation/coordinator/cloudprovider/gcp"
|
||||
"github.com/edgelesssys/constellation/coordinator/config"
|
||||
"github.com/edgelesssys/constellation/coordinator/core"
|
||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||
"github.com/edgelesssys/constellation/coordinator/role"
|
||||
"github.com/edgelesssys/constellation/state/keyservice/keyproto"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// keyAPI is the interface called by the Coordinator or an admin during restart of a node.
|
||||
type keyAPI struct {
|
||||
metadata core.ProviderMetadata
|
||||
// KeyAPI is the interface called by the Coordinator or an admin during restart of a node.
|
||||
type KeyAPI struct {
|
||||
mux sync.Mutex
|
||||
metadata core.ProviderMetadata
|
||||
issuer core.QuoteIssuer
|
||||
key []byte
|
||||
keyReceived chan bool
|
||||
keyReceived chan struct{}
|
||||
timeout time.Duration
|
||||
keyproto.UnimplementedAPIServer
|
||||
}
|
||||
|
||||
func (a *keyAPI) waitForDecryptionKey() {
|
||||
// go server.Start()
|
||||
// block until a key is pushed
|
||||
if <-a.keyReceived {
|
||||
return
|
||||
// New initializes a KeyAPI with the given parameters.
|
||||
func New(issuer core.QuoteIssuer, metadata core.ProviderMetadata, timeout time.Duration) *KeyAPI {
|
||||
return &KeyAPI{
|
||||
metadata: metadata,
|
||||
issuer: issuer,
|
||||
keyReceived: make(chan struct{}, 1),
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *keyAPI) requestKeyFromCoordinator(uuid string, opts ...grpc.DialOption) error {
|
||||
// PushStateDiskKeyRequest is the rpc to push state disk decryption keys to a restarting node.
|
||||
func (a *KeyAPI) PushStateDiskKey(ctx context.Context, in *keyproto.PushStateDiskKeyRequest) (*keyproto.PushStateDiskKeyResponse, error) {
|
||||
a.mux.Lock()
|
||||
defer a.mux.Unlock()
|
||||
if len(a.key) != 0 {
|
||||
return nil, status.Error(codes.FailedPrecondition, "node already received a passphrase")
|
||||
}
|
||||
if len(in.StateDiskKey) != config.RNGLengthDefault {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "received invalid passphrase: expected length: %d, but got: %d", config.RNGLengthDefault, len(in.StateDiskKey))
|
||||
}
|
||||
|
||||
a.key = in.StateDiskKey
|
||||
a.keyReceived <- struct{}{}
|
||||
return &keyproto.PushStateDiskKeyResponse{}, nil
|
||||
}
|
||||
|
||||
// WaitForDecryptionKey notifies the Coordinator to send a decryption key and waits until a key is received.
|
||||
func (a *KeyAPI) WaitForDecryptionKey(uuid, listenAddr string) ([]byte, error) {
|
||||
if uuid == "" {
|
||||
return nil, errors.New("received no disk UUID")
|
||||
}
|
||||
|
||||
tlsConfig, err := atls.CreateAttestationServerTLSConfig(a.issuer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
server := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
|
||||
keyproto.RegisterAPIServer(server, a)
|
||||
listener, err := net.Listen("tcp", listenAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
log.Printf("Waiting for decryption key. Listening on: %s", listener.Addr().String())
|
||||
go server.Serve(listener)
|
||||
defer server.GracefulStop()
|
||||
|
||||
if err := a.requestKeyLoop(uuid); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return a.key, nil
|
||||
}
|
||||
|
||||
// ResetKey resets a previously set key.
|
||||
func (a *KeyAPI) ResetKey() {
|
||||
a.key = nil
|
||||
}
|
||||
|
||||
// requestKeyLoop continuously requests decryption keys from all available Coordinators, until the KeyAPI receives a key.
|
||||
func (a *KeyAPI) requestKeyLoop(uuid string, opts ...grpc.DialOption) error {
|
||||
// we do not perform attestation, since the restarting node does not need to care about notifying the correct Coordinator
|
||||
// if an incorrect key is pushed by a malicious actor, decrypting the disk will fail, and the node will not start
|
||||
tlsClientConfig, err := atls.CreateUnverifiedClientTLSConfig()
|
||||
@ -44,18 +100,34 @@ func (a *keyAPI) requestKeyFromCoordinator(uuid string, opts ...grpc.DialOption)
|
||||
return err
|
||||
}
|
||||
|
||||
// set up for the select statement to immediately request a key, skipping the initial delay caused by using a ticker
|
||||
firstReq := make(chan struct{}, 1)
|
||||
firstReq <- struct{}{}
|
||||
|
||||
ticker := time.NewTicker(a.timeout)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
// return if a key was received by any means
|
||||
// return if a key was received
|
||||
// a key can be send by
|
||||
// - a Coordinator, after the request rpc was received
|
||||
// - by a Constellation admin, at any time this loop is running on a node during boot
|
||||
case <-a.keyReceived:
|
||||
return nil
|
||||
default:
|
||||
case <-ticker.C:
|
||||
a.requestKey(uuid, tlsClientConfig, opts...)
|
||||
case <-firstReq:
|
||||
a.requestKey(uuid, tlsClientConfig, opts...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *KeyAPI) requestKey(uuid string, tlsClientConfig *tls.Config, opts ...grpc.DialOption) {
|
||||
// list available Coordinators
|
||||
endpoints, _ := core.CoordinatorEndpoints(context.Background(), a.metadata)
|
||||
// notify the all available Coordinators to send a key to the node
|
||||
|
||||
log.Printf("Sending a key request to available Coordinators: %v", endpoints)
|
||||
// notify all available Coordinators to send a key to the node
|
||||
// any errors encountered here will be ignored, and the calls retried after a timeout
|
||||
for _, endpoint := range endpoints {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), a.timeout)
|
||||
@ -68,72 +140,4 @@ func (a *keyAPI) requestKeyFromCoordinator(uuid string, opts ...grpc.DialOption)
|
||||
|
||||
cancel()
|
||||
}
|
||||
time.Sleep(a.timeout)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WaitForDecryptionKey notifies the Coordinator to send a decryption key and waits until a key is received.
|
||||
func WaitForDecryptionKey(csp, uuid string) ([]byte, error) {
|
||||
if uuid == "" {
|
||||
return nil, errors.New("received no disk UUID")
|
||||
}
|
||||
|
||||
keyWaiter := &keyAPI{
|
||||
keyReceived: make(chan bool, 1),
|
||||
timeout: 20 * time.Second, // try to request a key every 20 seconds
|
||||
}
|
||||
go keyWaiter.waitForDecryptionKey()
|
||||
|
||||
switch strings.ToLower(csp) {
|
||||
case "azure":
|
||||
metadata, err := azurecloud.NewMetadata(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyWaiter.metadata = metadata
|
||||
case "gcp":
|
||||
gcpClient, err := gcpcloud.NewClient(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyWaiter.metadata = gcpcloud.New(gcpClient)
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "warning: csp %q is not supported, unable to automatically request decryption keys\n", csp)
|
||||
keyWaiter.metadata = stubMetadata{}
|
||||
}
|
||||
|
||||
if err := keyWaiter.requestKeyFromCoordinator(uuid); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return keyWaiter.key, nil
|
||||
}
|
||||
|
||||
type stubMetadata struct {
|
||||
listResponse []core.Instance
|
||||
}
|
||||
|
||||
func (s stubMetadata) List(ctx context.Context) ([]core.Instance, error) {
|
||||
return s.listResponse, nil
|
||||
}
|
||||
|
||||
func (s stubMetadata) Self(ctx context.Context) (core.Instance, error) {
|
||||
return core.Instance{}, nil
|
||||
}
|
||||
|
||||
func (s stubMetadata) GetInstance(ctx context.Context, providerID string) (core.Instance, error) {
|
||||
return core.Instance{}, nil
|
||||
}
|
||||
|
||||
func (s stubMetadata) SignalRole(ctx context.Context, role role.Role) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s stubMetadata) SetVPNIP(ctx context.Context, vpnIP string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s stubMetadata) Supported() bool {
|
||||
return true
|
||||
}
|
||||
|
@ -2,22 +2,16 @@ package keyservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"github.com/edgelesssys/constellation/coordinator/core"
|
||||
"github.com/edgelesssys/constellation/coordinator/oid"
|
||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||
"github.com/edgelesssys/constellation/coordinator/role"
|
||||
"github.com/edgelesssys/constellation/coordinator/util"
|
||||
"github.com/edgelesssys/constellation/state/keyservice/keyproto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
@ -25,7 +19,7 @@ import (
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
)
|
||||
|
||||
func TestRequestLoop(t *testing.T) {
|
||||
func TestRequestKeyLoop(t *testing.T) {
|
||||
defaultInstance := core.Instance{
|
||||
Name: "test-instance",
|
||||
ProviderID: "/test/provider",
|
||||
@ -77,11 +71,11 @@ func TestRequestLoop(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
keyReceived := make(chan bool, 1)
|
||||
keyReceived := make(chan struct{}, 1)
|
||||
listener := bufconn.Listen(1)
|
||||
defer listener.Close()
|
||||
|
||||
tlsConfig, err := stubTLSConfig()
|
||||
tlsConfig, err := atls.CreateAttestationServerTLSConfig(core.NewMockIssuer())
|
||||
require.NoError(err)
|
||||
s := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
|
||||
pubproto.RegisterAPIServer(s, tc.server)
|
||||
@ -90,7 +84,7 @@ func TestRequestLoop(t *testing.T) {
|
||||
go func() { require.NoError(s.Serve(listener)) }()
|
||||
}
|
||||
|
||||
keyWaiter := &keyAPI{
|
||||
keyWaiter := &KeyAPI{
|
||||
metadata: stubMetadata{listResponse: tc.listResponse},
|
||||
keyReceived: keyReceived,
|
||||
timeout: 500 * time.Millisecond,
|
||||
@ -99,10 +93,10 @@ func TestRequestLoop(t *testing.T) {
|
||||
// notify the API a key was received after 1 second
|
||||
go func() {
|
||||
time.Sleep(1 * time.Second)
|
||||
keyReceived <- true
|
||||
keyReceived <- struct{}{}
|
||||
}()
|
||||
|
||||
err = keyWaiter.requestKeyFromCoordinator(
|
||||
err = keyWaiter.requestKeyLoop(
|
||||
"1234",
|
||||
grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) {
|
||||
return listener.DialContext(ctx)
|
||||
@ -115,6 +109,54 @@ func TestRequestLoop(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushStateDiskKey(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
testAPI *KeyAPI
|
||||
request *keyproto.PushStateDiskKeyRequest
|
||||
errExpected bool
|
||||
}{
|
||||
"success": {
|
||||
testAPI: &KeyAPI{keyReceived: make(chan struct{}, 1)},
|
||||
request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")},
|
||||
},
|
||||
"key already set": {
|
||||
testAPI: &KeyAPI{
|
||||
keyReceived: make(chan struct{}, 1),
|
||||
key: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
|
||||
},
|
||||
request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB")},
|
||||
errExpected: true,
|
||||
},
|
||||
"incorrect size of pushed key": {
|
||||
testAPI: &KeyAPI{keyReceived: make(chan struct{}, 1)},
|
||||
request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("AAAAAAAAAAAAAAAA")},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
_, err := tc.testAPI.PushStateDiskKey(context.Background(), tc.request)
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.request.StateDiskKey, tc.testAPI.key)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetKey(t *testing.T) {
|
||||
api := New(nil, nil, time.Second)
|
||||
|
||||
api.key = []byte{0x1, 0x2, 0x3}
|
||||
api.ResetKey()
|
||||
assert.Nil(t, api.key)
|
||||
}
|
||||
|
||||
type stubAPIServer struct {
|
||||
requestStateDiskKeyResp *pubproto.RequestStateDiskKeyResponse
|
||||
requestStateDiskKeyErr error
|
||||
@ -149,30 +191,30 @@ func (s *stubAPIServer) RequestStateDiskKey(ctx context.Context, in *pubproto.Re
|
||||
return s.requestStateDiskKeyResp, s.requestStateDiskKeyErr
|
||||
}
|
||||
|
||||
func stubTLSConfig() (*tls.Config, error) {
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
getCertificate := func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
serialNumber, err := util.GenerateCertificateSerialNumber()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
now := time.Now()
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{CommonName: "Constellation"},
|
||||
NotBefore: now.Add(-2 * time.Hour),
|
||||
NotAfter: now.Add(2 * time.Hour),
|
||||
ExtraExtensions: []pkix.Extension{{Id: oid.Dummy{}.OID(), Value: []byte{0x1, 0x2, 0x3}}},
|
||||
}
|
||||
cert, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &tls.Certificate{Certificate: [][]byte{cert}, PrivateKey: priv}, nil
|
||||
}
|
||||
return &tls.Config{GetCertificate: getCertificate, MinVersion: tls.VersionTLS12}, nil
|
||||
type stubMetadata struct {
|
||||
listResponse []core.Instance
|
||||
}
|
||||
|
||||
func (s stubMetadata) List(ctx context.Context) ([]core.Instance, error) {
|
||||
return s.listResponse, nil
|
||||
}
|
||||
|
||||
func (s stubMetadata) Self(ctx context.Context) (core.Instance, error) {
|
||||
return core.Instance{}, nil
|
||||
}
|
||||
|
||||
func (s stubMetadata) GetInstance(ctx context.Context, providerID string) (core.Instance, error) {
|
||||
return core.Instance{}, nil
|
||||
}
|
||||
|
||||
func (s stubMetadata) SignalRole(ctx context.Context, role role.Role) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s stubMetadata) SetVPNIP(ctx context.Context, vpnIP string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s stubMetadata) Supported() bool {
|
||||
return true
|
||||
}
|
||||
|
@ -3,17 +3,10 @@ package mapper
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
cryptsetup "github.com/martinjungblut/go-cryptsetup"
|
||||
)
|
||||
|
||||
const (
|
||||
gcpStateDiskPath = "/dev/disk/by-id/google-state-disk"
|
||||
azureStateDiskPath = "/dev/disk/azure/scsi1/lun0"
|
||||
fallBackPath = "/dev/disk/by-id/state-disk"
|
||||
)
|
||||
|
||||
// Mapper handles actions for formating and mapping crypt devices.
|
||||
type Mapper struct {
|
||||
device cryptDevice
|
||||
@ -89,28 +82,3 @@ func (m *Mapper) MapDisk(target, passphrase string) error {
|
||||
func (m *Mapper) UnmapDisk(target string) error {
|
||||
return m.device.Deactivate(target)
|
||||
}
|
||||
|
||||
// GetDiskPath returns the device path of the data disk by cloud provider.
|
||||
//
|
||||
// For GCP a symlink to the disk is expected at /dev/disk/by-id/google-state-disk
|
||||
// For Azure a symlink to the disk is expected at /dev/disk/azure/scsi1/lun0
|
||||
// If no symlink can be found at the given path, or if no known cloud provider is supplied,
|
||||
// we instead return the device path of the os-disk stateful partition at /dev/disk/by-partlabel/stateful.
|
||||
func GetDiskPath(csp string) (string, error) {
|
||||
var diskPath string
|
||||
var err error
|
||||
|
||||
switch csp {
|
||||
case "gcp":
|
||||
diskPath, err = filepath.EvalSymlinks(gcpStateDiskPath)
|
||||
case "azure":
|
||||
diskPath, err = filepath.EvalSymlinks(azureStateDiskPath)
|
||||
default:
|
||||
diskPath = fallBackPath
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return filepath.EvalSymlinks(fallBackPath)
|
||||
}
|
||||
return diskPath, nil
|
||||
}
|
||||
|
45
state/setup/interface.go
Normal file
45
state/setup/interface.go
Normal file
@ -0,0 +1,45 @@
|
||||
package setup
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// Mounter is an interface for mount and unmount operations.
|
||||
type Mounter interface {
|
||||
Mount(source string, target string, fstype string, flags uintptr, data string) error
|
||||
Unmount(target string, flags int) error
|
||||
MkdirAll(path string, perm fs.FileMode) error
|
||||
}
|
||||
|
||||
// DeviceMapper is an interface for device mapping operations.
|
||||
type DeviceMapper interface {
|
||||
DiskUUID() string
|
||||
FormatDisk(passphrase string) error
|
||||
MapDisk(target string, passphrase string) error
|
||||
}
|
||||
|
||||
// KeyWaiter is an interface to request and wait for disk decryption keys.
|
||||
type KeyWaiter interface {
|
||||
WaitForDecryptionKey(uuid, addr string) ([]byte, error)
|
||||
ResetKey()
|
||||
}
|
||||
|
||||
// DiskMounter uses the syscall package to mount disks.
|
||||
type DiskMounter struct{}
|
||||
|
||||
// Mount performs a mount syscall.
|
||||
func (m DiskMounter) Mount(source string, target string, fstype string, flags uintptr, data string) error {
|
||||
return syscall.Mount(source, target, fstype, flags, data)
|
||||
}
|
||||
|
||||
// Unmount performs an unmount syscall.
|
||||
func (m DiskMounter) Unmount(target string, flags int) error {
|
||||
return syscall.Unmount(target, flags)
|
||||
}
|
||||
|
||||
// MkdirAll uses os.MkdirAll to create the directory.
|
||||
func (m DiskMounter) MkdirAll(path string, perm fs.FileMode) error {
|
||||
return os.MkdirAll(path, perm)
|
||||
}
|
125
state/setup/setup.go
Normal file
125
state/setup/setup.go
Normal file
@ -0,0 +1,125 @@
|
||||
package setup
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/file"
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
|
||||
"github.com/edgelesssys/constellation/coordinator/config"
|
||||
"github.com/edgelesssys/constellation/coordinator/nodestate"
|
||||
"github.com/spf13/afero"
|
||||
)
|
||||
|
||||
const (
|
||||
RecoveryPort = "9000"
|
||||
keyPath = "/run/cryptsetup-keys.d"
|
||||
keyFile = "state.key"
|
||||
stateDiskMappedName = "state"
|
||||
stateDiskMountPath = "/var/run/state"
|
||||
stateInfoPath = stateDiskMountPath + "/constellation/node_state.json"
|
||||
)
|
||||
|
||||
// SetupManager handles formating, mapping, mounting and unmounting of state disks.
|
||||
type SetupManager struct {
|
||||
csp string
|
||||
fs afero.Afero
|
||||
keyWaiter KeyWaiter
|
||||
mapper DeviceMapper
|
||||
mounter Mounter
|
||||
openTPM vtpm.TPMOpenFunc
|
||||
}
|
||||
|
||||
// New initializes a SetupManager with the given parameters.
|
||||
func New(csp string, fs afero.Afero, keyWaiter KeyWaiter, mapper DeviceMapper, mounter Mounter, openTPM vtpm.TPMOpenFunc) *SetupManager {
|
||||
return &SetupManager{
|
||||
csp: csp,
|
||||
fs: fs,
|
||||
keyWaiter: keyWaiter,
|
||||
mapper: mapper,
|
||||
mounter: mounter,
|
||||
openTPM: openTPM,
|
||||
}
|
||||
}
|
||||
|
||||
// PrepareExistingDisk requests and waits for a decryption key to remap the encrypted state disk.
|
||||
// Once the disk is mapped, the function taints the node as initialized by updating it's PCRs.
|
||||
func (s *SetupManager) PrepareExistingDisk() error {
|
||||
log.Println("Preparing existing state disk")
|
||||
uuid := s.mapper.DiskUUID()
|
||||
|
||||
getKey:
|
||||
passphrase, err := s.keyWaiter.WaitForDecryptionKey(uuid, net.JoinHostPort("0.0.0.0", RecoveryPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.mapper.MapDisk(stateDiskMappedName, string(passphrase)); err != nil {
|
||||
// retry key fetching if disk mapping fails
|
||||
s.keyWaiter.ResetKey()
|
||||
goto getKey
|
||||
}
|
||||
|
||||
if err := s.mounter.MkdirAll(stateDiskMountPath, os.ModePerm); err != nil {
|
||||
return err
|
||||
}
|
||||
// we do not care about cleaning up the mount point on error, since any errors returned here should result in a kernel panic in the main function
|
||||
if err := s.mounter.Mount(filepath.Join("/dev/mapper/", stateDiskMappedName), stateDiskMountPath, "ext4", syscall.MS_RDONLY, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ownerID, clusterID, err := s.readInitSecrets(stateInfoPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// taint the node as initialized
|
||||
if err := vtpm.MarkNodeAsInitialized(s.openTPM, ownerID, clusterID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.mounter.Unmount(stateDiskMountPath, 0)
|
||||
}
|
||||
|
||||
// PrepareNewDisk prepares an instances state disk by formatting the disk as a LUKS device using a random passphrase.
|
||||
func (s *SetupManager) PrepareNewDisk() error {
|
||||
log.Println("Preparing new state disk")
|
||||
|
||||
// generate and save temporary passphrase
|
||||
if err := s.fs.MkdirAll(keyPath, os.ModePerm); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
passphrase := make([]byte, config.RNGLengthDefault)
|
||||
if _, err := rand.Read(passphrase); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.fs.WriteFile(filepath.Join(keyPath, keyFile), passphrase, 0o400); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.mapper.FormatDisk(string(passphrase)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.mapper.MapDisk(stateDiskMappedName, string(passphrase))
|
||||
}
|
||||
|
||||
func (s *SetupManager) readInitSecrets(path string) ([]byte, []byte, error) {
|
||||
handler := file.NewHandler(s.fs)
|
||||
var state nodestate.NodeState
|
||||
if err := handler.ReadJSON(path, &state); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if len(state.ClusterID) == 0 || len(state.OwnerID) == 0 {
|
||||
return nil, nil, errors.New("missing state information to retaint node")
|
||||
}
|
||||
|
||||
return state.OwnerID, state.ClusterID, nil
|
||||
}
|
317
state/setup/setup_test.go
Normal file
317
state/setup/setup_test.go
Normal file
@ -0,0 +1,317 @@
|
||||
package setup
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"io/fs"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/file"
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
|
||||
"github.com/edgelesssys/constellation/coordinator/config"
|
||||
"github.com/edgelesssys/constellation/coordinator/nodestate"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPrepareExistingDisk(t *testing.T) {
|
||||
someErr := errors.New("error")
|
||||
|
||||
testCases := map[string]struct {
|
||||
fs afero.Afero
|
||||
keyWaiter *stubKeyWaiter
|
||||
mapper *stubMapper
|
||||
mounter *stubMounter
|
||||
openTPM vtpm.TPMOpenFunc
|
||||
missingState bool
|
||||
errExpected bool
|
||||
}{
|
||||
"success": {
|
||||
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
||||
keyWaiter: &stubKeyWaiter{},
|
||||
mapper: &stubMapper{uuid: "test"},
|
||||
mounter: &stubMounter{},
|
||||
openTPM: vtpm.OpenNOPTPM,
|
||||
},
|
||||
"WaitForDecryptionKey fails": {
|
||||
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
||||
keyWaiter: &stubKeyWaiter{waitErr: someErr},
|
||||
mapper: &stubMapper{uuid: "test"},
|
||||
mounter: &stubMounter{},
|
||||
openTPM: vtpm.OpenNOPTPM,
|
||||
errExpected: true,
|
||||
},
|
||||
"MapDisk fails causes a repeat": {
|
||||
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
||||
keyWaiter: &stubKeyWaiter{},
|
||||
mapper: &stubMapper{
|
||||
uuid: "test",
|
||||
mapDiskErr: someErr,
|
||||
mapDiskRepeatedCalls: 2,
|
||||
},
|
||||
mounter: &stubMounter{},
|
||||
openTPM: vtpm.OpenNOPTPM,
|
||||
errExpected: false,
|
||||
},
|
||||
"MkdirAll fails": {
|
||||
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
||||
keyWaiter: &stubKeyWaiter{},
|
||||
mapper: &stubMapper{uuid: "test"},
|
||||
mounter: &stubMounter{mkdirAllErr: someErr},
|
||||
openTPM: vtpm.OpenNOPTPM,
|
||||
errExpected: true,
|
||||
},
|
||||
"Mount fails": {
|
||||
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
||||
keyWaiter: &stubKeyWaiter{},
|
||||
mapper: &stubMapper{uuid: "test"},
|
||||
mounter: &stubMounter{mountErr: someErr},
|
||||
openTPM: vtpm.OpenNOPTPM,
|
||||
errExpected: true,
|
||||
},
|
||||
"Unmount fails": {
|
||||
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
||||
keyWaiter: &stubKeyWaiter{},
|
||||
mapper: &stubMapper{uuid: "test"},
|
||||
mounter: &stubMounter{unmountErr: someErr},
|
||||
openTPM: vtpm.OpenNOPTPM,
|
||||
errExpected: true,
|
||||
},
|
||||
"MarkNodeAsInitialized fails": {
|
||||
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
||||
keyWaiter: &stubKeyWaiter{},
|
||||
mapper: &stubMapper{uuid: "test"},
|
||||
mounter: &stubMounter{unmountErr: someErr},
|
||||
openTPM: failOpener,
|
||||
errExpected: true,
|
||||
},
|
||||
"no state file": {
|
||||
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
||||
keyWaiter: &stubKeyWaiter{},
|
||||
mapper: &stubMapper{uuid: "test"},
|
||||
mounter: &stubMounter{},
|
||||
openTPM: vtpm.OpenNOPTPM,
|
||||
missingState: true,
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
if !tc.missingState {
|
||||
handler := file.NewHandler(tc.fs)
|
||||
require.NoError(t, handler.WriteJSON(stateInfoPath, nodestate.NodeState{OwnerID: []byte("ownerID"), ClusterID: []byte("clusterID")}, file.OptMkdirAll))
|
||||
}
|
||||
|
||||
setupManager := New("test", tc.fs, tc.keyWaiter, tc.mapper, tc.mounter, tc.openTPM)
|
||||
|
||||
err := setupManager.PrepareExistingDisk()
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.mapper.uuid, tc.keyWaiter.receivedUUID)
|
||||
assert.True(tc.mapper.mapDiskCalled)
|
||||
assert.True(tc.mounter.mountCalled)
|
||||
assert.True(tc.mounter.unmountCalled)
|
||||
assert.False(tc.mapper.formatDiskCalled)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func failOpener() (io.ReadWriteCloser, error) {
|
||||
return nil, errors.New("error")
|
||||
}
|
||||
|
||||
func TestPrepareNewDisk(t *testing.T) {
|
||||
someErr := errors.New("error")
|
||||
testCases := map[string]struct {
|
||||
fs afero.Afero
|
||||
mapper *stubMapper
|
||||
errExpected bool
|
||||
}{
|
||||
"success": {
|
||||
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
||||
mapper: &stubMapper{uuid: "test"},
|
||||
},
|
||||
"creating directory fails": {
|
||||
fs: afero.Afero{Fs: afero.NewReadOnlyFs(afero.NewMemMapFs())},
|
||||
mapper: &stubMapper{},
|
||||
errExpected: true,
|
||||
},
|
||||
"FormatDisk fails": {
|
||||
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
||||
mapper: &stubMapper{
|
||||
uuid: "test",
|
||||
formatDiskErr: someErr,
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"MapDisk fails": {
|
||||
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
||||
mapper: &stubMapper{
|
||||
uuid: "test",
|
||||
mapDiskErr: someErr,
|
||||
mapDiskRepeatedCalls: 1,
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
setupManager := New("test", tc.fs, nil, tc.mapper, nil, nil)
|
||||
|
||||
err := setupManager.PrepareNewDisk()
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.True(tc.mapper.formatDiskCalled)
|
||||
assert.True(tc.mapper.mapDiskCalled)
|
||||
|
||||
data, err := tc.fs.ReadFile(filepath.Join(keyPath, keyFile))
|
||||
require.NoError(t, err)
|
||||
assert.Len(data, config.RNGLengthDefault)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadInitSecrets(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
fs afero.Afero
|
||||
ownerID string
|
||||
clusterID string
|
||||
writeFile bool
|
||||
errExpected bool
|
||||
}{
|
||||
"success": {
|
||||
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
||||
ownerID: "ownerID",
|
||||
clusterID: "clusterID",
|
||||
writeFile: true,
|
||||
},
|
||||
"no state file": {
|
||||
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
||||
errExpected: true,
|
||||
},
|
||||
"missing ownerID": {
|
||||
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
||||
clusterID: "clusterID",
|
||||
writeFile: true,
|
||||
errExpected: true,
|
||||
},
|
||||
"missing clusterID": {
|
||||
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
||||
ownerID: "ownerID",
|
||||
writeFile: true,
|
||||
errExpected: true,
|
||||
},
|
||||
"no IDs": {
|
||||
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
||||
writeFile: true,
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
if tc.writeFile {
|
||||
handler := file.NewHandler(tc.fs)
|
||||
state := nodestate.NodeState{ClusterID: []byte(tc.clusterID), OwnerID: []byte(tc.ownerID)}
|
||||
require.NoError(handler.WriteJSON("/tmp/test-state.json", state, file.OptMkdirAll))
|
||||
}
|
||||
|
||||
setupManager := New("test", tc.fs, nil, nil, nil, nil)
|
||||
|
||||
ownerID, clusterID, err := setupManager.readInitSecrets("/tmp/test-state.json")
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.Equal([]byte(tc.ownerID), ownerID)
|
||||
assert.Equal([]byte(tc.clusterID), clusterID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type stubMapper struct {
|
||||
formatDiskCalled bool
|
||||
formatDiskErr error
|
||||
mapDiskRepeatedCalls int
|
||||
mapDiskCalled bool
|
||||
mapDiskErr error
|
||||
uuid string
|
||||
}
|
||||
|
||||
func (s *stubMapper) DiskUUID() string {
|
||||
return s.uuid
|
||||
}
|
||||
|
||||
func (s *stubMapper) FormatDisk(passphrase string) error {
|
||||
s.formatDiskCalled = true
|
||||
return s.formatDiskErr
|
||||
}
|
||||
|
||||
func (s *stubMapper) MapDisk(target string, passphrase string) error {
|
||||
if s.mapDiskRepeatedCalls == 0 {
|
||||
s.mapDiskErr = nil
|
||||
}
|
||||
s.mapDiskRepeatedCalls--
|
||||
s.mapDiskCalled = true
|
||||
return s.mapDiskErr
|
||||
}
|
||||
|
||||
type stubMounter struct {
|
||||
mountCalled bool
|
||||
mountErr error
|
||||
unmountCalled bool
|
||||
unmountErr error
|
||||
mkdirAllErr error
|
||||
}
|
||||
|
||||
func (s *stubMounter) Mount(source string, target string, fstype string, flags uintptr, data string) error {
|
||||
s.mountCalled = true
|
||||
return s.mountErr
|
||||
}
|
||||
|
||||
func (s *stubMounter) Unmount(target string, flags int) error {
|
||||
s.unmountCalled = true
|
||||
return s.unmountErr
|
||||
}
|
||||
|
||||
func (s *stubMounter) MkdirAll(path string, perm fs.FileMode) error {
|
||||
return s.mkdirAllErr
|
||||
}
|
||||
|
||||
type stubKeyWaiter struct {
|
||||
receivedUUID string
|
||||
decryptionKey []byte
|
||||
waitErr error
|
||||
waitCalled bool
|
||||
}
|
||||
|
||||
func (s *stubKeyWaiter) WaitForDecryptionKey(uuid, addr string) ([]byte, error) {
|
||||
if s.waitCalled {
|
||||
return nil, errors.New("wait called before key was reset")
|
||||
}
|
||||
s.waitCalled = true
|
||||
s.receivedUUID = uuid
|
||||
return s.decryptionKey, s.waitErr
|
||||
}
|
||||
|
||||
func (s *stubKeyWaiter) ResetKey() {
|
||||
s.waitCalled = false
|
||||
}
|
@ -3,15 +3,24 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"github.com/edgelesssys/constellation/coordinator/core"
|
||||
"github.com/edgelesssys/constellation/state/keyservice"
|
||||
"github.com/edgelesssys/constellation/state/keyservice/keyproto"
|
||||
"github.com/edgelesssys/constellation/state/mapper"
|
||||
"github.com/martinjungblut/go-cryptsetup"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -62,3 +71,42 @@ func TestMapper(t *testing.T) {
|
||||
// Try to map disk with incorrect passphrase
|
||||
assert.Error(mapper.MapDisk(mappedDevice, "invalid-passphrase"), "was able to map disk with incorrect passphrase")
|
||||
}
|
||||
|
||||
func TestKeyAPI(t *testing.T) {
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
|
||||
testKey := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
|
||||
|
||||
// get a free port on localhost to run the test on
|
||||
listener, err := net.Listen("tcp", "localhost:0")
|
||||
require.NoError(err)
|
||||
apiAddr := listener.Addr().String()
|
||||
listener.Close()
|
||||
|
||||
api := keyservice.New(core.NewMockIssuer(), &core.ProviderMetadataFake{}, 20*time.Second)
|
||||
|
||||
// send a key to the server
|
||||
go func() {
|
||||
// wait 2 seconds before sending the key
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
clientCfg, err := atls.CreateUnverifiedClientTLSConfig()
|
||||
require.NoError(err)
|
||||
conn, err := grpc.Dial(apiAddr, grpc.WithTransportCredentials(credentials.NewTLS(clientCfg)))
|
||||
require.NoError(err)
|
||||
defer conn.Close()
|
||||
|
||||
client := keyproto.NewAPIClient(conn)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
||||
defer cancel()
|
||||
_, err = client.PushStateDiskKey(ctx, &keyproto.PushStateDiskKeyRequest{
|
||||
StateDiskKey: testKey,
|
||||
})
|
||||
require.NoError(err)
|
||||
}()
|
||||
|
||||
key, err := api.WaitForDecryptionKey("12345678-1234-1234-1234-123456789ABC", apiAddr)
|
||||
assert.NoError(err)
|
||||
assert.Equal(testKey, key)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user