mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-08-01 03:26:08 -04:00
Allow waiting for multiple states (#11)
* Simplify `fetch_pcrs.sh` script * Allow waiting for multiple states Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
parent
9df71da33f
commit
eb3411f2c1
7 changed files with 87 additions and 31 deletions
|
@ -107,7 +107,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, vpn
|
||||||
}
|
}
|
||||||
|
|
||||||
endpoints := ipsToEndpoints(append(coordinators.PublicIPs(), nodes.PublicIPs()...), *config.CoordinatorPort)
|
endpoints := ipsToEndpoints(append(coordinators.PublicIPs(), nodes.PublicIPs()...), *config.CoordinatorPort)
|
||||||
if err := waiter.WaitForAll(ctx, coordinatorstate.AcceptingInit, endpoints); err != nil {
|
if err := waiter.WaitForAll(ctx, endpoints, coordinatorstate.AcceptingInit); err != nil {
|
||||||
return fmt.Errorf("failed to wait for peer status: %w", err)
|
return fmt.Errorf("failed to wait for peer status: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,5 +7,5 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type statusWaiter interface {
|
type statusWaiter interface {
|
||||||
WaitForAll(ctx context.Context, status state.State, endpoints []string) error
|
WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,6 @@ type stubStatusWaiter struct {
|
||||||
waitForAllErr error
|
waitForAllErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w stubStatusWaiter) WaitForAll(ctx context.Context, status state.State, endpoints []string) error {
|
func (w stubStatusWaiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
|
||||||
return w.waitForAllErr
|
return w.waitForAllErr
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,7 +36,7 @@ func NewWaiter(gcpPCRs map[uint32][]byte) Waiter {
|
||||||
|
|
||||||
// WaitFor waits for a PeerStatusServer, which is reachable under the given endpoint
|
// WaitFor waits for a PeerStatusServer, which is reachable under the given endpoint
|
||||||
// to reach the specified state.
|
// to reach the specified state.
|
||||||
func (w Waiter) WaitFor(ctx context.Context, status state.State, endpoint string) error {
|
func (w Waiter) WaitFor(ctx context.Context, endpoint string, status ...state.State) error {
|
||||||
ticker := time.NewTicker(w.interval)
|
ticker := time.NewTicker(w.interval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
@ -45,7 +45,7 @@ func (w Waiter) WaitFor(ctx context.Context, status state.State, endpoint string
|
||||||
if err != nil && grpcstatus.Code(err) != grpccodes.Unavailable {
|
if err != nil && grpcstatus.Code(err) != grpccodes.Unavailable {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if resp != nil && resp.State == uint32(status) {
|
if resp != nil && containsState(state.State(resp.State), status...) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ func (w Waiter) WaitFor(ctx context.Context, status state.State, endpoint string
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if resp.State == uint32(status) {
|
if containsState(state.State(resp.State), status...) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
@ -84,9 +84,9 @@ func (w Waiter) probe(ctx context.Context, endpoint string) (*pubproto.GetStateR
|
||||||
|
|
||||||
// WaitForAll waits for a list of PeerStatusServers, which listen on the handed
|
// WaitForAll waits for a list of PeerStatusServers, which listen on the handed
|
||||||
// endpoints, to reach the specified state.
|
// endpoints, to reach the specified state.
|
||||||
func (w Waiter) WaitForAll(ctx context.Context, status state.State, endpoints []string) error {
|
func (w Waiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
|
||||||
for _, endpoint := range endpoints {
|
for _, endpoint := range endpoints {
|
||||||
if err := w.WaitFor(ctx, status, endpoint); err != nil {
|
if err := w.WaitFor(ctx, endpoint, status...); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -119,3 +119,13 @@ type ClientConn interface {
|
||||||
grpc.ClientConnInterface
|
grpc.ClientConnInterface
|
||||||
io.Closer
|
io.Closer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// containsState checks if current state is one of the given states.
|
||||||
|
func containsState(s state.State, states ...state.State) bool {
|
||||||
|
for _, state := range states {
|
||||||
|
if state == s {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
|
||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
waiter Waiter
|
waiter Waiter
|
||||||
waitForState state.State
|
waitForState []state.State
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
"successful wait": {
|
"successful wait": {
|
||||||
|
@ -27,7 +27,15 @@ func TestWaitForAndWaitForAll(t *testing.T) {
|
||||||
newConn: stubNewConnFunc(noErr),
|
newConn: stubNewConnFunc(noErr),
|
||||||
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
|
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
|
||||||
},
|
},
|
||||||
waitForState: state.IsNode,
|
waitForState: []state.State{state.IsNode},
|
||||||
|
},
|
||||||
|
"successful wait multi states": {
|
||||||
|
waiter: Waiter{
|
||||||
|
interval: time.Millisecond,
|
||||||
|
newConn: stubNewConnFunc(noErr),
|
||||||
|
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
|
||||||
|
},
|
||||||
|
waitForState: []state.State{state.IsNode, state.ActivatingNodes},
|
||||||
},
|
},
|
||||||
"expect timeout": {
|
"expect timeout": {
|
||||||
waiter: Waiter{
|
waiter: Waiter{
|
||||||
|
@ -35,7 +43,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
|
||||||
newConn: stubNewConnFunc(noErr),
|
newConn: stubNewConnFunc(noErr),
|
||||||
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.AcceptingInit}),
|
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.AcceptingInit}),
|
||||||
},
|
},
|
||||||
waitForState: state.IsNode,
|
waitForState: []state.State{state.IsNode},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
"fail to check call": {
|
"fail to check call": {
|
||||||
|
@ -44,7 +52,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
|
||||||
newConn: stubNewConnFunc(noErr),
|
newConn: stubNewConnFunc(noErr),
|
||||||
newClient: stubNewClientFunc(&stubPeerStatusClient{checkErr: someErr}),
|
newClient: stubNewClientFunc(&stubPeerStatusClient{checkErr: someErr}),
|
||||||
},
|
},
|
||||||
waitForState: state.IsNode,
|
waitForState: []state.State{state.IsNode},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
"fail to create conn": {
|
"fail to create conn": {
|
||||||
|
@ -53,7 +61,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
|
||||||
newConn: stubNewConnFunc(someErr),
|
newConn: stubNewConnFunc(someErr),
|
||||||
newClient: stubNewClientFunc(&stubPeerStatusClient{}),
|
newClient: stubNewClientFunc(&stubPeerStatusClient{}),
|
||||||
},
|
},
|
||||||
waitForState: state.IsNode,
|
waitForState: []state.State{state.IsNode},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -67,7 +75,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, 50*time.Millisecond)
|
ctx, cancel := context.WithTimeout(ctx, 50*time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
err := tc.waiter.WaitFor(ctx, tc.waitForState, "someIP")
|
err := tc.waiter.WaitFor(ctx, "someIP", tc.waitForState...)
|
||||||
|
|
||||||
if tc.wantErr {
|
if tc.wantErr {
|
||||||
assert.Error(err)
|
assert.Error(err)
|
||||||
|
@ -88,7 +96,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
endpoints := []string{"192.0.2.1", "192.0.2.2", "192.0.2.3"}
|
endpoints := []string{"192.0.2.1", "192.0.2.2", "192.0.2.3"}
|
||||||
err := tc.waiter.WaitForAll(ctx, tc.waitForState, endpoints)
|
err := tc.waiter.WaitForAll(ctx, endpoints, tc.waitForState...)
|
||||||
|
|
||||||
if tc.wantErr {
|
if tc.wantErr {
|
||||||
assert.Error(err)
|
assert.Error(err)
|
||||||
|
@ -136,3 +144,49 @@ func (c *stubPeerStatusClient) GetState(ctx context.Context, in *pubproto.GetSta
|
||||||
resp := &pubproto.GetStateResponse{State: uint32(c.state)}
|
resp := &pubproto.GetStateResponse{State: uint32(c.state)}
|
||||||
return resp, c.checkErr
|
return resp, c.checkErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestContainsState(t *testing.T) {
|
||||||
|
testCases := map[string]struct {
|
||||||
|
s state.State
|
||||||
|
states []state.State
|
||||||
|
success bool
|
||||||
|
}{
|
||||||
|
"is state": {
|
||||||
|
s: state.IsNode,
|
||||||
|
states: []state.State{
|
||||||
|
state.IsNode,
|
||||||
|
},
|
||||||
|
success: true,
|
||||||
|
},
|
||||||
|
"is state multi": {
|
||||||
|
s: state.AcceptingInit,
|
||||||
|
states: []state.State{
|
||||||
|
state.AcceptingInit,
|
||||||
|
state.ActivatingNodes,
|
||||||
|
},
|
||||||
|
success: true,
|
||||||
|
},
|
||||||
|
"is not state": {
|
||||||
|
s: state.NodeWaitingForClusterJoin,
|
||||||
|
states: []state.State{
|
||||||
|
state.AcceptingInit,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"is not state multi": {
|
||||||
|
s: state.NodeWaitingForClusterJoin,
|
||||||
|
states: []state.State{
|
||||||
|
state.AcceptingInit,
|
||||||
|
state.ActivatingNodes,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
res := containsState(tc.s, tc.states...)
|
||||||
|
assert.Equal(tc.success, res)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -6,21 +6,13 @@ trap 'terminate $?' ERR
|
||||||
terminate() {
|
terminate() {
|
||||||
echo "error: $1"
|
echo "error: $1"
|
||||||
constellation terminate
|
constellation terminate
|
||||||
popd || exit 1
|
|
||||||
exit 1
|
exit 1
|
||||||
}
|
}
|
||||||
|
|
||||||
main() {
|
main() {
|
||||||
if ! command -v constellation &> /dev/null
|
command -v constellation > /dev/null
|
||||||
then
|
command -v go > /dev/null
|
||||||
echo "constellation is not in path"
|
command -v jq > /dev/null
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
if ! command -v go &> /dev/null
|
|
||||||
then
|
|
||||||
echo "go is not in path"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
mkdir -p ./pcrs
|
mkdir -p ./pcrs
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ import (
|
||||||
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
|
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
|
||||||
"github.com/edgelesssys/constellation/coordinator/oid"
|
"github.com/edgelesssys/constellation/coordinator/oid"
|
||||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||||
coordinatorstate "github.com/edgelesssys/constellation/coordinator/state"
|
"github.com/edgelesssys/constellation/coordinator/state"
|
||||||
"github.com/spf13/afero"
|
"github.com/spf13/afero"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
|
@ -42,21 +42,21 @@ func main() {
|
||||||
|
|
||||||
// wait for coordinator to come online
|
// wait for coordinator to come online
|
||||||
waiter := status.NewWaiter(map[uint32][]byte{})
|
waiter := status.NewWaiter(map[uint32][]byte{})
|
||||||
if err := waiter.WaitFor(ctx, coordinatorstate.AcceptingInit, addr); err != nil {
|
if err := waiter.WaitFor(ctx, addr, state.AcceptingInit, state.ActivatingNodes, state.IsNode, state.NodeWaitingForClusterJoin); err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
attDocRaw := &[]byte{}
|
attDocRaw := []byte{}
|
||||||
tlsConfig, err := atls.CreateUnverifiedClientTLSConfig()
|
tlsConfig, err := atls.CreateUnverifiedClientTLSConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
tlsConfig.VerifyPeerCertificate = getVerifyPeerCertificateFunc(attDocRaw)
|
tlsConfig.VerifyPeerCertificate = getVerifyPeerCertificateFunc(&attDocRaw)
|
||||||
if err := connectToCoordinator(ctx, addr, tlsConfig); err != nil {
|
if err := connectToCoordinator(ctx, addr, tlsConfig); err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pcrs, err := validatePCRAttDoc(*attDocRaw)
|
pcrs, err := validatePCRAttDoc(attDocRaw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue