attestation: make AWS TPM check use the correct region

This commit is contained in:
Leonard Cohnen 2022-11-02 15:19:13 +01:00 committed by 3u13r
parent 37e8f5fc28
commit f199b08068
2 changed files with 8 additions and 8 deletions

View File

@ -24,7 +24,7 @@ import (
type Validator struct { type Validator struct {
oid.AWS oid.AWS
*vtpm.Validator *vtpm.Validator
getDescribeClient func(context.Context) (awsMetadataAPI, error) getDescribeClient func(context.Context, string) (awsMetadataAPI, error)
} }
// NewValidator create a new Validator structure and returns it. // NewValidator create a new Validator structure and returns it.
@ -62,14 +62,14 @@ func (v *Validator) tpmEnabled(attestation vtpm.AttestationDocument) error {
ctx := context.Background() ctx := context.Background()
idDocument := imds.InstanceIdentityDocument{} idDocument := imds.InstanceIdentityDocument{}
err := json.Unmarshal(attestation.UserData, &idDocument) err := json.Unmarshal(attestation.InstanceInfo, &idDocument)
if err != nil { if err != nil {
return err return err
} }
imageID := idDocument.ImageID imageID := idDocument.ImageID
client, err := v.getDescribeClient(ctx) client, err := v.getDescribeClient(ctx, idDocument.Region)
if err != nil { if err != nil {
return err return err
} }
@ -87,8 +87,8 @@ func (v *Validator) tpmEnabled(attestation vtpm.AttestationDocument) error {
return fmt.Errorf("iam image %s does not support TPM v2.0", imageID) return fmt.Errorf("iam image %s does not support TPM v2.0", imageID)
} }
func getEC2Client(ctx context.Context) (awsMetadataAPI, error) { func getEC2Client(ctx context.Context, region string) (awsMetadataAPI, error) {
client, err := config.LoadDefaultConfig(ctx, config.WithEC2IMDSRegion()) client, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -59,7 +59,7 @@ func TestTpmEnabled(t *testing.T) {
} }
userDataNoTPM, _ := json.Marshal(idDocNoTPM) userDataNoTPM, _ := json.Marshal(idDocNoTPM)
attDocNoTPM := vtpm.AttestationDocument{ attDocNoTPM := vtpm.AttestationDocument{
UserData: userDataNoTPM, InstanceInfo: userDataNoTPM,
} }
idDocTPM := imds.InstanceIdentityDocument{ idDocTPM := imds.InstanceIdentityDocument{
@ -67,7 +67,7 @@ func TestTpmEnabled(t *testing.T) {
} }
userDataTPM, _ := json.Marshal(idDocTPM) userDataTPM, _ := json.Marshal(idDocTPM)
attDocTPM := vtpm.AttestationDocument{ attDocTPM := vtpm.AttestationDocument{
UserData: userDataTPM, InstanceInfo: userDataTPM,
} }
testCases := map[string]struct { testCases := map[string]struct {
@ -103,7 +103,7 @@ func TestTpmEnabled(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
v := Validator{ v := Validator{
getDescribeClient: func(context.Context) (awsMetadataAPI, error) { getDescribeClient: func(context.Context, string) (awsMetadataAPI, error) {
return tc.awsAPI, nil return tc.awsAPI, nil
}, },
} }