mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-12-10 05:31:11 -05:00
AB#2544 add upgrade agent for automatic version updates (#745)
This commit is contained in:
parent
cff735b0ee
commit
9859b30c4d
13 changed files with 715 additions and 24 deletions
|
|
@ -1,299 +0,0 @@
|
|||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package k8sapi
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/internal/retry"
|
||||
"github.com/edgelesssys/constellation/v2/internal/versions"
|
||||
"github.com/spf13/afero"
|
||||
"k8s.io/utils/clock"
|
||||
)
|
||||
|
||||
const (
|
||||
// determines the period after which retryDownloadToTempDir will retry a download.
|
||||
downloadInterval = 10 * time.Millisecond
|
||||
)
|
||||
|
||||
// osInstaller installs binary components of supported kubernetes versions.
|
||||
type osInstaller struct {
|
||||
fs *afero.Afero
|
||||
hClient httpClient
|
||||
// clock is needed for testing purposes
|
||||
clock clock.WithTicker
|
||||
// retriable is the function used to check if an error is retriable. Needed for testing.
|
||||
retriable func(error) bool
|
||||
}
|
||||
|
||||
// newOSInstaller creates a new osInstaller.
|
||||
func newOSInstaller() *osInstaller {
|
||||
return &osInstaller{
|
||||
fs: &afero.Afero{Fs: afero.NewOsFs()},
|
||||
hClient: &http.Client{},
|
||||
clock: clock.RealClock{},
|
||||
retriable: isRetriable,
|
||||
}
|
||||
}
|
||||
|
||||
// Install downloads a resource from a URL, applies any given text transformations and extracts the resulting file if required.
|
||||
// The resulting file(s) are copied to the destination. It also verifies the sha256 hash of the downloaded file.
|
||||
func (i *osInstaller) Install(ctx context.Context, kubernetesComponent versions.ComponentVersion) error {
|
||||
tempPath, err := i.retryDownloadToTempDir(ctx, kubernetesComponent.URL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
file, err := i.fs.OpenFile(tempPath, os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("opening file %q: %w", tempPath, err)
|
||||
}
|
||||
sha := sha256.New()
|
||||
if _, err := io.Copy(sha, file); err != nil {
|
||||
return fmt.Errorf("reading file %q: %w", tempPath, err)
|
||||
}
|
||||
calculatedHash := fmt.Sprintf("sha256:%x", sha.Sum(nil))
|
||||
if calculatedHash != kubernetesComponent.Hash {
|
||||
return fmt.Errorf("hash of file %q %s does not match expected hash %s", tempPath, calculatedHash, kubernetesComponent.Hash)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = i.fs.Remove(tempPath)
|
||||
}()
|
||||
if kubernetesComponent.Extract {
|
||||
err = i.extractArchive(tempPath, kubernetesComponent.InstallPath, executablePerm)
|
||||
} else {
|
||||
err = i.copy(tempPath, kubernetesComponent.InstallPath, executablePerm)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("installing from %q: copying to destination %q: %w", kubernetesComponent.URL, kubernetesComponent.InstallPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractArchive extracts tar gz archives to a prefixed destination.
|
||||
func (i *osInstaller) extractArchive(archivePath, prefix string, perm fs.FileMode) error {
|
||||
archiveFile, err := i.fs.Open(archivePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("opening archive file: %w", err)
|
||||
}
|
||||
defer archiveFile.Close()
|
||||
gzReader, err := gzip.NewReader(archiveFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading archive file as gzip: %w", err)
|
||||
}
|
||||
defer gzReader.Close()
|
||||
if err := i.fs.MkdirAll(prefix, fs.ModePerm); err != nil {
|
||||
return fmt.Errorf("creating prefix folder: %w", err)
|
||||
}
|
||||
tarReader := tar.NewReader(gzReader)
|
||||
|
||||
for {
|
||||
header, err := tarReader.Next()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing tar header: %w", err)
|
||||
}
|
||||
if err := verifyTarPath(header.Name); err != nil {
|
||||
return fmt.Errorf("verifying tar path %q: %w", header.Name, err)
|
||||
}
|
||||
switch header.Typeflag {
|
||||
case tar.TypeDir:
|
||||
if len(header.Name) == 0 {
|
||||
return errors.New("cannot create dir for empty path")
|
||||
}
|
||||
prefixedPath := path.Join(prefix, header.Name)
|
||||
if err := i.fs.Mkdir(prefixedPath, fs.FileMode(header.Mode)&perm); err != nil && !errors.Is(err, os.ErrExist) {
|
||||
return fmt.Errorf("creating folder %q: %w", prefixedPath, err)
|
||||
}
|
||||
case tar.TypeReg:
|
||||
if len(header.Name) == 0 {
|
||||
return errors.New("cannot create file for empty path")
|
||||
}
|
||||
prefixedPath := path.Join(prefix, header.Name)
|
||||
out, err := i.fs.OpenFile(prefixedPath, os.O_WRONLY|os.O_CREATE, fs.FileMode(header.Mode))
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating file %q for writing: %w", prefixedPath, err)
|
||||
}
|
||||
defer out.Close()
|
||||
if _, err := io.Copy(out, tarReader); err != nil {
|
||||
return fmt.Errorf("writing extracted file contents: %w", err)
|
||||
}
|
||||
case tar.TypeSymlink:
|
||||
if err := verifyTarPath(header.Linkname); err != nil {
|
||||
return fmt.Errorf("invalid tar path %q: %w", header.Linkname, err)
|
||||
}
|
||||
if len(header.Name) == 0 {
|
||||
return errors.New("cannot symlink file for empty oldname")
|
||||
}
|
||||
if len(header.Linkname) == 0 {
|
||||
return errors.New("cannot symlink file for empty newname")
|
||||
}
|
||||
if symlinker, ok := i.fs.Fs.(afero.Symlinker); ok {
|
||||
if err := symlinker.SymlinkIfPossible(path.Join(prefix, header.Name), path.Join(prefix, header.Linkname)); err != nil {
|
||||
return fmt.Errorf("creating symlink: %w", err)
|
||||
}
|
||||
} else {
|
||||
return errors.New("fs does not support symlinks")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported tar record: %v", header.Typeflag)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (i *osInstaller) retryDownloadToTempDir(ctx context.Context, url string) (fileName string, someError error) {
|
||||
doer := downloadDoer{
|
||||
url: url,
|
||||
downloader: i,
|
||||
}
|
||||
|
||||
// Retries are canceled as soon as the context is canceled.
|
||||
// We need to call NewIntervalRetrier with a clock argument so that the tests can fake the clock by changing the osInstaller clock.
|
||||
retrier := retry.NewIntervalRetrier(&doer, downloadInterval, i.retriable, i.clock)
|
||||
if err := retrier.Do(ctx); err != nil {
|
||||
return "", fmt.Errorf("retrying downloadToTempDir: %w", err)
|
||||
}
|
||||
|
||||
return doer.path, nil
|
||||
}
|
||||
|
||||
// downloadToTempDir downloads a file to a temporary location.
|
||||
func (i *osInstaller) downloadToTempDir(ctx context.Context, url string) (fileName string, retErr error) {
|
||||
out, err := afero.TempFile(i.fs, "", "")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("creating destination temp file: %w", err)
|
||||
}
|
||||
// Remove the created file if an error occurs.
|
||||
defer func() {
|
||||
if retErr != nil {
|
||||
_ = i.fs.Remove(fileName)
|
||||
retErr = &retriableError{err: retErr} // mark any error after this point as retriable
|
||||
}
|
||||
}()
|
||||
defer out.Close()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request to download %q: %w", url, err)
|
||||
}
|
||||
resp, err := i.hClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request to download %q: %w", url, err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("request to download %q failed with status code: %v", url, resp.Status)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if _, err = io.Copy(out, resp.Body); err != nil {
|
||||
return "", fmt.Errorf("downloading %q: %w", url, err)
|
||||
}
|
||||
return out.Name(), nil
|
||||
}
|
||||
|
||||
// copy copies a file from oldname to newname.
|
||||
func (i *osInstaller) copy(oldname, newname string, perm fs.FileMode) (err error) {
|
||||
old, openOldErr := i.fs.OpenFile(oldname, os.O_RDONLY, fs.ModePerm)
|
||||
if openOldErr != nil {
|
||||
return fmt.Errorf("copying %q to %q: cannot open source file for reading: %w", oldname, newname, openOldErr)
|
||||
}
|
||||
defer func() { _ = old.Close() }()
|
||||
// create destination path if not exists
|
||||
if err := i.fs.MkdirAll(path.Dir(newname), fs.ModePerm); err != nil {
|
||||
return fmt.Errorf("copying %q to %q: unable to create destination folder: %w", oldname, newname, err)
|
||||
}
|
||||
new, openNewErr := i.fs.OpenFile(newname, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, perm)
|
||||
if openNewErr != nil {
|
||||
return fmt.Errorf("copying %q to %q: cannot open destination file for writing: %w", oldname, newname, openNewErr)
|
||||
}
|
||||
defer func() {
|
||||
_ = new.Close()
|
||||
if err != nil {
|
||||
_ = i.fs.Remove(newname)
|
||||
}
|
||||
}()
|
||||
if _, err := io.Copy(new, old); err != nil {
|
||||
return fmt.Errorf("copying %q to %q: copying file contents: %w", oldname, newname, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type downloadDoer struct {
|
||||
url string
|
||||
downloader downloader
|
||||
path string
|
||||
}
|
||||
|
||||
type downloader interface {
|
||||
downloadToTempDir(ctx context.Context, url string) (string, error)
|
||||
}
|
||||
|
||||
func (d *downloadDoer) Do(ctx context.Context) error {
|
||||
path, err := d.downloader.downloadToTempDir(ctx, d.url)
|
||||
d.path = path
|
||||
return err
|
||||
}
|
||||
|
||||
// retriableError is an error that can be retried.
|
||||
type retriableError struct{ err error }
|
||||
|
||||
func (e *retriableError) Error() string {
|
||||
return fmt.Sprintf("retriable error: %s", e.err.Error())
|
||||
}
|
||||
|
||||
func (e *retriableError) Unwrap() error { return e.err }
|
||||
|
||||
// isRetriable returns true if the action resulting in this error can be retried.
|
||||
func isRetriable(err error) bool {
|
||||
retriableError := &retriableError{}
|
||||
return errors.As(err, &retriableError)
|
||||
}
|
||||
|
||||
// verifyTarPath checks if a tar path is valid (must not contain ".." as path element).
|
||||
func verifyTarPath(pat string) error {
|
||||
n := len(pat)
|
||||
r := 0
|
||||
for r < n {
|
||||
switch {
|
||||
case os.IsPathSeparator(pat[r]):
|
||||
// empty path element
|
||||
r++
|
||||
case pat[r] == '.' && (r+1 == n || os.IsPathSeparator(pat[r+1])):
|
||||
// . element
|
||||
r++
|
||||
case pat[r] == '.' && pat[r+1] == '.' && (r+2 == n || os.IsPathSeparator(pat[r+2])):
|
||||
// .. element
|
||||
return errors.New("path contains \"..\"")
|
||||
default:
|
||||
// skip to next path element
|
||||
for r < n && !os.IsPathSeparator(pat[r]) {
|
||||
r++
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type httpClient interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
|
@ -1,750 +0,0 @@
|
|||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package k8sapi
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/internal/versions"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/multierr"
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
testclock "k8s.io/utils/clock/testing"
|
||||
)
|
||||
|
||||
func TestInstall(t *testing.T) {
|
||||
serverURL := "http://server/path"
|
||||
testCases := map[string]struct {
|
||||
server httpBufconnServer
|
||||
component versions.ComponentVersion
|
||||
hash string
|
||||
destination string
|
||||
extract bool
|
||||
wantErr bool
|
||||
wantFiles map[string][]byte
|
||||
}{
|
||||
"download works": {
|
||||
server: newHTTPBufconnServerWithBody([]byte("file-contents")),
|
||||
component: versions.ComponentVersion{
|
||||
URL: serverURL,
|
||||
Hash: "sha256:f03779b36bece74893fd6533a67549675e21573eb0e288d87158738f9c24594e",
|
||||
InstallPath: "/destination",
|
||||
},
|
||||
wantFiles: map[string][]byte{"/destination": []byte("file-contents")},
|
||||
},
|
||||
"download with extract works": {
|
||||
server: newHTTPBufconnServerWithBody(createTarGz([]byte("file-contents"), "/destination")),
|
||||
component: versions.ComponentVersion{
|
||||
URL: serverURL,
|
||||
Hash: "sha256:a52a1664ca0a6ec9790384e3d058852ab8b3a8f389a9113d150fdc6ab308d949",
|
||||
InstallPath: "/prefix",
|
||||
Extract: true,
|
||||
},
|
||||
wantFiles: map[string][]byte{"/prefix/destination": []byte("file-contents")},
|
||||
},
|
||||
"hash validation fails": {
|
||||
server: newHTTPBufconnServerWithBody([]byte("file-contents")),
|
||||
component: versions.ComponentVersion{
|
||||
URL: serverURL,
|
||||
Hash: "sha256:abc",
|
||||
InstallPath: "/destination",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"download fails": {
|
||||
server: newHTTPBufconnServer(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(500) }),
|
||||
component: versions.ComponentVersion{
|
||||
URL: serverURL,
|
||||
Hash: "sha256:abc",
|
||||
InstallPath: "/destination",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
defer tc.server.Close()
|
||||
|
||||
hClient := http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: tc.server.DialContext,
|
||||
Dial: tc.server.Dial,
|
||||
DialTLSContext: tc.server.DialContext,
|
||||
DialTLS: tc.server.Dial,
|
||||
},
|
||||
}
|
||||
|
||||
// This test was written before retriability was added to Install. It makes sense to test Install as if it wouldn't retry requests.
|
||||
inst := osInstaller{
|
||||
fs: &afero.Afero{Fs: afero.NewMemMapFs()},
|
||||
hClient: &hClient,
|
||||
clock: testclock.NewFakeClock(time.Time{}),
|
||||
retriable: func(err error) bool { return false },
|
||||
}
|
||||
|
||||
err := inst.Install(context.Background(), tc.component)
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(err)
|
||||
for path, wantContents := range tc.wantFiles {
|
||||
contents, err := inst.fs.ReadFile(path)
|
||||
assert.NoError(err)
|
||||
assert.Equal(wantContents, contents)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractArchive(t *testing.T) {
|
||||
tarGzTestFile := createTarGz([]byte("file-contents"), "/destination")
|
||||
tarGzTestWithFolder := createTarGzWithFolder([]byte("file-contents"), "/folder/destination", nil)
|
||||
|
||||
testCases := map[string]struct {
|
||||
source string
|
||||
destination string
|
||||
contents []byte
|
||||
readonly bool
|
||||
wantErr bool
|
||||
wantFiles map[string][]byte
|
||||
}{
|
||||
"extract works": {
|
||||
source: "in.tar.gz",
|
||||
destination: "/prefix",
|
||||
contents: tarGzTestFile,
|
||||
wantFiles: map[string][]byte{
|
||||
"/prefix/destination": []byte("file-contents"),
|
||||
},
|
||||
},
|
||||
"extract with folder works": {
|
||||
source: "in.tar.gz",
|
||||
destination: "/prefix",
|
||||
contents: tarGzTestWithFolder,
|
||||
wantFiles: map[string][]byte{
|
||||
"/prefix/folder/destination": []byte("file-contents"),
|
||||
},
|
||||
},
|
||||
"source missing": {
|
||||
source: "in.tar.gz",
|
||||
destination: "/prefix",
|
||||
wantErr: true,
|
||||
},
|
||||
"non-gzip file contents": {
|
||||
source: "in.tar.gz",
|
||||
contents: []byte("invalid bytes"),
|
||||
destination: "/prefix",
|
||||
wantErr: true,
|
||||
},
|
||||
"non-tar file contents": {
|
||||
source: "in.tar.gz",
|
||||
contents: createGz([]byte("file-contents")),
|
||||
destination: "/prefix",
|
||||
wantErr: true,
|
||||
},
|
||||
"mkdir prefix dir fails on RO fs": {
|
||||
source: "in.tar.gz",
|
||||
contents: tarGzTestFile,
|
||||
destination: "/prefix",
|
||||
readonly: true,
|
||||
wantErr: true,
|
||||
},
|
||||
"mkdir tar dir fails on RO fs": {
|
||||
source: "in.tar.gz",
|
||||
contents: tarGzTestWithFolder,
|
||||
destination: "/",
|
||||
readonly: true,
|
||||
wantErr: true,
|
||||
},
|
||||
"writing tar file fails on RO fs": {
|
||||
source: "in.tar.gz",
|
||||
contents: tarGzTestFile,
|
||||
destination: "/",
|
||||
readonly: true,
|
||||
wantErr: true,
|
||||
},
|
||||
"symlink can be detected (but is unsupported on memmapfs)": {
|
||||
source: "in.tar.gz",
|
||||
contents: createTarGzWithSymlink("source", "dest"),
|
||||
destination: "/prefix",
|
||||
wantErr: true,
|
||||
},
|
||||
"unsupported tar header type is detected": {
|
||||
source: "in.tar.gz",
|
||||
contents: createTarGzWithFifo("/destination"),
|
||||
destination: "/prefix",
|
||||
wantErr: true,
|
||||
},
|
||||
"path traversal is detected": {
|
||||
source: "in.tar.gz",
|
||||
contents: createTarGz([]byte{}, "../destination"),
|
||||
wantErr: true,
|
||||
},
|
||||
"path traversal in symlink is detected": {
|
||||
source: "in.tar.gz",
|
||||
contents: createTarGzWithSymlink("/source", "../destination"),
|
||||
wantErr: true,
|
||||
},
|
||||
"empty file name is detected": {
|
||||
source: "in.tar.gz",
|
||||
contents: createTarGz([]byte{}, ""),
|
||||
wantErr: true,
|
||||
},
|
||||
"empty folder name is detected": {
|
||||
source: "in.tar.gz",
|
||||
contents: createTarGzWithFolder([]byte{}, "source", stringPtr("")),
|
||||
wantErr: true,
|
||||
},
|
||||
"empty symlink source is detected": {
|
||||
source: "in.tar.gz",
|
||||
contents: createTarGzWithSymlink("", "/target"),
|
||||
wantErr: true,
|
||||
},
|
||||
"empty symlink target is detected": {
|
||||
source: "in.tar.gz",
|
||||
contents: createTarGzWithSymlink("/source", ""),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
afs := afero.NewMemMapFs()
|
||||
if len(tc.source) > 0 && len(tc.contents) > 0 {
|
||||
require.NoError(afero.WriteFile(afs, tc.source, tc.contents, fs.ModePerm))
|
||||
}
|
||||
|
||||
if tc.readonly {
|
||||
afs = afero.NewReadOnlyFs(afs)
|
||||
}
|
||||
|
||||
inst := osInstaller{
|
||||
fs: &afero.Afero{Fs: afs},
|
||||
}
|
||||
err := inst.extractArchive(tc.source, tc.destination, fs.ModePerm)
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(err)
|
||||
for path, wantContents := range tc.wantFiles {
|
||||
contents, err := inst.fs.ReadFile(path)
|
||||
assert.NoError(err)
|
||||
assert.Equal(wantContents, contents)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryDownloadToTempDir(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
responses []int
|
||||
cancelCtx bool
|
||||
wantErr bool
|
||||
wantFile []byte
|
||||
}{
|
||||
"Succeed on third try": {
|
||||
responses: []int{500, 500, 200},
|
||||
wantFile: []byte("file-content"),
|
||||
},
|
||||
"Cancel after second try": {
|
||||
responses: []int{500, 500},
|
||||
cancelCtx: true,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
// control the server's responses through stateCh
|
||||
stateCh := make(chan int)
|
||||
server := newHTTPBufconnServerWithState(stateCh, tc.wantFile)
|
||||
defer server.Close()
|
||||
|
||||
hClient := http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: server.DialContext,
|
||||
Dial: server.Dial,
|
||||
DialTLSContext: server.DialContext,
|
||||
DialTLS: server.Dial,
|
||||
},
|
||||
}
|
||||
|
||||
afs := afero.NewMemMapFs()
|
||||
|
||||
// control download retries through FakeClock clock
|
||||
clock := testclock.NewFakeClock(time.Now())
|
||||
inst := osInstaller{
|
||||
fs: &afero.Afero{Fs: afs},
|
||||
hClient: &hClient,
|
||||
clock: clock,
|
||||
retriable: func(error) bool { return true },
|
||||
}
|
||||
|
||||
// abort retryDownloadToTempDir in some test cases by using the context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
var downloadErr error
|
||||
var path string
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
path, downloadErr = inst.retryDownloadToTempDir(ctx, "http://server/path")
|
||||
}()
|
||||
|
||||
// control the server's responses through stateCh.
|
||||
for _, resp := range tc.responses {
|
||||
stateCh <- resp
|
||||
clock.Step(downloadInterval)
|
||||
}
|
||||
if tc.cancelCtx {
|
||||
cancel()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(downloadErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(downloadErr)
|
||||
content, err := inst.fs.ReadFile(path)
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.wantFile, content)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadToTempDir(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
server httpBufconnServer
|
||||
readonly bool
|
||||
wantErr bool
|
||||
wantFile []byte
|
||||
}{
|
||||
"download works": {
|
||||
server: newHTTPBufconnServerWithBody([]byte("file-contents")),
|
||||
wantFile: []byte("file-contents"),
|
||||
},
|
||||
"download fails": {
|
||||
server: newHTTPBufconnServer(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(500) }),
|
||||
wantErr: true,
|
||||
},
|
||||
"creating temp file fails on RO fs": {
|
||||
server: newHTTPBufconnServerWithBody([]byte("file-contents")),
|
||||
readonly: true,
|
||||
wantErr: true,
|
||||
},
|
||||
"content length mismatch": {
|
||||
server: newHTTPBufconnServer(func(writer http.ResponseWriter, request *http.Request) {
|
||||
writer.Header().Set("Content-Length", "1337")
|
||||
writer.WriteHeader(200)
|
||||
}),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
defer tc.server.Close()
|
||||
|
||||
hClient := http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: tc.server.DialContext,
|
||||
Dial: tc.server.Dial,
|
||||
DialTLSContext: tc.server.DialContext,
|
||||
DialTLS: tc.server.Dial,
|
||||
},
|
||||
}
|
||||
|
||||
afs := afero.NewMemMapFs()
|
||||
if tc.readonly {
|
||||
afs = afero.NewReadOnlyFs(afs)
|
||||
}
|
||||
inst := osInstaller{
|
||||
fs: &afero.Afero{Fs: afs},
|
||||
hClient: &hClient,
|
||||
}
|
||||
path, err := inst.downloadToTempDir(context.Background(), "http://server/path")
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(err)
|
||||
contents, err := inst.fs.ReadFile(path)
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.wantFile, contents)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopy(t *testing.T) {
|
||||
contents := []byte("file-contents")
|
||||
existingFile := "/source"
|
||||
testCases := map[string]struct {
|
||||
oldname string
|
||||
newname string
|
||||
perm fs.FileMode
|
||||
readonly bool
|
||||
wantErr bool
|
||||
}{
|
||||
"copy works": {
|
||||
oldname: existingFile,
|
||||
newname: "/destination",
|
||||
perm: fs.ModePerm,
|
||||
},
|
||||
"oldname does not exist": {
|
||||
oldname: "missing",
|
||||
newname: "/destination",
|
||||
wantErr: true,
|
||||
},
|
||||
"copy on readonly fs fails": {
|
||||
oldname: existingFile,
|
||||
newname: "/destination",
|
||||
perm: fs.ModePerm,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
afs := afero.NewMemMapFs()
|
||||
require.NoError(afero.WriteFile(afs, existingFile, contents, fs.ModePerm))
|
||||
|
||||
if tc.readonly {
|
||||
afs = afero.NewReadOnlyFs(afs)
|
||||
}
|
||||
|
||||
inst := osInstaller{fs: &afero.Afero{Fs: afs}}
|
||||
err := inst.copy(tc.oldname, tc.newname, tc.perm)
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(err)
|
||||
|
||||
oldfile, err := afs.Open(tc.oldname)
|
||||
assert.NoError(err)
|
||||
newfile, err := afs.Open(tc.newname)
|
||||
assert.NoError(err)
|
||||
|
||||
oldContents, _ := io.ReadAll(oldfile)
|
||||
newContents, _ := io.ReadAll(newfile)
|
||||
assert.Equal(oldContents, newContents)
|
||||
|
||||
newStat, _ := newfile.Stat()
|
||||
assert.Equal(tc.perm, newStat.Mode())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyTarPath(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
path string
|
||||
wantErr bool
|
||||
}{
|
||||
"valid relative path": {
|
||||
path: "a/b/c",
|
||||
},
|
||||
"valid absolute path": {
|
||||
path: "/a/b/c",
|
||||
},
|
||||
"valid path with dot": {
|
||||
path: "/a/b/.d",
|
||||
},
|
||||
"valid path with dots": {
|
||||
path: "/a/b/..d",
|
||||
},
|
||||
"single dot in path is allowed": {
|
||||
path: ".",
|
||||
},
|
||||
"simple path traversal": {
|
||||
path: "..",
|
||||
wantErr: true,
|
||||
},
|
||||
"simple path traversal 2": {
|
||||
path: "../",
|
||||
wantErr: true,
|
||||
},
|
||||
"simple path traversal 3": {
|
||||
path: "/..",
|
||||
wantErr: true,
|
||||
},
|
||||
"simple path traversal 4": {
|
||||
path: "/../",
|
||||
wantErr: true,
|
||||
},
|
||||
"complex relative path traversal": {
|
||||
path: "a/b/c/../../../../c/d/e",
|
||||
wantErr: true,
|
||||
},
|
||||
"complex absolute path traversal": {
|
||||
path: "/a/b/c/../../../../c/d/e",
|
||||
wantErr: true,
|
||||
},
|
||||
"path traversal at the end": {
|
||||
path: "a/..",
|
||||
wantErr: true,
|
||||
},
|
||||
"path traversal at the end with trailing /": {
|
||||
path: "a/../",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
err := verifyTarPath(tc.path)
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
assert.Equal(tc.path, path.Clean(tc.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type httpBufconnServer struct {
|
||||
*httptest.Server
|
||||
*bufconn.Listener
|
||||
}
|
||||
|
||||
func (s *httpBufconnServer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return s.Listener.DialContext(ctx)
|
||||
}
|
||||
|
||||
func (s *httpBufconnServer) Dial(network, addr string) (net.Conn, error) {
|
||||
return s.Listener.Dial()
|
||||
}
|
||||
|
||||
func (s *httpBufconnServer) Close() {
|
||||
s.Server.Close()
|
||||
s.Listener.Close()
|
||||
}
|
||||
|
||||
func newHTTPBufconnServer(handlerFunc http.HandlerFunc) httpBufconnServer {
|
||||
server := httptest.NewUnstartedServer(handlerFunc)
|
||||
listener := bufconn.Listen(1024)
|
||||
server.Listener = listener
|
||||
server.Start()
|
||||
return httpBufconnServer{
|
||||
Server: server,
|
||||
Listener: listener,
|
||||
}
|
||||
}
|
||||
|
||||
func newHTTPBufconnServerWithBody(body []byte) httpBufconnServer {
|
||||
return newHTTPBufconnServer(func(writer http.ResponseWriter, request *http.Request) {
|
||||
if _, err := writer.Write(body); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func newHTTPBufconnServerWithState(state chan int, body []byte) httpBufconnServer {
|
||||
return newHTTPBufconnServer(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch <-state {
|
||||
case 500:
|
||||
w.WriteHeader(500)
|
||||
case 200:
|
||||
if _, err := w.Write(body); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
default:
|
||||
w.WriteHeader(402)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func createTarGz(contents []byte, path string) []byte {
|
||||
tgzWriter := newTarGzWriter()
|
||||
defer func() { _ = tgzWriter.Close() }()
|
||||
|
||||
if err := tgzWriter.writeHeader(&tar.Header{
|
||||
Typeflag: tar.TypeReg,
|
||||
Name: path,
|
||||
Size: int64(len(contents)),
|
||||
Mode: int64(fs.ModePerm),
|
||||
}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if _, err := tgzWriter.writeTar(contents); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return tgzWriter.Bytes()
|
||||
}
|
||||
|
||||
func createTarGzWithFolder(contents []byte, pat string, dirnameOverride *string) []byte {
|
||||
tgzWriter := newTarGzWriter()
|
||||
defer func() { _ = tgzWriter.Close() }()
|
||||
|
||||
dir := path.Dir(pat)
|
||||
if dirnameOverride != nil {
|
||||
dir = *dirnameOverride
|
||||
}
|
||||
|
||||
if err := tgzWriter.writeHeader(&tar.Header{
|
||||
Typeflag: tar.TypeDir,
|
||||
Name: dir,
|
||||
Mode: int64(fs.ModePerm),
|
||||
}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := tgzWriter.writeHeader(&tar.Header{
|
||||
Typeflag: tar.TypeReg,
|
||||
Name: pat,
|
||||
Size: int64(len(contents)),
|
||||
Mode: int64(fs.ModePerm),
|
||||
}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if _, err := tgzWriter.writeTar(contents); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return tgzWriter.Bytes()
|
||||
}
|
||||
|
||||
func createTarGzWithSymlink(oldname, newname string) []byte {
|
||||
tgzWriter := newTarGzWriter()
|
||||
defer func() { _ = tgzWriter.Close() }()
|
||||
|
||||
if err := tgzWriter.writeHeader(&tar.Header{
|
||||
Typeflag: tar.TypeSymlink,
|
||||
Name: oldname,
|
||||
Linkname: newname,
|
||||
Mode: int64(fs.ModePerm),
|
||||
}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return tgzWriter.Bytes()
|
||||
}
|
||||
|
||||
func createTarGzWithFifo(name string) []byte {
|
||||
tgzWriter := newTarGzWriter()
|
||||
defer func() { _ = tgzWriter.Close() }()
|
||||
|
||||
if err := tgzWriter.writeHeader(&tar.Header{
|
||||
Typeflag: tar.TypeFifo,
|
||||
Name: name,
|
||||
Mode: int64(fs.ModePerm),
|
||||
}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return tgzWriter.Bytes()
|
||||
}
|
||||
|
||||
func createGz(contents []byte) []byte {
|
||||
tgzWriter := newTarGzWriter()
|
||||
defer func() { _ = tgzWriter.Close() }()
|
||||
|
||||
if _, err := tgzWriter.writeGz(contents); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return tgzWriter.Bytes()
|
||||
}
|
||||
|
||||
type tarGzWriter struct {
|
||||
buf *bytes.Buffer
|
||||
bufWriter *bufio.Writer
|
||||
gzWriter *gzip.Writer
|
||||
tarWriter *tar.Writer
|
||||
}
|
||||
|
||||
func newTarGzWriter() *tarGzWriter {
|
||||
var buf bytes.Buffer
|
||||
bufWriter := bufio.NewWriter(&buf)
|
||||
gzipWriter := gzip.NewWriter(bufWriter)
|
||||
tarWriter := tar.NewWriter(gzipWriter)
|
||||
|
||||
return &tarGzWriter{
|
||||
buf: &buf,
|
||||
bufWriter: bufWriter,
|
||||
gzWriter: gzipWriter,
|
||||
tarWriter: tarWriter,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *tarGzWriter) writeHeader(hdr *tar.Header) error {
|
||||
return w.tarWriter.WriteHeader(hdr)
|
||||
}
|
||||
|
||||
func (w *tarGzWriter) writeTar(b []byte) (int, error) {
|
||||
return w.tarWriter.Write(b)
|
||||
}
|
||||
|
||||
func (w *tarGzWriter) writeGz(b []byte) (int, error) {
|
||||
return w.gzWriter.Write(b)
|
||||
}
|
||||
|
||||
func (w *tarGzWriter) Bytes() []byte {
|
||||
_ = w.tarWriter.Flush()
|
||||
_ = w.gzWriter.Flush()
|
||||
_ = w.gzWriter.Close() // required to ensure clean EOF in gz reader
|
||||
_ = w.bufWriter.Flush()
|
||||
return w.buf.Bytes()
|
||||
}
|
||||
|
||||
func (w *tarGzWriter) Close() (result error) {
|
||||
if err := w.tarWriter.Close(); err != nil {
|
||||
result = multierr.Append(result, err)
|
||||
}
|
||||
if err := w.gzWriter.Close(); err != nil {
|
||||
result = multierr.Append(result, err)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
|
@ -31,6 +31,7 @@ import (
|
|||
|
||||
"github.com/edgelesssys/constellation/v2/internal/crypto"
|
||||
"github.com/edgelesssys/constellation/v2/internal/file"
|
||||
"github.com/edgelesssys/constellation/v2/internal/installer"
|
||||
"github.com/edgelesssys/constellation/v2/internal/logger"
|
||||
"github.com/edgelesssys/constellation/v2/internal/versions"
|
||||
"github.com/spf13/afero"
|
||||
|
|
@ -40,9 +41,7 @@ import (
|
|||
const (
|
||||
// kubeletStartTimeout is the maximum time given to the kubelet service to (re)start.
|
||||
kubeletStartTimeout = 10 * time.Minute
|
||||
// crdTimeout is the maximum time given to the CRDs to be created.
|
||||
crdTimeout = 30 * time.Second
|
||||
executablePerm = 0o544
|
||||
|
||||
kubeletServicePath = "/usr/lib/systemd/system/kubelet.service"
|
||||
)
|
||||
|
||||
|
|
@ -56,7 +55,7 @@ type Client interface {
|
|||
AnnotateNode(ctx context.Context, nodeName, annotationKey, annotationValue string) error
|
||||
}
|
||||
|
||||
type installer interface {
|
||||
type componentsInstaller interface {
|
||||
Install(
|
||||
ctx context.Context, kubernetesComponent versions.ComponentVersion,
|
||||
) error
|
||||
|
|
@ -64,14 +63,14 @@ type installer interface {
|
|||
|
||||
// KubernetesUtil provides low level management of the kubernetes cluster.
|
||||
type KubernetesUtil struct {
|
||||
inst installer
|
||||
inst componentsInstaller
|
||||
file file.Handler
|
||||
}
|
||||
|
||||
// NewKubernetesUtil creates a new KubernetesUtil.
|
||||
func NewKubernetesUtil() *KubernetesUtil {
|
||||
return &KubernetesUtil{
|
||||
inst: newOSInstaller(),
|
||||
inst: installer.NewOSInstaller(),
|
||||
file: file.NewHandler(afero.NewOsFs()),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue