/*
Copyright (c) Edgeless Systems GmbH

SPDX-License-Identifier: AGPL-3.0-only
*/

package cmd

import (
	"bytes"
	"errors"
	"io"
	"testing"

	"github.com/spf13/cobra"
	"github.com/stretchr/testify/assert"
)

func TestAskToConfirm(t *testing.T) {
	// errAborted is an error where the user aborted the action.
	errAborted := errors.New("user aborted")

	cmd := &cobra.Command{
		Use:  "test",
		Args: cobra.NoArgs,
		RunE: func(cmd *cobra.Command, args []string) error {
			ok, err := askToConfirm(cmd, "777")
			if err != nil {
				return err
			}
			if !ok {
				return errAborted
			}
			return nil
		},
	}

	testCases := map[string]struct {
		input   string
		wantErr error
	}{
		"user confirms":                       {"y\n", nil},
		"user confirms long":                  {"yes\n", nil},
		"user disagrees":                      {"n\n", errAborted},
		"user disagrees long":                 {"no\n", errAborted},
		"user is first unsure, but agrees":    {"what?\ny\n", nil},
		"user is first unsure, but disagrees": {"wait.\nn\n", errAborted},
		"repeated invalid input":              {"h\nb\nq\n", ErrInvalidInput},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)

			out := &bytes.Buffer{}
			cmd.SetOut(out)
			cmd.SetErr(&bytes.Buffer{})
			in := bytes.NewBufferString(tc.input)
			cmd.SetIn(in)
			cmd.SetArgs([]string{})

			err := cmd.Execute()
			assert.ErrorIs(err, tc.wantErr)

			output, err := io.ReadAll(out)
			assert.NoError(err)
			assert.Contains(string(output), "777")
		})
	}
}