/*
Copyright (c) Edgeless Systems GmbH

SPDX-License-Identifier: AGPL-3.0-only
*/

package imagefetcher

import (
	"bytes"
	"context"
	"io"
	"net/http"
	"os"
	"path"
	"testing"

	"github.com/spf13/afero"
	"github.com/stretchr/testify/assert"
)

func TestShouldDownload(t *testing.T) {
	testCases := map[string]struct {
		partfile, destination string
		wantDownload          bool
	}{
		"no files exist yet": {
			wantDownload: true,
		},
		"partial download": {
			partfile:     "some data",
			wantDownload: true,
		},
		"download succeeded": {
			destination: "all of the data",
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)
			downloader := &Downloader{
				fs: newDownloaderStubFs(t, "someVersion", tc.partfile, tc.destination),
			}
			gotDownload := downloader.shouldDownload("someVersion.raw")
			assert.Equal(tc.wantDownload, gotDownload)
		})
	}
}

func TestDownloadWithProgress(t *testing.T) {
	rawImage := "raw image"
	client := newTestClient(func(req *http.Request) *http.Response {
		if req.URL.String() == "https://cdn.example.com/image.raw" {
			return &http.Response{
				StatusCode: http.StatusOK,
				Body:       io.NopCloser(bytes.NewBufferString(rawImage)),
				Header:     make(http.Header),
			}
		}
		return &http.Response{
			StatusCode: http.StatusNotFound,
			Body:       io.NopCloser(bytes.NewBufferString("Not found.")),
			Header:     make(http.Header),
		}
	})

	testCases := map[string]struct {
		source  string
		wantErr bool
	}{
		"correct file requested": {
			source: "https://cdn.example.com/image.raw",
		},
		"incorrect file requested": {
			source:  "https://cdn.example.com/incorrect.raw",
			wantErr: true,
		},
		"invalid scheme": {
			source:  "xyz://",
			wantErr: true,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)
			fs := newDownloaderStubFs(t, "someVersion", "", "")
			downloader := &Downloader{
				httpc: client,
				fs:    fs,
			}
			var outBuffer bytes.Buffer
			err := downloader.downloadWithProgress(context.Background(), &outBuffer, false, tc.source, "someVersion.raw")
			if tc.wantErr {
				assert.Error(err)
				return
			}
			assert.NoError(err)
			out, err := fs.ReadFile("someVersion.raw")
			assert.NoError(err)
			assert.Equal(rawImage, string(out))
		})
	}
}

func TestDownload(t *testing.T) {
	rawImage := "raw image"
	cwd, err := os.Getwd()
	assert.NoError(t, err)
	wantDestination := path.Join(cwd, "someVersion.raw")
	client := newTestClient(func(req *http.Request) *http.Response {
		if req.URL.String() == "https://cdn.example.com/image.raw" {
			return &http.Response{
				StatusCode: http.StatusOK,
				Body:       io.NopCloser(bytes.NewBufferString(rawImage)),
				Header:     make(http.Header),
			}
		}
		return &http.Response{
			StatusCode: http.StatusNotFound,
			Body:       io.NopCloser(bytes.NewBufferString("Not found.")),
			Header:     make(http.Header),
		}
	})

	testCases := map[string]struct {
		source       string
		destination  string
		overrideFile string
		wantErr      bool
	}{
		"correct file requested": {
			source: "https://cdn.example.com/image.raw",
		},
		"file url": {
			source:       "file:///override.raw",
			overrideFile: "override image",
		},
		"file exists": {
			source:      "https://cdn.example.com/image.raw",
			destination: "already exists",
		},
		"incorrect file requested": {
			source:  "https://cdn.example.com/incorrect.raw",
			wantErr: true,
		},
		"invalid scheme": {
			source:  "xyz://",
			wantErr: true,
		},
		"invalid URL": {
			source:  "\x00",
			wantErr: true,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)
			fs := newDownloaderStubFs(t, cwd+"/someVersion", "", tc.destination)
			if tc.overrideFile != "" {
				must(t, fs.WriteFile("/override.raw", []byte(tc.overrideFile), os.ModePerm))
			}
			downloader := &Downloader{
				httpc: client,
				fs:    fs,
			}
			var outBuffer bytes.Buffer
			gotDestination, err := downloader.Download(context.Background(), &outBuffer, false, tc.source, "someVersion")
			if tc.wantErr {
				assert.Error(err)
				return
			}
			assert.NoError(err)
			if tc.overrideFile == "" {
				assert.Equal(wantDestination, gotDestination)
			} else {
				assert.Equal("/override.raw", gotDestination)
			}
			out, err := fs.ReadFile(gotDestination)
			assert.NoError(err)
			switch {
			case tc.overrideFile != "":
				assert.Equal(tc.overrideFile, string(out))
			case tc.destination != "":
				assert.Equal(tc.destination, string(out))
			default:
				assert.Equal(rawImage, string(out))
			}
		})
	}
}

func newDownloaderStubFs(t *testing.T, version, partfile, destination string) *afero.Afero {
	fs := afero.NewMemMapFs()
	if partfile != "" {
		must(t, afero.WriteFile(fs, version+".raw.part", []byte(partfile), os.ModePerm))
	}
	if destination != "" {
		must(t, afero.WriteFile(fs, version+".raw", []byte(destination), os.ModePerm))
	}
	return &afero.Afero{Fs: fs}
}