/* Copyright (c) Edgeless Systems GmbH SPDX-License-Identifier: AGPL-3.0-only */ package watcher import ( "bytes" "context" "encoding/asn1" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "path/filepath" "sync" "testing" "github.com/edgelesssys/constellation/v2/internal/atls" "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/logger" "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" ) func TestMain(m *testing.M) { goleak.VerifyTestMain(m, // https://github.com/census-instrumentation/opencensus-go/issues/1262 goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"), ) } func TestNewUpdateableValidator(t *testing.T) { testCases := map[string]struct { provider string writeFile bool wantErr bool }{ "azure": { provider: "azure", writeFile: true, }, "gcp": { provider: "gcp", writeFile: true, }, "qemu": { provider: "qemu", writeFile: true, }, "no file": { provider: "azure", writeFile: false, wantErr: true, }, "invalid provider": { provider: "invalid", writeFile: true, wantErr: true, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) require := require.New(t) handler := file.NewHandler(afero.NewMemMapFs()) if tc.writeFile { require.NoError(handler.WriteJSON( filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename), map[uint32][]byte{ 11: {0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, }, )) require.NoError(handler.WriteJSON( filepath.Join(constants.ServiceBasePath, constants.EnforcedPCRsFilename), []uint32{11}, )) require.NoError(handler.Write( filepath.Join(constants.ServiceBasePath, constants.IDKeyDigestFilename), []byte{}, )) require.NoError(handler.Write( filepath.Join(constants.ServiceBasePath, constants.EnforceIDKeyDigestFilename), []byte("false"), )) require.NoError(handler.Write( filepath.Join(constants.ServiceBasePath, constants.AzureCVM), []byte("true"), )) } _, err := NewValidator( logger.NewTest(t), tc.provider, handler, false, ) if tc.wantErr { assert.Error(err) } else { assert.NoError(err) } }) } } func TestUpdate(t *testing.T) { assert := assert.New(t) require := require.New(t) // we need safe access for overwriting the fake validator OID oid := fakeOID{1, 3, 9900, 1} var oidLock sync.Mutex updatedOID := func(newOID fakeOID) { oidLock.Lock() defer oidLock.Unlock() oid = newOID } newValidator := func(m measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator { oidLock.Lock() defer oidLock.Unlock() return fakeValidator{fakeOID: oid} } handler := file.NewHandler(afero.NewMemMapFs()) // create server validator := &Updatable{ log: logger.NewTest(t), newValidator: newValidator, fileHandler: handler, } // Update should fail if the file does not exist assert.Error(validator.Update()) // write measurement config require.NoError(handler.WriteJSON( filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename), measurements.M{11: measurements.WithAllBytes(0x00, false)}, )) require.NoError(handler.Write( filepath.Join(constants.ServiceBasePath, constants.IDKeyDigestFilename), []byte{}, )) require.NoError(handler.Write( filepath.Join(constants.ServiceBasePath, constants.EnforceIDKeyDigestFilename), []byte("false"), )) require.NoError(handler.Write( filepath.Join(constants.ServiceBasePath, constants.AzureCVM), []byte("true"), )) // call update once to initialize the server's validator require.NoError(validator.Update()) // create tls config and start the server serverConfig, err := atls.CreateAttestationServerTLSConfig(nil, []atls.Validator{validator}) require.NoError(err) server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = io.WriteString(w, "hello") })) server.TLS = serverConfig server.StartTLS() defer server.Close() // test connection to server clientOID := fakeOID{1, 3, 9900, 1} resp, err := testConnection(require, server.URL, clientOID) require.NoError(err) defer resp.Body.Close() body, err := io.ReadAll(resp.Body) require.NoError(err) assert.EqualValues("hello", body) // update the server's validator updatedOID(fakeOID{1, 3, 9900, 2}) require.NoError(validator.Update()) // client connection should fail now, since the server's validator expects a different OID from the client resp, err = testConnection(require, server.URL, clientOID) if err == nil { defer resp.Body.Close() } assert.Error(err) // update should work for legacy measurement format // TODO: remove with v2.4.0 require.NoError(handler.WriteJSON( filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename), map[uint32][]byte{ 11: bytes.Repeat([]byte{0x0}, 32), 12: bytes.Repeat([]byte{0x1}, 32), }, file.OptOverwrite, )) require.NoError(handler.WriteJSON( filepath.Join(constants.ServiceBasePath, constants.EnforcedPCRsFilename), []uint32{11}, )) assert.NoError(validator.Update()) } func TestOIDConcurrency(t *testing.T) { assert := assert.New(t) require := require.New(t) handler := file.NewHandler(afero.NewMemMapFs()) require.NoError(handler.WriteJSON( filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename), measurements.M{11: measurements.WithAllBytes(0x00, false)}, )) require.NoError(handler.Write( filepath.Join(constants.ServiceBasePath, constants.IDKeyDigestFilename), []byte{}, )) newValidator := func(m measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator { return fakeValidator{fakeOID: fakeOID{1, 3, 9900, 1}} } // create server validator := &Updatable{ log: logger.NewTest(t), newValidator: newValidator, fileHandler: handler, } // call update once to initialize the server's validator require.NoError(validator.Update()) var wg sync.WaitGroup wg.Add(2 * 20) for i := 0; i < 20; i++ { go func() { defer wg.Done() assert.NoError(validator.Update()) }() go func() { defer wg.Done() validator.OID() }() } wg.Wait() } func TestUpdateConcurrency(t *testing.T) { assert := assert.New(t) require := require.New(t) handler := file.NewHandler(afero.NewMemMapFs()) validator := &Updatable{ log: logger.NewTest(t), fileHandler: handler, newValidator: func(m measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator { return fakeValidator{fakeOID: fakeOID{1, 3, 9900, 1}} }, } require.NoError(handler.WriteJSON( filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename), map[uint32][]byte{ 11: {0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, }, file.OptNone, )) require.NoError(handler.WriteJSON( filepath.Join(constants.ServiceBasePath, constants.EnforcedPCRsFilename), []uint32{11}, )) require.NoError(handler.Write( filepath.Join(constants.ServiceBasePath, constants.IDKeyDigestFilename), []byte{}, )) require.NoError(handler.Write( filepath.Join(constants.ServiceBasePath, constants.EnforceIDKeyDigestFilename), []byte("false"), )) require.NoError(handler.Write( filepath.Join(constants.ServiceBasePath, constants.AzureCVM), []byte("true"), )) var wg sync.WaitGroup for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() assert.NoError(validator.Update()) }() } wg.Wait() } func testConnection(require *require.Assertions, url string, oid fakeOID) (*http.Response, error) { clientConfig, err := atls.CreateAttestationClientTLSConfig(fakeIssuer{fakeOID: oid}, nil) require.NoError(err) client := http.Client{Transport: &http.Transport{TLSClientConfig: clientConfig}} req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, http.NoBody) require.NoError(err) return client.Do(req) } type fakeIssuer struct { fakeOID } func (fakeIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) { return json.Marshal(fakeDoc{UserData: userData, Nonce: nonce}) } type fakeValidator struct { fakeOID err error } func (v fakeValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) { var doc fakeDoc if err := json.Unmarshal(attDoc, &doc); err != nil { return nil, err } if !bytes.Equal(doc.Nonce, nonce) { return nil, errors.New("invalid nonce") } return doc.UserData, v.err } type fakeOID asn1.ObjectIdentifier func (o fakeOID) OID() asn1.ObjectIdentifier { return asn1.ObjectIdentifier(o) } type fakeDoc struct { UserData []byte Nonce []byte }