constellation/internal/watcher/validator_test.go

222 lines
5.4 KiB
Go
Raw Normal View History

package watcher
import (
"bytes"
"context"
"encoding/asn1"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"path/filepath"
"sync"
"testing"
"github.com/edgelesssys/constellation/internal/atls"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/file"
"github.com/edgelesssys/constellation/internal/logger"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewUpdateableValidator(t *testing.T) {
testCases := map[string]struct {
provider string
writeFile bool
wantErr bool
}{
"azure": {
provider: "azure",
writeFile: true,
},
"gcp": {
provider: "gcp",
writeFile: true,
},
"qemu": {
provider: "qemu",
writeFile: true,
},
"no file": {
provider: "azure",
writeFile: false,
wantErr: true,
},
"invalid provider": {
provider: "invalid",
writeFile: true,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
handler := file.NewHandler(afero.NewMemMapFs())
if tc.writeFile {
require.NoError(handler.WriteJSON(
filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename),
map[uint32][]byte{
11: {0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
},
file.OptNone,
))
}
_, err := NewValidator(
logger.NewTest(t),
tc.provider,
handler,
)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestUpdate(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
oid := fakeOID{1, 3, 9900, 1}
newValidator := func(m map[uint32][]byte) atls.Validator {
return fakeValidator{fakeOID: oid}
}
handler := file.NewHandler(afero.NewMemMapFs())
// create server
validator := &Updatable{
log: logger.NewTest(t),
newValidator: newValidator,
fileHandler: handler,
}
// Update should fail if the file does not exist
assert.Error(validator.Update())
// write measurement config
require.NoError(handler.WriteJSON(
filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename),
map[uint32][]byte{
11: {0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
},
file.OptNone,
))
// call update once to initialize the server's validator
require.NoError(validator.Update())
// create tls config and start the server
serverConfig, err := atls.CreateAttestationServerTLSConfig(nil, []atls.Validator{validator})
require.NoError(err)
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "hello")
}))
server.TLS = serverConfig
server.StartTLS()
defer server.Close()
// test connection to server
clientOID := fakeOID{1, 3, 9900, 1}
resp, err := testConnection(require, server.URL, clientOID)
require.NoError(err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(err)
assert.EqualValues("hello", body)
// update the server's validator
oid = fakeOID{1, 3, 9900, 2}
require.NoError(validator.Update())
// client connection should fail now, since the server's validator expects a different OID from the client
_, err = testConnection(require, server.URL, clientOID)
assert.Error(err)
}
func TestUpdateConcurrency(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
handler := file.NewHandler(afero.NewMemMapFs())
validator := &Updatable{
log: logger.NewTest(t),
fileHandler: handler,
newValidator: func(m map[uint32][]byte) atls.Validator {
return fakeValidator{fakeOID: fakeOID{1, 3, 9900, 1}}
},
}
require.NoError(handler.WriteJSON(
filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename),
map[uint32][]byte{
11: {0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
},
file.OptNone,
))
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
assert.NoError(validator.Update())
}()
}
wg.Wait()
}
func testConnection(require *require.Assertions, url string, oid fakeOID) (*http.Response, error) {
clientConfig, err := atls.CreateAttestationClientTLSConfig(fakeIssuer{fakeOID: oid}, nil)
require.NoError(err)
client := http.Client{Transport: &http.Transport{TLSClientConfig: clientConfig}}
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, http.NoBody)
require.NoError(err)
return client.Do(req)
}
type fakeIssuer struct {
fakeOID
}
func (fakeIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) {
return json.Marshal(fakeDoc{UserData: userData, Nonce: nonce})
}
type fakeValidator struct {
fakeOID
err error
}
func (v fakeValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
var doc fakeDoc
if err := json.Unmarshal(attDoc, &doc); err != nil {
return nil, err
}
if !bytes.Equal(doc.Nonce, nonce) {
return nil, errors.New("invalid nonce")
}
return doc.UserData, v.err
}
type fakeOID asn1.ObjectIdentifier
func (o fakeOID) OID() asn1.ObjectIdentifier {
return asn1.ObjectIdentifier(o)
}
type fakeDoc struct {
UserData []byte
Nonce []byte
}