mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-01-02 11:26:25 -05:00
Fix potential data race when accessing a validators OID (#640)
Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
parent
594b43e629
commit
c2ea937fb5
@ -93,6 +93,8 @@ func (u *Updatable) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
|
||||
|
||||
// OID returns the validators Object Identifier.
|
||||
func (u *Updatable) OID() asn1.ObjectIdentifier {
|
||||
u.mux.Lock()
|
||||
defer u.mux.Unlock()
|
||||
return u.Validator.OID()
|
||||
}
|
||||
|
||||
|
@ -117,8 +117,17 @@ func TestUpdate(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
// we need safe access for overwriting the fake validator OID
|
||||
oid := fakeOID{1, 3, 9900, 1}
|
||||
var oidLock sync.Mutex
|
||||
updatedOID := func(newOID fakeOID) {
|
||||
oidLock.Lock()
|
||||
defer oidLock.Unlock()
|
||||
oid = newOID
|
||||
}
|
||||
newValidator := func(m measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator {
|
||||
oidLock.Lock()
|
||||
defer oidLock.Unlock()
|
||||
return fakeValidator{fakeOID: oid}
|
||||
}
|
||||
handler := file.NewHandler(afero.NewMemMapFs())
|
||||
@ -174,7 +183,7 @@ func TestUpdate(t *testing.T) {
|
||||
assert.EqualValues("hello", body)
|
||||
|
||||
// update the server's validator
|
||||
oid = fakeOID{1, 3, 9900, 2}
|
||||
updatedOID(fakeOID{1, 3, 9900, 2})
|
||||
require.NoError(validator.Update())
|
||||
|
||||
// client connection should fail now, since the server's validator expects a different OID from the client
|
||||
@ -202,6 +211,48 @@ func TestUpdate(t *testing.T) {
|
||||
assert.NoError(validator.Update())
|
||||
}
|
||||
|
||||
func TestOIDConcurrency(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
handler := file.NewHandler(afero.NewMemMapFs())
|
||||
require.NoError(handler.WriteJSON(
|
||||
filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename),
|
||||
measurements.M{11: measurements.WithAllBytes(0x00, false)},
|
||||
))
|
||||
require.NoError(handler.Write(
|
||||
filepath.Join(constants.ServiceBasePath, constants.IDKeyDigestFilename),
|
||||
[]byte{},
|
||||
))
|
||||
|
||||
newValidator := func(m measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator {
|
||||
return fakeValidator{fakeOID: fakeOID{1, 3, 9900, 1}}
|
||||
}
|
||||
// create server
|
||||
validator := &Updatable{
|
||||
log: logger.NewTest(t),
|
||||
newValidator: newValidator,
|
||||
fileHandler: handler,
|
||||
}
|
||||
|
||||
// call update once to initialize the server's validator
|
||||
require.NoError(validator.Update())
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2 * 20)
|
||||
for i := 0; i < 20; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
assert.NoError(validator.Update())
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
validator.OID()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestUpdateConcurrency(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
Loading…
Reference in New Issue
Block a user