/*
Copyright (c) Edgeless Systems GmbH

SPDX-License-Identifier: AGPL-3.0-only
*/

package poller

import (
	"context"
	"errors"
	"sync"
	"testing"
	"time"

	"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
	"github.com/stretchr/testify/assert"
	testclock "k8s.io/utils/clock/testing"
)

func TestResult(t *testing.T) {
	testCases := map[string]struct {
		done       bool
		pollErr    error
		resultErr  error
		wantErr    bool
		wantResult int
	}{
		"result called before poller is done": {
			wantErr: true,
		},
		"result returns error": {
			done:      true,
			resultErr: errors.New("result error"),
			wantErr:   true,
		},
		"result succeeds": {
			done: true,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)

			poller := New[int](&stubPoller[int]{
				result:    &tc.wantResult,
				done:      tc.done,
				pollErr:   tc.pollErr,
				resultErr: tc.resultErr,
			})
			_, firstErr := poller.Result(context.Background())
			if tc.wantErr {
				assert.Error(firstErr)
				// calling Result again should return the same error
				_, secondErr := poller.Result(context.Background())
				assert.Equal(firstErr, secondErr)
				return
			}
			assert.NoError(firstErr)
			// calling Result again should still not return an error
			_, secondErr := poller.Result(context.Background())
			assert.NoError(secondErr)
		})
	}
}

func TestPollUntilDone(t *testing.T) {
	testCases := map[string]struct {
		messages   []message
		maxBackoff time.Duration
		resultErr  error
		wantErr    bool
		wantResult int
	}{
		"poll succeeds on first try": {
			messages: []message{
				{pollErr: to.Ptr[error](nil), done: to.Ptr(true)},
				{done: to.Ptr(true)}, // Result() will call Done() after the last poll
			},
			wantResult: 1,
		},
		"poll succeeds on fourth try": {
			messages: []message{
				{pollErr: to.Ptr[error](nil), done: to.Ptr(false), backoff: time.Second},
				{pollErr: to.Ptr[error](nil), done: to.Ptr(false), backoff: 2 * time.Second},
				{pollErr: to.Ptr[error](nil), done: to.Ptr(false), backoff: 4 * time.Second},
				{pollErr: to.Ptr[error](nil), done: to.Ptr(true)},
				{done: to.Ptr(true)}, // Result() will call Done() after the last poll
			},
			wantResult: 1,
		},
		"max backoff reached": {
			messages: []message{
				{pollErr: to.Ptr[error](nil), done: to.Ptr(false), backoff: time.Second},
				{pollErr: to.Ptr[error](nil), done: to.Ptr(false), backoff: time.Second},
				{pollErr: to.Ptr[error](nil), done: to.Ptr(false), backoff: time.Second},
				{pollErr: to.Ptr[error](nil), done: to.Ptr(true)},
				{done: to.Ptr(true)}, // Result() will call Done() after the last poll
			},
			maxBackoff: time.Second,
			wantResult: 1,
		},
		"poll errors": {
			messages: []message{
				{pollErr: to.Ptr(errors.New("poll error"))},
			},
			wantErr: true,
		},
		"result errors": {
			messages: []message{
				{pollErr: to.Ptr[error](nil), done: to.Ptr(true)},
				{done: to.Ptr(true)}, // Result() will call Done() after the last poll
			},
			resultErr: errors.New("result error"),
			wantErr:   true,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)

			doneC := make(chan bool)
			pollC := make(chan error)
			poller := New[int](&fakePoller[int]{
				result:    &tc.wantResult,
				resultErr: tc.resultErr,
				doneC:     doneC,
				pollC:     pollC,
			})
			clock := testclock.NewFakeClock(time.Now())
			var gotResult int
			var gotErr error
			wg := sync.WaitGroup{}
			wg.Add(1)
			go func() {
				defer wg.Done()
				gotResult, gotErr = poller.PollUntilDone(context.Background(), &PollUntilDoneOptions{
					MaxBackoff: tc.maxBackoff,
					Clock:      clock,
				})
			}()

			for _, msg := range tc.messages {
				if msg.pollErr != nil {
					pollC <- *msg.pollErr
				}
				if msg.done != nil {
					doneC <- *msg.done
				}
				clock.Step(msg.backoff)
			}

			wg.Wait()
			if tc.wantErr {
				assert.Error(gotErr)
				return
			}
			assert.NoError(gotErr)
			assert.Equal(tc.wantResult, gotResult)
		})
	}
}

type stubPoller[T any] struct {
	result    *T
	done      bool
	pollErr   error
	resultErr error
}

func (s *stubPoller[T]) Poll(_ context.Context) error {
	return s.pollErr
}

func (s *stubPoller[T]) Done() bool {
	return s.done
}

func (s *stubPoller[T]) Result(_ context.Context, out *T) error {
	*out = *s.result
	return s.resultErr
}

type message struct {
	pollErr *error
	done    *bool
	backoff time.Duration
}

type fakePoller[T any] struct {
	result    *T
	resultErr error

	doneC chan bool
	pollC chan error
}

func (s *fakePoller[T]) Poll(_ context.Context) error {
	return <-s.pollC
}

func (s *fakePoller[T]) Done() bool {
	return <-s.doneC
}

func (s *fakePoller[T]) Result(_ context.Context, out *T) error {
	*out = *s.result
	return s.resultErr
}