Add concurrency tests for atls connections (#211)

This commit is contained in:
Paul Meyer 2022-06-15 10:35:15 +02:00 committed by Thomas Tendyck
parent e9916a7d3a
commit 86d29a4567

View File

@ -13,8 +13,13 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/goleak"
) )
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestTLSConfig(t *testing.T) { func TestTLSConfig(t *testing.T) {
oid1 := fakeOID{1, 3, 9900, 1} oid1 := fakeOID{1, 3, 9900, 1}
oid2 := fakeOID{1, 3, 9900, 2} oid2 := fakeOID{1, 3, 9900, 2}
@ -170,6 +175,153 @@ func TestTLSConfig(t *testing.T) {
} }
} }
func TestClientConnectionConcurrency(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
//
// Create servers.
//
const serverCount = 15
var urls []string
oid1 := fakeOID{1, 3, 9900, 1}
for i := 0; i < serverCount; i++ {
serverCfg, err := CreateAttestationServerTLSConfig(fakeIssuer{fakeOID: oid1}, []Validator{fakeValidator{fakeOID: oid1}})
require.NoError(err)
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "hello")
}))
server.TLS = serverCfg
server.StartTLS()
defer server.Close()
urls = append(urls, server.URL)
}
//
// Create client.
//
clientConfig, err := CreateAttestationClientTLSConfig(fakeIssuer{fakeOID: oid1}, []Validator{fakeValidator{fakeOID: oid1}})
require.NoError(err)
client := http.Client{Transport: &http.Transport{TLSClientConfig: clientConfig}}
//
// Prepare a request for each server.
//
var reqs []*http.Request
for _, url := range urls {
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, http.NoBody)
require.NoError(err)
reqs = append(reqs, req)
}
//
// Do the request concurrently and collect the errors in a channel.
// The config of the client is reused, so the nonce isn't fresh.
// This explicitly checks for data races on the clientConnection.
//
errChan := make(chan error, serverCount)
for _, req := range reqs {
go func(req *http.Request) {
resp, err := client.Do(req)
if err == nil {
resp.Body.Close()
}
errChan <- err
}(req)
}
//
// Wait for the requests to finish and check the errors.
//
for i := 0; i < serverCount; i++ {
assert.NoError(<-errChan)
}
}
func TestServerConnectionConcurrency(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
//
// Create servers.
// The serverCfg is reused.
//
const serverCount = 10
var urls []string
oid1 := fakeOID{1, 3, 9900, 1}
serverCfg, err := CreateAttestationServerTLSConfig(fakeIssuer{fakeOID: oid1}, []Validator{fakeValidator{fakeOID: oid1}})
require.NoError(err)
for i := 0; i < serverCount; i++ {
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "hello")
}))
server.TLS = serverCfg
server.StartTLS()
defer server.Close()
urls = append(urls, server.URL)
}
//
// Create client.
//
clientConfig, err := CreateAttestationClientTLSConfig(fakeIssuer{fakeOID: oid1}, []Validator{fakeValidator{fakeOID: oid1}})
require.NoError(err)
client := http.Client{Transport: &http.Transport{TLSClientConfig: clientConfig}}
//
// Prepare a request for each server.
//
var reqs []*http.Request
for _, url := range urls {
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, http.NoBody)
require.NoError(err)
reqs = append(reqs, req)
}
//
// Do the request concurrently and collect the errors in a channel.
//
errChan := make(chan error, serverCount)
for _, req := range reqs {
go func(req *http.Request) {
resp, err := client.Do(req)
if err == nil {
resp.Body.Close()
}
errChan <- err
}(req)
}
//
// Wait for the requests to finish and check the errors.
//
for i := 0; i < serverCount; i++ {
assert.NoError(<-errChan)
}
}
type fakeIssuer struct { type fakeIssuer struct {
fakeOID fakeOID
} }