mirror of
https://github.com/edgelesssys/constellation.git
synced 2024-10-01 01:36:09 -04:00
Refactor verify command
This commit is contained in:
parent
019003337f
commit
1317fc2bb2
@ -1,6 +1,8 @@
|
|||||||
package cloudcmd
|
package cloudcmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
@ -13,50 +15,120 @@ import (
|
|||||||
"github.com/edgelesssys/constellation/internal/config"
|
"github.com/edgelesssys/constellation/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const warningStr = "Warning: not verifying the Constellation's %s measurements\n"
|
||||||
|
|
||||||
type Validators struct {
|
type Validators struct {
|
||||||
|
provider cloudprovider.Provider
|
||||||
|
pcrs map[uint32][]byte
|
||||||
validators []atls.Validator
|
validators []atls.Validator
|
||||||
pcrWarnings string
|
|
||||||
pcrWarningsInit string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewValidators(provider cloudprovider.Provider, config *config.Config) (Validators, error) {
|
func NewValidators(provider cloudprovider.Provider, config *config.Config) (*Validators, error) {
|
||||||
v := Validators{}
|
v := Validators{}
|
||||||
switch provider {
|
if provider == cloudprovider.Unknown {
|
||||||
|
return nil, errors.New("unknown cloud provider")
|
||||||
|
}
|
||||||
|
v.provider = provider
|
||||||
|
if err := v.setPCRs(config); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Validators) UpdateInitPCRs(ownerID, clusterID string) error {
|
||||||
|
if err := v.updatePCR(uint32(vtpm.PCRIndexOwnerID), ownerID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return v.updatePCR(uint32(vtpm.PCRIndexClusterID), clusterID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// updatePCR adds a new entry to the pcr map of v, or removes the key if the input is an empty string.
|
||||||
|
//
|
||||||
|
// When adding, the input is first decoded from base64.
|
||||||
|
// We then calculate the expected PCR by hashing the input using SHA256,
|
||||||
|
// appending expected PCR for initialization, and then hashing once more.
|
||||||
|
func (v *Validators) updatePCR(pcrIndex uint32, encoded string) error {
|
||||||
|
if encoded == "" {
|
||||||
|
delete(v.pcrs, pcrIndex)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(encoded)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("input [%s] is not base64 encoded: %w", encoded, err)
|
||||||
|
}
|
||||||
|
// new_pcr_value := hash(old_pcr_value || data_to_extend)
|
||||||
|
// Since we use the TPM2_PCR_Event call to extend the PCR, data_to_extend is the hash of our input
|
||||||
|
hashedInput := sha256.Sum256(decoded)
|
||||||
|
expectedPcr := sha256.Sum256(append(v.pcrs[pcrIndex], hashedInput[:]...))
|
||||||
|
v.pcrs[pcrIndex] = expectedPcr[:]
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Validators) setPCRs(config *config.Config) error {
|
||||||
|
switch v.provider {
|
||||||
case cloudprovider.GCP:
|
case cloudprovider.GCP:
|
||||||
gcpPCRs := *config.Provider.GCP.PCRs
|
gcpPCRs := *config.Provider.GCP.PCRs
|
||||||
if err := v.checkPCRs(gcpPCRs); err != nil {
|
if err := v.checkPCRs(gcpPCRs); err != nil {
|
||||||
return Validators{}, err
|
return err
|
||||||
}
|
|
||||||
v.setPCRWarnings(gcpPCRs)
|
|
||||||
v.validators = []atls.Validator{
|
|
||||||
gcp.NewValidator(gcpPCRs),
|
|
||||||
gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non CVMs.
|
|
||||||
}
|
}
|
||||||
|
v.pcrs = gcpPCRs
|
||||||
case cloudprovider.Azure:
|
case cloudprovider.Azure:
|
||||||
azurePCRs := *config.Provider.Azure.PCRs
|
azurePCRs := *config.Provider.Azure.PCRs
|
||||||
if err := v.checkPCRs(azurePCRs); err != nil {
|
if err := v.checkPCRs(azurePCRs); err != nil {
|
||||||
return Validators{}, err
|
return err
|
||||||
}
|
}
|
||||||
v.setPCRWarnings(azurePCRs)
|
v.pcrs = azurePCRs
|
||||||
v.validators = []atls.Validator{
|
|
||||||
azure.NewValidator(azurePCRs),
|
|
||||||
}
|
}
|
||||||
default:
|
return nil
|
||||||
return Validators{}, errors.New("unsupported cloud provider")
|
|
||||||
}
|
|
||||||
return v, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// V returns validators as list of atls.Validator.
|
// V returns validators as list of atls.Validator.
|
||||||
func (v *Validators) V() []atls.Validator {
|
func (v *Validators) V() []atls.Validator {
|
||||||
|
v.updateValidators()
|
||||||
return v.validators
|
return v.validators
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (v *Validators) updateValidators() {
|
||||||
|
switch v.provider {
|
||||||
|
case cloudprovider.GCP:
|
||||||
|
v.validators = []atls.Validator{
|
||||||
|
gcp.NewValidator(v.pcrs),
|
||||||
|
gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non CVMs.
|
||||||
|
}
|
||||||
|
case cloudprovider.Azure:
|
||||||
|
v.validators = []atls.Validator{
|
||||||
|
azure.NewValidator(v.pcrs),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Warnings returns warnings for the specifc PCR values that are not verified.
|
// Warnings returns warnings for the specifc PCR values that are not verified.
|
||||||
//
|
//
|
||||||
// PCR allocation inspired by https://link.springer.com/chapter/10.1007/978-1-4302-6584-9_12#Tab1
|
// PCR allocation inspired by https://link.springer.com/chapter/10.1007/978-1-4302-6584-9_12#Tab1
|
||||||
func (v *Validators) Warnings() string {
|
func (v *Validators) Warnings() string {
|
||||||
return v.pcrWarnings
|
sb := &strings.Builder{}
|
||||||
|
|
||||||
|
if v.pcrs[0] == nil || v.pcrs[1] == nil {
|
||||||
|
writeFmt(sb, warningStr, "BIOS")
|
||||||
|
}
|
||||||
|
|
||||||
|
if v.pcrs[2] == nil || v.pcrs[3] == nil {
|
||||||
|
writeFmt(sb, warningStr, "OPROM")
|
||||||
|
}
|
||||||
|
|
||||||
|
if v.pcrs[4] == nil || v.pcrs[5] == nil {
|
||||||
|
writeFmt(sb, warningStr, "MBR")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GRUB measures kernel command line and initrd into pcrs 8 and 9
|
||||||
|
if v.pcrs[8] == nil {
|
||||||
|
writeFmt(sb, warningStr, "kernel command line")
|
||||||
|
}
|
||||||
|
if v.pcrs[9] == nil {
|
||||||
|
writeFmt(sb, warningStr, "initrd")
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// WarningsIncludeInit returns warnings for the specifc PCR values that are not verified.
|
// WarningsIncludeInit returns warnings for the specifc PCR values that are not verified.
|
||||||
@ -64,10 +136,18 @@ func (v *Validators) Warnings() string {
|
|||||||
//
|
//
|
||||||
// PCR allocation inspired by https://link.springer.com/chapter/10.1007/978-1-4302-6584-9_12#Tab1
|
// PCR allocation inspired by https://link.springer.com/chapter/10.1007/978-1-4302-6584-9_12#Tab1
|
||||||
func (v *Validators) WarningsIncludeInit() string {
|
func (v *Validators) WarningsIncludeInit() string {
|
||||||
return v.pcrWarnings + v.pcrWarningsInit
|
warnings := v.Warnings()
|
||||||
|
if v.pcrs[uint32(vtpm.PCRIndexOwnerID)] == nil || v.pcrs[uint32(vtpm.PCRIndexClusterID)] == nil {
|
||||||
|
warnings = warnings + fmt.Sprintf(warningStr, "initialization status")
|
||||||
|
}
|
||||||
|
|
||||||
|
return warnings
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *Validators) checkPCRs(pcrs map[uint32][]byte) error {
|
func (v *Validators) checkPCRs(pcrs map[uint32][]byte) error {
|
||||||
|
if len(pcrs) == 0 {
|
||||||
|
return errors.New("no PCR values provided")
|
||||||
|
}
|
||||||
for k, v := range pcrs {
|
for k, v := range pcrs {
|
||||||
if len(v) != 32 {
|
if len(v) != 32 {
|
||||||
return fmt.Errorf("bad config: PCR[%d]: expected length: %d, but got: %d", k, 32, len(v))
|
return fmt.Errorf("bad config: PCR[%d]: expected length: %d, but got: %d", k, 32, len(v))
|
||||||
@ -76,37 +156,6 @@ func (v *Validators) checkPCRs(pcrs map[uint32][]byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *Validators) setPCRWarnings(pcrs map[uint32][]byte) {
|
|
||||||
const warningStr = "Warning: not verifying the Constellation's %s measurements\n"
|
|
||||||
sb := &strings.Builder{}
|
|
||||||
|
|
||||||
if pcrs[0] == nil || pcrs[1] == nil {
|
|
||||||
writeFmt(sb, warningStr, "BIOS")
|
|
||||||
}
|
|
||||||
|
|
||||||
if pcrs[2] == nil || pcrs[3] == nil {
|
|
||||||
writeFmt(sb, warningStr, "OPROM")
|
|
||||||
}
|
|
||||||
|
|
||||||
if pcrs[4] == nil || pcrs[5] == nil {
|
|
||||||
writeFmt(sb, warningStr, "MBR")
|
|
||||||
}
|
|
||||||
|
|
||||||
// GRUB measures kernel command line and initrd into pcrs 8 and 9
|
|
||||||
if pcrs[8] == nil {
|
|
||||||
writeFmt(sb, warningStr, "kernel command line")
|
|
||||||
}
|
|
||||||
if pcrs[9] == nil {
|
|
||||||
writeFmt(sb, warningStr, "initrd")
|
|
||||||
}
|
|
||||||
v.pcrWarnings = sb.String()
|
|
||||||
|
|
||||||
// Write init warnings separate.
|
|
||||||
if pcrs[uint32(vtpm.PCRIndexOwnerID)] == nil || pcrs[uint32(vtpm.PCRIndexClusterID)] == nil {
|
|
||||||
v.pcrWarningsInit = fmt.Sprintf(warningStr, "initialization status")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeFmt(sb *strings.Builder, fmtStr string, args ...interface{}) {
|
func writeFmt(sb *strings.Builder, fmtStr string, args ...interface{}) {
|
||||||
sb.WriteString(fmt.Sprintf(fmtStr, args...))
|
sb.WriteString(fmt.Sprintf(fmtStr, args...))
|
||||||
}
|
}
|
||||||
|
@ -1,21 +1,94 @@
|
|||||||
package cloudcmd
|
package cloudcmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/edgelesssys/constellation/cli/cloudprovider"
|
"github.com/edgelesssys/constellation/cli/cloudprovider"
|
||||||
|
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||||
|
"github.com/edgelesssys/constellation/coordinator/attestation/azure"
|
||||||
|
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
|
||||||
|
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
|
||||||
"github.com/edgelesssys/constellation/internal/config"
|
"github.com/edgelesssys/constellation/internal/config"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestWarnAboutPCRs(t *testing.T) {
|
func TestNewValidators(t *testing.T) {
|
||||||
|
zero := []byte("00000000000000000000000000000000")
|
||||||
|
one := []byte("11111111111111111111111111111111")
|
||||||
|
testPCRs := map[uint32][]byte{
|
||||||
|
0: zero,
|
||||||
|
1: one,
|
||||||
|
2: zero,
|
||||||
|
3: one,
|
||||||
|
4: zero,
|
||||||
|
5: zero,
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
provider cloudprovider.Provider
|
||||||
|
config *config.Config
|
||||||
|
pcrs map[uint32][]byte
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
"gcp": {
|
||||||
|
provider: cloudprovider.GCP,
|
||||||
|
pcrs: testPCRs,
|
||||||
|
},
|
||||||
|
"azure": {
|
||||||
|
provider: cloudprovider.Azure,
|
||||||
|
pcrs: testPCRs,
|
||||||
|
},
|
||||||
|
"no pcrs provided": {
|
||||||
|
provider: cloudprovider.Azure,
|
||||||
|
pcrs: map[uint32][]byte{},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
"invalid pcr length": {
|
||||||
|
provider: cloudprovider.GCP,
|
||||||
|
pcrs: map[uint32][]byte{0: []byte("0000000000000000000000000000000")},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
"unknown provider": {
|
||||||
|
provider: cloudprovider.Unknown,
|
||||||
|
pcrs: testPCRs,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
conf := &config.Config{Provider: &config.ProviderConfig{}}
|
||||||
|
if tc.provider == cloudprovider.GCP {
|
||||||
|
conf.Provider.GCP = &config.GCPConfig{PCRs: &tc.pcrs}
|
||||||
|
}
|
||||||
|
if tc.provider == cloudprovider.Azure {
|
||||||
|
conf.Provider.Azure = &config.AzureConfig{PCRs: &tc.pcrs}
|
||||||
|
}
|
||||||
|
|
||||||
|
validators, err := NewValidators(tc.provider, conf)
|
||||||
|
|
||||||
|
if tc.wantErr {
|
||||||
|
assert.Error(err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(err)
|
||||||
|
assert.Equal(tc.pcrs, validators.pcrs)
|
||||||
|
assert.Equal(tc.provider, validators.provider)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidatorsWarnings(t *testing.T) {
|
||||||
zero := []byte("00000000000000000000000000000000")
|
zero := []byte("00000000000000000000000000000000")
|
||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
pcrs map[uint32][]byte
|
pcrs map[uint32][]byte
|
||||||
wantWarnings []string
|
wantWarnings []string
|
||||||
wantWInclude []string
|
wantWInclude []string
|
||||||
wantErr bool
|
|
||||||
}{
|
}{
|
||||||
"no warnings": {
|
"no warnings": {
|
||||||
pcrs: map[uint32][]byte{
|
pcrs: map[uint32][]byte{
|
||||||
@ -143,37 +216,17 @@ func TestWarnAboutPCRs(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantWInclude: []string{"initialization"},
|
wantWInclude: []string{"initialization"},
|
||||||
},
|
},
|
||||||
"bad config": {
|
|
||||||
pcrs: map[uint32][]byte{
|
|
||||||
0: []byte("000"),
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, provider := range []string{"gcp", "azure", "unknown"} {
|
|
||||||
t.Run(provider, func(t *testing.T) {
|
|
||||||
for name, tc := range testCases {
|
for name, tc := range testCases {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
assert := assert.New(t)
|
assert := assert.New(t)
|
||||||
|
|
||||||
config := &config.Config{
|
validators := Validators{pcrs: tc.pcrs}
|
||||||
Provider: &config.ProviderConfig{
|
|
||||||
Azure: &config.AzureConfig{PCRs: &tc.pcrs},
|
|
||||||
GCP: &config.GCPConfig{PCRs: &tc.pcrs},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
validators, err := NewValidators(cloudprovider.FromString(provider), config)
|
|
||||||
|
|
||||||
v := validators.V()
|
|
||||||
warnings := validators.Warnings()
|
warnings := validators.Warnings()
|
||||||
warningsInclueInit := validators.WarningsIncludeInit()
|
warningsInclueInit := validators.WarningsIncludeInit()
|
||||||
|
|
||||||
if tc.wantErr || provider == "unknown" {
|
|
||||||
assert.Error(err)
|
|
||||||
} else {
|
|
||||||
assert.NoError(err)
|
|
||||||
if len(tc.wantWarnings) == 0 {
|
if len(tc.wantWarnings) == 0 {
|
||||||
assert.Empty(warnings)
|
assert.Empty(warnings)
|
||||||
}
|
}
|
||||||
@ -191,10 +244,275 @@ func TestWarnAboutPCRs(t *testing.T) {
|
|||||||
for _, w := range tc.wantWInclude {
|
for _, w := range tc.wantWInclude {
|
||||||
assert.Contains(warningsInclueInit, w)
|
assert.Contains(warningsInclueInit, w)
|
||||||
}
|
}
|
||||||
assert.NotEmpty(v)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidatorsV(t *testing.T) {
|
||||||
|
zero := []byte("00000000000000000000000000000000")
|
||||||
|
newTestPCRs := func() map[uint32][]byte {
|
||||||
|
return map[uint32][]byte{
|
||||||
|
0: zero,
|
||||||
|
1: zero,
|
||||||
|
2: zero,
|
||||||
|
3: zero,
|
||||||
|
4: zero,
|
||||||
|
5: zero,
|
||||||
|
6: zero,
|
||||||
|
7: zero,
|
||||||
|
8: zero,
|
||||||
|
9: zero,
|
||||||
|
10: zero,
|
||||||
|
11: zero,
|
||||||
|
12: zero,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
provider cloudprovider.Provider
|
||||||
|
pcrs map[uint32][]byte
|
||||||
|
wantVs []atls.Validator
|
||||||
|
}{
|
||||||
|
"gcp": {
|
||||||
|
provider: cloudprovider.GCP,
|
||||||
|
pcrs: newTestPCRs(),
|
||||||
|
wantVs: []atls.Validator{
|
||||||
|
gcp.NewValidator(newTestPCRs()),
|
||||||
|
gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: remove when not longer needed.
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"azure": {
|
||||||
|
provider: cloudprovider.Azure,
|
||||||
|
pcrs: newTestPCRs(),
|
||||||
|
wantVs: []atls.Validator{
|
||||||
|
azure.NewValidator(newTestPCRs()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
validators := &Validators{provider: tc.provider, pcrs: tc.pcrs}
|
||||||
|
|
||||||
|
resultValidators := validators.V()
|
||||||
|
|
||||||
|
assert.Equal(len(tc.wantVs), len(resultValidators))
|
||||||
|
for i, resValidator := range resultValidators {
|
||||||
|
assert.Equal(tc.wantVs[i].OID(), resValidator.OID())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidatorsUpdateInitPCRs(t *testing.T) {
|
||||||
|
zero := []byte("00000000000000000000000000000000")
|
||||||
|
one := []byte("11111111111111111111111111111111")
|
||||||
|
one64 := base64.StdEncoding.EncodeToString(one)
|
||||||
|
oneHash := sha256.Sum256(one)
|
||||||
|
pcrZeroUpdatedOne := sha256.Sum256(append(zero, oneHash[:]...))
|
||||||
|
newTestPCRs := func() map[uint32][]byte {
|
||||||
|
return map[uint32][]byte{
|
||||||
|
0: zero,
|
||||||
|
1: zero,
|
||||||
|
2: zero,
|
||||||
|
3: zero,
|
||||||
|
4: zero,
|
||||||
|
5: zero,
|
||||||
|
6: zero,
|
||||||
|
7: zero,
|
||||||
|
8: zero,
|
||||||
|
9: zero,
|
||||||
|
10: zero,
|
||||||
|
11: zero,
|
||||||
|
12: zero,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
provider cloudprovider.Provider
|
||||||
|
pcrs map[uint32][]byte
|
||||||
|
ownerID string
|
||||||
|
clusterID string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
"gcp update owner ID": {
|
||||||
|
provider: cloudprovider.GCP,
|
||||||
|
pcrs: newTestPCRs(),
|
||||||
|
ownerID: one64,
|
||||||
|
},
|
||||||
|
"gcp update cluster ID": {
|
||||||
|
provider: cloudprovider.GCP,
|
||||||
|
pcrs: newTestPCRs(),
|
||||||
|
clusterID: one64,
|
||||||
|
},
|
||||||
|
"gcp update both": {
|
||||||
|
provider: cloudprovider.GCP,
|
||||||
|
pcrs: newTestPCRs(),
|
||||||
|
ownerID: one64,
|
||||||
|
clusterID: one64,
|
||||||
|
},
|
||||||
|
"azure update owner ID": {
|
||||||
|
provider: cloudprovider.Azure,
|
||||||
|
pcrs: newTestPCRs(),
|
||||||
|
ownerID: one64,
|
||||||
|
},
|
||||||
|
"azure update cluster ID": {
|
||||||
|
provider: cloudprovider.Azure,
|
||||||
|
pcrs: newTestPCRs(),
|
||||||
|
clusterID: one64,
|
||||||
|
},
|
||||||
|
"azure update both": {
|
||||||
|
provider: cloudprovider.Azure,
|
||||||
|
pcrs: newTestPCRs(),
|
||||||
|
ownerID: one64,
|
||||||
|
clusterID: one64,
|
||||||
|
},
|
||||||
|
"owner ID and cluster ID empty": {
|
||||||
|
provider: cloudprovider.GCP,
|
||||||
|
pcrs: newTestPCRs(),
|
||||||
|
},
|
||||||
|
"invalid encoding": {
|
||||||
|
provider: cloudprovider.GCP,
|
||||||
|
pcrs: newTestPCRs(),
|
||||||
|
ownerID: "invalid",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
validators := &Validators{provider: tc.provider, pcrs: tc.pcrs}
|
||||||
|
|
||||||
|
err := validators.UpdateInitPCRs(tc.ownerID, tc.clusterID)
|
||||||
|
|
||||||
|
if tc.wantErr {
|
||||||
|
assert.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
assert.NoError(err)
|
||||||
|
for i := 0; i < len(tc.pcrs); i++ {
|
||||||
|
switch {
|
||||||
|
case i == int(vtpm.PCRIndexClusterID) && tc.clusterID == "":
|
||||||
|
// should be deleted
|
||||||
|
_, ok := validators.pcrs[uint32(i)]
|
||||||
|
assert.False(ok)
|
||||||
|
|
||||||
|
case i == int(vtpm.PCRIndexClusterID):
|
||||||
|
pcr, ok := validators.pcrs[uint32(i)]
|
||||||
|
assert.True(ok)
|
||||||
|
assert.Equal(pcrZeroUpdatedOne[:], pcr)
|
||||||
|
|
||||||
|
case i == int(vtpm.PCRIndexOwnerID) && tc.ownerID == "":
|
||||||
|
// should be deleted
|
||||||
|
_, ok := validators.pcrs[uint32(i)]
|
||||||
|
assert.False(ok)
|
||||||
|
|
||||||
|
case i == int(vtpm.PCRIndexOwnerID):
|
||||||
|
pcr, ok := validators.pcrs[uint32(i)]
|
||||||
|
assert.True(ok)
|
||||||
|
assert.Equal(pcrZeroUpdatedOne[:], pcr)
|
||||||
|
|
||||||
|
default:
|
||||||
|
assert.Equal(zero, validators.pcrs[uint32(i)])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdatePCR(t *testing.T) {
|
||||||
|
emptyMap := map[uint32][]byte{}
|
||||||
|
defaultMap := map[uint32][]byte{
|
||||||
|
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
|
||||||
|
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"),
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
pcrMap map[uint32][]byte
|
||||||
|
pcrIndex uint32
|
||||||
|
encoded string
|
||||||
|
wantEntries int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
"empty input, empty map": {
|
||||||
|
pcrMap: emptyMap,
|
||||||
|
pcrIndex: 10,
|
||||||
|
encoded: "",
|
||||||
|
wantEntries: 0,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
"empty input, default map": {
|
||||||
|
pcrMap: defaultMap,
|
||||||
|
pcrIndex: 10,
|
||||||
|
encoded: "",
|
||||||
|
wantEntries: len(defaultMap),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
"correct input, empty map": {
|
||||||
|
pcrMap: emptyMap,
|
||||||
|
pcrIndex: 10,
|
||||||
|
encoded: base64.StdEncoding.EncodeToString([]byte("Constellation")),
|
||||||
|
wantEntries: 1,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
"correct input, default map": {
|
||||||
|
pcrMap: defaultMap,
|
||||||
|
pcrIndex: 10,
|
||||||
|
encoded: base64.StdEncoding.EncodeToString([]byte("Constellation")),
|
||||||
|
wantEntries: len(defaultMap) + 1,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
"unencoded input, empty map": {
|
||||||
|
pcrMap: emptyMap,
|
||||||
|
pcrIndex: 10,
|
||||||
|
encoded: "Constellation",
|
||||||
|
wantEntries: 0,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
"unencoded input, default map": {
|
||||||
|
pcrMap: defaultMap,
|
||||||
|
pcrIndex: 10,
|
||||||
|
encoded: "Constellation",
|
||||||
|
wantEntries: len(defaultMap),
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
"empty input at occupied index": {
|
||||||
|
pcrMap: defaultMap,
|
||||||
|
pcrIndex: 0,
|
||||||
|
encoded: "",
|
||||||
|
wantEntries: len(defaultMap) - 1,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
pcrs := make(map[uint32][]byte)
|
||||||
|
for k, v := range tc.pcrMap {
|
||||||
|
pcrs[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
validators := &Validators{
|
||||||
|
provider: cloudprovider.GCP,
|
||||||
|
pcrs: pcrs,
|
||||||
|
}
|
||||||
|
err := validators.updatePCR(tc.pcrIndex, tc.encoded)
|
||||||
|
|
||||||
|
if tc.wantErr {
|
||||||
|
assert.Error(err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(err)
|
||||||
|
}
|
||||||
|
assert.Len(pcrs, tc.wantEntries)
|
||||||
|
for _, v := range pcrs {
|
||||||
|
assert.Len(v, 32)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,18 +3,13 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
// ErrInvalidInput is an error where user entered invalid input.
|
||||||
// ErrInvalidInput is an error where user entered invalid input.
|
var ErrInvalidInput = errors.New("user made invalid input")
|
||||||
ErrInvalidInput = errors.New("user made invalid input")
|
|
||||||
warningStr = "Warning: not verifying the Constellation's %s measurements\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
// askToConfirm asks user to confirm an action.
|
// askToConfirm asks user to confirm an action.
|
||||||
// The user will be asked the handed question and can answer with
|
// The user will be asked the handed question and can answer with
|
||||||
@ -38,43 +33,3 @@ func askToConfirm(cmd *cobra.Command, question string) (bool, error) {
|
|||||||
}
|
}
|
||||||
return false, ErrInvalidInput
|
return false, ErrInvalidInput
|
||||||
}
|
}
|
||||||
|
|
||||||
// warnAboutPCRs displays warnings if specifc PCR values are not verified.
|
|
||||||
//
|
|
||||||
// PCR allocation inspired by https://link.springer.com/chapter/10.1007/978-1-4302-6584-9_12#Tab1
|
|
||||||
func warnAboutPCRs(cmd *cobra.Command, pcrs map[uint32][]byte, checkInit bool) error {
|
|
||||||
for k, v := range pcrs {
|
|
||||||
if len(v) != 32 {
|
|
||||||
return fmt.Errorf("bad config: PCR[%d]: expected length: %d, but got: %d", k, 32, len(v))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if pcrs[0] == nil || pcrs[1] == nil {
|
|
||||||
cmd.PrintErrf(warningStr, "BIOS")
|
|
||||||
}
|
|
||||||
|
|
||||||
if pcrs[2] == nil || pcrs[3] == nil {
|
|
||||||
cmd.PrintErrf(warningStr, "OPROM")
|
|
||||||
}
|
|
||||||
|
|
||||||
if pcrs[4] == nil || pcrs[5] == nil {
|
|
||||||
cmd.PrintErrf(warningStr, "MBR")
|
|
||||||
}
|
|
||||||
|
|
||||||
// GRUB measures kernel command line and initrd into pcrs 8 and 9
|
|
||||||
if pcrs[8] == nil {
|
|
||||||
cmd.PrintErrf(warningStr, "kernel command line")
|
|
||||||
}
|
|
||||||
if pcrs[9] == nil {
|
|
||||||
cmd.PrintErrf(warningStr, "initrd")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only warn about initialization PCRs if necessary
|
|
||||||
if checkInit {
|
|
||||||
if pcrs[uint32(vtpm.PCRIndexOwnerID)] == nil || pcrs[uint32(vtpm.PCRIndexClusterID)] == nil {
|
|
||||||
cmd.PrintErrf(warningStr, "initialization status")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
@ -61,188 +61,3 @@ func TestAskToConfirm(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWarnAboutPCRs(t *testing.T) {
|
|
||||||
zero := []byte("00000000000000000000000000000000")
|
|
||||||
|
|
||||||
testCases := map[string]struct {
|
|
||||||
pcrs map[uint32][]byte
|
|
||||||
dontWarnInit bool
|
|
||||||
wantWarnings []string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
"no warnings": {
|
|
||||||
pcrs: map[uint32][]byte{
|
|
||||||
0: zero,
|
|
||||||
1: zero,
|
|
||||||
2: zero,
|
|
||||||
3: zero,
|
|
||||||
4: zero,
|
|
||||||
5: zero,
|
|
||||||
6: zero,
|
|
||||||
7: zero,
|
|
||||||
8: zero,
|
|
||||||
9: zero,
|
|
||||||
10: zero,
|
|
||||||
11: zero,
|
|
||||||
12: zero,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"no warnings for missing non critical values": {
|
|
||||||
pcrs: map[uint32][]byte{
|
|
||||||
0: zero,
|
|
||||||
1: zero,
|
|
||||||
2: zero,
|
|
||||||
3: zero,
|
|
||||||
4: zero,
|
|
||||||
5: zero,
|
|
||||||
8: zero,
|
|
||||||
9: zero,
|
|
||||||
11: zero,
|
|
||||||
12: zero,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"warn for BIOS": {
|
|
||||||
pcrs: map[uint32][]byte{
|
|
||||||
0: zero,
|
|
||||||
2: zero,
|
|
||||||
3: zero,
|
|
||||||
4: zero,
|
|
||||||
5: zero,
|
|
||||||
8: zero,
|
|
||||||
9: zero,
|
|
||||||
11: zero,
|
|
||||||
12: zero,
|
|
||||||
},
|
|
||||||
wantWarnings: []string{"BIOS"},
|
|
||||||
},
|
|
||||||
"warn for OPROM": {
|
|
||||||
pcrs: map[uint32][]byte{
|
|
||||||
0: zero,
|
|
||||||
1: zero,
|
|
||||||
3: zero,
|
|
||||||
4: zero,
|
|
||||||
5: zero,
|
|
||||||
8: zero,
|
|
||||||
9: zero,
|
|
||||||
11: zero,
|
|
||||||
12: zero,
|
|
||||||
},
|
|
||||||
wantWarnings: []string{"OPROM"},
|
|
||||||
},
|
|
||||||
"warn for MBR": {
|
|
||||||
pcrs: map[uint32][]byte{
|
|
||||||
0: zero,
|
|
||||||
1: zero,
|
|
||||||
2: zero,
|
|
||||||
3: zero,
|
|
||||||
5: zero,
|
|
||||||
8: zero,
|
|
||||||
9: zero,
|
|
||||||
11: zero,
|
|
||||||
12: zero,
|
|
||||||
},
|
|
||||||
wantWarnings: []string{"MBR"},
|
|
||||||
},
|
|
||||||
"warn for kernel": {
|
|
||||||
pcrs: map[uint32][]byte{
|
|
||||||
0: zero,
|
|
||||||
1: zero,
|
|
||||||
2: zero,
|
|
||||||
3: zero,
|
|
||||||
4: zero,
|
|
||||||
5: zero,
|
|
||||||
9: zero,
|
|
||||||
11: zero,
|
|
||||||
12: zero,
|
|
||||||
},
|
|
||||||
wantWarnings: []string{"kernel"},
|
|
||||||
},
|
|
||||||
"warn for initrd": {
|
|
||||||
pcrs: map[uint32][]byte{
|
|
||||||
0: zero,
|
|
||||||
1: zero,
|
|
||||||
2: zero,
|
|
||||||
3: zero,
|
|
||||||
4: zero,
|
|
||||||
5: zero,
|
|
||||||
8: zero,
|
|
||||||
11: zero,
|
|
||||||
12: zero,
|
|
||||||
},
|
|
||||||
wantWarnings: []string{"initrd"},
|
|
||||||
},
|
|
||||||
"warn for initialization": {
|
|
||||||
pcrs: map[uint32][]byte{
|
|
||||||
0: zero,
|
|
||||||
1: zero,
|
|
||||||
2: zero,
|
|
||||||
3: zero,
|
|
||||||
4: zero,
|
|
||||||
5: zero,
|
|
||||||
8: zero,
|
|
||||||
9: zero,
|
|
||||||
11: zero,
|
|
||||||
},
|
|
||||||
dontWarnInit: false,
|
|
||||||
wantWarnings: []string{"initialization"},
|
|
||||||
},
|
|
||||||
"don't warn for initialization": {
|
|
||||||
pcrs: map[uint32][]byte{
|
|
||||||
0: zero,
|
|
||||||
1: zero,
|
|
||||||
2: zero,
|
|
||||||
3: zero,
|
|
||||||
4: zero,
|
|
||||||
5: zero,
|
|
||||||
8: zero,
|
|
||||||
9: zero,
|
|
||||||
11: zero,
|
|
||||||
},
|
|
||||||
dontWarnInit: true,
|
|
||||||
},
|
|
||||||
"multi warning": {
|
|
||||||
pcrs: map[uint32][]byte{},
|
|
||||||
wantWarnings: []string{
|
|
||||||
"BIOS",
|
|
||||||
"OPROM",
|
|
||||||
"MBR",
|
|
||||||
"initialization",
|
|
||||||
"initrd",
|
|
||||||
"kernel",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"bad config": {
|
|
||||||
pcrs: map[uint32][]byte{
|
|
||||||
0: []byte("000"),
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, tc := range testCases {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
assert := assert.New(t)
|
|
||||||
|
|
||||||
cmd := newInitCmd()
|
|
||||||
var out bytes.Buffer
|
|
||||||
cmd.SetOut(&out)
|
|
||||||
var errOut bytes.Buffer
|
|
||||||
cmd.SetErr(&errOut)
|
|
||||||
|
|
||||||
err := warnAboutPCRs(cmd, tc.pcrs, !tc.dontWarnInit)
|
|
||||||
if tc.wantErr {
|
|
||||||
assert.Error(err)
|
|
||||||
} else {
|
|
||||||
assert.NoError(err)
|
|
||||||
if len(tc.wantWarnings) == 0 {
|
|
||||||
assert.Empty(errOut.String())
|
|
||||||
} else {
|
|
||||||
for _, warning := range tc.wantWarnings {
|
|
||||||
assert.Contains(errOut.String(), warning)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -3,6 +3,7 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -33,6 +34,15 @@ func isIntGreaterArg(arg int, i int) cobra.PositionalArgs {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isIntLessArg(arg int, i int) cobra.PositionalArgs {
|
||||||
|
return cobra.MatchAll(isIntArg(arg), func(cmd *cobra.Command, args []string) error {
|
||||||
|
if v, _ := strconv.Atoi(args[arg]); v >= i {
|
||||||
|
return fmt.Errorf("argument %d must be less %d, but it's %d", arg, i, v)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// warnAWS warns that AWS isn't supported.
|
// warnAWS warns that AWS isn't supported.
|
||||||
func warnAWS(providerPos int) cobra.PositionalArgs {
|
func warnAWS(providerPos int) cobra.PositionalArgs {
|
||||||
return func(cmd *cobra.Command, args []string) error {
|
return func(cmd *cobra.Command, args []string) error {
|
||||||
@ -48,6 +58,22 @@ func isIntGreaterZeroArg(arg int) cobra.PositionalArgs {
|
|||||||
return isIntGreaterArg(arg, 0)
|
return isIntGreaterArg(arg, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isPort(arg int) cobra.PositionalArgs {
|
||||||
|
return cobra.MatchAll(
|
||||||
|
isIntGreaterArg(arg, -1),
|
||||||
|
isIntLessArg(arg, 65536),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIP(arg int) cobra.PositionalArgs {
|
||||||
|
return func(cmd *cobra.Command, args []string) error {
|
||||||
|
if ip := net.ParseIP(args[arg]); ip == nil {
|
||||||
|
return fmt.Errorf("argument %s isn't a valid IP address", args[arg])
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// isEC2InstanceType checks if argument at position arg is a key in m.
|
// isEC2InstanceType checks if argument at position arg is a key in m.
|
||||||
// The argument will always be converted to lower case letters.
|
// The argument will always be converted to lower case letters.
|
||||||
func isEC2InstanceType(arg int) cobra.PositionalArgs {
|
func isEC2InstanceType(arg int) cobra.PositionalArgs {
|
||||||
@ -81,6 +107,15 @@ func isAzureInstanceType(arg int) cobra.PositionalArgs {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isCloudProvider(arg int) cobra.PositionalArgs {
|
||||||
|
return func(cmd *cobra.Command, args []string) error {
|
||||||
|
if provider := cloudprovider.FromString(args[arg]); provider == cloudprovider.Unknown {
|
||||||
|
return fmt.Errorf("argument %s isn't a valid cloud provider", args[arg])
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// isInstanceTypeForProvider returns a argument validation function that checks if the argument
|
// isInstanceTypeForProvider returns a argument validation function that checks if the argument
|
||||||
// at position typePos is a valid instance type for the cloud provider string at position
|
// at position typePos is a valid instance type for the cloud provider string at position
|
||||||
// providerPos.
|
// providerPos.
|
||||||
|
@ -100,6 +100,66 @@ func TestIsIntGreaterZeroArg(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsPort(t *testing.T) {
|
||||||
|
testCases := map[string]struct {
|
||||||
|
args []string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
"valid port 1": {[]string{"80"}, false},
|
||||||
|
"valid port 2": {[]string{"8080"}, false},
|
||||||
|
"valid port 3": {[]string{"65535"}, false},
|
||||||
|
"invalid port 1": {[]string{"foo"}, true},
|
||||||
|
"invalid port 2": {[]string{"65536"}, true},
|
||||||
|
"invalid port 3": {[]string{"-1"}, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
testCmd := &cobra.Command{Args: isPort(0)}
|
||||||
|
|
||||||
|
err := testCmd.ValidateArgs(tc.args)
|
||||||
|
|
||||||
|
if tc.wantErr {
|
||||||
|
assert.Error(err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsIP(t *testing.T) {
|
||||||
|
testCases := map[string]struct {
|
||||||
|
args []string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
"valid ip 1": {[]string{"192.168.0.2"}, false},
|
||||||
|
"valid ip 2": {[]string{"127.0.0.1"}, false},
|
||||||
|
"valid ip 3": {[]string{"8.8.8.8"}, false},
|
||||||
|
"invalid ip 1": {[]string{"foo"}, true},
|
||||||
|
"invalid ip 2": {[]string{"foo.bar.baz.1"}, true},
|
||||||
|
"invalid ip 3": {[]string{"800.800.800.800"}, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
testCmd := &cobra.Command{Args: isIP(0)}
|
||||||
|
|
||||||
|
err := testCmd.ValidateArgs(tc.args)
|
||||||
|
|
||||||
|
if tc.wantErr {
|
||||||
|
assert.Error(err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestIsEC2InstanceType(t *testing.T) {
|
func TestIsEC2InstanceType(t *testing.T) {
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
args []string
|
args []string
|
||||||
@ -184,6 +244,36 @@ func TestIsAzureInstanceType(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsCloudProvider(t *testing.T) {
|
||||||
|
testCases := map[string]struct {
|
||||||
|
pos int
|
||||||
|
args []string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
"gcp": {0, []string{"gcp"}, false},
|
||||||
|
"azure": {1, []string{"foo", "azure"}, false},
|
||||||
|
"foo": {0, []string{"foo"}, true},
|
||||||
|
"empty": {0, []string{""}, true},
|
||||||
|
"unknown": {0, []string{"unknown"}, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
testCmd := &cobra.Command{Args: isCloudProvider(tc.pos)}
|
||||||
|
|
||||||
|
err := testCmd.ValidateArgs(tc.args)
|
||||||
|
|
||||||
|
if tc.wantErr {
|
||||||
|
assert.Error(err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestIsInstanceTypeForProvider(t *testing.T) {
|
func TestIsInstanceTypeForProvider(t *testing.T) {
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
typePos int
|
typePos int
|
||||||
|
@ -2,138 +2,122 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/base64"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/edgelesssys/constellation/cli/status"
|
"github.com/edgelesssys/constellation/cli/cloud/cloudcmd"
|
||||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
"github.com/edgelesssys/constellation/cli/cloudprovider"
|
||||||
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
|
"github.com/edgelesssys/constellation/cli/file"
|
||||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
"github.com/edgelesssys/constellation/cli/proto"
|
||||||
|
"github.com/edgelesssys/constellation/internal/config"
|
||||||
|
"github.com/spf13/afero"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"google.golang.org/grpc"
|
|
||||||
"google.golang.org/grpc/credentials"
|
|
||||||
rpcStatus "google.golang.org/grpc/status"
|
rpcStatus "google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newVerifyCmd() *cobra.Command {
|
func newVerifyCmd() *cobra.Command {
|
||||||
cmd := &cobra.Command{
|
cmd := &cobra.Command{
|
||||||
Use: "verify azure|gcp",
|
Use: "verify {azure|gcp} IP PORT",
|
||||||
Short: "Verify the confidential properties of your Constellation.",
|
Short: "Verify the confidential properties of your Constellation.",
|
||||||
Long: "Verify the confidential properties of your Constellation.",
|
Long: "Verify the confidential properties of your Constellation.",
|
||||||
|
Args: cobra.MatchAll(
|
||||||
|
cobra.ExactArgs(3),
|
||||||
|
isCloudProvider(0),
|
||||||
|
isIP(1),
|
||||||
|
isPort(2),
|
||||||
|
),
|
||||||
|
RunE: runVerify,
|
||||||
}
|
}
|
||||||
|
cmd.Flags().String("owner-id", "", "verify the Constellation using the owner identity derived from the master secret.")
|
||||||
cmd.PersistentFlags().String("owner-id", "", "verify the Constellation using the owner identity derived from the master secret.")
|
cmd.Flags().String("unique-id", "", "verify the Constellation using the unique cluster identity.")
|
||||||
cmd.PersistentFlags().String("unique-id", "", "verify the Constellation using the unique cluster identity.")
|
|
||||||
|
|
||||||
cmd.AddCommand(newVerifyGCPCmd())
|
|
||||||
cmd.AddCommand(newVerifyAzureCmd())
|
|
||||||
cmd.AddCommand(newVerifyGCPNonCVMCmd())
|
|
||||||
|
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
|
||||||
func runVerify(cmd *cobra.Command, args []string, pcrs map[uint32][]byte, validator atls.Validator) error {
|
func runVerify(cmd *cobra.Command, args []string) error {
|
||||||
if err := warnAboutPCRs(cmd, pcrs, false); err != nil {
|
provider := cloudprovider.FromString(args[0])
|
||||||
return err
|
ip := args[1]
|
||||||
}
|
port := args[2]
|
||||||
|
fileHandler := file.NewHandler(afero.NewOsFs())
|
||||||
verifier := verifier{
|
protoClient := &proto.Client{}
|
||||||
newConn: newVerifiedConn,
|
defer protoClient.Close()
|
||||||
newClient: pubproto.NewAPIClient,
|
return verify(cmd.Context(), cmd, provider, ip, port, fileHandler, protoClient)
|
||||||
}
|
|
||||||
return verify(cmd.Context(), cmd.OutOrStdout(), net.JoinHostPort(args[0], args[1]), []atls.Validator{validator}, verifier)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func verify(ctx context.Context, w io.Writer, target string, validators []atls.Validator, verifier verifier) error {
|
func verify(ctx context.Context, cmd *cobra.Command, provider cloudprovider.Provider, ip, port string, fileHandler file.Handler, protoClient protoClient) error {
|
||||||
conn, err := verifier.newConn(ctx, target, validators)
|
flags, err := parseVerifyFlags(cmd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
|
||||||
client := verifier.newClient(conn)
|
|
||||||
|
|
||||||
if _, err := client.GetState(ctx, &pubproto.GetStateRequest{}); err != nil {
|
config, err := config.FromFile(fileHandler, flags.devConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
validators, err := cloudcmd.NewValidators(provider, config)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validators.UpdateInitPCRs(flags.ownerID, flags.clusterID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if validators.Warnings() != "" {
|
||||||
|
cmd.Print(validators.Warnings())
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := protoClient.Connect(ip, port, validators.V()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := protoClient.GetState(ctx); err != nil {
|
||||||
if err, ok := rpcStatus.FromError(err); ok {
|
if err, ok := rpcStatus.FromError(err); ok {
|
||||||
return fmt.Errorf("unable to verify Constellation cluster: %s", err.Message())
|
return fmt.Errorf("unable to verify Constellation cluster: %s", err.Message())
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fmt.Fprintln(w, "OK")
|
|
||||||
|
cmd.Println("OK")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepareValidator parses parameters and updates the PCR map.
|
func parseVerifyFlags(cmd *cobra.Command) (verifyFlags, error) {
|
||||||
func prepareValidator(cmd *cobra.Command, pcrs map[uint32][]byte) error {
|
|
||||||
ownerID, err := cmd.Flags().GetString("owner-id")
|
ownerID, err := cmd.Flags().GetString("owner-id")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return verifyFlags{}, err
|
||||||
}
|
}
|
||||||
clusterID, err := cmd.Flags().GetString("unique-id")
|
clusterID, err := cmd.Flags().GetString("unique-id")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return verifyFlags{}, err
|
||||||
}
|
}
|
||||||
if ownerID == "" && clusterID == "" {
|
if ownerID == "" && clusterID == "" {
|
||||||
return errors.New("neither owner identity nor unique identity provided to verify the Constellation")
|
return verifyFlags{}, errors.New("neither owner ID nor unique ID provided to verify the Constellation")
|
||||||
}
|
}
|
||||||
|
|
||||||
return updatePCRMap(pcrs, ownerID, clusterID)
|
devConfigPath, err := cmd.Flags().GetString("dev-config")
|
||||||
}
|
|
||||||
|
|
||||||
func updatePCRMap(pcrs map[uint32][]byte, ownerID, clusterID string) error {
|
|
||||||
if err := addOrSkipPCR(pcrs, uint32(vtpm.PCRIndexOwnerID), ownerID); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return addOrSkipPCR(pcrs, uint32(vtpm.PCRIndexClusterID), clusterID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// addOrSkipPCR adds a new entry to the map, or removes the key if the input is an empty string.
|
|
||||||
//
|
|
||||||
// When adding, the input is first decoded from base64.
|
|
||||||
// We then calculate the expected PCR by hashing the input using SHA256,
|
|
||||||
// appending expected PCR for initialization, and then hashing once more.
|
|
||||||
func addOrSkipPCR(toAdd map[uint32][]byte, pcrIndex uint32, encoded string) error {
|
|
||||||
if encoded == "" {
|
|
||||||
delete(toAdd, pcrIndex)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
decoded, err := base64.StdEncoding.DecodeString(encoded)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("input [%s] is not base64 encoded: %w", encoded, err)
|
return verifyFlags{}, err
|
||||||
}
|
|
||||||
// new_pcr_value := hash(old_pcr_value || data_to_extend)
|
|
||||||
// Since we use the TPM2_PCR_Event call to extend the PCR, data_to_extend is the hash of our input
|
|
||||||
hashedInput := sha256.Sum256(decoded)
|
|
||||||
expectedPcr := sha256.Sum256(append(toAdd[pcrIndex], hashedInput[:]...))
|
|
||||||
toAdd[pcrIndex] = expectedPcr[:]
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type verifier struct {
|
|
||||||
newConn func(context.Context, string, []atls.Validator) (status.ClientConn, error)
|
|
||||||
newClient func(cc grpc.ClientConnInterface) pubproto.APIClient
|
|
||||||
}
|
|
||||||
|
|
||||||
// newVerifiedConn creates a grpc over aTLS connection to the target, using the provided PCR values to verify the server.
|
|
||||||
func newVerifiedConn(ctx context.Context, target string, validators []atls.Validator) (status.ClientConn, error) {
|
|
||||||
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return grpc.DialContext(
|
return verifyFlags{
|
||||||
ctx, target, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
|
devConfigPath: devConfigPath,
|
||||||
)
|
ownerID: ownerID,
|
||||||
|
clusterID: clusterID,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type verifyFlags struct {
|
||||||
|
ownerID string
|
||||||
|
clusterID string
|
||||||
|
devConfigPath string
|
||||||
}
|
}
|
||||||
|
|
||||||
// verifyCompletion handels the completion of CLI arguments. It is frequently called
|
// verifyCompletion handels the completion of CLI arguments. It is frequently called
|
||||||
// while the user types arguments of the command to suggest completion.
|
// while the user types arguments of the command to suggest completion.
|
||||||
func verifyCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
func verifyCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
switch len(args) {
|
switch len(args) {
|
||||||
case 0, 1:
|
case 0:
|
||||||
|
return []string{"gcp", "azure"}, cobra.ShellCompDirectiveNoFileComp
|
||||||
|
case 1, 2:
|
||||||
return []string{}, cobra.ShellCompDirectiveNoFileComp
|
return []string{}, cobra.ShellCompDirectiveNoFileComp
|
||||||
default:
|
default:
|
||||||
return []string{}, cobra.ShellCompDirectiveError
|
return []string{}, cobra.ShellCompDirectiveError
|
||||||
|
@ -1,51 +0,0 @@
|
|||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/edgelesssys/constellation/cli/file"
|
|
||||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
|
||||||
"github.com/edgelesssys/constellation/coordinator/attestation/azure"
|
|
||||||
"github.com/edgelesssys/constellation/internal/config"
|
|
||||||
"github.com/spf13/afero"
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newVerifyAzureCmd() *cobra.Command {
|
|
||||||
cmd := &cobra.Command{
|
|
||||||
Use: "azure IP PORT",
|
|
||||||
Short: "Verify the confidential properties of your Constellation on Azure.",
|
|
||||||
Long: "Verify the confidential properties of your Constellation on Azure.",
|
|
||||||
Args: cobra.ExactArgs(2),
|
|
||||||
ValidArgsFunction: verifyCompletion,
|
|
||||||
RunE: runVerifyAzure,
|
|
||||||
}
|
|
||||||
|
|
||||||
return cmd
|
|
||||||
}
|
|
||||||
|
|
||||||
func runVerifyAzure(cmd *cobra.Command, args []string) error {
|
|
||||||
fileHandler := file.NewHandler(afero.NewOsFs())
|
|
||||||
devConfigName, err := cmd.Flags().GetString("dev-config")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
config, err := config.FromFile(fileHandler, devConfigName)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
validators, err := getAzureValidator(cmd, *config.Provider.GCP.PCRs)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return runVerify(cmd, args, *config.Provider.GCP.PCRs, validators)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getAzureValidator returns an Azure validator.
|
|
||||||
func getAzureValidator(cmd *cobra.Command, pcrs map[uint32][]byte) (atls.Validator, error) {
|
|
||||||
if err := prepareValidator(cmd, pcrs); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return azure.NewValidator(pcrs), nil
|
|
||||||
}
|
|
@ -1,66 +0,0 @@
|
|||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/base64"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGetAzureValidator(t *testing.T) {
|
|
||||||
testCases := map[string]struct {
|
|
||||||
ownerID string
|
|
||||||
clusterID string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
"no input": {
|
|
||||||
ownerID: "",
|
|
||||||
clusterID: "",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"unencoded secret ID": {
|
|
||||||
ownerID: "owner-id",
|
|
||||||
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"unencoded cluster ID": {
|
|
||||||
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
|
|
||||||
clusterID: "unique-id",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"correct input": {
|
|
||||||
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
|
|
||||||
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, tc := range testCases {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
assert := assert.New(t)
|
|
||||||
require := require.New(t)
|
|
||||||
|
|
||||||
cmd := newVerifyAzureCmd()
|
|
||||||
cmd.Flags().String("owner-id", "", "")
|
|
||||||
cmd.Flags().String("unique-id", "", "")
|
|
||||||
require.NoError(cmd.Flags().Set("owner-id", tc.ownerID))
|
|
||||||
require.NoError(cmd.Flags().Set("unique-id", tc.clusterID))
|
|
||||||
var out bytes.Buffer
|
|
||||||
cmd.SetOut(&out)
|
|
||||||
var errOut bytes.Buffer
|
|
||||||
cmd.SetErr(&errOut)
|
|
||||||
|
|
||||||
_, err := getAzureValidator(cmd, map[uint32][]byte{
|
|
||||||
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
|
|
||||||
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"),
|
|
||||||
})
|
|
||||||
if tc.wantErr {
|
|
||||||
assert.Error(err)
|
|
||||||
} else {
|
|
||||||
assert.NoError(err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,51 +0,0 @@
|
|||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/edgelesssys/constellation/cli/file"
|
|
||||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
|
||||||
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
|
|
||||||
"github.com/edgelesssys/constellation/internal/config"
|
|
||||||
"github.com/spf13/afero"
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newVerifyGCPCmd() *cobra.Command {
|
|
||||||
cmd := &cobra.Command{
|
|
||||||
Use: "gcp IP PORT",
|
|
||||||
Short: "Verify the confidential properties of your Constellation on Google Cloud Platform.",
|
|
||||||
Long: "Verify the confidential properties of your Constellation on Google Cloud Platform.",
|
|
||||||
Args: cobra.ExactArgs(2),
|
|
||||||
ValidArgsFunction: verifyCompletion,
|
|
||||||
RunE: runVerifyGCP,
|
|
||||||
}
|
|
||||||
|
|
||||||
return cmd
|
|
||||||
}
|
|
||||||
|
|
||||||
func runVerifyGCP(cmd *cobra.Command, args []string) error {
|
|
||||||
fileHandler := file.NewHandler(afero.NewOsFs())
|
|
||||||
devConfigName, err := cmd.Flags().GetString("dev-config")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
config, err := config.FromFile(fileHandler, devConfigName)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
validators, err := getGCPValidator(cmd, *config.Provider.GCP.PCRs)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return runVerify(cmd, args, *config.Provider.GCP.PCRs, validators)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getValidators returns a GCP validator.
|
|
||||||
func getGCPValidator(cmd *cobra.Command, pcrs map[uint32][]byte) (atls.Validator, error) {
|
|
||||||
if err := prepareValidator(cmd, pcrs); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return gcp.NewValidator(pcrs), nil
|
|
||||||
}
|
|
@ -1,40 +0,0 @@
|
|||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
|
||||||
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TODO: Remove this command once we no longer use non cvms.
|
|
||||||
func newVerifyGCPNonCVMCmd() *cobra.Command {
|
|
||||||
cmd := &cobra.Command{
|
|
||||||
Use: "gcp-non-cvm IP PORT",
|
|
||||||
Short: "Verify the TPM attestation of your shielded VM Constellation on Google Cloud Platform.",
|
|
||||||
Long: "Verify the TPM attestation of your shielded VM Constellation on Google Cloud Platform.",
|
|
||||||
Args: cobra.ExactArgs(2),
|
|
||||||
ValidArgsFunction: verifyCompletion,
|
|
||||||
RunE: runVerifyGCPNonCVM,
|
|
||||||
}
|
|
||||||
|
|
||||||
return cmd
|
|
||||||
}
|
|
||||||
|
|
||||||
func runVerifyGCPNonCVM(cmd *cobra.Command, args []string) error {
|
|
||||||
pcrs := map[uint32][]byte{}
|
|
||||||
validator, err := getGCPNonCVMValidator(cmd, pcrs)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return runVerify(cmd, args, pcrs, validator)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getGCPNonCVMValidator returns a GCP validator for regular shielded VMs.
|
|
||||||
func getGCPNonCVMValidator(cmd *cobra.Command, pcrs map[uint32][]byte) (atls.Validator, error) {
|
|
||||||
if err := prepareValidator(cmd, pcrs); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return gcp.NewNonCVMValidator(pcrs), nil
|
|
||||||
}
|
|
@ -1,63 +0,0 @@
|
|||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/base64"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGetGCPNonCVMValidator(t *testing.T) {
|
|
||||||
testCases := map[string]struct {
|
|
||||||
ownerID string
|
|
||||||
clusterID string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
"no input": {
|
|
||||||
ownerID: "",
|
|
||||||
clusterID: "",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"unencoded secret ID": {
|
|
||||||
ownerID: "owner-id",
|
|
||||||
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"unencoded cluster ID": {
|
|
||||||
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
|
|
||||||
clusterID: "unique-id",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"correct input": {
|
|
||||||
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
|
|
||||||
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, tc := range testCases {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
assert := assert.New(t)
|
|
||||||
require := require.New(t)
|
|
||||||
|
|
||||||
cmd := newVerifyGCPNonCVMCmd()
|
|
||||||
cmd.Flags().String("owner-id", "", "")
|
|
||||||
cmd.Flags().String("unique-id", "", "")
|
|
||||||
require.NoError(cmd.Flags().Set("owner-id", tc.ownerID))
|
|
||||||
require.NoError(cmd.Flags().Set("unique-id", tc.clusterID))
|
|
||||||
var out bytes.Buffer
|
|
||||||
cmd.SetOut(&out)
|
|
||||||
var errOut bytes.Buffer
|
|
||||||
cmd.SetErr(&errOut)
|
|
||||||
|
|
||||||
_, err := getGCPNonCVMValidator(cmd, map[uint32][]byte{})
|
|
||||||
if tc.wantErr {
|
|
||||||
assert.Error(err)
|
|
||||||
} else {
|
|
||||||
assert.NoError(err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,66 +0,0 @@
|
|||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/base64"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGetGCPValidator(t *testing.T) {
|
|
||||||
testCases := map[string]struct {
|
|
||||||
ownerID string
|
|
||||||
clusterID string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
"no input": {
|
|
||||||
ownerID: "",
|
|
||||||
clusterID: "",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"unencoded secret ID": {
|
|
||||||
ownerID: "owner-id",
|
|
||||||
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"unencoded cluster ID": {
|
|
||||||
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
|
|
||||||
clusterID: "unique-id",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"correct input": {
|
|
||||||
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
|
|
||||||
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, tc := range testCases {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
assert := assert.New(t)
|
|
||||||
require := require.New(t)
|
|
||||||
|
|
||||||
cmd := newVerifyGCPCmd()
|
|
||||||
cmd.Flags().String("owner-id", "", "")
|
|
||||||
cmd.Flags().String("unique-id", "", "")
|
|
||||||
require.NoError(cmd.Flags().Set("owner-id", tc.ownerID))
|
|
||||||
require.NoError(cmd.Flags().Set("unique-id", tc.clusterID))
|
|
||||||
var out bytes.Buffer
|
|
||||||
cmd.SetOut(&out)
|
|
||||||
var errOut bytes.Buffer
|
|
||||||
cmd.SetErr(&errOut)
|
|
||||||
|
|
||||||
_, err := getGCPValidator(cmd, map[uint32][]byte{
|
|
||||||
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
|
|
||||||
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"),
|
|
||||||
})
|
|
||||||
if tc.wantErr {
|
|
||||||
assert.Error(err)
|
|
||||||
} else {
|
|
||||||
assert.NoError(err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
@ -7,152 +7,105 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/edgelesssys/constellation/cli/status"
|
"github.com/edgelesssys/constellation/cli/cloudprovider"
|
||||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
"github.com/edgelesssys/constellation/cli/file"
|
||||||
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
|
"github.com/spf13/afero"
|
||||||
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
|
|
||||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
|
||||||
"github.com/edgelesssys/constellation/coordinator/state"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"google.golang.org/grpc"
|
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
rpcStatus "google.golang.org/grpc/status"
|
rpcStatus "google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestVerify(t *testing.T) {
|
func TestVerifyCmdArgumentValidation(t *testing.T) {
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
connErr error
|
args []string
|
||||||
checkErr error
|
|
||||||
state state.State
|
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
"connection error": {
|
"no args": {[]string{}, true},
|
||||||
connErr: errors.New("connection error"),
|
"valid azure": {[]string{"azure", "192.0.2.1", "1234"}, false},
|
||||||
checkErr: nil,
|
"valid gcp": {[]string{"gcp", "192.0.2.1", "1234"}, false},
|
||||||
state: 0,
|
"invalid provider": {[]string{"invalid", "192.0.2.1", "1234"}, true},
|
||||||
wantErr: true,
|
"invalid ip": {[]string{"gcp", "invalid", "1234"}, true},
|
||||||
},
|
"invalid port": {[]string{"gcp", "192.0.2.1", "invalid"}, true},
|
||||||
"check error": {
|
"invalid port 2": {[]string{"gcp", "192.0.2.1", "65536"}, true},
|
||||||
connErr: nil,
|
"not enough arguments": {[]string{"gcp", "192.0.2.1"}, true},
|
||||||
checkErr: errors.New("check error"),
|
"too many arguments": {[]string{"gcp", "192.0.2.1", "1234", "5678"}, true},
|
||||||
state: 0,
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"check error, rpc status": {
|
|
||||||
connErr: nil,
|
|
||||||
checkErr: rpcStatus.Error(codes.Unavailable, "check error"),
|
|
||||||
state: 0,
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"verify on worker node": {
|
|
||||||
connErr: nil,
|
|
||||||
checkErr: nil,
|
|
||||||
state: state.IsNode,
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
"verify on master node": {
|
|
||||||
connErr: nil,
|
|
||||||
checkErr: nil,
|
|
||||||
state: state.ActivatingNodes,
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, tc := range testCases {
|
for name, tc := range testCases {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
assert := assert.New(t)
|
assert := assert.New(t)
|
||||||
|
|
||||||
ctx := context.Background()
|
cmd := newVerifyCmd()
|
||||||
var out bytes.Buffer
|
err := cmd.ValidateArgs(tc.args)
|
||||||
|
|
||||||
verifier := verifier{
|
|
||||||
newConn: stubNewConnFunc(tc.connErr),
|
|
||||||
newClient: stubNewClientFunc(&stubPeerStatusClient{
|
|
||||||
state: tc.state,
|
|
||||||
checkErr: tc.checkErr,
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
|
|
||||||
pcrs := map[uint32][]byte{
|
|
||||||
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
|
|
||||||
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"),
|
|
||||||
}
|
|
||||||
err := verify(ctx, &out, "", []atls.Validator{gcp.NewValidator(pcrs)}, verifier)
|
|
||||||
|
|
||||||
if tc.wantErr {
|
if tc.wantErr {
|
||||||
assert.Error(err)
|
assert.Error(err)
|
||||||
} else {
|
} else {
|
||||||
assert.NoError(err)
|
assert.NoError(err)
|
||||||
assert.Contains(out.String(), "OK")
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func stubNewConnFunc(errStub error) func(ctx context.Context, target string, validators []atls.Validator) (status.ClientConn, error) {
|
func TestVerify(t *testing.T) {
|
||||||
return func(ctx context.Context, target string, validators []atls.Validator) (status.ClientConn, error) {
|
zeroBase64 := base64.StdEncoding.EncodeToString([]byte("00000000000000000000000000000000"))
|
||||||
return &stubClientConn{}, errStub
|
someErr := errors.New("failed")
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type stubClientConn struct{}
|
|
||||||
|
|
||||||
func (c *stubClientConn) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...grpc.CallOption) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *stubClientConn) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *stubClientConn) Close() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func stubNewClientFunc(stubClient pubproto.APIClient) func(cc grpc.ClientConnInterface) pubproto.APIClient {
|
|
||||||
return func(cc grpc.ClientConnInterface) pubproto.APIClient {
|
|
||||||
return stubClient
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type stubPeerStatusClient struct {
|
|
||||||
state state.State
|
|
||||||
checkErr error
|
|
||||||
pubproto.APIClient
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *stubPeerStatusClient) GetState(ctx context.Context, in *pubproto.GetStateRequest, opts ...grpc.CallOption) (*pubproto.GetStateResponse, error) {
|
|
||||||
resp := &pubproto.GetStateResponse{State: uint32(c.state)}
|
|
||||||
return resp, c.checkErr
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPrepareValidator(t *testing.T) {
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
ownerID string
|
setupFs func(*require.Assertions) afero.Fs
|
||||||
clusterID string
|
provider cloudprovider.Provider
|
||||||
|
protoClient protoClient
|
||||||
|
devConfigFlag string
|
||||||
|
ownerIDFlag string
|
||||||
|
clusterIDFlag string
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
"no input": {
|
"gcp": {
|
||||||
ownerID: "",
|
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||||
clusterID: "",
|
provider: cloudprovider.GCP,
|
||||||
|
ownerIDFlag: zeroBase64,
|
||||||
|
protoClient: &stubProtoClient{},
|
||||||
|
},
|
||||||
|
"azure": {
|
||||||
|
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||||
|
provider: cloudprovider.Azure,
|
||||||
|
ownerIDFlag: zeroBase64,
|
||||||
|
protoClient: &stubProtoClient{},
|
||||||
|
},
|
||||||
|
"neither owner id nor cluster id set": {
|
||||||
|
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||||
|
provider: cloudprovider.GCP,
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
"unencoded secret ID": {
|
"dev config file not existing": {
|
||||||
ownerID: "owner-id",
|
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||||
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
|
provider: cloudprovider.GCP,
|
||||||
|
ownerIDFlag: zeroBase64,
|
||||||
|
devConfigFlag: "./file",
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
"unencoded cluster ID": {
|
"error protoClient Connect": {
|
||||||
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
|
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||||
clusterID: "unique-id",
|
provider: cloudprovider.Azure,
|
||||||
|
ownerIDFlag: zeroBase64,
|
||||||
|
protoClient: &stubProtoClient{connectErr: someErr},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
"correct input": {
|
"error protoClient GetState": {
|
||||||
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
|
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||||
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
|
provider: cloudprovider.Azure,
|
||||||
wantErr: false,
|
ownerIDFlag: zeroBase64,
|
||||||
|
protoClient: &stubProtoClient{getStateErr: rpcStatus.Error(codes.Internal, "failed")},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
"error protoClient GetState not rpc": {
|
||||||
|
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||||
|
provider: cloudprovider.Azure,
|
||||||
|
ownerIDFlag: zeroBase64,
|
||||||
|
protoClient: &stubProtoClient{getStateErr: someErr},
|
||||||
|
wantErr: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -162,124 +115,29 @@ func TestPrepareValidator(t *testing.T) {
|
|||||||
require := require.New(t)
|
require := require.New(t)
|
||||||
|
|
||||||
cmd := newVerifyCmd()
|
cmd := newVerifyCmd()
|
||||||
cmd.Flags().String("owner-id", "", "")
|
cmd.Flags().String("dev-config", "", "") // register persisten flag manually
|
||||||
cmd.Flags().String("unique-id", "", "")
|
out := bytes.NewBufferString("")
|
||||||
require.NoError(cmd.Flags().Set("owner-id", tc.ownerID))
|
cmd.SetOut(out)
|
||||||
require.NoError(cmd.Flags().Set("unique-id", tc.clusterID))
|
cmd.SetErr(bytes.NewBufferString(""))
|
||||||
var out bytes.Buffer
|
if tc.devConfigFlag != "" {
|
||||||
cmd.SetOut(&out)
|
require.NoError(cmd.Flags().Set("dev-config", tc.devConfigFlag))
|
||||||
var errOut bytes.Buffer
|
}
|
||||||
cmd.SetErr(&errOut)
|
if tc.ownerIDFlag != "" {
|
||||||
|
require.NoError(cmd.Flags().Set("owner-id", tc.ownerIDFlag))
|
||||||
|
}
|
||||||
|
if tc.clusterIDFlag != "" {
|
||||||
|
require.NoError(cmd.Flags().Set("cluster-id", tc.clusterIDFlag))
|
||||||
|
}
|
||||||
|
fileHandler := file.NewHandler(tc.setupFs(require))
|
||||||
|
|
||||||
pcrs := map[uint32][]byte{
|
ctx := context.Background()
|
||||||
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
|
err := verify(ctx, cmd, tc.provider, "192.0.2.1", "1234", fileHandler, tc.protoClient)
|
||||||
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := prepareValidator(cmd, pcrs)
|
|
||||||
if tc.wantErr {
|
|
||||||
assert.Error(err)
|
|
||||||
} else {
|
|
||||||
assert.NoError(err)
|
|
||||||
if tc.clusterID != "" {
|
|
||||||
assert.Len(pcrs[uint32(vtpm.PCRIndexClusterID)], 32)
|
|
||||||
} else {
|
|
||||||
assert.Nil(pcrs[uint32(vtpm.PCRIndexClusterID)])
|
|
||||||
}
|
|
||||||
if tc.ownerID != "" {
|
|
||||||
assert.Len(pcrs[uint32(vtpm.PCRIndexOwnerID)], 32)
|
|
||||||
} else {
|
|
||||||
assert.Nil(pcrs[uint32(vtpm.PCRIndexOwnerID)])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAddOrSkipPcr(t *testing.T) {
|
|
||||||
emptyMap := map[uint32][]byte{}
|
|
||||||
defaultMap := map[uint32][]byte{
|
|
||||||
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
|
|
||||||
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"),
|
|
||||||
}
|
|
||||||
|
|
||||||
testCases := map[string]struct {
|
|
||||||
pcrMap map[uint32][]byte
|
|
||||||
pcrIndex uint32
|
|
||||||
encoded string
|
|
||||||
wantEntries int
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
"empty input, empty map": {
|
|
||||||
pcrMap: emptyMap,
|
|
||||||
pcrIndex: 10,
|
|
||||||
encoded: "",
|
|
||||||
wantEntries: 0,
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
"empty input, default map": {
|
|
||||||
pcrMap: defaultMap,
|
|
||||||
pcrIndex: 10,
|
|
||||||
encoded: "",
|
|
||||||
wantEntries: len(defaultMap),
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
"correct input, empty map": {
|
|
||||||
pcrMap: emptyMap,
|
|
||||||
pcrIndex: 10,
|
|
||||||
encoded: base64.StdEncoding.EncodeToString([]byte("Constellation")),
|
|
||||||
wantEntries: 1,
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
"correct input, default map": {
|
|
||||||
pcrMap: defaultMap,
|
|
||||||
pcrIndex: 10,
|
|
||||||
encoded: base64.StdEncoding.EncodeToString([]byte("Constellation")),
|
|
||||||
wantEntries: len(defaultMap) + 1,
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
"unencoded input, empty map": {
|
|
||||||
pcrMap: emptyMap,
|
|
||||||
pcrIndex: 10,
|
|
||||||
encoded: "Constellation",
|
|
||||||
wantEntries: 0,
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"unencoded input, default map": {
|
|
||||||
pcrMap: defaultMap,
|
|
||||||
pcrIndex: 10,
|
|
||||||
encoded: "Constellation",
|
|
||||||
wantEntries: len(defaultMap),
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"empty input at occupied index": {
|
|
||||||
pcrMap: defaultMap,
|
|
||||||
pcrIndex: 0,
|
|
||||||
encoded: "",
|
|
||||||
wantEntries: len(defaultMap) - 1,
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, tc := range testCases {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
assert := assert.New(t)
|
|
||||||
|
|
||||||
res := make(map[uint32][]byte)
|
|
||||||
for k, v := range tc.pcrMap {
|
|
||||||
res[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
err := addOrSkipPCR(res, tc.pcrIndex, tc.encoded)
|
|
||||||
|
|
||||||
if tc.wantErr {
|
if tc.wantErr {
|
||||||
assert.Error(err)
|
assert.Error(err)
|
||||||
} else {
|
} else {
|
||||||
assert.NoError(err)
|
assert.NoError(err)
|
||||||
}
|
assert.Contains(out.String(), "OK")
|
||||||
assert.Len(res, tc.wantEntries)
|
|
||||||
for _, v := range res {
|
|
||||||
assert.Len(v, 32)
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -294,18 +152,24 @@ func TestVerifyCompletion(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
"first arg": {
|
"first arg": {
|
||||||
args: []string{},
|
args: []string{},
|
||||||
|
toComplete: "az",
|
||||||
|
wantResult: []string{"gcp", "azure"},
|
||||||
|
wantShellCD: cobra.ShellCompDirectiveNoFileComp,
|
||||||
|
},
|
||||||
|
"second arg": {
|
||||||
|
args: []string{"gcp"},
|
||||||
toComplete: "192.0.2.1",
|
toComplete: "192.0.2.1",
|
||||||
wantResult: []string{},
|
wantResult: []string{},
|
||||||
wantShellCD: cobra.ShellCompDirectiveNoFileComp,
|
wantShellCD: cobra.ShellCompDirectiveNoFileComp,
|
||||||
},
|
},
|
||||||
"second arg": {
|
"third arg": {
|
||||||
args: []string{"192.0.2.1"},
|
args: []string{"gcp", "192.0.2.1"},
|
||||||
toComplete: "443",
|
toComplete: "443",
|
||||||
wantResult: []string{},
|
wantResult: []string{},
|
||||||
wantShellCD: cobra.ShellCompDirectiveNoFileComp,
|
wantShellCD: cobra.ShellCompDirectiveNoFileComp,
|
||||||
},
|
},
|
||||||
"third arg": {
|
"additional arg": {
|
||||||
args: []string{"192.0.2.1", "443"},
|
args: []string{"gcp", "192.0.2.1", "443"},
|
||||||
toComplete: "./file",
|
toComplete: "./file",
|
||||||
wantResult: []string{},
|
wantResult: []string{},
|
||||||
wantShellCD: cobra.ShellCompDirectiveError,
|
wantShellCD: cobra.ShellCompDirectiveError,
|
||||||
|
@ -15,6 +15,8 @@ import (
|
|||||||
grpcstatus "google.golang.org/grpc/status"
|
grpcstatus "google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TODO(katexochen): Use protoClient for waiter?
|
||||||
|
|
||||||
// Waiter waits for PeerStatusServer to reach a specific state. The waiter needs
|
// Waiter waits for PeerStatusServer to reach a specific state. The waiter needs
|
||||||
// to be initialized before usage.
|
// to be initialized before usage.
|
||||||
type Waiter struct {
|
type Waiter struct {
|
||||||
|
Loading…
Reference in New Issue
Block a user