diff --git a/cli/internal/cmd/init.go b/cli/internal/cmd/init.go index a8b4982a5..d03d17287 100644 --- a/cli/internal/cmd/init.go +++ b/cli/internal/cmd/init.go @@ -142,8 +142,8 @@ func initCall(ctx context.Context, dialer grpcDialer, ip string, req *initproto. endpoint: net.JoinHostPort(ip, strconv.Itoa(constants.CoordinatorPort)), req: req, } - retryer := retry.NewIntervalRetryer(doer, 30*time.Second) - if err := retryer.Do(ctx); err != nil { + retrier := retry.NewIntervalRetrier(doer, 30*time.Second) + if err := retrier.Do(ctx); err != nil { return nil, err } return doer.resp, nil diff --git a/internal/grpc/retry/retry.go b/internal/grpc/retry/retry.go index 1d7e934d8..f7f04f4a9 100644 --- a/internal/grpc/retry/retry.go +++ b/internal/grpc/retry/retry.go @@ -10,21 +10,24 @@ import ( "k8s.io/utils/clock" ) -type IntervalRetryer struct { +// IntervalRetrier is retries a grpc call with an interval. +type IntervalRetrier struct { interval time.Duration doer Doer clock clock.WithTicker } -func NewIntervalRetryer(doer Doer, interval time.Duration) *IntervalRetryer { - return &IntervalRetryer{ +// NewIntervalRetrier returns a new IntervalRetrier. +func NewIntervalRetrier(doer Doer, interval time.Duration) *IntervalRetrier { + return &IntervalRetrier{ interval: interval, doer: doer, clock: clock.RealClock{}, } } -func (r *IntervalRetryer) Do(ctx context.Context) error { +// Do retries performing a grpc call until it succeeds, returns a permanent error or the context is cancelled. +func (r *IntervalRetrier) Do(ctx context.Context) error { ticker := r.clock.NewTicker(r.interval) defer ticker.Stop() @@ -39,14 +42,16 @@ func (r *IntervalRetryer) Do(ctx context.Context) error { } select { - case <-ctx.Done(): // TODO(katexochen): is this necessary? + case <-ctx.Done(): return ctx.Err() case <-ticker.C(): } } } -func (r *IntervalRetryer) serviceIsUnavailable(err error) bool { +// serviceIsUnavailable checks if the error is a grpc status with code Unavailable. +// In the special case of an authentication handshake failure, false is returned to prevent further retries. +func (r *IntervalRetrier) serviceIsUnavailable(err error) bool { statusErr, ok := status.FromError(err) if !ok { return false @@ -55,9 +60,12 @@ func (r *IntervalRetryer) serviceIsUnavailable(err error) bool { return false } // ideally we would check the error type directly, but grpc only provides a string - return strings.HasPrefix(statusErr.Message(), `connection error: desc = "transport: authentication handshake failed`) + return !strings.HasPrefix(statusErr.Message(), `connection error: desc = "transport: authentication handshake failed`) } type Doer interface { + // Do performs a grpc operation. + // + // It should return a grpc status with code Unavailable error to signal a transient fault. Do(ctx context.Context) error } diff --git a/internal/grpc/retry/retry_test.go b/internal/grpc/retry/retry_test.go new file mode 100644 index 000000000..344e11116 --- /dev/null +++ b/internal/grpc/retry/retry_test.go @@ -0,0 +1,125 @@ +package retry + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + testclock "k8s.io/utils/clock/testing" +) + +func TestDo(t *testing.T) { + testCases := map[string]struct { + cancel bool + errors []error + wantErr error + }{ + "no error": { + errors: []error{ + nil, + }, + }, + "permanent error": { + errors: []error{ + errors.New("error"), + }, + wantErr: errors.New("error"), + }, + "service unavailable then success": { + errors: []error{ + status.Error(codes.Unavailable, "error"), + nil, + }, + }, + "service unavailable then permanent error": { + errors: []error{ + status.Error(codes.Unavailable, "error"), + errors.New("error"), + }, + wantErr: errors.New("error"), + }, + "cancellation works": { + cancel: true, + errors: []error{ + status.Error(codes.Unavailable, "error"), + }, + wantErr: context.Canceled, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + doer := newStubDoer() + clock := testclock.NewFakeClock(time.Now()) + retrier := IntervalRetrier{ + doer: doer, + clock: clock, + } + retrierResult := make(chan error, 1) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { retrierResult <- retrier.Do(ctx) }() + + if tc.cancel { + cancel() + } + + for _, err := range tc.errors { + doer.errC <- err + clock.Step(retrier.interval) + } + + assert.Equal(tc.wantErr, <-retrierResult) + }) + } +} + +func TestServiceIsUnavailable(t *testing.T) { + testCases := map[string]struct { + err error + wantUnavailable bool + }{ + "nil": {}, + "not status error": { + err: errors.New("error"), + }, + "not unavailable": { + err: status.Error(codes.Internal, "error"), + }, + "unavailable error with authentication handshake failure": { + err: status.Error(codes.Unavailable, `connection error: desc = "transport: authentication handshake failed`), + }, + "normal unavailable error": { + err: status.Error(codes.Unavailable, "error"), + wantUnavailable: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + retrier := IntervalRetrier{} + assert.Equal(tc.wantUnavailable, retrier.serviceIsUnavailable(tc.err)) + }) + } +} + +type stubDoer struct { + errC chan error +} + +func newStubDoer() *stubDoer { + return &stubDoer{ + errC: make(chan error), + } +} + +func (d *stubDoer) Do(_ context.Context) error { + return <-d.errC +}