From 86d29a456782a2a6ca9cea239e365cbf5778714f Mon Sep 17 00:00:00 2001 From: Paul Meyer <49727155+katexochen@users.noreply.github.com> Date: Wed, 15 Jun 2022 10:35:15 +0200 Subject: [PATCH] Add concurrency tests for atls connections (#211) --- internal/atls/atls_test.go | 152 +++++++++++++++++++++++++++++++++++++ 1 file changed, 152 insertions(+) diff --git a/internal/atls/atls_test.go b/internal/atls/atls_test.go index 22a79f458..12d0521c1 100644 --- a/internal/atls/atls_test.go +++ b/internal/atls/atls_test.go @@ -13,8 +13,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/goleak" ) +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + func TestTLSConfig(t *testing.T) { oid1 := fakeOID{1, 3, 9900, 1} oid2 := fakeOID{1, 3, 9900, 2} @@ -170,6 +175,153 @@ func TestTLSConfig(t *testing.T) { } } +func TestClientConnectionConcurrency(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + + // + // Create servers. + // + + const serverCount = 15 + + var urls []string + oid1 := fakeOID{1, 3, 9900, 1} + + for i := 0; i < serverCount; i++ { + serverCfg, err := CreateAttestationServerTLSConfig(fakeIssuer{fakeOID: oid1}, []Validator{fakeValidator{fakeOID: oid1}}) + require.NoError(err) + + server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "hello") + })) + server.TLS = serverCfg + + server.StartTLS() + defer server.Close() + + urls = append(urls, server.URL) + } + + // + // Create client. + // + + clientConfig, err := CreateAttestationClientTLSConfig(fakeIssuer{fakeOID: oid1}, []Validator{fakeValidator{fakeOID: oid1}}) + require.NoError(err) + client := http.Client{Transport: &http.Transport{TLSClientConfig: clientConfig}} + + // + // Prepare a request for each server. + // + + var reqs []*http.Request + for _, url := range urls { + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, http.NoBody) + require.NoError(err) + reqs = append(reqs, req) + } + + // + // Do the request concurrently and collect the errors in a channel. + // The config of the client is reused, so the nonce isn't fresh. + // This explicitly checks for data races on the clientConnection. + // + + errChan := make(chan error, serverCount) + + for _, req := range reqs { + go func(req *http.Request) { + resp, err := client.Do(req) + if err == nil { + resp.Body.Close() + } + errChan <- err + }(req) + } + + // + // Wait for the requests to finish and check the errors. + // + + for i := 0; i < serverCount; i++ { + assert.NoError(<-errChan) + } +} + +func TestServerConnectionConcurrency(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + + // + // Create servers. + // The serverCfg is reused. + // + + const serverCount = 10 + + var urls []string + oid1 := fakeOID{1, 3, 9900, 1} + + serverCfg, err := CreateAttestationServerTLSConfig(fakeIssuer{fakeOID: oid1}, []Validator{fakeValidator{fakeOID: oid1}}) + require.NoError(err) + + for i := 0; i < serverCount; i++ { + server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "hello") + })) + server.TLS = serverCfg + + server.StartTLS() + defer server.Close() + + urls = append(urls, server.URL) + } + + // + // Create client. + // + + clientConfig, err := CreateAttestationClientTLSConfig(fakeIssuer{fakeOID: oid1}, []Validator{fakeValidator{fakeOID: oid1}}) + require.NoError(err) + client := http.Client{Transport: &http.Transport{TLSClientConfig: clientConfig}} + + // + // Prepare a request for each server. + // + + var reqs []*http.Request + for _, url := range urls { + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, http.NoBody) + require.NoError(err) + reqs = append(reqs, req) + } + + // + // Do the request concurrently and collect the errors in a channel. + // + + errChan := make(chan error, serverCount) + + for _, req := range reqs { + go func(req *http.Request) { + resp, err := client.Do(req) + if err == nil { + resp.Body.Close() + } + errChan <- err + }(req) + } + + // + // Wait for the requests to finish and check the errors. + // + + for i := 0; i < serverCount; i++ { + assert.NoError(<-errChan) + } +} + type fakeIssuer struct { fakeOID }