constellation/internal/installer/installer_test.go

774 lines
18 KiB
Go
Raw Normal View History

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package installer
import (
"archive/tar"
"bufio"
"bytes"
"compress/gzip"
"context"
"errors"
"io"
"io/fs"
"net"
"net/http"
"net/http/httptest"
"path"
"sync"
"testing"
"time"
"github.com/edgelesssys/constellation/v2/internal/versions/components"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/test/bufconn"
testclock "k8s.io/utils/clock/testing"
)
func TestInstall(t *testing.T) {
serverURL := "http://server/path"
testCases := map[string]struct {
server httpBufconnServer
component *components.Component
hash string
destination string
extract bool
wantErr bool
wantFiles map[string][]byte
}{
"download works": {
server: newHTTPBufconnServerWithBody([]byte("file-contents")),
component: &components.Component{
Url: serverURL,
Hash: "sha256:f03779b36bece74893fd6533a67549675e21573eb0e288d87158738f9c24594e",
InstallPath: "/destination",
},
wantFiles: map[string][]byte{"/destination": []byte("file-contents")},
},
"download with extract works": {
server: newHTTPBufconnServerWithBody(createTarGz([]byte("file-contents"), "/destination")),
component: &components.Component{
Url: serverURL,
Hash: "sha256:a52a1664ca0a6ec9790384e3d058852ab8b3a8f389a9113d150fdc6ab308d949",
InstallPath: "/prefix",
Extract: true,
},
wantFiles: map[string][]byte{"/prefix/destination": []byte("file-contents")},
},
"hash validation fails": {
server: newHTTPBufconnServerWithBody([]byte("file-contents")),
component: &components.Component{
Url: serverURL,
Hash: "sha256:abc",
InstallPath: "/destination",
},
wantErr: true,
},
2023-12-08 18:19:24 +01:00
"hash is not mandatory": {
server: newHTTPBufconnServerWithBody([]byte("file-contents")),
component: &components.Component{
Url: serverURL,
2023-12-08 18:19:24 +01:00
Hash: "",
InstallPath: "/destination",
},
wantFiles: map[string][]byte{"/destination": []byte("file-contents")},
},
"download fails": {
server: newHTTPBufconnServer(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(500) }),
component: &components.Component{
Url: serverURL,
Hash: "sha256:abc",
InstallPath: "/destination",
},
wantErr: true,
},
"dataurl works": {
server: newHTTPBufconnServer(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(500) }),
component: &components.Component{
Url: "data:text/plain,file-contents",
Hash: "",
InstallPath: "/destination",
},
wantFiles: map[string][]byte{"/destination": []byte("file-contents")},
},
"broken dataurl fails": {
server: newHTTPBufconnServer(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(500) }),
component: &components.Component{
Url: "data:file-contents",
Hash: "",
InstallPath: "/destination",
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
defer tc.server.Close()
hClient := http.Client{
Transport: &http.Transport{
DialContext: tc.server.DialContext,
Dial: tc.server.Dial,
DialTLSContext: tc.server.DialContext,
DialTLS: tc.server.Dial,
},
}
// This test was written before retriability was added to Install. It makes sense to test Install as if it wouldn't retry requests.
inst := OsInstaller{
fs: &afero.Afero{Fs: afero.NewMemMapFs()},
hClient: &hClient,
clock: testclock.NewFakeClock(time.Time{}),
retriable: func(_ error) bool { return false },
}
err := inst.Install(context.Background(), tc.component)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
for path, wantContents := range tc.wantFiles {
contents, err := inst.fs.ReadFile(path)
assert.NoError(err)
assert.Equal(wantContents, contents)
}
})
}
}
func TestExtractArchive(t *testing.T) {
tarGzTestFile := createTarGz([]byte("file-contents"), "/destination")
tarGzTestWithFolder := createTarGzWithFolder([]byte("file-contents"), "/folder/destination", nil)
testCases := map[string]struct {
source string
destination string
contents []byte
readonly bool
wantErr bool
wantFiles map[string][]byte
}{
"extract works": {
source: "in.tar.gz",
destination: "/prefix",
contents: tarGzTestFile,
wantFiles: map[string][]byte{
"/prefix/destination": []byte("file-contents"),
},
},
"extract with folder works": {
source: "in.tar.gz",
destination: "/prefix",
contents: tarGzTestWithFolder,
wantFiles: map[string][]byte{
"/prefix/folder/destination": []byte("file-contents"),
},
},
"source missing": {
source: "in.tar.gz",
destination: "/prefix",
wantErr: true,
},
"non-gzip file contents": {
source: "in.tar.gz",
contents: []byte("invalid bytes"),
destination: "/prefix",
wantErr: true,
},
"non-tar file contents": {
source: "in.tar.gz",
contents: createGz([]byte("file-contents")),
destination: "/prefix",
wantErr: true,
},
"mkdir prefix dir fails on RO fs": {
source: "in.tar.gz",
contents: tarGzTestFile,
destination: "/prefix",
readonly: true,
wantErr: true,
},
"mkdir tar dir fails on RO fs": {
source: "in.tar.gz",
contents: tarGzTestWithFolder,
destination: "/",
readonly: true,
wantErr: true,
},
"writing tar file fails on RO fs": {
source: "in.tar.gz",
contents: tarGzTestFile,
destination: "/",
readonly: true,
wantErr: true,
},
"symlink can be detected (but is unsupported on memmapfs)": {
source: "in.tar.gz",
contents: createTarGzWithSymlink("source", "dest"),
destination: "/prefix",
wantErr: true,
},
"unsupported tar header type is detected": {
source: "in.tar.gz",
contents: createTarGzWithFifo("/destination"),
destination: "/prefix",
wantErr: true,
},
"path traversal is detected": {
source: "in.tar.gz",
contents: createTarGz([]byte{}, "../destination"),
wantErr: true,
},
"path traversal in symlink is detected": {
source: "in.tar.gz",
contents: createTarGzWithSymlink("/source", "../destination"),
wantErr: true,
},
"empty file name is detected": {
source: "in.tar.gz",
contents: createTarGz([]byte{}, ""),
wantErr: true,
},
"empty folder name is detected": {
source: "in.tar.gz",
contents: createTarGzWithFolder([]byte{}, "source", stringPtr("")),
wantErr: true,
},
"empty symlink source is detected": {
source: "in.tar.gz",
contents: createTarGzWithSymlink("", "/target"),
wantErr: true,
},
"empty symlink target is detected": {
source: "in.tar.gz",
contents: createTarGzWithSymlink("/source", ""),
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
afs := afero.NewMemMapFs()
if len(tc.source) > 0 && len(tc.contents) > 0 {
require.NoError(afero.WriteFile(afs, tc.source, tc.contents, fs.ModePerm))
}
if tc.readonly {
afs = afero.NewReadOnlyFs(afs)
}
inst := OsInstaller{
fs: &afero.Afero{Fs: afs},
}
err := inst.extractArchive(tc.source, tc.destination, fs.ModePerm)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
for path, wantContents := range tc.wantFiles {
contents, err := inst.fs.ReadFile(path)
assert.NoError(err)
assert.Equal(wantContents, contents)
}
})
}
}
func TestRetryDownloadToTempDir(t *testing.T) {
testCases := map[string]struct {
responses []int
cancelCtx bool
wantErr bool
wantFile []byte
}{
"Succeed on third try": {
responses: []int{500, 500, 200},
wantFile: []byte("file-content"),
},
"Cancel after second try": {
responses: []int{500, 500},
cancelCtx: true,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
// control the server's responses through stateCh
stateCh := make(chan int)
server := newHTTPBufconnServerWithState(stateCh, tc.wantFile)
defer server.Close()
hClient := http.Client{
Transport: &http.Transport{
DialContext: server.DialContext,
Dial: server.Dial,
DialTLSContext: server.DialContext,
DialTLS: server.Dial,
},
}
afs := afero.NewMemMapFs()
// control download retries through FakeClock clock
clock := testclock.NewFakeClock(time.Now())
inst := OsInstaller{
fs: &afero.Afero{Fs: afs},
hClient: &hClient,
clock: clock,
retriable: func(error) bool { return true },
}
// abort retryDownloadToTempDir in some test cases by using the context
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
wg := sync.WaitGroup{}
var downloadErr error
var path string
wg.Add(1)
go func() {
defer wg.Done()
path, downloadErr = inst.retryDownloadToTempDir(ctx, "http://server/path")
}()
// control the server's responses through stateCh.
for _, resp := range tc.responses {
stateCh <- resp
clock.Step(downloadInterval)
}
if tc.cancelCtx {
cancel()
}
wg.Wait()
if tc.wantErr {
assert.Error(downloadErr)
return
}
require.NoError(downloadErr)
content, err := inst.fs.ReadFile(path)
assert.NoError(err)
assert.Equal(tc.wantFile, content)
})
}
}
func TestDownloadToTempDir(t *testing.T) {
testCases := map[string]struct {
server httpBufconnServer
readonly bool
wantErr bool
wantFile []byte
}{
"download works": {
server: newHTTPBufconnServerWithBody([]byte("file-contents")),
wantFile: []byte("file-contents"),
},
"download fails": {
server: newHTTPBufconnServer(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(500) }),
wantErr: true,
},
"creating temp file fails on RO fs": {
server: newHTTPBufconnServerWithBody([]byte("file-contents")),
readonly: true,
wantErr: true,
},
"content length mismatch": {
server: newHTTPBufconnServer(func(writer http.ResponseWriter, _ *http.Request) {
writer.Header().Set("Content-Length", "1337")
writer.WriteHeader(200)
}),
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
defer tc.server.Close()
hClient := http.Client{
Transport: &http.Transport{
DialContext: tc.server.DialContext,
Dial: tc.server.Dial,
DialTLSContext: tc.server.DialContext,
DialTLS: tc.server.Dial,
},
}
afs := afero.NewMemMapFs()
if tc.readonly {
afs = afero.NewReadOnlyFs(afs)
}
inst := OsInstaller{
fs: &afero.Afero{Fs: afs},
hClient: &hClient,
}
path, err := inst.downloadToTempDir(context.Background(), "http://server/path")
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
contents, err := inst.fs.ReadFile(path)
assert.NoError(err)
assert.Equal(tc.wantFile, contents)
})
}
}
func TestCopy(t *testing.T) {
contents := []byte("file-contents")
existingFile := "/source"
testCases := map[string]struct {
oldname string
newname string
perm fs.FileMode
readonly bool
wantErr bool
}{
"copy works": {
oldname: existingFile,
newname: "/destination",
perm: fs.ModePerm,
},
"oldname does not exist": {
oldname: "missing",
newname: "/destination",
wantErr: true,
},
"copy on readonly fs fails": {
oldname: existingFile,
newname: "/destination",
perm: fs.ModePerm,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
afs := afero.NewMemMapFs()
require.NoError(afero.WriteFile(afs, existingFile, contents, fs.ModePerm))
if tc.readonly {
afs = afero.NewReadOnlyFs(afs)
}
inst := OsInstaller{fs: &afero.Afero{Fs: afs}}
err := inst.copy(tc.oldname, tc.newname, tc.perm)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
oldfile, err := afs.Open(tc.oldname)
assert.NoError(err)
newfile, err := afs.Open(tc.newname)
assert.NoError(err)
oldContents, _ := io.ReadAll(oldfile)
newContents, _ := io.ReadAll(newfile)
assert.Equal(oldContents, newContents)
newStat, _ := newfile.Stat()
assert.Equal(tc.perm, newStat.Mode())
})
}
}
func TestVerifyTarPath(t *testing.T) {
testCases := map[string]struct {
path string
wantErr bool
}{
"valid relative path": {
path: "a/b/c",
},
"valid absolute path": {
path: "/a/b/c",
},
"valid path with dot": {
path: "/a/b/.d",
},
"valid path with dots": {
path: "/a/b/..d",
},
"single dot in path is allowed": {
path: ".",
},
"simple path traversal": {
path: "..",
wantErr: true,
},
"simple path traversal 2": {
path: "../",
wantErr: true,
},
"simple path traversal 3": {
path: "/..",
wantErr: true,
},
"simple path traversal 4": {
path: "/../",
wantErr: true,
},
"complex relative path traversal": {
path: "a/b/c/../../../../c/d/e",
wantErr: true,
},
"complex absolute path traversal": {
path: "/a/b/c/../../../../c/d/e",
wantErr: true,
},
"path traversal at the end": {
path: "a/..",
wantErr: true,
},
"path traversal at the end with trailing /": {
path: "a/../",
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
err := verifyTarPath(tc.path)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.path, path.Clean(tc.path))
})
}
}
type httpBufconnServer struct {
*httptest.Server
*bufconn.Listener
}
func (s *httpBufconnServer) DialContext(ctx context.Context, _, _ string) (net.Conn, error) {
return s.Listener.DialContext(ctx)
}
func (s *httpBufconnServer) Dial(_, _ string) (net.Conn, error) {
return s.Listener.Dial()
}
func (s *httpBufconnServer) Close() {
s.Server.Close()
s.Listener.Close()
}
func newHTTPBufconnServer(handlerFunc http.HandlerFunc) httpBufconnServer {
server := httptest.NewUnstartedServer(handlerFunc)
listener := bufconn.Listen(1024)
server.Listener = listener
server.Start()
return httpBufconnServer{
Server: server,
Listener: listener,
}
}
func newHTTPBufconnServerWithBody(body []byte) httpBufconnServer {
return newHTTPBufconnServer(func(writer http.ResponseWriter, _ *http.Request) {
if _, err := writer.Write(body); err != nil {
panic(err)
}
})
}
func newHTTPBufconnServerWithState(state chan int, body []byte) httpBufconnServer {
return newHTTPBufconnServer(func(w http.ResponseWriter, _ *http.Request) {
switch <-state {
case 500:
w.WriteHeader(500)
case 200:
if _, err := w.Write(body); err != nil {
panic(err)
}
default:
w.WriteHeader(402)
}
})
}
func createTarGz(contents []byte, path string) []byte {
tgzWriter := newTarGzWriter()
defer func() { _ = tgzWriter.Close() }()
if err := tgzWriter.writeHeader(&tar.Header{
Typeflag: tar.TypeReg,
Name: path,
Size: int64(len(contents)),
Mode: int64(fs.ModePerm),
}); err != nil {
panic(err)
}
if _, err := tgzWriter.writeTar(contents); err != nil {
panic(err)
}
return tgzWriter.Bytes()
}
func createTarGzWithFolder(contents []byte, pat string, dirnameOverride *string) []byte {
tgzWriter := newTarGzWriter()
defer func() { _ = tgzWriter.Close() }()
dir := path.Dir(pat)
if dirnameOverride != nil {
dir = *dirnameOverride
}
if err := tgzWriter.writeHeader(&tar.Header{
Typeflag: tar.TypeDir,
Name: dir,
Mode: int64(fs.ModePerm),
}); err != nil {
panic(err)
}
if err := tgzWriter.writeHeader(&tar.Header{
Typeflag: tar.TypeReg,
Name: pat,
Size: int64(len(contents)),
Mode: int64(fs.ModePerm),
}); err != nil {
panic(err)
}
if _, err := tgzWriter.writeTar(contents); err != nil {
panic(err)
}
return tgzWriter.Bytes()
}
func createTarGzWithSymlink(oldname, newname string) []byte {
tgzWriter := newTarGzWriter()
defer func() { _ = tgzWriter.Close() }()
if err := tgzWriter.writeHeader(&tar.Header{
Typeflag: tar.TypeSymlink,
Name: oldname,
Linkname: newname,
Mode: int64(fs.ModePerm),
}); err != nil {
panic(err)
}
return tgzWriter.Bytes()
}
func createTarGzWithFifo(name string) []byte {
tgzWriter := newTarGzWriter()
defer func() { _ = tgzWriter.Close() }()
if err := tgzWriter.writeHeader(&tar.Header{
Typeflag: tar.TypeFifo,
Name: name,
Mode: int64(fs.ModePerm),
}); err != nil {
panic(err)
}
return tgzWriter.Bytes()
}
func createGz(contents []byte) []byte {
tgzWriter := newTarGzWriter()
defer func() { _ = tgzWriter.Close() }()
if _, err := tgzWriter.writeGz(contents); err != nil {
panic(err)
}
return tgzWriter.Bytes()
}
type tarGzWriter struct {
buf *bytes.Buffer
bufWriter *bufio.Writer
gzWriter *gzip.Writer
tarWriter *tar.Writer
}
func newTarGzWriter() *tarGzWriter {
var buf bytes.Buffer
bufWriter := bufio.NewWriter(&buf)
gzipWriter := gzip.NewWriter(bufWriter)
tarWriter := tar.NewWriter(gzipWriter)
return &tarGzWriter{
buf: &buf,
bufWriter: bufWriter,
gzWriter: gzipWriter,
tarWriter: tarWriter,
}
}
func (w *tarGzWriter) writeHeader(hdr *tar.Header) error {
return w.tarWriter.WriteHeader(hdr)
}
func (w *tarGzWriter) writeTar(b []byte) (int, error) {
return w.tarWriter.Write(b)
}
func (w *tarGzWriter) writeGz(b []byte) (int, error) {
return w.gzWriter.Write(b)
}
func (w *tarGzWriter) Bytes() []byte {
_ = w.tarWriter.Flush()
_ = w.gzWriter.Flush()
_ = w.gzWriter.Close() // required to ensure clean EOF in gz reader
_ = w.bufWriter.Flush()
return w.buf.Bytes()
}
func (w *tarGzWriter) Close() (retErr error) {
retErr = errors.Join(retErr, w.tarWriter.Close())
retErr = errors.Join(retErr, w.gzWriter.Close())
return retErr
}
func stringPtr(s string) *string {
return &s
}