feat: use SSH host certificates (#3786)

This commit is contained in:
miampf 2025-07-01 12:47:04 +02:00 committed by GitHub
parent 95f17a6d06
commit 7ea5c41f9b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 706 additions and 117 deletions

View file

@ -26,11 +26,13 @@ import (
"io"
"log/slog"
"net"
"os"
"strings"
"sync"
"time"
"github.com/edgelesssys/constellation/v2/bootstrapper/initproto"
"github.com/edgelesssys/constellation/v2/bootstrapper/internal/addresses"
"github.com/edgelesssys/constellation/v2/bootstrapper/internal/journald"
"github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/attestation"
@ -153,35 +155,23 @@ func (s *Server) Init(req *initproto.InitRequest, stream initproto.API_InitServe
s.kmsURI = req.KmsUri
if err := bcrypt.CompareHashAndPassword(s.initSecretHash, req.InitSecret); err != nil {
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "invalid init secret %s", err)); e != nil {
err = errors.Join(err, e)
}
return err
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "invalid init secret %s", err)))
}
cloudKms, err := kmssetup.KMS(stream.Context(), req.StorageUri, req.KmsUri)
if err != nil {
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "creating kms client: %s", err)); e != nil {
err = errors.Join(err, e)
}
return err
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "creating kms client: %s", err)))
}
// generate values for cluster attestation
clusterID, err := deriveMeasurementValues(stream.Context(), req.MeasurementSalt, cloudKms)
if err != nil {
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "deriving measurement values: %s", err)); e != nil {
err = errors.Join(err, e)
}
return err
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "deriving measurement values: %s", err)))
}
nodeLockAcquired, err := s.nodeLock.TryLockOnce(clusterID)
if err != nil {
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "locking node: %s", err)); e != nil {
err = errors.Join(err, e)
}
return err
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "locking node: %s", err)))
}
if !nodeLockAcquired {
// The join client seems to already have a connection to an
@ -208,10 +198,7 @@ func (s *Server) Init(req *initproto.InitRequest, stream initproto.API_InitServe
}()
if err := s.setupDisk(stream.Context(), cloudKms); err != nil {
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "setting up disk: %s", err)); e != nil {
err = errors.Join(err, e)
}
return err
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "setting up disk: %s", err)))
}
state := nodestate.NodeState{
@ -219,32 +206,67 @@ func (s *Server) Init(req *initproto.InitRequest, stream initproto.API_InitServe
MeasurementSalt: req.MeasurementSalt,
}
if err := state.ToFile(s.fileHandler); err != nil {
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "persisting node state: %s", err)); e != nil {
err = errors.Join(err, e)
}
return err
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "persisting node state: %s", err)))
}
// Derive the emergency ssh CA key
key, err := cloudKms.GetDEK(stream.Context(), crypto.DEKPrefix+constants.SSHCAKeySuffix, ed25519.SeedSize)
if err != nil {
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "retrieving DEK for key derivation: %s", err)); e != nil {
err = errors.Join(err, e)
}
return err
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "retrieving DEK for key derivation: %s", err)))
}
ca, err := crypto.GenerateEmergencySSHCAKey(key)
if err != nil {
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "generating emergency SSH CA key: %s", err)); e != nil {
err = errors.Join(err, e)
}
return err
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "generating emergency SSH CA key: %s", err)))
}
if err := s.fileHandler.Write(constants.SSHCAKeyPath, ssh.MarshalAuthorizedKey(ca.PublicKey()), file.OptMkdirAll); err != nil {
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "writing ssh CA pubkey: %s", err)); e != nil {
err = errors.Join(err, e)
}
return err
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "writing ssh CA pubkey: %s", err)))
}
interfaces, err := net.Interfaces()
if err != nil {
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "getting network interfaces: %s", err)))
}
// Needed since go doesn't implicitly convert slices of structs to slices of interfaces
interfacesForFunc := make([]addresses.NetInterface, len(interfaces))
for i := range interfaces {
interfacesForFunc[i] = &interfaces[i]
}
principalList, err := addresses.GetMachineNetworkAddresses(interfacesForFunc)
if err != nil {
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "failed to get network addresses: %s", err)))
}
hostname, err := os.Hostname()
if err != nil {
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "failed to get hostname: %s", err)))
}
principalList = append(principalList, hostname)
principalList = append(principalList, req.ApiserverCertSans...)
hostKeyContent, err := s.fileHandler.Read(constants.SSHHostKeyPath)
if err != nil {
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "failed to read host SSH key: %s", err)))
}
hostPrivateKey, err := ssh.ParsePrivateKey(hostKeyContent)
if err != nil {
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "failed to parse host SSH key: %s", err)))
}
hostKeyPubSSH := hostPrivateKey.PublicKey()
hostCertificate, err := crypto.GenerateSSHHostCertificate(principalList, hostKeyPubSSH, ca)
if err != nil {
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "generating SSH host certificate: %s", err)))
}
if err := s.fileHandler.Write(constants.SSHAdditionalPrincipalsPath, []byte(strings.Join(req.ApiserverCertSans, ",")), file.OptMkdirAll); err != nil {
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "writing list of public ssh principals: %s", err)))
}
if err := s.fileHandler.Write(constants.SSHHostCertificatePath, ssh.MarshalAuthorizedKey(hostCertificate), file.OptMkdirAll); err != nil {
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "writing ssh host certificate: %s", err)))
}
clusterName := req.ClusterName
@ -261,10 +283,7 @@ func (s *Server) Init(req *initproto.InitRequest, stream initproto.API_InitServe
req.ServiceCidr,
)
if err != nil {
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "initializing cluster: %s", err)); e != nil {
err = errors.Join(err, e)
}
return err
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "initializing cluster: %s", err)))
}
log.Info("Init succeeded")