Refactor verify command

This commit is contained in:
katexochen 2022-04-27 11:17:41 +02:00 committed by Paul Meyer
parent 019003337f
commit 1317fc2bb2
15 changed files with 757 additions and 982 deletions

View file

@ -1,6 +1,8 @@
package cloudcmd
import (
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"strings"
@ -13,50 +15,120 @@ import (
"github.com/edgelesssys/constellation/internal/config"
)
const warningStr = "Warning: not verifying the Constellation's %s measurements\n"
type Validators struct {
validators []atls.Validator
pcrWarnings string
pcrWarningsInit string
provider cloudprovider.Provider
pcrs map[uint32][]byte
validators []atls.Validator
}
func NewValidators(provider cloudprovider.Provider, config *config.Config) (Validators, error) {
func NewValidators(provider cloudprovider.Provider, config *config.Config) (*Validators, error) {
v := Validators{}
switch provider {
if provider == cloudprovider.Unknown {
return nil, errors.New("unknown cloud provider")
}
v.provider = provider
if err := v.setPCRs(config); err != nil {
return nil, err
}
return &v, nil
}
func (v *Validators) UpdateInitPCRs(ownerID, clusterID string) error {
if err := v.updatePCR(uint32(vtpm.PCRIndexOwnerID), ownerID); err != nil {
return err
}
return v.updatePCR(uint32(vtpm.PCRIndexClusterID), clusterID)
}
// updatePCR adds a new entry to the pcr map of v, or removes the key if the input is an empty string.
//
// When adding, the input is first decoded from base64.
// We then calculate the expected PCR by hashing the input using SHA256,
// appending expected PCR for initialization, and then hashing once more.
func (v *Validators) updatePCR(pcrIndex uint32, encoded string) error {
if encoded == "" {
delete(v.pcrs, pcrIndex)
return nil
}
decoded, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
return fmt.Errorf("input [%s] is not base64 encoded: %w", encoded, err)
}
// new_pcr_value := hash(old_pcr_value || data_to_extend)
// Since we use the TPM2_PCR_Event call to extend the PCR, data_to_extend is the hash of our input
hashedInput := sha256.Sum256(decoded)
expectedPcr := sha256.Sum256(append(v.pcrs[pcrIndex], hashedInput[:]...))
v.pcrs[pcrIndex] = expectedPcr[:]
return nil
}
func (v *Validators) setPCRs(config *config.Config) error {
switch v.provider {
case cloudprovider.GCP:
gcpPCRs := *config.Provider.GCP.PCRs
if err := v.checkPCRs(gcpPCRs); err != nil {
return Validators{}, err
}
v.setPCRWarnings(gcpPCRs)
v.validators = []atls.Validator{
gcp.NewValidator(gcpPCRs),
gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non CVMs.
return err
}
v.pcrs = gcpPCRs
case cloudprovider.Azure:
azurePCRs := *config.Provider.Azure.PCRs
if err := v.checkPCRs(azurePCRs); err != nil {
return Validators{}, err
return err
}
v.setPCRWarnings(azurePCRs)
v.validators = []atls.Validator{
azure.NewValidator(azurePCRs),
}
default:
return Validators{}, errors.New("unsupported cloud provider")
v.pcrs = azurePCRs
}
return v, nil
return nil
}
// V returns validators as list of atls.Validator.
func (v *Validators) V() []atls.Validator {
v.updateValidators()
return v.validators
}
func (v *Validators) updateValidators() {
switch v.provider {
case cloudprovider.GCP:
v.validators = []atls.Validator{
gcp.NewValidator(v.pcrs),
gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non CVMs.
}
case cloudprovider.Azure:
v.validators = []atls.Validator{
azure.NewValidator(v.pcrs),
}
}
}
// Warnings returns warnings for the specifc PCR values that are not verified.
//
// PCR allocation inspired by https://link.springer.com/chapter/10.1007/978-1-4302-6584-9_12#Tab1
func (v *Validators) Warnings() string {
return v.pcrWarnings
sb := &strings.Builder{}
if v.pcrs[0] == nil || v.pcrs[1] == nil {
writeFmt(sb, warningStr, "BIOS")
}
if v.pcrs[2] == nil || v.pcrs[3] == nil {
writeFmt(sb, warningStr, "OPROM")
}
if v.pcrs[4] == nil || v.pcrs[5] == nil {
writeFmt(sb, warningStr, "MBR")
}
// GRUB measures kernel command line and initrd into pcrs 8 and 9
if v.pcrs[8] == nil {
writeFmt(sb, warningStr, "kernel command line")
}
if v.pcrs[9] == nil {
writeFmt(sb, warningStr, "initrd")
}
return sb.String()
}
// WarningsIncludeInit returns warnings for the specifc PCR values that are not verified.
@ -64,10 +136,18 @@ func (v *Validators) Warnings() string {
//
// PCR allocation inspired by https://link.springer.com/chapter/10.1007/978-1-4302-6584-9_12#Tab1
func (v *Validators) WarningsIncludeInit() string {
return v.pcrWarnings + v.pcrWarningsInit
warnings := v.Warnings()
if v.pcrs[uint32(vtpm.PCRIndexOwnerID)] == nil || v.pcrs[uint32(vtpm.PCRIndexClusterID)] == nil {
warnings = warnings + fmt.Sprintf(warningStr, "initialization status")
}
return warnings
}
func (v *Validators) checkPCRs(pcrs map[uint32][]byte) error {
if len(pcrs) == 0 {
return errors.New("no PCR values provided")
}
for k, v := range pcrs {
if len(v) != 32 {
return fmt.Errorf("bad config: PCR[%d]: expected length: %d, but got: %d", k, 32, len(v))
@ -76,37 +156,6 @@ func (v *Validators) checkPCRs(pcrs map[uint32][]byte) error {
return nil
}
func (v *Validators) setPCRWarnings(pcrs map[uint32][]byte) {
const warningStr = "Warning: not verifying the Constellation's %s measurements\n"
sb := &strings.Builder{}
if pcrs[0] == nil || pcrs[1] == nil {
writeFmt(sb, warningStr, "BIOS")
}
if pcrs[2] == nil || pcrs[3] == nil {
writeFmt(sb, warningStr, "OPROM")
}
if pcrs[4] == nil || pcrs[5] == nil {
writeFmt(sb, warningStr, "MBR")
}
// GRUB measures kernel command line and initrd into pcrs 8 and 9
if pcrs[8] == nil {
writeFmt(sb, warningStr, "kernel command line")
}
if pcrs[9] == nil {
writeFmt(sb, warningStr, "initrd")
}
v.pcrWarnings = sb.String()
// Write init warnings separate.
if pcrs[uint32(vtpm.PCRIndexOwnerID)] == nil || pcrs[uint32(vtpm.PCRIndexClusterID)] == nil {
v.pcrWarningsInit = fmt.Sprintf(warningStr, "initialization status")
}
}
func writeFmt(sb *strings.Builder, fmtStr string, args ...interface{}) {
sb.WriteString(fmt.Sprintf(fmtStr, args...))
}

View file

@ -1,21 +1,94 @@
package cloudcmd
import (
"crypto/sha256"
"encoding/base64"
"testing"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/attestation/azure"
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
"github.com/edgelesssys/constellation/internal/config"
"github.com/stretchr/testify/assert"
)
func TestWarnAboutPCRs(t *testing.T) {
func TestNewValidators(t *testing.T) {
zero := []byte("00000000000000000000000000000000")
one := []byte("11111111111111111111111111111111")
testPCRs := map[uint32][]byte{
0: zero,
1: one,
2: zero,
3: one,
4: zero,
5: zero,
}
testCases := map[string]struct {
provider cloudprovider.Provider
config *config.Config
pcrs map[uint32][]byte
wantErr bool
}{
"gcp": {
provider: cloudprovider.GCP,
pcrs: testPCRs,
},
"azure": {
provider: cloudprovider.Azure,
pcrs: testPCRs,
},
"no pcrs provided": {
provider: cloudprovider.Azure,
pcrs: map[uint32][]byte{},
wantErr: true,
},
"invalid pcr length": {
provider: cloudprovider.GCP,
pcrs: map[uint32][]byte{0: []byte("0000000000000000000000000000000")},
wantErr: true,
},
"unknown provider": {
provider: cloudprovider.Unknown,
pcrs: testPCRs,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
conf := &config.Config{Provider: &config.ProviderConfig{}}
if tc.provider == cloudprovider.GCP {
conf.Provider.GCP = &config.GCPConfig{PCRs: &tc.pcrs}
}
if tc.provider == cloudprovider.Azure {
conf.Provider.Azure = &config.AzureConfig{PCRs: &tc.pcrs}
}
validators, err := NewValidators(tc.provider, conf)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.pcrs, validators.pcrs)
assert.Equal(tc.provider, validators.provider)
}
})
}
}
func TestValidatorsWarnings(t *testing.T) {
zero := []byte("00000000000000000000000000000000")
testCases := map[string]struct {
pcrs map[uint32][]byte
wantWarnings []string
wantWInclude []string
wantErr bool
}{
"no warnings": {
pcrs: map[uint32][]byte{
@ -143,57 +216,302 @@ func TestWarnAboutPCRs(t *testing.T) {
},
wantWInclude: []string{"initialization"},
},
"bad config": {
pcrs: map[uint32][]byte{
0: []byte("000"),
},
wantErr: true,
},
}
for _, provider := range []string{"gcp", "azure", "unknown"} {
t.Run(provider, func(t *testing.T) {
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
config := &config.Config{
Provider: &config.ProviderConfig{
Azure: &config.AzureConfig{PCRs: &tc.pcrs},
GCP: &config.GCPConfig{PCRs: &tc.pcrs},
},
}
validators := Validators{pcrs: tc.pcrs}
validators, err := NewValidators(cloudprovider.FromString(provider), config)
warnings := validators.Warnings()
warningsInclueInit := validators.WarningsIncludeInit()
v := validators.V()
warnings := validators.Warnings()
warningsInclueInit := validators.WarningsIncludeInit()
if tc.wantErr || provider == "unknown" {
assert.Error(err)
} else {
assert.NoError(err)
if len(tc.wantWarnings) == 0 {
assert.Empty(warnings)
}
for _, w := range tc.wantWarnings {
assert.Contains(warnings, w)
}
for _, w := range tc.wantWarnings {
assert.Contains(warningsInclueInit, w)
}
if len(tc.wantWInclude) == 0 {
assert.Equal(len(warnings), len(warningsInclueInit))
} else {
assert.Greater(len(warningsInclueInit), len(warnings))
}
for _, w := range tc.wantWInclude {
assert.Contains(warningsInclueInit, w)
}
assert.NotEmpty(v)
}
})
if len(tc.wantWarnings) == 0 {
assert.Empty(warnings)
}
for _, w := range tc.wantWarnings {
assert.Contains(warnings, w)
}
for _, w := range tc.wantWarnings {
assert.Contains(warningsInclueInit, w)
}
if len(tc.wantWInclude) == 0 {
assert.Equal(len(warnings), len(warningsInclueInit))
} else {
assert.Greater(len(warningsInclueInit), len(warnings))
}
for _, w := range tc.wantWInclude {
assert.Contains(warningsInclueInit, w)
}
})
}
}
func TestValidatorsV(t *testing.T) {
zero := []byte("00000000000000000000000000000000")
newTestPCRs := func() map[uint32][]byte {
return map[uint32][]byte{
0: zero,
1: zero,
2: zero,
3: zero,
4: zero,
5: zero,
6: zero,
7: zero,
8: zero,
9: zero,
10: zero,
11: zero,
12: zero,
}
}
testCases := map[string]struct {
provider cloudprovider.Provider
pcrs map[uint32][]byte
wantVs []atls.Validator
}{
"gcp": {
provider: cloudprovider.GCP,
pcrs: newTestPCRs(),
wantVs: []atls.Validator{
gcp.NewValidator(newTestPCRs()),
gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: remove when not longer needed.
},
},
"azure": {
provider: cloudprovider.Azure,
pcrs: newTestPCRs(),
wantVs: []atls.Validator{
azure.NewValidator(newTestPCRs()),
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
validators := &Validators{provider: tc.provider, pcrs: tc.pcrs}
resultValidators := validators.V()
assert.Equal(len(tc.wantVs), len(resultValidators))
for i, resValidator := range resultValidators {
assert.Equal(tc.wantVs[i].OID(), resValidator.OID())
}
})
}
}
func TestValidatorsUpdateInitPCRs(t *testing.T) {
zero := []byte("00000000000000000000000000000000")
one := []byte("11111111111111111111111111111111")
one64 := base64.StdEncoding.EncodeToString(one)
oneHash := sha256.Sum256(one)
pcrZeroUpdatedOne := sha256.Sum256(append(zero, oneHash[:]...))
newTestPCRs := func() map[uint32][]byte {
return map[uint32][]byte{
0: zero,
1: zero,
2: zero,
3: zero,
4: zero,
5: zero,
6: zero,
7: zero,
8: zero,
9: zero,
10: zero,
11: zero,
12: zero,
}
}
testCases := map[string]struct {
provider cloudprovider.Provider
pcrs map[uint32][]byte
ownerID string
clusterID string
wantErr bool
}{
"gcp update owner ID": {
provider: cloudprovider.GCP,
pcrs: newTestPCRs(),
ownerID: one64,
},
"gcp update cluster ID": {
provider: cloudprovider.GCP,
pcrs: newTestPCRs(),
clusterID: one64,
},
"gcp update both": {
provider: cloudprovider.GCP,
pcrs: newTestPCRs(),
ownerID: one64,
clusterID: one64,
},
"azure update owner ID": {
provider: cloudprovider.Azure,
pcrs: newTestPCRs(),
ownerID: one64,
},
"azure update cluster ID": {
provider: cloudprovider.Azure,
pcrs: newTestPCRs(),
clusterID: one64,
},
"azure update both": {
provider: cloudprovider.Azure,
pcrs: newTestPCRs(),
ownerID: one64,
clusterID: one64,
},
"owner ID and cluster ID empty": {
provider: cloudprovider.GCP,
pcrs: newTestPCRs(),
},
"invalid encoding": {
provider: cloudprovider.GCP,
pcrs: newTestPCRs(),
ownerID: "invalid",
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
validators := &Validators{provider: tc.provider, pcrs: tc.pcrs}
err := validators.UpdateInitPCRs(tc.ownerID, tc.clusterID)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
for i := 0; i < len(tc.pcrs); i++ {
switch {
case i == int(vtpm.PCRIndexClusterID) && tc.clusterID == "":
// should be deleted
_, ok := validators.pcrs[uint32(i)]
assert.False(ok)
case i == int(vtpm.PCRIndexClusterID):
pcr, ok := validators.pcrs[uint32(i)]
assert.True(ok)
assert.Equal(pcrZeroUpdatedOne[:], pcr)
case i == int(vtpm.PCRIndexOwnerID) && tc.ownerID == "":
// should be deleted
_, ok := validators.pcrs[uint32(i)]
assert.False(ok)
case i == int(vtpm.PCRIndexOwnerID):
pcr, ok := validators.pcrs[uint32(i)]
assert.True(ok)
assert.Equal(pcrZeroUpdatedOne[:], pcr)
default:
assert.Equal(zero, validators.pcrs[uint32(i)])
}
}
})
}
}
func TestUpdatePCR(t *testing.T) {
emptyMap := map[uint32][]byte{}
defaultMap := map[uint32][]byte{
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"),
}
testCases := map[string]struct {
pcrMap map[uint32][]byte
pcrIndex uint32
encoded string
wantEntries int
wantErr bool
}{
"empty input, empty map": {
pcrMap: emptyMap,
pcrIndex: 10,
encoded: "",
wantEntries: 0,
wantErr: false,
},
"empty input, default map": {
pcrMap: defaultMap,
pcrIndex: 10,
encoded: "",
wantEntries: len(defaultMap),
wantErr: false,
},
"correct input, empty map": {
pcrMap: emptyMap,
pcrIndex: 10,
encoded: base64.StdEncoding.EncodeToString([]byte("Constellation")),
wantEntries: 1,
wantErr: false,
},
"correct input, default map": {
pcrMap: defaultMap,
pcrIndex: 10,
encoded: base64.StdEncoding.EncodeToString([]byte("Constellation")),
wantEntries: len(defaultMap) + 1,
wantErr: false,
},
"unencoded input, empty map": {
pcrMap: emptyMap,
pcrIndex: 10,
encoded: "Constellation",
wantEntries: 0,
wantErr: true,
},
"unencoded input, default map": {
pcrMap: defaultMap,
pcrIndex: 10,
encoded: "Constellation",
wantEntries: len(defaultMap),
wantErr: true,
},
"empty input at occupied index": {
pcrMap: defaultMap,
pcrIndex: 0,
encoded: "",
wantEntries: len(defaultMap) - 1,
wantErr: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
pcrs := make(map[uint32][]byte)
for k, v := range tc.pcrMap {
pcrs[k] = v
}
validators := &Validators{
provider: cloudprovider.GCP,
pcrs: pcrs,
}
err := validators.updatePCR(tc.pcrIndex, tc.encoded)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
}
assert.Len(pcrs, tc.wantEntries)
for _, v := range pcrs {
assert.Len(v, 32)
}
})
}