constellation/internal/imagefetcher/raw.go
2023-05-25 15:01:15 +02:00

144 lines
3.4 KiB
Go

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package imagefetcher
import (
"context"
"errors"
"fmt"
"io"
"io/fs"
"net/http"
"net/url"
"os"
"path/filepath"
"github.com/schollz/progressbar/v3"
"github.com/spf13/afero"
)
// Downloader downloads raw images.
type Downloader struct {
httpc httpc
fs *afero.Afero
}
// NewDownloader creates a new Downloader.
func NewDownloader() *Downloader {
return &Downloader{
httpc: http.DefaultClient,
fs: &afero.Afero{Fs: afero.NewOsFs()},
}
}
// Download downloads the raw image from source.
func (d *Downloader) Download(ctx context.Context, errWriter io.Writer, showBar bool, source, imageName string) (string, error) {
url, err := url.Parse(source)
if err != nil {
return "", fmt.Errorf("parsing image source URL: %w", err)
}
imageName = filepath.Base(imageName)
var partfile, destination string
switch url.Scheme {
case "http", "https":
cwd, err := os.Getwd()
if err != nil {
return "", fmt.Errorf("getting current working directory: %w", err)
}
partfile = filepath.Join(cwd, imageName+".raw.part")
destination = filepath.Join(cwd, imageName+".raw")
case "file":
return url.Path, nil
default:
return "", fmt.Errorf("unsupported image source URL scheme: %s", url.Scheme)
}
if !d.shouldDownload(destination) {
return destination, nil
}
if err := d.downloadWithProgress(ctx, errWriter, showBar, source, partfile); err != nil {
return "", err
}
return destination, d.fs.Rename(partfile, destination)
}
// shouldDownload checks if the image should be downloaded.
func (d *Downloader) shouldDownload(destination string) bool {
_, err := d.fs.Stat(destination)
return errors.Is(err, fs.ErrNotExist)
}
// downloadWithProgress downloads the raw image from source to the destination.
func (d *Downloader) downloadWithProgress(ctx context.Context, errWriter io.Writer, showBar bool, source, destination string) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, source, nil)
if err != nil {
return fmt.Errorf("creating request: %w", err)
}
resp, err := d.httpc.Do(req)
if err != nil {
return fmt.Errorf("doing request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("downloading from %q: %s", source, resp.Status)
}
f, err := d.fs.OpenFile(destination, os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return err
}
defer f.Close()
var bar io.WriteCloser
if showBar {
bar = prepareBar(errWriter, resp.ContentLength)
} else {
bar = &nopWriteCloser{}
}
defer bar.Close()
_, err = io.Copy(io.MultiWriter(f, bar), resp.Body)
if err != nil {
return err
}
return nil
}
func prepareBar(writer io.Writer, total int64) io.WriteCloser {
return progressbar.NewOptions64(
total,
progressbar.OptionSetWriter(writer),
progressbar.OptionShowBytes(true),
progressbar.OptionSetPredictTime(true),
progressbar.OptionFullWidth(),
progressbar.OptionSetTheme(progressbar.Theme{
Saucer: "=",
SaucerHead: ">",
SaucerPadding: " ",
BarStart: "[",
BarEnd: "]",
}),
progressbar.OptionClearOnFinish(),
progressbar.OptionOnCompletion(func() { fmt.Fprintf(writer, "Done.\n\n") }),
)
}
type nopWriteCloser struct{}
func (*nopWriteCloser) Write(p []byte) (int, error) {
return len(p), nil
}
func (*nopWriteCloser) Close() error {
return nil
}
type httpc interface {
Do(req *http.Request) (*http.Response, error)
}