mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-01-12 16:09:39 -05:00
AB#2305 Fix missing atls verifier in init call (#352)
Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
parent
aee3f2afa2
commit
8f5f84deb5
@ -18,14 +18,14 @@ import (
|
|||||||
|
|
||||||
const warningStr = "Warning: not verifying the Constellation cluster's %s measurements\n"
|
const warningStr = "Warning: not verifying the Constellation cluster's %s measurements\n"
|
||||||
|
|
||||||
type Validators struct {
|
type Validator struct {
|
||||||
provider cloudprovider.Provider
|
provider cloudprovider.Provider
|
||||||
pcrs map[uint32][]byte
|
pcrs map[uint32][]byte
|
||||||
validators []atls.Validator
|
validator atls.Validator
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewValidators(provider cloudprovider.Provider, config *config.Config) (*Validators, error) {
|
func NewValidator(provider cloudprovider.Provider, config *config.Config) (*Validator, error) {
|
||||||
v := Validators{}
|
v := Validator{}
|
||||||
if provider == cloudprovider.Unknown {
|
if provider == cloudprovider.Unknown {
|
||||||
return nil, errors.New("unknown cloud provider")
|
return nil, errors.New("unknown cloud provider")
|
||||||
}
|
}
|
||||||
@ -36,7 +36,7 @@ func NewValidators(provider cloudprovider.Provider, config *config.Config) (*Val
|
|||||||
return &v, nil
|
return &v, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *Validators) UpdateInitPCRs(ownerID, clusterID string) error {
|
func (v *Validator) UpdateInitPCRs(ownerID, clusterID string) error {
|
||||||
if err := v.updatePCR(uint32(vtpm.PCRIndexOwnerID), ownerID); err != nil {
|
if err := v.updatePCR(uint32(vtpm.PCRIndexOwnerID), ownerID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -48,7 +48,7 @@ func (v *Validators) UpdateInitPCRs(ownerID, clusterID string) error {
|
|||||||
// When adding, the input is first decoded from base64.
|
// When adding, the input is first decoded from base64.
|
||||||
// We then calculate the expected PCR by hashing the input using SHA256,
|
// We then calculate the expected PCR by hashing the input using SHA256,
|
||||||
// appending expected PCR for initialization, and then hashing once more.
|
// appending expected PCR for initialization, and then hashing once more.
|
||||||
func (v *Validators) updatePCR(pcrIndex uint32, encoded string) error {
|
func (v *Validator) updatePCR(pcrIndex uint32, encoded string) error {
|
||||||
if encoded == "" {
|
if encoded == "" {
|
||||||
delete(v.pcrs, pcrIndex)
|
delete(v.pcrs, pcrIndex)
|
||||||
return nil
|
return nil
|
||||||
@ -65,7 +65,7 @@ func (v *Validators) updatePCR(pcrIndex uint32, encoded string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *Validators) setPCRs(config *config.Config) error {
|
func (v *Validator) setPCRs(config *config.Config) error {
|
||||||
switch v.provider {
|
switch v.provider {
|
||||||
case cloudprovider.GCP:
|
case cloudprovider.GCP:
|
||||||
gcpPCRs := config.Provider.GCP.Measurements
|
gcpPCRs := config.Provider.GCP.Measurements
|
||||||
@ -89,33 +89,32 @@ func (v *Validators) setPCRs(config *config.Config) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// V returns validators as list of atls.Validator.
|
// V returns the validator as atls.Validator.
|
||||||
func (v *Validators) V() []atls.Validator {
|
func (v *Validator) V() atls.Validator {
|
||||||
v.updateValidators()
|
v.updateValidator()
|
||||||
return v.validators
|
return v.validator
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *Validators) updateValidators() {
|
// PCRS returns the validator's PCR map.
|
||||||
|
func (v *Validator) PCRS() map[uint32][]byte {
|
||||||
|
return v.pcrs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Validator) updateValidator() {
|
||||||
switch v.provider {
|
switch v.provider {
|
||||||
case cloudprovider.GCP:
|
case cloudprovider.GCP:
|
||||||
v.validators = []atls.Validator{
|
v.validator = gcp.NewValidator(v.pcrs)
|
||||||
gcp.NewValidator(v.pcrs),
|
|
||||||
}
|
|
||||||
case cloudprovider.Azure:
|
case cloudprovider.Azure:
|
||||||
v.validators = []atls.Validator{
|
v.validator = azure.NewValidator(v.pcrs)
|
||||||
azure.NewValidator(v.pcrs),
|
|
||||||
}
|
|
||||||
case cloudprovider.QEMU:
|
case cloudprovider.QEMU:
|
||||||
v.validators = []atls.Validator{
|
v.validator = qemu.NewValidator(v.pcrs)
|
||||||
qemu.NewValidator(v.pcrs),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Warnings returns warnings for the specifc PCR values that are not verified.
|
// Warnings returns warnings for the specifc PCR values that are not verified.
|
||||||
//
|
//
|
||||||
// PCR allocation inspired by https://link.springer.com/chapter/10.1007/978-1-4302-6584-9_12#Tab1
|
// PCR allocation inspired by https://link.springer.com/chapter/10.1007/978-1-4302-6584-9_12#Tab1
|
||||||
func (v *Validators) Warnings() string {
|
func (v *Validator) Warnings() string {
|
||||||
sb := &strings.Builder{}
|
sb := &strings.Builder{}
|
||||||
|
|
||||||
if v.pcrs[0] == nil || v.pcrs[1] == nil {
|
if v.pcrs[0] == nil || v.pcrs[1] == nil {
|
||||||
@ -141,11 +140,11 @@ func (v *Validators) Warnings() string {
|
|||||||
return sb.String()
|
return sb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// WarningsIncludeInit returns warnings for the specifc PCR values that are not verified.
|
// WarningsIncludeInit returns warnings for the specific PCR values that are not verified.
|
||||||
// Warnings regarding the initialization are included.
|
// Warnings regarding the initialization are included.
|
||||||
//
|
//
|
||||||
// PCR allocation inspired by https://link.springer.com/chapter/10.1007/978-1-4302-6584-9_12#Tab1
|
// PCR allocation inspired by https://link.springer.com/chapter/10.1007/978-1-4302-6584-9_12#Tab1
|
||||||
func (v *Validators) WarningsIncludeInit() string {
|
func (v *Validator) WarningsIncludeInit() string {
|
||||||
warnings := v.Warnings()
|
warnings := v.Warnings()
|
||||||
if v.pcrs[uint32(vtpm.PCRIndexOwnerID)] == nil || v.pcrs[uint32(vtpm.PCRIndexClusterID)] == nil {
|
if v.pcrs[uint32(vtpm.PCRIndexOwnerID)] == nil || v.pcrs[uint32(vtpm.PCRIndexClusterID)] == nil {
|
||||||
warnings = warnings + fmt.Sprintf(warningStr, "initialization status")
|
warnings = warnings + fmt.Sprintf(warningStr, "initialization status")
|
||||||
@ -154,7 +153,7 @@ func (v *Validators) WarningsIncludeInit() string {
|
|||||||
return warnings
|
return warnings
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *Validators) checkPCRs(pcrs map[uint32][]byte) error {
|
func (v *Validator) checkPCRs(pcrs map[uint32][]byte) error {
|
||||||
if len(pcrs) == 0 {
|
if len(pcrs) == 0 {
|
||||||
return errors.New("no PCR values provided")
|
return errors.New("no PCR values provided")
|
||||||
}
|
}
|
||||||
|
@ -15,7 +15,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewValidators(t *testing.T) {
|
func TestNewValidator(t *testing.T) {
|
||||||
zero := []byte("00000000000000000000000000000000")
|
zero := []byte("00000000000000000000000000000000")
|
||||||
one := []byte("11111111111111111111111111111111")
|
one := []byte("11111111111111111111111111111111")
|
||||||
testPCRs := map[uint32][]byte{
|
testPCRs := map[uint32][]byte{
|
||||||
@ -80,7 +80,7 @@ func TestNewValidators(t *testing.T) {
|
|||||||
conf.Provider.QEMU = &config.QEMUConfig{Measurements: measurements}
|
conf.Provider.QEMU = &config.QEMUConfig{Measurements: measurements}
|
||||||
}
|
}
|
||||||
|
|
||||||
validators, err := NewValidators(tc.provider, conf)
|
validators, err := NewValidator(tc.provider, conf)
|
||||||
|
|
||||||
if tc.wantErr {
|
if tc.wantErr {
|
||||||
assert.Error(err)
|
assert.Error(err)
|
||||||
@ -93,7 +93,7 @@ func TestNewValidators(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidatorsWarnings(t *testing.T) {
|
func TestValidatorWarnings(t *testing.T) {
|
||||||
zero := []byte("00000000000000000000000000000000")
|
zero := []byte("00000000000000000000000000000000")
|
||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
@ -233,7 +233,7 @@ func TestValidatorsWarnings(t *testing.T) {
|
|||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
assert := assert.New(t)
|
assert := assert.New(t)
|
||||||
|
|
||||||
validators := Validators{pcrs: tc.pcrs}
|
validators := Validator{pcrs: tc.pcrs}
|
||||||
|
|
||||||
warnings := validators.Warnings()
|
warnings := validators.Warnings()
|
||||||
warningsInclueInit := validators.WarningsIncludeInit()
|
warningsInclueInit := validators.WarningsIncludeInit()
|
||||||
@ -259,7 +259,7 @@ func TestValidatorsWarnings(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidatorsV(t *testing.T) {
|
func TestValidatorV(t *testing.T) {
|
||||||
zero := []byte("00000000000000000000000000000000")
|
zero := []byte("00000000000000000000000000000000")
|
||||||
newTestPCRs := func() map[uint32][]byte {
|
newTestPCRs := func() map[uint32][]byte {
|
||||||
return map[uint32][]byte{
|
return map[uint32][]byte{
|
||||||
@ -282,28 +282,22 @@ func TestValidatorsV(t *testing.T) {
|
|||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
provider cloudprovider.Provider
|
provider cloudprovider.Provider
|
||||||
pcrs map[uint32][]byte
|
pcrs map[uint32][]byte
|
||||||
wantVs []atls.Validator
|
wantVs atls.Validator
|
||||||
}{
|
}{
|
||||||
"gcp": {
|
"gcp": {
|
||||||
provider: cloudprovider.GCP,
|
provider: cloudprovider.GCP,
|
||||||
pcrs: newTestPCRs(),
|
pcrs: newTestPCRs(),
|
||||||
wantVs: []atls.Validator{
|
wantVs: gcp.NewValidator(newTestPCRs()),
|
||||||
gcp.NewValidator(newTestPCRs()),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
"azure": {
|
"azure": {
|
||||||
provider: cloudprovider.Azure,
|
provider: cloudprovider.Azure,
|
||||||
pcrs: newTestPCRs(),
|
pcrs: newTestPCRs(),
|
||||||
wantVs: []atls.Validator{
|
wantVs: azure.NewValidator(newTestPCRs()),
|
||||||
azure.NewValidator(newTestPCRs()),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
"qemu": {
|
"qemu": {
|
||||||
provider: cloudprovider.QEMU,
|
provider: cloudprovider.QEMU,
|
||||||
pcrs: newTestPCRs(),
|
pcrs: newTestPCRs(),
|
||||||
wantVs: []atls.Validator{
|
wantVs: qemu.NewValidator(newTestPCRs()),
|
||||||
qemu.NewValidator(newTestPCRs()),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -311,19 +305,16 @@ func TestValidatorsV(t *testing.T) {
|
|||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
assert := assert.New(t)
|
assert := assert.New(t)
|
||||||
|
|
||||||
validators := &Validators{provider: tc.provider, pcrs: tc.pcrs}
|
validators := &Validator{provider: tc.provider, pcrs: tc.pcrs}
|
||||||
|
|
||||||
resultValidators := validators.V()
|
resultValidator := validators.V()
|
||||||
|
|
||||||
assert.Equal(len(tc.wantVs), len(resultValidators))
|
assert.Equal(tc.wantVs.OID(), resultValidator.OID())
|
||||||
for i, resValidator := range resultValidators {
|
|
||||||
assert.Equal(tc.wantVs[i].OID(), resValidator.OID())
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidatorsUpdateInitPCRs(t *testing.T) {
|
func TestValidatorUpdateInitPCRs(t *testing.T) {
|
||||||
zero := []byte("00000000000000000000000000000000")
|
zero := []byte("00000000000000000000000000000000")
|
||||||
one := []byte("11111111111111111111111111111111")
|
one := []byte("11111111111111111111111111111111")
|
||||||
one64 := base64.StdEncoding.EncodeToString(one)
|
one64 := base64.StdEncoding.EncodeToString(one)
|
||||||
@ -402,7 +393,7 @@ func TestValidatorsUpdateInitPCRs(t *testing.T) {
|
|||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
assert := assert.New(t)
|
assert := assert.New(t)
|
||||||
|
|
||||||
validators := &Validators{provider: tc.provider, pcrs: tc.pcrs}
|
validators := &Validator{provider: tc.provider, pcrs: tc.pcrs}
|
||||||
|
|
||||||
err := validators.UpdateInitPCRs(tc.ownerID, tc.clusterID)
|
err := validators.UpdateInitPCRs(tc.ownerID, tc.clusterID)
|
||||||
|
|
||||||
@ -515,7 +506,7 @@ func TestUpdatePCR(t *testing.T) {
|
|||||||
pcrs[k] = v
|
pcrs[k] = v
|
||||||
}
|
}
|
||||||
|
|
||||||
validators := &Validators{
|
validators := &Validator{
|
||||||
provider: cloudprovider.GCP,
|
provider: cloudprovider.GCP,
|
||||||
pcrs: pcrs,
|
pcrs: pcrs,
|
||||||
}
|
}
|
||||||
|
@ -52,13 +52,15 @@ func NewInitCmd() *cobra.Command {
|
|||||||
func runInitialize(cmd *cobra.Command, args []string) error {
|
func runInitialize(cmd *cobra.Command, args []string) error {
|
||||||
fileHandler := file.NewHandler(afero.NewOsFs())
|
fileHandler := file.NewHandler(afero.NewOsFs())
|
||||||
serviceAccountCreator := cloudcmd.NewServiceAccountCreator()
|
serviceAccountCreator := cloudcmd.NewServiceAccountCreator()
|
||||||
dialer := dialer.New(nil, nil, &net.Dialer{})
|
newDialer := func(validator *cloudcmd.Validator) *dialer.Dialer {
|
||||||
return initialize(cmd, dialer, serviceAccountCreator, fileHandler)
|
return dialer.New(nil, validator.V(), &net.Dialer{})
|
||||||
|
}
|
||||||
|
return initialize(cmd, newDialer, serviceAccountCreator, fileHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize initializes a Constellation.
|
// initialize initializes a Constellation.
|
||||||
func initialize(cmd *cobra.Command, dialer grpcDialer, serviceAccCreator serviceAccountCreator,
|
func initialize(cmd *cobra.Command, newDialer func(*cloudcmd.Validator) *dialer.Dialer,
|
||||||
fileHandler file.Handler,
|
serviceAccCreator serviceAccountCreator, fileHandler file.Handler,
|
||||||
) error {
|
) error {
|
||||||
flags, err := evalFlagArgs(cmd, fileHandler)
|
flags, err := evalFlagArgs(cmd, fileHandler)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -88,11 +90,11 @@ func initialize(cmd *cobra.Command, dialer grpcDialer, serviceAccCreator service
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
validators, err := cloudcmd.NewValidators(provider, config)
|
validator, err := cloudcmd.NewValidator(provider, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
cmd.Print(validators.WarningsIncludeInit())
|
cmd.Print(validator.WarningsIncludeInit())
|
||||||
|
|
||||||
cmd.Println("Creating service account ...")
|
cmd.Println("Creating service account ...")
|
||||||
serviceAccount, stat, err := serviceAccCreator.Create(cmd.Context(), stat, config)
|
serviceAccount, stat, err := serviceAccCreator.Create(cmd.Context(), stat, config)
|
||||||
@ -103,7 +105,7 @@ func initialize(cmd *cobra.Command, dialer grpcDialer, serviceAccCreator service
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
controlPlanes, workers, err := getScalingGroupsFromState(stat, config)
|
_, workers, err := getScalingGroupsFromState(stat, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -125,12 +127,12 @@ func initialize(cmd *cobra.Command, dialer grpcDialer, serviceAccCreator service
|
|||||||
KubernetesVersion: config.KubernetesVersion,
|
KubernetesVersion: config.KubernetesVersion,
|
||||||
SshUserKeys: ssh.ToProtoSlice(sshUsers),
|
SshUserKeys: ssh.ToProtoSlice(sshUsers),
|
||||||
}
|
}
|
||||||
resp, err := initCall(cmd.Context(), dialer, stat.BootstrapperHost, req)
|
resp, err := initCall(cmd.Context(), newDialer(validator), stat.BootstrapperHost, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return writeOutput(resp, controlPlanes.PublicIPs()[0], cmd.OutOrStdout(), fileHandler)
|
return writeOutput(resp, stat.BootstrapperHost, cmd.OutOrStdout(), fileHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func initCall(ctx context.Context, dialer grpcDialer, ip string, req *initproto.InitRequest) (*initproto.InitResponse, error) {
|
func initCall(ctx context.Context, dialer grpcDialer, ip string, req *initproto.InitRequest) (*initproto.InitResponse, error) {
|
||||||
|
@ -13,12 +13,16 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/edgelesssys/constellation/bootstrapper/initproto"
|
"github.com/edgelesssys/constellation/bootstrapper/initproto"
|
||||||
|
"github.com/edgelesssys/constellation/cli/internal/cloudcmd"
|
||||||
|
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
|
||||||
"github.com/edgelesssys/constellation/internal/cloud/cloudtypes"
|
"github.com/edgelesssys/constellation/internal/cloud/cloudtypes"
|
||||||
|
"github.com/edgelesssys/constellation/internal/config"
|
||||||
"github.com/edgelesssys/constellation/internal/constants"
|
"github.com/edgelesssys/constellation/internal/constants"
|
||||||
"github.com/edgelesssys/constellation/internal/file"
|
"github.com/edgelesssys/constellation/internal/file"
|
||||||
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
|
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
|
||||||
"github.com/edgelesssys/constellation/internal/grpc/dialer"
|
"github.com/edgelesssys/constellation/internal/grpc/dialer"
|
||||||
"github.com/edgelesssys/constellation/internal/grpc/testdialer"
|
"github.com/edgelesssys/constellation/internal/grpc/testdialer"
|
||||||
|
"github.com/edgelesssys/constellation/internal/oid"
|
||||||
"github.com/edgelesssys/constellation/internal/state"
|
"github.com/edgelesssys/constellation/internal/state"
|
||||||
"github.com/spf13/afero"
|
"github.com/spf13/afero"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
@ -131,7 +135,9 @@ func TestInitialize(t *testing.T) {
|
|||||||
require := require.New(t)
|
require := require.New(t)
|
||||||
|
|
||||||
netDialer := testdialer.NewBufconnDialer()
|
netDialer := testdialer.NewBufconnDialer()
|
||||||
dialer := dialer.New(nil, nil, netDialer)
|
newDialer := func(*cloudcmd.Validator) *dialer.Dialer {
|
||||||
|
return dialer.New(nil, nil, netDialer)
|
||||||
|
}
|
||||||
serverCreds := atlscredentials.New(nil, nil)
|
serverCreds := atlscredentials.New(nil, nil)
|
||||||
initServer := grpc.NewServer(grpc.Creds(serverCreds))
|
initServer := grpc.NewServer(grpc.Creds(serverCreds))
|
||||||
initproto.RegisterAPIServer(initServer, tc.initServerAPI)
|
initproto.RegisterAPIServer(initServer, tc.initServerAPI)
|
||||||
@ -145,7 +151,7 @@ func TestInitialize(t *testing.T) {
|
|||||||
cmd.SetOut(&out)
|
cmd.SetOut(&out)
|
||||||
var errOut bytes.Buffer
|
var errOut bytes.Buffer
|
||||||
cmd.SetErr(&errOut)
|
cmd.SetErr(&errOut)
|
||||||
cmd.Flags().String("config", "", "") // register persisten flag manually
|
cmd.Flags().String("config", "", "") // register persistent flag manually
|
||||||
fs := afero.NewMemMapFs()
|
fs := afero.NewMemMapFs()
|
||||||
fileHandler := file.NewHandler(fs)
|
fileHandler := file.NewHandler(fs)
|
||||||
require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.existingState, file.OptNone))
|
require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.existingState, file.OptNone))
|
||||||
@ -156,7 +162,7 @@ func TestInitialize(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
cmd.SetContext(ctx)
|
cmd.SetContext(ctx)
|
||||||
|
|
||||||
err := initialize(cmd, dialer, &tc.serviceAccountCreator, fileHandler)
|
err := initialize(cmd, newDialer, &tc.serviceAccountCreator, fileHandler)
|
||||||
|
|
||||||
if tc.wantErr {
|
if tc.wantErr {
|
||||||
assert.Error(err)
|
assert.Error(err)
|
||||||
@ -364,6 +370,122 @@ func TestReadOrGeneratedMasterSecret(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAttestation(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
initServerAPI := &stubInitServer{initResp: &initproto.InitResponse{
|
||||||
|
Kubeconfig: []byte("kubeconfig"),
|
||||||
|
OwnerId: []byte("ownerID"),
|
||||||
|
ClusterId: []byte("clusterID"),
|
||||||
|
}}
|
||||||
|
existingState := state.ConstellationState{
|
||||||
|
CloudProvider: "QEMU",
|
||||||
|
BootstrapperHost: "192.0.2.1",
|
||||||
|
QEMUWorkers: cloudtypes.Instances{
|
||||||
|
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||||
|
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||||
|
},
|
||||||
|
QEMUControlPlane: cloudtypes.Instances{
|
||||||
|
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
netDialer := testdialer.NewBufconnDialer()
|
||||||
|
newDialer := func(v *cloudcmd.Validator) *dialer.Dialer {
|
||||||
|
validator := &testValidator{
|
||||||
|
Getter: oid.QEMU{},
|
||||||
|
pcrs: v.PCRS(),
|
||||||
|
}
|
||||||
|
return dialer.New(nil, validator, netDialer)
|
||||||
|
}
|
||||||
|
|
||||||
|
issuer := &testIssuer{
|
||||||
|
Getter: oid.QEMU{},
|
||||||
|
pcrs: map[uint32][]byte{
|
||||||
|
0: []byte("ffffffffffffffffffffffffffffffff"),
|
||||||
|
1: []byte("ffffffffffffffffffffffffffffffff"),
|
||||||
|
2: []byte("ffffffffffffffffffffffffffffffff"),
|
||||||
|
3: []byte("ffffffffffffffffffffffffffffffff"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
serverCreds := atlscredentials.New(issuer, nil)
|
||||||
|
initServer := grpc.NewServer(grpc.Creds(serverCreds))
|
||||||
|
initproto.RegisterAPIServer(initServer, initServerAPI)
|
||||||
|
port := strconv.Itoa(constants.BootstrapperPort)
|
||||||
|
listener := netDialer.GetListener(net.JoinHostPort("192.0.2.1", port))
|
||||||
|
go initServer.Serve(listener)
|
||||||
|
defer initServer.GracefulStop()
|
||||||
|
|
||||||
|
cmd := NewInitCmd()
|
||||||
|
cmd.Flags().String("config", constants.ConfigFilename, "") // register persistent flag manually
|
||||||
|
var out bytes.Buffer
|
||||||
|
cmd.SetOut(&out)
|
||||||
|
var errOut bytes.Buffer
|
||||||
|
cmd.SetErr(&errOut)
|
||||||
|
|
||||||
|
fs := afero.NewMemMapFs()
|
||||||
|
fileHandler := file.NewHandler(fs)
|
||||||
|
require.NoError(fileHandler.WriteJSON(constants.StateFilename, existingState, file.OptNone))
|
||||||
|
|
||||||
|
cfg := config.Default()
|
||||||
|
cfg.RemoveProviderExcept(cloudprovider.QEMU)
|
||||||
|
cfg.Provider.QEMU.Measurements[0] = []byte("00000000000000000000000000000000")
|
||||||
|
cfg.Provider.QEMU.Measurements[1] = []byte("11111111111111111111111111111111")
|
||||||
|
cfg.Provider.QEMU.Measurements[2] = []byte("22222222222222222222222222222222")
|
||||||
|
cfg.Provider.QEMU.Measurements[3] = []byte("33333333333333333333333333333333")
|
||||||
|
require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg, file.OptNone))
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
cmd.SetContext(ctx)
|
||||||
|
|
||||||
|
err := initialize(cmd, newDialer, &stubServiceAccountCreator{}, fileHandler)
|
||||||
|
assert.Error(err)
|
||||||
|
// make sure the error is actually a TLS handshake error
|
||||||
|
assert.Contains(err.Error(), "transport: authentication handshake failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
type testValidator struct {
|
||||||
|
oid.Getter
|
||||||
|
pcrs map[uint32][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *testValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
|
||||||
|
var attestation struct {
|
||||||
|
UserData []byte
|
||||||
|
PCRs map[uint32][]byte
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(attDoc, &attestation); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, pcr := range v.pcrs {
|
||||||
|
if !bytes.Equal(attestation.PCRs[k], pcr) {
|
||||||
|
return nil, errors.New("invalid PCR value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return attestation.UserData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type testIssuer struct {
|
||||||
|
oid.Getter
|
||||||
|
pcrs map[uint32][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *testIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) {
|
||||||
|
return json.Marshal(
|
||||||
|
struct {
|
||||||
|
UserData []byte
|
||||||
|
PCRs map[uint32][]byte
|
||||||
|
}{
|
||||||
|
UserData: userData,
|
||||||
|
PCRs: i.pcrs,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
type stubInitServer struct {
|
type stubInitServer struct {
|
||||||
initResp *initproto.InitResponse
|
initResp *initproto.InitResponse
|
||||||
initErr error
|
initErr error
|
||||||
|
@ -68,7 +68,7 @@ func recover(cmd *cobra.Command, fileHandler file.Handler, recoveryClient recove
|
|||||||
return fmt.Errorf("reading and validating config: %w", err)
|
return fmt.Errorf("reading and validating config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
validators, err := cloudcmd.NewValidators(provider, config)
|
validators, err := cloudcmd.NewValidator(provider, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -8,7 +8,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type recoveryClient interface {
|
type recoveryClient interface {
|
||||||
Connect(endpoint string, validators []atls.Validator) error
|
Connect(endpoint string, validators atls.Validator) error
|
||||||
PushStateDiskKey(ctx context.Context, stateDiskKey, measurementSecret []byte) error
|
PushStateDiskKey(ctx context.Context, stateDiskKey, measurementSecret []byte) error
|
||||||
io.Closer
|
io.Closer
|
||||||
}
|
}
|
||||||
|
@ -15,7 +15,7 @@ type stubRecoveryClient struct {
|
|||||||
pushStateDiskKeyKey []byte
|
pushStateDiskKeyKey []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *stubRecoveryClient) Connect(_ string, _ []atls.Validator) error {
|
func (c *stubRecoveryClient) Connect(_ string, _ atls.Validator) error {
|
||||||
c.conn = true
|
c.conn = true
|
||||||
return c.connectErr
|
return c.connectErr
|
||||||
}
|
}
|
||||||
|
@ -62,7 +62,7 @@ func verify(
|
|||||||
return fmt.Errorf("reading and validating config: %w", err)
|
return fmt.Errorf("reading and validating config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
validators, err := cloudcmd.NewValidators(provider, config)
|
validators, err := cloudcmd.NewValidator(provider, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -90,7 +90,7 @@ func verify(
|
|||||||
Nonce: nonce,
|
Nonce: nonce,
|
||||||
UserData: userData,
|
UserData: userData,
|
||||||
},
|
},
|
||||||
validators.V()[0],
|
validators.V(),
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -21,8 +21,8 @@ type KeyClient struct {
|
|||||||
// The connection must be closed using Close(). If connect is
|
// The connection must be closed using Close(). If connect is
|
||||||
// called on a client that already has a connection, the old
|
// called on a client that already has a connection, the old
|
||||||
// connection is closed.
|
// connection is closed.
|
||||||
func (c *KeyClient) Connect(endpoint string, validators []atls.Validator) error {
|
func (c *KeyClient) Connect(endpoint string, validators atls.Validator) error {
|
||||||
creds := atlscredentials.New(nil, validators)
|
creds := atlscredentials.New(nil, []atls.Validator{validators})
|
||||||
|
|
||||||
conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(creds))
|
conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(creds))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
Loading…
Reference in New Issue
Block a user