mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-08-01 11:36:10 -04:00
Fix potential data race when accessing a validators OID (#640)
Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
parent
594b43e629
commit
c2ea937fb5
2 changed files with 54 additions and 1 deletions
|
@ -93,6 +93,8 @@ func (u *Updatable) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
|
||||||
|
|
||||||
// OID returns the validators Object Identifier.
|
// OID returns the validators Object Identifier.
|
||||||
func (u *Updatable) OID() asn1.ObjectIdentifier {
|
func (u *Updatable) OID() asn1.ObjectIdentifier {
|
||||||
|
u.mux.Lock()
|
||||||
|
defer u.mux.Unlock()
|
||||||
return u.Validator.OID()
|
return u.Validator.OID()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -117,8 +117,17 @@ func TestUpdate(t *testing.T) {
|
||||||
assert := assert.New(t)
|
assert := assert.New(t)
|
||||||
require := require.New(t)
|
require := require.New(t)
|
||||||
|
|
||||||
|
// we need safe access for overwriting the fake validator OID
|
||||||
oid := fakeOID{1, 3, 9900, 1}
|
oid := fakeOID{1, 3, 9900, 1}
|
||||||
|
var oidLock sync.Mutex
|
||||||
|
updatedOID := func(newOID fakeOID) {
|
||||||
|
oidLock.Lock()
|
||||||
|
defer oidLock.Unlock()
|
||||||
|
oid = newOID
|
||||||
|
}
|
||||||
newValidator := func(m measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator {
|
newValidator := func(m measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator {
|
||||||
|
oidLock.Lock()
|
||||||
|
defer oidLock.Unlock()
|
||||||
return fakeValidator{fakeOID: oid}
|
return fakeValidator{fakeOID: oid}
|
||||||
}
|
}
|
||||||
handler := file.NewHandler(afero.NewMemMapFs())
|
handler := file.NewHandler(afero.NewMemMapFs())
|
||||||
|
@ -174,7 +183,7 @@ func TestUpdate(t *testing.T) {
|
||||||
assert.EqualValues("hello", body)
|
assert.EqualValues("hello", body)
|
||||||
|
|
||||||
// update the server's validator
|
// update the server's validator
|
||||||
oid = fakeOID{1, 3, 9900, 2}
|
updatedOID(fakeOID{1, 3, 9900, 2})
|
||||||
require.NoError(validator.Update())
|
require.NoError(validator.Update())
|
||||||
|
|
||||||
// client connection should fail now, since the server's validator expects a different OID from the client
|
// client connection should fail now, since the server's validator expects a different OID from the client
|
||||||
|
@ -202,6 +211,48 @@ func TestUpdate(t *testing.T) {
|
||||||
assert.NoError(validator.Update())
|
assert.NoError(validator.Update())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOIDConcurrency(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
handler := file.NewHandler(afero.NewMemMapFs())
|
||||||
|
require.NoError(handler.WriteJSON(
|
||||||
|
filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename),
|
||||||
|
measurements.M{11: measurements.WithAllBytes(0x00, false)},
|
||||||
|
))
|
||||||
|
require.NoError(handler.Write(
|
||||||
|
filepath.Join(constants.ServiceBasePath, constants.IDKeyDigestFilename),
|
||||||
|
[]byte{},
|
||||||
|
))
|
||||||
|
|
||||||
|
newValidator := func(m measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator {
|
||||||
|
return fakeValidator{fakeOID: fakeOID{1, 3, 9900, 1}}
|
||||||
|
}
|
||||||
|
// create server
|
||||||
|
validator := &Updatable{
|
||||||
|
log: logger.NewTest(t),
|
||||||
|
newValidator: newValidator,
|
||||||
|
fileHandler: handler,
|
||||||
|
}
|
||||||
|
|
||||||
|
// call update once to initialize the server's validator
|
||||||
|
require.NoError(validator.Update())
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2 * 20)
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
assert.NoError(validator.Update())
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
validator.OID()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
func TestUpdateConcurrency(t *testing.T) {
|
func TestUpdateConcurrency(t *testing.T) {
|
||||||
assert := assert.New(t)
|
assert := assert.New(t)
|
||||||
require := require.New(t)
|
require := require.New(t)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue