AB#2305 Fix missing atls verifier in init call (#352)

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2022-08-09 14:04:40 +02:00 committed by GitHub
parent aee3f2afa2
commit 8f5f84deb5
9 changed files with 184 additions and 70 deletions

View File

@ -18,14 +18,14 @@ import (
const warningStr = "Warning: not verifying the Constellation cluster's %s measurements\n" const warningStr = "Warning: not verifying the Constellation cluster's %s measurements\n"
type Validators struct { type Validator struct {
provider cloudprovider.Provider provider cloudprovider.Provider
pcrs map[uint32][]byte pcrs map[uint32][]byte
validators []atls.Validator validator atls.Validator
} }
func NewValidators(provider cloudprovider.Provider, config *config.Config) (*Validators, error) { func NewValidator(provider cloudprovider.Provider, config *config.Config) (*Validator, error) {
v := Validators{} v := Validator{}
if provider == cloudprovider.Unknown { if provider == cloudprovider.Unknown {
return nil, errors.New("unknown cloud provider") return nil, errors.New("unknown cloud provider")
} }
@ -36,7 +36,7 @@ func NewValidators(provider cloudprovider.Provider, config *config.Config) (*Val
return &v, nil return &v, nil
} }
func (v *Validators) UpdateInitPCRs(ownerID, clusterID string) error { func (v *Validator) UpdateInitPCRs(ownerID, clusterID string) error {
if err := v.updatePCR(uint32(vtpm.PCRIndexOwnerID), ownerID); err != nil { if err := v.updatePCR(uint32(vtpm.PCRIndexOwnerID), ownerID); err != nil {
return err return err
} }
@ -48,7 +48,7 @@ func (v *Validators) UpdateInitPCRs(ownerID, clusterID string) error {
// When adding, the input is first decoded from base64. // When adding, the input is first decoded from base64.
// We then calculate the expected PCR by hashing the input using SHA256, // We then calculate the expected PCR by hashing the input using SHA256,
// appending expected PCR for initialization, and then hashing once more. // appending expected PCR for initialization, and then hashing once more.
func (v *Validators) updatePCR(pcrIndex uint32, encoded string) error { func (v *Validator) updatePCR(pcrIndex uint32, encoded string) error {
if encoded == "" { if encoded == "" {
delete(v.pcrs, pcrIndex) delete(v.pcrs, pcrIndex)
return nil return nil
@ -65,7 +65,7 @@ func (v *Validators) updatePCR(pcrIndex uint32, encoded string) error {
return nil return nil
} }
func (v *Validators) setPCRs(config *config.Config) error { func (v *Validator) setPCRs(config *config.Config) error {
switch v.provider { switch v.provider {
case cloudprovider.GCP: case cloudprovider.GCP:
gcpPCRs := config.Provider.GCP.Measurements gcpPCRs := config.Provider.GCP.Measurements
@ -89,33 +89,32 @@ func (v *Validators) setPCRs(config *config.Config) error {
return nil return nil
} }
// V returns validators as list of atls.Validator. // V returns the validator as atls.Validator.
func (v *Validators) V() []atls.Validator { func (v *Validator) V() atls.Validator {
v.updateValidators() v.updateValidator()
return v.validators return v.validator
} }
func (v *Validators) updateValidators() { // PCRS returns the validator's PCR map.
func (v *Validator) PCRS() map[uint32][]byte {
return v.pcrs
}
func (v *Validator) updateValidator() {
switch v.provider { switch v.provider {
case cloudprovider.GCP: case cloudprovider.GCP:
v.validators = []atls.Validator{ v.validator = gcp.NewValidator(v.pcrs)
gcp.NewValidator(v.pcrs),
}
case cloudprovider.Azure: case cloudprovider.Azure:
v.validators = []atls.Validator{ v.validator = azure.NewValidator(v.pcrs)
azure.NewValidator(v.pcrs),
}
case cloudprovider.QEMU: case cloudprovider.QEMU:
v.validators = []atls.Validator{ v.validator = qemu.NewValidator(v.pcrs)
qemu.NewValidator(v.pcrs),
}
} }
} }
// Warnings returns warnings for the specifc PCR values that are not verified. // Warnings returns warnings for the specifc PCR values that are not verified.
// //
// PCR allocation inspired by https://link.springer.com/chapter/10.1007/978-1-4302-6584-9_12#Tab1 // PCR allocation inspired by https://link.springer.com/chapter/10.1007/978-1-4302-6584-9_12#Tab1
func (v *Validators) Warnings() string { func (v *Validator) Warnings() string {
sb := &strings.Builder{} sb := &strings.Builder{}
if v.pcrs[0] == nil || v.pcrs[1] == nil { if v.pcrs[0] == nil || v.pcrs[1] == nil {
@ -141,11 +140,11 @@ func (v *Validators) Warnings() string {
return sb.String() return sb.String()
} }
// WarningsIncludeInit returns warnings for the specifc PCR values that are not verified. // WarningsIncludeInit returns warnings for the specific PCR values that are not verified.
// Warnings regarding the initialization are included. // Warnings regarding the initialization are included.
// //
// PCR allocation inspired by https://link.springer.com/chapter/10.1007/978-1-4302-6584-9_12#Tab1 // PCR allocation inspired by https://link.springer.com/chapter/10.1007/978-1-4302-6584-9_12#Tab1
func (v *Validators) WarningsIncludeInit() string { func (v *Validator) WarningsIncludeInit() string {
warnings := v.Warnings() warnings := v.Warnings()
if v.pcrs[uint32(vtpm.PCRIndexOwnerID)] == nil || v.pcrs[uint32(vtpm.PCRIndexClusterID)] == nil { if v.pcrs[uint32(vtpm.PCRIndexOwnerID)] == nil || v.pcrs[uint32(vtpm.PCRIndexClusterID)] == nil {
warnings = warnings + fmt.Sprintf(warningStr, "initialization status") warnings = warnings + fmt.Sprintf(warningStr, "initialization status")
@ -154,7 +153,7 @@ func (v *Validators) WarningsIncludeInit() string {
return warnings return warnings
} }
func (v *Validators) checkPCRs(pcrs map[uint32][]byte) error { func (v *Validator) checkPCRs(pcrs map[uint32][]byte) error {
if len(pcrs) == 0 { if len(pcrs) == 0 {
return errors.New("no PCR values provided") return errors.New("no PCR values provided")
} }

View File

@ -15,7 +15,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestNewValidators(t *testing.T) { func TestNewValidator(t *testing.T) {
zero := []byte("00000000000000000000000000000000") zero := []byte("00000000000000000000000000000000")
one := []byte("11111111111111111111111111111111") one := []byte("11111111111111111111111111111111")
testPCRs := map[uint32][]byte{ testPCRs := map[uint32][]byte{
@ -80,7 +80,7 @@ func TestNewValidators(t *testing.T) {
conf.Provider.QEMU = &config.QEMUConfig{Measurements: measurements} conf.Provider.QEMU = &config.QEMUConfig{Measurements: measurements}
} }
validators, err := NewValidators(tc.provider, conf) validators, err := NewValidator(tc.provider, conf)
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
@ -93,7 +93,7 @@ func TestNewValidators(t *testing.T) {
} }
} }
func TestValidatorsWarnings(t *testing.T) { func TestValidatorWarnings(t *testing.T) {
zero := []byte("00000000000000000000000000000000") zero := []byte("00000000000000000000000000000000")
testCases := map[string]struct { testCases := map[string]struct {
@ -233,7 +233,7 @@ func TestValidatorsWarnings(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
validators := Validators{pcrs: tc.pcrs} validators := Validator{pcrs: tc.pcrs}
warnings := validators.Warnings() warnings := validators.Warnings()
warningsInclueInit := validators.WarningsIncludeInit() warningsInclueInit := validators.WarningsIncludeInit()
@ -259,7 +259,7 @@ func TestValidatorsWarnings(t *testing.T) {
} }
} }
func TestValidatorsV(t *testing.T) { func TestValidatorV(t *testing.T) {
zero := []byte("00000000000000000000000000000000") zero := []byte("00000000000000000000000000000000")
newTestPCRs := func() map[uint32][]byte { newTestPCRs := func() map[uint32][]byte {
return map[uint32][]byte{ return map[uint32][]byte{
@ -282,28 +282,22 @@ func TestValidatorsV(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
provider cloudprovider.Provider provider cloudprovider.Provider
pcrs map[uint32][]byte pcrs map[uint32][]byte
wantVs []atls.Validator wantVs atls.Validator
}{ }{
"gcp": { "gcp": {
provider: cloudprovider.GCP, provider: cloudprovider.GCP,
pcrs: newTestPCRs(), pcrs: newTestPCRs(),
wantVs: []atls.Validator{ wantVs: gcp.NewValidator(newTestPCRs()),
gcp.NewValidator(newTestPCRs()),
},
}, },
"azure": { "azure": {
provider: cloudprovider.Azure, provider: cloudprovider.Azure,
pcrs: newTestPCRs(), pcrs: newTestPCRs(),
wantVs: []atls.Validator{ wantVs: azure.NewValidator(newTestPCRs()),
azure.NewValidator(newTestPCRs()),
},
}, },
"qemu": { "qemu": {
provider: cloudprovider.QEMU, provider: cloudprovider.QEMU,
pcrs: newTestPCRs(), pcrs: newTestPCRs(),
wantVs: []atls.Validator{ wantVs: qemu.NewValidator(newTestPCRs()),
qemu.NewValidator(newTestPCRs()),
},
}, },
} }
@ -311,19 +305,16 @@ func TestValidatorsV(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
validators := &Validators{provider: tc.provider, pcrs: tc.pcrs} validators := &Validator{provider: tc.provider, pcrs: tc.pcrs}
resultValidators := validators.V() resultValidator := validators.V()
assert.Equal(len(tc.wantVs), len(resultValidators)) assert.Equal(tc.wantVs.OID(), resultValidator.OID())
for i, resValidator := range resultValidators {
assert.Equal(tc.wantVs[i].OID(), resValidator.OID())
}
}) })
} }
} }
func TestValidatorsUpdateInitPCRs(t *testing.T) { func TestValidatorUpdateInitPCRs(t *testing.T) {
zero := []byte("00000000000000000000000000000000") zero := []byte("00000000000000000000000000000000")
one := []byte("11111111111111111111111111111111") one := []byte("11111111111111111111111111111111")
one64 := base64.StdEncoding.EncodeToString(one) one64 := base64.StdEncoding.EncodeToString(one)
@ -402,7 +393,7 @@ func TestValidatorsUpdateInitPCRs(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
validators := &Validators{provider: tc.provider, pcrs: tc.pcrs} validators := &Validator{provider: tc.provider, pcrs: tc.pcrs}
err := validators.UpdateInitPCRs(tc.ownerID, tc.clusterID) err := validators.UpdateInitPCRs(tc.ownerID, tc.clusterID)
@ -515,7 +506,7 @@ func TestUpdatePCR(t *testing.T) {
pcrs[k] = v pcrs[k] = v
} }
validators := &Validators{ validators := &Validator{
provider: cloudprovider.GCP, provider: cloudprovider.GCP,
pcrs: pcrs, pcrs: pcrs,
} }

View File

@ -52,13 +52,15 @@ func NewInitCmd() *cobra.Command {
func runInitialize(cmd *cobra.Command, args []string) error { func runInitialize(cmd *cobra.Command, args []string) error {
fileHandler := file.NewHandler(afero.NewOsFs()) fileHandler := file.NewHandler(afero.NewOsFs())
serviceAccountCreator := cloudcmd.NewServiceAccountCreator() serviceAccountCreator := cloudcmd.NewServiceAccountCreator()
dialer := dialer.New(nil, nil, &net.Dialer{}) newDialer := func(validator *cloudcmd.Validator) *dialer.Dialer {
return initialize(cmd, dialer, serviceAccountCreator, fileHandler) return dialer.New(nil, validator.V(), &net.Dialer{})
}
return initialize(cmd, newDialer, serviceAccountCreator, fileHandler)
} }
// initialize initializes a Constellation. // initialize initializes a Constellation.
func initialize(cmd *cobra.Command, dialer grpcDialer, serviceAccCreator serviceAccountCreator, func initialize(cmd *cobra.Command, newDialer func(*cloudcmd.Validator) *dialer.Dialer,
fileHandler file.Handler, serviceAccCreator serviceAccountCreator, fileHandler file.Handler,
) error { ) error {
flags, err := evalFlagArgs(cmd, fileHandler) flags, err := evalFlagArgs(cmd, fileHandler)
if err != nil { if err != nil {
@ -88,11 +90,11 @@ func initialize(cmd *cobra.Command, dialer grpcDialer, serviceAccCreator service
}) })
} }
validators, err := cloudcmd.NewValidators(provider, config) validator, err := cloudcmd.NewValidator(provider, config)
if err != nil { if err != nil {
return err return err
} }
cmd.Print(validators.WarningsIncludeInit()) cmd.Print(validator.WarningsIncludeInit())
cmd.Println("Creating service account ...") cmd.Println("Creating service account ...")
serviceAccount, stat, err := serviceAccCreator.Create(cmd.Context(), stat, config) serviceAccount, stat, err := serviceAccCreator.Create(cmd.Context(), stat, config)
@ -103,7 +105,7 @@ func initialize(cmd *cobra.Command, dialer grpcDialer, serviceAccCreator service
return err return err
} }
controlPlanes, workers, err := getScalingGroupsFromState(stat, config) _, workers, err := getScalingGroupsFromState(stat, config)
if err != nil { if err != nil {
return err return err
} }
@ -125,12 +127,12 @@ func initialize(cmd *cobra.Command, dialer grpcDialer, serviceAccCreator service
KubernetesVersion: config.KubernetesVersion, KubernetesVersion: config.KubernetesVersion,
SshUserKeys: ssh.ToProtoSlice(sshUsers), SshUserKeys: ssh.ToProtoSlice(sshUsers),
} }
resp, err := initCall(cmd.Context(), dialer, stat.BootstrapperHost, req) resp, err := initCall(cmd.Context(), newDialer(validator), stat.BootstrapperHost, req)
if err != nil { if err != nil {
return err return err
} }
return writeOutput(resp, controlPlanes.PublicIPs()[0], cmd.OutOrStdout(), fileHandler) return writeOutput(resp, stat.BootstrapperHost, cmd.OutOrStdout(), fileHandler)
} }
func initCall(ctx context.Context, dialer grpcDialer, ip string, req *initproto.InitRequest) (*initproto.InitResponse, error) { func initCall(ctx context.Context, dialer grpcDialer, ip string, req *initproto.InitRequest) (*initproto.InitResponse, error) {

View File

@ -13,12 +13,16 @@ import (
"time" "time"
"github.com/edgelesssys/constellation/bootstrapper/initproto" "github.com/edgelesssys/constellation/bootstrapper/initproto"
"github.com/edgelesssys/constellation/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/internal/cloud/cloudtypes" "github.com/edgelesssys/constellation/internal/cloud/cloudtypes"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/constants" "github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/file" "github.com/edgelesssys/constellation/internal/file"
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials" "github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/internal/grpc/dialer" "github.com/edgelesssys/constellation/internal/grpc/dialer"
"github.com/edgelesssys/constellation/internal/grpc/testdialer" "github.com/edgelesssys/constellation/internal/grpc/testdialer"
"github.com/edgelesssys/constellation/internal/oid"
"github.com/edgelesssys/constellation/internal/state" "github.com/edgelesssys/constellation/internal/state"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -131,7 +135,9 @@ func TestInitialize(t *testing.T) {
require := require.New(t) require := require.New(t)
netDialer := testdialer.NewBufconnDialer() netDialer := testdialer.NewBufconnDialer()
dialer := dialer.New(nil, nil, netDialer) newDialer := func(*cloudcmd.Validator) *dialer.Dialer {
return dialer.New(nil, nil, netDialer)
}
serverCreds := atlscredentials.New(nil, nil) serverCreds := atlscredentials.New(nil, nil)
initServer := grpc.NewServer(grpc.Creds(serverCreds)) initServer := grpc.NewServer(grpc.Creds(serverCreds))
initproto.RegisterAPIServer(initServer, tc.initServerAPI) initproto.RegisterAPIServer(initServer, tc.initServerAPI)
@ -145,7 +151,7 @@ func TestInitialize(t *testing.T) {
cmd.SetOut(&out) cmd.SetOut(&out)
var errOut bytes.Buffer var errOut bytes.Buffer
cmd.SetErr(&errOut) cmd.SetErr(&errOut)
cmd.Flags().String("config", "", "") // register persisten flag manually cmd.Flags().String("config", "", "") // register persistent flag manually
fs := afero.NewMemMapFs() fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs) fileHandler := file.NewHandler(fs)
require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.existingState, file.OptNone)) require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.existingState, file.OptNone))
@ -156,7 +162,7 @@ func TestInitialize(t *testing.T) {
defer cancel() defer cancel()
cmd.SetContext(ctx) cmd.SetContext(ctx)
err := initialize(cmd, dialer, &tc.serviceAccountCreator, fileHandler) err := initialize(cmd, newDialer, &tc.serviceAccountCreator, fileHandler)
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
@ -364,6 +370,122 @@ func TestReadOrGeneratedMasterSecret(t *testing.T) {
} }
} }
func TestAttestation(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
initServerAPI := &stubInitServer{initResp: &initproto.InitResponse{
Kubeconfig: []byte("kubeconfig"),
OwnerId: []byte("ownerID"),
ClusterId: []byte("clusterID"),
}}
existingState := state.ConstellationState{
CloudProvider: "QEMU",
BootstrapperHost: "192.0.2.1",
QEMUWorkers: cloudtypes.Instances{
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
QEMUControlPlane: cloudtypes.Instances{
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
}
netDialer := testdialer.NewBufconnDialer()
newDialer := func(v *cloudcmd.Validator) *dialer.Dialer {
validator := &testValidator{
Getter: oid.QEMU{},
pcrs: v.PCRS(),
}
return dialer.New(nil, validator, netDialer)
}
issuer := &testIssuer{
Getter: oid.QEMU{},
pcrs: map[uint32][]byte{
0: []byte("ffffffffffffffffffffffffffffffff"),
1: []byte("ffffffffffffffffffffffffffffffff"),
2: []byte("ffffffffffffffffffffffffffffffff"),
3: []byte("ffffffffffffffffffffffffffffffff"),
},
}
serverCreds := atlscredentials.New(issuer, nil)
initServer := grpc.NewServer(grpc.Creds(serverCreds))
initproto.RegisterAPIServer(initServer, initServerAPI)
port := strconv.Itoa(constants.BootstrapperPort)
listener := netDialer.GetListener(net.JoinHostPort("192.0.2.1", port))
go initServer.Serve(listener)
defer initServer.GracefulStop()
cmd := NewInitCmd()
cmd.Flags().String("config", constants.ConfigFilename, "") // register persistent flag manually
var out bytes.Buffer
cmd.SetOut(&out)
var errOut bytes.Buffer
cmd.SetErr(&errOut)
fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs)
require.NoError(fileHandler.WriteJSON(constants.StateFilename, existingState, file.OptNone))
cfg := config.Default()
cfg.RemoveProviderExcept(cloudprovider.QEMU)
cfg.Provider.QEMU.Measurements[0] = []byte("00000000000000000000000000000000")
cfg.Provider.QEMU.Measurements[1] = []byte("11111111111111111111111111111111")
cfg.Provider.QEMU.Measurements[2] = []byte("22222222222222222222222222222222")
cfg.Provider.QEMU.Measurements[3] = []byte("33333333333333333333333333333333")
require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg, file.OptNone))
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
defer cancel()
cmd.SetContext(ctx)
err := initialize(cmd, newDialer, &stubServiceAccountCreator{}, fileHandler)
assert.Error(err)
// make sure the error is actually a TLS handshake error
assert.Contains(err.Error(), "transport: authentication handshake failed")
}
type testValidator struct {
oid.Getter
pcrs map[uint32][]byte
}
func (v *testValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
var attestation struct {
UserData []byte
PCRs map[uint32][]byte
}
if err := json.Unmarshal(attDoc, &attestation); err != nil {
return nil, err
}
for k, pcr := range v.pcrs {
if !bytes.Equal(attestation.PCRs[k], pcr) {
return nil, errors.New("invalid PCR value")
}
}
return attestation.UserData, nil
}
type testIssuer struct {
oid.Getter
pcrs map[uint32][]byte
}
func (i *testIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) {
return json.Marshal(
struct {
UserData []byte
PCRs map[uint32][]byte
}{
UserData: userData,
PCRs: i.pcrs,
},
)
}
type stubInitServer struct { type stubInitServer struct {
initResp *initproto.InitResponse initResp *initproto.InitResponse
initErr error initErr error

View File

@ -68,7 +68,7 @@ func recover(cmd *cobra.Command, fileHandler file.Handler, recoveryClient recove
return fmt.Errorf("reading and validating config: %w", err) return fmt.Errorf("reading and validating config: %w", err)
} }
validators, err := cloudcmd.NewValidators(provider, config) validators, err := cloudcmd.NewValidator(provider, config)
if err != nil { if err != nil {
return err return err
} }

View File

@ -8,7 +8,7 @@ import (
) )
type recoveryClient interface { type recoveryClient interface {
Connect(endpoint string, validators []atls.Validator) error Connect(endpoint string, validators atls.Validator) error
PushStateDiskKey(ctx context.Context, stateDiskKey, measurementSecret []byte) error PushStateDiskKey(ctx context.Context, stateDiskKey, measurementSecret []byte) error
io.Closer io.Closer
} }

View File

@ -15,7 +15,7 @@ type stubRecoveryClient struct {
pushStateDiskKeyKey []byte pushStateDiskKeyKey []byte
} }
func (c *stubRecoveryClient) Connect(_ string, _ []atls.Validator) error { func (c *stubRecoveryClient) Connect(_ string, _ atls.Validator) error {
c.conn = true c.conn = true
return c.connectErr return c.connectErr
} }

View File

@ -62,7 +62,7 @@ func verify(
return fmt.Errorf("reading and validating config: %w", err) return fmt.Errorf("reading and validating config: %w", err)
} }
validators, err := cloudcmd.NewValidators(provider, config) validators, err := cloudcmd.NewValidator(provider, config)
if err != nil { if err != nil {
return err return err
} }
@ -90,7 +90,7 @@ func verify(
Nonce: nonce, Nonce: nonce,
UserData: userData, UserData: userData,
}, },
validators.V()[0], validators.V(),
); err != nil { ); err != nil {
return err return err
} }

View File

@ -21,8 +21,8 @@ type KeyClient struct {
// The connection must be closed using Close(). If connect is // The connection must be closed using Close(). If connect is
// called on a client that already has a connection, the old // called on a client that already has a connection, the old
// connection is closed. // connection is closed.
func (c *KeyClient) Connect(endpoint string, validators []atls.Validator) error { func (c *KeyClient) Connect(endpoint string, validators atls.Validator) error {
creds := atlscredentials.New(nil, validators) creds := atlscredentials.New(nil, []atls.Validator{validators})
conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(creds)) conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(creds))
if err != nil { if err != nil {