feat: use SSH host certificates (#3786)

This commit is contained in:
miampf 2025-07-01 12:47:04 +02:00 committed by GitHub
parent 95f17a6d06
commit 7ea5c41f9b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 706 additions and 117 deletions

View file

@ -23,19 +23,21 @@ runs:
lb="$(terraform output -raw loadbalancer_address)" lb="$(terraform output -raw loadbalancer_address)"
popd popd
lb_ip="$(gethostip $lb | awk '{print $2}')"
echo "Resolved ip of load balancer: $lb_ip"
# write ssh config # write ssh config
cat > ssh_config <<EOF cat > ssh_config <<EOF
Host $lb Host $lb_ip
ProxyJump none ProxyJump none
Host * Host *
StrictHostKeyChecking no
UserKnownHostsFile=/dev/null
IdentityFile ./access-key IdentityFile ./access-key
PreferredAuthentications publickey PreferredAuthentications publickey
CertificateFile=constellation_cert.pub CertificateFile=constellation_cert.pub
UserKnownHostsFile=./known_hosts
User root User root
ProxyJump $lb ProxyJump $lb_ip
EOF EOF
for i in {1..26}; do for i in {1..26}; do

View file

@ -150,7 +150,9 @@ runs:
- name: Setup bazel - name: Setup bazel
uses: ./.github/actions/setup_bazel_nix uses: ./.github/actions/setup_bazel_nix
with: with:
nixTools: terraform nixTools: |
terraform
syslinux
- name: Log in to the Container registry - name: Log in to the Container registry
uses: ./.github/actions/container_registry_login uses: ./.github/actions/container_registry_login

View file

@ -0,0 +1,26 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
load("//bazel/go:go_test.bzl", "go_test")
go_library(
name = "interfaces",
srcs = ["interfaces.go"],
importpath = "github.com/edgelesssys/constellation/v2/bootstrapper/internal/interfaces",
visibility = ["//bootstrapper:__subpackages__"],
)
go_library(
name = "addresses",
srcs = ["addresses.go"],
importpath = "github.com/edgelesssys/constellation/v2/bootstrapper/internal/addresses",
visibility = ["//bootstrapper:__subpackages__"],
)
go_test(
name = "addresses_test",
srcs = ["addresses_test.go"],
deps = [
":addresses",
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require",
],
)

View file

@ -0,0 +1,45 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package addresses
import (
"net"
)
// GetMachineNetworkAddresses retrieves all network interface addresses.
func GetMachineNetworkAddresses(interfaces []NetInterface) ([]string, error) {
var addresses []string
for _, i := range interfaces {
addrs, err := i.Addrs()
if err != nil {
return nil, err
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
default:
continue
}
if ip.IsLoopback() {
continue
}
addresses = append(addresses, ip.String())
}
}
return addresses, nil
}
// NetInterface represents a network interface used to get network addresses.
type NetInterface interface {
Addrs() ([]net.Addr, error)
}

View file

@ -0,0 +1,67 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package addresses_test
import (
"errors"
"net"
"testing"
"github.com/edgelesssys/constellation/v2/bootstrapper/internal/addresses"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetMachineNetworkAddresses(t *testing.T) {
_, someAddr, err := net.ParseCIDR("10.9.0.1/24")
require.NoError(t, err)
testCases := map[string]struct {
interfaces []addresses.NetInterface
wantErr bool
}{
"successful": {
interfaces: []addresses.NetInterface{
&mockNetInterface{
addrs: []net.Addr{
someAddr,
},
},
},
},
"unsuccessful": {
interfaces: []addresses.NetInterface{
&mockNetInterface{addrs: nil, err: errors.New("someError")},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
addrs, err := addresses.GetMachineNetworkAddresses(tc.interfaces)
if tc.wantErr {
assert.Error(err)
} else {
assert.Equal([]string{"10.9.0.0"}, addrs)
assert.NoError(err)
}
})
}
}
type mockNetInterface struct {
addrs []net.Addr
err error
}
func (m *mockNetInterface) Addrs() ([]net.Addr, error) {
return m.addrs, m.err
}

View file

@ -8,6 +8,7 @@ go_library(
visibility = ["//bootstrapper:__subpackages__"], visibility = ["//bootstrapper:__subpackages__"],
deps = [ deps = [
"//bootstrapper/initproto", "//bootstrapper/initproto",
"//bootstrapper/internal/addresses",
"//bootstrapper/internal/journald", "//bootstrapper/internal/journald",
"//internal/atls", "//internal/atls",
"//internal/attestation", "//internal/attestation",
@ -43,6 +44,7 @@ go_test(
"//bootstrapper/initproto", "//bootstrapper/initproto",
"//internal/atls", "//internal/atls",
"//internal/attestation/variant", "//internal/attestation/variant",
"//internal/constants",
"//internal/crypto/testvector", "//internal/crypto/testvector",
"//internal/file", "//internal/file",
"//internal/kms/setup", "//internal/kms/setup",
@ -54,6 +56,7 @@ go_test(
"@com_github_stretchr_testify//require", "@com_github_stretchr_testify//require",
"@org_golang_google_grpc//:grpc", "@org_golang_google_grpc//:grpc",
"@org_golang_x_crypto//bcrypt", "@org_golang_x_crypto//bcrypt",
"@org_golang_x_crypto//ssh",
"@org_uber_go_goleak//:goleak", "@org_uber_go_goleak//:goleak",
], ],
) )

View file

@ -26,11 +26,13 @@ import (
"io" "io"
"log/slog" "log/slog"
"net" "net"
"os"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/edgelesssys/constellation/v2/bootstrapper/initproto" "github.com/edgelesssys/constellation/v2/bootstrapper/initproto"
"github.com/edgelesssys/constellation/v2/bootstrapper/internal/addresses"
"github.com/edgelesssys/constellation/v2/bootstrapper/internal/journald" "github.com/edgelesssys/constellation/v2/bootstrapper/internal/journald"
"github.com/edgelesssys/constellation/v2/internal/atls" "github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/attestation" "github.com/edgelesssys/constellation/v2/internal/attestation"
@ -153,35 +155,23 @@ func (s *Server) Init(req *initproto.InitRequest, stream initproto.API_InitServe
s.kmsURI = req.KmsUri s.kmsURI = req.KmsUri
if err := bcrypt.CompareHashAndPassword(s.initSecretHash, req.InitSecret); err != nil { if err := bcrypt.CompareHashAndPassword(s.initSecretHash, req.InitSecret); err != nil {
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "invalid init secret %s", err)); e != nil { return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "invalid init secret %s", err)))
err = errors.Join(err, e)
}
return err
} }
cloudKms, err := kmssetup.KMS(stream.Context(), req.StorageUri, req.KmsUri) cloudKms, err := kmssetup.KMS(stream.Context(), req.StorageUri, req.KmsUri)
if err != nil { if err != nil {
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "creating kms client: %s", err)); e != nil { return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "creating kms client: %s", err)))
err = errors.Join(err, e)
}
return err
} }
// generate values for cluster attestation // generate values for cluster attestation
clusterID, err := deriveMeasurementValues(stream.Context(), req.MeasurementSalt, cloudKms) clusterID, err := deriveMeasurementValues(stream.Context(), req.MeasurementSalt, cloudKms)
if err != nil { if err != nil {
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "deriving measurement values: %s", err)); e != nil { return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "deriving measurement values: %s", err)))
err = errors.Join(err, e)
}
return err
} }
nodeLockAcquired, err := s.nodeLock.TryLockOnce(clusterID) nodeLockAcquired, err := s.nodeLock.TryLockOnce(clusterID)
if err != nil { if err != nil {
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "locking node: %s", err)); e != nil { return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "locking node: %s", err)))
err = errors.Join(err, e)
}
return err
} }
if !nodeLockAcquired { if !nodeLockAcquired {
// The join client seems to already have a connection to an // The join client seems to already have a connection to an
@ -208,10 +198,7 @@ func (s *Server) Init(req *initproto.InitRequest, stream initproto.API_InitServe
}() }()
if err := s.setupDisk(stream.Context(), cloudKms); err != nil { if err := s.setupDisk(stream.Context(), cloudKms); err != nil {
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "setting up disk: %s", err)); e != nil { return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "setting up disk: %s", err)))
err = errors.Join(err, e)
}
return err
} }
state := nodestate.NodeState{ state := nodestate.NodeState{
@ -219,32 +206,67 @@ func (s *Server) Init(req *initproto.InitRequest, stream initproto.API_InitServe
MeasurementSalt: req.MeasurementSalt, MeasurementSalt: req.MeasurementSalt,
} }
if err := state.ToFile(s.fileHandler); err != nil { if err := state.ToFile(s.fileHandler); err != nil {
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "persisting node state: %s", err)); e != nil { return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "persisting node state: %s", err)))
err = errors.Join(err, e)
}
return err
} }
// Derive the emergency ssh CA key // Derive the emergency ssh CA key
key, err := cloudKms.GetDEK(stream.Context(), crypto.DEKPrefix+constants.SSHCAKeySuffix, ed25519.SeedSize) key, err := cloudKms.GetDEK(stream.Context(), crypto.DEKPrefix+constants.SSHCAKeySuffix, ed25519.SeedSize)
if err != nil { if err != nil {
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "retrieving DEK for key derivation: %s", err)); e != nil { return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "retrieving DEK for key derivation: %s", err)))
err = errors.Join(err, e)
}
return err
} }
ca, err := crypto.GenerateEmergencySSHCAKey(key) ca, err := crypto.GenerateEmergencySSHCAKey(key)
if err != nil { if err != nil {
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "generating emergency SSH CA key: %s", err)); e != nil { return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "generating emergency SSH CA key: %s", err)))
err = errors.Join(err, e)
}
return err
} }
if err := s.fileHandler.Write(constants.SSHCAKeyPath, ssh.MarshalAuthorizedKey(ca.PublicKey()), file.OptMkdirAll); err != nil { if err := s.fileHandler.Write(constants.SSHCAKeyPath, ssh.MarshalAuthorizedKey(ca.PublicKey()), file.OptMkdirAll); err != nil {
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "writing ssh CA pubkey: %s", err)); e != nil { return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "writing ssh CA pubkey: %s", err)))
err = errors.Join(err, e)
} }
return err
interfaces, err := net.Interfaces()
if err != nil {
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "getting network interfaces: %s", err)))
}
// Needed since go doesn't implicitly convert slices of structs to slices of interfaces
interfacesForFunc := make([]addresses.NetInterface, len(interfaces))
for i := range interfaces {
interfacesForFunc[i] = &interfaces[i]
}
principalList, err := addresses.GetMachineNetworkAddresses(interfacesForFunc)
if err != nil {
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "failed to get network addresses: %s", err)))
}
hostname, err := os.Hostname()
if err != nil {
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "failed to get hostname: %s", err)))
}
principalList = append(principalList, hostname)
principalList = append(principalList, req.ApiserverCertSans...)
hostKeyContent, err := s.fileHandler.Read(constants.SSHHostKeyPath)
if err != nil {
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "failed to read host SSH key: %s", err)))
}
hostPrivateKey, err := ssh.ParsePrivateKey(hostKeyContent)
if err != nil {
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "failed to parse host SSH key: %s", err)))
}
hostKeyPubSSH := hostPrivateKey.PublicKey()
hostCertificate, err := crypto.GenerateSSHHostCertificate(principalList, hostKeyPubSSH, ca)
if err != nil {
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "generating SSH host certificate: %s", err)))
}
if err := s.fileHandler.Write(constants.SSHAdditionalPrincipalsPath, []byte(strings.Join(req.ApiserverCertSans, ",")), file.OptMkdirAll); err != nil {
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "writing list of public ssh principals: %s", err)))
}
if err := s.fileHandler.Write(constants.SSHHostCertificatePath, ssh.MarshalAuthorizedKey(hostCertificate), file.OptMkdirAll); err != nil {
return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "writing ssh host certificate: %s", err)))
} }
clusterName := req.ClusterName clusterName := req.ClusterName
@ -261,10 +283,7 @@ func (s *Server) Init(req *initproto.InitRequest, stream initproto.API_InitServe
req.ServiceCidr, req.ServiceCidr,
) )
if err != nil { if err != nil {
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "initializing cluster: %s", err)); e != nil { return errors.Join(err, s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "initializing cluster: %s", err)))
err = errors.Join(err, e)
}
return err
} }
log.Info("Init succeeded") log.Info("Init succeeded")

View file

@ -9,9 +9,12 @@ package initserver
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/ed25519"
"encoding/pem"
"errors" "errors"
"io" "io"
"net" "net"
"os"
"strings" "strings"
"sync" "sync"
"testing" "testing"
@ -20,6 +23,7 @@ import (
"github.com/edgelesssys/constellation/v2/bootstrapper/initproto" "github.com/edgelesssys/constellation/v2/bootstrapper/initproto"
"github.com/edgelesssys/constellation/v2/internal/atls" "github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/attestation/variant" "github.com/edgelesssys/constellation/v2/internal/attestation/variant"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/crypto/testvector" "github.com/edgelesssys/constellation/v2/internal/crypto/testvector"
"github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/file"
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup" kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
@ -31,6 +35,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/goleak" "go.uber.org/goleak"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"golang.org/x/crypto/ssh"
"google.golang.org/grpc" "google.golang.org/grpc"
) )
@ -100,6 +105,19 @@ func TestInit(t *testing.T) {
masterSecret := uri.MasterSecret{Key: []byte("secret"), Salt: []byte("salt")} masterSecret := uri.MasterSecret{Key: []byte("secret"), Salt: []byte("salt")}
_, privkey, err := ed25519.GenerateKey(nil)
require.NoError(t, err)
pemHostKey, err := ssh.MarshalPrivateKey(privkey, "")
require.NoError(t, err)
fsWithHostKey := afero.NewMemMapFs()
hostKeyFile, err := fsWithHostKey.Create(constants.SSHHostKeyPath)
require.NoError(t, err)
_, err = hostKeyFile.Write(pem.EncodeToMemory(pemHostKey))
require.NoError(t, err)
require.NoError(t, hostKeyFile.Close())
readOnlyFSWithHostKey := afero.NewReadOnlyFs(fsWithHostKey)
testCases := map[string]struct { testCases := map[string]struct {
nodeLock *fakeLock nodeLock *fakeLock
initializer ClusterInitializer initializer ClusterInitializer
@ -109,6 +127,7 @@ func TestInit(t *testing.T) {
stream stubStream stream stubStream
logCollector stubJournaldCollector logCollector stubJournaldCollector
initSecretHash []byte initSecretHash []byte
hostkeyDoesntExist bool
wantErr bool wantErr bool
wantShutdown bool wantShutdown bool
}{ }{
@ -174,7 +193,7 @@ func TestInit(t *testing.T) {
nodeLock: newFakeLock(), nodeLock: newFakeLock(),
initializer: &stubClusterInitializer{}, initializer: &stubClusterInitializer{},
disk: &stubDisk{}, disk: &stubDisk{},
fileHandler: file.NewHandler(afero.NewReadOnlyFs(afero.NewMemMapFs())), fileHandler: file.NewHandler(readOnlyFSWithHostKey),
req: &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: uri.NoStoreURI}, req: &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: uri.NoStoreURI},
stream: stubStream{}, stream: stubStream{},
logCollector: stubJournaldCollector{logPipe: &stubReadCloser{reader: bytes.NewReader([]byte{})}}, logCollector: stubJournaldCollector{logPipe: &stubReadCloser{reader: bytes.NewReader([]byte{})}},
@ -205,11 +224,31 @@ func TestInit(t *testing.T) {
logCollector: stubJournaldCollector{logPipe: &stubReadCloser{reader: bytes.NewReader([]byte{})}}, logCollector: stubJournaldCollector{logPipe: &stubReadCloser{reader: bytes.NewReader([]byte{})}},
wantErr: true, wantErr: true,
}, },
"host key doesn't exist": {
nodeLock: newFakeLock(),
initializer: &stubClusterInitializer{},
disk: &stubDisk{},
fileHandler: file.NewHandler(afero.NewMemMapFs()),
initSecretHash: initSecretHash,
req: &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: uri.NoStoreURI},
stream: stubStream{},
logCollector: stubJournaldCollector{logPipe: &stubReadCloser{reader: bytes.NewReader([]byte{})}},
hostkeyDoesntExist: true,
wantShutdown: true,
wantErr: true,
},
} }
for name, tc := range testCases { for name, tc := range testCases {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t)
if _, err := tc.fileHandler.Stat(constants.SSHHostKeyPath); errors.Is(err, os.ErrNotExist) {
if !tc.hostkeyDoesntExist {
require.NoError(tc.fileHandler.Write(constants.SSHHostKeyPath, pem.EncodeToMemory(pemHostKey), file.OptMkdirAll))
}
}
serveStopper := newStubServeStopper() serveStopper := newStubServeStopper()
server := &Server{ server := &Server{

View file

@ -7,6 +7,7 @@ go_library(
importpath = "github.com/edgelesssys/constellation/v2/bootstrapper/internal/joinclient", importpath = "github.com/edgelesssys/constellation/v2/bootstrapper/internal/joinclient",
visibility = ["//bootstrapper:__subpackages__"], visibility = ["//bootstrapper:__subpackages__"],
deps = [ deps = [
"//bootstrapper/internal/addresses",
"//bootstrapper/internal/certificate", "//bootstrapper/internal/certificate",
"//internal/attestation", "//internal/attestation",
"//internal/cloud/metadata", "//internal/cloud/metadata",
@ -21,6 +22,7 @@ go_library(
"@io_k8s_kubernetes//cmd/kubeadm/app/constants", "@io_k8s_kubernetes//cmd/kubeadm/app/constants",
"@io_k8s_utils//clock", "@io_k8s_utils//clock",
"@org_golang_google_grpc//:grpc", "@org_golang_google_grpc//:grpc",
"@org_golang_x_crypto//ssh",
], ],
) )
@ -35,6 +37,7 @@ go_test(
deps = [ deps = [
"//internal/cloud/metadata", "//internal/cloud/metadata",
"//internal/constants", "//internal/constants",
"//internal/crypto",
"//internal/file", "//internal/file",
"//internal/grpc/atlscredentials", "//internal/grpc/atlscredentials",
"//internal/grpc/dialer", "//internal/grpc/dialer",
@ -49,6 +52,7 @@ go_test(
"@io_k8s_kubernetes//cmd/kubeadm/app/apis/kubeadm/v1beta3", "@io_k8s_kubernetes//cmd/kubeadm/app/apis/kubeadm/v1beta3",
"@io_k8s_utils//clock/testing", "@io_k8s_utils//clock/testing",
"@org_golang_google_grpc//:grpc", "@org_golang_google_grpc//:grpc",
"@org_golang_x_crypto//ssh",
"@org_uber_go_goleak//:goleak", "@org_uber_go_goleak//:goleak",
], ],
) )

View file

@ -23,10 +23,12 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"net" "net"
"os"
"path/filepath" "path/filepath"
"strconv" "strconv"
"time" "time"
"github.com/edgelesssys/constellation/v2/bootstrapper/internal/addresses"
"github.com/edgelesssys/constellation/v2/bootstrapper/internal/certificate" "github.com/edgelesssys/constellation/v2/bootstrapper/internal/certificate"
"github.com/edgelesssys/constellation/v2/internal/attestation" "github.com/edgelesssys/constellation/v2/internal/attestation"
"github.com/edgelesssys/constellation/v2/internal/cloud/metadata" "github.com/edgelesssys/constellation/v2/internal/cloud/metadata"
@ -37,6 +39,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/versions/components" "github.com/edgelesssys/constellation/v2/internal/versions/components"
"github.com/edgelesssys/constellation/v2/joinservice/joinproto" "github.com/edgelesssys/constellation/v2/joinservice/joinproto"
"github.com/spf13/afero" "github.com/spf13/afero"
"golang.org/x/crypto/ssh"
"google.golang.org/grpc" "google.golang.org/grpc"
kubeadm "k8s.io/kubernetes/cmd/kubeadm/app/apis/kubeadm/v1beta3" kubeadm "k8s.io/kubernetes/cmd/kubeadm/app/apis/kubeadm/v1beta3"
kubeconstants "k8s.io/kubernetes/cmd/kubeadm/app/constants" kubeconstants "k8s.io/kubernetes/cmd/kubeadm/app/constants"
@ -209,6 +212,42 @@ func (c *JoinClient) requestJoinTicket(serviceEndpoint string) (ticket *joinprot
return nil, nil, err return nil, nil, err
} }
interfaces, err := net.Interfaces()
if err != nil {
c.log.With(slog.Any("error", err)).Error("Failed to get network interfaces")
return nil, nil, err
}
// Needed since go doesn't implicitly convert slices of structs to slices of interfaces
interfacesForFunc := make([]addresses.NetInterface, len(interfaces))
for i := range interfaces {
interfacesForFunc[i] = &interfaces[i]
}
principalList, err := addresses.GetMachineNetworkAddresses(interfacesForFunc)
if err != nil {
c.log.With(slog.Any("error", err)).Error("Failed to get network addresses")
return nil, nil, err
}
hostname, err := os.Hostname()
if err != nil {
c.log.With(slog.Any("error", err)).Error("Failed to get hostname")
return nil, nil, err
}
principalList = append(principalList, hostname)
hostKeyData, err := c.fileHandler.Read(constants.SSHHostKeyPath)
if err != nil {
c.log.With(slog.Any("error", err)).Error("Failed to read SSH host key file")
return nil, nil, err
}
hostKey, err := ssh.ParsePrivateKey(hostKeyData)
if err != nil {
c.log.With(slog.Any("error", err)).Error("Failed to parse SSH host key file")
return nil, nil, err
}
hostKeyPubSSH := hostKey.PublicKey()
conn, err := c.dialer.Dial(serviceEndpoint) conn, err := c.dialer.Dial(serviceEndpoint)
if err != nil { if err != nil {
c.log.With(slog.String("endpoint", serviceEndpoint), slog.Any("error", err)).Error("Join service unreachable") c.log.With(slog.String("endpoint", serviceEndpoint), slog.Any("error", err)).Error("Join service unreachable")
@ -221,6 +260,8 @@ func (c *JoinClient) requestJoinTicket(serviceEndpoint string) (ticket *joinprot
DiskUuid: c.diskUUID, DiskUuid: c.diskUUID,
CertificateRequest: certificateRequest, CertificateRequest: certificateRequest,
IsControlPlane: c.role == role.ControlPlane, IsControlPlane: c.role == role.ControlPlane,
HostPublicKey: hostKeyPubSSH.Marshal(),
HostCertificatePrincipals: principalList,
} }
ticket, err = protoClient.IssueJoinTicket(ctx, req) ticket, err = protoClient.IssueJoinTicket(ctx, req)
if err != nil { if err != nil {
@ -275,6 +316,10 @@ func (c *JoinClient) startNodeAndJoin(ticket *joinproto.IssueJoinTicketResponse,
return fmt.Errorf("writing ssh ca key: %w", err) return fmt.Errorf("writing ssh ca key: %w", err)
} }
if err := c.fileHandler.Write(constants.SSHHostCertificatePath, ticket.HostCertificate, file.OptMkdirAll); err != nil {
return fmt.Errorf("writing ssh host certificate: %w", err)
}
state := nodestate.NodeState{ state := nodestate.NodeState{
Role: c.role, Role: c.role,
MeasurementSalt: ticket.MeasurementSalt, MeasurementSalt: ticket.MeasurementSalt,

View file

@ -8,7 +8,11 @@ package joinclient
import ( import (
"context" "context"
"crypto/ed25519"
"encoding/pem"
"errors"
"net" "net"
"os"
"strconv" "strconv"
"sync" "sync"
"testing" "testing"
@ -16,6 +20,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/cloud/metadata" "github.com/edgelesssys/constellation/v2/internal/cloud/metadata"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/crypto"
"github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/grpc/atlscredentials" "github.com/edgelesssys/constellation/v2/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer" "github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
@ -28,6 +33,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/goleak" "go.uber.org/goleak"
"golang.org/x/crypto/ssh"
"google.golang.org/grpc" "google.golang.org/grpc"
kubeadm "k8s.io/kubernetes/cmd/kubeadm/app/apis/kubeadm/v1beta3" kubeadm "k8s.io/kubernetes/cmd/kubeadm/app/apis/kubeadm/v1beta3"
testclock "k8s.io/utils/clock/testing" testclock "k8s.io/utils/clock/testing"
@ -53,6 +59,59 @@ func TestClient(t *testing.T) {
caDerivationKey := make([]byte, 256) caDerivationKey := make([]byte, 256)
respCaKey := &joinproto.IssueJoinTicketResponse{AuthorizedCaPublicKey: caDerivationKey} respCaKey := &joinproto.IssueJoinTicketResponse{AuthorizedCaPublicKey: caDerivationKey}
// TODO: fix test since keys are generated with systemd service
makeIssueJoinTicketAnswerWithValidCert := func(t *testing.T, originalAnswer issueJoinTicketAnswer, fh file.Handler) issueJoinTicketAnswer {
require := require.New(t)
sshKeyBytes, err := fh.Read(constants.SSHHostKeyPath)
require.NoError(err)
sshKey, err := ssh.ParsePrivateKey(sshKeyBytes)
require.NoError(err)
_, randomCAKey, err := ed25519.GenerateKey(nil)
require.NoError(err)
randomCA, err := ssh.NewSignerFromSigner(randomCAKey)
require.NoError(err)
cert, err := crypto.GenerateSSHHostCertificate([]string{"asdf"}, sshKey.PublicKey(), randomCA)
require.NoError(err)
certBytes := ssh.MarshalAuthorizedKey(cert)
if originalAnswer.resp == nil {
originalAnswer.resp = &joinproto.IssueJoinTicketResponse{HostCertificate: certBytes}
} else {
originalAnswer.resp.HostCertificate = certBytes
}
return originalAnswer
}
makeIssueJoinTicketAnswerWithInvalidCert := func(t *testing.T, originalAnswer issueJoinTicketAnswer) issueJoinTicketAnswer {
require := require.New(t)
_, randomCAKey, err := ed25519.GenerateKey(nil)
require.NoError(err)
randomCA, err := ssh.NewSignerFromSigner(randomCAKey)
require.NoError(err)
randomKey, _, err := ed25519.GenerateKey(nil)
require.NoError(err)
randomSSHKey, err := ssh.NewPublicKey(randomKey)
require.NoError(err)
cert, err := crypto.GenerateSSHHostCertificate([]string{"asdf"}, randomSSHKey, randomCA)
require.NoError(err)
certBytes := ssh.MarshalAuthorizedKey(cert)
if originalAnswer.resp == nil {
originalAnswer.resp = &joinproto.IssueJoinTicketResponse{HostCertificate: certBytes}
} else {
originalAnswer.resp.HostCertificate = certBytes
}
return originalAnswer
}
testCases := map[string]struct { testCases := map[string]struct {
role role.Role role role.Role
clusterJoiner *stubClusterJoiner clusterJoiner *stubClusterJoiner
@ -62,6 +121,8 @@ func TestClient(t *testing.T) {
wantLock bool wantLock bool
wantJoin bool wantJoin bool
wantNumJoins int wantNumJoins int
wantNotMatchingCert bool
wantCertNotExisting bool
}{ }{
"on worker: metadata self: errors occur": { "on worker: metadata self: errors occur": {
role: role.Worker, role: role.Worker,
@ -79,6 +140,23 @@ func TestClient(t *testing.T) {
wantJoin: true, wantJoin: true,
wantLock: true, wantLock: true,
}, },
"on worker: SSH host cert not matching": {
role: role.Worker,
apiAnswers: []any{
selfAnswer{err: assert.AnError},
selfAnswer{err: assert.AnError},
selfAnswer{err: assert.AnError},
selfAnswer{instance: workerSelf},
listAnswer{instances: peers},
issueJoinTicketAnswer{resp: respCaKey},
},
clusterJoiner: &stubClusterJoiner{},
nodeLock: newFakeLock(),
disk: &stubDisk{},
wantJoin: true,
wantLock: true,
wantNotMatchingCert: true,
},
"on worker: metadata self: invalid answer": { "on worker: metadata self: invalid answer": {
role: role.Worker, role: role.Worker,
apiAnswers: []any{ apiAnswers: []any{
@ -199,29 +277,39 @@ func TestClient(t *testing.T) {
nodeLock: lockedLock, nodeLock: lockedLock,
disk: &stubDisk{}, disk: &stubDisk{},
wantLock: true, wantLock: true,
wantCertNotExisting: true,
}, },
"on control plane: disk open fails": { "on control plane: disk open fails": {
role: role.ControlPlane, role: role.ControlPlane,
clusterJoiner: &stubClusterJoiner{}, clusterJoiner: &stubClusterJoiner{},
nodeLock: newFakeLock(), nodeLock: newFakeLock(),
disk: &stubDisk{openErr: assert.AnError}, disk: &stubDisk{openErr: assert.AnError},
wantCertNotExisting: true,
}, },
"on control plane: disk uuid fails": { "on control plane: disk uuid fails": {
role: role.ControlPlane, role: role.ControlPlane,
clusterJoiner: &stubClusterJoiner{}, clusterJoiner: &stubClusterJoiner{},
nodeLock: newFakeLock(), nodeLock: newFakeLock(),
disk: &stubDisk{uuidErr: assert.AnError}, disk: &stubDisk{uuidErr: assert.AnError},
wantCertNotExisting: true,
}, },
} }
for name, tc := range testCases { for name, tc := range testCases {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t)
clock := testclock.NewFakeClock(time.Now()) clock := testclock.NewFakeClock(time.Now())
metadataAPI := newStubMetadataAPI() metadataAPI := newStubMetadataAPI()
fileHandler := file.NewHandler(afero.NewMemMapFs()) fileHandler := file.NewHandler(afero.NewMemMapFs())
_, hostKey, err := ed25519.GenerateKey(nil)
require.NoError(err)
hostKeyPEM, err := ssh.MarshalPrivateKey(hostKey, "hostkey")
require.NoError(err)
require.NoError(fileHandler.Write(constants.SSHHostKeyPath, pem.EncodeToMemory(hostKeyPEM), file.OptMkdirAll))
netDialer := testdialer.NewBufconnDialer() netDialer := testdialer.NewBufconnDialer()
dialer := dialer.New(nil, nil, netDialer) dialer := dialer.New(nil, nil, netDialer)
@ -259,13 +347,43 @@ func TestClient(t *testing.T) {
case listAnswer: case listAnswer:
metadataAPI.listAnswerC <- a metadataAPI.listAnswerC <- a
case issueJoinTicketAnswer: case issueJoinTicketAnswer:
joinserviceAPI.issueJoinTicketAnswerC <- a var answer issueJoinTicketAnswer
if tc.wantNotMatchingCert {
answer = makeIssueJoinTicketAnswerWithInvalidCert(t, a)
} else {
answer = makeIssueJoinTicketAnswerWithValidCert(t, a, fileHandler)
}
joinserviceAPI.issueJoinTicketAnswerC <- answer
} }
clock.Step(time.Second) clock.Step(time.Second)
} }
client.Stop() client.Stop()
if !tc.wantCertNotExisting {
hostCertBytes, err := fileHandler.Read(constants.SSHHostCertificatePath)
require.NoError(err)
hostKeyBytes, err := fileHandler.Read(constants.SSHHostKeyPath)
require.NoError(err)
hostCertKey, _, _, _, err := ssh.ParseAuthorizedKey(hostCertBytes)
require.NoError(err)
hostCert, ok := hostCertKey.(*ssh.Certificate)
require.True(ok)
hostKey, err := ssh.ParsePrivateKey(hostKeyBytes)
require.NoError(err)
if !tc.wantNotMatchingCert {
assert.Equal(hostKey.PublicKey().Marshal(), hostCert.Key.Marshal())
} else {
assert.NotEqual(hostKey.PublicKey().Marshal(), hostCert.Key.Marshal())
}
} else {
_, err := fileHandler.Stat(constants.SSHHostCertificatePath)
require.True(errors.Is(err, os.ErrNotExist))
}
if tc.wantJoin { if tc.wantJoin {
assert.Greater(tc.clusterJoiner.joinClusterCalled, 0) assert.Greater(tc.clusterJoiner.joinClusterCalled, 0)
} else { } else {

View file

@ -74,7 +74,12 @@ func writeCertificateForKey(cmd *cobra.Command, keyPath string, fh file.Handler,
return fmt.Errorf("generating SSH emergency CA key: %s", err) return fmt.Errorf("generating SSH emergency CA key: %s", err)
} }
debugLogger.Debug("SSH CA KEY generated", "public-key", string(ssh.MarshalAuthorizedKey(ca.PublicKey()))) marshalledKey := string(ssh.MarshalAuthorizedKey(ca.PublicKey()))
debugLogger.Debug("SSH CA KEY generated", "public-key", marshalledKey)
knownHostsContent := fmt.Sprintf("@cert-authority * %s", marshalledKey)
if err := fh.Write("./known_hosts", []byte(knownHostsContent), file.OptMkdirAll); err != nil {
return fmt.Errorf("writing known hosts file: %w", err)
}
keyBuffer, err := fh.Read(keyPath) keyBuffer, err := fh.Read(keyPath)
if err != nil { if err != nil {

View file

@ -177,7 +177,7 @@ Emergency SSH access to nodes can be useful to diagnose issues or download impor
3. Now you can connect to any Constellation node using your certificate and your private key. 3. Now you can connect to any Constellation node using your certificate and your private key.
```bash ```bash
ssh -o CertificateFile=constellation_cert.pub -i <your private key> root@<ip of constellation node> ssh -o CertificateFile=constellation_cert.pub -o UserKnownHostsFile=./known_hosts -i <your private key> root@<ip of constellation node>
``` ```
Normally, you don't have access to the Constellation nodes since they reside in a private network. Normally, you don't have access to the Constellation nodes since they reside in a private network.
@ -185,16 +185,18 @@ Emergency SSH access to nodes can be useful to diagnose issues or download impor
For this, use something along the following SSH client configuration: For this, use something along the following SSH client configuration:
```text ```text
Host <LB domain name> Host <LB public IP>
ProxyJump none ProxyJump none
Host * Host *
IdentityFile <your private key> IdentityFile <your private key>
PreferredAuthentications publickey PreferredAuthentications publickey
CertificateFile=constellation_cert.pub CertificateFile=constellation_cert.pub
UserKnownHostsFile=./known_hosts
User root User root
ProxyJump <LB domain name> ProxyJump <LB public IP>
``` ```
With this configuration you can connect to a Constellation node using `ssh -F <this config> <private node IP>`. With this configuration you can connect to a Constellation node using `ssh -F <this config> <private node IP>`.
You can obtain the private node IP and the domain name of the load balancer using your CSP's web UI. You can obtain the private node IP and the public IP of the load balancer using your CSP's web UI. Note that if
you use the load balancers domain name, ssh host certificate verification doesn't work, so using the public IP is recommended.

View file

@ -10,4 +10,3 @@ enable measurements.service
enable export_constellation_debug.service enable export_constellation_debug.service
enable systemd-timesyncd enable systemd-timesyncd
enable udev-trigger.service enable udev-trigger.service
enable create-host-ssh-key.service

View file

@ -1,7 +1,8 @@
[Unit] [Unit]
Description=Constellation Bootstrapper Description=Constellation Bootstrapper
Wants=network-online.target Wants=network-online.target
After=network-online.target configure-constel-csp.service Requires=sshd-keygen.target
After=network-online.target configure-constel-csp.service sshd-keygen.target
After=export_constellation_debug.service After=export_constellation_debug.service
[Service] [Service]

View file

@ -1,10 +0,0 @@
[Unit]
Description=Create a host SSH key
Before=network-pre.target
[Service]
Type=oneshot
ExecStart=/bin/bash -c "mkdir -p /run/ssh; ssh-keygen -t ecdsa -q -N '' -f /run/ssh/ssh_host_ecdsa_key"
[Install]
WantedBy=network-pre.target

View file

@ -1,4 +1,5 @@
HostKey /run/ssh/ssh_host_ecdsa_key HostKey /var/run/state/ssh/ssh_host_ed25519_key
TrustedUserCAKeys /run/ssh/ssh_ca.pub HostCertificate /var/run/state/ssh/ssh_host_cert.pub
TrustedUserCAKeys /var/run/state/ssh/ssh_ca.pub
PasswordAuthentication no PasswordAuthentication no
ChallengeResponseAuthentication no ChallengeResponseAuthentication no

View file

@ -0,0 +1,3 @@
[Unit]
ConditionFileNotEmpty=|!/var/run/state/ssh/ssh_host_%i_key
Before=constellation-bootstrapper.service

View file

@ -0,0 +1,3 @@
[Unit]
Wants=sshd-keygen@ed25519.service
PartOf=sshd.service

View file

@ -0,0 +1,44 @@
#!/usr/bin/bash
# Taken from the original openssh-server package and slightly modified
set -x
# Create the host keys for the OpenSSH server.
KEYTYPE=$1
case $KEYTYPE in
"dsa") ;& # disabled in FIPS
"ed25519")
FIPS=/proc/sys/crypto/fips_enabled
if [[ -r $FIPS && $(cat $FIPS) == "1" ]]; then
exit 0
fi
;;
"rsa") ;; # always ok
"ecdsa") ;;
*) # wrong argument
exit 12 ;;
esac
mkdir -p /var/run/state/ssh
KEY=/var/run/state/ssh/ssh_host_${KEYTYPE}_key
KEYGEN=/usr/bin/ssh-keygen
if [[ ! -x $KEYGEN ]]; then
exit 13
fi
# remove old keys
rm -f "$KEY"{,.pub}
# create new keys
if ! $KEYGEN -q -t "$KEYTYPE" -f "$KEY" -C '' -N '' >&/dev/null; then
exit 1
fi
# sanitize permissions
/usr/bin/chmod 600 "$KEY"
/usr/bin/chmod 644 "$KEY".pub
if [[ -x /usr/sbin/restorecon ]]; then
/usr/sbin/restorecon "$KEY"{,.pub}
fi
exit 0

View file

@ -45,7 +45,13 @@ const (
// SSHCAKeySuffix is the suffix used together with the DEKPrefix to derive an SSH CA key for emergency ssh access. // SSHCAKeySuffix is the suffix used together with the DEKPrefix to derive an SSH CA key for emergency ssh access.
SSHCAKeySuffix = "ca_emergency_ssh" SSHCAKeySuffix = "ca_emergency_ssh"
// SSHCAKeyPath is the path to the emergency SSH CA key on the node. // SSHCAKeyPath is the path to the emergency SSH CA key on the node.
SSHCAKeyPath = "/run/ssh/ssh_ca.pub" SSHCAKeyPath = "/var/run/state/ssh/ssh_ca.pub"
// SSHHostKeyPath is the path to the SSH host key of the node.
SSHHostKeyPath = "/var/run/state/ssh/ssh_host_ed25519_key"
// SSHHostCertificatePath is the path to the SSH host certificate.
SSHHostCertificatePath = "/var/run/state/ssh/ssh_host_cert.pub"
// SSHAdditionalPrincipalsPath stores additional principals (like the public IP of the load balancer) that get added to all host certificates.
SSHAdditionalPrincipalsPath = "/var/run/state/ssh/additional_principals.txt"
// //
// Ports. // Ports.

View file

@ -53,6 +53,8 @@ spec:
- mountPath: /var/secrets/google - mountPath: /var/secrets/google
name: gcekey name: gcekey
readOnly: true readOnly: true
- mountPath: /var/run/state/ssh
name: ssh
ports: ports:
- containerPort: {{ .Values.joinServicePort }} - containerPort: {{ .Values.joinServicePort }}
name: tcp name: tcp
@ -74,4 +76,7 @@ spec:
- name: kubeadm - name: kubeadm
hostPath: hostPath:
path: /etc/kubernetes path: /etc/kubernetes
- name: ssh
hostPath:
path: /var/run/state/ssh
updateStrategy: {} updateStrategy: {}

View file

@ -53,6 +53,8 @@ spec:
- mountPath: /var/secrets/google - mountPath: /var/secrets/google
name: gcekey name: gcekey
readOnly: true readOnly: true
- mountPath: /var/run/state/ssh
name: ssh
ports: ports:
- containerPort: 9090 - containerPort: 9090
name: tcp name: tcp
@ -74,4 +76,7 @@ spec:
- name: kubeadm - name: kubeadm
hostPath: hostPath:
path: /etc/kubernetes path: /etc/kubernetes
- name: ssh
hostPath:
path: /var/run/state/ssh
updateStrategy: {} updateStrategy: {}

View file

@ -53,6 +53,8 @@ spec:
- mountPath: /var/secrets/google - mountPath: /var/secrets/google
name: gcekey name: gcekey
readOnly: true readOnly: true
- mountPath: /var/run/state/ssh
name: ssh
ports: ports:
- containerPort: 9090 - containerPort: 9090
name: tcp name: tcp
@ -74,4 +76,7 @@ spec:
- name: kubeadm - name: kubeadm
hostPath: hostPath:
path: /etc/kubernetes path: /etc/kubernetes
- name: ssh
hostPath:
path: /var/run/state/ssh
updateStrategy: {} updateStrategy: {}

View file

@ -53,6 +53,8 @@ spec:
- mountPath: /var/secrets/google - mountPath: /var/secrets/google
name: gcekey name: gcekey
readOnly: true readOnly: true
- mountPath: /var/run/state/ssh
name: ssh
ports: ports:
- containerPort: 9090 - containerPort: 9090
name: tcp name: tcp
@ -74,4 +76,7 @@ spec:
- name: kubeadm - name: kubeadm
hostPath: hostPath:
path: /etc/kubernetes path: /etc/kubernetes
- name: ssh
hostPath:
path: /var/run/state/ssh
updateStrategy: {} updateStrategy: {}

View file

@ -53,6 +53,8 @@ spec:
- mountPath: /var/secrets/google - mountPath: /var/secrets/google
name: gcekey name: gcekey
readOnly: true readOnly: true
- mountPath: /var/run/state/ssh
name: ssh
ports: ports:
- containerPort: 9090 - containerPort: 9090
name: tcp name: tcp
@ -74,4 +76,7 @@ spec:
- name: kubeadm - name: kubeadm
hostPath: hostPath:
path: /etc/kubernetes path: /etc/kubernetes
- name: ssh
hostPath:
path: /var/run/state/ssh
updateStrategy: {} updateStrategy: {}

View file

@ -53,6 +53,8 @@ spec:
- mountPath: /var/secrets/google - mountPath: /var/secrets/google
name: gcekey name: gcekey
readOnly: true readOnly: true
- mountPath: /var/run/state/ssh
name: ssh
ports: ports:
- containerPort: 9090 - containerPort: 9090
name: tcp name: tcp
@ -74,4 +76,7 @@ spec:
- name: kubeadm - name: kubeadm
hostPath: hostPath:
path: /etc/kubernetes path: /etc/kubernetes
- name: ssh
hostPath:
path: /var/run/state/ssh
updateStrategy: {} updateStrategy: {}

View file

@ -17,6 +17,7 @@ import (
"fmt" "fmt"
"io" "io"
"math/big" "math/big"
"time"
"golang.org/x/crypto/hkdf" "golang.org/x/crypto/hkdf"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
@ -77,6 +78,28 @@ func GenerateEmergencySSHCAKey(seed []byte) (ssh.Signer, error) {
return ca, nil return ca, nil
} }
// GenerateSSHHostCertificate takes a given public key and CA to generate a host certificate.
func GenerateSSHHostCertificate(principals []string, publicKey ssh.PublicKey, ca ssh.Signer) (*ssh.Certificate, error) {
certificate := ssh.Certificate{
CertType: ssh.HostCert,
ValidPrincipals: principals,
ValidAfter: uint64(time.Now().Unix()),
ValidBefore: ssh.CertTimeInfinity,
Reserved: []byte{},
Key: publicKey,
KeyId: principals[0],
Permissions: ssh.Permissions{
CriticalOptions: map[string]string{},
Extensions: map[string]string{},
},
}
if err := certificate.SignCert(rand.Reader, ca); err != nil {
return nil, err
}
return &certificate, nil
}
// PemToX509Cert takes a list of PEM-encoded certificates, parses the first one and returns it // PemToX509Cert takes a list of PEM-encoded certificates, parses the first one and returns it
// as an x.509 certificate. // as an x.509 certificate.
func PemToX509Cert(raw []byte) (*x509.Certificate, error) { func PemToX509Cert(raw []byte) (*x509.Certificate, error) {

View file

@ -116,6 +116,7 @@ func main() {
keyServiceClient, keyServiceClient,
kubeClient, kubeClient,
log.WithGroup("server"), log.WithGroup("server"),
file.NewHandler(afero.NewOsFs()),
) )
if err != nil { if err != nil {
log.With(slog.Any("error", err)).Error("Failed to create server") log.With(slog.Any("error", err)).Error("Failed to create server")

View file

@ -10,6 +10,7 @@ go_library(
"//internal/attestation", "//internal/attestation",
"//internal/constants", "//internal/constants",
"//internal/crypto", "//internal/crypto",
"//internal/file",
"//internal/grpc/grpclog", "//internal/grpc/grpclog",
"//internal/logger", "//internal/logger",
"//internal/versions/components", "//internal/versions/components",
@ -30,12 +31,15 @@ go_test(
deps = [ deps = [
"//internal/attestation", "//internal/attestation",
"//internal/constants", "//internal/constants",
"//internal/file",
"//internal/logger", "//internal/logger",
"//internal/versions/components", "//internal/versions/components",
"//joinservice/joinproto", "//joinservice/joinproto",
"@com_github_spf13_afero//:afero",
"@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require", "@com_github_stretchr_testify//require",
"@io_k8s_kubernetes//cmd/kubeadm/app/apis/kubeadm/v1beta3", "@io_k8s_kubernetes//cmd/kubeadm/app/apis/kubeadm/v1beta3",
"@org_golang_x_crypto//ssh",
"@org_uber_go_goleak//:goleak", "@org_uber_go_goleak//:goleak",
], ],
) )

View file

@ -13,11 +13,13 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"net" "net"
"strings"
"time" "time"
"github.com/edgelesssys/constellation/v2/internal/attestation" "github.com/edgelesssys/constellation/v2/internal/attestation"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/crypto" "github.com/edgelesssys/constellation/v2/internal/crypto"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/grpc/grpclog" "github.com/edgelesssys/constellation/v2/internal/grpc/grpclog"
"github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/versions/components" "github.com/edgelesssys/constellation/v2/internal/versions/components"
@ -40,6 +42,7 @@ type Server struct {
dataKeyGetter dataKeyGetter dataKeyGetter dataKeyGetter
ca certificateAuthority ca certificateAuthority
kubeClient kubeClient kubeClient kubeClient
fileHandler file.Handler
joinproto.UnimplementedAPIServer joinproto.UnimplementedAPIServer
} }
@ -47,6 +50,7 @@ type Server struct {
func New( func New(
measurementSalt []byte, ca certificateAuthority, measurementSalt []byte, ca certificateAuthority,
joinTokenGetter joinTokenGetter, dataKeyGetter dataKeyGetter, kubeClient kubeClient, log *slog.Logger, joinTokenGetter joinTokenGetter, dataKeyGetter dataKeyGetter, kubeClient kubeClient, log *slog.Logger,
fileHandler file.Handler,
) (*Server, error) { ) (*Server, error) {
return &Server{ return &Server{
measurementSalt: measurementSalt, measurementSalt: measurementSalt,
@ -55,6 +59,7 @@ func New(
dataKeyGetter: dataKeyGetter, dataKeyGetter: dataKeyGetter,
ca: ca, ca: ca,
kubeClient: kubeClient, kubeClient: kubeClient,
fileHandler: fileHandler,
}, nil }, nil
} }
@ -114,6 +119,25 @@ func (s *Server) IssueJoinTicket(ctx context.Context, req *joinproto.IssueJoinTi
return nil, status.Errorf(codes.Internal, "generating ssh emergency CA key: %s", err) return nil, status.Errorf(codes.Internal, "generating ssh emergency CA key: %s", err)
} }
principalList := req.HostCertificatePrincipals
additionalPrincipals, err := s.fileHandler.Read(constants.SSHAdditionalPrincipalsPath)
if err != nil {
log.With(slog.Any("error", err)).Error("Failed to read additional principals file")
return nil, status.Errorf(codes.Internal, "reading additional principals file: %s", err)
}
principalList = append(principalList, strings.Split(string(additionalPrincipals), ",")...)
publicKey, err := ssh.ParsePublicKey(req.HostPublicKey)
if err != nil {
log.With(slog.Any("error", err)).Error("Failed to parse host public key")
return nil, status.Errorf(codes.Internal, "unmarshalling host public key: %s", err)
}
hostCertificate, err := crypto.GenerateSSHHostCertificate(principalList, publicKey, ca)
if err != nil {
log.With(slog.Any("error", err)).Error("Failed to generate and sign SSH host key")
return nil, status.Errorf(codes.Internal, "generating and signing SSH host key: %s", err)
}
log.Info("Creating Kubernetes join token") log.Info("Creating Kubernetes join token")
kubeArgs, err := s.joinTokenGetter.GetJoinToken(constants.KubernetesJoinTokenTTL) kubeArgs, err := s.joinTokenGetter.GetJoinToken(constants.KubernetesJoinTokenTTL)
if err != nil { if err != nil {
@ -182,6 +206,7 @@ func (s *Server) IssueJoinTicket(ctx context.Context, req *joinproto.IssueJoinTi
ControlPlaneFiles: controlPlaneFiles, ControlPlaneFiles: controlPlaneFiles,
KubernetesComponents: components, KubernetesComponents: components,
AuthorizedCaPublicKey: ssh.MarshalAuthorizedKey(ca.PublicKey()), AuthorizedCaPublicKey: ssh.MarshalAuthorizedKey(ca.PublicKey()),
HostCertificate: ssh.MarshalAuthorizedKey(hostCertificate),
}, nil }, nil
} }

View file

@ -15,12 +15,15 @@ import (
"github.com/edgelesssys/constellation/v2/internal/attestation" "github.com/edgelesssys/constellation/v2/internal/attestation"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/versions/components" "github.com/edgelesssys/constellation/v2/internal/versions/components"
"github.com/edgelesssys/constellation/v2/joinservice/joinproto" "github.com/edgelesssys/constellation/v2/joinservice/joinproto"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/goleak" "go.uber.org/goleak"
"golang.org/x/crypto/ssh"
kubeadmv1 "k8s.io/kubernetes/cmd/kubeadm/app/apis/kubeadm/v1beta3" kubeadmv1 "k8s.io/kubernetes/cmd/kubeadm/app/apis/kubeadm/v1beta3"
) )
@ -36,6 +39,11 @@ func TestIssueJoinTicket(t *testing.T) {
measurementSecret := []byte{0x7, 0x8, 0x9} measurementSecret := []byte{0x7, 0x8, 0x9}
uuid := "uuid" uuid := "uuid"
pubkey, _, err := ed25519.GenerateKey(nil)
require.NoError(t, err)
hostSSHPubKey, err := ssh.NewPublicKey(pubkey)
require.NoError(t, err)
testJoinToken := &kubeadmv1.BootstrapTokenDiscovery{ testJoinToken := &kubeadmv1.BootstrapTokenDiscovery{
APIServerEndpoint: "192.0.2.1", APIServerEndpoint: "192.0.2.1",
CACertHashes: []string{"hash"}, CACertHashes: []string{"hash"},
@ -58,6 +66,8 @@ func TestIssueJoinTicket(t *testing.T) {
ca stubCA ca stubCA
kubeClient stubKubeClient kubeClient stubKubeClient
missingComponentsReferenceFile bool missingComponentsReferenceFile bool
missingAdditionalPrincipalsFile bool
missingSSHHostKey bool
wantErr bool wantErr bool
}{ }{
"worker node": { "worker node": {
@ -179,6 +189,30 @@ func TestIssueJoinTicket(t *testing.T) {
kubeClient: stubKubeClient{getComponentsVal: clusterComponents, getK8sComponentsRefFromNodeVersionCRDVal: "k8s-components-ref"}, kubeClient: stubKubeClient{getComponentsVal: clusterComponents, getK8sComponentsRefFromNodeVersionCRDVal: "k8s-components-ref"},
wantErr: true, wantErr: true,
}, },
"Additional principals file is missing": {
kubeadm: stubTokenGetter{token: testJoinToken},
kms: stubKeyGetter{dataKeys: map[string][]byte{
uuid: testKey,
attestation.MeasurementSecretContext: measurementSecret,
constants.SSHCAKeySuffix: testCaKey,
}},
ca: stubCA{cert: testCert, nodeName: "node"},
kubeClient: stubKubeClient{getComponentsVal: clusterComponents, getK8sComponentsRefFromNodeVersionCRDVal: "k8s-components-ref"},
missingAdditionalPrincipalsFile: true,
wantErr: true,
},
"Host pubkey is missing": {
kubeadm: stubTokenGetter{token: testJoinToken},
kms: stubKeyGetter{dataKeys: map[string][]byte{
uuid: testKey,
attestation.MeasurementSecretContext: measurementSecret,
constants.SSHCAKeySuffix: testCaKey,
}},
ca: stubCA{cert: testCert, nodeName: "node"},
kubeClient: stubKubeClient{getComponentsVal: clusterComponents, getK8sComponentsRefFromNodeVersionCRDVal: "k8s-components-ref"},
missingSSHHostKey: true,
wantErr: true,
},
} }
for name, tc := range testCases { for name, tc := range testCases {
@ -188,6 +222,11 @@ func TestIssueJoinTicket(t *testing.T) {
salt := []byte{0xA, 0xB, 0xC} salt := []byte{0xA, 0xB, 0xC}
fh := file.NewHandler(afero.NewMemMapFs())
if !tc.missingAdditionalPrincipalsFile {
require.NoError(fh.Write(constants.SSHAdditionalPrincipalsPath, []byte("*"), file.OptMkdirAll))
}
api := Server{ api := Server{
measurementSalt: salt, measurementSalt: salt,
ca: tc.ca, ca: tc.ca,
@ -195,11 +234,20 @@ func TestIssueJoinTicket(t *testing.T) {
dataKeyGetter: tc.kms, dataKeyGetter: tc.kms,
kubeClient: &tc.kubeClient, kubeClient: &tc.kubeClient,
log: logger.NewTest(t), log: logger.NewTest(t),
fileHandler: fh,
}
var keyToSend []byte
if tc.missingSSHHostKey {
keyToSend = nil
} else {
keyToSend = hostSSHPubKey.Marshal()
} }
req := &joinproto.IssueJoinTicketRequest{ req := &joinproto.IssueJoinTicketRequest{
DiskUuid: "uuid", DiskUuid: "uuid",
IsControlPlane: tc.isControlPlane, IsControlPlane: tc.isControlPlane,
HostPublicKey: keyToSend,
} }
resp, err := api.IssueJoinTicket(t.Context(), req) resp, err := api.IssueJoinTicket(t.Context(), req)
if tc.wantErr { if tc.wantErr {
@ -260,6 +308,7 @@ func TestIssueRejoinTicker(t *testing.T) {
joinTokenGetter: stubTokenGetter{}, joinTokenGetter: stubTokenGetter{},
dataKeyGetter: tc.keyGetter, dataKeyGetter: tc.keyGetter,
log: logger.NewTest(t), log: logger.NewTest(t),
fileHandler: file.NewHandler(afero.NewMemMapFs()),
} }
req := &joinproto.IssueRejoinTicketRequest{ req := &joinproto.IssueRejoinTicketRequest{

View file

@ -31,6 +31,8 @@ type IssueJoinTicketRequest struct {
DiskUuid string `protobuf:"bytes,1,opt,name=disk_uuid,json=diskUuid,proto3" json:"disk_uuid,omitempty"` DiskUuid string `protobuf:"bytes,1,opt,name=disk_uuid,json=diskUuid,proto3" json:"disk_uuid,omitempty"`
CertificateRequest []byte `protobuf:"bytes,2,opt,name=certificate_request,json=certificateRequest,proto3" json:"certificate_request,omitempty"` CertificateRequest []byte `protobuf:"bytes,2,opt,name=certificate_request,json=certificateRequest,proto3" json:"certificate_request,omitempty"`
IsControlPlane bool `protobuf:"varint,3,opt,name=is_control_plane,json=isControlPlane,proto3" json:"is_control_plane,omitempty"` IsControlPlane bool `protobuf:"varint,3,opt,name=is_control_plane,json=isControlPlane,proto3" json:"is_control_plane,omitempty"`
HostPublicKey []byte `protobuf:"bytes,4,opt,name=host_public_key,json=hostPublicKey,proto3" json:"host_public_key,omitempty"`
HostCertificatePrincipals []string `protobuf:"bytes,5,rep,name=host_certificate_principals,json=hostCertificatePrincipals,proto3" json:"host_certificate_principals,omitempty"`
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
} }
@ -86,6 +88,20 @@ func (x *IssueJoinTicketRequest) GetIsControlPlane() bool {
return false return false
} }
func (x *IssueJoinTicketRequest) GetHostPublicKey() []byte {
if x != nil {
return x.HostPublicKey
}
return nil
}
func (x *IssueJoinTicketRequest) GetHostCertificatePrincipals() []string {
if x != nil {
return x.HostCertificatePrincipals
}
return nil
}
type IssueJoinTicketResponse struct { type IssueJoinTicketResponse struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
StateDiskKey []byte `protobuf:"bytes,1,opt,name=state_disk_key,json=stateDiskKey,proto3" json:"state_disk_key,omitempty"` StateDiskKey []byte `protobuf:"bytes,1,opt,name=state_disk_key,json=stateDiskKey,proto3" json:"state_disk_key,omitempty"`
@ -99,6 +115,7 @@ type IssueJoinTicketResponse struct {
KubernetesVersion string `protobuf:"bytes,9,opt,name=kubernetes_version,json=kubernetesVersion,proto3" json:"kubernetes_version,omitempty"` KubernetesVersion string `protobuf:"bytes,9,opt,name=kubernetes_version,json=kubernetesVersion,proto3" json:"kubernetes_version,omitempty"`
KubernetesComponents []*components.Component `protobuf:"bytes,10,rep,name=kubernetes_components,json=kubernetesComponents,proto3" json:"kubernetes_components,omitempty"` KubernetesComponents []*components.Component `protobuf:"bytes,10,rep,name=kubernetes_components,json=kubernetesComponents,proto3" json:"kubernetes_components,omitempty"`
AuthorizedCaPublicKey []byte `protobuf:"bytes,11,opt,name=authorized_ca_public_key,json=authorizedCaPublicKey,proto3" json:"authorized_ca_public_key,omitempty"` AuthorizedCaPublicKey []byte `protobuf:"bytes,11,opt,name=authorized_ca_public_key,json=authorizedCaPublicKey,proto3" json:"authorized_ca_public_key,omitempty"`
HostCertificate []byte `protobuf:"bytes,12,opt,name=host_certificate,json=hostCertificate,proto3" json:"host_certificate,omitempty"`
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
} }
@ -210,6 +227,13 @@ func (x *IssueJoinTicketResponse) GetAuthorizedCaPublicKey() []byte {
return nil return nil
} }
func (x *IssueJoinTicketResponse) GetHostCertificate() []byte {
if x != nil {
return x.HostCertificate
}
return nil
}
type ControlPlaneCertOrKey struct { type ControlPlaneCertOrKey struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
@ -362,11 +386,13 @@ var File_joinservice_joinproto_join_proto protoreflect.FileDescriptor
const file_joinservice_joinproto_join_proto_rawDesc = "" + const file_joinservice_joinproto_join_proto_rawDesc = "" +
"\n" + "\n" +
" joinservice/joinproto/join.proto\x12\x04join\x1a-internal/versions/components/components.proto\"\x90\x01\n" + " joinservice/joinproto/join.proto\x12\x04join\x1a-internal/versions/components/components.proto\"\xf8\x01\n" +
"\x16IssueJoinTicketRequest\x12\x1b\n" + "\x16IssueJoinTicketRequest\x12\x1b\n" +
"\tdisk_uuid\x18\x01 \x01(\tR\bdiskUuid\x12/\n" + "\tdisk_uuid\x18\x01 \x01(\tR\bdiskUuid\x12/\n" +
"\x13certificate_request\x18\x02 \x01(\fR\x12certificateRequest\x12(\n" + "\x13certificate_request\x18\x02 \x01(\fR\x12certificateRequest\x12(\n" +
"\x10is_control_plane\x18\x03 \x01(\bR\x0eisControlPlane\"\xc7\x04\n" + "\x10is_control_plane\x18\x03 \x01(\bR\x0eisControlPlane\x12&\n" +
"\x0fhost_public_key\x18\x04 \x01(\fR\rhostPublicKey\x12>\n" +
"\x1bhost_certificate_principals\x18\x05 \x03(\tR\x19hostCertificatePrincipals\"\xf2\x04\n" +
"\x17IssueJoinTicketResponse\x12$\n" + "\x17IssueJoinTicketResponse\x12$\n" +
"\x0estate_disk_key\x18\x01 \x01(\fR\fstateDiskKey\x12)\n" + "\x0estate_disk_key\x18\x01 \x01(\fR\fstateDiskKey\x12)\n" +
"\x10measurement_salt\x18\x02 \x01(\fR\x0fmeasurementSalt\x12-\n" + "\x10measurement_salt\x18\x02 \x01(\fR\x0fmeasurementSalt\x12-\n" +
@ -379,7 +405,8 @@ const file_joinservice_joinproto_join_proto_rawDesc = "" +
"\x12kubernetes_version\x18\t \x01(\tR\x11kubernetesVersion\x12J\n" + "\x12kubernetes_version\x18\t \x01(\tR\x11kubernetesVersion\x12J\n" +
"\x15kubernetes_components\x18\n" + "\x15kubernetes_components\x18\n" +
" \x03(\v2\x15.components.ComponentR\x14kubernetesComponents\x127\n" + " \x03(\v2\x15.components.ComponentR\x14kubernetesComponents\x127\n" +
"\x18authorized_ca_public_key\x18\v \x01(\fR\x15authorizedCaPublicKey\"C\n" + "\x18authorized_ca_public_key\x18\v \x01(\fR\x15authorizedCaPublicKey\x12)\n" +
"\x10host_certificate\x18\f \x01(\fR\x0fhostCertificate\"C\n" +
"\x19control_plane_cert_or_key\x12\x12\n" + "\x19control_plane_cert_or_key\x12\x12\n" +
"\x04name\x18\x01 \x01(\tR\x04name\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\x12\x12\n" +
"\x04data\x18\x02 \x01(\fR\x04data\"7\n" + "\x04data\x18\x02 \x01(\fR\x04data\"7\n" +

View file

@ -20,6 +20,10 @@ message IssueJoinTicketRequest {
bytes certificate_request = 2; bytes certificate_request = 2;
// is_control_plane indicates whether the node is a control-plane node. // is_control_plane indicates whether the node is a control-plane node.
bool is_control_plane = 3; bool is_control_plane = 3;
// host_public_key is the public host key that should be signed.
bytes host_public_key = 4;
// host_certificate_principals are principals that should be added to the host certificate.
repeated string host_certificate_principals = 5;
} }
message IssueJoinTicketResponse { message IssueJoinTicketResponse {
@ -47,6 +51,8 @@ message IssueJoinTicketResponse {
repeated components.Component kubernetes_components = 10; repeated components.Component kubernetes_components = 10;
// authorized_ca_public_key is an ssh ca key that can be used to connect to a node in case of an emergency. // authorized_ca_public_key is an ssh ca key that can be used to connect to a node in case of an emergency.
bytes authorized_ca_public_key = 11; bytes authorized_ca_public_key = 11;
// host_certificate is the certificate that can be used to verify a nodes host key.
bytes host_certificate = 12;
} }
message control_plane_cert_or_key { message control_plane_cert_or_key {