AB#2181: retry k8s downloads (#286)

Generalize retrier:
* Generalize Do to use a supplied 'retriable' function
* Make clock an optional argument in NewIntervalRetrier
* Move grpc/retrier to interal package
* Update existing unittests to not use retry feature

Add retryDownloadToTempDir:
* Wrap downloadToTempDir with retrier.
* Retry if TCP connection is reset.
* Abort by canceling the context.
* Use a mock server in the unit test that serves responses
depending on the state received through a state channel.

Co-authored-by: katexochen <49727155+katexochen@users.noreply.github.com>
This commit is contained in:
Otto Bittner 2022-07-21 15:20:12 +02:00 committed by GitHub
parent 741384158a
commit c743398a23
7 changed files with 338 additions and 146 deletions

View File

@ -11,22 +11,37 @@ import (
"net/http"
"os"
"path"
"syscall"
"time"
"github.com/edgelesssys/constellation/internal/retry"
"github.com/spf13/afero"
"golang.org/x/text/transform"
"k8s.io/utils/clock"
)
const (
// determines the period after which retryDownloadToTempDir will retry a download.
downloadInterval = 10 * time.Millisecond
)
// osInstaller installs binary components of supported kubernetes versions.
type osInstaller struct {
fs *afero.Afero
hClient httpClient
// clock is needed for testing purposes
clock clock.WithTicker
// retriable is the function used to check if an error is retriable. Needed for testing.
retriable func(error) bool
}
// newOSInstaller creates a new osInstaller.
func newOSInstaller() *osInstaller {
return &osInstaller{
fs: &afero.Afero{Fs: afero.NewOsFs()},
hClient: &http.Client{},
fs: &afero.Afero{Fs: afero.NewOsFs()},
hClient: &http.Client{},
clock: clock.RealClock{},
retriable: connectionResetErr,
}
}
@ -36,7 +51,7 @@ func (i *osInstaller) Install(
ctx context.Context, sourceURL string, destinations []string, perm fs.FileMode,
extract bool, transforms ...transform.Transformer,
) error {
tempPath, err := i.downloadToTempDir(ctx, sourceURL, transforms...)
tempPath, err := i.retryDownloadToTempDir(ctx, sourceURL, transforms...)
if err != nil {
return err
}
@ -130,13 +145,37 @@ func (i *osInstaller) extractArchive(archivePath, prefix string, perm fs.FileMod
}
}
func (i *osInstaller) retryDownloadToTempDir(ctx context.Context, url string, transforms ...transform.Transformer) (fileName string, someError error) {
doer := downloadDoer{
url: url,
transforms: transforms,
downloader: i,
}
// Retries are cancled as soon as the context is canceled.
// We need to call NewIntervalRetrier with a clock argument so that the tests can fake the clock by changing the osInstaller clock.
retrier := retry.NewIntervalRetrier(&doer, downloadInterval, i.retriable, i.clock)
if err := retrier.Do(ctx); err != nil {
return "", fmt.Errorf("retrying downloadToTempDir: %w", err)
}
return doer.path, nil
}
// downloadToTempDir downloads a file to a temporary location, applying transform on-the-fly.
func (i *osInstaller) downloadToTempDir(ctx context.Context, url string, transforms ...transform.Transformer) (string, error) {
func (i *osInstaller) downloadToTempDir(ctx context.Context, url string, transforms ...transform.Transformer) (fileName string, retErr error) {
out, err := afero.TempFile(i.fs, "", "")
if err != nil {
return "", fmt.Errorf("creating destination temp file: %w", err)
}
// Remove the created file if an error occurs.
defer func() {
if retErr != nil {
_ = i.fs.Remove(fileName)
}
}()
defer out.Close()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return "", fmt.Errorf("request to download %q: %w", url, err)
@ -186,6 +225,27 @@ func (i *osInstaller) copy(oldname, newname string, perm fs.FileMode) (err error
return nil
}
type downloadDoer struct {
url string
transforms []transform.Transformer
downloader downloader
path string
}
type downloader interface {
downloadToTempDir(ctx context.Context, url string, transforms ...transform.Transformer) (string, error)
}
func (d *downloadDoer) Do(ctx context.Context) error {
path, err := d.downloader.downloadToTempDir(ctx, d.url, d.transforms...)
d.path = path
return err
}
func connectionResetErr(err error) bool {
return errors.Is(err, syscall.ECONNRESET)
}
// verifyTarPath checks if a tar path is valid (must not contain ".." as path element).
func verifyTarPath(pat string) error {
n := len(pat)

View File

@ -12,7 +12,9 @@ import (
"net/http"
"net/http/httptest"
"path"
"sync"
"testing"
"time"
"github.com/hashicorp/go-multierror"
"github.com/icholy/replace"
@ -21,6 +23,7 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/text/transform"
"google.golang.org/grpc/test/bufconn"
testclock "k8s.io/utils/clock/testing"
)
func TestInstall(t *testing.T) {
@ -75,10 +78,14 @@ func TestInstall(t *testing.T) {
},
}
// This test was written before retriability was added to Install. It makes sense to test Install as if it wouldn't retry requests.
inst := osInstaller{
fs: &afero.Afero{Fs: afero.NewMemMapFs()},
hClient: &hClient,
fs: &afero.Afero{Fs: afero.NewMemMapFs()},
hClient: &hClient,
clock: testclock.NewFakeClock(time.Time{}),
retriable: func(err error) bool { return false },
}
err := inst.Install(context.Background(), "http://server/path", []string{tc.destination}, fs.ModePerm, tc.extract, tc.transforms...)
if tc.wantErr {
assert.Error(err)
@ -238,6 +245,91 @@ func TestExtractArchive(t *testing.T) {
}
}
func TestRetryDownloadToTempDir(t *testing.T) {
testCases := map[string]struct {
responses []int
cancelCtx bool
wantErr bool
wantFile []byte
}{
"Succeed on third try": {
responses: []int{500, 500, 200},
wantFile: []byte("file-content"),
},
"Cancel after second try": {
responses: []int{500, 500},
cancelCtx: true,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
// control the server's responses through stateCh
stateCh := make(chan int)
server := newHTTPBufconnServerWithState(stateCh, tc.wantFile)
defer server.Close()
hClient := http.Client{
Transport: &http.Transport{
DialContext: server.DialContext,
Dial: server.Dial,
DialTLSContext: server.DialContext,
DialTLS: server.Dial,
},
}
afs := afero.NewMemMapFs()
// control download retries through FakeClock clock
clock := testclock.NewFakeClock(time.Now())
inst := osInstaller{
fs: &afero.Afero{Fs: afs},
hClient: &hClient,
clock: clock,
retriable: func(error) bool { return true },
}
// abort retryDownloadToTempDir in some test cases by using the context
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
wg := sync.WaitGroup{}
var downloadErr error
var path string
wg.Add(1)
go func() {
defer wg.Done()
path, downloadErr = inst.retryDownloadToTempDir(ctx, "http://server/path", []transform.Transformer{}...)
}()
// control the server's responses through stateCh.
for _, resp := range tc.responses {
stateCh <- resp
clock.Step(downloadInterval)
}
if tc.cancelCtx {
cancel()
}
wg.Wait()
if tc.wantErr {
assert.Error(downloadErr)
return
}
require.NoError(downloadErr)
content, err := inst.fs.ReadFile(path)
assert.NoError(err)
assert.Equal(tc.wantFile, content)
})
}
}
func TestDownloadToTempDir(t *testing.T) {
testCases := map[string]struct {
server httpBufconnServer
@ -484,6 +576,21 @@ func newHTTPBufconnServerWithBody(body []byte) httpBufconnServer {
})
}
func newHTTPBufconnServerWithState(state chan int, body []byte) httpBufconnServer {
return newHTTPBufconnServer(func(w http.ResponseWriter, r *http.Request) {
switch <-state {
case 500:
w.WriteHeader(500)
case 200:
if _, err := w.Write(body); err != nil {
panic(err)
}
default:
w.WriteHeader(402)
}
})
}
func createTarGz(contents []byte, path string) []byte {
tgzWriter := newTarGzWriter()
defer func() { _ = tgzWriter.Close() }()

View File

@ -24,7 +24,8 @@ import (
"github.com/edgelesssys/constellation/internal/deploy/ssh"
"github.com/edgelesssys/constellation/internal/file"
"github.com/edgelesssys/constellation/internal/grpc/dialer"
"github.com/edgelesssys/constellation/internal/grpc/retry"
grpcRetry "github.com/edgelesssys/constellation/internal/grpc/retry"
"github.com/edgelesssys/constellation/internal/retry"
"github.com/edgelesssys/constellation/internal/state"
kms "github.com/edgelesssys/constellation/kms/setup"
"github.com/spf13/afero"
@ -141,7 +142,7 @@ func initCall(ctx context.Context, dialer grpcDialer, ip string, req *initproto.
endpoint: net.JoinHostPort(ip, strconv.Itoa(constants.BootstrapperPort)),
req: req,
}
retrier := retry.NewIntervalRetrier(doer, 30*time.Second)
retrier := retry.NewIntervalRetrier(doer, 30*time.Second, grpcRetry.ServiceIsUnavailable)
if err := retrier.Do(ctx); err != nil {
return nil, err
}

View File

@ -1,58 +1,16 @@
package retry
import (
"context"
"errors"
"strings"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"k8s.io/utils/clock"
)
// IntervalRetrier is retries a grpc call with an interval.
type IntervalRetrier struct {
interval time.Duration
doer Doer
clock clock.WithTicker
}
// NewIntervalRetrier returns a new IntervalRetrier.
func NewIntervalRetrier(doer Doer, interval time.Duration) *IntervalRetrier {
return &IntervalRetrier{
interval: interval,
doer: doer,
clock: clock.RealClock{},
}
}
// Do retries performing a grpc call until it succeeds, returns a permanent error or the context is cancelled.
func (r *IntervalRetrier) Do(ctx context.Context) error {
ticker := r.clock.NewTicker(r.interval)
defer ticker.Stop()
for {
err := r.doer.Do(ctx)
if err == nil {
return nil
}
if !r.serviceIsUnavailable(err) {
return err
}
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C():
}
}
}
// serviceIsUnavailable checks if the error is a grpc status with code Unavailable.
// ServiceIsUnavailable checks if the error is a grpc status with code Unavailable.
// In the special case of an authentication handshake failure, false is returned to prevent further retries.
func (r *IntervalRetrier) serviceIsUnavailable(err error) bool {
func ServiceIsUnavailable(err error) bool {
// taken from google.golang.org/grpc/status.FromError
var targetErr interface {
GRPCStatus() *status.Status
@ -75,10 +33,3 @@ func (r *IntervalRetrier) serviceIsUnavailable(err error) bool {
// ideally we would check the error type directly, but grpc only provides a string
return !strings.HasPrefix(statusErr.Message(), `connection error: desc = "transport: authentication handshake failed`)
}
type Doer interface {
// Do performs a grpc operation.
//
// It should return a grpc status with code Unavailable error to signal a transient fault.
Do(ctx context.Context) error
}

View File

@ -1,86 +1,15 @@
package retry
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
testclock "k8s.io/utils/clock/testing"
)
func TestDo(t *testing.T) {
testCases := map[string]struct {
cancel bool
errors []error
wantErr error
}{
"no error": {
errors: []error{
nil,
},
},
"permanent error": {
errors: []error{
errors.New("error"),
},
wantErr: errors.New("error"),
},
"service unavailable then success": {
errors: []error{
status.Error(codes.Unavailable, "error"),
nil,
},
},
"service unavailable then permanent error": {
errors: []error{
status.Error(codes.Unavailable, "error"),
errors.New("error"),
},
wantErr: errors.New("error"),
},
"cancellation works": {
cancel: true,
errors: []error{
status.Error(codes.Unavailable, "error"),
},
wantErr: context.Canceled,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
doer := newStubDoer()
clock := testclock.NewFakeClock(time.Now())
retrier := IntervalRetrier{
doer: doer,
clock: clock,
}
retrierResult := make(chan error, 1)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { retrierResult <- retrier.Do(ctx) }()
if tc.cancel {
cancel()
}
for _, err := range tc.errors {
doer.errC <- err
clock.Step(retrier.interval)
}
assert.Equal(tc.wantErr, <-retrierResult)
})
}
}
func TestServiceIsUnavailable(t *testing.T) {
testCases := map[string]struct {
err error
@ -112,22 +41,7 @@ func TestServiceIsUnavailable(t *testing.T) {
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
retrier := IntervalRetrier{}
assert.Equal(tc.wantUnavailable, retrier.serviceIsUnavailable(tc.err))
assert.Equal(tc.wantUnavailable, ServiceIsUnavailable(tc.err))
})
}
}
type stubDoer struct {
errC chan error
}
func newStubDoer() *stubDoer {
return &stubDoer{
errC: make(chan error),
}
}
func (d *stubDoer) Do(_ context.Context) error {
return <-d.errC
}

61
internal/retry/retry.go Normal file
View File

@ -0,0 +1,61 @@
package retry
import (
"context"
"time"
"k8s.io/utils/clock"
)
// IntervalRetrier retries a call with an interval. The call is defined in the Doer property.
type IntervalRetrier struct {
interval time.Duration
doer Doer
clock clock.WithTicker
retriable func(error) bool
}
// NewIntervalRetrier returns a new IntervalRetrier. The optional clock is used for testing.
func NewIntervalRetrier(doer Doer, interval time.Duration, retriable func(error) bool, optClock ...clock.WithTicker) *IntervalRetrier {
var clock clock.WithTicker = clock.RealClock{}
if len(optClock) > 0 {
clock = optClock[0]
}
return &IntervalRetrier{
interval: interval,
doer: doer,
clock: clock,
retriable: retriable,
}
}
// Do retries performing a call until it succeeds, returns a permanent error or the context is cancelled.
func (r *IntervalRetrier) Do(ctx context.Context) error {
ticker := r.clock.NewTicker(r.interval)
defer ticker.Stop()
for {
err := r.doer.Do(ctx)
if err == nil {
return nil
}
if !r.retriable(err) {
return err
}
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C():
}
}
}
type Doer interface {
// Do performs an operation.
//
// It should return an error that can be checked for retriability.
Do(ctx context.Context) error
}

View File

@ -0,0 +1,98 @@
package retry
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
testclock "k8s.io/utils/clock/testing"
)
func TestDo(t *testing.T) {
testCases := map[string]struct {
cancel bool
errors []error
wantErr error
}{
"no error": {
errors: []error{
nil,
},
},
"permanent error": {
errors: []error{
errors.New("error"),
},
wantErr: errors.New("error"),
},
"service unavailable then success": {
errors: []error{
errors.New("retry me"),
nil,
},
},
"service unavailable then permanent error": {
errors: []error{
errors.New("retry me"),
errors.New("error"),
},
wantErr: errors.New("error"),
},
"cancellation works": {
cancel: true,
errors: []error{
errors.New("retry me"),
},
wantErr: context.Canceled,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
doer := newStubDoer()
clock := testclock.NewFakeClock(time.Now())
retrier := IntervalRetrier{
doer: doer,
clock: clock,
retriable: isRetriable,
}
retrierResult := make(chan error, 1)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { retrierResult <- retrier.Do(ctx) }()
if tc.cancel {
cancel()
}
for _, err := range tc.errors {
doer.errC <- err
clock.Step(retrier.interval)
}
assert.Equal(tc.wantErr, <-retrierResult)
})
}
}
type stubDoer struct {
errC chan error
}
func newStubDoer() *stubDoer {
return &stubDoer{
errC: make(chan error),
}
}
func (d *stubDoer) Do(_ context.Context) error {
return <-d.errC
}
func isRetriable(err error) bool {
return err.Error() == "retry me"
}