constellation/internal/retry/retry_test.go

114 lines
1.9 KiB
Go
Raw Normal View History

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package retry
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
testclock "k8s.io/utils/clock/testing"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestDo(t *testing.T) {
testCases := map[string]struct {
cancel bool
errors []error
wantErr error
}{
"no error": {
errors: []error{
nil,
},
},
"permanent error": {
errors: []error{
errors.New("error"),
},
wantErr: errors.New("error"),
},
"service unavailable then success": {
errors: []error{
errors.New("retry me"),
nil,
},
},
"service unavailable then permanent error": {
errors: []error{
errors.New("retry me"),
errors.New("error"),
},
wantErr: errors.New("error"),
},
"cancellation works": {
cancel: true,
errors: []error{
errors.New("retry me"),
},
wantErr: context.Canceled,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
doer := newStubDoer()
clock := testclock.NewFakeClock(time.Now())
retrier := IntervalRetrier{
doer: doer,
clock: clock,
retriable: isRetriable,
}
retrierResult := make(chan error, 1)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { retrierResult <- retrier.Do(ctx) }()
for _, err := range tc.errors {
doer.errC <- err
clock.Step(retrier.interval)
}
if tc.cancel {
cancel()
}
assert.Equal(tc.wantErr, <-retrierResult)
})
}
}
type stubDoer struct {
errC chan error
}
func newStubDoer() *stubDoer {
return &stubDoer{
errC: make(chan error),
}
}
func (d *stubDoer) Do(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case err := <-d.errC:
return err
}
}
func isRetriable(err error) bool {
return err.Error() == "retry me"
}