mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-12-10 13:40:57 -05:00
AB#2260 Refactor disk-mapper recovery (#82)
* Refactor disk-mapper recovery * Adapt constellation recover command to use new disk-mapper recovery API * Fix Cilium connectivity on rebooting nodes (#89) * Lower CoreDNS reschedule timeout to 10 seconds (#93) Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
parent
a7b20b2a11
commit
8cb155d5c5
40 changed files with 1600 additions and 1130 deletions
134
disk-mapper/internal/recoveryserver/server.go
Normal file
134
disk-mapper/internal/recoveryserver/server.go
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package recoveryserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/edgelesssys/constellation/disk-mapper/recoverproto"
|
||||
"github.com/edgelesssys/constellation/internal/atls"
|
||||
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
|
||||
"github.com/edgelesssys/constellation/internal/logger"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// RecoveryServer is a gRPC server that can be used by an admin to recover a restarting node.
|
||||
type RecoveryServer struct {
|
||||
mux sync.Mutex
|
||||
|
||||
diskUUID string
|
||||
stateDiskKey []byte
|
||||
measurementSecret []byte
|
||||
grpcServer server
|
||||
|
||||
log *logger.Logger
|
||||
|
||||
recoverproto.UnimplementedAPIServer
|
||||
}
|
||||
|
||||
// New returns a new RecoveryServer.
|
||||
func New(issuer atls.Issuer, log *logger.Logger) *RecoveryServer {
|
||||
server := &RecoveryServer{
|
||||
log: log,
|
||||
}
|
||||
|
||||
grpcServer := grpc.NewServer(
|
||||
grpc.Creds(atlscredentials.New(issuer, nil)),
|
||||
log.Named("gRPC").GetServerStreamInterceptor(),
|
||||
)
|
||||
recoverproto.RegisterAPIServer(grpcServer, server)
|
||||
|
||||
server.grpcServer = grpcServer
|
||||
return server
|
||||
}
|
||||
|
||||
// Serve starts the recovery server.
|
||||
// It blocks until a recover request call is successful.
|
||||
// The server will shut down when the call is successful and the keys are returned.
|
||||
// Additionally, the server can be shutdown by canceling the context.
|
||||
func (s *RecoveryServer) Serve(ctx context.Context, listener net.Listener, diskUUID string) (diskKey, measurementSecret []byte, err error) {
|
||||
s.log.Infof("Starting RecoveryServer")
|
||||
s.diskUUID = diskUUID
|
||||
recoveryDone := make(chan struct{}, 1)
|
||||
var serveErr error
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
defer wg.Wait()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
serveErr = s.grpcServer.Serve(listener)
|
||||
recoveryDone <- struct{}{}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.log.Infof("Context canceled, shutting down server")
|
||||
s.grpcServer.GracefulStop()
|
||||
return nil, nil, ctx.Err()
|
||||
case <-recoveryDone:
|
||||
if serveErr != nil {
|
||||
return nil, nil, serveErr
|
||||
}
|
||||
return s.stateDiskKey, s.measurementSecret, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Recover is a bidirectional streaming RPC that is used to send recovery keys to a restarting node.
|
||||
func (s *RecoveryServer) Recover(stream recoverproto.API_RecoverServer) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
s.log.Infof("Received recover call")
|
||||
|
||||
msg, err := stream.Recv()
|
||||
if err != nil {
|
||||
return status.Error(codes.Internal, "failed to receive message")
|
||||
}
|
||||
|
||||
measurementSecret, ok := msg.GetRequest().(*recoverproto.RecoverMessage_MeasurementSecret)
|
||||
if !ok {
|
||||
s.log.Errorf("Received invalid first message: not a measurement secret")
|
||||
return status.Error(codes.InvalidArgument, "first message is not a measurement secret")
|
||||
}
|
||||
|
||||
if err := stream.Send(&recoverproto.RecoverResponse{DiskUuid: s.diskUUID}); err != nil {
|
||||
s.log.With(zap.Error(err)).Errorf("Failed to send disk UUID")
|
||||
return status.Error(codes.Internal, "failed to send response")
|
||||
}
|
||||
|
||||
msg, err = stream.Recv()
|
||||
if err != nil {
|
||||
s.log.With(zap.Error(err)).Errorf("Failed to receive disk key")
|
||||
return status.Error(codes.Internal, "failed to receive message")
|
||||
}
|
||||
|
||||
stateDiskKey, ok := msg.GetRequest().(*recoverproto.RecoverMessage_StateDiskKey)
|
||||
if !ok {
|
||||
s.log.Errorf("Received invalid second message: not a state disk key")
|
||||
return status.Error(codes.InvalidArgument, "second message is not a state disk key")
|
||||
}
|
||||
|
||||
s.stateDiskKey = stateDiskKey.StateDiskKey
|
||||
s.measurementSecret = measurementSecret.MeasurementSecret
|
||||
s.log.Infof("Received state disk key and measurement secret, shutting down server")
|
||||
|
||||
go s.grpcServer.GracefulStop()
|
||||
return nil
|
||||
}
|
||||
|
||||
type server interface {
|
||||
Serve(net.Listener) error
|
||||
GracefulStop()
|
||||
}
|
||||
194
disk-mapper/internal/recoveryserver/server_test.go
Normal file
194
disk-mapper/internal/recoveryserver/server_test.go
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package recoveryserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/disk-mapper/recoverproto"
|
||||
"github.com/edgelesssys/constellation/internal/atls"
|
||||
"github.com/edgelesssys/constellation/internal/grpc/dialer"
|
||||
"github.com/edgelesssys/constellation/internal/grpc/testdialer"
|
||||
"github.com/edgelesssys/constellation/internal/logger"
|
||||
"github.com/edgelesssys/constellation/internal/oid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
||||
|
||||
func TestServe(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
log := logger.NewTest(t)
|
||||
uuid := "uuid"
|
||||
server := New(atls.NewFakeIssuer(oid.Dummy{}), log)
|
||||
dialer := testdialer.NewBufconnDialer()
|
||||
listener := dialer.GetListener("192.0.2.1:1234")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Serve method returns when context is canceled
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _, err := server.Serve(ctx, listener, uuid)
|
||||
assert.ErrorIs(err, context.Canceled)
|
||||
}()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cancel()
|
||||
wg.Wait()
|
||||
|
||||
server = New(atls.NewFakeIssuer(oid.Dummy{}), log)
|
||||
dialer = testdialer.NewBufconnDialer()
|
||||
listener = dialer.GetListener("192.0.2.1:1234")
|
||||
|
||||
// Serve method returns without error when the server is shut down
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _, err := server.Serve(context.Background(), listener, uuid)
|
||||
assert.NoError(err)
|
||||
}()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
server.grpcServer.GracefulStop()
|
||||
wg.Wait()
|
||||
|
||||
// Serve method returns an error when serving is unsuccessful
|
||||
_, _, err := server.Serve(context.Background(), listener, uuid)
|
||||
assert.Error(err)
|
||||
}
|
||||
|
||||
func TestRecover(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
initialMsg message
|
||||
keyMsg message
|
||||
wantErr bool
|
||||
}{
|
||||
"success": {
|
||||
initialMsg: message{
|
||||
recoverMsg: &recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_MeasurementSecret{
|
||||
MeasurementSecret: []byte("measurementSecret"),
|
||||
},
|
||||
},
|
||||
},
|
||||
keyMsg: message{
|
||||
recoverMsg: &recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_StateDiskKey{
|
||||
StateDiskKey: []byte("diskKey"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"first message is not a measurement secret": {
|
||||
initialMsg: message{
|
||||
recoverMsg: &recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_StateDiskKey{
|
||||
StateDiskKey: []byte("diskKey"),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
keyMsg: message{
|
||||
recoverMsg: &recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_StateDiskKey{
|
||||
StateDiskKey: []byte("diskKey"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"second message is not a state disk key": {
|
||||
initialMsg: message{
|
||||
recoverMsg: &recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_MeasurementSecret{
|
||||
MeasurementSecret: []byte("measurementSecret"),
|
||||
},
|
||||
},
|
||||
},
|
||||
keyMsg: message{
|
||||
recoverMsg: &recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_MeasurementSecret{
|
||||
MeasurementSecret: []byte("measurementSecret"),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
ctx := context.Background()
|
||||
serverUUID := "uuid"
|
||||
server := New(atls.NewFakeIssuer(oid.Dummy{}), logger.NewTest(t))
|
||||
netDialer := testdialer.NewBufconnDialer()
|
||||
listener := netDialer.GetListener("192.0.2.1:1234")
|
||||
|
||||
var diskKey, measurementSecret []byte
|
||||
var serveErr error
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait()
|
||||
|
||||
serveCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
diskKey, measurementSecret, serveErr = server.Serve(serveCtx, listener, serverUUID)
|
||||
}()
|
||||
|
||||
conn, err := dialer.New(nil, nil, netDialer).Dial(ctx, "192.0.2.1:1234")
|
||||
require.NoError(err)
|
||||
defer conn.Close()
|
||||
client, err := recoverproto.NewAPIClient(conn).Recover(ctx)
|
||||
require.NoError(err)
|
||||
|
||||
// Send initial message
|
||||
err = client.Send(tc.initialMsg.recoverMsg)
|
||||
require.NoError(err)
|
||||
|
||||
// Receive uuid
|
||||
uuid, err := client.Recv()
|
||||
if tc.initialMsg.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.Equal(serverUUID, uuid.DiskUuid)
|
||||
|
||||
// Send key message
|
||||
err = client.Send(tc.keyMsg.recoverMsg)
|
||||
require.NoError(err)
|
||||
|
||||
_, err = client.Recv()
|
||||
if tc.keyMsg.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.ErrorIs(io.EOF, err)
|
||||
|
||||
wg.Wait()
|
||||
assert.NoError(serveErr)
|
||||
assert.Equal(tc.initialMsg.recoverMsg.GetMeasurementSecret(), measurementSecret)
|
||||
assert.Equal(tc.keyMsg.recoverMsg.GetStateDiskKey(), diskKey)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type message struct {
|
||||
recoverMsg *recoverproto.RecoverMessage
|
||||
wantErr bool
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue