Fix potential data race when accessing a validators OID (#640)

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2022-11-24 15:50:59 +01:00 committed by GitHub
parent 594b43e629
commit c2ea937fb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 54 additions and 1 deletions

View File

@ -93,6 +93,8 @@ func (u *Updatable) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
// OID returns the validators Object Identifier.
func (u *Updatable) OID() asn1.ObjectIdentifier {
u.mux.Lock()
defer u.mux.Unlock()
return u.Validator.OID()
}

View File

@ -117,8 +117,17 @@ func TestUpdate(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
// we need safe access for overwriting the fake validator OID
oid := fakeOID{1, 3, 9900, 1}
var oidLock sync.Mutex
updatedOID := func(newOID fakeOID) {
oidLock.Lock()
defer oidLock.Unlock()
oid = newOID
}
newValidator := func(m measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator {
oidLock.Lock()
defer oidLock.Unlock()
return fakeValidator{fakeOID: oid}
}
handler := file.NewHandler(afero.NewMemMapFs())
@ -174,7 +183,7 @@ func TestUpdate(t *testing.T) {
assert.EqualValues("hello", body)
// update the server's validator
oid = fakeOID{1, 3, 9900, 2}
updatedOID(fakeOID{1, 3, 9900, 2})
require.NoError(validator.Update())
// client connection should fail now, since the server's validator expects a different OID from the client
@ -202,6 +211,48 @@ func TestUpdate(t *testing.T) {
assert.NoError(validator.Update())
}
func TestOIDConcurrency(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
handler := file.NewHandler(afero.NewMemMapFs())
require.NoError(handler.WriteJSON(
filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename),
measurements.M{11: measurements.WithAllBytes(0x00, false)},
))
require.NoError(handler.Write(
filepath.Join(constants.ServiceBasePath, constants.IDKeyDigestFilename),
[]byte{},
))
newValidator := func(m measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator {
return fakeValidator{fakeOID: fakeOID{1, 3, 9900, 1}}
}
// create server
validator := &Updatable{
log: logger.NewTest(t),
newValidator: newValidator,
fileHandler: handler,
}
// call update once to initialize the server's validator
require.NoError(validator.Update())
var wg sync.WaitGroup
wg.Add(2 * 20)
for i := 0; i < 20; i++ {
go func() {
defer wg.Done()
assert.NoError(validator.Update())
}()
go func() {
defer wg.Done()
validator.OID()
}()
}
wg.Wait()
}
func TestUpdateConcurrency(t *testing.T) {
assert := assert.New(t)
require := require.New(t)