Fix stateservice test and increase speed

This commit is contained in:
katexochen 2022-06-01 10:14:36 +02:00 committed by Paul Meyer
parent d43ee053ed
commit ed5f64dc0a
5 changed files with 210 additions and 119 deletions

View file

@ -7,7 +7,8 @@ Files and source code for mounting persistent state disks
Integration test is available in `state/test/integration_test.go`. Integration test is available in `state/test/integration_test.go`.
The integration test requires root privileges since it uses dm-crypt. The integration test requires root privileges since it uses dm-crypt.
Build and run the test: Build and run the test:
```bash ```bash
go test -c ./state/test/ go test -c -tags=integration ./state/test/
sudo ./test.test sudo ./test.test
``` ```

View file

@ -96,7 +96,7 @@ func main() {
log.Named("setupManager"), log.Named("setupManager"),
*csp, *csp,
afero.Afero{Fs: afero.NewOsFs()}, afero.Afero{Fs: afero.NewOsFs()},
keyservice.New(log.Named("keyService"), issuer, metadata, 20*time.Second), // try to request a key every 20 seconds keyservice.New(log.Named("keyService"), issuer, metadata, 20*time.Second, 20*time.Second), // try to request a key every 20 seconds
mapper, mapper,
setup.DiskMounter{}, setup.DiskMounter{},
vtpm.OpenVTPM, vtpm.OpenVTPM,

View file

@ -3,6 +3,7 @@ package keyservice
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"net" "net"
"sync" "sync"
"time" "time"
@ -19,6 +20,7 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"k8s.io/utils/clock"
) )
// KeyAPI is the interface called by control-plane or an admin during restart of a node. // KeyAPI is the interface called by control-plane or an admin during restart of a node.
@ -31,18 +33,24 @@ type KeyAPI struct {
key []byte key []byte
measurementSecret []byte measurementSecret []byte
keyReceived chan struct{} keyReceived chan struct{}
clock clock.WithTicker
timeout time.Duration timeout time.Duration
interval time.Duration
keyproto.UnimplementedAPIServer keyproto.UnimplementedAPIServer
} }
// New initializes a KeyAPI with the given parameters. // New initializes a KeyAPI with the given parameters.
func New(log *logger.Logger, issuer QuoteIssuer, metadata metadata.InstanceLister, timeout time.Duration) *KeyAPI { func New(log *logger.Logger, issuer QuoteIssuer, metadata metadata.InstanceLister, timeout time.Duration, interval time.Duration) *KeyAPI {
return &KeyAPI{ return &KeyAPI{
log: log, log: log,
metadata: metadata, metadata: metadata,
issuer: issuer, issuer: issuer,
keyReceived: make(chan struct{}, 1), keyReceived: make(chan struct{}, 1),
clock: clock.RealClock{},
timeout: timeout, timeout: timeout,
interval: interval,
} }
} }
@ -99,67 +107,80 @@ func (a *KeyAPI) requestKeyLoop(uuid string, opts ...grpc.DialOption) {
// we do not perform attestation, since the restarting node does not need to care about notifying the correct node // we do not perform attestation, since the restarting node does not need to care about notifying the correct node
// if an incorrect key is pushed by a malicious actor, decrypting the disk will fail, and the node will not start // if an incorrect key is pushed by a malicious actor, decrypting the disk will fail, and the node will not start
creds := atlscredentials.New(a.issuer, nil) creds := atlscredentials.New(a.issuer, nil)
// set up for the select statement to immediately request a key, skipping the initial delay caused by using a ticker
firstReq := make(chan struct{}, 1)
firstReq <- struct{}{}
ticker := time.NewTicker(a.timeout) ticker := a.clock.NewTicker(a.interval)
defer ticker.Stop() defer ticker.Stop()
for { for {
endpoints, err := a.getJoinServiceEndpoints()
if err != nil {
a.log.With(zap.Error(err)).Errorf("Failed to get JoinService endpoints")
} else {
a.log.Infof("Received list with JoinService endpoints: %v", endpoints)
for _, endpoint := range endpoints {
a.requestKey(endpoint, uuid, creds, opts...)
}
}
select { select {
case <-a.keyReceived:
// return if a key was received // return if a key was received
// a key can be send by // a key can be send by
// - a control-plane node, after the request rpc was received // - a control-plane node, after the request rpc was received
// - by a Constellation admin, at any time this loop is running on a node during boot // - by a Constellation admin, at any time this loop is running on a node during boot
case <-a.keyReceived:
return return
case <-ticker.C: case <-ticker.C():
a.requestKey(uuid, creds, opts...)
case <-firstReq:
a.requestKey(uuid, creds, opts...)
} }
} }
} }
func (a *KeyAPI) requestKey(uuid string, credentials credentials.TransportCredentials, opts ...grpc.DialOption) { func (a *KeyAPI) getJoinServiceEndpoints() ([]string, error) {
// list available control-plane nodes
endpoints, _ := metadata.JoinServiceEndpoints(context.Background(), a.metadata)
a.log.With(zap.Strings("endpoints", endpoints)).Infof("Sending a key request to available control-plane nodes")
// notify all available control-plane nodes to send a key to the node
// any errors encountered here will be ignored, and the calls retried after a timeout
for _, endpoint := range endpoints {
ctx, cancel := context.WithTimeout(context.Background(), a.timeout) ctx, cancel := context.WithTimeout(context.Background(), a.timeout)
defer cancel() defer cancel()
return metadata.JoinServiceEndpoints(ctx, a.metadata)
}
// request rejoin ticket from JoinService func (a *KeyAPI) requestKey(endpoint, uuid string, credentials credentials.TransportCredentials, opts ...grpc.DialOption) {
conn, err := grpc.DialContext(ctx, endpoint, append(opts, grpc.WithTransportCredentials(credentials))...) opts = append(opts, grpc.WithTransportCredentials(credentials))
a.log.With(zap.String("endpoint", endpoint)).Infof("Requesting rejoin ticket")
rejoinTicket, err := a.requestRejoinTicket(endpoint, uuid, opts...)
if err != nil { if err != nil {
continue a.log.With(zap.Error(err), zap.String("endpoint", endpoint)).Errorf("Failed to request rejoin ticket")
return
}
a.log.With(zap.String("endpoint", endpoint)).Infof("Pushing key to own server")
if err := a.pushKeyToOwnServer(rejoinTicket.StateDiskKey, rejoinTicket.MeasurementSecret, opts...); err != nil {
a.log.With(zap.Error(err), zap.String("endpoint", a.listenAddr)).Errorf("Failed to push key to own server")
return
}
}
func (a *KeyAPI) requestRejoinTicket(endpoint, uuid string, opts ...grpc.DialOption) (*joinproto.IssueRejoinTicketResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), a.timeout)
defer cancel()
conn, err := grpc.DialContext(ctx, endpoint, opts...)
if err != nil {
return nil, fmt.Errorf("dialing gRPC: %w", err)
} }
defer conn.Close() defer conn.Close()
client := joinproto.NewAPIClient(conn) client := joinproto.NewAPIClient(conn)
response, err := client.IssueRejoinTicket(ctx, &joinproto.IssueRejoinTicketRequest{DiskUuid: uuid}) req := &joinproto.IssueRejoinTicketRequest{DiskUuid: uuid}
if err != nil { return client.IssueRejoinTicket(ctx, req)
a.log.With(zap.Error(err), zap.String("endpoint", endpoint)).Warnf("Failed to request key") }
continue
}
// push key to own gRPC server func (a *KeyAPI) pushKeyToOwnServer(stateDiskKey, measurementSecret []byte, opts ...grpc.DialOption) error {
pushKeyConn, err := grpc.DialContext(ctx, a.listenAddr, append(opts, grpc.WithTransportCredentials(credentials))...) ctx, cancel := context.WithTimeout(context.Background(), a.timeout)
defer cancel()
conn, err := grpc.DialContext(ctx, a.listenAddr, opts...)
if err != nil { if err != nil {
continue return fmt.Errorf("dialing gRPC: %w", err)
}
defer pushKeyConn.Close()
pushKeyClient := keyproto.NewAPIClient(pushKeyConn)
if _, err := pushKeyClient.PushStateDiskKey(
ctx,
&keyproto.PushStateDiskKeyRequest{StateDiskKey: response.StateDiskKey, MeasurementSecret: response.MeasurementSecret},
); err != nil {
a.log.With(zap.Error(err), zap.String("endpoint", a.listenAddr)).Errorf("Failed to push key")
continue
}
} }
defer conn.Close()
client := keyproto.NewAPIClient(conn)
req := &keyproto.PushStateDiskKeyRequest{StateDiskKey: stateDiskKey, MeasurementSecret: measurementSecret}
_, err = client.PushStateDiskKey(ctx, req)
return err
} }
// QuoteValidator validates quotes. // QuoteValidator validates quotes.

View file

@ -13,23 +13,22 @@ import (
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials" "github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/internal/logger" "github.com/edgelesssys/constellation/internal/logger"
"github.com/edgelesssys/constellation/internal/oid" "github.com/edgelesssys/constellation/internal/oid"
"github.com/edgelesssys/constellation/kms/kmsproto" "github.com/edgelesssys/constellation/joinservice/joinproto"
"github.com/edgelesssys/constellation/state/keyservice/keyproto" "github.com/edgelesssys/constellation/state/keyservice/keyproto"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak" "go.uber.org/goleak"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/test/bufconn" "google.golang.org/grpc/test/bufconn"
testclock "k8s.io/utils/clock/testing"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
goleak.VerifyTestMain(m, goleak.VerifyTestMain(m)
// https://github.com/census-instrumentation/opencensus-go/issues/1262
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
)
} }
func TestRequestKeyLoop(t *testing.T) { func TestRequestKeyLoop(t *testing.T) {
clockstep := struct{}{}
someErr := errors.New("failed")
defaultInstance := metadata.InstanceMetadata{ defaultInstance := metadata.InstanceMetadata{
Name: "test-instance", Name: "test-instance",
ProviderID: "/test/provider", ProviderID: "/test/provider",
@ -38,81 +37,102 @@ func TestRequestKeyLoop(t *testing.T) {
} }
testCases := map[string]struct { testCases := map[string]struct {
server *stubAPIServer answers []any
wantCalls int
listResponse []metadata.InstanceMetadata
dontStartServer bool
}{ }{
"success": { "success": {
server: &stubAPIServer{requestStateDiskKeyResp: &kmsproto.GetDataKeyResponse{}}, answers: []any{
listResponse: []metadata.InstanceMetadata{defaultInstance}, listAnswer{listResponse: []metadata.InstanceMetadata{defaultInstance}},
issueRejoinTicketAnswer{stateDiskKey: []byte{0x1}, measurementSecret: []byte{0x2}},
pushStateDiskKeyAnswer{},
}, },
"no error if server throws an error": {
server: &stubAPIServer{
requestStateDiskKeyResp: &kmsproto.GetDataKeyResponse{},
requestStateDiskKeyErr: errors.New("error"),
}, },
listResponse: []metadata.InstanceMetadata{defaultInstance}, "recover metadata list error": {
answers: []any{
listAnswer{err: someErr},
clockstep,
listAnswer{listResponse: []metadata.InstanceMetadata{defaultInstance}},
issueRejoinTicketAnswer{stateDiskKey: []byte{0x1}, measurementSecret: []byte{0x2}},
pushStateDiskKeyAnswer{},
}, },
"no error if the server can not be reached": {
server: &stubAPIServer{requestStateDiskKeyResp: &kmsproto.GetDataKeyResponse{}},
listResponse: []metadata.InstanceMetadata{defaultInstance},
dontStartServer: true,
}, },
"no error if no endpoint is available": { "recover issue rejoin ticket error": {
server: &stubAPIServer{requestStateDiskKeyResp: &kmsproto.GetDataKeyResponse{}}, answers: []any{
listAnswer{listResponse: []metadata.InstanceMetadata{defaultInstance}},
issueRejoinTicketAnswer{err: someErr},
clockstep,
listAnswer{listResponse: []metadata.InstanceMetadata{defaultInstance}},
issueRejoinTicketAnswer{stateDiskKey: []byte{0x1}, measurementSecret: []byte{0x2}},
pushStateDiskKeyAnswer{},
}, },
"works for multiple endpoints": {
server: &stubAPIServer{requestStateDiskKeyResp: &kmsproto.GetDataKeyResponse{}},
listResponse: []metadata.InstanceMetadata{
defaultInstance,
{
Name: "test-instance-2",
ProviderID: "/test/provider",
Role: role.ControlPlane,
PrivateIPs: []string{"192.0.2.2"},
}, },
"recover push key error": {
answers: []any{
listAnswer{listResponse: []metadata.InstanceMetadata{defaultInstance}},
issueRejoinTicketAnswer{stateDiskKey: []byte{0x1}, measurementSecret: []byte{0x2}},
pushStateDiskKeyAnswer{err: someErr},
clockstep,
listAnswer{listResponse: []metadata.InstanceMetadata{defaultInstance}},
issueRejoinTicketAnswer{stateDiskKey: []byte{0x1}, measurementSecret: []byte{0x2}},
pushStateDiskKeyAnswer{},
}, },
}, },
} }
for name, tc := range testCases { for name, tc := range testCases {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
require := require.New(t) metadataServer := newStubMetadataServer()
joinServer := newStubJoinAPIServer()
keyServer := newStubKeyAPIServer()
keyReceived := make(chan struct{}, 1) listener := bufconn.Listen(1024)
listener := bufconn.Listen(1)
defer listener.Close() defer listener.Close()
creds := atlscredentials.New(atls.NewFakeIssuer(oid.Dummy{}), nil) creds := atlscredentials.New(atls.NewFakeIssuer(oid.Dummy{}), nil)
s := grpc.NewServer(grpc.Creds(creds)) grpcServer := grpc.NewServer(grpc.Creds(creds))
kmsproto.RegisterAPIServer(s, tc.server) joinproto.RegisterAPIServer(grpcServer, joinServer)
keyproto.RegisterAPIServer(grpcServer, keyServer)
if !tc.dontStartServer { go grpcServer.Serve(listener)
go func() { require.NoError(s.Serve(listener)) }() defer grpcServer.GracefulStop()
}
clock := testclock.NewFakeClock(time.Now())
keyReceived := make(chan struct{}, 1)
keyWaiter := &KeyAPI{ keyWaiter := &KeyAPI{
listenAddr: "192.0.2.1:30090",
log: logger.NewTest(t), log: logger.NewTest(t),
metadata: stubMetadata{listResponse: tc.listResponse}, metadata: metadataServer,
keyReceived: keyReceived, keyReceived: keyReceived,
timeout: 500 * time.Millisecond, clock: clock,
timeout: 1 * time.Second,
interval: 1 * time.Second,
} }
grpcOpts := []grpc.DialOption{
// notify the API a key was received after 1 second
go func() {
time.Sleep(1 * time.Second)
keyReceived <- struct{}{}
}()
keyWaiter.requestKeyLoop(
"1234",
grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) { grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) {
return listener.DialContext(ctx) return listener.DialContext(ctx)
}), }),
) }
s.Stop() // Start the request loop under tests.
done := make(chan struct{})
go func() {
defer close(done)
keyWaiter.requestKeyLoop("1234", grpcOpts...)
}()
// Play test case answers.
for _, answ := range tc.answers {
switch answ := answ.(type) {
case listAnswer:
metadataServer.listAnswerC <- answ
case issueRejoinTicketAnswer:
joinServer.issueRejoinTicketAnswerC <- answ
case pushStateDiskKeyAnswer:
keyServer.pushStateDiskKeyAnswerC <- answ
default:
clock.Step(time.Second)
}
}
// Stop the request loop.
keyReceived <- struct{}{}
}) })
} }
} }
@ -164,27 +184,75 @@ func TestPushStateDiskKey(t *testing.T) {
} }
func TestResetKey(t *testing.T) { func TestResetKey(t *testing.T) {
api := New(logger.NewTest(t), nil, nil, time.Second) api := New(logger.NewTest(t), nil, nil, time.Second, time.Millisecond)
api.key = []byte{0x1, 0x2, 0x3} api.key = []byte{0x1, 0x2, 0x3}
api.ResetKey() api.ResetKey()
assert.Nil(t, api.key) assert.Nil(t, api.key)
} }
type stubAPIServer struct { type stubMetadataServer struct {
requestStateDiskKeyResp *kmsproto.GetDataKeyResponse listAnswerC chan listAnswer
requestStateDiskKeyErr error
kmsproto.UnimplementedAPIServer
} }
func (s *stubAPIServer) GetDataKey(ctx context.Context, req *kmsproto.GetDataKeyRequest) (*kmsproto.GetDataKeyResponse, error) { func newStubMetadataServer() *stubMetadataServer {
return s.requestStateDiskKeyResp, s.requestStateDiskKeyErr return &stubMetadataServer{
listAnswerC: make(chan listAnswer),
}
} }
type stubMetadata struct { func (s *stubMetadataServer) List(context.Context) ([]metadata.InstanceMetadata, error) {
answer := <-s.listAnswerC
return answer.listResponse, answer.err
}
type listAnswer struct {
listResponse []metadata.InstanceMetadata listResponse []metadata.InstanceMetadata
err error
} }
func (s stubMetadata) List(ctx context.Context) ([]metadata.InstanceMetadata, error) { type stubJoinAPIServer struct {
return s.listResponse, nil issueRejoinTicketAnswerC chan issueRejoinTicketAnswer
joinproto.UnimplementedAPIServer
}
func newStubJoinAPIServer() *stubJoinAPIServer {
return &stubJoinAPIServer{
issueRejoinTicketAnswerC: make(chan issueRejoinTicketAnswer),
}
}
func (s *stubJoinAPIServer) IssueRejoinTicket(context.Context, *joinproto.IssueRejoinTicketRequest) (*joinproto.IssueRejoinTicketResponse, error) {
answer := <-s.issueRejoinTicketAnswerC
resp := &joinproto.IssueRejoinTicketResponse{
StateDiskKey: answer.stateDiskKey,
MeasurementSecret: answer.measurementSecret,
}
return resp, answer.err
}
type issueRejoinTicketAnswer struct {
stateDiskKey []byte
measurementSecret []byte
err error
}
type stubKeyAPIServer struct {
pushStateDiskKeyAnswerC chan pushStateDiskKeyAnswer
keyproto.UnimplementedAPIServer
}
func newStubKeyAPIServer() *stubKeyAPIServer {
return &stubKeyAPIServer{
pushStateDiskKeyAnswerC: make(chan pushStateDiskKeyAnswer),
}
}
func (s *stubKeyAPIServer) PushStateDiskKey(context.Context, *keyproto.PushStateDiskKeyRequest) (*keyproto.PushStateDiskKeyResponse, error) {
answer := <-s.pushStateDiskKeyAnswerC
return &keyproto.PushStateDiskKeyResponse{}, answer.err
}
type pushStateDiskKeyAnswer struct {
err error
} }

View file

@ -99,6 +99,7 @@ func TestKeyAPI(t *testing.T) {
atls.NewFakeIssuer(oid.Dummy{}), atls.NewFakeIssuer(oid.Dummy{}),
&fakeMetadataAPI{}, &fakeMetadataAPI{},
20*time.Second, 20*time.Second,
time.Second,
) )
// send a key to the server // send a key to the server