gcp: pass context to metadata functions (#3228)

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2024-07-03 14:41:29 +02:00 committed by GitHub
parent 7b6c3a710e
commit 20269ab46e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 23 additions and 23 deletions

View File

@ -91,14 +91,14 @@ type fakeMetadataClient struct {
zoneErr error
}
func (c fakeMetadataClient) ProjectID() (string, error) {
func (c fakeMetadataClient) ProjectID(_ context.Context) (string, error) {
return c.projectIDString, c.projecIDErr
}
func (c fakeMetadataClient) InstanceName() (string, error) {
func (c fakeMetadataClient) InstanceName(_ context.Context) (string, error) {
return c.instanceNameString, c.instanceNameErr
}
func (c fakeMetadataClient) Zone() (string, error) {
func (c fakeMetadataClient) Zone(_ context.Context) (string, error) {
return c.zoneString, c.zoneErr
}

View File

@ -21,17 +21,17 @@ func GCEInstanceInfo(client gcpMetadataClient) func(context.Context, io.ReadWrit
// Ideally we would want to use the endorsement public key certificate
// However, this is not available on GCE instances
// Workaround: Provide ShieldedVM instance info
// The attestating party can request the VMs signing key using Google's API
return func(context.Context, io.ReadWriteCloser, []byte) ([]byte, error) {
projectID, err := client.ProjectID()
// The attesting party can request the VMs signing key using Google's API
return func(ctx context.Context, _ io.ReadWriteCloser, _ []byte) ([]byte, error) {
projectID, err := client.ProjectID(ctx)
if err != nil {
return nil, errors.New("unable to fetch projectID")
}
zone, err := client.Zone()
zone, err := client.Zone(ctx)
if err != nil {
return nil, errors.New("unable to fetch zone")
}
instanceName, err := client.InstanceName()
instanceName, err := client.InstanceName(ctx)
if err != nil {
return nil, errors.New("unable to fetch instance name")
}
@ -45,25 +45,25 @@ func GCEInstanceInfo(client gcpMetadataClient) func(context.Context, io.ReadWrit
}
type gcpMetadataClient interface {
ProjectID() (string, error)
InstanceName() (string, error)
Zone() (string, error)
ProjectID(context.Context) (string, error)
InstanceName(context.Context) (string, error)
Zone(context.Context) (string, error)
}
// A MetadataClient fetches metadata from the GCE Metadata API.
type MetadataClient struct{}
// ProjectID returns the project ID of the GCE instance.
func (c MetadataClient) ProjectID() (string, error) {
return metadata.ProjectIDWithContext(context.Background())
func (c MetadataClient) ProjectID(ctx context.Context) (string, error) {
return metadata.ProjectIDWithContext(ctx)
}
// InstanceName returns the instance name of the GCE instance.
func (c MetadataClient) InstanceName() (string, error) {
return metadata.InstanceNameWithContext(context.Background())
func (c MetadataClient) InstanceName(ctx context.Context) (string, error) {
return metadata.InstanceNameWithContext(ctx)
}
// Zone returns the zone the GCE instance is located in.
func (c MetadataClient) Zone() (string, error) {
return metadata.ZoneWithContext(context.Background())
func (c MetadataClient) Zone(ctx context.Context) (string, error) {
return metadata.ZoneWithContext(ctx)
}

View File

@ -57,7 +57,7 @@ func getAttestationKey(tpm io.ReadWriter) (*tpmclient.Key, error) {
// getInstanceInfo generates an extended SNP report, i.e. the report and any loaded certificates.
// Report generation is triggered by sending ioctl syscalls to the SNP guest device, the AMD PSP generates the report.
// The returned bytes will be written into the attestation document.
func getInstanceInfo(_ context.Context, _ io.ReadWriteCloser, extraData []byte) ([]byte, error) {
func getInstanceInfo(ctx context.Context, _ io.ReadWriteCloser, extraData []byte) ([]byte, error) {
if len(extraData) > 64 {
return nil, fmt.Errorf("extra data too long: %d, should be 64 bytes at most", len(extraData))
}
@ -74,7 +74,7 @@ func getInstanceInfo(_ context.Context, _ io.ReadWriteCloser, extraData []byte)
return nil, fmt.Errorf("parsing vcek: %w", err)
}
gceInstanceInfo, err := gceInstanceInfo()
gceInstanceInfo, err := gceInstanceInfo(ctx)
if err != nil {
return nil, fmt.Errorf("getting GCE instance info: %w", err)
}
@ -93,20 +93,20 @@ func getInstanceInfo(_ context.Context, _ io.ReadWriteCloser, extraData []byte)
}
// gceInstanceInfo returns the instance info for a GCE instance from the metadata API.
func gceInstanceInfo() (*attest.GCEInstanceInfo, error) {
func gceInstanceInfo(ctx context.Context) (*attest.GCEInstanceInfo, error) {
c := gcp.MetadataClient{}
instanceName, err := c.InstanceName()
instanceName, err := c.InstanceName(ctx)
if err != nil {
return nil, fmt.Errorf("getting instance name: %w", err)
}
projectID, err := c.ProjectID()
projectID, err := c.ProjectID(ctx)
if err != nil {
return nil, fmt.Errorf("getting project ID: %w", err)
}
zone, err := c.Zone()
zone, err := c.Zone(ctx)
if err != nil {
return nil, fmt.Errorf("getting zone: %w", err)
}