mirror of
https://github.com/edgelesssys/constellation.git
synced 2024-10-01 01:36:09 -04:00
209 lines
6.0 KiB
Go
209 lines
6.0 KiB
Go
|
/*
|
||
|
Copyright (c) Edgeless Systems GmbH
|
||
|
|
||
|
SPDX-License-Identifier: AGPL-3.0-only
|
||
|
*/
|
||
|
|
||
|
// End-to-end test that issues various types of malicious join requests to a cluster.
|
||
|
package main
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"encoding/json"
|
||
|
"flag"
|
||
|
"fmt"
|
||
|
"net"
|
||
|
"strings"
|
||
|
|
||
|
"github.com/edgelesssys/constellation/v2/internal/attestation/variant"
|
||
|
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
|
||
|
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
|
||
|
"github.com/edgelesssys/constellation/v2/internal/logger"
|
||
|
"github.com/edgelesssys/constellation/v2/joinservice/joinproto"
|
||
|
"go.uber.org/zap/zapcore"
|
||
|
)
|
||
|
|
||
|
func main() {
|
||
|
jsEndpoint := flag.String("js-endpoint", "", "Join service endpoint to use.")
|
||
|
csp := flag.String("csp", "", "Cloud service provider to use.")
|
||
|
attVariant := flag.String(
|
||
|
"variant",
|
||
|
"",
|
||
|
fmt.Sprintf("Attestation variant to use. Set to \"default\" to use the default attestation variant for the CSP,"+
|
||
|
"or one of: %s", variant.GetAvailableAttestationVariants()),
|
||
|
)
|
||
|
flag.Parse()
|
||
|
fmt.Println(formatFlags(*attVariant, *csp, *jsEndpoint))
|
||
|
|
||
|
testCases := map[string]struct {
|
||
|
fn func(attVariant, csp, jsEndpoint string) error
|
||
|
wantErr bool
|
||
|
}{
|
||
|
"JoinFromUnattestedNode": {
|
||
|
fn: JoinFromUnattestedNode,
|
||
|
wantErr: true,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
allPassed := true
|
||
|
testOutput := &testOutput{
|
||
|
TestCases: make(map[string]testCaseOutput),
|
||
|
}
|
||
|
for name, tc := range testCases {
|
||
|
fmt.Printf("Running testcase %s\n", name)
|
||
|
|
||
|
err := tc.fn(*attVariant, *csp, *jsEndpoint)
|
||
|
|
||
|
switch {
|
||
|
case err == nil && tc.wantErr:
|
||
|
fmt.Printf("Test case %s failed: Expected error but got none\n", name)
|
||
|
testOutput.TestCases[name] = testCaseOutput{
|
||
|
Passed: false,
|
||
|
Message: "Expected error but got none",
|
||
|
}
|
||
|
allPassed = false
|
||
|
case !tc.wantErr && err != nil:
|
||
|
fmt.Printf("Test case %s failed: Got unexpected error: %s\n", name, err)
|
||
|
testOutput.TestCases[name] = testCaseOutput{
|
||
|
Passed: false,
|
||
|
Message: fmt.Sprintf("Got unexpected error: %s", err),
|
||
|
}
|
||
|
allPassed = false
|
||
|
case tc.wantErr && err != nil:
|
||
|
fmt.Printf("Test case %s succeeded\n", name)
|
||
|
testOutput.TestCases[name] = testCaseOutput{
|
||
|
Passed: true,
|
||
|
Message: fmt.Sprintf("Got expected error: %s", err),
|
||
|
}
|
||
|
case !tc.wantErr && err == nil:
|
||
|
fmt.Printf("Test case %s succeeded\n", name)
|
||
|
testOutput.TestCases[name] = testCaseOutput{
|
||
|
Passed: true,
|
||
|
Message: "No error, as expected",
|
||
|
}
|
||
|
default:
|
||
|
panic("invalid result")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
testOutput.AllPassed = allPassed
|
||
|
out, err := json.Marshal(testOutput)
|
||
|
if err != nil {
|
||
|
panic(fmt.Sprintf("marshalling test output: %s", err))
|
||
|
}
|
||
|
fmt.Println(string(out))
|
||
|
}
|
||
|
|
||
|
type testOutput struct {
|
||
|
AllPassed bool `json:"allPassed"`
|
||
|
TestCases map[string]testCaseOutput `json:"testCases"`
|
||
|
}
|
||
|
|
||
|
type testCaseOutput struct {
|
||
|
Passed bool `json:"passed"`
|
||
|
Message string `json:"message"`
|
||
|
}
|
||
|
|
||
|
func formatFlags(attVariant, csp, jsEndpoint string) string {
|
||
|
var sb strings.Builder
|
||
|
sb.WriteString("Using Flags:\n")
|
||
|
sb.WriteString(fmt.Sprintf("\tjs-endpoint: %s\n", jsEndpoint))
|
||
|
sb.WriteString(fmt.Sprintf("\tcsp: %s\n", csp))
|
||
|
sb.WriteString(fmt.Sprintf("\tvariant: %s\n", attVariant))
|
||
|
return sb.String()
|
||
|
}
|
||
|
|
||
|
// JoinFromUnattestedNode simulates a join request from a Node that uses a stub issuer
|
||
|
// and thus cannot be attested correctly.
|
||
|
func JoinFromUnattestedNode(attVariant, csp, jsEndpoint string) error {
|
||
|
log := logger.New(logger.JSONLog, zapcore.DebugLevel)
|
||
|
joiner, err := newMaliciousJoiner(attVariant, csp, jsEndpoint, log)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("creating malicious joiner: %w", err)
|
||
|
}
|
||
|
|
||
|
_, err = joiner.join(context.Background())
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("joining cluster: %w", err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// newMaliciousJoiner creates a new malicious joiner, i.e. a simulated node that issues
|
||
|
// an invalid join request.
|
||
|
func newMaliciousJoiner(attVariant, csp, endpoint string, log *logger.Logger) (*maliciousJoiner, error) {
|
||
|
var attVariantOid variant.Variant
|
||
|
var err error
|
||
|
if strings.EqualFold(attVariant, "default") {
|
||
|
attVariantOid = variant.GetDefaultAttestation(cloudprovider.FromString(csp))
|
||
|
} else {
|
||
|
attVariantOid, err = variant.FromString(attVariant)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("parsing attestation variant: %w", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
issuer := newFakeIssuer(attVariantOid)
|
||
|
|
||
|
return &maliciousJoiner{
|
||
|
endpoint: endpoint,
|
||
|
logger: log,
|
||
|
dialer: dialer.New(issuer, nil, &net.Dialer{}),
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
// maliciousJoiner simulates a malicious node joining a cluster.
|
||
|
type maliciousJoiner struct {
|
||
|
endpoint string
|
||
|
logger *logger.Logger
|
||
|
dialer *dialer.Dialer
|
||
|
}
|
||
|
|
||
|
// join issues a join request to the join service endpoint.
|
||
|
func (j *maliciousJoiner) join(ctx context.Context) (*joinproto.IssueJoinTicketResponse, error) {
|
||
|
j.logger.Debugf("Dialing join service endpoint %s", j.endpoint)
|
||
|
conn, err := j.dialer.Dial(ctx, j.endpoint)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("dialing join service endpoint: %w", err)
|
||
|
}
|
||
|
defer conn.Close()
|
||
|
j.logger.Debugf("Successfully dialed join service endpoint %s", j.endpoint)
|
||
|
|
||
|
protoClient := joinproto.NewAPIClient(conn)
|
||
|
|
||
|
j.logger.Debugf("Issuing join ticket")
|
||
|
req := &joinproto.IssueJoinTicketRequest{
|
||
|
DiskUuid: "",
|
||
|
CertificateRequest: []byte{},
|
||
|
IsControlPlane: false,
|
||
|
}
|
||
|
res, err := protoClient.IssueJoinTicket(ctx, req)
|
||
|
j.logger.Debugf("Got join ticket response: %+v", res)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("issuing join ticket: %w", err)
|
||
|
}
|
||
|
|
||
|
return res, nil
|
||
|
}
|
||
|
|
||
|
// newFakeIssuer creates a new fake issuer for a given attestation variant.
|
||
|
func newFakeIssuer(oid variant.Getter) *fakeIssuer {
|
||
|
return &fakeIssuer{oid}
|
||
|
}
|
||
|
|
||
|
// fakeIssuer simulates an issuer that issues a fake / invalid attestation document.
|
||
|
type fakeIssuer struct {
|
||
|
variant.Getter
|
||
|
}
|
||
|
|
||
|
// Issue issues a fake attestation document.
|
||
|
func (i *fakeIssuer) Issue(_ context.Context, userData, nonce []byte) ([]byte, error) {
|
||
|
return json.Marshal(fakeAttestationDoc{UserData: userData, Nonce: nonce})
|
||
|
}
|
||
|
|
||
|
// fakeAttestationDoc is a fake attestation document.
|
||
|
type fakeAttestationDoc struct {
|
||
|
UserData []byte
|
||
|
Nonce []byte
|
||
|
}
|