mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-07-30 18:48:39 -04:00
feat: use SSH host certificates (#3786)
This commit is contained in:
parent
95f17a6d06
commit
7ea5c41f9b
34 changed files with 706 additions and 117 deletions
|
@ -26,11 +26,13 @@ import (
|
|||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/bootstrapper/initproto"
|
||||
"github.com/edgelesssys/constellation/v2/bootstrapper/internal/addresses"
|
||||
"github.com/edgelesssys/constellation/v2/bootstrapper/internal/journald"
|
||||
"github.com/edgelesssys/constellation/v2/internal/atls"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation"
|
||||
|
@ -153,35 +155,23 @@ func (s *Server) Init(req *initproto.InitRequest, stream initproto.API_InitServe
|
|||
s.kmsURI = req.KmsUri
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword(s.initSecretHash, req.InitSecret); err != nil {
|
||||
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "invalid init secret %s", err)); e != nil {
|
||||
err = errors.Join(err, e)
|
||||
}
|
||||
return err
|
||||
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "invalid init secret %s", err)))
|
||||
}
|
||||
|
||||
cloudKms, err := kmssetup.KMS(stream.Context(), req.StorageUri, req.KmsUri)
|
||||
if err != nil {
|
||||
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "creating kms client: %s", err)); e != nil {
|
||||
err = errors.Join(err, e)
|
||||
}
|
||||
return err
|
||||
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "creating kms client: %s", err)))
|
||||
}
|
||||
|
||||
// generate values for cluster attestation
|
||||
clusterID, err := deriveMeasurementValues(stream.Context(), req.MeasurementSalt, cloudKms)
|
||||
if err != nil {
|
||||
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "deriving measurement values: %s", err)); e != nil {
|
||||
err = errors.Join(err, e)
|
||||
}
|
||||
return err
|
||||
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "deriving measurement values: %s", err)))
|
||||
}
|
||||
|
||||
nodeLockAcquired, err := s.nodeLock.TryLockOnce(clusterID)
|
||||
if err != nil {
|
||||
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "locking node: %s", err)); e != nil {
|
||||
err = errors.Join(err, e)
|
||||
}
|
||||
return err
|
||||
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "locking node: %s", err)))
|
||||
}
|
||||
if !nodeLockAcquired {
|
||||
// The join client seems to already have a connection to an
|
||||
|
@ -208,10 +198,7 @@ func (s *Server) Init(req *initproto.InitRequest, stream initproto.API_InitServe
|
|||
}()
|
||||
|
||||
if err := s.setupDisk(stream.Context(), cloudKms); err != nil {
|
||||
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "setting up disk: %s", err)); e != nil {
|
||||
err = errors.Join(err, e)
|
||||
}
|
||||
return err
|
||||
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "setting up disk: %s", err)))
|
||||
}
|
||||
|
||||
state := nodestate.NodeState{
|
||||
|
@ -219,32 +206,67 @@ func (s *Server) Init(req *initproto.InitRequest, stream initproto.API_InitServe
|
|||
MeasurementSalt: req.MeasurementSalt,
|
||||
}
|
||||
if err := state.ToFile(s.fileHandler); err != nil {
|
||||
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "persisting node state: %s", err)); e != nil {
|
||||
err = errors.Join(err, e)
|
||||
}
|
||||
return err
|
||||
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "persisting node state: %s", err)))
|
||||
}
|
||||
|
||||
// Derive the emergency ssh CA key
|
||||
key, err := cloudKms.GetDEK(stream.Context(), crypto.DEKPrefix+constants.SSHCAKeySuffix, ed25519.SeedSize)
|
||||
if err != nil {
|
||||
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "retrieving DEK for key derivation: %s", err)); e != nil {
|
||||
err = errors.Join(err, e)
|
||||
}
|
||||
return err
|
||||
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "retrieving DEK for key derivation: %s", err)))
|
||||
}
|
||||
ca, err := crypto.GenerateEmergencySSHCAKey(key)
|
||||
if err != nil {
|
||||
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "generating emergency SSH CA key: %s", err)); e != nil {
|
||||
err = errors.Join(err, e)
|
||||
}
|
||||
return err
|
||||
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "generating emergency SSH CA key: %s", err)))
|
||||
}
|
||||
if err := s.fileHandler.Write(constants.SSHCAKeyPath, ssh.MarshalAuthorizedKey(ca.PublicKey()), file.OptMkdirAll); err != nil {
|
||||
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "writing ssh CA pubkey: %s", err)); e != nil {
|
||||
err = errors.Join(err, e)
|
||||
}
|
||||
return err
|
||||
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "writing ssh CA pubkey: %s", err)))
|
||||
}
|
||||
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "getting network interfaces: %s", err)))
|
||||
}
|
||||
// Needed since go doesn't implicitly convert slices of structs to slices of interfaces
|
||||
interfacesForFunc := make([]addresses.NetInterface, len(interfaces))
|
||||
for i := range interfaces {
|
||||
interfacesForFunc[i] = &interfaces[i]
|
||||
}
|
||||
|
||||
principalList, err := addresses.GetMachineNetworkAddresses(interfacesForFunc)
|
||||
if err != nil {
|
||||
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "failed to get network addresses: %s", err)))
|
||||
}
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "failed to get hostname: %s", err)))
|
||||
}
|
||||
|
||||
principalList = append(principalList, hostname)
|
||||
principalList = append(principalList, req.ApiserverCertSans...)
|
||||
|
||||
hostKeyContent, err := s.fileHandler.Read(constants.SSHHostKeyPath)
|
||||
if err != nil {
|
||||
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "failed to read host SSH key: %s", err)))
|
||||
}
|
||||
|
||||
hostPrivateKey, err := ssh.ParsePrivateKey(hostKeyContent)
|
||||
if err != nil {
|
||||
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "failed to parse host SSH key: %s", err)))
|
||||
}
|
||||
|
||||
hostKeyPubSSH := hostPrivateKey.PublicKey()
|
||||
|
||||
hostCertificate, err := crypto.GenerateSSHHostCertificate(principalList, hostKeyPubSSH, ca)
|
||||
if err != nil {
|
||||
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "generating SSH host certificate: %s", err)))
|
||||
}
|
||||
|
||||
if err := s.fileHandler.Write(constants.SSHAdditionalPrincipalsPath, []byte(strings.Join(req.ApiserverCertSans, ",")), file.OptMkdirAll); err != nil {
|
||||
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "writing list of public ssh principals: %s", err)))
|
||||
}
|
||||
|
||||
if err := s.fileHandler.Write(constants.SSHHostCertificatePath, ssh.MarshalAuthorizedKey(hostCertificate), file.OptMkdirAll); err != nil {
|
||||
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "writing ssh host certificate: %s", err)))
|
||||
}
|
||||
|
||||
clusterName := req.ClusterName
|
||||
|
@ -261,10 +283,7 @@ func (s *Server) Init(req *initproto.InitRequest, stream initproto.API_InitServe
|
|||
req.ServiceCidr,
|
||||
)
|
||||
if err != nil {
|
||||
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "initializing cluster: %s", err)); e != nil {
|
||||
err = errors.Join(err, e)
|
||||
}
|
||||
return err
|
||||
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "initializing cluster: %s", err)))
|
||||
}
|
||||
|
||||
log.Info("Init succeeded")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue