diff --git a/hack/image-measurement/main.go b/hack/image-measurement/main.go index 222f09861..f806b4a65 100644 --- a/hack/image-measurement/main.go +++ b/hack/image-measurement/main.go @@ -16,10 +16,12 @@ import ( "syscall" "github.com/edgelesssys/constellation/hack/image-measurement/server" + "github.com/edgelesssys/constellation/internal/config" "github.com/edgelesssys/constellation/internal/logger" "go.uber.org/multierr" "go.uber.org/zap" "go.uber.org/zap/zapcore" + "gopkg.in/yaml.v3" "libvirt.org/go/libvirt" ) @@ -280,10 +282,10 @@ func (l *libvirtInstance) deleteLibvirtInstance() error { return err } -func (l *libvirtInstance) obtainMeasurements() (err error) { +func (l *libvirtInstance) obtainMeasurements() (measurements config.Measurements, err error) { // sanity check if err := l.deleteLibvirtInstance(); err != nil { - return err + return nil, err } done := make(chan struct{}, 1) serv := server.New(l.log, done) @@ -296,7 +298,7 @@ func (l *libvirtInstance) obtainMeasurements() (err error) { err = multierr.Append(err, l.deleteLibvirtInstance()) }() if err := l.createLibvirtInstance(); err != nil { - return err + return nil, err } sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) @@ -310,19 +312,26 @@ func (l *libvirtInstance) obtainMeasurements() (err error) { signal.Stop(sigs) close(sigs) if err := serv.Shutdown(); err != nil { - return err + return nil, err } close(done) - return nil + + return serv.GetMeasurements(), nil } func main() { - var imageLocation string - var imageType string + var imageLocation, imageType, outFile string + var verboseLog bool flag.StringVar(&imageLocation, "path", "", "path to the image to measure (required)") flag.StringVar(&imageType, "type", "", "type of the image. One of 'qcow2' or 'raw' (required)") + flag.StringVar(&outFile, "file", "-", "path to output file, or '-' for stdout") + flag.BoolVar(&verboseLog, "v", false, "verbose logging") + flag.Parse() - log := logger.New(logger.JSONLog, zapcore.InfoLevel) + log := logger.New(logger.JSONLog, zapcore.DebugLevel) + if !verboseLog { + log = log.WithIncreasedLevel(zapcore.FatalLevel) // Only print fatal errors in non-verbose mode + } if imageLocation == "" || imageType == "" { flag.Usage() @@ -343,8 +352,22 @@ func main() { imagePath: imageLocation, } - if err := lInstance.obtainMeasurements(); err != nil { + measurements, err := lInstance.obtainMeasurements() + if err != nil { log.With(zap.Error(err)).Fatalf("Failed to obtain PCR measurements") } - log.Infof("instaces terminated successfully") + log.Infof("instances terminated successfully") + + output, err := yaml.Marshal(measurements) + if err != nil { + log.With(zap.Error(err)).Fatalf("Failed to marshal measurements") + } + + if outFile == "-" { + fmt.Println(string(output)) + } else { + if err := os.WriteFile(outFile, output, 0o644); err != nil { + log.With(zap.Error(err)).Fatalf("Failed to write measurements to file") + } + } } diff --git a/hack/image-measurement/server/server.go b/hack/image-measurement/server/server.go index 68f385fdc..b4ebe2ce3 100644 --- a/hack/image-measurement/server/server.go +++ b/hack/image-measurement/server/server.go @@ -17,9 +17,10 @@ import ( ) type Server struct { - log *logger.Logger - server http.Server - done chan<- struct{} + log *logger.Logger + server http.Server + measurements map[uint32][]byte + done chan<- struct{} } func New(log *logger.Logger, done chan<- struct{}) *Server { @@ -77,5 +78,12 @@ func (s *Server) logPCRs(w http.ResponseWriter, r *http.Request) { log.Infof("PCR 4 %x", pcrs[4]) log.Infof("PCR 8 %x", pcrs[8]) log.Infof("PCR 9 %x", pcrs[9]) + + s.measurements = pcrs + s.done <- struct{}{} } + +func (s *Server) GetMeasurements() map[uint32][]byte { + return s.measurements +}