versionsapi: new fetcher implementation

Signed-off-by: Paul Meyer <49727155+katexochen@users.noreply.github.com>
This commit is contained in:
Paul Meyer 2022-12-29 17:24:08 +01:00
parent 3f00f89d55
commit 9dbe6033f2
2 changed files with 328 additions and 0 deletions

View File

@ -0,0 +1,110 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package fetcher
import (
"context"
"encoding/json"
"fmt"
"net/http"
"github.com/edgelesssys/constellation/v2/internal/versionsapi"
)
// Fetcher fetches versions API resources.
//
// The fetcher is used to get information from the versions API without having to
// authenticate with AWS. It is the interface that should be used in user-facing
// application code most of the time.
type Fetcher struct {
httpc httpc
}
// NewFetcher returns a new Fetcher.
func NewFetcher() *Fetcher {
return &Fetcher{
httpc: http.DefaultClient,
}
}
// FetchVersionList fetches the given version list from the versions API.
func (f *Fetcher) FetchVersionList(ctx context.Context, list versionsapi.List) (versionsapi.List, error) {
return fetch(ctx, f.httpc, list)
}
// FetchVersionLatest fetches the latest version from the versions API.
func (f *Fetcher) FetchVersionLatest(ctx context.Context, latest versionsapi.Latest) (versionsapi.Latest, error) {
return fetch(ctx, f.httpc, latest)
}
// FetchImageInfo fetches the given image info from the versions API.
func (f *Fetcher) FetchImageInfo(ctx context.Context, imageInfo versionsapi.ImageInfo) (versionsapi.ImageInfo, error) {
return fetch(ctx, f.httpc, imageInfo)
}
type apiObject interface {
ValidateRequest() error
Validate() error
URL() (string, error)
}
func fetch[T apiObject](ctx context.Context, c httpc, obj T) (T, error) {
if err := obj.ValidateRequest(); err != nil {
return *new(T), fmt.Errorf("validating request for %T: %w", obj, err)
}
url, err := obj.URL()
if err != nil {
return *new(T), fmt.Errorf("getting URL for %T: %w", obj, err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
if err != nil {
return *new(T), fmt.Errorf("creating request for %T: %w", obj, err)
}
resp, err := c.Do(req)
if err != nil {
return *new(T), fmt.Errorf("sending request for %T: %w", obj, err)
}
defer resp.Body.Close()
switch resp.StatusCode {
case http.StatusOK:
case http.StatusNotFound:
return *new(T), &NotFoundError{fmt.Errorf("requesting resource at %s returned status code 404", url)}
default:
return *new(T), fmt.Errorf("unexpected status code %d while requesting resource", resp.StatusCode)
}
var newObj T
if err := json.NewDecoder(resp.Body).Decode(&newObj); err != nil {
return *new(T), fmt.Errorf("decoding %T: %w", obj, err)
}
if newObj.Validate() != nil {
return *new(T), fmt.Errorf("received invalid %T: %w", newObj, newObj.Validate())
}
return newObj, nil
}
// NotFoundError is an error that is returned when a resource is not found.
type NotFoundError struct {
err error
}
func (e *NotFoundError) Error() string {
return fmt.Sprintf("the requested resource was not found: %s", e.err.Error())
}
func (e *NotFoundError) Unwrap() error {
return e.err
}
type httpc interface {
Do(req *http.Request) (*http.Response, error)
}

View File

@ -0,0 +1,218 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package fetcher
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/edgelesssys/constellation/v2/internal/versionsapi"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestFetchVersionList(t *testing.T) {
require := require.New(t)
majorList := func() *versionsapi.List {
return &versionsapi.List{
Ref: "test-ref",
Stream: "nightly",
Granularity: versionsapi.GranularityMajor,
Base: "v1",
Kind: versionsapi.VersionKindImage,
Versions: []string{"v1.0", "v1.1", "v1.2"},
}
}
minorList := func() *versionsapi.List {
return &versionsapi.List{
Ref: "test-ref",
Stream: "nightly",
Granularity: versionsapi.GranularityMinor,
Base: "v1.1",
Kind: versionsapi.VersionKindImage,
Versions: []string{"v1.1.0", "v1.1.1", "v1.1.2"},
}
}
majorListJSON, err := json.Marshal(majorList())
require.NoError(err)
minorListJSON, err := json.Marshal(minorList())
require.NoError(err)
inconsistentList := majorList()
inconsistentList.Base = "v2"
inconsistentListJSON, err := json.Marshal(inconsistentList)
require.NoError(err)
testCases := map[string]struct {
list versionsapi.List
serverPath string
serverResp *http.Response
wantList versionsapi.List
wantErr bool
}{
"major list fetched": {
list: versionsapi.List{
Ref: "test-ref",
Stream: "nightly",
Granularity: versionsapi.GranularityMajor,
Base: "v1",
Kind: versionsapi.VersionKindImage,
},
serverPath: "/constellation/v1/ref/test-ref/stream/nightly/versions/major/v1/image.json",
serverResp: &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBuffer(majorListJSON)),
},
wantList: *majorList(),
},
"minor list fetched": {
list: versionsapi.List{
Ref: "test-ref",
Stream: "nightly",
Granularity: versionsapi.GranularityMinor,
Base: "v1.1",
Kind: versionsapi.VersionKindImage,
},
serverPath: "/constellation/v1/ref/test-ref/stream/nightly/versions/minor/v1.1/image.json",
serverResp: &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBuffer(minorListJSON)),
},
wantList: *minorList(),
},
"list does not exist": {
list: versionsapi.List{
Ref: "another-ref",
Stream: "nightly",
Granularity: versionsapi.GranularityMajor,
Base: "v1",
Kind: versionsapi.VersionKindImage,
},
wantErr: true,
},
"invalid list requested": {
list: versionsapi.List{
Ref: "",
Stream: "unknown",
Granularity: versionsapi.GranularityMajor,
Base: "v1",
Kind: versionsapi.VersionKindImage,
},
wantErr: true,
},
"unexpected error code": {
list: versionsapi.List{
Ref: "test-ref",
Stream: "nightly",
Granularity: versionsapi.GranularityMajor,
Base: "v1",
Kind: versionsapi.VersionKindImage,
},
serverPath: "/constellation/v1/ref/test-ref/stream/nightly/versions/major/v1/image.json",
serverResp: &http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(bytes.NewBufferString("Internal Server Error")),
},
wantErr: true,
},
"invalid json returned": {
list: versionsapi.List{
Ref: "test-ref",
Stream: "nightly",
Granularity: versionsapi.GranularityMajor,
Base: "v1",
Kind: versionsapi.VersionKindImage,
},
serverPath: "/constellation/v1/ref/test-ref/stream/nightly/versions/major/v1/image.json",
serverResp: &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString("invalid json")),
},
wantErr: true,
},
"invalid list returned": {
list: versionsapi.List{
Ref: "test-ref",
Stream: "nightly",
Granularity: versionsapi.GranularityMajor,
Base: "v2",
Kind: versionsapi.VersionKindImage,
},
serverPath: "/constellation/v1/ref/test-ref/stream/nightly/versions/major/v2/image.json",
serverResp: &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBuffer(inconsistentListJSON)),
},
wantErr: true,
},
// TODO(katexochen): Remove or find strategy to implement this check in a generic way
// "response does not match request": {
// list: versionsapi.List{
// Ref: "test-ref",
// Stream: "nightly",
// Granularity: versionsapi.GranularityMajor,
// Base: "v3",
// Kind: versionsapi.VersionKindImage,
// },
// serverPath: "/constellation/v1/ref/test-ref/stream/nightly/versions/major/v3/image.json",
// serverResp: &http.Response{
// StatusCode: http.StatusOK,
// Body: io.NopCloser(bytes.NewBuffer(minorListJSON)),
// },
// wantErr: true,
// },
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := newTestClient(func(req *http.Request) *http.Response {
if req.URL.Path != tc.serverPath {
return &http.Response{
StatusCode: http.StatusNotFound,
Body: io.NopCloser(bytes.NewBufferString("Not found.")),
}
}
return tc.serverResp
})
fetcher := &Fetcher{httpc: client}
list, err := fetcher.FetchVersionList(context.Background(), tc.list)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
assert.Equal(tc.wantList, list)
})
}
}
type roundTripFunc func(req *http.Request) *http.Response
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req), nil
}
// newTestClient returns *http.Client with Transport replaced to avoid making real calls.
func newTestClient(fn roundTripFunc) *http.Client {
return &http.Client{
Transport: fn,
}
}