mirror of
https://github.com/edgelesssys/constellation.git
synced 2024-12-25 07:29:38 -05:00
Add concurrency tests for atls connections (#211)
This commit is contained in:
parent
e9916a7d3a
commit
86d29a4567
@ -13,8 +13,13 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
||||
|
||||
func TestTLSConfig(t *testing.T) {
|
||||
oid1 := fakeOID{1, 3, 9900, 1}
|
||||
oid2 := fakeOID{1, 3, 9900, 2}
|
||||
@ -170,6 +175,153 @@ func TestTLSConfig(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientConnectionConcurrency(t *testing.T) {
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
|
||||
//
|
||||
// Create servers.
|
||||
//
|
||||
|
||||
const serverCount = 15
|
||||
|
||||
var urls []string
|
||||
oid1 := fakeOID{1, 3, 9900, 1}
|
||||
|
||||
for i := 0; i < serverCount; i++ {
|
||||
serverCfg, err := CreateAttestationServerTLSConfig(fakeIssuer{fakeOID: oid1}, []Validator{fakeValidator{fakeOID: oid1}})
|
||||
require.NoError(err)
|
||||
|
||||
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "hello")
|
||||
}))
|
||||
server.TLS = serverCfg
|
||||
|
||||
server.StartTLS()
|
||||
defer server.Close()
|
||||
|
||||
urls = append(urls, server.URL)
|
||||
}
|
||||
|
||||
//
|
||||
// Create client.
|
||||
//
|
||||
|
||||
clientConfig, err := CreateAttestationClientTLSConfig(fakeIssuer{fakeOID: oid1}, []Validator{fakeValidator{fakeOID: oid1}})
|
||||
require.NoError(err)
|
||||
client := http.Client{Transport: &http.Transport{TLSClientConfig: clientConfig}}
|
||||
|
||||
//
|
||||
// Prepare a request for each server.
|
||||
//
|
||||
|
||||
var reqs []*http.Request
|
||||
for _, url := range urls {
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, http.NoBody)
|
||||
require.NoError(err)
|
||||
reqs = append(reqs, req)
|
||||
}
|
||||
|
||||
//
|
||||
// Do the request concurrently and collect the errors in a channel.
|
||||
// The config of the client is reused, so the nonce isn't fresh.
|
||||
// This explicitly checks for data races on the clientConnection.
|
||||
//
|
||||
|
||||
errChan := make(chan error, serverCount)
|
||||
|
||||
for _, req := range reqs {
|
||||
go func(req *http.Request) {
|
||||
resp, err := client.Do(req)
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
errChan <- err
|
||||
}(req)
|
||||
}
|
||||
|
||||
//
|
||||
// Wait for the requests to finish and check the errors.
|
||||
//
|
||||
|
||||
for i := 0; i < serverCount; i++ {
|
||||
assert.NoError(<-errChan)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerConnectionConcurrency(t *testing.T) {
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
|
||||
//
|
||||
// Create servers.
|
||||
// The serverCfg is reused.
|
||||
//
|
||||
|
||||
const serverCount = 10
|
||||
|
||||
var urls []string
|
||||
oid1 := fakeOID{1, 3, 9900, 1}
|
||||
|
||||
serverCfg, err := CreateAttestationServerTLSConfig(fakeIssuer{fakeOID: oid1}, []Validator{fakeValidator{fakeOID: oid1}})
|
||||
require.NoError(err)
|
||||
|
||||
for i := 0; i < serverCount; i++ {
|
||||
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "hello")
|
||||
}))
|
||||
server.TLS = serverCfg
|
||||
|
||||
server.StartTLS()
|
||||
defer server.Close()
|
||||
|
||||
urls = append(urls, server.URL)
|
||||
}
|
||||
|
||||
//
|
||||
// Create client.
|
||||
//
|
||||
|
||||
clientConfig, err := CreateAttestationClientTLSConfig(fakeIssuer{fakeOID: oid1}, []Validator{fakeValidator{fakeOID: oid1}})
|
||||
require.NoError(err)
|
||||
client := http.Client{Transport: &http.Transport{TLSClientConfig: clientConfig}}
|
||||
|
||||
//
|
||||
// Prepare a request for each server.
|
||||
//
|
||||
|
||||
var reqs []*http.Request
|
||||
for _, url := range urls {
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, http.NoBody)
|
||||
require.NoError(err)
|
||||
reqs = append(reqs, req)
|
||||
}
|
||||
|
||||
//
|
||||
// Do the request concurrently and collect the errors in a channel.
|
||||
//
|
||||
|
||||
errChan := make(chan error, serverCount)
|
||||
|
||||
for _, req := range reqs {
|
||||
go func(req *http.Request) {
|
||||
resp, err := client.Do(req)
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
errChan <- err
|
||||
}(req)
|
||||
}
|
||||
|
||||
//
|
||||
// Wait for the requests to finish and check the errors.
|
||||
//
|
||||
|
||||
for i := 0; i < serverCount; i++ {
|
||||
assert.NoError(<-errChan)
|
||||
}
|
||||
}
|
||||
|
||||
type fakeIssuer struct {
|
||||
fakeOID
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user