mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-01-03 11:50:57 -05:00
128 lines
2.8 KiB
Go
128 lines
2.8 KiB
Go
/*
|
|
Copyright (c) Edgeless Systems GmbH
|
|
|
|
SPDX-License-Identifier: AGPL-3.0-only
|
|
*/
|
|
|
|
package atlscredentials
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/asn1"
|
|
"encoding/json"
|
|
"errors"
|
|
"net"
|
|
"testing"
|
|
|
|
"github.com/edgelesssys/constellation/v2/bootstrapper/initproto"
|
|
"github.com/edgelesssys/constellation/v2/internal/atls"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/goleak"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/test/bufconn"
|
|
)
|
|
|
|
func TestMain(m *testing.M) {
|
|
goleak.VerifyTestMain(m, goleak.IgnoreAnyFunction("github.com/bazelbuild/rules_go/go/tools/bzltestutil.RegisterTimeoutHandler.func1"))
|
|
}
|
|
|
|
func TestATLSCredentials(t *testing.T) {
|
|
assert := assert.New(t)
|
|
require := require.New(t)
|
|
|
|
oid := fakeOID{1, 3, 9900, 1}
|
|
|
|
//
|
|
// Create servers
|
|
//
|
|
|
|
serverCreds := New(fakeIssuer{fakeOID: oid}, nil)
|
|
|
|
const serverCount = 15
|
|
var listeners []*bufconn.Listener
|
|
for i := 0; i < serverCount; i++ {
|
|
api := &fakeAPI{}
|
|
server := grpc.NewServer(grpc.Creds(serverCreds))
|
|
initproto.RegisterAPIServer(server, api)
|
|
|
|
listener := bufconn.Listen(1024)
|
|
listeners = append(listeners, listener)
|
|
|
|
defer server.GracefulStop()
|
|
go server.Serve(listener)
|
|
}
|
|
|
|
//
|
|
// Dial concurrently
|
|
//
|
|
|
|
clientCreds := New(nil, []atls.Validator{fakeValidator{fakeOID: oid}})
|
|
|
|
errChan := make(chan error, serverCount)
|
|
for _, listener := range listeners {
|
|
lis := listener
|
|
go func() {
|
|
var err error
|
|
defer func() { errChan <- err }()
|
|
conn, err := grpc.DialContext(context.Background(), "", grpc.WithContextDialer(func(_ context.Context, _ string) (net.Conn, error) {
|
|
return lis.Dial()
|
|
}), grpc.WithTransportCredentials(clientCreds))
|
|
require.NoError(err)
|
|
defer conn.Close()
|
|
|
|
client := initproto.NewAPIClient(conn)
|
|
_, err = client.Init(context.Background(), &initproto.InitRequest{})
|
|
}()
|
|
}
|
|
|
|
for i := 0; i < serverCount; i++ {
|
|
assert.NoError(<-errChan)
|
|
}
|
|
}
|
|
|
|
type fakeIssuer struct {
|
|
fakeOID
|
|
}
|
|
|
|
func (fakeIssuer) Issue(_ context.Context, userData []byte, nonce []byte) ([]byte, error) {
|
|
return json.Marshal(fakeDoc{UserData: userData, Nonce: nonce})
|
|
}
|
|
|
|
type fakeValidator struct {
|
|
fakeOID
|
|
err error
|
|
}
|
|
|
|
func (v fakeValidator) Validate(_ context.Context, attDoc []byte, nonce []byte) ([]byte, error) {
|
|
var doc fakeDoc
|
|
if err := json.Unmarshal(attDoc, &doc); err != nil {
|
|
return nil, err
|
|
}
|
|
if !bytes.Equal(doc.Nonce, nonce) {
|
|
return nil, errors.New("invalid nonce")
|
|
}
|
|
return doc.UserData, v.err
|
|
}
|
|
|
|
type fakeOID asn1.ObjectIdentifier
|
|
|
|
func (o fakeOID) OID() asn1.ObjectIdentifier {
|
|
return asn1.ObjectIdentifier(o)
|
|
}
|
|
|
|
type fakeDoc struct {
|
|
UserData []byte
|
|
Nonce []byte
|
|
}
|
|
|
|
type fakeAPI struct {
|
|
initproto.UnimplementedAPIServer
|
|
}
|
|
|
|
func (f *fakeAPI) Init(_ *initproto.InitRequest, stream initproto.API_InitServer) error {
|
|
_ = stream.Send(&initproto.InitResponse{})
|
|
return nil
|
|
}
|