mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-01-17 18:37:08 -05:00
c2ea937fb5
Signed-off-by: Daniel Weiße <dw@edgeless.systems>
349 lines
9.0 KiB
Go
349 lines
9.0 KiB
Go
/*
|
|
Copyright (c) Edgeless Systems GmbH
|
|
|
|
SPDX-License-Identifier: AGPL-3.0-only
|
|
*/
|
|
|
|
package watcher
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/asn1"
|
|
"encoding/json"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"path/filepath"
|
|
"sync"
|
|
"testing"
|
|
|
|
"github.com/edgelesssys/constellation/v2/internal/atls"
|
|
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
|
|
"github.com/edgelesssys/constellation/v2/internal/constants"
|
|
"github.com/edgelesssys/constellation/v2/internal/file"
|
|
"github.com/edgelesssys/constellation/v2/internal/logger"
|
|
"github.com/spf13/afero"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/goleak"
|
|
)
|
|
|
|
func TestMain(m *testing.M) {
|
|
goleak.VerifyTestMain(m,
|
|
// https://github.com/census-instrumentation/opencensus-go/issues/1262
|
|
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
|
|
)
|
|
}
|
|
|
|
func TestNewUpdateableValidator(t *testing.T) {
|
|
testCases := map[string]struct {
|
|
provider string
|
|
writeFile bool
|
|
wantErr bool
|
|
}{
|
|
"azure": {
|
|
provider: "azure",
|
|
writeFile: true,
|
|
},
|
|
"gcp": {
|
|
provider: "gcp",
|
|
writeFile: true,
|
|
},
|
|
"qemu": {
|
|
provider: "qemu",
|
|
writeFile: true,
|
|
},
|
|
"no file": {
|
|
provider: "azure",
|
|
writeFile: false,
|
|
wantErr: true,
|
|
},
|
|
"invalid provider": {
|
|
provider: "invalid",
|
|
writeFile: true,
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for name, tc := range testCases {
|
|
t.Run(name, func(t *testing.T) {
|
|
assert := assert.New(t)
|
|
require := require.New(t)
|
|
|
|
handler := file.NewHandler(afero.NewMemMapFs())
|
|
if tc.writeFile {
|
|
require.NoError(handler.WriteJSON(
|
|
filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename),
|
|
map[uint32][]byte{
|
|
11: {0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
|
},
|
|
))
|
|
require.NoError(handler.WriteJSON(
|
|
filepath.Join(constants.ServiceBasePath, constants.EnforcedPCRsFilename),
|
|
[]uint32{11},
|
|
))
|
|
require.NoError(handler.Write(
|
|
filepath.Join(constants.ServiceBasePath, constants.IDKeyDigestFilename),
|
|
[]byte{},
|
|
))
|
|
require.NoError(handler.Write(
|
|
filepath.Join(constants.ServiceBasePath, constants.EnforceIDKeyDigestFilename),
|
|
[]byte("false"),
|
|
))
|
|
require.NoError(handler.Write(
|
|
filepath.Join(constants.ServiceBasePath, constants.AzureCVM),
|
|
[]byte("true"),
|
|
))
|
|
}
|
|
|
|
_, err := NewValidator(
|
|
logger.NewTest(t),
|
|
tc.provider,
|
|
handler,
|
|
false,
|
|
)
|
|
if tc.wantErr {
|
|
assert.Error(err)
|
|
} else {
|
|
assert.NoError(err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestUpdate(t *testing.T) {
|
|
assert := assert.New(t)
|
|
require := require.New(t)
|
|
|
|
// we need safe access for overwriting the fake validator OID
|
|
oid := fakeOID{1, 3, 9900, 1}
|
|
var oidLock sync.Mutex
|
|
updatedOID := func(newOID fakeOID) {
|
|
oidLock.Lock()
|
|
defer oidLock.Unlock()
|
|
oid = newOID
|
|
}
|
|
newValidator := func(m measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator {
|
|
oidLock.Lock()
|
|
defer oidLock.Unlock()
|
|
return fakeValidator{fakeOID: oid}
|
|
}
|
|
handler := file.NewHandler(afero.NewMemMapFs())
|
|
|
|
// create server
|
|
validator := &Updatable{
|
|
log: logger.NewTest(t),
|
|
newValidator: newValidator,
|
|
fileHandler: handler,
|
|
}
|
|
|
|
// Update should fail if the file does not exist
|
|
assert.Error(validator.Update())
|
|
|
|
// write measurement config
|
|
require.NoError(handler.WriteJSON(
|
|
filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename),
|
|
measurements.M{11: measurements.WithAllBytes(0x00, false)},
|
|
))
|
|
require.NoError(handler.Write(
|
|
filepath.Join(constants.ServiceBasePath, constants.IDKeyDigestFilename),
|
|
[]byte{},
|
|
))
|
|
require.NoError(handler.Write(
|
|
filepath.Join(constants.ServiceBasePath, constants.EnforceIDKeyDigestFilename),
|
|
[]byte("false"),
|
|
))
|
|
require.NoError(handler.Write(
|
|
filepath.Join(constants.ServiceBasePath, constants.AzureCVM),
|
|
[]byte("true"),
|
|
))
|
|
|
|
// call update once to initialize the server's validator
|
|
require.NoError(validator.Update())
|
|
|
|
// create tls config and start the server
|
|
serverConfig, err := atls.CreateAttestationServerTLSConfig(nil, []atls.Validator{validator})
|
|
require.NoError(err)
|
|
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = io.WriteString(w, "hello")
|
|
}))
|
|
server.TLS = serverConfig
|
|
server.StartTLS()
|
|
defer server.Close()
|
|
|
|
// test connection to server
|
|
clientOID := fakeOID{1, 3, 9900, 1}
|
|
resp, err := testConnection(require, server.URL, clientOID)
|
|
require.NoError(err)
|
|
defer resp.Body.Close()
|
|
body, err := io.ReadAll(resp.Body)
|
|
require.NoError(err)
|
|
assert.EqualValues("hello", body)
|
|
|
|
// update the server's validator
|
|
updatedOID(fakeOID{1, 3, 9900, 2})
|
|
require.NoError(validator.Update())
|
|
|
|
// client connection should fail now, since the server's validator expects a different OID from the client
|
|
resp, err = testConnection(require, server.URL, clientOID)
|
|
if err == nil {
|
|
defer resp.Body.Close()
|
|
}
|
|
assert.Error(err)
|
|
|
|
// update should work for legacy measurement format
|
|
// TODO: remove with v2.4.0
|
|
require.NoError(handler.WriteJSON(
|
|
filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename),
|
|
map[uint32][]byte{
|
|
11: bytes.Repeat([]byte{0x0}, 32),
|
|
12: bytes.Repeat([]byte{0x1}, 32),
|
|
},
|
|
file.OptOverwrite,
|
|
))
|
|
require.NoError(handler.WriteJSON(
|
|
filepath.Join(constants.ServiceBasePath, constants.EnforcedPCRsFilename),
|
|
[]uint32{11},
|
|
))
|
|
|
|
assert.NoError(validator.Update())
|
|
}
|
|
|
|
func TestOIDConcurrency(t *testing.T) {
|
|
assert := assert.New(t)
|
|
require := require.New(t)
|
|
|
|
handler := file.NewHandler(afero.NewMemMapFs())
|
|
require.NoError(handler.WriteJSON(
|
|
filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename),
|
|
measurements.M{11: measurements.WithAllBytes(0x00, false)},
|
|
))
|
|
require.NoError(handler.Write(
|
|
filepath.Join(constants.ServiceBasePath, constants.IDKeyDigestFilename),
|
|
[]byte{},
|
|
))
|
|
|
|
newValidator := func(m measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator {
|
|
return fakeValidator{fakeOID: fakeOID{1, 3, 9900, 1}}
|
|
}
|
|
// create server
|
|
validator := &Updatable{
|
|
log: logger.NewTest(t),
|
|
newValidator: newValidator,
|
|
fileHandler: handler,
|
|
}
|
|
|
|
// call update once to initialize the server's validator
|
|
require.NoError(validator.Update())
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(2 * 20)
|
|
for i := 0; i < 20; i++ {
|
|
go func() {
|
|
defer wg.Done()
|
|
assert.NoError(validator.Update())
|
|
}()
|
|
go func() {
|
|
defer wg.Done()
|
|
validator.OID()
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestUpdateConcurrency(t *testing.T) {
|
|
assert := assert.New(t)
|
|
require := require.New(t)
|
|
|
|
handler := file.NewHandler(afero.NewMemMapFs())
|
|
validator := &Updatable{
|
|
log: logger.NewTest(t),
|
|
fileHandler: handler,
|
|
newValidator: func(m measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator {
|
|
return fakeValidator{fakeOID: fakeOID{1, 3, 9900, 1}}
|
|
},
|
|
}
|
|
require.NoError(handler.WriteJSON(
|
|
filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename),
|
|
map[uint32][]byte{
|
|
11: {0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
|
},
|
|
file.OptNone,
|
|
))
|
|
require.NoError(handler.WriteJSON(
|
|
filepath.Join(constants.ServiceBasePath, constants.EnforcedPCRsFilename),
|
|
[]uint32{11},
|
|
))
|
|
require.NoError(handler.Write(
|
|
filepath.Join(constants.ServiceBasePath, constants.IDKeyDigestFilename),
|
|
[]byte{},
|
|
))
|
|
require.NoError(handler.Write(
|
|
filepath.Join(constants.ServiceBasePath, constants.EnforceIDKeyDigestFilename),
|
|
[]byte("false"),
|
|
))
|
|
require.NoError(handler.Write(
|
|
filepath.Join(constants.ServiceBasePath, constants.AzureCVM),
|
|
[]byte("true"),
|
|
))
|
|
|
|
var wg sync.WaitGroup
|
|
|
|
for i := 0; i < 10; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
assert.NoError(validator.Update())
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
func testConnection(require *require.Assertions, url string, oid fakeOID) (*http.Response, error) {
|
|
clientConfig, err := atls.CreateAttestationClientTLSConfig(fakeIssuer{fakeOID: oid}, nil)
|
|
require.NoError(err)
|
|
client := http.Client{Transport: &http.Transport{TLSClientConfig: clientConfig}}
|
|
|
|
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, http.NoBody)
|
|
require.NoError(err)
|
|
return client.Do(req)
|
|
}
|
|
|
|
type fakeIssuer struct {
|
|
fakeOID
|
|
}
|
|
|
|
func (fakeIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) {
|
|
return json.Marshal(fakeDoc{UserData: userData, Nonce: nonce})
|
|
}
|
|
|
|
type fakeValidator struct {
|
|
fakeOID
|
|
err error
|
|
}
|
|
|
|
func (v fakeValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
|
|
var doc fakeDoc
|
|
if err := json.Unmarshal(attDoc, &doc); err != nil {
|
|
return nil, err
|
|
}
|
|
if !bytes.Equal(doc.Nonce, nonce) {
|
|
return nil, errors.New("invalid nonce")
|
|
}
|
|
return doc.UserData, v.err
|
|
}
|
|
|
|
type fakeOID asn1.ObjectIdentifier
|
|
|
|
func (o fakeOID) OID() asn1.ObjectIdentifier {
|
|
return asn1.ObjectIdentifier(o)
|
|
}
|
|
|
|
type fakeDoc struct {
|
|
UserData []byte
|
|
Nonce []byte
|
|
}
|