constellation/internal/cloud/azure/imds_test.go

181 lines
4.5 KiB
Go
Raw Normal View History

package azure
import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/test/bufconn"
)
func TestIMDSClient(t *testing.T) {
uidTags := []metadataTag{{Name: "uid", Value: "uid"}}
response := metadataResponse{
Compute: metadataResponseCompute{
ResourceID: "resource-id",
ResourceGroup: "resource-group",
Tags: uidTags,
},
}
responseWithoutID := metadataResponse{
Compute: metadataResponseCompute{
ResourceGroup: "resource-group",
Tags: uidTags,
},
}
responseWithoutGroup := metadataResponse{
Compute: metadataResponseCompute{
ResourceID: "resource-id",
Tags: uidTags,
},
}
responseWithoutUID := metadataResponse{
Compute: metadataResponseCompute{
ResourceID: "resource-id",
ResourceGroup: "resource-group",
},
}
testCases := map[string]struct {
server httpBufconnServer
wantProviderIDErr bool
wantProviderID string
wantResourceGroupErr bool
wantResourceGroup string
wantUIDErr bool
wantUID string
}{
"metadata response parsed": {
server: newHTTPBufconnServerWithMetadataResponse(response),
wantProviderID: "resource-id",
wantResourceGroup: "resource-group",
wantUID: "uid",
},
"metadata response without resource ID": {
server: newHTTPBufconnServerWithMetadataResponse(responseWithoutID),
wantProviderIDErr: true,
wantResourceGroup: "resource-group",
wantUID: "uid",
},
"metadata response without UID tag": {
server: newHTTPBufconnServerWithMetadataResponse(responseWithoutUID),
wantProviderID: "resource-id",
wantResourceGroup: "resource-group",
wantUIDErr: true,
},
"metadata response without resource group": {
server: newHTTPBufconnServerWithMetadataResponse(responseWithoutGroup),
wantProviderID: "resource-id",
wantResourceGroupErr: true,
wantUID: "uid",
},
"invalid imds response detected": {
server: newHTTPBufconnServer(func(writer http.ResponseWriter, request *http.Request) {
fmt.Fprintln(writer, "invalid-result")
}),
wantProviderIDErr: true,
wantResourceGroupErr: true,
wantUIDErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
defer tc.server.Close()
hClient := http.Client{
Transport: &http.Transport{
DialContext: tc.server.DialContext,
Dial: tc.server.Dial,
DialTLSContext: tc.server.DialContext,
DialTLS: tc.server.Dial,
},
}
iClient := imdsClient{client: &hClient}
ctx := context.Background()
id, err := iClient.ProviderID(ctx)
if tc.wantProviderIDErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.wantProviderID, id)
}
group, err := iClient.ResourceGroup(ctx)
if tc.wantResourceGroupErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.wantResourceGroup, group)
}
uid, err := iClient.UID(ctx)
if tc.wantUIDErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.wantUID, uid)
}
})
}
}
type httpBufconnServer struct {
*httptest.Server
*bufconn.Listener
}
func (s *httpBufconnServer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
return s.Listener.DialContext(ctx)
}
func (s *httpBufconnServer) Dial(network, addr string) (net.Conn, error) {
return s.Listener.Dial()
}
func (s *httpBufconnServer) Close() {
s.Server.Close()
s.Listener.Close()
}
func newHTTPBufconnServer(handlerFunc http.HandlerFunc) httpBufconnServer {
server := httptest.NewUnstartedServer(handlerFunc)
listener := bufconn.Listen(1024)
server.Listener = listener
server.Start()
return httpBufconnServer{
Server: server,
Listener: listener,
}
}
func newHTTPBufconnServerWithMetadataResponse(res metadataResponse) httpBufconnServer {
return newHTTPBufconnServer(func(writer http.ResponseWriter, request *http.Request) {
if request.Host != "169.254.169.254" || request.Header.Get("Metadata") != "True" || request.URL.Query().Get("format") != "json" || request.URL.Query().Get("api-version") != imdsAPIVersion {
writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("error"))
if err != nil {
panic(err)
}
return
}
rawResp, err := json.Marshal(res)
if err != nil {
panic(err)
}
_, err = writer.Write(rawResp)
if err != nil {
panic(err)
}
})
}