Allow waiting for multiple states (#11)

* Simplify `fetch_pcrs.sh` script

* Allow waiting for multiple states

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2022-03-29 09:10:22 +02:00 committed by GitHub
parent 9df71da33f
commit eb3411f2c1
7 changed files with 87 additions and 31 deletions

View file

@ -107,7 +107,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, vpn
} }
endpoints := ipsToEndpoints(append(coordinators.PublicIPs(), nodes.PublicIPs()...), *config.CoordinatorPort) endpoints := ipsToEndpoints(append(coordinators.PublicIPs(), nodes.PublicIPs()...), *config.CoordinatorPort)
if err := waiter.WaitForAll(ctx, coordinatorstate.AcceptingInit, endpoints); err != nil { if err := waiter.WaitForAll(ctx, endpoints, coordinatorstate.AcceptingInit); err != nil {
return fmt.Errorf("failed to wait for peer status: %w", err) return fmt.Errorf("failed to wait for peer status: %w", err)
} }

View file

@ -7,5 +7,5 @@ import (
) )
type statusWaiter interface { type statusWaiter interface {
WaitForAll(ctx context.Context, status state.State, endpoints []string) error WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error
} }

View file

@ -10,6 +10,6 @@ type stubStatusWaiter struct {
waitForAllErr error waitForAllErr error
} }
func (w stubStatusWaiter) WaitForAll(ctx context.Context, status state.State, endpoints []string) error { func (w stubStatusWaiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
return w.waitForAllErr return w.waitForAllErr
} }

View file

@ -36,7 +36,7 @@ func NewWaiter(gcpPCRs map[uint32][]byte) Waiter {
// WaitFor waits for a PeerStatusServer, which is reachable under the given endpoint // WaitFor waits for a PeerStatusServer, which is reachable under the given endpoint
// to reach the specified state. // to reach the specified state.
func (w Waiter) WaitFor(ctx context.Context, status state.State, endpoint string) error { func (w Waiter) WaitFor(ctx context.Context, endpoint string, status ...state.State) error {
ticker := time.NewTicker(w.interval) ticker := time.NewTicker(w.interval)
defer ticker.Stop() defer ticker.Stop()
@ -45,7 +45,7 @@ func (w Waiter) WaitFor(ctx context.Context, status state.State, endpoint string
if err != nil && grpcstatus.Code(err) != grpccodes.Unavailable { if err != nil && grpcstatus.Code(err) != grpccodes.Unavailable {
return err return err
} }
if resp != nil && resp.State == uint32(status) { if resp != nil && containsState(state.State(resp.State), status...) {
return nil return nil
} }
@ -61,7 +61,7 @@ func (w Waiter) WaitFor(ctx context.Context, status state.State, endpoint string
if err != nil { if err != nil {
return err return err
} }
if resp.State == uint32(status) { if containsState(state.State(resp.State), status...) {
return nil return nil
} }
case <-ctx.Done(): case <-ctx.Done():
@ -84,9 +84,9 @@ func (w Waiter) probe(ctx context.Context, endpoint string) (*pubproto.GetStateR
// WaitForAll waits for a list of PeerStatusServers, which listen on the handed // WaitForAll waits for a list of PeerStatusServers, which listen on the handed
// endpoints, to reach the specified state. // endpoints, to reach the specified state.
func (w Waiter) WaitForAll(ctx context.Context, status state.State, endpoints []string) error { func (w Waiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
for _, endpoint := range endpoints { for _, endpoint := range endpoints {
if err := w.WaitFor(ctx, status, endpoint); err != nil { if err := w.WaitFor(ctx, endpoint, status...); err != nil {
return err return err
} }
} }
@ -119,3 +119,13 @@ type ClientConn interface {
grpc.ClientConnInterface grpc.ClientConnInterface
io.Closer io.Closer
} }
// containsState checks if current state is one of the given states.
func containsState(s state.State, states ...state.State) bool {
for _, state := range states {
if state == s {
return true
}
}
return false
}

View file

@ -18,7 +18,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
waiter Waiter waiter Waiter
waitForState state.State waitForState []state.State
wantErr bool wantErr bool
}{ }{
"successful wait": { "successful wait": {
@ -27,7 +27,15 @@ func TestWaitForAndWaitForAll(t *testing.T) {
newConn: stubNewConnFunc(noErr), newConn: stubNewConnFunc(noErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}), newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
}, },
waitForState: state.IsNode, waitForState: []state.State{state.IsNode},
},
"successful wait multi states": {
waiter: Waiter{
interval: time.Millisecond,
newConn: stubNewConnFunc(noErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
},
waitForState: []state.State{state.IsNode, state.ActivatingNodes},
}, },
"expect timeout": { "expect timeout": {
waiter: Waiter{ waiter: Waiter{
@ -35,7 +43,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
newConn: stubNewConnFunc(noErr), newConn: stubNewConnFunc(noErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.AcceptingInit}), newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.AcceptingInit}),
}, },
waitForState: state.IsNode, waitForState: []state.State{state.IsNode},
wantErr: true, wantErr: true,
}, },
"fail to check call": { "fail to check call": {
@ -44,7 +52,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
newConn: stubNewConnFunc(noErr), newConn: stubNewConnFunc(noErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{checkErr: someErr}), newClient: stubNewClientFunc(&stubPeerStatusClient{checkErr: someErr}),
}, },
waitForState: state.IsNode, waitForState: []state.State{state.IsNode},
wantErr: true, wantErr: true,
}, },
"fail to create conn": { "fail to create conn": {
@ -53,7 +61,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
newConn: stubNewConnFunc(someErr), newConn: stubNewConnFunc(someErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{}), newClient: stubNewClientFunc(&stubPeerStatusClient{}),
}, },
waitForState: state.IsNode, waitForState: []state.State{state.IsNode},
wantErr: true, wantErr: true,
}, },
} }
@ -67,7 +75,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, 50*time.Millisecond) ctx, cancel := context.WithTimeout(ctx, 50*time.Millisecond)
defer cancel() defer cancel()
err := tc.waiter.WaitFor(ctx, tc.waitForState, "someIP") err := tc.waiter.WaitFor(ctx, "someIP", tc.waitForState...)
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
@ -88,7 +96,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
defer cancel() defer cancel()
endpoints := []string{"192.0.2.1", "192.0.2.2", "192.0.2.3"} endpoints := []string{"192.0.2.1", "192.0.2.2", "192.0.2.3"}
err := tc.waiter.WaitForAll(ctx, tc.waitForState, endpoints) err := tc.waiter.WaitForAll(ctx, endpoints, tc.waitForState...)
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
@ -136,3 +144,49 @@ func (c *stubPeerStatusClient) GetState(ctx context.Context, in *pubproto.GetSta
resp := &pubproto.GetStateResponse{State: uint32(c.state)} resp := &pubproto.GetStateResponse{State: uint32(c.state)}
return resp, c.checkErr return resp, c.checkErr
} }
func TestContainsState(t *testing.T) {
testCases := map[string]struct {
s state.State
states []state.State
success bool
}{
"is state": {
s: state.IsNode,
states: []state.State{
state.IsNode,
},
success: true,
},
"is state multi": {
s: state.AcceptingInit,
states: []state.State{
state.AcceptingInit,
state.ActivatingNodes,
},
success: true,
},
"is not state": {
s: state.NodeWaitingForClusterJoin,
states: []state.State{
state.AcceptingInit,
},
},
"is not state multi": {
s: state.NodeWaitingForClusterJoin,
states: []state.State{
state.AcceptingInit,
state.ActivatingNodes,
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
res := containsState(tc.s, tc.states...)
assert.Equal(tc.success, res)
})
}
}

View file

@ -6,21 +6,13 @@ trap 'terminate $?' ERR
terminate() { terminate() {
echo "error: $1" echo "error: $1"
constellation terminate constellation terminate
popd || exit 1
exit 1 exit 1
} }
main() { main() {
if ! command -v constellation &> /dev/null command -v constellation > /dev/null
then command -v go > /dev/null
echo "constellation is not in path" command -v jq > /dev/null
exit 1
fi
if ! command -v go &> /dev/null
then
echo "go is not in path"
exit 1
fi
mkdir -p ./pcrs mkdir -p ./pcrs

View file

@ -19,7 +19,7 @@ import (
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm" "github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
"github.com/edgelesssys/constellation/coordinator/oid" "github.com/edgelesssys/constellation/coordinator/oid"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
coordinatorstate "github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/coordinator/state"
"github.com/spf13/afero" "github.com/spf13/afero"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
@ -42,21 +42,21 @@ func main() {
// wait for coordinator to come online // wait for coordinator to come online
waiter := status.NewWaiter(map[uint32][]byte{}) waiter := status.NewWaiter(map[uint32][]byte{})
if err := waiter.WaitFor(ctx, coordinatorstate.AcceptingInit, addr); err != nil { if err := waiter.WaitFor(ctx, addr, state.AcceptingInit, state.ActivatingNodes, state.IsNode, state.NodeWaitingForClusterJoin); err != nil {
log.Fatal(err) log.Fatal(err)
} }
attDocRaw := &[]byte{} attDocRaw := []byte{}
tlsConfig, err := atls.CreateUnverifiedClientTLSConfig() tlsConfig, err := atls.CreateUnverifiedClientTLSConfig()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
tlsConfig.VerifyPeerCertificate = getVerifyPeerCertificateFunc(attDocRaw) tlsConfig.VerifyPeerCertificate = getVerifyPeerCertificateFunc(&attDocRaw)
if err := connectToCoordinator(ctx, addr, tlsConfig); err != nil { if err := connectToCoordinator(ctx, addr, tlsConfig); err != nil {
log.Fatal(err) log.Fatal(err)
} }
pcrs, err := validatePCRAttDoc(*attDocRaw) pcrs, err := validatePCRAttDoc(attDocRaw)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }