/*
Copyright (c) Edgeless Systems GmbH

SPDX-License-Identifier: AGPL-3.0-only
*/

package retry

import (
	"context"
	"errors"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
	"go.uber.org/goleak"
	testclock "k8s.io/utils/clock/testing"
)

func TestMain(m *testing.M) {
	goleak.VerifyTestMain(m)
}

func TestDo(t *testing.T) {
	testCases := map[string]struct {
		cancel  bool
		errors  []error
		wantErr error
	}{
		"no error": {
			errors: []error{
				nil,
			},
		},
		"permanent error": {
			errors: []error{
				errors.New("error"),
			},
			wantErr: errors.New("error"),
		},
		"service unavailable then success": {
			errors: []error{
				errors.New("retry me"),
				nil,
			},
		},
		"service unavailable then permanent error": {
			errors: []error{
				errors.New("retry me"),
				errors.New("error"),
			},
			wantErr: errors.New("error"),
		},
		"cancellation works": {
			cancel: true,
			errors: []error{
				errors.New("retry me"),
			},
			wantErr: context.Canceled,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)
			doer := newStubDoer()
			clock := testclock.NewFakeClock(time.Now())
			retrier := IntervalRetrier{
				doer:      doer,
				clock:     clock,
				retriable: isRetriable,
			}
			retrierResult := make(chan error, 1)
			ctx, cancel := context.WithCancel(context.Background())
			defer cancel()

			go func() { retrierResult <- retrier.Do(ctx) }()
			for _, err := range tc.errors {
				doer.errC <- err
				clock.Step(retrier.interval)
			}

			if tc.cancel {
				cancel()
			}

			assert.Equal(tc.wantErr, <-retrierResult)
		})
	}
}

type stubDoer struct {
	errC chan error
}

func newStubDoer() *stubDoer {
	return &stubDoer{
		errC: make(chan error),
	}
}

func (d *stubDoer) Do(ctx context.Context) error {
	select {
	case <-ctx.Done():
		return ctx.Err()
	case err := <-d.errC:
		return err
	}
}

func isRetriable(err error) bool {
	return err.Error() == "retry me"
}