constellation/internal/update/update_test.go

280 lines
6.4 KiB
Go

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package update
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestValidate(t *testing.T) {
testCases := map[string]struct {
listFunc func() *VersionsList
overrideFunc func(list *VersionsList)
wantErr bool
}{
"valid major list": {
listFunc: majorList,
},
"valid minor list": {
listFunc: minorList,
},
"invalid stream": {
listFunc: majorList,
overrideFunc: func(list *VersionsList) {
list.Stream = "invalid"
},
wantErr: true,
},
"invalid granularity": {
listFunc: majorList,
overrideFunc: func(list *VersionsList) {
list.Granularity = "invalid"
},
wantErr: true,
},
"invalid kind": {
listFunc: majorList,
overrideFunc: func(list *VersionsList) {
list.Kind = "invalid"
},
wantErr: true,
},
"base ver is not semantic version": {
listFunc: majorList,
overrideFunc: func(list *VersionsList) {
list.Base = "invalid"
},
wantErr: true,
},
"base ver does not reflect major granularity": {
listFunc: majorList,
overrideFunc: func(list *VersionsList) {
list.Base = "v1.0"
},
wantErr: true,
},
"base ver does not reflect minor granularity": {
listFunc: minorList,
overrideFunc: func(list *VersionsList) {
list.Base = "v1"
},
wantErr: true,
},
"version in list is not semantic version": {
listFunc: majorList,
overrideFunc: func(list *VersionsList) {
list.Versions[0] = "invalid"
},
wantErr: true,
},
"version in list is not sub version of base": {
listFunc: majorList,
overrideFunc: func(list *VersionsList) {
list.Versions[0] = "v2.1"
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
list := tc.listFunc()
if tc.overrideFunc != nil {
tc.overrideFunc(list)
}
err := list.Validate()
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
})
}
}
func TestList(t *testing.T) {
majorListJSON, err := json.Marshal(majorList())
require.NoError(t, err)
minorListJSON, err := json.Marshal(minorList())
require.NoError(t, err)
inconsistentList := majorList()
inconsistentList.Base = "v2"
inconsistentListJSON, err := json.Marshal(inconsistentList)
require.NoError(t, err)
client := newTestClient(func(req *http.Request) *http.Response {
switch req.URL.Path {
case "/constellation/v1/updates/stable/major/v1/image.json":
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBuffer(majorListJSON)),
Header: make(http.Header),
}
case "/constellation/v1/updates/stable/minor/v1.1/image.json":
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBuffer(minorListJSON)),
Header: make(http.Header),
}
case "/constellation/v1/updates/stable/major/v1/500.json": // 500 error
return &http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(bytes.NewBufferString("Server Error.")),
Header: make(http.Header),
}
case "/constellation/v1/updates/stable/major/v1/nojson.json": // invalid format
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString("not json")),
Header: make(http.Header),
}
case "/constellation/v1/updates/stable/major/v2/image.json": // inconsistent list
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBuffer(inconsistentListJSON)),
Header: make(http.Header),
}
case "/constellation/v1/updates/stable/major/v3/image.json": // does not match requested version
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBuffer(minorListJSON)),
Header: make(http.Header),
}
}
return &http.Response{
StatusCode: http.StatusNotFound,
Body: io.NopCloser(bytes.NewBufferString("Not found.")),
Header: make(http.Header),
}
})
testCases := map[string]struct {
stream, granularity, base, kind string
overrideFile string
wantList VersionsList
wantErr bool
}{
"major list fetched remotely": {
wantList: *majorList(),
},
"minor list fetched remotely": {
granularity: "minor",
base: "v1.1",
wantList: *minorList(),
},
"list does not exist": {
stream: "unknown",
wantErr: true,
},
"unexpected error code": {
kind: "500",
wantErr: true,
},
"invalid json returned": {
kind: "nojson",
wantErr: true,
},
"invalid list returned": {
base: "v2",
wantErr: true,
},
"response does not match request": {
base: "v3",
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
stream := "stable"
granularity := "major"
base := "v1"
kind := "image"
if tc.stream != "" {
stream = tc.stream
}
if tc.granularity != "" {
granularity = tc.granularity
}
if tc.base != "" {
base = tc.base
}
if tc.kind != "" {
kind = tc.kind
}
fetcher := &VersionsFetcher{
httpc: client,
}
list, err := fetcher.list(context.Background(), stream, granularity, base, kind)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantList, *list)
})
}
}
// roundTripFunc .
type roundTripFunc func(req *http.Request) *http.Response
// RoundTrip .
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req), nil
}
// newTestClient returns *http.Client with Transport replaced to avoid making real calls.
func newTestClient(fn roundTripFunc) *http.Client {
return &http.Client{
Transport: fn,
}
}
func majorList() *VersionsList {
return &VersionsList{
Stream: "stable",
Granularity: "major",
Base: "v1",
Kind: "image",
Versions: []string{
"v1.0", "v1.1", "v1.2",
},
}
}
func minorList() *VersionsList {
return &VersionsList{
Stream: "stable",
Granularity: "minor",
Base: "v1.1",
Kind: "image",
Versions: []string{
"v1.1.0", "v1.1.1", "v1.1.2",
},
}
}