mirror of
https://github.com/edgelesssys/constellation.git
synced 2024-12-26 07:59:37 -05:00
Add concurrency tests for atls connections (#211)
This commit is contained in:
parent
e9916a7d3a
commit
86d29a4567
@ -13,8 +13,13 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.uber.org/goleak"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
goleak.VerifyTestMain(m)
|
||||||
|
}
|
||||||
|
|
||||||
func TestTLSConfig(t *testing.T) {
|
func TestTLSConfig(t *testing.T) {
|
||||||
oid1 := fakeOID{1, 3, 9900, 1}
|
oid1 := fakeOID{1, 3, 9900, 1}
|
||||||
oid2 := fakeOID{1, 3, 9900, 2}
|
oid2 := fakeOID{1, 3, 9900, 2}
|
||||||
@ -170,6 +175,153 @@ func TestTLSConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClientConnectionConcurrency(t *testing.T) {
|
||||||
|
require := require.New(t)
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
//
|
||||||
|
// Create servers.
|
||||||
|
//
|
||||||
|
|
||||||
|
const serverCount = 15
|
||||||
|
|
||||||
|
var urls []string
|
||||||
|
oid1 := fakeOID{1, 3, 9900, 1}
|
||||||
|
|
||||||
|
for i := 0; i < serverCount; i++ {
|
||||||
|
serverCfg, err := CreateAttestationServerTLSConfig(fakeIssuer{fakeOID: oid1}, []Validator{fakeValidator{fakeOID: oid1}})
|
||||||
|
require.NoError(err)
|
||||||
|
|
||||||
|
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = io.WriteString(w, "hello")
|
||||||
|
}))
|
||||||
|
server.TLS = serverCfg
|
||||||
|
|
||||||
|
server.StartTLS()
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
urls = append(urls, server.URL)
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Create client.
|
||||||
|
//
|
||||||
|
|
||||||
|
clientConfig, err := CreateAttestationClientTLSConfig(fakeIssuer{fakeOID: oid1}, []Validator{fakeValidator{fakeOID: oid1}})
|
||||||
|
require.NoError(err)
|
||||||
|
client := http.Client{Transport: &http.Transport{TLSClientConfig: clientConfig}}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Prepare a request for each server.
|
||||||
|
//
|
||||||
|
|
||||||
|
var reqs []*http.Request
|
||||||
|
for _, url := range urls {
|
||||||
|
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, http.NoBody)
|
||||||
|
require.NoError(err)
|
||||||
|
reqs = append(reqs, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Do the request concurrently and collect the errors in a channel.
|
||||||
|
// The config of the client is reused, so the nonce isn't fresh.
|
||||||
|
// This explicitly checks for data races on the clientConnection.
|
||||||
|
//
|
||||||
|
|
||||||
|
errChan := make(chan error, serverCount)
|
||||||
|
|
||||||
|
for _, req := range reqs {
|
||||||
|
go func(req *http.Request) {
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err == nil {
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
errChan <- err
|
||||||
|
}(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Wait for the requests to finish and check the errors.
|
||||||
|
//
|
||||||
|
|
||||||
|
for i := 0; i < serverCount; i++ {
|
||||||
|
assert.NoError(<-errChan)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerConnectionConcurrency(t *testing.T) {
|
||||||
|
require := require.New(t)
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
//
|
||||||
|
// Create servers.
|
||||||
|
// The serverCfg is reused.
|
||||||
|
//
|
||||||
|
|
||||||
|
const serverCount = 10
|
||||||
|
|
||||||
|
var urls []string
|
||||||
|
oid1 := fakeOID{1, 3, 9900, 1}
|
||||||
|
|
||||||
|
serverCfg, err := CreateAttestationServerTLSConfig(fakeIssuer{fakeOID: oid1}, []Validator{fakeValidator{fakeOID: oid1}})
|
||||||
|
require.NoError(err)
|
||||||
|
|
||||||
|
for i := 0; i < serverCount; i++ {
|
||||||
|
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = io.WriteString(w, "hello")
|
||||||
|
}))
|
||||||
|
server.TLS = serverCfg
|
||||||
|
|
||||||
|
server.StartTLS()
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
urls = append(urls, server.URL)
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Create client.
|
||||||
|
//
|
||||||
|
|
||||||
|
clientConfig, err := CreateAttestationClientTLSConfig(fakeIssuer{fakeOID: oid1}, []Validator{fakeValidator{fakeOID: oid1}})
|
||||||
|
require.NoError(err)
|
||||||
|
client := http.Client{Transport: &http.Transport{TLSClientConfig: clientConfig}}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Prepare a request for each server.
|
||||||
|
//
|
||||||
|
|
||||||
|
var reqs []*http.Request
|
||||||
|
for _, url := range urls {
|
||||||
|
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, http.NoBody)
|
||||||
|
require.NoError(err)
|
||||||
|
reqs = append(reqs, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Do the request concurrently and collect the errors in a channel.
|
||||||
|
//
|
||||||
|
|
||||||
|
errChan := make(chan error, serverCount)
|
||||||
|
|
||||||
|
for _, req := range reqs {
|
||||||
|
go func(req *http.Request) {
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err == nil {
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
errChan <- err
|
||||||
|
}(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Wait for the requests to finish and check the errors.
|
||||||
|
//
|
||||||
|
|
||||||
|
for i := 0; i < serverCount; i++ {
|
||||||
|
assert.NoError(<-errChan)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type fakeIssuer struct {
|
type fakeIssuer struct {
|
||||||
fakeOID
|
fakeOID
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user