/* Copyright (c) Edgeless Systems GmbH SPDX-License-Identifier: AGPL-3.0-only */ package atls import ( "context" "encoding/asn1" "errors" "io" "net/http" "net/http/httptest" "testing" "github.com/edgelesssys/constellation/v2/internal/oid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" ) func TestMain(m *testing.M) { goleak.VerifyTestMain(m) } func TestTLSConfig(t *testing.T) { oid1 := fakeOID{asn1.ObjectIdentifier{1, 3, 9900, 1}} oid2 := fakeOID{asn1.ObjectIdentifier{1, 3, 9900, 2}} testCases := map[string]struct { clientIssuer Issuer clientValidators []Validator serverIssuer Issuer serverValidators []Validator wantErr bool }{ "client->server basic": { serverIssuer: NewFakeIssuer(oid1), clientValidators: []Validator{NewFakeValidator(oid1)}, }, "client->server multiple validators": { serverIssuer: NewFakeIssuer(oid2), clientValidators: []Validator{NewFakeValidator(oid1), NewFakeValidator(oid2)}, }, "client->server validate error": { serverIssuer: NewFakeIssuer(oid1), clientValidators: []Validator{FakeValidator{oid1, errors.New("failed")}}, wantErr: true, }, "client->server unknown oid": { serverIssuer: NewFakeIssuer(oid1), clientValidators: []Validator{NewFakeValidator(oid2)}, wantErr: true, }, "client->server client cert is not verified": { serverIssuer: NewFakeIssuer(oid1), clientValidators: []Validator{NewFakeValidator(oid1)}, }, "server->client basic": { serverValidators: []Validator{NewFakeValidator(oid1)}, clientIssuer: NewFakeIssuer(oid1), }, "server->client multiple validators": { serverValidators: []Validator{NewFakeValidator(oid1), NewFakeValidator(oid2)}, clientIssuer: NewFakeIssuer(oid2), }, "server->client validate error": { serverValidators: []Validator{FakeValidator{oid1, errors.New("failed")}}, clientIssuer: NewFakeIssuer(oid1), wantErr: true, }, "server->client unknown oid": { serverValidators: []Validator{NewFakeValidator(oid2)}, clientIssuer: NewFakeIssuer(oid1), wantErr: true, }, "mutual basic": { serverIssuer: NewFakeIssuer(oid1), serverValidators: []Validator{NewFakeValidator(oid1)}, clientIssuer: NewFakeIssuer(oid1), clientValidators: []Validator{NewFakeValidator(oid1)}, }, "mutual multiple validators": { serverIssuer: NewFakeIssuer(oid2), serverValidators: []Validator{NewFakeValidator(oid1), NewFakeValidator(oid2)}, clientIssuer: NewFakeIssuer(oid2), clientValidators: []Validator{NewFakeValidator(oid1), NewFakeValidator(oid2)}, }, "mutual fails if client sends no attestation": { serverIssuer: NewFakeIssuer(oid1), serverValidators: []Validator{NewFakeValidator(oid1)}, clientValidators: []Validator{NewFakeValidator(oid1)}, wantErr: true, }, "mutual fails if server sends no attestation": { serverValidators: []Validator{NewFakeValidator(oid1)}, clientIssuer: NewFakeIssuer(oid1), clientValidators: []Validator{NewFakeValidator(oid1)}, wantErr: true, }, "mutual validate error client side": { serverIssuer: NewFakeIssuer(oid1), serverValidators: []Validator{NewFakeValidator(oid1)}, clientIssuer: NewFakeIssuer(oid1), clientValidators: []Validator{FakeValidator{oid1, errors.New("failed")}}, wantErr: true, }, "mutual validate error server side": { serverIssuer: NewFakeIssuer(oid1), serverValidators: []Validator{FakeValidator{oid1, errors.New("failed")}}, clientIssuer: NewFakeIssuer(oid1), clientValidators: []Validator{NewFakeValidator(oid1)}, wantErr: true, }, "mutual unknown oid from client": { serverIssuer: NewFakeIssuer(oid1), serverValidators: []Validator{NewFakeValidator(oid1)}, clientIssuer: NewFakeIssuer(oid2), clientValidators: []Validator{NewFakeValidator(oid1)}, wantErr: true, }, "mutual unknown oid from server": { serverIssuer: NewFakeIssuer(oid2), serverValidators: []Validator{NewFakeValidator(oid1)}, clientIssuer: NewFakeIssuer(oid1), clientValidators: []Validator{NewFakeValidator(oid1)}, wantErr: true, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) require := require.New(t) // // Create server // serverConfig, err := CreateAttestationServerTLSConfig(tc.serverIssuer, tc.serverValidators) require.NoError(err) server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = io.WriteString(w, "hello") })) server.TLS = serverConfig // // Create client // clientConfig, err := CreateAttestationClientTLSConfig(tc.clientIssuer, tc.clientValidators) require.NoError(err) client := http.Client{Transport: &http.Transport{TLSClientConfig: clientConfig}} // // Test connection // server.StartTLS() defer server.Close() req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, http.NoBody) require.NoError(err) resp, err := client.Do(req) if tc.wantErr { assert.Error(err) return } require.NoError(err) defer resp.Body.Close() body, err := io.ReadAll(resp.Body) require.NoError(err) assert.EqualValues("hello", body) }) } } func TestClientConnectionConcurrency(t *testing.T) { require := require.New(t) assert := assert.New(t) // // Create servers. // const serverCount = 15 var urls []string for i := 0; i < serverCount; i++ { serverCfg, err := CreateAttestationServerTLSConfig(NewFakeIssuer(oid.Dummy{}), NewFakeValidators(oid.Dummy{})) require.NoError(err) server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = io.WriteString(w, "hello") })) server.TLS = serverCfg server.StartTLS() defer server.Close() urls = append(urls, server.URL) } // // Create client. // clientConfig, err := CreateAttestationClientTLSConfig(NewFakeIssuer(oid.Dummy{}), NewFakeValidators(oid.Dummy{})) require.NoError(err) client := http.Client{Transport: &http.Transport{TLSClientConfig: clientConfig}} // // Prepare a request for each server. // var reqs []*http.Request for _, url := range urls { req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, http.NoBody) require.NoError(err) reqs = append(reqs, req) } // // Do the request concurrently and collect the errors in a channel. // The config of the client is reused, so the nonce isn't fresh. // This explicitly checks for data races on the clientConnection. // errChan := make(chan error, serverCount) for _, req := range reqs { go func(req *http.Request) { resp, err := client.Do(req) if err == nil { resp.Body.Close() } errChan <- err }(req) } // // Wait for the requests to finish and check the errors. // for i := 0; i < serverCount; i++ { assert.NoError(<-errChan) } } func TestServerConnectionConcurrency(t *testing.T) { require := require.New(t) assert := assert.New(t) // // Create servers. // The serverCfg is reused. // const serverCount = 10 var urls []string serverCfg, err := CreateAttestationServerTLSConfig(NewFakeIssuer(oid.Dummy{}), NewFakeValidators(oid.Dummy{})) require.NoError(err) for i := 0; i < serverCount; i++ { server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = io.WriteString(w, "hello") })) server.TLS = serverCfg server.StartTLS() defer server.Close() urls = append(urls, server.URL) } // // Create client. // clientConfig, err := CreateAttestationClientTLSConfig(NewFakeIssuer(oid.Dummy{}), NewFakeValidators(oid.Dummy{})) require.NoError(err) client := http.Client{Transport: &http.Transport{TLSClientConfig: clientConfig}} // // Prepare a request for each server. // var reqs []*http.Request for _, url := range urls { req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, http.NoBody) require.NoError(err) reqs = append(reqs, req) } // // Do the request concurrently and collect the errors in a channel. // errChan := make(chan error, serverCount) for _, req := range reqs { go func(req *http.Request) { resp, err := client.Do(req) if err == nil { resp.Body.Close() } errChan <- err }(req) } // // Wait for the requests to finish and check the errors. // for i := 0; i < serverCount; i++ { assert.NoError(<-errChan) } }