init: overwrite kubeconfig address (#2393)

This commit is contained in:
3u13r 2023-09-29 14:01:40 +02:00 committed by GitHub
parent 85b4101dc3
commit eebaef9ddd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 71 additions and 11 deletions

View File

@ -172,6 +172,8 @@ go_test(
"@io_k8s_api//core/v1:core",
"@io_k8s_apiextensions_apiserver//pkg/apis/apiextensions/v1:apiextensions",
"@io_k8s_apimachinery//pkg/apis/meta/v1:meta",
"@io_k8s_client_go//tools/clientcmd",
"@io_k8s_client_go//tools/clientcmd/api",
"@org_golang_google_grpc//:go_default_library",
"@org_golang_google_grpc//codes",
"@org_golang_google_grpc//status",

View File

@ -14,6 +14,7 @@ import (
"fmt"
"io"
"net"
"net/url"
"os"
"path/filepath"
"strconv"
@ -472,7 +473,28 @@ func (i *initCmd) writeOutput(
tw.Flush()
fmt.Fprintln(wr)
if err := i.fileHandler.Write(constants.AdminConfFilename, initResp.GetKubeconfig(), file.OptNone); err != nil {
i.log.Debugf("Rewriting cluster server address in kubeconfig to %s", idFile.IP)
kubeconfig, err := clientcmd.Load(initResp.GetKubeconfig())
if err != nil {
return fmt.Errorf("loading kubeconfig: %w", err)
}
if len(kubeconfig.Clusters) != 1 {
return fmt.Errorf("expected exactly one cluster in kubeconfig, got %d", len(kubeconfig.Clusters))
}
for _, cluster := range kubeconfig.Clusters {
kubeEndpoint, err := url.Parse(cluster.Server)
if err != nil {
return fmt.Errorf("parsing kubeconfig server URL: %w", err)
}
kubeEndpoint.Host = net.JoinHostPort(idFile.IP, kubeEndpoint.Port())
cluster.Server = kubeEndpoint.String()
}
kubeconfigBytes, err := clientcmd.Write(*kubeconfig)
if err != nil {
return fmt.Errorf("marshaling kubeconfig: %w", err)
}
if err := i.fileHandler.Write(constants.AdminConfFilename, kubeconfigBytes, file.OptNone); err != nil {
return fmt.Errorf("writing kubeconfig: %w", err)
}
i.log.Debugf("Kubeconfig written to %s", i.pf.PrefixPrintablePath(constants.AdminConfFilename))

View File

@ -12,6 +12,7 @@ import (
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"strconv"
@ -44,6 +45,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"k8s.io/client-go/tools/clientcmd"
k8sclientapi "k8s.io/client-go/tools/clientcmd/api"
)
func TestInitArgumentValidation(t *testing.T) {
@ -56,6 +59,18 @@ func TestInitArgumentValidation(t *testing.T) {
}
func TestInitialize(t *testing.T) {
require := require.New(t)
respKubeconfig := k8sclientapi.Config{
Clusters: map[string]*k8sclientapi.Cluster{
"cluster": {
Server: "https://192.0.2.1:6443",
},
},
}
respKubeconfigBytes, err := clientcmd.Write(respKubeconfig)
require.NoError(err)
gcpServiceAccKey := &gcpshared.ServiceAccountKey{
Type: "service_account",
ProjectID: "project_id",
@ -69,7 +84,7 @@ func TestInitialize(t *testing.T) {
ClientX509CertURL: "client_cert",
}
testInitResp := &initproto.InitSuccessResponse{
Kubeconfig: []byte("kubeconfig"),
Kubeconfig: respKubeconfigBytes,
OwnerId: []byte("ownerID"),
ClusterId: []byte("clusterID"),
}
@ -160,7 +175,7 @@ func TestInitialize(t *testing.T) {
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
configMutator: func(c *config.Config) {
res, err := versions.NewValidK8sVersion(strings.TrimPrefix(string(versions.Default), "v"), true)
require.NoError(t, err)
require.NoError(err)
c.KubernetesVersion = res
},
},
@ -170,7 +185,7 @@ func TestInitialize(t *testing.T) {
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
configMutator: func(c *config.Config) {
v, err := semver.New(versions.SupportedK8sVersions()[0])
require.NoError(t, err)
require.NoError(err)
outdatedPatchVer := semver.NewFromInt(v.Major(), v.Minor(), v.Patch()-1, "").String()
c.KubernetesVersion = versions.ValidK8sVersion(outdatedPatchVer)
},
@ -182,8 +197,6 @@ func TestInitialize(t *testing.T) {
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
// Networking
netDialer := testdialer.NewBufconnDialer()
newDialer := func(atls.Validator) *dialer.Dialer {
@ -339,12 +352,34 @@ func TestWriteOutput(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
clusterEndpoint := "cluster-endpoint"
expectedKubeconfig := k8sclientapi.Config{
Clusters: map[string]*k8sclientapi.Cluster{
"cluster": {
Server: fmt.Sprintf("https://%s:6443", clusterEndpoint),
},
},
}
expectedKubeconfigBytes, err := clientcmd.Write(expectedKubeconfig)
require.NoError(err)
respKubeconfig := k8sclientapi.Config{
Clusters: map[string]*k8sclientapi.Cluster{
"cluster": {
Server: "https://192.0.2.1:6443",
},
},
}
respKubeconfigBytes, err := clientcmd.Write(respKubeconfig)
require.NoError(err)
resp := &initproto.InitResponse{
Kind: &initproto.InitResponse_InitSuccess{
InitSuccess: &initproto.InitSuccessResponse{
OwnerId: []byte("ownerID"),
ClusterId: []byte("clusterID"),
Kubeconfig: []byte("kubeconfig"),
Kubeconfig: respKubeconfigBytes,
},
},
}
@ -355,7 +390,7 @@ func TestWriteOutput(t *testing.T) {
expectedIDFile := clusterid.File{
ClusterID: clusterID,
OwnerID: ownerID,
IP: "cluster-ip",
IP: clusterEndpoint,
UID: "test-uid",
}
@ -365,10 +400,10 @@ func TestWriteOutput(t *testing.T) {
idFile := clusterid.File{
UID: "test-uid",
IP: "cluster-ip",
IP: clusterEndpoint,
}
i := newInitCmd(nil, fileHandler, &nopSpinner{}, &stubMerger{}, logger.NewTest(t))
err := i.writeOutput(idFile, resp.GetInitSuccess(), false, &out)
err = i.writeOutput(idFile, resp.GetInitSuccess(), false, &out)
require.NoError(err)
// assert.Contains(out.String(), ownerID)
assert.Contains(out.String(), clusterID)
@ -377,7 +412,8 @@ func TestWriteOutput(t *testing.T) {
afs := afero.Afero{Fs: testFs}
adminConf, err := afs.ReadFile(constants.AdminConfFilename)
assert.NoError(err)
assert.Equal(string(resp.GetInitSuccess().GetKubeconfig()), string(adminConf))
assert.Contains(string(adminConf), clusterEndpoint)
assert.Equal(string(expectedKubeconfigBytes), string(adminConf))
idsFile, err := afs.ReadFile(constants.ClusterIDsFilename)
assert.NoError(err)