mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-01-12 16:09:39 -05:00
131 lines
2.8 KiB
Go
131 lines
2.8 KiB
Go
|
package atls
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"encoding/asn1"
|
||
|
"encoding/json"
|
||
|
"errors"
|
||
|
"io"
|
||
|
"net/http"
|
||
|
"net/http/httptest"
|
||
|
"testing"
|
||
|
|
||
|
"github.com/stretchr/testify/assert"
|
||
|
"github.com/stretchr/testify/require"
|
||
|
)
|
||
|
|
||
|
func TestTLSConfig(t *testing.T) {
|
||
|
oid1 := fakeOID{1, 3, 9900, 1}
|
||
|
oid2 := fakeOID{1, 3, 9900, 2}
|
||
|
|
||
|
testCases := map[string]struct {
|
||
|
issuer Issuer
|
||
|
validators []Validator
|
||
|
expectErr bool
|
||
|
}{
|
||
|
"basic": {
|
||
|
issuer: fakeIssuer{fakeOID: oid1},
|
||
|
validators: []Validator{fakeValidator{fakeOID: oid1}},
|
||
|
},
|
||
|
"multiple validators": {
|
||
|
issuer: fakeIssuer{fakeOID: oid2},
|
||
|
validators: []Validator{fakeValidator{fakeOID: oid1}, fakeValidator{fakeOID: oid2}},
|
||
|
},
|
||
|
"validate error": {
|
||
|
issuer: fakeIssuer{fakeOID: oid1},
|
||
|
validators: []Validator{fakeValidator{fakeOID: oid1, err: errors.New("failed")}},
|
||
|
expectErr: true,
|
||
|
},
|
||
|
"unknown oid": {
|
||
|
issuer: fakeIssuer{fakeOID: oid1},
|
||
|
validators: []Validator{fakeValidator{fakeOID: oid2}},
|
||
|
expectErr: true,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for name, tc := range testCases {
|
||
|
t.Run(name, func(t *testing.T) {
|
||
|
assert := assert.New(t)
|
||
|
require := require.New(t)
|
||
|
|
||
|
//
|
||
|
// Create server
|
||
|
//
|
||
|
|
||
|
serverConfig, err := CreateAttestationServerTLSConfig(tc.issuer)
|
||
|
require.NoError(err)
|
||
|
|
||
|
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
_, _ = io.WriteString(w, "hello")
|
||
|
}))
|
||
|
server.TLS = serverConfig
|
||
|
|
||
|
//
|
||
|
// Create client
|
||
|
//
|
||
|
|
||
|
clientConfig, err := CreateAttestationClientTLSConfig(tc.validators)
|
||
|
require.NoError(err)
|
||
|
client := http.Client{Transport: &http.Transport{TLSClientConfig: clientConfig}}
|
||
|
|
||
|
//
|
||
|
// Test connection
|
||
|
//
|
||
|
|
||
|
server.StartTLS()
|
||
|
defer server.Close()
|
||
|
|
||
|
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, http.NoBody)
|
||
|
require.NoError(err)
|
||
|
resp, err := client.Do(req)
|
||
|
if tc.expectErr {
|
||
|
assert.Error(err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
require.NoError(err)
|
||
|
defer resp.Body.Close()
|
||
|
|
||
|
body, err := io.ReadAll(resp.Body)
|
||
|
require.NoError(err)
|
||
|
assert.EqualValues("hello", body)
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type fakeIssuer struct {
|
||
|
fakeOID
|
||
|
}
|
||
|
|
||
|
func (fakeIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) {
|
||
|
return json.Marshal(fakeDoc{UserData: userData, Nonce: nonce})
|
||
|
}
|
||
|
|
||
|
type fakeValidator struct {
|
||
|
fakeOID
|
||
|
err error
|
||
|
}
|
||
|
|
||
|
func (v fakeValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
|
||
|
var doc fakeDoc
|
||
|
if err := json.Unmarshal(attDoc, &doc); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if !bytes.Equal(doc.Nonce, nonce) {
|
||
|
return nil, errors.New("invalid nonce")
|
||
|
}
|
||
|
return doc.UserData, v.err
|
||
|
}
|
||
|
|
||
|
type fakeOID asn1.ObjectIdentifier
|
||
|
|
||
|
func (o fakeOID) OID() asn1.ObjectIdentifier {
|
||
|
return asn1.ObjectIdentifier(o)
|
||
|
}
|
||
|
|
||
|
type fakeDoc struct {
|
||
|
UserData []byte
|
||
|
Nonce []byte
|
||
|
}
|