cli: refactor flag parsing code (#2425)

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2023-10-16 15:05:29 +02:00 committed by GitHub
parent adfe443b28
commit c52086c5ff
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
34 changed files with 1490 additions and 1726 deletions

View file

@ -58,6 +58,7 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: "192.0.2.1:1234",
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{},
stateFile: state.New(),
wantEndpoint: "192.0.2.1:1234",
formatter: &stubAttDocFormatter{},
},
@ -66,6 +67,7 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: "192.0.2.1:1234",
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{},
stateFile: state.New(),
wantEndpoint: "192.0.2.1:1234",
formatter: &stubAttDocFormatter{},
},
@ -74,6 +76,7 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: "192.0.2.1",
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{},
stateFile: state.New(),
wantEndpoint: "192.0.2.1:" + strconv.Itoa(constants.VerifyServiceNodePortGRPC),
formatter: &stubAttDocFormatter{},
},
@ -81,6 +84,7 @@ func TestVerify(t *testing.T) {
provider: cloudprovider.GCP,
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{},
stateFile: state.New(),
formatter: &stubAttDocFormatter{},
wantErr: true,
},
@ -106,12 +110,14 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: ":::::",
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{},
stateFile: state.New(),
formatter: &stubAttDocFormatter{},
wantErr: true,
},
"neither owner id nor cluster id set": {
provider: cloudprovider.GCP,
nodeEndpointFlag: "192.0.2.1:1234",
stateFile: state.New(),
formatter: &stubAttDocFormatter{},
wantErr: true,
},
@ -127,6 +133,7 @@ func TestVerify(t *testing.T) {
provider: cloudprovider.GCP,
clusterIDFlag: zeroBase64,
nodeEndpointFlag: "192.0.2.1:1234",
stateFile: state.New(),
formatter: &stubAttDocFormatter{},
skipConfigCreation: true,
wantErr: true,
@ -136,6 +143,7 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: "192.0.2.1:1234",
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{verifyErr: rpcStatus.Error(codes.Internal, "failed")},
stateFile: state.New(),
formatter: &stubAttDocFormatter{},
wantErr: true,
},
@ -144,6 +152,7 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: "192.0.2.1:1234",
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{verifyErr: someErr},
stateFile: state.New(),
formatter: &stubAttDocFormatter{},
wantErr: true,
},
@ -152,6 +161,7 @@ func TestVerify(t *testing.T) {
nodeEndpointFlag: "192.0.2.1:1234",
clusterIDFlag: zeroBase64,
protoClient: &stubVerifyClient{},
stateFile: state.New(),
wantEndpoint: "192.0.2.1:1234",
formatter: &stubAttDocFormatter{formatErr: someErr},
wantErr: true,
@ -164,31 +174,28 @@ func TestVerify(t *testing.T) {
require := require.New(t)
cmd := NewVerifyCmd()
cmd.Flags().String("workspace", "", "") // register persistent flag manually
cmd.Flags().Bool("force", true, "") // register persistent flag manually
out := &bytes.Buffer{}
cmd.SetErr(out)
if tc.clusterIDFlag != "" {
require.NoError(cmd.Flags().Set("cluster-id", tc.clusterIDFlag))
}
if tc.nodeEndpointFlag != "" {
require.NoError(cmd.Flags().Set("node-endpoint", tc.nodeEndpointFlag))
}
fileHandler := file.NewHandler(afero.NewMemMapFs())
if !tc.skipConfigCreation {
cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), tc.provider)
require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg))
}
if tc.stateFile != nil {
require.NoError(tc.stateFile.WriteToFile(fileHandler, constants.StateFilename))
}
require.NoError(tc.stateFile.WriteToFile(fileHandler, constants.StateFilename))
v := &verifyCmd{log: logger.NewTest(t)}
v := &verifyCmd{
fileHandler: fileHandler,
log: logger.NewTest(t),
flags: verifyFlags{
clusterID: tc.clusterIDFlag,
endpoint: tc.nodeEndpointFlag,
},
}
formatterFac := func(_ string, _ cloudprovider.Provider, _ debugLog) (attestationDocFormatter, error) {
return tc.formatter, nil
}
err := v.verify(cmd, fileHandler, tc.protoClient, formatterFac, stubAttestationFetcher{})
err := v.verify(cmd, tc.protoClient, formatterFac, stubAttestationFetcher{})
if tc.wantErr {
assert.Error(err)
} else {