attestation: make AWS TPM check use the correct region

This commit is contained in:
Leonard Cohnen 2022-11-02 15:19:13 +01:00 committed by 3u13r
parent 37e8f5fc28
commit f199b08068
2 changed files with 8 additions and 8 deletions

View File

@ -24,7 +24,7 @@ import (
type Validator struct {
oid.AWS
*vtpm.Validator
getDescribeClient func(context.Context) (awsMetadataAPI, error)
getDescribeClient func(context.Context, string) (awsMetadataAPI, error)
}
// NewValidator create a new Validator structure and returns it.
@ -62,14 +62,14 @@ func (v *Validator) tpmEnabled(attestation vtpm.AttestationDocument) error {
ctx := context.Background()
idDocument := imds.InstanceIdentityDocument{}
err := json.Unmarshal(attestation.UserData, &idDocument)
err := json.Unmarshal(attestation.InstanceInfo, &idDocument)
if err != nil {
return err
}
imageID := idDocument.ImageID
client, err := v.getDescribeClient(ctx)
client, err := v.getDescribeClient(ctx, idDocument.Region)
if err != nil {
return err
}
@ -87,8 +87,8 @@ func (v *Validator) tpmEnabled(attestation vtpm.AttestationDocument) error {
return fmt.Errorf("iam image %s does not support TPM v2.0", imageID)
}
func getEC2Client(ctx context.Context) (awsMetadataAPI, error) {
client, err := config.LoadDefaultConfig(ctx, config.WithEC2IMDSRegion())
func getEC2Client(ctx context.Context, region string) (awsMetadataAPI, error) {
client, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
if err != nil {
return nil, err
}

View File

@ -59,7 +59,7 @@ func TestTpmEnabled(t *testing.T) {
}
userDataNoTPM, _ := json.Marshal(idDocNoTPM)
attDocNoTPM := vtpm.AttestationDocument{
UserData: userDataNoTPM,
InstanceInfo: userDataNoTPM,
}
idDocTPM := imds.InstanceIdentityDocument{
@ -67,7 +67,7 @@ func TestTpmEnabled(t *testing.T) {
}
userDataTPM, _ := json.Marshal(idDocTPM)
attDocTPM := vtpm.AttestationDocument{
UserData: userDataTPM,
InstanceInfo: userDataTPM,
}
testCases := map[string]struct {
@ -103,7 +103,7 @@ func TestTpmEnabled(t *testing.T) {
assert := assert.New(t)
v := Validator{
getDescribeClient: func(context.Context) (awsMetadataAPI, error) {
getDescribeClient: func(context.Context, string) (awsMetadataAPI, error) {
return tc.awsAPI, nil
},
}