config: add attestation variant (#1413)

* Add attestation type to config (optional for now)

* Get attestation variant from config in CLI

* Set attestation variant for Constellation services in helm deployments

* Remove AzureCVM variable from helm deployments

---------

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2023-03-14 11:46:27 +01:00 committed by GitHub
parent 8679988b6c
commit 6ea5588bdc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
44 changed files with 379 additions and 383 deletions

View file

@ -7,11 +7,9 @@ SPDX-License-Identifier: AGPL-3.0-only
package watcher
import (
"bytes"
"context"
"encoding/asn1"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
@ -22,9 +20,11 @@ import (
"github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/attestation/idkeydigest"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/oid"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -40,29 +40,29 @@ func TestMain(m *testing.M) {
func TestNewUpdateableValidator(t *testing.T) {
testCases := map[string]struct {
provider string
variant oid.Getter
writeFile bool
wantErr bool
}{
"azure": {
provider: "azure",
variant: oid.AzureSEVSNP{},
writeFile: true,
},
"gcp": {
provider: "gcp",
variant: oid.GCPSEVES{},
writeFile: true,
},
"qemu": {
provider: "qemu",
variant: oid.QEMUVTPM{},
writeFile: true,
},
"no file": {
provider: "azure",
variant: oid.AzureSEVSNP{},
writeFile: false,
wantErr: true,
},
"invalid provider": {
provider: "invalid",
variant: fakeOID{1, 3, 9900, 9999, 9999},
writeFile: true,
wantErr: true,
},
@ -77,33 +77,24 @@ func TestNewUpdateableValidator(t *testing.T) {
if tc.writeFile {
require.NoError(handler.WriteJSON(
filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename),
map[uint32][]byte{
11: {0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
},
))
require.NoError(handler.WriteJSON(
filepath.Join(constants.ServiceBasePath, constants.EnforcedPCRsFilename),
[]uint32{11},
measurements.M{11: measurements.WithAllBytes(0x00, false)},
))
keyDigest, err := json.Marshal(idkeydigest.DefaultsFor(cloudprovider.Azure))
require.NoError(err)
require.NoError(handler.Write(
filepath.Join(constants.ServiceBasePath, constants.IDKeyDigestFilename),
[]byte{},
keyDigest,
))
require.NoError(handler.Write(
filepath.Join(constants.ServiceBasePath, constants.EnforceIDKeyDigestFilename),
[]byte("false"),
))
require.NoError(handler.Write(
filepath.Join(constants.ServiceBasePath, constants.AzureCVM),
[]byte("true"),
))
}
_, err := NewValidator(
logger.NewTest(t),
tc.provider,
tc.variant,
handler,
false,
)
if tc.wantErr {
assert.Error(err)
@ -118,26 +109,13 @@ func TestUpdate(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
// we need safe access for overwriting the fake validator OID
oid := fakeOID{1, 3, 9900, 1}
var oidLock sync.Mutex
updatedOID := func(newOID fakeOID) {
oidLock.Lock()
defer oidLock.Unlock()
oid = newOID
}
newValidator := func(m measurements.M, digest idkeydigest.IDKeyDigests, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator {
oidLock.Lock()
defer oidLock.Unlock()
return fakeValidator{fakeOID: oid}
}
handler := file.NewHandler(afero.NewMemMapFs())
// create server
validator := &Updatable{
log: logger.NewTest(t),
newValidator: newValidator,
fileHandler: handler,
log: logger.NewTest(t),
variant: oid.Dummy{},
fileHandler: handler,
}
// Update should fail if the file does not exist
@ -156,10 +134,6 @@ func TestUpdate(t *testing.T) {
filepath.Join(constants.ServiceBasePath, constants.EnforceIDKeyDigestFilename),
[]byte("false"),
))
require.NoError(handler.Write(
filepath.Join(constants.ServiceBasePath, constants.AzureCVM),
[]byte("true"),
))
// call update once to initialize the server's validator
require.NoError(validator.Update())
@ -175,7 +149,7 @@ func TestUpdate(t *testing.T) {
defer server.Close()
// test connection to server
clientOID := fakeOID{1, 3, 9900, 1}
clientOID := oid.Dummy{}
resp, err := testConnection(require, server.URL, clientOID)
require.NoError(err)
defer resp.Body.Close()
@ -184,7 +158,7 @@ func TestUpdate(t *testing.T) {
assert.EqualValues("hello", body)
// update the server's validator
updatedOID(fakeOID{1, 3, 9900, 2})
validator.variant = oid.QEMUVTPM{}
require.NoError(validator.Update())
// client connection should fail now, since the server's validator expects a different OID from the client
@ -193,23 +167,6 @@ func TestUpdate(t *testing.T) {
defer resp.Body.Close()
}
assert.Error(err)
// update should work for legacy measurement format
// TODO: remove with v2.4.0
require.NoError(handler.WriteJSON(
filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename),
map[uint32][]byte{
11: bytes.Repeat([]byte{0x0}, 32),
12: bytes.Repeat([]byte{0x1}, 32),
},
file.OptOverwrite,
))
require.NoError(handler.WriteJSON(
filepath.Join(constants.ServiceBasePath, constants.EnforcedPCRsFilename),
[]uint32{11},
))
assert.NoError(validator.Update())
}
func TestOIDConcurrency(t *testing.T) {
@ -226,14 +183,11 @@ func TestOIDConcurrency(t *testing.T) {
[]byte{},
))
newValidator := func(m measurements.M, digest idkeydigest.IDKeyDigests, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator {
return fakeValidator{fakeOID: fakeOID{1, 3, 9900, 1}}
}
// create server
validator := &Updatable{
log: logger.NewTest(t),
newValidator: newValidator,
fileHandler: handler,
log: logger.NewTest(t),
variant: oid.Dummy{},
fileHandler: handler,
}
// call update once to initialize the server's validator
@ -262,21 +216,13 @@ func TestUpdateConcurrency(t *testing.T) {
validator := &Updatable{
log: logger.NewTest(t),
fileHandler: handler,
newValidator: func(m measurements.M, digest idkeydigest.IDKeyDigests, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator {
return fakeValidator{fakeOID: fakeOID{1, 3, 9900, 1}}
},
variant: oid.Dummy{},
}
require.NoError(handler.WriteJSON(
filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename),
map[uint32][]byte{
11: {0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
},
measurements.M{11: measurements.WithAllBytes(0x00, false)},
file.OptNone,
))
require.NoError(handler.WriteJSON(
filepath.Join(constants.ServiceBasePath, constants.EnforcedPCRsFilename),
[]uint32{11},
))
require.NoError(handler.Write(
filepath.Join(constants.ServiceBasePath, constants.IDKeyDigestFilename),
[]byte{},
@ -285,10 +231,6 @@ func TestUpdateConcurrency(t *testing.T) {
filepath.Join(constants.ServiceBasePath, constants.EnforceIDKeyDigestFilename),
[]byte("false"),
))
require.NoError(handler.Write(
filepath.Join(constants.ServiceBasePath, constants.AzureCVM),
[]byte("true"),
))
var wg sync.WaitGroup
@ -303,8 +245,8 @@ func TestUpdateConcurrency(t *testing.T) {
wg.Wait()
}
func testConnection(require *require.Assertions, url string, oid fakeOID) (*http.Response, error) {
clientConfig, err := atls.CreateAttestationClientTLSConfig(fakeIssuer{fakeOID: oid}, nil)
func testConnection(require *require.Assertions, url string, oid oid.Getter) (*http.Response, error) {
clientConfig, err := atls.CreateAttestationClientTLSConfig(fakeIssuer{oid}, nil)
require.NoError(err)
client := http.Client{Transport: &http.Transport{TLSClientConfig: clientConfig}}
@ -314,29 +256,13 @@ func testConnection(require *require.Assertions, url string, oid fakeOID) (*http
}
type fakeIssuer struct {
fakeOID
oid.Getter
}
func (fakeIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) {
return json.Marshal(fakeDoc{UserData: userData, Nonce: nonce})
}
type fakeValidator struct {
fakeOID
err error
}
func (v fakeValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
var doc fakeDoc
if err := json.Unmarshal(attDoc, &doc); err != nil {
return nil, err
}
if !bytes.Equal(doc.Nonce, nonce) {
return nil, errors.New("invalid nonce")
}
return doc.UserData, v.err
}
type fakeOID asn1.ObjectIdentifier
func (o fakeOID) OID() asn1.ObjectIdentifier {