168 lines
4.8 KiB
Go
Raw Normal View History

package main
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"log"
"net"
"os"
"time"
"github.com/edgelesssys/constellation/cli/status"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
"github.com/edgelesssys/constellation/coordinator/oid"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
coordinatorstate "github.com/edgelesssys/constellation/coordinator/state"
"github.com/spf13/afero"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
var (
coordIP = flag.String("coord-ip", "", "IP of the VM the Coordinator is running on")
coordinatorPort = flag.String("coord-port", "9000", "Port of the Coordinator's pub API")
export = flag.String("o", "", "Write PCRs, formatted as Go code, to file")
quiet = flag.Bool("q", false, "Set to disable output")
)
func main() {
flag.Parse()
fmt.Printf("connecting to Coordinator at %s:%s\n", *coordIP, *coordinatorPort)
addr := net.JoinHostPort(*coordIP, *coordinatorPort)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
// wait for coordinator to come online
waiter := status.NewWaiter(map[uint32][]byte{})
if err := waiter.WaitFor(ctx, coordinatorstate.AcceptingInit, addr); err != nil {
log.Fatal(err)
}
attDocRaw := &[]byte{}
tlsConfig, err := atls.CreateUnverifiedClientTLSConfig()
if err != nil {
log.Fatal(err)
}
tlsConfig.VerifyPeerCertificate = getVerifyPeerCertificateFunc(attDocRaw)
if err := connectToCoordinator(ctx, addr, tlsConfig); err != nil {
log.Fatal(err)
}
pcrs, err := validatePCRAttDoc(*attDocRaw)
if err != nil {
log.Fatal(err)
}
if !*quiet {
if err := printPCRs(os.Stdout, pcrs); err != nil {
log.Fatal(err)
}
}
if *export != "" {
if err := exportToFile(*export, pcrs, &afero.Afero{Fs: afero.NewOsFs()}); err != nil {
log.Fatal(err)
}
}
}
// connectToCoordinator connects to the Constellation Coordinator and returns its attestation document.
func connectToCoordinator(ctx context.Context, addr string, tlsConfig *tls.Config) error {
conn, err := grpc.DialContext(
ctx, addr, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
)
if err != nil {
return err
}
defer conn.Close()
client := pubproto.NewAPIClient(conn)
_, err = client.GetState(ctx, &pubproto.GetStateRequest{})
return err
}
// getVerifyPeerCertificateFunc returns a VerifyPeerCertificate function, which writes the attestation document extension to the given byte slice pointer.
func getVerifyPeerCertificateFunc(attDoc *[]byte) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
if len(rawCerts) == 0 {
return errors.New("rawCerts is empty")
}
cert, err := x509.ParseCertificate(rawCerts[0])
if err != nil {
return err
}
for _, ex := range cert.Extensions {
if ex.Id.Equal(oid.Azure{}.OID()) || ex.Id.Equal(oid.GCP{}.OID()) || ex.Id.Equal(oid.GCPNonCVM{}.OID()) {
if err := json.Unmarshal(ex.Value, attDoc); err != nil {
*attDoc = ex.Value
}
}
}
if len(*attDoc) == 0 {
return errors.New("did not receive attestation document in certificate extension")
}
return nil
}
}
// validatePCRAttDoc parses and validates PCRs of an attestation document.
func validatePCRAttDoc(attDocRaw []byte) (map[uint32][]byte, error) {
attDoc := vtpm.AttestationDocument{}
if err := json.Unmarshal(attDocRaw, &attDoc); err != nil {
return nil, err
}
if attDoc.Attestation == nil {
return nil, errors.New("empty attestation")
}
qIdx, err := vtpm.GetSHA256QuoteIndex(attDoc.Attestation.Quotes)
if err != nil {
return nil, err
}
for idx, pcr := range attDoc.Attestation.Quotes[qIdx].Pcrs.Pcrs {
if len(pcr) != 32 {
return nil, fmt.Errorf("incomplete PCR at index: %d", idx)
}
}
return attDoc.Attestation.Quotes[qIdx].Pcrs.Pcrs, nil
}
// printPCRs formates and prints PCRs to the given writer.
func printPCRs(w io.Writer, pcrs map[uint32][]byte) error {
pcrJSON, err := json.MarshalIndent(pcrs, "", " ")
if err != nil {
return err
}
fmt.Fprintf(w, "PCRs:\n%s\n", string(pcrJSON))
return nil
}
// exportToFile writes pcrs to a file, formatted to be valid Go code.
// Validity of the PCR map is not checked, and should be handled by the caller.
func exportToFile(path string, pcrs map[uint32][]byte, fs *afero.Afero) error {
goCode := `package pcrs
var pcrs = map[uint32][]byte{%s
}
`
pcrsFormatted := ""
for i := 0; i < len(pcrs); i++ {
pcrHex := fmt.Sprintf("%#02X", pcrs[uint32(i)][0])
for j := 1; j < len(pcrs[uint32(i)]); j++ {
pcrHex = fmt.Sprintf("%s, %#02X", pcrHex, pcrs[uint32(i)][j])
}
pcrsFormatted = pcrsFormatted + fmt.Sprintf("\n\t%d: {%s},", i, pcrHex)
}
return fs.WriteFile(path, []byte(fmt.Sprintf(goCode, pcrsFormatted)), 0o644)
}