Allow waiting for multiple states (#11)

* Simplify `fetch_pcrs.sh` script

* Allow waiting for multiple states

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2022-03-29 09:10:22 +02:00 committed by GitHub
parent 9df71da33f
commit eb3411f2c1
7 changed files with 87 additions and 31 deletions

View File

@ -107,7 +107,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, vpn
}
endpoints := ipsToEndpoints(append(coordinators.PublicIPs(), nodes.PublicIPs()...), *config.CoordinatorPort)
if err := waiter.WaitForAll(ctx, coordinatorstate.AcceptingInit, endpoints); err != nil {
if err := waiter.WaitForAll(ctx, endpoints, coordinatorstate.AcceptingInit); err != nil {
return fmt.Errorf("failed to wait for peer status: %w", err)
}

View File

@ -7,5 +7,5 @@ import (
)
type statusWaiter interface {
WaitForAll(ctx context.Context, status state.State, endpoints []string) error
WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error
}

View File

@ -10,6 +10,6 @@ type stubStatusWaiter struct {
waitForAllErr error
}
func (w stubStatusWaiter) WaitForAll(ctx context.Context, status state.State, endpoints []string) error {
func (w stubStatusWaiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
return w.waitForAllErr
}

View File

@ -36,7 +36,7 @@ func NewWaiter(gcpPCRs map[uint32][]byte) Waiter {
// WaitFor waits for a PeerStatusServer, which is reachable under the given endpoint
// to reach the specified state.
func (w Waiter) WaitFor(ctx context.Context, status state.State, endpoint string) error {
func (w Waiter) WaitFor(ctx context.Context, endpoint string, status ...state.State) error {
ticker := time.NewTicker(w.interval)
defer ticker.Stop()
@ -45,7 +45,7 @@ func (w Waiter) WaitFor(ctx context.Context, status state.State, endpoint string
if err != nil && grpcstatus.Code(err) != grpccodes.Unavailable {
return err
}
if resp != nil && resp.State == uint32(status) {
if resp != nil && containsState(state.State(resp.State), status...) {
return nil
}
@ -61,7 +61,7 @@ func (w Waiter) WaitFor(ctx context.Context, status state.State, endpoint string
if err != nil {
return err
}
if resp.State == uint32(status) {
if containsState(state.State(resp.State), status...) {
return nil
}
case <-ctx.Done():
@ -84,9 +84,9 @@ func (w Waiter) probe(ctx context.Context, endpoint string) (*pubproto.GetStateR
// WaitForAll waits for a list of PeerStatusServers, which listen on the handed
// endpoints, to reach the specified state.
func (w Waiter) WaitForAll(ctx context.Context, status state.State, endpoints []string) error {
func (w Waiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
for _, endpoint := range endpoints {
if err := w.WaitFor(ctx, status, endpoint); err != nil {
if err := w.WaitFor(ctx, endpoint, status...); err != nil {
return err
}
}
@ -119,3 +119,13 @@ type ClientConn interface {
grpc.ClientConnInterface
io.Closer
}
// containsState checks if current state is one of the given states.
func containsState(s state.State, states ...state.State) bool {
for _, state := range states {
if state == s {
return true
}
}
return false
}

View File

@ -18,7 +18,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
testCases := map[string]struct {
waiter Waiter
waitForState state.State
waitForState []state.State
wantErr bool
}{
"successful wait": {
@ -27,7 +27,15 @@ func TestWaitForAndWaitForAll(t *testing.T) {
newConn: stubNewConnFunc(noErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
},
waitForState: state.IsNode,
waitForState: []state.State{state.IsNode},
},
"successful wait multi states": {
waiter: Waiter{
interval: time.Millisecond,
newConn: stubNewConnFunc(noErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
},
waitForState: []state.State{state.IsNode, state.ActivatingNodes},
},
"expect timeout": {
waiter: Waiter{
@ -35,7 +43,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
newConn: stubNewConnFunc(noErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.AcceptingInit}),
},
waitForState: state.IsNode,
waitForState: []state.State{state.IsNode},
wantErr: true,
},
"fail to check call": {
@ -44,7 +52,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
newConn: stubNewConnFunc(noErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{checkErr: someErr}),
},
waitForState: state.IsNode,
waitForState: []state.State{state.IsNode},
wantErr: true,
},
"fail to create conn": {
@ -53,7 +61,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
newConn: stubNewConnFunc(someErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{}),
},
waitForState: state.IsNode,
waitForState: []state.State{state.IsNode},
wantErr: true,
},
}
@ -67,7 +75,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, 50*time.Millisecond)
defer cancel()
err := tc.waiter.WaitFor(ctx, tc.waitForState, "someIP")
err := tc.waiter.WaitFor(ctx, "someIP", tc.waitForState...)
if tc.wantErr {
assert.Error(err)
@ -88,7 +96,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
defer cancel()
endpoints := []string{"192.0.2.1", "192.0.2.2", "192.0.2.3"}
err := tc.waiter.WaitForAll(ctx, tc.waitForState, endpoints)
err := tc.waiter.WaitForAll(ctx, endpoints, tc.waitForState...)
if tc.wantErr {
assert.Error(err)
@ -136,3 +144,49 @@ func (c *stubPeerStatusClient) GetState(ctx context.Context, in *pubproto.GetSta
resp := &pubproto.GetStateResponse{State: uint32(c.state)}
return resp, c.checkErr
}
func TestContainsState(t *testing.T) {
testCases := map[string]struct {
s state.State
states []state.State
success bool
}{
"is state": {
s: state.IsNode,
states: []state.State{
state.IsNode,
},
success: true,
},
"is state multi": {
s: state.AcceptingInit,
states: []state.State{
state.AcceptingInit,
state.ActivatingNodes,
},
success: true,
},
"is not state": {
s: state.NodeWaitingForClusterJoin,
states: []state.State{
state.AcceptingInit,
},
},
"is not state multi": {
s: state.NodeWaitingForClusterJoin,
states: []state.State{
state.AcceptingInit,
state.ActivatingNodes,
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
res := containsState(tc.s, tc.states...)
assert.Equal(tc.success, res)
})
}
}

View File

@ -6,21 +6,13 @@ trap 'terminate $?' ERR
terminate() {
echo "error: $1"
constellation terminate
popd || exit 1
exit 1
}
main() {
if ! command -v constellation &> /dev/null
then
echo "constellation is not in path"
exit 1
fi
if ! command -v go &> /dev/null
then
echo "go is not in path"
exit 1
fi
command -v constellation > /dev/null
command -v go > /dev/null
command -v jq > /dev/null
mkdir -p ./pcrs

View File

@ -19,7 +19,7 @@ import (
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
"github.com/edgelesssys/constellation/coordinator/oid"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
coordinatorstate "github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/coordinator/state"
"github.com/spf13/afero"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
@ -42,21 +42,21 @@ func main() {
// wait for coordinator to come online
waiter := status.NewWaiter(map[uint32][]byte{})
if err := waiter.WaitFor(ctx, coordinatorstate.AcceptingInit, addr); err != nil {
if err := waiter.WaitFor(ctx, addr, state.AcceptingInit, state.ActivatingNodes, state.IsNode, state.NodeWaitingForClusterJoin); err != nil {
log.Fatal(err)
}
attDocRaw := &[]byte{}
attDocRaw := []byte{}
tlsConfig, err := atls.CreateUnverifiedClientTLSConfig()
if err != nil {
log.Fatal(err)
}
tlsConfig.VerifyPeerCertificate = getVerifyPeerCertificateFunc(attDocRaw)
tlsConfig.VerifyPeerCertificate = getVerifyPeerCertificateFunc(&attDocRaw)
if err := connectToCoordinator(ctx, addr, tlsConfig); err != nil {
log.Fatal(err)
}
pcrs, err := validatePCRAttDoc(*attDocRaw)
pcrs, err := validatePCRAttDoc(attDocRaw)
if err != nil {
log.Fatal(err)
}