diff --git a/proto/Dockerfile.gen-proto b/proto/Dockerfile.gen-proto index e70f5ae9f..b1f68cbc7 100644 --- a/proto/Dockerfile.gen-proto +++ b/proto/Dockerfile.gen-proto @@ -36,7 +36,7 @@ RUN protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_o ## disk-mapper keyservice api WORKDIR /disk-mapper -COPY state/keyservice/proto/*.proto /disk-mapper +COPY state/keyservice/keyproto/*.proto /disk-mapper RUN protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative *.proto ## debugd service @@ -48,5 +48,5 @@ RUN protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_o FROM scratch as export COPY --from=build /pubapi/*.go coordinator/pubapi/pubproto/ COPY --from=build /vpnapi/*.go coordinator/vpnapi/vpnproto/ -COPY --from=build /disk-mapper/*.go state/keyservice/proto/ +COPY --from=build /disk-mapper/*.go state/keyservice/keyproto/ COPY --from=build /service/*.go debugd/service/ diff --git a/state/cmd/main.go b/state/cmd/main.go index 3ee4b995c..4eff002eb 100644 --- a/state/cmd/main.go +++ b/state/cmd/main.go @@ -1,68 +1,103 @@ package main import ( - "crypto/rand" + "context" "flag" + "fmt" "os" "path/filepath" + "strings" + "time" - "github.com/edgelesssys/constellation/coordinator/config" + "github.com/edgelesssys/constellation/coordinator/attestation/azure" + "github.com/edgelesssys/constellation/coordinator/attestation/gcp" + "github.com/edgelesssys/constellation/coordinator/attestation/vtpm" + azurecloud "github.com/edgelesssys/constellation/coordinator/cloudprovider/azure" + gcpcloud "github.com/edgelesssys/constellation/coordinator/cloudprovider/gcp" + "github.com/edgelesssys/constellation/coordinator/core" "github.com/edgelesssys/constellation/internal/utils" "github.com/edgelesssys/constellation/state/keyservice" "github.com/edgelesssys/constellation/state/mapper" + "github.com/edgelesssys/constellation/state/setup" + "github.com/spf13/afero" ) const ( - keyPath = "/run/cryptsetup-keys.d" - keyFile = "state.key" + gcpStateDiskPath = "/dev/disk/by-id/google-state-disk" + azureStateDiskPath = "/dev/disk/azure/scsi1/lun0" + fallBackPath = "/dev/disk/by-id/state-disk" ) var csp = flag.String("csp", "", "Cloud Service Provider the image is running on") func main() { flag.Parse() - diskPath, err := mapper.GetDiskPath(*csp) - if err != nil { - utils.KernelPanic(err) + + // set up metadata API and quote issuer for aTLS connections + var err error + var diskPathErr error + var diskPath string + var issuer core.QuoteIssuer + var metadata core.ProviderMetadata + switch strings.ToLower(*csp) { + case "azure": + diskPath, diskPathErr = filepath.EvalSymlinks(azureStateDiskPath) + metadata, err = azurecloud.NewMetadata(context.Background()) + if err != nil { + utils.KernelPanic(err) + } + issuer = azure.NewIssuer() + + case "gcp": + diskPath, diskPathErr = filepath.EvalSymlinks(gcpStateDiskPath) + issuer = gcp.NewIssuer() + gcpClient, err := gcpcloud.NewClient(context.Background()) + if err != nil { + utils.KernelPanic(err) + } + metadata = gcpcloud.New(gcpClient) + + default: + diskPath, err = filepath.EvalSymlinks(fallBackPath) + if err != nil { + utils.KernelPanic(err) + } + issuer = core.NewMockIssuer() + fmt.Fprintf(os.Stderr, "warning: csp %q is not supported, unable to automatically request decryption keys on reboot\n", *csp) + metadata = &core.ProviderMetadataFake{} + } + if diskPathErr != nil { + fmt.Fprintf(os.Stderr, "warning: no attached disk detected, trying to use boot-disk state partition as fallback") + diskPath, err = filepath.EvalSymlinks(fallBackPath) + if err != nil { + utils.KernelPanic(err) + } } + // initialize device mapper mapper, err := mapper.New(diskPath) if err != nil { utils.KernelPanic(err) } defer mapper.Close() + setupManger := setup.New( + *csp, + afero.Afero{Fs: afero.NewOsFs()}, + keyservice.New(issuer, metadata, 20*time.Second), // try to request a key every 20 seconds + mapper, + setup.DiskMounter{}, + vtpm.OpenVTPM, + ) + + // prepare the state disk if mapper.IsLUKSDevice() { - uuid := mapper.DiskUUID() - _, err = keyservice.WaitForDecryptionKey(*csp, uuid) + err = setupManger.PrepareExistingDisk() } else { - err = formatDisk(mapper) + err = setupManger.PrepareNewDisk() } if err != nil { utils.KernelPanic(err) } } - -func formatDisk(mapper *mapper.Mapper) error { - // generate and save temporary passphrase - if err := os.MkdirAll(keyPath, os.ModePerm); err != nil { - utils.KernelPanic(err) - } - passphrase := make([]byte, config.RNGLengthDefault) - if _, err := rand.Read(passphrase); err != nil { - utils.KernelPanic(err) - } - if err := os.WriteFile(filepath.Join(keyPath, keyFile), passphrase, 0o400); err != nil { - utils.KernelPanic(err) - } - - if err := mapper.FormatDisk(string(passphrase)); err != nil { - utils.KernelPanic(err) - } - if err := mapper.MapDisk("state", string(passphrase)); err != nil { - utils.KernelPanic(err) - } - - return nil -} diff --git a/state/keyservice/keyproto/keyservice.pb.go b/state/keyservice/keyproto/keyservice.pb.go new file mode 100644 index 000000000..20e7f1f0b --- /dev/null +++ b/state/keyservice/keyproto/keyservice.pb.go @@ -0,0 +1,208 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.27.1 +// protoc v3.17.3 +// source: keyservice.proto + +package keyproto + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type PushStateDiskKeyRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + StateDiskKey []byte `protobuf:"bytes,1,opt,name=state_disk_key,json=stateDiskKey,proto3" json:"state_disk_key,omitempty"` +} + +func (x *PushStateDiskKeyRequest) Reset() { + *x = PushStateDiskKeyRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_keyservice_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *PushStateDiskKeyRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PushStateDiskKeyRequest) ProtoMessage() {} + +func (x *PushStateDiskKeyRequest) ProtoReflect() protoreflect.Message { + mi := &file_keyservice_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PushStateDiskKeyRequest.ProtoReflect.Descriptor instead. +func (*PushStateDiskKeyRequest) Descriptor() ([]byte, []int) { + return file_keyservice_proto_rawDescGZIP(), []int{0} +} + +func (x *PushStateDiskKeyRequest) GetStateDiskKey() []byte { + if x != nil { + return x.StateDiskKey + } + return nil +} + +type PushStateDiskKeyResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *PushStateDiskKeyResponse) Reset() { + *x = PushStateDiskKeyResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_keyservice_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *PushStateDiskKeyResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PushStateDiskKeyResponse) ProtoMessage() {} + +func (x *PushStateDiskKeyResponse) ProtoReflect() protoreflect.Message { + mi := &file_keyservice_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PushStateDiskKeyResponse.ProtoReflect.Descriptor instead. +func (*PushStateDiskKeyResponse) Descriptor() ([]byte, []int) { + return file_keyservice_proto_rawDescGZIP(), []int{1} +} + +var File_keyservice_proto protoreflect.FileDescriptor + +var file_keyservice_proto_rawDesc = []byte{ + 0x0a, 0x10, 0x6b, 0x65, 0x79, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x12, 0x08, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x3f, 0x0a, 0x17, + 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x24, 0x0a, 0x0e, 0x73, 0x74, 0x61, 0x74, 0x65, + 0x5f, 0x64, 0x69, 0x73, 0x6b, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, + 0x0c, 0x73, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79, 0x22, 0x1a, 0x0a, + 0x18, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, + 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x32, 0x60, 0x0a, 0x03, 0x41, 0x50, 0x49, + 0x12, 0x59, 0x0a, 0x10, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, + 0x6b, 0x4b, 0x65, 0x79, 0x12, 0x21, 0x2e, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, + 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x22, 0x2e, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x2e, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x61, 0x74, 0x65, 0x44, 0x69, 0x73, 0x6b, + 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x40, 0x5a, 0x3e, 0x67, + 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x65, 0x64, 0x67, 0x65, 0x6c, 0x65, + 0x73, 0x73, 0x73, 0x79, 0x73, 0x2f, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x65, 0x6c, 0x6c, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x2f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x2f, 0x6b, 0x65, 0x79, 0x73, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x2f, 0x6b, 0x65, 0x79, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_keyservice_proto_rawDescOnce sync.Once + file_keyservice_proto_rawDescData = file_keyservice_proto_rawDesc +) + +func file_keyservice_proto_rawDescGZIP() []byte { + file_keyservice_proto_rawDescOnce.Do(func() { + file_keyservice_proto_rawDescData = protoimpl.X.CompressGZIP(file_keyservice_proto_rawDescData) + }) + return file_keyservice_proto_rawDescData +} + +var file_keyservice_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_keyservice_proto_goTypes = []interface{}{ + (*PushStateDiskKeyRequest)(nil), // 0: keyproto.PushStateDiskKeyRequest + (*PushStateDiskKeyResponse)(nil), // 1: keyproto.PushStateDiskKeyResponse +} +var file_keyservice_proto_depIdxs = []int32{ + 0, // 0: keyproto.API.PushStateDiskKey:input_type -> keyproto.PushStateDiskKeyRequest + 1, // 1: keyproto.API.PushStateDiskKey:output_type -> keyproto.PushStateDiskKeyResponse + 1, // [1:2] is the sub-list for method output_type + 0, // [0:1] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_keyservice_proto_init() } +func file_keyservice_proto_init() { + if File_keyservice_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_keyservice_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*PushStateDiskKeyRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_keyservice_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*PushStateDiskKeyResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_keyservice_proto_rawDesc, + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_keyservice_proto_goTypes, + DependencyIndexes: file_keyservice_proto_depIdxs, + MessageInfos: file_keyservice_proto_msgTypes, + }.Build() + File_keyservice_proto = out.File + file_keyservice_proto_rawDesc = nil + file_keyservice_proto_goTypes = nil + file_keyservice_proto_depIdxs = nil +} diff --git a/state/keyservice/keyproto/keyservice.proto b/state/keyservice/keyproto/keyservice.proto new file mode 100644 index 000000000..45befba71 --- /dev/null +++ b/state/keyservice/keyproto/keyservice.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +package keyproto; + +option go_package = "github.com/edgelesssys/constellation/state/keyservice/keyproto"; + +service API { + rpc PushStateDiskKey(PushStateDiskKeyRequest) returns (PushStateDiskKeyResponse); +} + +message PushStateDiskKeyRequest { + bytes state_disk_key = 1; +} + +message PushStateDiskKeyResponse { +} diff --git a/state/keyservice/keyproto/keyservice_grpc.pb.go b/state/keyservice/keyproto/keyservice_grpc.pb.go new file mode 100644 index 000000000..242f4e8ab --- /dev/null +++ b/state/keyservice/keyproto/keyservice_grpc.pb.go @@ -0,0 +1,101 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. + +package keyproto + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 + +// APIClient is the client API for API service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type APIClient interface { + PushStateDiskKey(ctx context.Context, in *PushStateDiskKeyRequest, opts ...grpc.CallOption) (*PushStateDiskKeyResponse, error) +} + +type aPIClient struct { + cc grpc.ClientConnInterface +} + +func NewAPIClient(cc grpc.ClientConnInterface) APIClient { + return &aPIClient{cc} +} + +func (c *aPIClient) PushStateDiskKey(ctx context.Context, in *PushStateDiskKeyRequest, opts ...grpc.CallOption) (*PushStateDiskKeyResponse, error) { + out := new(PushStateDiskKeyResponse) + err := c.cc.Invoke(ctx, "/keyproto.API/PushStateDiskKey", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// APIServer is the server API for API service. +// All implementations must embed UnimplementedAPIServer +// for forward compatibility +type APIServer interface { + PushStateDiskKey(context.Context, *PushStateDiskKeyRequest) (*PushStateDiskKeyResponse, error) + mustEmbedUnimplementedAPIServer() +} + +// UnimplementedAPIServer must be embedded to have forward compatible implementations. +type UnimplementedAPIServer struct { +} + +func (UnimplementedAPIServer) PushStateDiskKey(context.Context, *PushStateDiskKeyRequest) (*PushStateDiskKeyResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method PushStateDiskKey not implemented") +} +func (UnimplementedAPIServer) mustEmbedUnimplementedAPIServer() {} + +// UnsafeAPIServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to APIServer will +// result in compilation errors. +type UnsafeAPIServer interface { + mustEmbedUnimplementedAPIServer() +} + +func RegisterAPIServer(s grpc.ServiceRegistrar, srv APIServer) { + s.RegisterService(&API_ServiceDesc, srv) +} + +func _API_PushStateDiskKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(PushStateDiskKeyRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(APIServer).PushStateDiskKey(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/keyproto.API/PushStateDiskKey", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(APIServer).PushStateDiskKey(ctx, req.(*PushStateDiskKeyRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// API_ServiceDesc is the grpc.ServiceDesc for API service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var API_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "keyproto.API", + HandlerType: (*APIServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "PushStateDiskKey", + Handler: _API_PushStateDiskKey_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "keyservice.proto", +} diff --git a/state/keyservice/keyservice.go b/state/keyservice/keyservice.go index 8241970f9..b095666f9 100644 --- a/state/keyservice/keyservice.go +++ b/state/keyservice/keyservice.go @@ -2,41 +2,97 @@ package keyservice import ( "context" + "crypto/tls" "errors" - "fmt" - "os" - "strings" + "log" + "net" "sync" "time" "github.com/edgelesssys/constellation/coordinator/atls" - azurecloud "github.com/edgelesssys/constellation/coordinator/cloudprovider/azure" - gcpcloud "github.com/edgelesssys/constellation/coordinator/cloudprovider/gcp" + "github.com/edgelesssys/constellation/coordinator/config" "github.com/edgelesssys/constellation/coordinator/core" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" - "github.com/edgelesssys/constellation/coordinator/role" + "github.com/edgelesssys/constellation/state/keyservice/keyproto" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/status" ) -// keyAPI is the interface called by the Coordinator or an admin during restart of a node. -type keyAPI struct { - metadata core.ProviderMetadata +// KeyAPI is the interface called by the Coordinator or an admin during restart of a node. +type KeyAPI struct { mux sync.Mutex + metadata core.ProviderMetadata + issuer core.QuoteIssuer key []byte - keyReceived chan bool + keyReceived chan struct{} timeout time.Duration + keyproto.UnimplementedAPIServer } -func (a *keyAPI) waitForDecryptionKey() { - // go server.Start() - // block until a key is pushed - if <-a.keyReceived { - return +// New initializes a KeyAPI with the given parameters. +func New(issuer core.QuoteIssuer, metadata core.ProviderMetadata, timeout time.Duration) *KeyAPI { + return &KeyAPI{ + metadata: metadata, + issuer: issuer, + keyReceived: make(chan struct{}, 1), + timeout: timeout, } } -func (a *keyAPI) requestKeyFromCoordinator(uuid string, opts ...grpc.DialOption) error { +// PushStateDiskKeyRequest is the rpc to push state disk decryption keys to a restarting node. +func (a *KeyAPI) PushStateDiskKey(ctx context.Context, in *keyproto.PushStateDiskKeyRequest) (*keyproto.PushStateDiskKeyResponse, error) { + a.mux.Lock() + defer a.mux.Unlock() + if len(a.key) != 0 { + return nil, status.Error(codes.FailedPrecondition, "node already received a passphrase") + } + if len(in.StateDiskKey) != config.RNGLengthDefault { + return nil, status.Errorf(codes.InvalidArgument, "received invalid passphrase: expected length: %d, but got: %d", config.RNGLengthDefault, len(in.StateDiskKey)) + } + + a.key = in.StateDiskKey + a.keyReceived <- struct{}{} + return &keyproto.PushStateDiskKeyResponse{}, nil +} + +// WaitForDecryptionKey notifies the Coordinator to send a decryption key and waits until a key is received. +func (a *KeyAPI) WaitForDecryptionKey(uuid, listenAddr string) ([]byte, error) { + if uuid == "" { + return nil, errors.New("received no disk UUID") + } + + tlsConfig, err := atls.CreateAttestationServerTLSConfig(a.issuer) + if err != nil { + return nil, err + } + server := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) + keyproto.RegisterAPIServer(server, a) + listener, err := net.Listen("tcp", listenAddr) + if err != nil { + return nil, err + } + defer listener.Close() + + log.Printf("Waiting for decryption key. Listening on: %s", listener.Addr().String()) + go server.Serve(listener) + defer server.GracefulStop() + + if err := a.requestKeyLoop(uuid); err != nil { + return nil, err + } + + return a.key, nil +} + +// ResetKey resets a previously set key. +func (a *KeyAPI) ResetKey() { + a.key = nil +} + +// requestKeyLoop continuously requests decryption keys from all available Coordinators, until the KeyAPI receives a key. +func (a *KeyAPI) requestKeyLoop(uuid string, opts ...grpc.DialOption) error { // we do not perform attestation, since the restarting node does not need to care about notifying the correct Coordinator // if an incorrect key is pushed by a malicious actor, decrypting the disk will fail, and the node will not start tlsClientConfig, err := atls.CreateUnverifiedClientTLSConfig() @@ -44,96 +100,44 @@ func (a *keyAPI) requestKeyFromCoordinator(uuid string, opts ...grpc.DialOption) return err } + // set up for the select statement to immediately request a key, skipping the initial delay caused by using a ticker + firstReq := make(chan struct{}, 1) + firstReq <- struct{}{} + + ticker := time.NewTicker(a.timeout) + defer ticker.Stop() for { select { - // return if a key was received by any means + // return if a key was received // a key can be send by // - a Coordinator, after the request rpc was received // - by a Constellation admin, at any time this loop is running on a node during boot case <-a.keyReceived: return nil - default: - // list available Coordinators - endpoints, _ := core.CoordinatorEndpoints(context.Background(), a.metadata) - // notify the all available Coordinators to send a key to the node - // any errors encountered here will be ignored, and the calls retried after a timeout - for _, endpoint := range endpoints { - ctx, cancel := context.WithTimeout(context.Background(), a.timeout) - conn, err := grpc.DialContext(ctx, endpoint, append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tlsClientConfig)))...) - if err == nil { - client := pubproto.NewAPIClient(conn) - _, _ = client.RequestStateDiskKey(ctx, &pubproto.RequestStateDiskKeyRequest{DiskUuid: uuid}) - conn.Close() - } - - cancel() - } - time.Sleep(a.timeout) + case <-ticker.C: + a.requestKey(uuid, tlsClientConfig, opts...) + case <-firstReq: + a.requestKey(uuid, tlsClientConfig, opts...) } } } -// WaitForDecryptionKey notifies the Coordinator to send a decryption key and waits until a key is received. -func WaitForDecryptionKey(csp, uuid string) ([]byte, error) { - if uuid == "" { - return nil, errors.New("received no disk UUID") - } +func (a *KeyAPI) requestKey(uuid string, tlsClientConfig *tls.Config, opts ...grpc.DialOption) { + // list available Coordinators + endpoints, _ := core.CoordinatorEndpoints(context.Background(), a.metadata) - keyWaiter := &keyAPI{ - keyReceived: make(chan bool, 1), - timeout: 20 * time.Second, // try to request a key every 20 seconds - } - go keyWaiter.waitForDecryptionKey() - - switch strings.ToLower(csp) { - case "azure": - metadata, err := azurecloud.NewMetadata(context.Background()) - if err != nil { - return nil, err + log.Printf("Sending a key request to available Coordinators: %v", endpoints) + // notify all available Coordinators to send a key to the node + // any errors encountered here will be ignored, and the calls retried after a timeout + for _, endpoint := range endpoints { + ctx, cancel := context.WithTimeout(context.Background(), a.timeout) + conn, err := grpc.DialContext(ctx, endpoint, append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tlsClientConfig)))...) + if err == nil { + client := pubproto.NewAPIClient(conn) + _, _ = client.RequestStateDiskKey(ctx, &pubproto.RequestStateDiskKeyRequest{DiskUuid: uuid}) + conn.Close() } - keyWaiter.metadata = metadata - case "gcp": - gcpClient, err := gcpcloud.NewClient(context.Background()) - if err != nil { - return nil, err - } - keyWaiter.metadata = gcpcloud.New(gcpClient) - default: - fmt.Fprintf(os.Stderr, "warning: csp %q is not supported, unable to automatically request decryption keys\n", csp) - keyWaiter.metadata = stubMetadata{} + + cancel() } - - if err := keyWaiter.requestKeyFromCoordinator(uuid); err != nil { - return nil, err - } - - return keyWaiter.key, nil -} - -type stubMetadata struct { - listResponse []core.Instance -} - -func (s stubMetadata) List(ctx context.Context) ([]core.Instance, error) { - return s.listResponse, nil -} - -func (s stubMetadata) Self(ctx context.Context) (core.Instance, error) { - return core.Instance{}, nil -} - -func (s stubMetadata) GetInstance(ctx context.Context, providerID string) (core.Instance, error) { - return core.Instance{}, nil -} - -func (s stubMetadata) SignalRole(ctx context.Context, role role.Role) error { - return nil -} - -func (s stubMetadata) SetVPNIP(ctx context.Context, vpnIP string) error { - return nil -} - -func (s stubMetadata) Supported() bool { - return true } diff --git a/state/keyservice/keyservice_test.go b/state/keyservice/keyservice_test.go index 6b901262b..1b3035778 100644 --- a/state/keyservice/keyservice_test.go +++ b/state/keyservice/keyservice_test.go @@ -2,22 +2,16 @@ package keyservice import ( "context" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" "errors" "net" "testing" "time" + "github.com/edgelesssys/constellation/coordinator/atls" "github.com/edgelesssys/constellation/coordinator/core" - "github.com/edgelesssys/constellation/coordinator/oid" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/role" - "github.com/edgelesssys/constellation/coordinator/util" + "github.com/edgelesssys/constellation/state/keyservice/keyproto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" @@ -25,7 +19,7 @@ import ( "google.golang.org/grpc/test/bufconn" ) -func TestRequestLoop(t *testing.T) { +func TestRequestKeyLoop(t *testing.T) { defaultInstance := core.Instance{ Name: "test-instance", ProviderID: "/test/provider", @@ -77,11 +71,11 @@ func TestRequestLoop(t *testing.T) { assert := assert.New(t) require := require.New(t) - keyReceived := make(chan bool, 1) + keyReceived := make(chan struct{}, 1) listener := bufconn.Listen(1) defer listener.Close() - tlsConfig, err := stubTLSConfig() + tlsConfig, err := atls.CreateAttestationServerTLSConfig(core.NewMockIssuer()) require.NoError(err) s := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) pubproto.RegisterAPIServer(s, tc.server) @@ -90,7 +84,7 @@ func TestRequestLoop(t *testing.T) { go func() { require.NoError(s.Serve(listener)) }() } - keyWaiter := &keyAPI{ + keyWaiter := &KeyAPI{ metadata: stubMetadata{listResponse: tc.listResponse}, keyReceived: keyReceived, timeout: 500 * time.Millisecond, @@ -99,10 +93,10 @@ func TestRequestLoop(t *testing.T) { // notify the API a key was received after 1 second go func() { time.Sleep(1 * time.Second) - keyReceived <- true + keyReceived <- struct{}{} }() - err = keyWaiter.requestKeyFromCoordinator( + err = keyWaiter.requestKeyLoop( "1234", grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) { return listener.DialContext(ctx) @@ -115,6 +109,54 @@ func TestRequestLoop(t *testing.T) { } } +func TestPushStateDiskKey(t *testing.T) { + testCases := map[string]struct { + testAPI *KeyAPI + request *keyproto.PushStateDiskKeyRequest + errExpected bool + }{ + "success": { + testAPI: &KeyAPI{keyReceived: make(chan struct{}, 1)}, + request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}, + }, + "key already set": { + testAPI: &KeyAPI{ + keyReceived: make(chan struct{}, 1), + key: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), + }, + request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB")}, + errExpected: true, + }, + "incorrect size of pushed key": { + testAPI: &KeyAPI{keyReceived: make(chan struct{}, 1)}, + request: &keyproto.PushStateDiskKeyRequest{StateDiskKey: []byte("AAAAAAAAAAAAAAAA")}, + errExpected: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + + _, err := tc.testAPI.PushStateDiskKey(context.Background(), tc.request) + if tc.errExpected { + assert.Error(err) + } else { + assert.NoError(err) + assert.Equal(tc.request.StateDiskKey, tc.testAPI.key) + } + }) + } +} + +func TestResetKey(t *testing.T) { + api := New(nil, nil, time.Second) + + api.key = []byte{0x1, 0x2, 0x3} + api.ResetKey() + assert.Nil(t, api.key) +} + type stubAPIServer struct { requestStateDiskKeyResp *pubproto.RequestStateDiskKeyResponse requestStateDiskKeyErr error @@ -149,30 +191,30 @@ func (s *stubAPIServer) RequestStateDiskKey(ctx context.Context, in *pubproto.Re return s.requestStateDiskKeyResp, s.requestStateDiskKeyErr } -func stubTLSConfig() (*tls.Config, error) { - priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - return nil, err - } - getCertificate := func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { - serialNumber, err := util.GenerateCertificateSerialNumber() - if err != nil { - return nil, err - } - now := time.Now() - template := &x509.Certificate{ - SerialNumber: serialNumber, - Subject: pkix.Name{CommonName: "Constellation"}, - NotBefore: now.Add(-2 * time.Hour), - NotAfter: now.Add(2 * time.Hour), - ExtraExtensions: []pkix.Extension{{Id: oid.Dummy{}.OID(), Value: []byte{0x1, 0x2, 0x3}}}, - } - cert, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) - if err != nil { - return nil, err - } - - return &tls.Certificate{Certificate: [][]byte{cert}, PrivateKey: priv}, nil - } - return &tls.Config{GetCertificate: getCertificate, MinVersion: tls.VersionTLS12}, nil +type stubMetadata struct { + listResponse []core.Instance +} + +func (s stubMetadata) List(ctx context.Context) ([]core.Instance, error) { + return s.listResponse, nil +} + +func (s stubMetadata) Self(ctx context.Context) (core.Instance, error) { + return core.Instance{}, nil +} + +func (s stubMetadata) GetInstance(ctx context.Context, providerID string) (core.Instance, error) { + return core.Instance{}, nil +} + +func (s stubMetadata) SignalRole(ctx context.Context, role role.Role) error { + return nil +} + +func (s stubMetadata) SetVPNIP(ctx context.Context, vpnIP string) error { + return nil +} + +func (s stubMetadata) Supported() bool { + return true } diff --git a/state/mapper/mapper.go b/state/mapper/mapper.go index cb84fb5ef..eaf45c49c 100644 --- a/state/mapper/mapper.go +++ b/state/mapper/mapper.go @@ -3,17 +3,10 @@ package mapper import ( "errors" "fmt" - "path/filepath" cryptsetup "github.com/martinjungblut/go-cryptsetup" ) -const ( - gcpStateDiskPath = "/dev/disk/by-id/google-state-disk" - azureStateDiskPath = "/dev/disk/azure/scsi1/lun0" - fallBackPath = "/dev/disk/by-id/state-disk" -) - // Mapper handles actions for formating and mapping crypt devices. type Mapper struct { device cryptDevice @@ -89,28 +82,3 @@ func (m *Mapper) MapDisk(target, passphrase string) error { func (m *Mapper) UnmapDisk(target string) error { return m.device.Deactivate(target) } - -// GetDiskPath returns the device path of the data disk by cloud provider. -// -// For GCP a symlink to the disk is expected at /dev/disk/by-id/google-state-disk -// For Azure a symlink to the disk is expected at /dev/disk/azure/scsi1/lun0 -// If no symlink can be found at the given path, or if no known cloud provider is supplied, -// we instead return the device path of the os-disk stateful partition at /dev/disk/by-partlabel/stateful. -func GetDiskPath(csp string) (string, error) { - var diskPath string - var err error - - switch csp { - case "gcp": - diskPath, err = filepath.EvalSymlinks(gcpStateDiskPath) - case "azure": - diskPath, err = filepath.EvalSymlinks(azureStateDiskPath) - default: - diskPath = fallBackPath - } - - if err != nil { - return filepath.EvalSymlinks(fallBackPath) - } - return diskPath, nil -} diff --git a/state/setup/interface.go b/state/setup/interface.go new file mode 100644 index 000000000..d88838dd0 --- /dev/null +++ b/state/setup/interface.go @@ -0,0 +1,45 @@ +package setup + +import ( + "io/fs" + "os" + "syscall" +) + +// Mounter is an interface for mount and unmount operations. +type Mounter interface { + Mount(source string, target string, fstype string, flags uintptr, data string) error + Unmount(target string, flags int) error + MkdirAll(path string, perm fs.FileMode) error +} + +// DeviceMapper is an interface for device mapping operations. +type DeviceMapper interface { + DiskUUID() string + FormatDisk(passphrase string) error + MapDisk(target string, passphrase string) error +} + +// KeyWaiter is an interface to request and wait for disk decryption keys. +type KeyWaiter interface { + WaitForDecryptionKey(uuid, addr string) ([]byte, error) + ResetKey() +} + +// DiskMounter uses the syscall package to mount disks. +type DiskMounter struct{} + +// Mount performs a mount syscall. +func (m DiskMounter) Mount(source string, target string, fstype string, flags uintptr, data string) error { + return syscall.Mount(source, target, fstype, flags, data) +} + +// Unmount performs an unmount syscall. +func (m DiskMounter) Unmount(target string, flags int) error { + return syscall.Unmount(target, flags) +} + +// MkdirAll uses os.MkdirAll to create the directory. +func (m DiskMounter) MkdirAll(path string, perm fs.FileMode) error { + return os.MkdirAll(path, perm) +} diff --git a/state/setup/setup.go b/state/setup/setup.go new file mode 100644 index 000000000..92c574d2e --- /dev/null +++ b/state/setup/setup.go @@ -0,0 +1,125 @@ +package setup + +import ( + "crypto/rand" + "errors" + "log" + "net" + "os" + "path/filepath" + "syscall" + + "github.com/edgelesssys/constellation/cli/file" + "github.com/edgelesssys/constellation/coordinator/attestation/vtpm" + "github.com/edgelesssys/constellation/coordinator/config" + "github.com/edgelesssys/constellation/coordinator/nodestate" + "github.com/spf13/afero" +) + +const ( + RecoveryPort = "9000" + keyPath = "/run/cryptsetup-keys.d" + keyFile = "state.key" + stateDiskMappedName = "state" + stateDiskMountPath = "/var/run/state" + stateInfoPath = stateDiskMountPath + "/constellation/node_state.json" +) + +// SetupManager handles formating, mapping, mounting and unmounting of state disks. +type SetupManager struct { + csp string + fs afero.Afero + keyWaiter KeyWaiter + mapper DeviceMapper + mounter Mounter + openTPM vtpm.TPMOpenFunc +} + +// New initializes a SetupManager with the given parameters. +func New(csp string, fs afero.Afero, keyWaiter KeyWaiter, mapper DeviceMapper, mounter Mounter, openTPM vtpm.TPMOpenFunc) *SetupManager { + return &SetupManager{ + csp: csp, + fs: fs, + keyWaiter: keyWaiter, + mapper: mapper, + mounter: mounter, + openTPM: openTPM, + } +} + +// PrepareExistingDisk requests and waits for a decryption key to remap the encrypted state disk. +// Once the disk is mapped, the function taints the node as initialized by updating it's PCRs. +func (s *SetupManager) PrepareExistingDisk() error { + log.Println("Preparing existing state disk") + uuid := s.mapper.DiskUUID() + +getKey: + passphrase, err := s.keyWaiter.WaitForDecryptionKey(uuid, net.JoinHostPort("0.0.0.0", RecoveryPort)) + if err != nil { + return err + } + + if err := s.mapper.MapDisk(stateDiskMappedName, string(passphrase)); err != nil { + // retry key fetching if disk mapping fails + s.keyWaiter.ResetKey() + goto getKey + } + + if err := s.mounter.MkdirAll(stateDiskMountPath, os.ModePerm); err != nil { + return err + } + // we do not care about cleaning up the mount point on error, since any errors returned here should result in a kernel panic in the main function + if err := s.mounter.Mount(filepath.Join("/dev/mapper/", stateDiskMappedName), stateDiskMountPath, "ext4", syscall.MS_RDONLY, ""); err != nil { + return err + } + + ownerID, clusterID, err := s.readInitSecrets(stateInfoPath) + if err != nil { + return err + } + + // taint the node as initialized + if err := vtpm.MarkNodeAsInitialized(s.openTPM, ownerID, clusterID); err != nil { + return err + } + + return s.mounter.Unmount(stateDiskMountPath, 0) +} + +// PrepareNewDisk prepares an instances state disk by formatting the disk as a LUKS device using a random passphrase. +func (s *SetupManager) PrepareNewDisk() error { + log.Println("Preparing new state disk") + + // generate and save temporary passphrase + if err := s.fs.MkdirAll(keyPath, os.ModePerm); err != nil { + return err + } + + passphrase := make([]byte, config.RNGLengthDefault) + if _, err := rand.Read(passphrase); err != nil { + return err + } + if err := s.fs.WriteFile(filepath.Join(keyPath, keyFile), passphrase, 0o400); err != nil { + return err + } + + if err := s.mapper.FormatDisk(string(passphrase)); err != nil { + return err + } + + return s.mapper.MapDisk(stateDiskMappedName, string(passphrase)) +} + +func (s *SetupManager) readInitSecrets(path string) ([]byte, []byte, error) { + handler := file.NewHandler(s.fs) + var state nodestate.NodeState + if err := handler.ReadJSON(path, &state); err != nil { + return nil, nil, err + } + + if len(state.ClusterID) == 0 || len(state.OwnerID) == 0 { + return nil, nil, errors.New("missing state information to retaint node") + } + + return state.OwnerID, state.ClusterID, nil +} diff --git a/state/setup/setup_test.go b/state/setup/setup_test.go new file mode 100644 index 000000000..a342c7d8a --- /dev/null +++ b/state/setup/setup_test.go @@ -0,0 +1,317 @@ +package setup + +import ( + "errors" + "io" + "io/fs" + "path/filepath" + "testing" + + "github.com/edgelesssys/constellation/cli/file" + "github.com/edgelesssys/constellation/coordinator/attestation/vtpm" + "github.com/edgelesssys/constellation/coordinator/config" + "github.com/edgelesssys/constellation/coordinator/nodestate" + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPrepareExistingDisk(t *testing.T) { + someErr := errors.New("error") + + testCases := map[string]struct { + fs afero.Afero + keyWaiter *stubKeyWaiter + mapper *stubMapper + mounter *stubMounter + openTPM vtpm.TPMOpenFunc + missingState bool + errExpected bool + }{ + "success": { + fs: afero.Afero{Fs: afero.NewMemMapFs()}, + keyWaiter: &stubKeyWaiter{}, + mapper: &stubMapper{uuid: "test"}, + mounter: &stubMounter{}, + openTPM: vtpm.OpenNOPTPM, + }, + "WaitForDecryptionKey fails": { + fs: afero.Afero{Fs: afero.NewMemMapFs()}, + keyWaiter: &stubKeyWaiter{waitErr: someErr}, + mapper: &stubMapper{uuid: "test"}, + mounter: &stubMounter{}, + openTPM: vtpm.OpenNOPTPM, + errExpected: true, + }, + "MapDisk fails causes a repeat": { + fs: afero.Afero{Fs: afero.NewMemMapFs()}, + keyWaiter: &stubKeyWaiter{}, + mapper: &stubMapper{ + uuid: "test", + mapDiskErr: someErr, + mapDiskRepeatedCalls: 2, + }, + mounter: &stubMounter{}, + openTPM: vtpm.OpenNOPTPM, + errExpected: false, + }, + "MkdirAll fails": { + fs: afero.Afero{Fs: afero.NewMemMapFs()}, + keyWaiter: &stubKeyWaiter{}, + mapper: &stubMapper{uuid: "test"}, + mounter: &stubMounter{mkdirAllErr: someErr}, + openTPM: vtpm.OpenNOPTPM, + errExpected: true, + }, + "Mount fails": { + fs: afero.Afero{Fs: afero.NewMemMapFs()}, + keyWaiter: &stubKeyWaiter{}, + mapper: &stubMapper{uuid: "test"}, + mounter: &stubMounter{mountErr: someErr}, + openTPM: vtpm.OpenNOPTPM, + errExpected: true, + }, + "Unmount fails": { + fs: afero.Afero{Fs: afero.NewMemMapFs()}, + keyWaiter: &stubKeyWaiter{}, + mapper: &stubMapper{uuid: "test"}, + mounter: &stubMounter{unmountErr: someErr}, + openTPM: vtpm.OpenNOPTPM, + errExpected: true, + }, + "MarkNodeAsInitialized fails": { + fs: afero.Afero{Fs: afero.NewMemMapFs()}, + keyWaiter: &stubKeyWaiter{}, + mapper: &stubMapper{uuid: "test"}, + mounter: &stubMounter{unmountErr: someErr}, + openTPM: failOpener, + errExpected: true, + }, + "no state file": { + fs: afero.Afero{Fs: afero.NewMemMapFs()}, + keyWaiter: &stubKeyWaiter{}, + mapper: &stubMapper{uuid: "test"}, + mounter: &stubMounter{}, + openTPM: vtpm.OpenNOPTPM, + missingState: true, + errExpected: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + + if !tc.missingState { + handler := file.NewHandler(tc.fs) + require.NoError(t, handler.WriteJSON(stateInfoPath, nodestate.NodeState{OwnerID: []byte("ownerID"), ClusterID: []byte("clusterID")}, file.OptMkdirAll)) + } + + setupManager := New("test", tc.fs, tc.keyWaiter, tc.mapper, tc.mounter, tc.openTPM) + + err := setupManager.PrepareExistingDisk() + if tc.errExpected { + assert.Error(err) + } else { + assert.NoError(err) + assert.Equal(tc.mapper.uuid, tc.keyWaiter.receivedUUID) + assert.True(tc.mapper.mapDiskCalled) + assert.True(tc.mounter.mountCalled) + assert.True(tc.mounter.unmountCalled) + assert.False(tc.mapper.formatDiskCalled) + } + }) + } +} + +func failOpener() (io.ReadWriteCloser, error) { + return nil, errors.New("error") +} + +func TestPrepareNewDisk(t *testing.T) { + someErr := errors.New("error") + testCases := map[string]struct { + fs afero.Afero + mapper *stubMapper + errExpected bool + }{ + "success": { + fs: afero.Afero{Fs: afero.NewMemMapFs()}, + mapper: &stubMapper{uuid: "test"}, + }, + "creating directory fails": { + fs: afero.Afero{Fs: afero.NewReadOnlyFs(afero.NewMemMapFs())}, + mapper: &stubMapper{}, + errExpected: true, + }, + "FormatDisk fails": { + fs: afero.Afero{Fs: afero.NewMemMapFs()}, + mapper: &stubMapper{ + uuid: "test", + formatDiskErr: someErr, + }, + errExpected: true, + }, + "MapDisk fails": { + fs: afero.Afero{Fs: afero.NewMemMapFs()}, + mapper: &stubMapper{ + uuid: "test", + mapDiskErr: someErr, + mapDiskRepeatedCalls: 1, + }, + errExpected: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + + setupManager := New("test", tc.fs, nil, tc.mapper, nil, nil) + + err := setupManager.PrepareNewDisk() + if tc.errExpected { + assert.Error(err) + } else { + assert.NoError(err) + assert.True(tc.mapper.formatDiskCalled) + assert.True(tc.mapper.mapDiskCalled) + + data, err := tc.fs.ReadFile(filepath.Join(keyPath, keyFile)) + require.NoError(t, err) + assert.Len(data, config.RNGLengthDefault) + } + }) + } +} + +func TestReadInitSecrets(t *testing.T) { + testCases := map[string]struct { + fs afero.Afero + ownerID string + clusterID string + writeFile bool + errExpected bool + }{ + "success": { + fs: afero.Afero{Fs: afero.NewMemMapFs()}, + ownerID: "ownerID", + clusterID: "clusterID", + writeFile: true, + }, + "no state file": { + fs: afero.Afero{Fs: afero.NewMemMapFs()}, + errExpected: true, + }, + "missing ownerID": { + fs: afero.Afero{Fs: afero.NewMemMapFs()}, + clusterID: "clusterID", + writeFile: true, + errExpected: true, + }, + "missing clusterID": { + fs: afero.Afero{Fs: afero.NewMemMapFs()}, + ownerID: "ownerID", + writeFile: true, + errExpected: true, + }, + "no IDs": { + fs: afero.Afero{Fs: afero.NewMemMapFs()}, + writeFile: true, + errExpected: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + if tc.writeFile { + handler := file.NewHandler(tc.fs) + state := nodestate.NodeState{ClusterID: []byte(tc.clusterID), OwnerID: []byte(tc.ownerID)} + require.NoError(handler.WriteJSON("/tmp/test-state.json", state, file.OptMkdirAll)) + } + + setupManager := New("test", tc.fs, nil, nil, nil, nil) + + ownerID, clusterID, err := setupManager.readInitSecrets("/tmp/test-state.json") + if tc.errExpected { + assert.Error(err) + } else { + assert.NoError(err) + assert.Equal([]byte(tc.ownerID), ownerID) + assert.Equal([]byte(tc.clusterID), clusterID) + } + }) + } +} + +type stubMapper struct { + formatDiskCalled bool + formatDiskErr error + mapDiskRepeatedCalls int + mapDiskCalled bool + mapDiskErr error + uuid string +} + +func (s *stubMapper) DiskUUID() string { + return s.uuid +} + +func (s *stubMapper) FormatDisk(passphrase string) error { + s.formatDiskCalled = true + return s.formatDiskErr +} + +func (s *stubMapper) MapDisk(target string, passphrase string) error { + if s.mapDiskRepeatedCalls == 0 { + s.mapDiskErr = nil + } + s.mapDiskRepeatedCalls-- + s.mapDiskCalled = true + return s.mapDiskErr +} + +type stubMounter struct { + mountCalled bool + mountErr error + unmountCalled bool + unmountErr error + mkdirAllErr error +} + +func (s *stubMounter) Mount(source string, target string, fstype string, flags uintptr, data string) error { + s.mountCalled = true + return s.mountErr +} + +func (s *stubMounter) Unmount(target string, flags int) error { + s.unmountCalled = true + return s.unmountErr +} + +func (s *stubMounter) MkdirAll(path string, perm fs.FileMode) error { + return s.mkdirAllErr +} + +type stubKeyWaiter struct { + receivedUUID string + decryptionKey []byte + waitErr error + waitCalled bool +} + +func (s *stubKeyWaiter) WaitForDecryptionKey(uuid, addr string) ([]byte, error) { + if s.waitCalled { + return nil, errors.New("wait called before key was reset") + } + s.waitCalled = true + s.receivedUUID = uuid + return s.decryptionKey, s.waitErr +} + +func (s *stubKeyWaiter) ResetKey() { + s.waitCalled = false +} diff --git a/state/test/integration_test.go b/state/test/integration_test.go index 3b630b693..9061e8ff1 100644 --- a/state/test/integration_test.go +++ b/state/test/integration_test.go @@ -3,15 +3,24 @@ package integration import ( + "context" "fmt" + "net" "os" "os/exec" "testing" + "time" + "github.com/edgelesssys/constellation/coordinator/atls" + "github.com/edgelesssys/constellation/coordinator/core" + "github.com/edgelesssys/constellation/state/keyservice" + "github.com/edgelesssys/constellation/state/keyservice/keyproto" "github.com/edgelesssys/constellation/state/mapper" "github.com/martinjungblut/go-cryptsetup" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" ) const ( @@ -62,3 +71,42 @@ func TestMapper(t *testing.T) { // Try to map disk with incorrect passphrase assert.Error(mapper.MapDisk(mappedDevice, "invalid-passphrase"), "was able to map disk with incorrect passphrase") } + +func TestKeyAPI(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + + testKey := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") + + // get a free port on localhost to run the test on + listener, err := net.Listen("tcp", "localhost:0") + require.NoError(err) + apiAddr := listener.Addr().String() + listener.Close() + + api := keyservice.New(core.NewMockIssuer(), &core.ProviderMetadataFake{}, 20*time.Second) + + // send a key to the server + go func() { + // wait 2 seconds before sending the key + time.Sleep(2 * time.Second) + + clientCfg, err := atls.CreateUnverifiedClientTLSConfig() + require.NoError(err) + conn, err := grpc.Dial(apiAddr, grpc.WithTransportCredentials(credentials.NewTLS(clientCfg))) + require.NoError(err) + defer conn.Close() + + client := keyproto.NewAPIClient(conn) + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + _, err = client.PushStateDiskKey(ctx, &keyproto.PushStateDiskKeyRequest{ + StateDiskKey: testKey, + }) + require.NoError(err) + }() + + key, err := api.WaitForDecryptionKey("12345678-1234-1234-1234-123456789ABC", apiAddr) + assert.NoError(err) + assert.Equal(testKey, key) +}