mirror of
https://github.com/edgelesssys/constellation.git
synced 2024-10-01 01:36:09 -04:00
Allow waiting for multiple states (#11)
* Simplify `fetch_pcrs.sh` script * Allow waiting for multiple states Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
parent
9df71da33f
commit
eb3411f2c1
@ -107,7 +107,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, vpn
|
||||
}
|
||||
|
||||
endpoints := ipsToEndpoints(append(coordinators.PublicIPs(), nodes.PublicIPs()...), *config.CoordinatorPort)
|
||||
if err := waiter.WaitForAll(ctx, coordinatorstate.AcceptingInit, endpoints); err != nil {
|
||||
if err := waiter.WaitForAll(ctx, endpoints, coordinatorstate.AcceptingInit); err != nil {
|
||||
return fmt.Errorf("failed to wait for peer status: %w", err)
|
||||
}
|
||||
|
||||
|
@ -7,5 +7,5 @@ import (
|
||||
)
|
||||
|
||||
type statusWaiter interface {
|
||||
WaitForAll(ctx context.Context, status state.State, endpoints []string) error
|
||||
WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error
|
||||
}
|
||||
|
@ -10,6 +10,6 @@ type stubStatusWaiter struct {
|
||||
waitForAllErr error
|
||||
}
|
||||
|
||||
func (w stubStatusWaiter) WaitForAll(ctx context.Context, status state.State, endpoints []string) error {
|
||||
func (w stubStatusWaiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
|
||||
return w.waitForAllErr
|
||||
}
|
||||
|
@ -36,7 +36,7 @@ func NewWaiter(gcpPCRs map[uint32][]byte) Waiter {
|
||||
|
||||
// WaitFor waits for a PeerStatusServer, which is reachable under the given endpoint
|
||||
// to reach the specified state.
|
||||
func (w Waiter) WaitFor(ctx context.Context, status state.State, endpoint string) error {
|
||||
func (w Waiter) WaitFor(ctx context.Context, endpoint string, status ...state.State) error {
|
||||
ticker := time.NewTicker(w.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
@ -45,7 +45,7 @@ func (w Waiter) WaitFor(ctx context.Context, status state.State, endpoint string
|
||||
if err != nil && grpcstatus.Code(err) != grpccodes.Unavailable {
|
||||
return err
|
||||
}
|
||||
if resp != nil && resp.State == uint32(status) {
|
||||
if resp != nil && containsState(state.State(resp.State), status...) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -61,7 +61,7 @@ func (w Waiter) WaitFor(ctx context.Context, status state.State, endpoint string
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.State == uint32(status) {
|
||||
if containsState(state.State(resp.State), status...) {
|
||||
return nil
|
||||
}
|
||||
case <-ctx.Done():
|
||||
@ -84,9 +84,9 @@ func (w Waiter) probe(ctx context.Context, endpoint string) (*pubproto.GetStateR
|
||||
|
||||
// WaitForAll waits for a list of PeerStatusServers, which listen on the handed
|
||||
// endpoints, to reach the specified state.
|
||||
func (w Waiter) WaitForAll(ctx context.Context, status state.State, endpoints []string) error {
|
||||
func (w Waiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
|
||||
for _, endpoint := range endpoints {
|
||||
if err := w.WaitFor(ctx, status, endpoint); err != nil {
|
||||
if err := w.WaitFor(ctx, endpoint, status...); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -119,3 +119,13 @@ type ClientConn interface {
|
||||
grpc.ClientConnInterface
|
||||
io.Closer
|
||||
}
|
||||
|
||||
// containsState checks if current state is one of the given states.
|
||||
func containsState(s state.State, states ...state.State) bool {
|
||||
for _, state := range states {
|
||||
if state == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -18,7 +18,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
|
||||
|
||||
testCases := map[string]struct {
|
||||
waiter Waiter
|
||||
waitForState state.State
|
||||
waitForState []state.State
|
||||
wantErr bool
|
||||
}{
|
||||
"successful wait": {
|
||||
@ -27,7 +27,15 @@ func TestWaitForAndWaitForAll(t *testing.T) {
|
||||
newConn: stubNewConnFunc(noErr),
|
||||
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
|
||||
},
|
||||
waitForState: state.IsNode,
|
||||
waitForState: []state.State{state.IsNode},
|
||||
},
|
||||
"successful wait multi states": {
|
||||
waiter: Waiter{
|
||||
interval: time.Millisecond,
|
||||
newConn: stubNewConnFunc(noErr),
|
||||
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
|
||||
},
|
||||
waitForState: []state.State{state.IsNode, state.ActivatingNodes},
|
||||
},
|
||||
"expect timeout": {
|
||||
waiter: Waiter{
|
||||
@ -35,7 +43,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
|
||||
newConn: stubNewConnFunc(noErr),
|
||||
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.AcceptingInit}),
|
||||
},
|
||||
waitForState: state.IsNode,
|
||||
waitForState: []state.State{state.IsNode},
|
||||
wantErr: true,
|
||||
},
|
||||
"fail to check call": {
|
||||
@ -44,7 +52,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
|
||||
newConn: stubNewConnFunc(noErr),
|
||||
newClient: stubNewClientFunc(&stubPeerStatusClient{checkErr: someErr}),
|
||||
},
|
||||
waitForState: state.IsNode,
|
||||
waitForState: []state.State{state.IsNode},
|
||||
wantErr: true,
|
||||
},
|
||||
"fail to create conn": {
|
||||
@ -53,7 +61,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
|
||||
newConn: stubNewConnFunc(someErr),
|
||||
newClient: stubNewClientFunc(&stubPeerStatusClient{}),
|
||||
},
|
||||
waitForState: state.IsNode,
|
||||
waitForState: []state.State{state.IsNode},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
@ -67,7 +75,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := tc.waiter.WaitFor(ctx, tc.waitForState, "someIP")
|
||||
err := tc.waiter.WaitFor(ctx, "someIP", tc.waitForState...)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
@ -88,7 +96,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
endpoints := []string{"192.0.2.1", "192.0.2.2", "192.0.2.3"}
|
||||
err := tc.waiter.WaitForAll(ctx, tc.waitForState, endpoints)
|
||||
err := tc.waiter.WaitForAll(ctx, endpoints, tc.waitForState...)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
@ -136,3 +144,49 @@ func (c *stubPeerStatusClient) GetState(ctx context.Context, in *pubproto.GetSta
|
||||
resp := &pubproto.GetStateResponse{State: uint32(c.state)}
|
||||
return resp, c.checkErr
|
||||
}
|
||||
|
||||
func TestContainsState(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
s state.State
|
||||
states []state.State
|
||||
success bool
|
||||
}{
|
||||
"is state": {
|
||||
s: state.IsNode,
|
||||
states: []state.State{
|
||||
state.IsNode,
|
||||
},
|
||||
success: true,
|
||||
},
|
||||
"is state multi": {
|
||||
s: state.AcceptingInit,
|
||||
states: []state.State{
|
||||
state.AcceptingInit,
|
||||
state.ActivatingNodes,
|
||||
},
|
||||
success: true,
|
||||
},
|
||||
"is not state": {
|
||||
s: state.NodeWaitingForClusterJoin,
|
||||
states: []state.State{
|
||||
state.AcceptingInit,
|
||||
},
|
||||
},
|
||||
"is not state multi": {
|
||||
s: state.NodeWaitingForClusterJoin,
|
||||
states: []state.State{
|
||||
state.AcceptingInit,
|
||||
state.ActivatingNodes,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
res := containsState(tc.s, tc.states...)
|
||||
assert.Equal(tc.success, res)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -6,21 +6,13 @@ trap 'terminate $?' ERR
|
||||
terminate() {
|
||||
echo "error: $1"
|
||||
constellation terminate
|
||||
popd || exit 1
|
||||
exit 1
|
||||
}
|
||||
|
||||
main() {
|
||||
if ! command -v constellation &> /dev/null
|
||||
then
|
||||
echo "constellation is not in path"
|
||||
exit 1
|
||||
fi
|
||||
if ! command -v go &> /dev/null
|
||||
then
|
||||
echo "go is not in path"
|
||||
exit 1
|
||||
fi
|
||||
command -v constellation > /dev/null
|
||||
command -v go > /dev/null
|
||||
command -v jq > /dev/null
|
||||
|
||||
mkdir -p ./pcrs
|
||||
|
||||
|
@ -19,7 +19,7 @@ import (
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
|
||||
"github.com/edgelesssys/constellation/coordinator/oid"
|
||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||
coordinatorstate "github.com/edgelesssys/constellation/coordinator/state"
|
||||
"github.com/edgelesssys/constellation/coordinator/state"
|
||||
"github.com/spf13/afero"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
@ -42,21 +42,21 @@ func main() {
|
||||
|
||||
// wait for coordinator to come online
|
||||
waiter := status.NewWaiter(map[uint32][]byte{})
|
||||
if err := waiter.WaitFor(ctx, coordinatorstate.AcceptingInit, addr); err != nil {
|
||||
if err := waiter.WaitFor(ctx, addr, state.AcceptingInit, state.ActivatingNodes, state.IsNode, state.NodeWaitingForClusterJoin); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
attDocRaw := &[]byte{}
|
||||
attDocRaw := []byte{}
|
||||
tlsConfig, err := atls.CreateUnverifiedClientTLSConfig()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
tlsConfig.VerifyPeerCertificate = getVerifyPeerCertificateFunc(attDocRaw)
|
||||
tlsConfig.VerifyPeerCertificate = getVerifyPeerCertificateFunc(&attDocRaw)
|
||||
if err := connectToCoordinator(ctx, addr, tlsConfig); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
pcrs, err := validatePCRAttDoc(*attDocRaw)
|
||||
pcrs, err := validatePCRAttDoc(attDocRaw)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user