constellation/joinservice/internal/watcher/validator_test.go
Moritz Sanft a5021c52d3
joinservice: cache certificates for Azure SEV-SNP attestation (#2336)
* add ASK caching in joinservice

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* use cached ASK in Azure SEV-SNP attestation

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* update test charts

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* fix linter

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* fix typ

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* make caching mechanism less provider-specific

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* update buildfiles

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* add `omitempty` flag

Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com>

* frontload certificate getter

Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com>

* rename frontloaded function

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* pass cached certificates to constructor

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* fix race condition

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* fix marshalling of empty certs

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* fix validator usage

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* [wip] add certcache tests

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* add certcache tests

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* tidy

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* fix validator test

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* remove unused fields in validator

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* fix certificate precedence

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* use separate context

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* tidy

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* linter fixes

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* linter fixes

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* Remove unnecessary comment

Co-authored-by: Thomas Tendyck <51411342+thomasten@users.noreply.github.com>

* use background context

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* Use error format directive

Co-authored-by: Thomas Tendyck <51411342+thomasten@users.noreply.github.com>

* `azure` -> `Azure`

Co-authored-by: Thomas Tendyck <51411342+thomasten@users.noreply.github.com>

* improve error messages

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* add x509 -> PEM util function

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* use crypto util functions

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* fix certificate replacement logic

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* only require ASK from certcache

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* tidy

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* fix comment typo

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

---------

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>
Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com>
Co-authored-by: Thomas Tendyck <51411342+thomasten@users.noreply.github.com>
2023-09-29 14:29:50 +02:00

268 lines
7.0 KiB
Go

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package watcher
import (
"context"
"crypto/x509"
"encoding/asn1"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"path/filepath"
"sync"
"testing"
"github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/variant"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m,
// https://github.com/census-instrumentation/opencensus-go/issues/1262
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
)
}
func TestNewUpdateableValidator(t *testing.T) {
testCases := map[string]struct {
variant variant.Variant
config config.AttestationCfg
snpCerts *stubSnpCerts
wantErr bool
}{
"azure": {
variant: variant.AzureSEVSNP{},
config: config.DefaultForAzureSEVSNP(),
snpCerts: &stubSnpCerts{
ask: &x509.Certificate{},
ark: &x509.Certificate{},
},
},
"gcp": {
variant: variant.GCPSEVES{},
config: &config.GCPSEVES{Measurements: measurements.M{11: measurements.WithAllBytes(0x00, measurements.Enforce, measurements.PCRMeasurementLength)}},
},
"qemu": {
variant: variant.QEMUVTPM{},
config: &config.QEMUVTPM{Measurements: measurements.M{11: measurements.WithAllBytes(0x00, measurements.Enforce, measurements.PCRMeasurementLength)}},
},
"no file": {
variant: variant.AzureSEVSNP{},
wantErr: true,
},
"invalid provider": {
variant: fakeOID{1, 3, 9900, 9999, 9999},
config: &config.GCPSEVES{Measurements: measurements.M{11: measurements.WithAllBytes(0x00, measurements.Enforce, measurements.PCRMeasurementLength)}},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
handler := file.NewHandler(afero.NewMemMapFs())
if tc.config != nil {
require.NoError(handler.WriteJSON(
filepath.Join(constants.ServiceBasePath, constants.AttestationConfigFilename),
tc.config,
))
}
_, err := NewValidator(
logger.NewTest(t),
tc.variant,
handler,
tc.snpCerts,
)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
type stubSnpCerts struct {
ask *x509.Certificate
ark *x509.Certificate
}
func (s *stubSnpCerts) SevSnpCerts() (ask *x509.Certificate, ark *x509.Certificate) {
return s.ask, s.ark
}
func TestUpdate(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
handler := file.NewHandler(afero.NewMemMapFs())
// create server
validator := &Updatable{
log: logger.NewTest(t),
variant: variant.Dummy{},
fileHandler: handler,
}
// Update should fail if the file does not exist
assert.Error(validator.Update())
// write measurement config
require.NoError(handler.WriteJSON(
filepath.Join(constants.ServiceBasePath, constants.AttestationConfigFilename),
&config.DummyCfg{Measurements: measurements.M{11: measurements.WithAllBytes(0x00, measurements.Enforce, measurements.PCRMeasurementLength)}},
))
// call update once to initialize the server's validator
require.NoError(validator.Update())
// create tls config and start the server
serverConfig, err := atls.CreateAttestationServerTLSConfig(nil, []atls.Validator{validator})
require.NoError(err)
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "hello")
}))
server.TLS = serverConfig
server.StartTLS()
defer server.Close()
// test connection to server
clientOID := variant.Dummy{}
resp, err := testConnection(require, server.URL, clientOID)
require.NoError(err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(err)
assert.EqualValues("hello", body)
// update the server's validator
validator.variant = variant.QEMUVTPM{}
require.NoError(validator.Update())
// client connection should fail now, since the server's validator expects a different OID from the client
resp, err = testConnection(require, server.URL, clientOID)
if err == nil {
defer resp.Body.Close()
}
assert.Error(err)
}
func TestOIDConcurrency(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
handler := file.NewHandler(afero.NewMemMapFs())
require.NoError(handler.WriteJSON(
filepath.Join(constants.ServiceBasePath, constants.AttestationConfigFilename),
&config.DummyCfg{Measurements: measurements.M{11: measurements.WithAllBytes(0x00, measurements.Enforce, measurements.PCRMeasurementLength)}},
))
// create server
validator := &Updatable{
log: logger.NewTest(t),
variant: variant.Dummy{},
fileHandler: handler,
}
// call update once to initialize the server's validator
require.NoError(validator.Update())
var wg sync.WaitGroup
wg.Add(2 * 20)
for i := 0; i < 20; i++ {
go func() {
defer wg.Done()
assert.NoError(validator.Update())
}()
go func() {
defer wg.Done()
validator.OID()
}()
}
wg.Wait()
}
func TestUpdateConcurrency(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
handler := file.NewHandler(afero.NewMemMapFs())
validator := &Updatable{
log: logger.NewTest(t),
fileHandler: handler,
variant: variant.Dummy{},
}
require.NoError(handler.WriteJSON(
filepath.Join(constants.ServiceBasePath, constants.AttestationConfigFilename),
&config.DummyCfg{Measurements: measurements.M{11: measurements.WithAllBytes(0x00, measurements.Enforce, measurements.PCRMeasurementLength)}},
file.OptNone,
))
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
assert.NoError(validator.Update())
}()
}
wg.Wait()
}
func testConnection(require *require.Assertions, url string, oid variant.Getter) (*http.Response, error) {
clientConfig, err := atls.CreateAttestationClientTLSConfig(fakeIssuer{oid}, nil)
require.NoError(err)
client := http.Client{Transport: &http.Transport{TLSClientConfig: clientConfig}}
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, http.NoBody)
require.NoError(err)
return client.Do(req)
}
type fakeIssuer struct {
variant.Getter
}
func (fakeIssuer) Issue(_ context.Context, userData []byte, nonce []byte) ([]byte, error) {
return json.Marshal(fakeDoc{UserData: userData, Nonce: nonce})
}
type fakeOID asn1.ObjectIdentifier
func (o fakeOID) OID() asn1.ObjectIdentifier {
return asn1.ObjectIdentifier(o)
}
func (o fakeOID) String() string {
return o.OID().String()
}
func (o fakeOID) Equal(other variant.Getter) bool {
return o.OID().Equal(other.OID())
}
type fakeDoc struct {
UserData []byte
Nonce []byte
}