constellation/internal/installer/installer.go
Markus Rudy ae00b0a198 installer: add support for data URLs
RFC 015 proposes the introduction of data URLs to materialize static
content to files on disk. This commit adds support for data URLs to the
installer. The corresponding content will be added to versions.go in a
subsequent commit.
2023-12-13 09:35:19 +01:00

351 lines
11 KiB
Go

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
// Package installer provides functionality to install binary components of supported kubernetes versions.
package installer
import (
"archive/tar"
"bytes"
"compress/gzip"
"context"
"crypto/sha256"
"errors"
"fmt"
"io"
"io/fs"
"net/http"
"net/url"
"os"
"path"
"slices"
"time"
"github.com/edgelesssys/constellation/v2/internal/retry"
"github.com/edgelesssys/constellation/v2/internal/versions/components"
"github.com/spf13/afero"
"github.com/vincent-petithory/dataurl"
"k8s.io/utils/clock"
)
const (
// determines the period after which retryDownloadToTempDir will retry a download.
downloadInterval = 10 * time.Millisecond
executablePerm = 0o544
)
// OsInstaller installs binary components of supported kubernetes versions.
type OsInstaller struct {
fs *afero.Afero
hClient httpClient
// clock is needed for testing purposes
clock clock.WithTicker
// retriable is the function used to check if an error is retriable. Needed for testing.
retriable func(error) bool
}
// NewOSInstaller creates a new osInstaller.
func NewOSInstaller() *OsInstaller {
return &OsInstaller{
fs: &afero.Afero{Fs: afero.NewOsFs()},
hClient: &http.Client{},
clock: clock.RealClock{},
retriable: isRetriable,
}
}
// Install downloads a resource from a URL, applies any given text transformations and extracts the resulting file if required.
// The resulting file(s) are copied to the destination. It also verifies the sha256 hash of the downloaded file.
func (i *OsInstaller) Install(ctx context.Context, kubernetesComponent *components.Component) error {
tempPath, err := i.retryDownloadToTempDir(ctx, kubernetesComponent.Url)
if err != nil {
return err
}
file, err := i.fs.OpenFile(tempPath, os.O_RDONLY, 0)
if err != nil {
return fmt.Errorf("opening file %q: %w", tempPath, err)
}
sha := sha256.New()
if _, err := io.Copy(sha, file); err != nil {
return fmt.Errorf("reading file %q: %w", tempPath, err)
}
calculatedHash := fmt.Sprintf("sha256:%x", sha.Sum(nil))
if len(kubernetesComponent.Hash) > 0 && calculatedHash != kubernetesComponent.Hash {
return fmt.Errorf("hash of file %q %s does not match expected hash %s", tempPath, calculatedHash, kubernetesComponent.Hash)
}
defer func() {
_ = i.fs.Remove(tempPath)
}()
if kubernetesComponent.Extract {
err = i.extractArchive(tempPath, kubernetesComponent.InstallPath, executablePerm)
} else {
err = i.copy(tempPath, kubernetesComponent.InstallPath, executablePerm)
}
if err != nil {
return fmt.Errorf("installing from %q: copying to destination %q: %w", kubernetesComponent.Url, kubernetesComponent.InstallPath, err)
}
return nil
}
// extractArchive extracts tar gz archives to a prefixed destination.
func (i *OsInstaller) extractArchive(archivePath, prefix string, perm fs.FileMode) error {
archiveFile, err := i.fs.Open(archivePath)
if err != nil {
return fmt.Errorf("opening archive file: %w", err)
}
defer archiveFile.Close()
gzReader, err := gzip.NewReader(archiveFile)
if err != nil {
return fmt.Errorf("reading archive file as gzip: %w", err)
}
defer gzReader.Close()
if err := i.fs.MkdirAll(prefix, fs.ModePerm); err != nil {
return fmt.Errorf("creating prefix folder: %w", err)
}
tarReader := tar.NewReader(gzReader)
for {
header, err := tarReader.Next()
if err == io.EOF {
return nil
}
if err != nil {
return fmt.Errorf("parsing tar header: %w", err)
}
if err := verifyTarPath(header.Name); err != nil {
return fmt.Errorf("verifying tar path %q: %w", header.Name, err)
}
switch header.Typeflag {
case tar.TypeDir:
if len(header.Name) == 0 {
return errors.New("cannot create dir for empty path")
}
prefixedPath := path.Join(prefix, header.Name)
if err := i.fs.Mkdir(prefixedPath, fs.FileMode(header.Mode)&perm); err != nil && !errors.Is(err, os.ErrExist) {
return fmt.Errorf("creating folder %q: %w", prefixedPath, err)
}
case tar.TypeReg:
if len(header.Name) == 0 {
return errors.New("cannot create file for empty path")
}
prefixedPath := path.Join(prefix, header.Name)
out, err := i.fs.OpenFile(prefixedPath, os.O_WRONLY|os.O_CREATE, fs.FileMode(header.Mode))
if err != nil {
return fmt.Errorf("creating file %q for writing: %w", prefixedPath, err)
}
defer out.Close()
if _, err := io.Copy(out, tarReader); err != nil {
return fmt.Errorf("writing extracted file contents: %w", err)
}
case tar.TypeSymlink:
if err := verifyTarPath(header.Linkname); err != nil {
return fmt.Errorf("invalid tar path %q: %w", header.Linkname, err)
}
if len(header.Name) == 0 {
return errors.New("cannot symlink file for empty oldname")
}
if len(header.Linkname) == 0 {
return errors.New("cannot symlink file for empty newname")
}
if symlinker, ok := i.fs.Fs.(afero.Symlinker); ok {
if err := symlinker.SymlinkIfPossible(path.Join(prefix, header.Name), path.Join(prefix, header.Linkname)); err != nil {
return fmt.Errorf("creating symlink: %w", err)
}
} else {
return errors.New("fs does not support symlinks")
}
default:
return fmt.Errorf("unsupported tar record: %v", header.Typeflag)
}
}
}
func (i *OsInstaller) retryDownloadToTempDir(ctx context.Context, url string) (fileName string, someError error) {
doer := downloadDoer{
url: url,
downloader: i,
}
// Retries are canceled as soon as the context is canceled.
// We need to call NewIntervalRetrier with a clock argument so that the tests can fake the clock by changing the osInstaller clock.
retrier := retry.NewIntervalRetrier(&doer, downloadInterval, i.retriable, i.clock)
if err := retrier.Do(ctx); err != nil {
return "", fmt.Errorf("retrying downloadToTempDir: %w", err)
}
return doer.path, nil
}
// retriableHTTPStatusCodes are status codes that might flip to 200 if retried.
// This arguably depends on the web server implementation, but below list is
// a reasonable selection, cf. https://stackoverflow.com/a/74627395.
var retriableHTTPStatusCodes = []int{
http.StatusRequestTimeout,
http.StatusTooEarly,
http.StatusTooManyRequests,
http.StatusBadGateway,
http.StatusServiceUnavailable,
http.StatusGatewayTimeout,
}
// downloadHTTP downloads the given URL with the embedded HTTP client and writes the content to out.
func (i *OsInstaller) downloadHTTP(ctx context.Context, url string, out io.Writer) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return fmt.Errorf("request to download %q: %w", url, err)
}
resp, err := i.hClient.Do(req)
if err != nil {
// A failure at this point might be transient, such as network connectivity.
return fmt.Errorf("request to download %q: %w", url, &retriableError{err: err})
}
if resp.StatusCode != http.StatusOK {
// The HTTP request went through, but the result is not what we
// expected. Wrap the error return in case we think the request could
// be retried.
err = fmt.Errorf("request to download %q failed with status code: %v", url, resp.Status)
if slices.Contains(retriableHTTPStatusCodes, resp.StatusCode) {
err = &retriableError{err: err}
}
return err
}
defer resp.Body.Close()
if _, err = io.Copy(out, resp.Body); err != nil {
return fmt.Errorf("downloading %q: %w", url, &retriableError{err: err})
}
return nil
}
// unpackData parses the given data URL and writes the content to out.
func (i *OsInstaller) unpackData(url string, out io.Writer) error {
dataURL, err := dataurl.DecodeString(url)
if err != nil {
return fmt.Errorf("parsing data URL: %w", err)
}
buf := bytes.NewBuffer(dataURL.Data)
if _, err = io.Copy(out, buf); err != nil {
return fmt.Errorf("writing content of data URL %q: %w", url, err)
}
return nil
}
// downloadToTempDir downloads a file from the given URL to a temporary location and returns the path to the downloaded file.
func (i *OsInstaller) downloadToTempDir(ctx context.Context, u string) (string, error) {
url, err := url.Parse(u)
if err != nil {
return "", fmt.Errorf("parsing component URL: %w", err)
}
out, err := afero.TempFile(i.fs, "", "")
if err != nil {
return "", fmt.Errorf("creating destination temp file: %w", err)
}
if url.Scheme == "data" {
err = i.unpackData(u, out)
} else {
err = i.downloadHTTP(ctx, u, out)
}
out.Close()
if err != nil {
removeErr := i.fs.Remove(out.Name())
return "", errors.Join(err, removeErr)
}
return out.Name(), nil
}
// copy copies a file from oldname to newname.
func (i *OsInstaller) copy(oldname, newname string, perm fs.FileMode) (err error) {
old, openOldErr := i.fs.OpenFile(oldname, os.O_RDONLY, fs.ModePerm)
if openOldErr != nil {
return fmt.Errorf("copying %q to %q: cannot open source file for reading: %w", oldname, newname, openOldErr)
}
defer func() { _ = old.Close() }()
// create destination path if not exists
if err := i.fs.MkdirAll(path.Dir(newname), fs.ModePerm); err != nil {
return fmt.Errorf("copying %q to %q: unable to create destination folder: %w", oldname, newname, err)
}
newFile, openNewErr := i.fs.OpenFile(newname, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, perm)
if openNewErr != nil {
return fmt.Errorf("copying %q to %q: cannot open destination file for writing: %w", oldname, newname, openNewErr)
}
defer func() {
_ = newFile.Close()
if err != nil {
_ = i.fs.Remove(newname)
}
}()
if _, err := io.Copy(newFile, old); err != nil {
return fmt.Errorf("copying %q to %q: copying file contents: %w", oldname, newname, err)
}
return nil
}
type downloadDoer struct {
url string
downloader downloader
path string
}
type downloader interface {
downloadToTempDir(ctx context.Context, url string) (string, error)
}
func (d *downloadDoer) Do(ctx context.Context) error {
path, err := d.downloader.downloadToTempDir(ctx, d.url)
d.path = path
return err
}
// retriableError is an error that can be retried.
type retriableError struct{ err error }
func (e *retriableError) Error() string {
return fmt.Sprintf("retriable error: %s", e.err.Error())
}
func (e *retriableError) Unwrap() error { return e.err }
// isRetriable returns true if the action resulting in this error can be retried.
func isRetriable(err error) bool {
retriableError := &retriableError{}
return errors.As(err, &retriableError)
}
// verifyTarPath checks if a tar path is valid (must not contain ".." as path element).
func verifyTarPath(pat string) error {
n := len(pat)
r := 0
for r < n {
switch {
case os.IsPathSeparator(pat[r]):
// empty path element
r++
case pat[r] == '.' && (r+1 == n || os.IsPathSeparator(pat[r+1])):
// . element
r++
case pat[r] == '.' && pat[r+1] == '.' && (r+2 == n || os.IsPathSeparator(pat[r+2])):
// .. element
return errors.New("path contains \"..\"")
default:
// skip to next path element
for r < n && !os.IsPathSeparator(pat[r]) {
r++
}
}
}
return nil
}
type httpClient interface {
Do(req *http.Request) (*http.Response, error)
}