mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-05-02 06:16:08 -04:00
AB#2305 Fix missing atls verifier in init call (#352)
Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
parent
aee3f2afa2
commit
8f5f84deb5
9 changed files with 184 additions and 70 deletions
|
@ -13,12 +13,16 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/bootstrapper/initproto"
|
||||
"github.com/edgelesssys/constellation/cli/internal/cloudcmd"
|
||||
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/internal/cloud/cloudtypes"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
"github.com/edgelesssys/constellation/internal/file"
|
||||
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
|
||||
"github.com/edgelesssys/constellation/internal/grpc/dialer"
|
||||
"github.com/edgelesssys/constellation/internal/grpc/testdialer"
|
||||
"github.com/edgelesssys/constellation/internal/oid"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
|
@ -131,7 +135,9 @@ func TestInitialize(t *testing.T) {
|
|||
require := require.New(t)
|
||||
|
||||
netDialer := testdialer.NewBufconnDialer()
|
||||
dialer := dialer.New(nil, nil, netDialer)
|
||||
newDialer := func(*cloudcmd.Validator) *dialer.Dialer {
|
||||
return dialer.New(nil, nil, netDialer)
|
||||
}
|
||||
serverCreds := atlscredentials.New(nil, nil)
|
||||
initServer := grpc.NewServer(grpc.Creds(serverCreds))
|
||||
initproto.RegisterAPIServer(initServer, tc.initServerAPI)
|
||||
|
@ -145,7 +151,7 @@ func TestInitialize(t *testing.T) {
|
|||
cmd.SetOut(&out)
|
||||
var errOut bytes.Buffer
|
||||
cmd.SetErr(&errOut)
|
||||
cmd.Flags().String("config", "", "") // register persisten flag manually
|
||||
cmd.Flags().String("config", "", "") // register persistent flag manually
|
||||
fs := afero.NewMemMapFs()
|
||||
fileHandler := file.NewHandler(fs)
|
||||
require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.existingState, file.OptNone))
|
||||
|
@ -156,7 +162,7 @@ func TestInitialize(t *testing.T) {
|
|||
defer cancel()
|
||||
cmd.SetContext(ctx)
|
||||
|
||||
err := initialize(cmd, dialer, &tc.serviceAccountCreator, fileHandler)
|
||||
err := initialize(cmd, newDialer, &tc.serviceAccountCreator, fileHandler)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
|
@ -364,6 +370,122 @@ func TestReadOrGeneratedMasterSecret(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestAttestation(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
initServerAPI := &stubInitServer{initResp: &initproto.InitResponse{
|
||||
Kubeconfig: []byte("kubeconfig"),
|
||||
OwnerId: []byte("ownerID"),
|
||||
ClusterId: []byte("clusterID"),
|
||||
}}
|
||||
existingState := state.ConstellationState{
|
||||
CloudProvider: "QEMU",
|
||||
BootstrapperHost: "192.0.2.1",
|
||||
QEMUWorkers: cloudtypes.Instances{
|
||||
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
},
|
||||
QEMUControlPlane: cloudtypes.Instances{
|
||||
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
},
|
||||
}
|
||||
|
||||
netDialer := testdialer.NewBufconnDialer()
|
||||
newDialer := func(v *cloudcmd.Validator) *dialer.Dialer {
|
||||
validator := &testValidator{
|
||||
Getter: oid.QEMU{},
|
||||
pcrs: v.PCRS(),
|
||||
}
|
||||
return dialer.New(nil, validator, netDialer)
|
||||
}
|
||||
|
||||
issuer := &testIssuer{
|
||||
Getter: oid.QEMU{},
|
||||
pcrs: map[uint32][]byte{
|
||||
0: []byte("ffffffffffffffffffffffffffffffff"),
|
||||
1: []byte("ffffffffffffffffffffffffffffffff"),
|
||||
2: []byte("ffffffffffffffffffffffffffffffff"),
|
||||
3: []byte("ffffffffffffffffffffffffffffffff"),
|
||||
},
|
||||
}
|
||||
serverCreds := atlscredentials.New(issuer, nil)
|
||||
initServer := grpc.NewServer(grpc.Creds(serverCreds))
|
||||
initproto.RegisterAPIServer(initServer, initServerAPI)
|
||||
port := strconv.Itoa(constants.BootstrapperPort)
|
||||
listener := netDialer.GetListener(net.JoinHostPort("192.0.2.1", port))
|
||||
go initServer.Serve(listener)
|
||||
defer initServer.GracefulStop()
|
||||
|
||||
cmd := NewInitCmd()
|
||||
cmd.Flags().String("config", constants.ConfigFilename, "") // register persistent flag manually
|
||||
var out bytes.Buffer
|
||||
cmd.SetOut(&out)
|
||||
var errOut bytes.Buffer
|
||||
cmd.SetErr(&errOut)
|
||||
|
||||
fs := afero.NewMemMapFs()
|
||||
fileHandler := file.NewHandler(fs)
|
||||
require.NoError(fileHandler.WriteJSON(constants.StateFilename, existingState, file.OptNone))
|
||||
|
||||
cfg := config.Default()
|
||||
cfg.RemoveProviderExcept(cloudprovider.QEMU)
|
||||
cfg.Provider.QEMU.Measurements[0] = []byte("00000000000000000000000000000000")
|
||||
cfg.Provider.QEMU.Measurements[1] = []byte("11111111111111111111111111111111")
|
||||
cfg.Provider.QEMU.Measurements[2] = []byte("22222222222222222222222222222222")
|
||||
cfg.Provider.QEMU.Measurements[3] = []byte("33333333333333333333333333333333")
|
||||
require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg, file.OptNone))
|
||||
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
|
||||
defer cancel()
|
||||
cmd.SetContext(ctx)
|
||||
|
||||
err := initialize(cmd, newDialer, &stubServiceAccountCreator{}, fileHandler)
|
||||
assert.Error(err)
|
||||
// make sure the error is actually a TLS handshake error
|
||||
assert.Contains(err.Error(), "transport: authentication handshake failed")
|
||||
}
|
||||
|
||||
type testValidator struct {
|
||||
oid.Getter
|
||||
pcrs map[uint32][]byte
|
||||
}
|
||||
|
||||
func (v *testValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
|
||||
var attestation struct {
|
||||
UserData []byte
|
||||
PCRs map[uint32][]byte
|
||||
}
|
||||
if err := json.Unmarshal(attDoc, &attestation); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for k, pcr := range v.pcrs {
|
||||
if !bytes.Equal(attestation.PCRs[k], pcr) {
|
||||
return nil, errors.New("invalid PCR value")
|
||||
}
|
||||
}
|
||||
return attestation.UserData, nil
|
||||
}
|
||||
|
||||
type testIssuer struct {
|
||||
oid.Getter
|
||||
pcrs map[uint32][]byte
|
||||
}
|
||||
|
||||
func (i *testIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) {
|
||||
return json.Marshal(
|
||||
struct {
|
||||
UserData []byte
|
||||
PCRs map[uint32][]byte
|
||||
}{
|
||||
UserData: userData,
|
||||
PCRs: i.pcrs,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
type stubInitServer struct {
|
||||
initResp *initproto.InitResponse
|
||||
initErr error
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue