mirror of
https://github.com/edgelesssys/constellation.git
synced 2024-10-01 01:36:09 -04:00
bd63aa3c6b
sed -i '1i/*\nCopyright (c) Edgeless Systems GmbH\n\nSPDX-License-Identifier: AGPL-3.0-only\n*/\n' `grep -rL --include='*.go' 'DO NOT EDIT'` gofumpt -w .
294 lines
8.8 KiB
Go
294 lines
8.8 KiB
Go
/*
|
|
Copyright (c) Edgeless Systems GmbH
|
|
|
|
SPDX-License-Identifier: AGPL-3.0-only
|
|
*/
|
|
|
|
package k8sapi
|
|
|
|
import (
|
|
"archive/tar"
|
|
"compress/gzip"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/fs"
|
|
"net/http"
|
|
"os"
|
|
"path"
|
|
"time"
|
|
|
|
"github.com/edgelesssys/constellation/internal/retry"
|
|
"github.com/spf13/afero"
|
|
"golang.org/x/text/transform"
|
|
"k8s.io/utils/clock"
|
|
)
|
|
|
|
const (
|
|
// determines the period after which retryDownloadToTempDir will retry a download.
|
|
downloadInterval = 10 * time.Millisecond
|
|
)
|
|
|
|
// osInstaller installs binary components of supported kubernetes versions.
|
|
type osInstaller struct {
|
|
fs *afero.Afero
|
|
hClient httpClient
|
|
// clock is needed for testing purposes
|
|
clock clock.WithTicker
|
|
// retriable is the function used to check if an error is retriable. Needed for testing.
|
|
retriable func(error) bool
|
|
}
|
|
|
|
// newOSInstaller creates a new osInstaller.
|
|
func newOSInstaller() *osInstaller {
|
|
return &osInstaller{
|
|
fs: &afero.Afero{Fs: afero.NewOsFs()},
|
|
hClient: &http.Client{},
|
|
clock: clock.RealClock{},
|
|
retriable: isRetriable,
|
|
}
|
|
}
|
|
|
|
// Install downloads a resource from a URL, applies any given text transformations and extracts the resulting file if required.
|
|
// The resulting file(s) are copied to all destinations.
|
|
func (i *osInstaller) Install(
|
|
ctx context.Context, sourceURL string, destinations []string, perm fs.FileMode,
|
|
extract bool, transforms ...transform.Transformer,
|
|
) error {
|
|
tempPath, err := i.retryDownloadToTempDir(ctx, sourceURL, transforms...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
_ = i.fs.Remove(tempPath)
|
|
}()
|
|
for _, destination := range destinations {
|
|
var err error
|
|
if extract {
|
|
err = i.extractArchive(tempPath, destination, perm)
|
|
} else {
|
|
err = i.copy(tempPath, destination, perm)
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("installing from %q: copying to destination %q: %w", sourceURL, destination, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// extractArchive extracts tar gz archives to a prefixed destination.
|
|
func (i *osInstaller) extractArchive(archivePath, prefix string, perm fs.FileMode) error {
|
|
archiveFile, err := i.fs.Open(archivePath)
|
|
if err != nil {
|
|
return fmt.Errorf("opening archive file: %w", err)
|
|
}
|
|
defer archiveFile.Close()
|
|
gzReader, err := gzip.NewReader(archiveFile)
|
|
if err != nil {
|
|
return fmt.Errorf("reading archive file as gzip: %w", err)
|
|
}
|
|
defer gzReader.Close()
|
|
if err := i.fs.MkdirAll(prefix, fs.ModePerm); err != nil {
|
|
return fmt.Errorf("creating prefix folder: %w", err)
|
|
}
|
|
tarReader := tar.NewReader(gzReader)
|
|
|
|
for {
|
|
header, err := tarReader.Next()
|
|
if err == io.EOF {
|
|
return nil
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("parsing tar header: %w", err)
|
|
}
|
|
if err := verifyTarPath(header.Name); err != nil {
|
|
return fmt.Errorf("verifying tar path %q: %w", header.Name, err)
|
|
}
|
|
switch header.Typeflag {
|
|
case tar.TypeDir:
|
|
if len(header.Name) == 0 {
|
|
return errors.New("cannot create dir for empty path")
|
|
}
|
|
prefixedPath := path.Join(prefix, header.Name)
|
|
if err := i.fs.Mkdir(prefixedPath, fs.FileMode(header.Mode)&perm); err != nil && !errors.Is(err, os.ErrExist) {
|
|
return fmt.Errorf("creating folder %q: %w", prefixedPath, err)
|
|
}
|
|
case tar.TypeReg:
|
|
if len(header.Name) == 0 {
|
|
return errors.New("cannot create file for empty path")
|
|
}
|
|
prefixedPath := path.Join(prefix, header.Name)
|
|
out, err := i.fs.OpenFile(prefixedPath, os.O_WRONLY|os.O_CREATE, fs.FileMode(header.Mode))
|
|
if err != nil {
|
|
return fmt.Errorf("creating file %q for writing: %w", prefixedPath, err)
|
|
}
|
|
defer out.Close()
|
|
if _, err := io.Copy(out, tarReader); err != nil {
|
|
return fmt.Errorf("writing extracted file contents: %w", err)
|
|
}
|
|
case tar.TypeSymlink:
|
|
if err := verifyTarPath(header.Linkname); err != nil {
|
|
return fmt.Errorf("invalid tar path %q: %w", header.Linkname, err)
|
|
}
|
|
if len(header.Name) == 0 {
|
|
return errors.New("cannot symlink file for empty oldname")
|
|
}
|
|
if len(header.Linkname) == 0 {
|
|
return errors.New("cannot symlink file for empty newname")
|
|
}
|
|
if symlinker, ok := i.fs.Fs.(afero.Symlinker); ok {
|
|
if err := symlinker.SymlinkIfPossible(path.Join(prefix, header.Name), path.Join(prefix, header.Linkname)); err != nil {
|
|
return fmt.Errorf("creating symlink: %w", err)
|
|
}
|
|
} else {
|
|
return errors.New("fs does not support symlinks")
|
|
}
|
|
default:
|
|
return fmt.Errorf("unsupported tar record: %v", header.Typeflag)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (i *osInstaller) retryDownloadToTempDir(ctx context.Context, url string, transforms ...transform.Transformer) (fileName string, someError error) {
|
|
doer := downloadDoer{
|
|
url: url,
|
|
transforms: transforms,
|
|
downloader: i,
|
|
}
|
|
|
|
// Retries are canceled as soon as the context is canceled.
|
|
// We need to call NewIntervalRetrier with a clock argument so that the tests can fake the clock by changing the osInstaller clock.
|
|
retrier := retry.NewIntervalRetrier(&doer, downloadInterval, i.retriable, i.clock)
|
|
if err := retrier.Do(ctx); err != nil {
|
|
return "", fmt.Errorf("retrying downloadToTempDir: %w", err)
|
|
}
|
|
|
|
return doer.path, nil
|
|
}
|
|
|
|
// downloadToTempDir downloads a file to a temporary location, applying transform on-the-fly.
|
|
func (i *osInstaller) downloadToTempDir(ctx context.Context, url string, transforms ...transform.Transformer) (fileName string, retErr error) {
|
|
out, err := afero.TempFile(i.fs, "", "")
|
|
if err != nil {
|
|
return "", fmt.Errorf("creating destination temp file: %w", err)
|
|
}
|
|
// Remove the created file if an error occurs.
|
|
defer func() {
|
|
if retErr != nil {
|
|
_ = i.fs.Remove(fileName)
|
|
retErr = &retriableError{err: retErr} // mark any error after this point as retriable
|
|
}
|
|
}()
|
|
defer out.Close()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
|
if err != nil {
|
|
return "", fmt.Errorf("request to download %q: %w", url, err)
|
|
}
|
|
resp, err := i.hClient.Do(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("request to download %q: %w", url, err)
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
return "", fmt.Errorf("request to download %q failed with status code: %v", url, resp.Status)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
transformReader := transform.NewReader(resp.Body, transform.Chain(transforms...))
|
|
|
|
if _, err = io.Copy(out, transformReader); err != nil {
|
|
return "", fmt.Errorf("downloading %q: %w", url, err)
|
|
}
|
|
return out.Name(), nil
|
|
}
|
|
|
|
// copy copies a file from oldname to newname.
|
|
func (i *osInstaller) copy(oldname, newname string, perm fs.FileMode) (err error) {
|
|
old, openOldErr := i.fs.OpenFile(oldname, os.O_RDONLY, fs.ModePerm)
|
|
if openOldErr != nil {
|
|
return fmt.Errorf("copying %q to %q: cannot open source file for reading: %w", oldname, newname, openOldErr)
|
|
}
|
|
defer func() { _ = old.Close() }()
|
|
// create destination path if not exists
|
|
if err := i.fs.MkdirAll(path.Dir(newname), fs.ModePerm); err != nil {
|
|
return fmt.Errorf("copying %q to %q: unable to create destination folder: %w", oldname, newname, err)
|
|
}
|
|
new, openNewErr := i.fs.OpenFile(newname, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, perm)
|
|
if openNewErr != nil {
|
|
return fmt.Errorf("copying %q to %q: cannot open destination file for writing: %w", oldname, newname, openNewErr)
|
|
}
|
|
defer func() {
|
|
_ = new.Close()
|
|
if err != nil {
|
|
_ = i.fs.Remove(newname)
|
|
}
|
|
}()
|
|
if _, err := io.Copy(new, old); err != nil {
|
|
return fmt.Errorf("copying %q to %q: copying file contents: %w", oldname, newname, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type downloadDoer struct {
|
|
url string
|
|
transforms []transform.Transformer
|
|
downloader downloader
|
|
path string
|
|
}
|
|
|
|
type downloader interface {
|
|
downloadToTempDir(ctx context.Context, url string, transforms ...transform.Transformer) (string, error)
|
|
}
|
|
|
|
func (d *downloadDoer) Do(ctx context.Context) error {
|
|
path, err := d.downloader.downloadToTempDir(ctx, d.url, d.transforms...)
|
|
d.path = path
|
|
return err
|
|
}
|
|
|
|
// retriableError is an error that can be retried.
|
|
type retriableError struct{ err error }
|
|
|
|
func (e *retriableError) Error() string {
|
|
return fmt.Sprintf("retriable error: %s", e.err.Error())
|
|
}
|
|
|
|
func (e *retriableError) Unwrap() error { return e.err }
|
|
|
|
// isRetriable returns true if the action resulting in this error can be retried.
|
|
func isRetriable(err error) bool {
|
|
retriableError := &retriableError{}
|
|
return errors.As(err, &retriableError)
|
|
}
|
|
|
|
// verifyTarPath checks if a tar path is valid (must not contain ".." as path element).
|
|
func verifyTarPath(pat string) error {
|
|
n := len(pat)
|
|
r := 0
|
|
for r < n {
|
|
switch {
|
|
case os.IsPathSeparator(pat[r]):
|
|
// empty path element
|
|
r++
|
|
case pat[r] == '.' && (r+1 == n || os.IsPathSeparator(pat[r+1])):
|
|
// . element
|
|
r++
|
|
case pat[r] == '.' && pat[r+1] == '.' && (r+2 == n || os.IsPathSeparator(pat[r+2])):
|
|
// .. element
|
|
return errors.New("path contains \"..\"")
|
|
default:
|
|
// skip to next path element
|
|
for r < n && !os.IsPathSeparator(pat[r]) {
|
|
r++
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type httpClient interface {
|
|
Do(req *http.Request) (*http.Response, error)
|
|
}
|