mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-06-22 05:04:25 -04:00
AB#2181: retry k8s downloads (#286)
Generalize retrier: * Generalize Do to use a supplied 'retriable' function * Make clock an optional argument in NewIntervalRetrier * Move grpc/retrier to interal package * Update existing unittests to not use retry feature Add retryDownloadToTempDir: * Wrap downloadToTempDir with retrier. * Retry if TCP connection is reset. * Abort by canceling the context. * Use a mock server in the unit test that serves responses depending on the state received through a state channel. Co-authored-by: katexochen <49727155+katexochen@users.noreply.github.com>
This commit is contained in:
parent
741384158a
commit
c743398a23
7 changed files with 338 additions and 146 deletions
|
@ -11,15 +11,28 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/edgelesssys/constellation/internal/retry"
|
||||||
"github.com/spf13/afero"
|
"github.com/spf13/afero"
|
||||||
"golang.org/x/text/transform"
|
"golang.org/x/text/transform"
|
||||||
|
"k8s.io/utils/clock"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// determines the period after which retryDownloadToTempDir will retry a download.
|
||||||
|
downloadInterval = 10 * time.Millisecond
|
||||||
)
|
)
|
||||||
|
|
||||||
// osInstaller installs binary components of supported kubernetes versions.
|
// osInstaller installs binary components of supported kubernetes versions.
|
||||||
type osInstaller struct {
|
type osInstaller struct {
|
||||||
fs *afero.Afero
|
fs *afero.Afero
|
||||||
hClient httpClient
|
hClient httpClient
|
||||||
|
// clock is needed for testing purposes
|
||||||
|
clock clock.WithTicker
|
||||||
|
// retriable is the function used to check if an error is retriable. Needed for testing.
|
||||||
|
retriable func(error) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// newOSInstaller creates a new osInstaller.
|
// newOSInstaller creates a new osInstaller.
|
||||||
|
@ -27,6 +40,8 @@ func newOSInstaller() *osInstaller {
|
||||||
return &osInstaller{
|
return &osInstaller{
|
||||||
fs: &afero.Afero{Fs: afero.NewOsFs()},
|
fs: &afero.Afero{Fs: afero.NewOsFs()},
|
||||||
hClient: &http.Client{},
|
hClient: &http.Client{},
|
||||||
|
clock: clock.RealClock{},
|
||||||
|
retriable: connectionResetErr,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -36,7 +51,7 @@ func (i *osInstaller) Install(
|
||||||
ctx context.Context, sourceURL string, destinations []string, perm fs.FileMode,
|
ctx context.Context, sourceURL string, destinations []string, perm fs.FileMode,
|
||||||
extract bool, transforms ...transform.Transformer,
|
extract bool, transforms ...transform.Transformer,
|
||||||
) error {
|
) error {
|
||||||
tempPath, err := i.downloadToTempDir(ctx, sourceURL, transforms...)
|
tempPath, err := i.retryDownloadToTempDir(ctx, sourceURL, transforms...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -130,13 +145,37 @@ func (i *osInstaller) extractArchive(archivePath, prefix string, perm fs.FileMod
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *osInstaller) retryDownloadToTempDir(ctx context.Context, url string, transforms ...transform.Transformer) (fileName string, someError error) {
|
||||||
|
doer := downloadDoer{
|
||||||
|
url: url,
|
||||||
|
transforms: transforms,
|
||||||
|
downloader: i,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retries are cancled as soon as the context is canceled.
|
||||||
|
// We need to call NewIntervalRetrier with a clock argument so that the tests can fake the clock by changing the osInstaller clock.
|
||||||
|
retrier := retry.NewIntervalRetrier(&doer, downloadInterval, i.retriable, i.clock)
|
||||||
|
if err := retrier.Do(ctx); err != nil {
|
||||||
|
return "", fmt.Errorf("retrying downloadToTempDir: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return doer.path, nil
|
||||||
|
}
|
||||||
|
|
||||||
// downloadToTempDir downloads a file to a temporary location, applying transform on-the-fly.
|
// downloadToTempDir downloads a file to a temporary location, applying transform on-the-fly.
|
||||||
func (i *osInstaller) downloadToTempDir(ctx context.Context, url string, transforms ...transform.Transformer) (string, error) {
|
func (i *osInstaller) downloadToTempDir(ctx context.Context, url string, transforms ...transform.Transformer) (fileName string, retErr error) {
|
||||||
out, err := afero.TempFile(i.fs, "", "")
|
out, err := afero.TempFile(i.fs, "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("creating destination temp file: %w", err)
|
return "", fmt.Errorf("creating destination temp file: %w", err)
|
||||||
}
|
}
|
||||||
|
// Remove the created file if an error occurs.
|
||||||
|
defer func() {
|
||||||
|
if retErr != nil {
|
||||||
|
_ = i.fs.Remove(fileName)
|
||||||
|
}
|
||||||
|
}()
|
||||||
defer out.Close()
|
defer out.Close()
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("request to download %q: %w", url, err)
|
return "", fmt.Errorf("request to download %q: %w", url, err)
|
||||||
|
@ -186,6 +225,27 @@ func (i *osInstaller) copy(oldname, newname string, perm fs.FileMode) (err error
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type downloadDoer struct {
|
||||||
|
url string
|
||||||
|
transforms []transform.Transformer
|
||||||
|
downloader downloader
|
||||||
|
path string
|
||||||
|
}
|
||||||
|
|
||||||
|
type downloader interface {
|
||||||
|
downloadToTempDir(ctx context.Context, url string, transforms ...transform.Transformer) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *downloadDoer) Do(ctx context.Context) error {
|
||||||
|
path, err := d.downloader.downloadToTempDir(ctx, d.url, d.transforms...)
|
||||||
|
d.path = path
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func connectionResetErr(err error) bool {
|
||||||
|
return errors.Is(err, syscall.ECONNRESET)
|
||||||
|
}
|
||||||
|
|
||||||
// verifyTarPath checks if a tar path is valid (must not contain ".." as path element).
|
// verifyTarPath checks if a tar path is valid (must not contain ".." as path element).
|
||||||
func verifyTarPath(pat string) error {
|
func verifyTarPath(pat string) error {
|
||||||
n := len(pat)
|
n := len(pat)
|
||||||
|
|
|
@ -12,7 +12,9 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"path"
|
"path"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
"github.com/icholy/replace"
|
"github.com/icholy/replace"
|
||||||
|
@ -21,6 +23,7 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/text/transform"
|
"golang.org/x/text/transform"
|
||||||
"google.golang.org/grpc/test/bufconn"
|
"google.golang.org/grpc/test/bufconn"
|
||||||
|
testclock "k8s.io/utils/clock/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestInstall(t *testing.T) {
|
func TestInstall(t *testing.T) {
|
||||||
|
@ -75,10 +78,14 @@ func TestInstall(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This test was written before retriability was added to Install. It makes sense to test Install as if it wouldn't retry requests.
|
||||||
inst := osInstaller{
|
inst := osInstaller{
|
||||||
fs: &afero.Afero{Fs: afero.NewMemMapFs()},
|
fs: &afero.Afero{Fs: afero.NewMemMapFs()},
|
||||||
hClient: &hClient,
|
hClient: &hClient,
|
||||||
|
clock: testclock.NewFakeClock(time.Time{}),
|
||||||
|
retriable: func(err error) bool { return false },
|
||||||
}
|
}
|
||||||
|
|
||||||
err := inst.Install(context.Background(), "http://server/path", []string{tc.destination}, fs.ModePerm, tc.extract, tc.transforms...)
|
err := inst.Install(context.Background(), "http://server/path", []string{tc.destination}, fs.ModePerm, tc.extract, tc.transforms...)
|
||||||
if tc.wantErr {
|
if tc.wantErr {
|
||||||
assert.Error(err)
|
assert.Error(err)
|
||||||
|
@ -238,6 +245,91 @@ func TestExtractArchive(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRetryDownloadToTempDir(t *testing.T) {
|
||||||
|
testCases := map[string]struct {
|
||||||
|
responses []int
|
||||||
|
cancelCtx bool
|
||||||
|
wantErr bool
|
||||||
|
wantFile []byte
|
||||||
|
}{
|
||||||
|
"Succeed on third try": {
|
||||||
|
responses: []int{500, 500, 200},
|
||||||
|
wantFile: []byte("file-content"),
|
||||||
|
},
|
||||||
|
"Cancel after second try": {
|
||||||
|
responses: []int{500, 500},
|
||||||
|
cancelCtx: true,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
// control the server's responses through stateCh
|
||||||
|
stateCh := make(chan int)
|
||||||
|
server := newHTTPBufconnServerWithState(stateCh, tc.wantFile)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
hClient := http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
DialContext: server.DialContext,
|
||||||
|
Dial: server.Dial,
|
||||||
|
DialTLSContext: server.DialContext,
|
||||||
|
DialTLS: server.Dial,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
afs := afero.NewMemMapFs()
|
||||||
|
|
||||||
|
// control download retries through FakeClock clock
|
||||||
|
clock := testclock.NewFakeClock(time.Now())
|
||||||
|
inst := osInstaller{
|
||||||
|
fs: &afero.Afero{Fs: afs},
|
||||||
|
hClient: &hClient,
|
||||||
|
clock: clock,
|
||||||
|
retriable: func(error) bool { return true },
|
||||||
|
}
|
||||||
|
|
||||||
|
// abort retryDownloadToTempDir in some test cases by using the context
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
var downloadErr error
|
||||||
|
var path string
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
path, downloadErr = inst.retryDownloadToTempDir(ctx, "http://server/path", []transform.Transformer{}...)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// control the server's responses through stateCh.
|
||||||
|
for _, resp := range tc.responses {
|
||||||
|
stateCh <- resp
|
||||||
|
clock.Step(downloadInterval)
|
||||||
|
}
|
||||||
|
if tc.cancelCtx {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
if tc.wantErr {
|
||||||
|
assert.Error(downloadErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(downloadErr)
|
||||||
|
content, err := inst.fs.ReadFile(path)
|
||||||
|
assert.NoError(err)
|
||||||
|
assert.Equal(tc.wantFile, content)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDownloadToTempDir(t *testing.T) {
|
func TestDownloadToTempDir(t *testing.T) {
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
server httpBufconnServer
|
server httpBufconnServer
|
||||||
|
@ -484,6 +576,21 @@ func newHTTPBufconnServerWithBody(body []byte) httpBufconnServer {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newHTTPBufconnServerWithState(state chan int, body []byte) httpBufconnServer {
|
||||||
|
return newHTTPBufconnServer(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch <-state {
|
||||||
|
case 500:
|
||||||
|
w.WriteHeader(500)
|
||||||
|
case 200:
|
||||||
|
if _, err := w.Write(body); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
w.WriteHeader(402)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func createTarGz(contents []byte, path string) []byte {
|
func createTarGz(contents []byte, path string) []byte {
|
||||||
tgzWriter := newTarGzWriter()
|
tgzWriter := newTarGzWriter()
|
||||||
defer func() { _ = tgzWriter.Close() }()
|
defer func() { _ = tgzWriter.Close() }()
|
||||||
|
|
|
@ -24,7 +24,8 @@ import (
|
||||||
"github.com/edgelesssys/constellation/internal/deploy/ssh"
|
"github.com/edgelesssys/constellation/internal/deploy/ssh"
|
||||||
"github.com/edgelesssys/constellation/internal/file"
|
"github.com/edgelesssys/constellation/internal/file"
|
||||||
"github.com/edgelesssys/constellation/internal/grpc/dialer"
|
"github.com/edgelesssys/constellation/internal/grpc/dialer"
|
||||||
"github.com/edgelesssys/constellation/internal/grpc/retry"
|
grpcRetry "github.com/edgelesssys/constellation/internal/grpc/retry"
|
||||||
|
"github.com/edgelesssys/constellation/internal/retry"
|
||||||
"github.com/edgelesssys/constellation/internal/state"
|
"github.com/edgelesssys/constellation/internal/state"
|
||||||
kms "github.com/edgelesssys/constellation/kms/setup"
|
kms "github.com/edgelesssys/constellation/kms/setup"
|
||||||
"github.com/spf13/afero"
|
"github.com/spf13/afero"
|
||||||
|
@ -141,7 +142,7 @@ func initCall(ctx context.Context, dialer grpcDialer, ip string, req *initproto.
|
||||||
endpoint: net.JoinHostPort(ip, strconv.Itoa(constants.BootstrapperPort)),
|
endpoint: net.JoinHostPort(ip, strconv.Itoa(constants.BootstrapperPort)),
|
||||||
req: req,
|
req: req,
|
||||||
}
|
}
|
||||||
retrier := retry.NewIntervalRetrier(doer, 30*time.Second)
|
retrier := retry.NewIntervalRetrier(doer, 30*time.Second, grpcRetry.ServiceIsUnavailable)
|
||||||
if err := retrier.Do(ctx); err != nil {
|
if err := retrier.Do(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,58 +1,16 @@
|
||||||
package retry
|
package retry
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
"k8s.io/utils/clock"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// IntervalRetrier is retries a grpc call with an interval.
|
// ServiceIsUnavailable checks if the error is a grpc status with code Unavailable.
|
||||||
type IntervalRetrier struct {
|
|
||||||
interval time.Duration
|
|
||||||
doer Doer
|
|
||||||
clock clock.WithTicker
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewIntervalRetrier returns a new IntervalRetrier.
|
|
||||||
func NewIntervalRetrier(doer Doer, interval time.Duration) *IntervalRetrier {
|
|
||||||
return &IntervalRetrier{
|
|
||||||
interval: interval,
|
|
||||||
doer: doer,
|
|
||||||
clock: clock.RealClock{},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do retries performing a grpc call until it succeeds, returns a permanent error or the context is cancelled.
|
|
||||||
func (r *IntervalRetrier) Do(ctx context.Context) error {
|
|
||||||
ticker := r.clock.NewTicker(r.interval)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
err := r.doer.Do(ctx)
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if !r.serviceIsUnavailable(err) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
case <-ticker.C():
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// serviceIsUnavailable checks if the error is a grpc status with code Unavailable.
|
|
||||||
// In the special case of an authentication handshake failure, false is returned to prevent further retries.
|
// In the special case of an authentication handshake failure, false is returned to prevent further retries.
|
||||||
func (r *IntervalRetrier) serviceIsUnavailable(err error) bool {
|
func ServiceIsUnavailable(err error) bool {
|
||||||
// taken from google.golang.org/grpc/status.FromError
|
// taken from google.golang.org/grpc/status.FromError
|
||||||
var targetErr interface {
|
var targetErr interface {
|
||||||
GRPCStatus() *status.Status
|
GRPCStatus() *status.Status
|
||||||
|
@ -75,10 +33,3 @@ func (r *IntervalRetrier) serviceIsUnavailable(err error) bool {
|
||||||
// ideally we would check the error type directly, but grpc only provides a string
|
// ideally we would check the error type directly, but grpc only provides a string
|
||||||
return !strings.HasPrefix(statusErr.Message(), `connection error: desc = "transport: authentication handshake failed`)
|
return !strings.HasPrefix(statusErr.Message(), `connection error: desc = "transport: authentication handshake failed`)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Doer interface {
|
|
||||||
// Do performs a grpc operation.
|
|
||||||
//
|
|
||||||
// It should return a grpc status with code Unavailable error to signal a transient fault.
|
|
||||||
Do(ctx context.Context) error
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,86 +1,15 @@
|
||||||
package retry
|
package retry
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
testclock "k8s.io/utils/clock/testing"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDo(t *testing.T) {
|
|
||||||
testCases := map[string]struct {
|
|
||||||
cancel bool
|
|
||||||
errors []error
|
|
||||||
wantErr error
|
|
||||||
}{
|
|
||||||
"no error": {
|
|
||||||
errors: []error{
|
|
||||||
nil,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"permanent error": {
|
|
||||||
errors: []error{
|
|
||||||
errors.New("error"),
|
|
||||||
},
|
|
||||||
wantErr: errors.New("error"),
|
|
||||||
},
|
|
||||||
"service unavailable then success": {
|
|
||||||
errors: []error{
|
|
||||||
status.Error(codes.Unavailable, "error"),
|
|
||||||
nil,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"service unavailable then permanent error": {
|
|
||||||
errors: []error{
|
|
||||||
status.Error(codes.Unavailable, "error"),
|
|
||||||
errors.New("error"),
|
|
||||||
},
|
|
||||||
wantErr: errors.New("error"),
|
|
||||||
},
|
|
||||||
"cancellation works": {
|
|
||||||
cancel: true,
|
|
||||||
errors: []error{
|
|
||||||
status.Error(codes.Unavailable, "error"),
|
|
||||||
},
|
|
||||||
wantErr: context.Canceled,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, tc := range testCases {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
assert := assert.New(t)
|
|
||||||
doer := newStubDoer()
|
|
||||||
clock := testclock.NewFakeClock(time.Now())
|
|
||||||
retrier := IntervalRetrier{
|
|
||||||
doer: doer,
|
|
||||||
clock: clock,
|
|
||||||
}
|
|
||||||
retrierResult := make(chan error, 1)
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
go func() { retrierResult <- retrier.Do(ctx) }()
|
|
||||||
|
|
||||||
if tc.cancel {
|
|
||||||
cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, err := range tc.errors {
|
|
||||||
doer.errC <- err
|
|
||||||
clock.Step(retrier.interval)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(tc.wantErr, <-retrierResult)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestServiceIsUnavailable(t *testing.T) {
|
func TestServiceIsUnavailable(t *testing.T) {
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
err error
|
err error
|
||||||
|
@ -112,22 +41,7 @@ func TestServiceIsUnavailable(t *testing.T) {
|
||||||
for name, tc := range testCases {
|
for name, tc := range testCases {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
assert := assert.New(t)
|
assert := assert.New(t)
|
||||||
retrier := IntervalRetrier{}
|
assert.Equal(tc.wantUnavailable, ServiceIsUnavailable(tc.err))
|
||||||
assert.Equal(tc.wantUnavailable, retrier.serviceIsUnavailable(tc.err))
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type stubDoer struct {
|
|
||||||
errC chan error
|
|
||||||
}
|
|
||||||
|
|
||||||
func newStubDoer() *stubDoer {
|
|
||||||
return &stubDoer{
|
|
||||||
errC: make(chan error),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *stubDoer) Do(_ context.Context) error {
|
|
||||||
return <-d.errC
|
|
||||||
}
|
|
||||||
|
|
61
internal/retry/retry.go
Normal file
61
internal/retry/retry.go
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
package retry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"k8s.io/utils/clock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IntervalRetrier retries a call with an interval. The call is defined in the Doer property.
|
||||||
|
type IntervalRetrier struct {
|
||||||
|
interval time.Duration
|
||||||
|
doer Doer
|
||||||
|
clock clock.WithTicker
|
||||||
|
retriable func(error) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewIntervalRetrier returns a new IntervalRetrier. The optional clock is used for testing.
|
||||||
|
func NewIntervalRetrier(doer Doer, interval time.Duration, retriable func(error) bool, optClock ...clock.WithTicker) *IntervalRetrier {
|
||||||
|
var clock clock.WithTicker = clock.RealClock{}
|
||||||
|
if len(optClock) > 0 {
|
||||||
|
clock = optClock[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
return &IntervalRetrier{
|
||||||
|
interval: interval,
|
||||||
|
doer: doer,
|
||||||
|
clock: clock,
|
||||||
|
retriable: retriable,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do retries performing a call until it succeeds, returns a permanent error or the context is cancelled.
|
||||||
|
func (r *IntervalRetrier) Do(ctx context.Context) error {
|
||||||
|
ticker := r.clock.NewTicker(r.interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
err := r.doer.Do(ctx)
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !r.retriable(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-ticker.C():
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Doer interface {
|
||||||
|
// Do performs an operation.
|
||||||
|
//
|
||||||
|
// It should return an error that can be checked for retriability.
|
||||||
|
Do(ctx context.Context) error
|
||||||
|
}
|
98
internal/retry/retry_test.go
Normal file
98
internal/retry/retry_test.go
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
package retry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
testclock "k8s.io/utils/clock/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDo(t *testing.T) {
|
||||||
|
testCases := map[string]struct {
|
||||||
|
cancel bool
|
||||||
|
errors []error
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
"no error": {
|
||||||
|
errors: []error{
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"permanent error": {
|
||||||
|
errors: []error{
|
||||||
|
errors.New("error"),
|
||||||
|
},
|
||||||
|
wantErr: errors.New("error"),
|
||||||
|
},
|
||||||
|
"service unavailable then success": {
|
||||||
|
errors: []error{
|
||||||
|
errors.New("retry me"),
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"service unavailable then permanent error": {
|
||||||
|
errors: []error{
|
||||||
|
errors.New("retry me"),
|
||||||
|
errors.New("error"),
|
||||||
|
},
|
||||||
|
wantErr: errors.New("error"),
|
||||||
|
},
|
||||||
|
"cancellation works": {
|
||||||
|
cancel: true,
|
||||||
|
errors: []error{
|
||||||
|
errors.New("retry me"),
|
||||||
|
},
|
||||||
|
wantErr: context.Canceled,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
doer := newStubDoer()
|
||||||
|
clock := testclock.NewFakeClock(time.Now())
|
||||||
|
retrier := IntervalRetrier{
|
||||||
|
doer: doer,
|
||||||
|
clock: clock,
|
||||||
|
retriable: isRetriable,
|
||||||
|
}
|
||||||
|
retrierResult := make(chan error, 1)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
go func() { retrierResult <- retrier.Do(ctx) }()
|
||||||
|
|
||||||
|
if tc.cancel {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, err := range tc.errors {
|
||||||
|
doer.errC <- err
|
||||||
|
clock.Step(retrier.interval)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(tc.wantErr, <-retrierResult)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type stubDoer struct {
|
||||||
|
errC chan error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStubDoer() *stubDoer {
|
||||||
|
return &stubDoer{
|
||||||
|
errC: make(chan error),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *stubDoer) Do(_ context.Context) error {
|
||||||
|
return <-d.errC
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRetriable(err error) bool {
|
||||||
|
return err.Error() == "retry me"
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue