Daniel Weiße 869448c3e1 Add mutual aTLS support (#176)
Signed-off-by: Daniel Weiße <dw@edgeless.systems>
2022-05-24 16:33:44 +02:00

206 lines
6.1 KiB
Go

package atls
import (
"bytes"
"context"
"encoding/asn1"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestTLSConfig(t *testing.T) {
oid1 := fakeOID{1, 3, 9900, 1}
oid2 := fakeOID{1, 3, 9900, 2}
testCases := map[string]struct {
clientIssuer Issuer
clientValidators []Validator
serverIssuer Issuer
serverValidators []Validator
wantErr bool
}{
"client->server basic": {
serverIssuer: fakeIssuer{fakeOID: oid1},
clientValidators: []Validator{fakeValidator{fakeOID: oid1}},
},
"client->server multiple validators": {
serverIssuer: fakeIssuer{fakeOID: oid2},
clientValidators: []Validator{fakeValidator{fakeOID: oid1}, fakeValidator{fakeOID: oid2}},
},
"client->server validate error": {
serverIssuer: fakeIssuer{fakeOID: oid1},
clientValidators: []Validator{fakeValidator{fakeOID: oid1, err: errors.New("failed")}},
wantErr: true,
},
"client->server unknown oid": {
serverIssuer: fakeIssuer{fakeOID: oid1},
clientValidators: []Validator{fakeValidator{fakeOID: oid2}},
wantErr: true,
},
"client->server client cert is not verified": {
serverIssuer: fakeIssuer{fakeOID: oid1},
clientIssuer: fakeIssuer{fakeOID: oid1},
clientValidators: []Validator{fakeValidator{fakeOID: oid1}},
},
"server->client basic": {
serverIssuer: fakeIssuer{fakeOID: oid1},
serverValidators: []Validator{fakeValidator{fakeOID: oid1}},
clientIssuer: fakeIssuer{fakeOID: oid1},
},
"server->client multiple validators": {
serverIssuer: fakeIssuer{fakeOID: oid1},
serverValidators: []Validator{fakeValidator{fakeOID: oid1}, fakeValidator{fakeOID: oid2}},
clientIssuer: fakeIssuer{fakeOID: oid2},
},
"server->client validate error": {
serverIssuer: fakeIssuer{fakeOID: oid1},
serverValidators: []Validator{fakeValidator{fakeOID: oid1, err: errors.New("failed")}},
clientIssuer: fakeIssuer{fakeOID: oid1},
wantErr: true,
},
"server->client unknown oid": {
serverIssuer: fakeIssuer{fakeOID: oid2},
serverValidators: []Validator{fakeValidator{fakeOID: oid2}},
clientIssuer: fakeIssuer{fakeOID: oid1},
wantErr: true,
},
"mutual basic": {
serverIssuer: fakeIssuer{fakeOID: oid1},
serverValidators: []Validator{fakeValidator{fakeOID: oid1}},
clientIssuer: fakeIssuer{fakeOID: oid1},
clientValidators: []Validator{fakeValidator{fakeOID: oid1}},
},
"mutual multiple validators": {
serverIssuer: fakeIssuer{fakeOID: oid2},
serverValidators: []Validator{fakeValidator{fakeOID: oid1}, fakeValidator{fakeOID: oid2}},
clientIssuer: fakeIssuer{fakeOID: oid2},
clientValidators: []Validator{fakeValidator{fakeOID: oid1}, fakeValidator{fakeOID: oid2}},
},
"mutual fails if client sends no cert": {
serverIssuer: fakeIssuer{fakeOID: oid1},
serverValidators: []Validator{fakeValidator{fakeOID: oid1}},
clientValidators: []Validator{fakeValidator{fakeOID: oid1}},
wantErr: true,
},
"mutual validate error client side": {
serverIssuer: fakeIssuer{fakeOID: oid1},
serverValidators: []Validator{fakeValidator{fakeOID: oid1}},
clientIssuer: fakeIssuer{fakeOID: oid1},
clientValidators: []Validator{fakeValidator{fakeOID: oid1, err: errors.New("failed")}},
wantErr: true,
},
"mutual validate error server side": {
serverIssuer: fakeIssuer{fakeOID: oid1},
serverValidators: []Validator{fakeValidator{fakeOID: oid1, err: errors.New("failed")}},
clientIssuer: fakeIssuer{fakeOID: oid1},
clientValidators: []Validator{fakeValidator{fakeOID: oid1}},
wantErr: true,
},
"mutual unknown oid from client": {
serverIssuer: fakeIssuer{fakeOID: oid1},
serverValidators: []Validator{fakeValidator{fakeOID: oid1}},
clientIssuer: fakeIssuer{fakeOID: oid2},
clientValidators: []Validator{fakeValidator{fakeOID: oid1}},
wantErr: true,
},
"mutual unknown oid from server": {
serverIssuer: fakeIssuer{fakeOID: oid2},
serverValidators: []Validator{fakeValidator{fakeOID: oid1}},
clientIssuer: fakeIssuer{fakeOID: oid1},
clientValidators: []Validator{fakeValidator{fakeOID: oid1}},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
//
// Create server
//
serverConfig, err := CreateAttestationServerTLSConfig(tc.serverIssuer, tc.serverValidators)
require.NoError(err)
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "hello")
}))
server.TLS = serverConfig
//
// Create client
//
clientConfig, err := CreateAttestationClientTLSConfig(tc.clientIssuer, tc.clientValidators)
require.NoError(err)
client := http.Client{Transport: &http.Transport{TLSClientConfig: clientConfig}}
//
// Test connection
//
server.StartTLS()
defer server.Close()
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, http.NoBody)
require.NoError(err)
resp, err := client.Do(req)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(err)
assert.EqualValues("hello", body)
})
}
}
type fakeIssuer struct {
fakeOID
}
func (fakeIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) {
return json.Marshal(fakeDoc{UserData: userData, Nonce: nonce})
}
type fakeValidator struct {
fakeOID
err error
}
func (v fakeValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
var doc fakeDoc
if err := json.Unmarshal(attDoc, &doc); err != nil {
return nil, err
}
if !bytes.Equal(doc.Nonce, nonce) {
return nil, errors.New("invalid nonce")
}
return doc.UserData, v.err
}
type fakeOID asn1.ObjectIdentifier
func (o fakeOID) OID() asn1.ObjectIdentifier {
return asn1.ObjectIdentifier(o)
}
type fakeDoc struct {
UserData []byte
Nonce []byte
}