2022-09-05 03:06:08 -04:00
|
|
|
/*
|
|
|
|
Copyright (c) Edgeless Systems GmbH
|
|
|
|
|
|
|
|
SPDX-License-Identifier: AGPL-3.0-only
|
|
|
|
*/
|
|
|
|
|
2022-07-18 04:18:29 -04:00
|
|
|
package client
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"fmt"
|
|
|
|
"strings"
|
|
|
|
|
|
|
|
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
|
|
|
|
)
|
|
|
|
|
|
|
|
// GetNodeImage returns the image name of the node.
|
|
|
|
func (c *Client) GetNodeImage(ctx context.Context, providerID string) (string, error) {
|
|
|
|
_, resourceGroup, scaleSet, instanceID, err := scaleSetInformationFromProviderID(providerID)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
resp, err := c.virtualMachineScaleSetVMsAPI.Get(ctx, resourceGroup, scaleSet, instanceID, nil)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
if resp.Properties == nil ||
|
|
|
|
resp.Properties.StorageProfile == nil ||
|
|
|
|
resp.Properties.StorageProfile.ImageReference == nil ||
|
2022-08-30 09:15:51 -04:00
|
|
|
resp.Properties.StorageProfile.ImageReference.ID == nil && resp.Properties.StorageProfile.ImageReference.CommunityGalleryImageID == nil {
|
2022-07-18 04:18:29 -04:00
|
|
|
return "", fmt.Errorf("node %q does not have valid image reference", providerID)
|
|
|
|
}
|
2022-08-30 09:15:51 -04:00
|
|
|
if resp.Properties.StorageProfile.ImageReference.ID != nil {
|
|
|
|
return *resp.Properties.StorageProfile.ImageReference.ID, nil
|
|
|
|
}
|
2022-10-05 09:02:46 -04:00
|
|
|
return *resp.Properties.StorageProfile.ImageReference.CommunityGalleryImageID, nil
|
2022-07-18 04:18:29 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
// GetScalingGroupID returns the scaling group ID of the node.
|
|
|
|
func (c *Client) GetScalingGroupID(ctx context.Context, providerID string) (string, error) {
|
|
|
|
subscriptionID, resourceGroup, scaleSet, _, err := scaleSetInformationFromProviderID(providerID)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
return joinVMSSID(subscriptionID, resourceGroup, scaleSet), nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// CreateNode creates a node in the specified scaling group.
|
|
|
|
func (c *Client) CreateNode(ctx context.Context, scalingGroupID string) (nodeName, providerID string, err error) {
|
|
|
|
_, resourceGroup, scaleSet, err := splitVMSSID(scalingGroupID)
|
|
|
|
if err != nil {
|
|
|
|
return "", "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
// get list of instance IDs before scaling,
|
|
|
|
var oldVMIDs []string
|
|
|
|
pager := c.virtualMachineScaleSetVMsAPI.NewListPager(resourceGroup, scaleSet, nil)
|
|
|
|
for pager.More() {
|
|
|
|
page, err := pager.NextPage(ctx)
|
|
|
|
if err != nil {
|
|
|
|
return "", "", err
|
|
|
|
}
|
|
|
|
for _, vm := range page.Value {
|
|
|
|
if vm == nil || vm.ID == nil {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
oldVMIDs = append(oldVMIDs, *vm.ID)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// increase the number of instances by one
|
|
|
|
resp, err := c.scaleSetsAPI.Get(ctx, resourceGroup, scaleSet, nil)
|
|
|
|
if err != nil {
|
|
|
|
return "", "", err
|
|
|
|
}
|
|
|
|
if resp.SKU == nil || resp.SKU.Capacity == nil {
|
|
|
|
return "", "", fmt.Errorf("scale set %q does not have valid capacity", scaleSet)
|
|
|
|
}
|
|
|
|
wantedCapacity := *resp.SKU.Capacity + 1
|
|
|
|
_, err = c.scaleSetsAPI.BeginUpdate(ctx, resourceGroup, scaleSet, armcompute.VirtualMachineScaleSetUpdate{
|
|
|
|
SKU: &armcompute.SKU{
|
|
|
|
Capacity: &wantedCapacity,
|
|
|
|
},
|
|
|
|
}, nil)
|
|
|
|
if err != nil {
|
|
|
|
return "", "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
poller := c.capacityPollerGenerator(resourceGroup, scaleSet, wantedCapacity)
|
|
|
|
if _, err := poller.PollUntilDone(ctx, c.pollerOptions); err != nil {
|
|
|
|
return "", "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
// get the list of instances again
|
|
|
|
// and find the new instance id by comparing the old and new lists
|
|
|
|
pager = c.virtualMachineScaleSetVMsAPI.NewListPager(resourceGroup, scaleSet, nil)
|
|
|
|
for pager.More() {
|
|
|
|
page, err := pager.NextPage(ctx)
|
|
|
|
if err != nil {
|
|
|
|
return "", "", err
|
|
|
|
}
|
|
|
|
for _, vm := range page.Value {
|
|
|
|
// check if the instance already existed in the old list
|
|
|
|
if !hasKnownVMID(vm, oldVMIDs) {
|
|
|
|
return strings.ToLower(*vm.Properties.OSProfile.ComputerName), "azure://" + *vm.ID, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return "", "", fmt.Errorf("failed to find new node after scaling up")
|
|
|
|
}
|
|
|
|
|
|
|
|
// DeleteNode deletes a node specified by its provider ID.
|
|
|
|
func (c *Client) DeleteNode(ctx context.Context, providerID string) error {
|
|
|
|
_, resourceGroup, scaleSet, instanceID, err := scaleSetInformationFromProviderID(providerID)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
ids := armcompute.VirtualMachineScaleSetVMInstanceRequiredIDs{InstanceIDs: []*string{&instanceID}}
|
|
|
|
_, err = c.scaleSetsAPI.BeginDeleteInstances(ctx, resourceGroup, scaleSet, ids, nil)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// capacityPollingHandler polls a scale set
|
|
|
|
// until its capacity reaches the desired value.
|
|
|
|
type capacityPollingHandler struct {
|
|
|
|
done bool
|
|
|
|
wantedCapacity int64
|
|
|
|
resourceGroup string
|
|
|
|
scaleSet string
|
|
|
|
scaleSetsAPI
|
|
|
|
}
|
|
|
|
|
|
|
|
func (h *capacityPollingHandler) Done() bool {
|
|
|
|
return h.done
|
|
|
|
}
|
|
|
|
|
|
|
|
func (h *capacityPollingHandler) Poll(ctx context.Context) error {
|
|
|
|
resp, err := h.scaleSetsAPI.Get(ctx, h.resourceGroup, h.scaleSet, nil)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
if resp.SKU == nil || resp.SKU.Capacity == nil {
|
|
|
|
return fmt.Errorf("scale set %q does not have valid capacity", h.scaleSet)
|
|
|
|
}
|
|
|
|
h.done = *resp.SKU.Capacity == h.wantedCapacity
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (h *capacityPollingHandler) Result(ctx context.Context, out *int64) error {
|
|
|
|
if !h.done {
|
|
|
|
return fmt.Errorf("failed to scale up")
|
|
|
|
}
|
|
|
|
*out = h.wantedCapacity
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// hasKnownVMID returns true if the vmID is found in the vm ID list.
|
|
|
|
func hasKnownVMID(vm *armcompute.VirtualMachineScaleSetVM, vmIDs []string) bool {
|
|
|
|
for _, id := range vmIDs {
|
|
|
|
if vm != nil && vm.ID != nil && *vm.ID == id {
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return false
|
|
|
|
}
|