mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-01-25 23:06:08 -05:00
AB#1903 Push keys to restarting nodes on trigger RPC
Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
parent
152e3985f7
commit
37aff14cab
@ -54,7 +54,7 @@ func run(validator core.QuoteValidator, issuer core.QuoteIssuer, vpn core.VPN, o
|
|||||||
|
|
||||||
vapiServer := &vpnAPIServer{logger: zapLoggerCore.Named("vpnapi"), core: core}
|
vapiServer := &vpnAPIServer{logger: zapLoggerCore.Named("vpnapi"), core: core}
|
||||||
zapLoggerPubapi := zapLoggerCore.Named("pubapi")
|
zapLoggerPubapi := zapLoggerCore.Named("pubapi")
|
||||||
papi := pubapi.New(zapLoggerPubapi, core, dialer, vapiServer, validator, getPublicIPAddr)
|
papi := pubapi.New(zapLoggerPubapi, core, dialer, vapiServer, validator, getPublicIPAddr, pubapi.GetRecoveryPeerFromContext)
|
||||||
|
|
||||||
zapLoggergRPC := zapLoggerPubapi.Named("gRPC")
|
zapLoggergRPC := zapLoggerPubapi.Named("gRPC")
|
||||||
|
|
||||||
|
@ -211,7 +211,7 @@ func spawnPeer(require *require.Assertions, logger *zap.Logger, dialer *testdial
|
|||||||
}
|
}
|
||||||
|
|
||||||
vapiServer := &fakeVPNAPIServer{logger: logger.Named("vpnapi"), core: cor, dialer: dialer}
|
vapiServer := &fakeVPNAPIServer{logger: logger.Named("vpnapi"), core: cor, dialer: dialer}
|
||||||
papi := pubapi.New(logger, cor, dialer, vapiServer, &core.MockValidator{}, getPublicAddr)
|
papi := pubapi.New(logger, cor, dialer, vapiServer, &core.MockValidator{}, getPublicAddr, nil)
|
||||||
|
|
||||||
tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{})
|
tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{})
|
||||||
require.NoError(err)
|
require.NoError(err)
|
||||||
|
@ -145,7 +145,7 @@ func newMockCoreWithDialer(dialer *bufconnDialer) (*Core, *pubapi.API, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
vapiServer := &fakeVPNAPIServer{logger: zapLogger, core: core, dialer: dialer}
|
vapiServer := &fakeVPNAPIServer{logger: zapLogger, core: core, dialer: dialer}
|
||||||
papi := pubapi.New(zapLogger, core, dialer, vapiServer, validator, getPublicAddr)
|
papi := pubapi.New(zapLogger, core, dialer, vapiServer, validator, getPublicAddr, nil)
|
||||||
|
|
||||||
return core, papi, nil
|
return core, papi, nil
|
||||||
}
|
}
|
||||||
|
@ -7,10 +7,12 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/edgelesssys/constellation/coordinator/config"
|
||||||
"github.com/edgelesssys/constellation/coordinator/peer"
|
"github.com/edgelesssys/constellation/coordinator/peer"
|
||||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||||
"github.com/edgelesssys/constellation/coordinator/role"
|
"github.com/edgelesssys/constellation/coordinator/role"
|
||||||
"github.com/edgelesssys/constellation/coordinator/state"
|
"github.com/edgelesssys/constellation/coordinator/state"
|
||||||
|
"github.com/edgelesssys/constellation/state/keyservice/keyproto"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
@ -185,17 +187,31 @@ func (a *API) ActivateAdditionalNodes(in *pubproto.ActivateAdditionalNodesReques
|
|||||||
|
|
||||||
// RequestStateDiskKey triggers the Coordinator to return a key derived from the Constellation's master secret to the caller.
|
// RequestStateDiskKey triggers the Coordinator to return a key derived from the Constellation's master secret to the caller.
|
||||||
func (a *API) RequestStateDiskKey(ctx context.Context, in *pubproto.RequestStateDiskKeyRequest) (*pubproto.RequestStateDiskKeyResponse, error) {
|
func (a *API) RequestStateDiskKey(ctx context.Context, in *pubproto.RequestStateDiskKeyRequest) (*pubproto.RequestStateDiskKeyResponse, error) {
|
||||||
// TODO: Add Coordinator call to restarting node and deliver the key
|
if err := a.core.RequireState(state.ActivatingNodes); err != nil {
|
||||||
/*
|
return nil, status.Errorf(codes.FailedPrecondition, "%v", err)
|
||||||
if err := a.core.RequireState(state.IsNode, state.ActivatingNodes); err != nil {
|
}
|
||||||
return nil, err
|
key, err := a.core.GetDataKey(ctx, in.DiskUuid, config.RNGLengthDefault)
|
||||||
}
|
if err != nil {
|
||||||
_, err := a.core.GetDataKey(ctx, in.DiskUuid, 32)
|
return nil, status.Errorf(codes.Internal, "unable to load key: %v", err)
|
||||||
if err != nil {
|
}
|
||||||
return nil, status.Errorf(codes.Internal, "")
|
|
||||||
}
|
peer, err := a.peerFromContext(ctx)
|
||||||
*/
|
if err != nil {
|
||||||
return &pubproto.RequestStateDiskKeyResponse{}, errors.New("unimplemented")
|
return nil, status.Errorf(codes.Internal, "%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := a.dial(ctx, peer)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "%v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := keyproto.NewAPIClient(conn)
|
||||||
|
if _, err := client.PushStateDiskKey(ctx, &keyproto.PushStateDiskKeyRequest{StateDiskKey: key}); err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "pushing key to peer %q: %v", peer, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &pubproto.RequestStateDiskKeyResponse{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) activateNodes(logToCLI logFunc, nodePublicIPs []string) error {
|
func (a *API) activateNodes(logToCLI logFunc, nodePublicIPs []string) error {
|
||||||
|
@ -2,6 +2,7 @@ package pubapi
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@ -9,6 +10,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||||
|
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
|
||||||
|
"github.com/edgelesssys/constellation/coordinator/core"
|
||||||
"github.com/edgelesssys/constellation/coordinator/kms"
|
"github.com/edgelesssys/constellation/coordinator/kms"
|
||||||
"github.com/edgelesssys/constellation/coordinator/oid"
|
"github.com/edgelesssys/constellation/coordinator/oid"
|
||||||
"github.com/edgelesssys/constellation/coordinator/peer"
|
"github.com/edgelesssys/constellation/coordinator/peer"
|
||||||
@ -16,11 +19,13 @@ import (
|
|||||||
"github.com/edgelesssys/constellation/coordinator/role"
|
"github.com/edgelesssys/constellation/coordinator/role"
|
||||||
"github.com/edgelesssys/constellation/coordinator/state"
|
"github.com/edgelesssys/constellation/coordinator/state"
|
||||||
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
|
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
|
||||||
|
"github.com/edgelesssys/constellation/state/keyservice/keyproto"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.uber.org/zap/zaptest"
|
"go.uber.org/zap/zaptest"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
|
grpcpeer "google.golang.org/grpc/peer"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestActivateAsCoordinator(t *testing.T) {
|
func TestActivateAsCoordinator(t *testing.T) {
|
||||||
@ -126,7 +131,7 @@ func TestActivateAsCoordinator(t *testing.T) {
|
|||||||
return "192.0.2.1", nil
|
return "192.0.2.1", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr)
|
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil)
|
||||||
defer api.Close()
|
defer api.Close()
|
||||||
|
|
||||||
// spawn nodes
|
// spawn nodes
|
||||||
@ -257,7 +262,7 @@ func TestActivateAdditionalNodes(t *testing.T) {
|
|||||||
return "192.0.2.1", nil
|
return "192.0.2.1", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
api := New(zaptest.NewLogger(t), core, dialer, nil, fakeValidator{}, getPublicIPAddr)
|
api := New(zaptest.NewLogger(t), core, dialer, nil, fakeValidator{}, getPublicIPAddr, nil)
|
||||||
defer api.Close()
|
defer api.Close()
|
||||||
// spawn nodes
|
// spawn nodes
|
||||||
var nodePublicIPs []string
|
var nodePublicIPs []string
|
||||||
@ -306,7 +311,7 @@ func TestAssemblePeerStruct(t *testing.T) {
|
|||||||
|
|
||||||
vpnPubKey := []byte{2, 3, 4}
|
vpnPubKey := []byte{2, 3, 4}
|
||||||
core := &fakeCore{vpnPubKey: vpnPubKey}
|
core := &fakeCore{vpnPubKey: vpnPubKey}
|
||||||
api := New(zaptest.NewLogger(t), core, nil, nil, nil, getPublicIPAddr)
|
api := New(zaptest.NewLogger(t), core, nil, nil, nil, getPublicIPAddr, nil)
|
||||||
defer api.Close()
|
defer api.Close()
|
||||||
|
|
||||||
vpnIP, err := core.GetVPNIP()
|
vpnIP, err := core.GetVPNIP()
|
||||||
@ -433,3 +438,107 @@ func (s *stubActivateAdditionalNodesServer) Send(req *pubproto.ActivateAdditiona
|
|||||||
s.sent = append(s.sent, req)
|
s.sent = append(s.sent, req)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRequestStateDiskKey(t *testing.T) {
|
||||||
|
defaultKey := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
|
||||||
|
someErr := errors.New("error")
|
||||||
|
testCases := map[string]struct {
|
||||||
|
state state.State
|
||||||
|
dataKey []byte
|
||||||
|
getDataKeyErr error
|
||||||
|
pushKeyErr error
|
||||||
|
errExpected bool
|
||||||
|
}{
|
||||||
|
"success": {
|
||||||
|
state: state.ActivatingNodes,
|
||||||
|
dataKey: defaultKey,
|
||||||
|
},
|
||||||
|
"Coordinator in wrong state": {
|
||||||
|
state: state.IsNode,
|
||||||
|
dataKey: defaultKey,
|
||||||
|
errExpected: true,
|
||||||
|
},
|
||||||
|
"GetDataKey fails": {
|
||||||
|
state: state.ActivatingNodes,
|
||||||
|
dataKey: defaultKey,
|
||||||
|
getDataKeyErr: someErr,
|
||||||
|
errExpected: true,
|
||||||
|
},
|
||||||
|
"key pushing fails": {
|
||||||
|
state: state.ActivatingNodes,
|
||||||
|
dataKey: defaultKey,
|
||||||
|
pushKeyErr: someErr,
|
||||||
|
errExpected: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
issuer := core.NewMockIssuer()
|
||||||
|
|
||||||
|
stateDiskServer := &stubStateDiskServer{pushKeyErr: tc.pushKeyErr}
|
||||||
|
|
||||||
|
// we can not use a bufconn here, since we rely on grpcpeer.FromContext() to connect to the caller
|
||||||
|
listener, err := net.Listen("tcp", ":")
|
||||||
|
require.NoError(err)
|
||||||
|
defer listener.Close()
|
||||||
|
|
||||||
|
tlsConfig, err := atls.CreateAttestationServerTLSConfig(issuer)
|
||||||
|
require.NoError(err)
|
||||||
|
s := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
|
||||||
|
keyproto.RegisterAPIServer(s, stateDiskServer)
|
||||||
|
defer s.GracefulStop()
|
||||||
|
go s.Serve(listener)
|
||||||
|
|
||||||
|
ctx := grpcpeer.NewContext(context.Background(), &grpcpeer.Peer{Addr: listener.Addr()})
|
||||||
|
getPeerFromContext := func(ctx context.Context) (string, error) {
|
||||||
|
peer, ok := grpcpeer.FromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return "", errors.New("unable to get peer from context")
|
||||||
|
}
|
||||||
|
return peer.Addr.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
core := &fakeCore{
|
||||||
|
state: tc.state,
|
||||||
|
dataKey: tc.dataKey,
|
||||||
|
getDataKeyErr: tc.getDataKeyErr,
|
||||||
|
}
|
||||||
|
api := New(zaptest.NewLogger(t), core, &net.Dialer{}, nil, dummyValidator{}, nil, getPeerFromContext)
|
||||||
|
|
||||||
|
_, err = api.RequestStateDiskKey(ctx, &pubproto.RequestStateDiskKeyRequest{})
|
||||||
|
if tc.errExpected {
|
||||||
|
assert.Error(err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(err)
|
||||||
|
assert.Equal(tc.dataKey, stateDiskServer.receivedRequest.StateDiskKey)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type dummyValidator struct {
|
||||||
|
oid.Dummy
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d dummyValidator) Validate(attdoc []byte, nonce []byte) ([]byte, error) {
|
||||||
|
var attestation vtpm.AttestationDocument
|
||||||
|
if err := json.Unmarshal(attdoc, &attestation); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return attestation.UserData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type stubStateDiskServer struct {
|
||||||
|
receivedRequest *keyproto.PushStateDiskKeyRequest
|
||||||
|
pushKeyErr error
|
||||||
|
keyproto.UnimplementedAPIServer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubStateDiskServer) PushStateDiskKey(ctx context.Context, in *keyproto.PushStateDiskKeyRequest) (*keyproto.PushStateDiskKeyResponse, error) {
|
||||||
|
s.receivedRequest = in
|
||||||
|
return &keyproto.PushStateDiskKeyResponse{}, s.pushKeyErr
|
||||||
|
}
|
||||||
|
@ -66,7 +66,7 @@ func TestActivateAsCoordinators(t *testing.T) {
|
|||||||
return "192.0.2.1", nil
|
return "192.0.2.1", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr)
|
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil)
|
||||||
defer api.Close()
|
defer api.Close()
|
||||||
|
|
||||||
// spawn coordinator
|
// spawn coordinator
|
||||||
@ -133,7 +133,7 @@ func TestTriggerCoordinatorUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
dialer := testdialer.NewBufconnDialer()
|
dialer := testdialer.NewBufconnDialer()
|
||||||
|
|
||||||
api := New(logger, core, dialer, nil, nil, nil)
|
api := New(logger, core, dialer, nil, nil, nil, nil)
|
||||||
|
|
||||||
_, err := api.TriggerCoordinatorUpdate(context.Background(), &pubproto.TriggerCoordinatorUpdateRequest{})
|
_, err := api.TriggerCoordinatorUpdate(context.Background(), &pubproto.TriggerCoordinatorUpdateRequest{})
|
||||||
if tc.expectErr {
|
if tc.expectErr {
|
||||||
@ -202,7 +202,7 @@ func TestActivateAdditionalCoordinators(t *testing.T) {
|
|||||||
return "192.0.2.1", nil
|
return "192.0.2.1", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr)
|
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil)
|
||||||
defer api.Close()
|
defer api.Close()
|
||||||
|
|
||||||
// spawn coordinator
|
// spawn coordinator
|
||||||
|
@ -129,7 +129,7 @@ func TestActivateAsNode(t *testing.T) {
|
|||||||
cor := &fakeCore{state: tc.state, vpnPubKey: vpnPubKey, setVPNIPErr: tc.setVPNIPErr}
|
cor := &fakeCore{state: tc.state, vpnPubKey: vpnPubKey, setVPNIPErr: tc.setVPNIPErr}
|
||||||
dialer := testdialer.NewBufconnDialer()
|
dialer := testdialer.NewBufconnDialer()
|
||||||
|
|
||||||
api := New(logger, cor, dialer, nil, nil, nil)
|
api := New(logger, cor, dialer, nil, nil, nil, nil)
|
||||||
defer api.Close()
|
defer api.Close()
|
||||||
|
|
||||||
vserver := grpc.NewServer()
|
vserver := grpc.NewServer()
|
||||||
@ -217,7 +217,7 @@ func TestTriggerNodeUpdate(t *testing.T) {
|
|||||||
core := &fakeCore{state: tc.state}
|
core := &fakeCore{state: tc.state}
|
||||||
dialer := testdialer.NewBufconnDialer()
|
dialer := testdialer.NewBufconnDialer()
|
||||||
|
|
||||||
api := New(logger, core, dialer, nil, nil, nil)
|
api := New(logger, core, dialer, nil, nil, nil, nil)
|
||||||
|
|
||||||
vserver := grpc.NewServer()
|
vserver := grpc.NewServer()
|
||||||
vapi := &stubVPNAPI{
|
vapi := &stubVPNAPI{
|
||||||
@ -292,7 +292,7 @@ func TestJoinCluster(t *testing.T) {
|
|||||||
core := &fakeCore{state: tc.state, joinClusterErr: tc.joinClusterErr}
|
core := &fakeCore{state: tc.state, joinClusterErr: tc.joinClusterErr}
|
||||||
dialer := testdialer.NewBufconnDialer()
|
dialer := testdialer.NewBufconnDialer()
|
||||||
|
|
||||||
api := New(logger, core, dialer, nil, nil, nil)
|
api := New(logger, core, dialer, nil, nil, nil, nil)
|
||||||
|
|
||||||
vserver := grpc.NewServer()
|
vserver := grpc.NewServer()
|
||||||
vapi := &stubVPNAPI{
|
vapi := &stubVPNAPI{
|
||||||
|
@ -3,16 +3,19 @@ package pubapi
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||||
|
"github.com/edgelesssys/constellation/state/setup"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
"google.golang.org/grpc/peer"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -34,11 +37,12 @@ type API struct {
|
|||||||
stopUpdate chan struct{}
|
stopUpdate chan struct{}
|
||||||
wgClose sync.WaitGroup
|
wgClose sync.WaitGroup
|
||||||
resourceVersion int
|
resourceVersion int
|
||||||
|
peerFromContext PeerFromContextFunc
|
||||||
pubproto.UnimplementedAPIServer
|
pubproto.UnimplementedAPIServer
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new API.
|
// New creates a new API.
|
||||||
func New(logger *zap.Logger, core Core, dialer Dialer, vpnAPIServer VPNAPIServer, validator atls.Validator, getPublicIPAddr GetIPAddrFunc) *API {
|
func New(logger *zap.Logger, core Core, dialer Dialer, vpnAPIServer VPNAPIServer, validator atls.Validator, getPublicIPAddr GetIPAddrFunc, peerFromContext PeerFromContextFunc) *API {
|
||||||
return &API{
|
return &API{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
core: core,
|
core: core,
|
||||||
@ -47,6 +51,7 @@ func New(logger *zap.Logger, core Core, dialer Dialer, vpnAPIServer VPNAPIServer
|
|||||||
validator: validator,
|
validator: validator,
|
||||||
getPublicIPAddr: getPublicIPAddr,
|
getPublicIPAddr: getPublicIPAddr,
|
||||||
stopUpdate: make(chan struct{}, 1),
|
stopUpdate: make(chan struct{}, 1),
|
||||||
|
peerFromContext: peerFromContext,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -112,3 +117,21 @@ type VPNAPIServer interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type GetIPAddrFunc func() (string, error)
|
type GetIPAddrFunc func() (string, error)
|
||||||
|
|
||||||
|
// PeerFromContextFunc returns a peer endpoint (IP:port) from a given context.
|
||||||
|
type PeerFromContextFunc func(context.Context) (string, error)
|
||||||
|
|
||||||
|
// GetRecoveryPeerFromContext returns the context's IP joined with the Coordinator's default port.
|
||||||
|
func GetRecoveryPeerFromContext(ctx context.Context) (string, error) {
|
||||||
|
peer, ok := peer.FromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return "", errors.New("unable to get peer from context")
|
||||||
|
}
|
||||||
|
|
||||||
|
peerIP, _, err := net.SplitHostPort(peer.Addr.String())
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return net.JoinHostPort(peerIP, setup.RecoveryPort), nil
|
||||||
|
}
|
||||||
|
@ -1,9 +1,13 @@
|
|||||||
package pubapi
|
package pubapi
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"go.uber.org/goleak"
|
"go.uber.org/goleak"
|
||||||
|
grpcpeer "google.golang.org/grpc/peer"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
@ -14,3 +18,20 @@ func TestMain(m *testing.M) {
|
|||||||
goleak.IgnoreTopFunction("k8s.io/klog/v2.(*loggingT).flushDaemon"),
|
goleak.IgnoreTopFunction("k8s.io/klog/v2.(*loggingT).flushDaemon"),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetRecoveryPeerFromContext(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
testIP := "192.0.2.1"
|
||||||
|
testPort := 1234
|
||||||
|
expectedPeer := net.JoinHostPort(testIP, "9000")
|
||||||
|
|
||||||
|
addr := &net.TCPAddr{IP: net.ParseIP(testIP), Port: testPort}
|
||||||
|
ctx := grpcpeer.NewContext(context.Background(), &grpcpeer.Peer{Addr: addr})
|
||||||
|
|
||||||
|
peer, err := GetRecoveryPeerFromContext(ctx)
|
||||||
|
assert.NoError(err)
|
||||||
|
assert.Equal(expectedPeer, peer)
|
||||||
|
|
||||||
|
_, err = GetRecoveryPeerFromContext(context.Background())
|
||||||
|
assert.Error(err)
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user