mirror of
https://github.com/edgelesssys/constellation.git
synced 2024-12-15 10:54:29 -05:00
144 lines
3.4 KiB
Go
144 lines
3.4 KiB
Go
/*
|
|
Copyright (c) Edgeless Systems GmbH
|
|
|
|
SPDX-License-Identifier: AGPL-3.0-only
|
|
*/
|
|
|
|
package imagefetcher
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/fs"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
|
|
"github.com/schollz/progressbar/v3"
|
|
"github.com/spf13/afero"
|
|
)
|
|
|
|
// Downloader downloads raw images.
|
|
type Downloader struct {
|
|
httpc httpc
|
|
fs *afero.Afero
|
|
}
|
|
|
|
// NewDownloader creates a new Downloader.
|
|
func NewDownloader() *Downloader {
|
|
return &Downloader{
|
|
httpc: http.DefaultClient,
|
|
fs: &afero.Afero{Fs: afero.NewOsFs()},
|
|
}
|
|
}
|
|
|
|
// Download downloads the raw image from source.
|
|
func (d *Downloader) Download(ctx context.Context, errWriter io.Writer, showBar bool, source, imageName string) (string, error) {
|
|
url, err := url.Parse(source)
|
|
if err != nil {
|
|
return "", fmt.Errorf("parsing image source URL: %w", err)
|
|
}
|
|
imageName = filepath.Base(imageName)
|
|
var partfile, destination string
|
|
switch url.Scheme {
|
|
case "http", "https":
|
|
cwd, err := os.Getwd()
|
|
if err != nil {
|
|
return "", fmt.Errorf("getting current working directory: %w", err)
|
|
}
|
|
partfile = filepath.Join(cwd, imageName+".raw.part")
|
|
destination = filepath.Join(cwd, imageName+".raw")
|
|
case "file":
|
|
return url.Path, nil
|
|
default:
|
|
return "", fmt.Errorf("unsupported image source URL scheme: %s", url.Scheme)
|
|
}
|
|
if !d.shouldDownload(destination) {
|
|
return destination, nil
|
|
}
|
|
if err := d.downloadWithProgress(ctx, errWriter, showBar, source, partfile); err != nil {
|
|
return "", err
|
|
}
|
|
return destination, d.fs.Rename(partfile, destination)
|
|
}
|
|
|
|
// shouldDownload checks if the image should be downloaded.
|
|
func (d *Downloader) shouldDownload(destination string) bool {
|
|
_, err := d.fs.Stat(destination)
|
|
return errors.Is(err, fs.ErrNotExist)
|
|
}
|
|
|
|
// downloadWithProgress downloads the raw image from source to the destination.
|
|
func (d *Downloader) downloadWithProgress(ctx context.Context, errWriter io.Writer, showBar bool, source, destination string) error {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, source, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("creating request: %w", err)
|
|
}
|
|
|
|
resp, err := d.httpc.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("doing request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("downloading from %q: %s", source, resp.Status)
|
|
}
|
|
|
|
f, err := d.fs.OpenFile(destination, os.O_CREATE|os.O_WRONLY, 0o644)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
|
|
var bar io.WriteCloser
|
|
if showBar {
|
|
bar = prepareBar(errWriter, resp.ContentLength)
|
|
} else {
|
|
bar = &nopWriteCloser{}
|
|
}
|
|
defer bar.Close()
|
|
|
|
_, err = io.Copy(io.MultiWriter(f, bar), resp.Body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func prepareBar(writer io.Writer, total int64) io.WriteCloser {
|
|
return progressbar.NewOptions64(
|
|
total,
|
|
progressbar.OptionSetWriter(writer),
|
|
progressbar.OptionShowBytes(true),
|
|
progressbar.OptionSetPredictTime(true),
|
|
progressbar.OptionFullWidth(),
|
|
progressbar.OptionSetTheme(progressbar.Theme{
|
|
Saucer: "=",
|
|
SaucerHead: ">",
|
|
SaucerPadding: " ",
|
|
BarStart: "[",
|
|
BarEnd: "]",
|
|
}),
|
|
progressbar.OptionClearOnFinish(),
|
|
progressbar.OptionOnCompletion(func() { fmt.Fprintf(writer, "Done.\n\n") }),
|
|
)
|
|
}
|
|
|
|
type nopWriteCloser struct{}
|
|
|
|
func (*nopWriteCloser) Write(p []byte) (int, error) {
|
|
return len(p), nil
|
|
}
|
|
|
|
func (*nopWriteCloser) Close() error {
|
|
return nil
|
|
}
|
|
|
|
type httpc interface {
|
|
Do(req *http.Request) (*http.Response, error)
|
|
}
|