diff --git a/bootstrapper/internal/kubernetes/k8sapi/install.go b/bootstrapper/internal/kubernetes/k8sapi/install.go index 7a56a9ed1..72888fc26 100644 --- a/bootstrapper/internal/kubernetes/k8sapi/install.go +++ b/bootstrapper/internal/kubernetes/k8sapi/install.go @@ -11,22 +11,37 @@ import ( "net/http" "os" "path" + "syscall" + "time" + "github.com/edgelesssys/constellation/internal/retry" "github.com/spf13/afero" "golang.org/x/text/transform" + "k8s.io/utils/clock" +) + +const ( + // determines the period after which retryDownloadToTempDir will retry a download. + downloadInterval = 10 * time.Millisecond ) // osInstaller installs binary components of supported kubernetes versions. type osInstaller struct { fs *afero.Afero hClient httpClient + // clock is needed for testing purposes + clock clock.WithTicker + // retriable is the function used to check if an error is retriable. Needed for testing. + retriable func(error) bool } // newOSInstaller creates a new osInstaller. func newOSInstaller() *osInstaller { return &osInstaller{ - fs: &afero.Afero{Fs: afero.NewOsFs()}, - hClient: &http.Client{}, + fs: &afero.Afero{Fs: afero.NewOsFs()}, + hClient: &http.Client{}, + clock: clock.RealClock{}, + retriable: connectionResetErr, } } @@ -36,7 +51,7 @@ func (i *osInstaller) Install( ctx context.Context, sourceURL string, destinations []string, perm fs.FileMode, extract bool, transforms ...transform.Transformer, ) error { - tempPath, err := i.downloadToTempDir(ctx, sourceURL, transforms...) + tempPath, err := i.retryDownloadToTempDir(ctx, sourceURL, transforms...) if err != nil { return err } @@ -130,13 +145,37 @@ func (i *osInstaller) extractArchive(archivePath, prefix string, perm fs.FileMod } } +func (i *osInstaller) retryDownloadToTempDir(ctx context.Context, url string, transforms ...transform.Transformer) (fileName string, someError error) { + doer := downloadDoer{ + url: url, + transforms: transforms, + downloader: i, + } + + // Retries are cancled as soon as the context is canceled. + // We need to call NewIntervalRetrier with a clock argument so that the tests can fake the clock by changing the osInstaller clock. + retrier := retry.NewIntervalRetrier(&doer, downloadInterval, i.retriable, i.clock) + if err := retrier.Do(ctx); err != nil { + return "", fmt.Errorf("retrying downloadToTempDir: %w", err) + } + + return doer.path, nil +} + // downloadToTempDir downloads a file to a temporary location, applying transform on-the-fly. -func (i *osInstaller) downloadToTempDir(ctx context.Context, url string, transforms ...transform.Transformer) (string, error) { +func (i *osInstaller) downloadToTempDir(ctx context.Context, url string, transforms ...transform.Transformer) (fileName string, retErr error) { out, err := afero.TempFile(i.fs, "", "") if err != nil { return "", fmt.Errorf("creating destination temp file: %w", err) } + // Remove the created file if an error occurs. + defer func() { + if retErr != nil { + _ = i.fs.Remove(fileName) + } + }() defer out.Close() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return "", fmt.Errorf("request to download %q: %w", url, err) @@ -186,6 +225,27 @@ func (i *osInstaller) copy(oldname, newname string, perm fs.FileMode) (err error return nil } +type downloadDoer struct { + url string + transforms []transform.Transformer + downloader downloader + path string +} + +type downloader interface { + downloadToTempDir(ctx context.Context, url string, transforms ...transform.Transformer) (string, error) +} + +func (d *downloadDoer) Do(ctx context.Context) error { + path, err := d.downloader.downloadToTempDir(ctx, d.url, d.transforms...) + d.path = path + return err +} + +func connectionResetErr(err error) bool { + return errors.Is(err, syscall.ECONNRESET) +} + // verifyTarPath checks if a tar path is valid (must not contain ".." as path element). func verifyTarPath(pat string) error { n := len(pat) diff --git a/bootstrapper/internal/kubernetes/k8sapi/install_test.go b/bootstrapper/internal/kubernetes/k8sapi/install_test.go index 495f70ad8..230572ef0 100644 --- a/bootstrapper/internal/kubernetes/k8sapi/install_test.go +++ b/bootstrapper/internal/kubernetes/k8sapi/install_test.go @@ -12,7 +12,9 @@ import ( "net/http" "net/http/httptest" "path" + "sync" "testing" + "time" "github.com/hashicorp/go-multierror" "github.com/icholy/replace" @@ -21,6 +23,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/text/transform" "google.golang.org/grpc/test/bufconn" + testclock "k8s.io/utils/clock/testing" ) func TestInstall(t *testing.T) { @@ -75,10 +78,14 @@ func TestInstall(t *testing.T) { }, } + // This test was written before retriability was added to Install. It makes sense to test Install as if it wouldn't retry requests. inst := osInstaller{ - fs: &afero.Afero{Fs: afero.NewMemMapFs()}, - hClient: &hClient, + fs: &afero.Afero{Fs: afero.NewMemMapFs()}, + hClient: &hClient, + clock: testclock.NewFakeClock(time.Time{}), + retriable: func(err error) bool { return false }, } + err := inst.Install(context.Background(), "http://server/path", []string{tc.destination}, fs.ModePerm, tc.extract, tc.transforms...) if tc.wantErr { assert.Error(err) @@ -238,6 +245,91 @@ func TestExtractArchive(t *testing.T) { } } +func TestRetryDownloadToTempDir(t *testing.T) { + testCases := map[string]struct { + responses []int + cancelCtx bool + wantErr bool + wantFile []byte + }{ + "Succeed on third try": { + responses: []int{500, 500, 200}, + wantFile: []byte("file-content"), + }, + "Cancel after second try": { + responses: []int{500, 500}, + cancelCtx: true, + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + // control the server's responses through stateCh + stateCh := make(chan int) + server := newHTTPBufconnServerWithState(stateCh, tc.wantFile) + defer server.Close() + + hClient := http.Client{ + Transport: &http.Transport{ + DialContext: server.DialContext, + Dial: server.Dial, + DialTLSContext: server.DialContext, + DialTLS: server.Dial, + }, + } + + afs := afero.NewMemMapFs() + + // control download retries through FakeClock clock + clock := testclock.NewFakeClock(time.Now()) + inst := osInstaller{ + fs: &afero.Afero{Fs: afs}, + hClient: &hClient, + clock: clock, + retriable: func(error) bool { return true }, + } + + // abort retryDownloadToTempDir in some test cases by using the context + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + wg := sync.WaitGroup{} + var downloadErr error + var path string + wg.Add(1) + go func() { + defer wg.Done() + path, downloadErr = inst.retryDownloadToTempDir(ctx, "http://server/path", []transform.Transformer{}...) + }() + + // control the server's responses through stateCh. + for _, resp := range tc.responses { + stateCh <- resp + clock.Step(downloadInterval) + } + if tc.cancelCtx { + cancel() + } + + wg.Wait() + + if tc.wantErr { + assert.Error(downloadErr) + return + } + + require.NoError(downloadErr) + content, err := inst.fs.ReadFile(path) + assert.NoError(err) + assert.Equal(tc.wantFile, content) + }) + } +} + func TestDownloadToTempDir(t *testing.T) { testCases := map[string]struct { server httpBufconnServer @@ -484,6 +576,21 @@ func newHTTPBufconnServerWithBody(body []byte) httpBufconnServer { }) } +func newHTTPBufconnServerWithState(state chan int, body []byte) httpBufconnServer { + return newHTTPBufconnServer(func(w http.ResponseWriter, r *http.Request) { + switch <-state { + case 500: + w.WriteHeader(500) + case 200: + if _, err := w.Write(body); err != nil { + panic(err) + } + default: + w.WriteHeader(402) + } + }) +} + func createTarGz(contents []byte, path string) []byte { tgzWriter := newTarGzWriter() defer func() { _ = tgzWriter.Close() }() diff --git a/cli/internal/cmd/init.go b/cli/internal/cmd/init.go index 386dcd9e2..d77b4e0e5 100644 --- a/cli/internal/cmd/init.go +++ b/cli/internal/cmd/init.go @@ -24,7 +24,8 @@ import ( "github.com/edgelesssys/constellation/internal/deploy/ssh" "github.com/edgelesssys/constellation/internal/file" "github.com/edgelesssys/constellation/internal/grpc/dialer" - "github.com/edgelesssys/constellation/internal/grpc/retry" + grpcRetry "github.com/edgelesssys/constellation/internal/grpc/retry" + "github.com/edgelesssys/constellation/internal/retry" "github.com/edgelesssys/constellation/internal/state" kms "github.com/edgelesssys/constellation/kms/setup" "github.com/spf13/afero" @@ -141,7 +142,7 @@ func initCall(ctx context.Context, dialer grpcDialer, ip string, req *initproto. endpoint: net.JoinHostPort(ip, strconv.Itoa(constants.BootstrapperPort)), req: req, } - retrier := retry.NewIntervalRetrier(doer, 30*time.Second) + retrier := retry.NewIntervalRetrier(doer, 30*time.Second, grpcRetry.ServiceIsUnavailable) if err := retrier.Do(ctx); err != nil { return nil, err } diff --git a/internal/grpc/retry/retry.go b/internal/grpc/retry/retry.go index 8a1c61dfb..fb63a4519 100644 --- a/internal/grpc/retry/retry.go +++ b/internal/grpc/retry/retry.go @@ -1,58 +1,16 @@ package retry import ( - "context" "errors" "strings" - "time" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "k8s.io/utils/clock" ) -// IntervalRetrier is retries a grpc call with an interval. -type IntervalRetrier struct { - interval time.Duration - doer Doer - clock clock.WithTicker -} - -// NewIntervalRetrier returns a new IntervalRetrier. -func NewIntervalRetrier(doer Doer, interval time.Duration) *IntervalRetrier { - return &IntervalRetrier{ - interval: interval, - doer: doer, - clock: clock.RealClock{}, - } -} - -// Do retries performing a grpc call until it succeeds, returns a permanent error or the context is cancelled. -func (r *IntervalRetrier) Do(ctx context.Context) error { - ticker := r.clock.NewTicker(r.interval) - defer ticker.Stop() - - for { - err := r.doer.Do(ctx) - if err == nil { - return nil - } - - if !r.serviceIsUnavailable(err) { - return err - } - - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C(): - } - } -} - -// serviceIsUnavailable checks if the error is a grpc status with code Unavailable. +// ServiceIsUnavailable checks if the error is a grpc status with code Unavailable. // In the special case of an authentication handshake failure, false is returned to prevent further retries. -func (r *IntervalRetrier) serviceIsUnavailable(err error) bool { +func ServiceIsUnavailable(err error) bool { // taken from google.golang.org/grpc/status.FromError var targetErr interface { GRPCStatus() *status.Status @@ -75,10 +33,3 @@ func (r *IntervalRetrier) serviceIsUnavailable(err error) bool { // ideally we would check the error type directly, but grpc only provides a string return !strings.HasPrefix(statusErr.Message(), `connection error: desc = "transport: authentication handshake failed`) } - -type Doer interface { - // Do performs a grpc operation. - // - // It should return a grpc status with code Unavailable error to signal a transient fault. - Do(ctx context.Context) error -} diff --git a/internal/grpc/retry/retry_test.go b/internal/grpc/retry/retry_test.go index df2087896..9cc947ae8 100644 --- a/internal/grpc/retry/retry_test.go +++ b/internal/grpc/retry/retry_test.go @@ -1,86 +1,15 @@ package retry import ( - "context" "errors" "fmt" "testing" - "time" "github.com/stretchr/testify/assert" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - testclock "k8s.io/utils/clock/testing" ) -func TestDo(t *testing.T) { - testCases := map[string]struct { - cancel bool - errors []error - wantErr error - }{ - "no error": { - errors: []error{ - nil, - }, - }, - "permanent error": { - errors: []error{ - errors.New("error"), - }, - wantErr: errors.New("error"), - }, - "service unavailable then success": { - errors: []error{ - status.Error(codes.Unavailable, "error"), - nil, - }, - }, - "service unavailable then permanent error": { - errors: []error{ - status.Error(codes.Unavailable, "error"), - errors.New("error"), - }, - wantErr: errors.New("error"), - }, - "cancellation works": { - cancel: true, - errors: []error{ - status.Error(codes.Unavailable, "error"), - }, - wantErr: context.Canceled, - }, - } - - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - assert := assert.New(t) - doer := newStubDoer() - clock := testclock.NewFakeClock(time.Now()) - retrier := IntervalRetrier{ - doer: doer, - clock: clock, - } - retrierResult := make(chan error, 1) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { retrierResult <- retrier.Do(ctx) }() - - if tc.cancel { - cancel() - } - - for _, err := range tc.errors { - doer.errC <- err - clock.Step(retrier.interval) - } - - assert.Equal(tc.wantErr, <-retrierResult) - }) - } -} - func TestServiceIsUnavailable(t *testing.T) { testCases := map[string]struct { err error @@ -112,22 +41,7 @@ func TestServiceIsUnavailable(t *testing.T) { for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) - retrier := IntervalRetrier{} - assert.Equal(tc.wantUnavailable, retrier.serviceIsUnavailable(tc.err)) + assert.Equal(tc.wantUnavailable, ServiceIsUnavailable(tc.err)) }) } } - -type stubDoer struct { - errC chan error -} - -func newStubDoer() *stubDoer { - return &stubDoer{ - errC: make(chan error), - } -} - -func (d *stubDoer) Do(_ context.Context) error { - return <-d.errC -} diff --git a/internal/retry/retry.go b/internal/retry/retry.go new file mode 100644 index 000000000..1d2d044cd --- /dev/null +++ b/internal/retry/retry.go @@ -0,0 +1,61 @@ +package retry + +import ( + "context" + "time" + + "k8s.io/utils/clock" +) + +// IntervalRetrier retries a call with an interval. The call is defined in the Doer property. +type IntervalRetrier struct { + interval time.Duration + doer Doer + clock clock.WithTicker + retriable func(error) bool +} + +// NewIntervalRetrier returns a new IntervalRetrier. The optional clock is used for testing. +func NewIntervalRetrier(doer Doer, interval time.Duration, retriable func(error) bool, optClock ...clock.WithTicker) *IntervalRetrier { + var clock clock.WithTicker = clock.RealClock{} + if len(optClock) > 0 { + clock = optClock[0] + } + + return &IntervalRetrier{ + interval: interval, + doer: doer, + clock: clock, + retriable: retriable, + } +} + +// Do retries performing a call until it succeeds, returns a permanent error or the context is cancelled. +func (r *IntervalRetrier) Do(ctx context.Context) error { + ticker := r.clock.NewTicker(r.interval) + defer ticker.Stop() + + for { + err := r.doer.Do(ctx) + if err == nil { + return nil + } + + if !r.retriable(err) { + return err + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C(): + } + } +} + +type Doer interface { + // Do performs an operation. + // + // It should return an error that can be checked for retriability. + Do(ctx context.Context) error +} diff --git a/internal/retry/retry_test.go b/internal/retry/retry_test.go new file mode 100644 index 000000000..dd5a88477 --- /dev/null +++ b/internal/retry/retry_test.go @@ -0,0 +1,98 @@ +package retry + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + testclock "k8s.io/utils/clock/testing" +) + +func TestDo(t *testing.T) { + testCases := map[string]struct { + cancel bool + errors []error + wantErr error + }{ + "no error": { + errors: []error{ + nil, + }, + }, + "permanent error": { + errors: []error{ + errors.New("error"), + }, + wantErr: errors.New("error"), + }, + "service unavailable then success": { + errors: []error{ + errors.New("retry me"), + nil, + }, + }, + "service unavailable then permanent error": { + errors: []error{ + errors.New("retry me"), + errors.New("error"), + }, + wantErr: errors.New("error"), + }, + "cancellation works": { + cancel: true, + errors: []error{ + errors.New("retry me"), + }, + wantErr: context.Canceled, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + doer := newStubDoer() + clock := testclock.NewFakeClock(time.Now()) + retrier := IntervalRetrier{ + doer: doer, + clock: clock, + retriable: isRetriable, + } + retrierResult := make(chan error, 1) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { retrierResult <- retrier.Do(ctx) }() + + if tc.cancel { + cancel() + } + + for _, err := range tc.errors { + doer.errC <- err + clock.Step(retrier.interval) + } + + assert.Equal(tc.wantErr, <-retrierResult) + }) + } +} + +type stubDoer struct { + errC chan error +} + +func newStubDoer() *stubDoer { + return &stubDoer{ + errC: make(chan error), + } +} + +func (d *stubDoer) Do(_ context.Context) error { + return <-d.errC +} + +func isRetriable(err error) bool { + return err.Error() == "retry me" +}