mirror of
https://github.com/edgelesssys/constellation.git
synced 2024-10-01 01:36:09 -04:00
327 lines
8.5 KiB
Go
327 lines
8.5 KiB
Go
/*
|
|
Copyright (c) Edgeless Systems GmbH
|
|
|
|
SPDX-License-Identifier: AGPL-3.0-only
|
|
*/
|
|
|
|
package atls
|
|
|
|
import (
|
|
"context"
|
|
"encoding/asn1"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/edgelesssys/constellation/v2/internal/oid"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/goleak"
|
|
)
|
|
|
|
func TestMain(m *testing.M) {
|
|
goleak.VerifyTestMain(m)
|
|
}
|
|
|
|
func TestTLSConfig(t *testing.T) {
|
|
oid1 := fakeOID{asn1.ObjectIdentifier{1, 3, 9900, 1}}
|
|
oid2 := fakeOID{asn1.ObjectIdentifier{1, 3, 9900, 2}}
|
|
|
|
testCases := map[string]struct {
|
|
clientIssuer Issuer
|
|
clientValidators []Validator
|
|
serverIssuer Issuer
|
|
serverValidators []Validator
|
|
wantErr bool
|
|
}{
|
|
"client->server basic": {
|
|
serverIssuer: NewFakeIssuer(oid1),
|
|
clientValidators: []Validator{NewFakeValidator(oid1)},
|
|
},
|
|
"client->server multiple validators": {
|
|
serverIssuer: NewFakeIssuer(oid2),
|
|
clientValidators: []Validator{NewFakeValidator(oid1), NewFakeValidator(oid2)},
|
|
},
|
|
"client->server validate error": {
|
|
serverIssuer: NewFakeIssuer(oid1),
|
|
clientValidators: []Validator{FakeValidator{oid1, errors.New("failed")}},
|
|
wantErr: true,
|
|
},
|
|
"client->server unknown oid": {
|
|
serverIssuer: NewFakeIssuer(oid1),
|
|
clientValidators: []Validator{NewFakeValidator(oid2)},
|
|
wantErr: true,
|
|
},
|
|
"client->server client cert is not verified": {
|
|
serverIssuer: NewFakeIssuer(oid1),
|
|
clientValidators: []Validator{NewFakeValidator(oid1)},
|
|
},
|
|
"server->client basic": {
|
|
serverValidators: []Validator{NewFakeValidator(oid1)},
|
|
clientIssuer: NewFakeIssuer(oid1),
|
|
},
|
|
"server->client multiple validators": {
|
|
serverValidators: []Validator{NewFakeValidator(oid1), NewFakeValidator(oid2)},
|
|
clientIssuer: NewFakeIssuer(oid2),
|
|
},
|
|
"server->client validate error": {
|
|
serverValidators: []Validator{FakeValidator{oid1, errors.New("failed")}},
|
|
clientIssuer: NewFakeIssuer(oid1),
|
|
wantErr: true,
|
|
},
|
|
"server->client unknown oid": {
|
|
serverValidators: []Validator{NewFakeValidator(oid2)},
|
|
clientIssuer: NewFakeIssuer(oid1),
|
|
wantErr: true,
|
|
},
|
|
"mutual basic": {
|
|
serverIssuer: NewFakeIssuer(oid1),
|
|
serverValidators: []Validator{NewFakeValidator(oid1)},
|
|
clientIssuer: NewFakeIssuer(oid1),
|
|
clientValidators: []Validator{NewFakeValidator(oid1)},
|
|
},
|
|
"mutual multiple validators": {
|
|
serverIssuer: NewFakeIssuer(oid2),
|
|
serverValidators: []Validator{NewFakeValidator(oid1), NewFakeValidator(oid2)},
|
|
clientIssuer: NewFakeIssuer(oid2),
|
|
clientValidators: []Validator{NewFakeValidator(oid1), NewFakeValidator(oid2)},
|
|
},
|
|
"mutual fails if client sends no attestation": {
|
|
serverIssuer: NewFakeIssuer(oid1),
|
|
serverValidators: []Validator{NewFakeValidator(oid1)},
|
|
clientValidators: []Validator{NewFakeValidator(oid1)},
|
|
wantErr: true,
|
|
},
|
|
"mutual fails if server sends no attestation": {
|
|
serverValidators: []Validator{NewFakeValidator(oid1)},
|
|
clientIssuer: NewFakeIssuer(oid1),
|
|
clientValidators: []Validator{NewFakeValidator(oid1)},
|
|
wantErr: true,
|
|
},
|
|
"mutual validate error client side": {
|
|
serverIssuer: NewFakeIssuer(oid1),
|
|
serverValidators: []Validator{NewFakeValidator(oid1)},
|
|
clientIssuer: NewFakeIssuer(oid1),
|
|
clientValidators: []Validator{FakeValidator{oid1, errors.New("failed")}},
|
|
wantErr: true,
|
|
},
|
|
"mutual validate error server side": {
|
|
serverIssuer: NewFakeIssuer(oid1),
|
|
serverValidators: []Validator{FakeValidator{oid1, errors.New("failed")}},
|
|
clientIssuer: NewFakeIssuer(oid1),
|
|
clientValidators: []Validator{NewFakeValidator(oid1)},
|
|
wantErr: true,
|
|
},
|
|
"mutual unknown oid from client": {
|
|
serverIssuer: NewFakeIssuer(oid1),
|
|
serverValidators: []Validator{NewFakeValidator(oid1)},
|
|
clientIssuer: NewFakeIssuer(oid2),
|
|
clientValidators: []Validator{NewFakeValidator(oid1)},
|
|
wantErr: true,
|
|
},
|
|
"mutual unknown oid from server": {
|
|
serverIssuer: NewFakeIssuer(oid2),
|
|
serverValidators: []Validator{NewFakeValidator(oid1)},
|
|
clientIssuer: NewFakeIssuer(oid1),
|
|
clientValidators: []Validator{NewFakeValidator(oid1)},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for name, tc := range testCases {
|
|
t.Run(name, func(t *testing.T) {
|
|
assert := assert.New(t)
|
|
require := require.New(t)
|
|
|
|
//
|
|
// Create server
|
|
//
|
|
|
|
serverConfig, err := CreateAttestationServerTLSConfig(tc.serverIssuer, tc.serverValidators)
|
|
require.NoError(err)
|
|
|
|
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = io.WriteString(w, "hello")
|
|
}))
|
|
server.TLS = serverConfig
|
|
|
|
//
|
|
// Create client
|
|
//
|
|
|
|
clientConfig, err := CreateAttestationClientTLSConfig(tc.clientIssuer, tc.clientValidators)
|
|
require.NoError(err)
|
|
client := http.Client{Transport: &http.Transport{TLSClientConfig: clientConfig}}
|
|
|
|
//
|
|
// Test connection
|
|
//
|
|
|
|
server.StartTLS()
|
|
defer server.Close()
|
|
|
|
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, http.NoBody)
|
|
require.NoError(err)
|
|
resp, err := client.Do(req)
|
|
if tc.wantErr {
|
|
assert.Error(err)
|
|
return
|
|
}
|
|
|
|
require.NoError(err)
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
require.NoError(err)
|
|
assert.EqualValues("hello", body)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestClientConnectionConcurrency(t *testing.T) {
|
|
require := require.New(t)
|
|
assert := assert.New(t)
|
|
|
|
//
|
|
// Create servers.
|
|
//
|
|
|
|
const serverCount = 15
|
|
|
|
var urls []string
|
|
|
|
for i := 0; i < serverCount; i++ {
|
|
serverCfg, err := CreateAttestationServerTLSConfig(NewFakeIssuer(oid.Dummy{}), NewFakeValidators(oid.Dummy{}))
|
|
require.NoError(err)
|
|
|
|
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = io.WriteString(w, "hello")
|
|
}))
|
|
server.TLS = serverCfg
|
|
|
|
server.StartTLS()
|
|
defer server.Close()
|
|
|
|
urls = append(urls, server.URL)
|
|
}
|
|
|
|
//
|
|
// Create client.
|
|
//
|
|
|
|
clientConfig, err := CreateAttestationClientTLSConfig(NewFakeIssuer(oid.Dummy{}), NewFakeValidators(oid.Dummy{}))
|
|
require.NoError(err)
|
|
client := http.Client{Transport: &http.Transport{TLSClientConfig: clientConfig}}
|
|
|
|
//
|
|
// Prepare a request for each server.
|
|
//
|
|
|
|
var reqs []*http.Request
|
|
for _, url := range urls {
|
|
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, http.NoBody)
|
|
require.NoError(err)
|
|
reqs = append(reqs, req)
|
|
}
|
|
|
|
//
|
|
// Do the request concurrently and collect the errors in a channel.
|
|
// The config of the client is reused, so the nonce isn't fresh.
|
|
// This explicitly checks for data races on the clientConnection.
|
|
//
|
|
|
|
errChan := make(chan error, serverCount)
|
|
|
|
for _, req := range reqs {
|
|
go func(req *http.Request) {
|
|
resp, err := client.Do(req)
|
|
if err == nil {
|
|
resp.Body.Close()
|
|
}
|
|
errChan <- err
|
|
}(req)
|
|
}
|
|
|
|
//
|
|
// Wait for the requests to finish and check the errors.
|
|
//
|
|
|
|
for i := 0; i < serverCount; i++ {
|
|
assert.NoError(<-errChan)
|
|
}
|
|
}
|
|
|
|
func TestServerConnectionConcurrency(t *testing.T) {
|
|
require := require.New(t)
|
|
assert := assert.New(t)
|
|
|
|
//
|
|
// Create servers.
|
|
// The serverCfg is reused.
|
|
//
|
|
|
|
const serverCount = 10
|
|
|
|
var urls []string
|
|
|
|
serverCfg, err := CreateAttestationServerTLSConfig(NewFakeIssuer(oid.Dummy{}), NewFakeValidators(oid.Dummy{}))
|
|
require.NoError(err)
|
|
|
|
for i := 0; i < serverCount; i++ {
|
|
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = io.WriteString(w, "hello")
|
|
}))
|
|
server.TLS = serverCfg
|
|
|
|
server.StartTLS()
|
|
defer server.Close()
|
|
|
|
urls = append(urls, server.URL)
|
|
}
|
|
|
|
//
|
|
// Create client.
|
|
//
|
|
|
|
clientConfig, err := CreateAttestationClientTLSConfig(NewFakeIssuer(oid.Dummy{}), NewFakeValidators(oid.Dummy{}))
|
|
require.NoError(err)
|
|
client := http.Client{Transport: &http.Transport{TLSClientConfig: clientConfig}}
|
|
|
|
//
|
|
// Prepare a request for each server.
|
|
//
|
|
|
|
var reqs []*http.Request
|
|
for _, url := range urls {
|
|
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, http.NoBody)
|
|
require.NoError(err)
|
|
reqs = append(reqs, req)
|
|
}
|
|
|
|
//
|
|
// Do the request concurrently and collect the errors in a channel.
|
|
//
|
|
|
|
errChan := make(chan error, serverCount)
|
|
|
|
for _, req := range reqs {
|
|
go func(req *http.Request) {
|
|
resp, err := client.Do(req)
|
|
if err == nil {
|
|
resp.Body.Close()
|
|
}
|
|
errChan <- err
|
|
}(req)
|
|
}
|
|
|
|
//
|
|
// Wait for the requests to finish and check the errors.
|
|
//
|
|
|
|
for i := 0; i < serverCount; i++ {
|
|
assert.NoError(<-errChan)
|
|
}
|
|
}
|