mirror of
https://github.com/edgelesssys/constellation.git
synced 2024-10-01 01:36:09 -04:00
monorepo
Co-authored-by: Malte Poll <mp@edgeless.systems> Co-authored-by: katexochen <katexochen@users.noreply.github.com> Co-authored-by: Daniel Weiße <dw@edgeless.systems> Co-authored-by: Thomas Tendyck <tt@edgeless.systems> Co-authored-by: Benedict Schlueter <bs@edgeless.systems> Co-authored-by: leongross <leon.gross@rub.de> Co-authored-by: Moritz Eckert <m1gh7ym0@gmail.com>
This commit is contained in:
commit
2d8fcd9bf4
31
.dockerignore
Normal file
31
.dockerignore
Normal file
@ -0,0 +1,31 @@
|
||||
# Binaries for programs and plugins
|
||||
*.exe
|
||||
*.exe~
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
|
||||
# Test binary, built with `go test -c`
|
||||
*.test
|
||||
|
||||
# Output of the go coverage tool, specifically when used with LiteIDE
|
||||
*.out
|
||||
|
||||
#ignore build files
|
||||
/build
|
||||
admin.conf
|
||||
coordinatorConfig.json
|
||||
coordinator-*
|
||||
|
||||
/debugd
|
||||
/images
|
||||
|
||||
# Dockerfiles
|
||||
Dockerfile
|
||||
Dockerfile.*
|
||||
|
||||
# GitHub actions
|
||||
.github
|
||||
|
||||
# VS Code configuration folder
|
||||
.vscode
|
55
.github/workflows/build-ami.yml
vendored
Normal file
55
.github/workflows/build-ami.yml
vendored
Normal file
@ -0,0 +1,55 @@
|
||||
name: Build the AMI Template
|
||||
on:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
secrets:
|
||||
AWS_ACCESS_KEY_ID:
|
||||
required: true
|
||||
AWS_SECRET_ACCESS_KEY:
|
||||
required: true
|
||||
AWS_DEFAULT_REGION:
|
||||
required: true
|
||||
BUCKET_NAME:
|
||||
required: true
|
||||
|
||||
|
||||
jobs:
|
||||
build-enclave:
|
||||
name: "Build the AMI"
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
AWS_DEFAULT_REGION: ${{ secrets.AWS_DEFAULT_REGION }}
|
||||
working-directory: images/aws/ec2
|
||||
steps:
|
||||
- name: Checkout
|
||||
id: checkout
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Install AWS CLI
|
||||
id: prepare
|
||||
run: sudo apt-get update && sudo apt-get -y install awscli
|
||||
|
||||
- name: Download eif
|
||||
id: download_eif
|
||||
run: aws s3 cp s3://${{ secrets.BUCKET_NAME }}/eif/ ${{ github.workspace }}/${{ env.working-directory }}/ --recursive --quiet
|
||||
|
||||
- name: Download gvproxy
|
||||
id: download_gvproxy
|
||||
run: aws s3 cp s3://${{ secrets.BUCKET_NAME }}/gvproxy/gvproxy ${{ github.workspace }}/${{ env.working-directory }}/ --quiet
|
||||
|
||||
- name: Install build dependencies
|
||||
run: sudo apt-get -y install packer
|
||||
|
||||
- name: Init packer
|
||||
run: packer init .
|
||||
working-directory: ${{ env.working-directory }}
|
||||
|
||||
- name: Validate packer
|
||||
run: packer validate -syntax-only .
|
||||
working-directory: ${{ env.working-directory }}
|
||||
|
||||
- name: Build packer
|
||||
run: packer build -color=false .
|
||||
working-directory: ${{ env.working-directory }}
|
107
.github/workflows/build-coordinator.yml
vendored
Normal file
107
.github/workflows/build-coordinator.yml
vendored
Normal file
@ -0,0 +1,107 @@
|
||||
name: Build and Upload the Coordinator
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
jobs:
|
||||
build-coordinator:
|
||||
name: "Build the Coordinator"
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
AWS_DEFAULT_REGION: ${{ secrets.AWS_DEFAULT_REGION }}
|
||||
outputs:
|
||||
coordinator-name: ${{ steps.copy.outputs.coordinator-name }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
id: checkout
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
- name: Cache Docker layers
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: /tmp/.buildx-cache
|
||||
key: ${{ runner.os }}-buildx-${{ github.sha }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-buildx-
|
||||
|
||||
- name: Install Dependencies
|
||||
id: prepare
|
||||
run: sudo apt-get update && sudo apt-get -y install awscli
|
||||
|
||||
- name: Build the Coordinator
|
||||
uses: docker/build-push-action@v2
|
||||
with:
|
||||
context: .
|
||||
file: Dockerfile.build
|
||||
outputs: .
|
||||
push: false
|
||||
cache-from: type=local,src=/tmp/.buildx-cache
|
||||
cache-to: type=local,dest=/tmp/.buildx-cache-new,mode=max
|
||||
|
||||
# This ugly bit is necessary if you don't want your cache to grow forever
|
||||
# till it hits GitHub's limit of 5GB.
|
||||
# Temp fix
|
||||
# https://github.com/docker/build-push-action/issues/252
|
||||
# https://github.com/moby/buildkit/issues/1896
|
||||
- name: Move cache
|
||||
run: |
|
||||
rm -rf /tmp/.buildx-cache
|
||||
mv /tmp/.buildx-cache-new /tmp/.buildx-cache
|
||||
|
||||
- name: Copy Coordinator to S3 if not exists
|
||||
id: copy
|
||||
run: >
|
||||
aws s3api head-object --bucket ${{ secrets.PUBLIC_BUCKET_NAME }} --key coordinator/$(ls | grep "coordinator-")
|
||||
|| (
|
||||
echo "::set-output name=coordinator-name::$(ls | grep "coordinator-")"
|
||||
&& aws s3 cp ${{ github.workspace }}/ s3://${{ secrets.PUBLIC_BUCKET_NAME }}/coordinator/ --exclude "*" --include "coordinator-*" --include "constellation" --recursive --quiet)
|
||||
shell: bash {0}
|
||||
|
||||
call-coreos:
|
||||
needs: build-coordinator
|
||||
if: startsWith(needs.build-coordinator.outputs.coordinator-name, 'coordinator-')
|
||||
uses: ./.github/workflows/build-coreos.yml
|
||||
with:
|
||||
coordinator-name: ${{ needs.build-coordinator.outputs.coordinator-name }}
|
||||
secrets:
|
||||
CI_GITHUB_REPOSITORY: ${{ secrets.CI_GITHUB_REPOSITORY }}
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
AWS_DEFAULT_REGION: ${{ secrets.AWS_DEFAULT_REGION }}
|
||||
BUCKET_NAME: ${{ secrets.BUCKET_NAME }}
|
||||
PUBLIC_BUCKET_NAME: ${{ secrets.PUBLIC_BUCKET_NAME }}
|
||||
SSH_PUB_KEY: ${{ secrets.SSH_PUB_KEY }}
|
||||
SSH_PUB_KEY_PATH: ${{ secrets.SSH_PUB_KEY_PATH }}
|
||||
AZURE_CREDENTIALS: ${{ secrets.AZURE_CREDENTIALS }}
|
||||
|
||||
call-aws-enclave:
|
||||
needs: build-coordinator
|
||||
if: startsWith(needs.build-coordinator.outputs.coordinator-name, 'coordinator-')
|
||||
uses: ./.github/workflows/build-enclave.yml
|
||||
with:
|
||||
coordinator-name: ${{ needs.build-coordinator.outputs.coordinator-name }}
|
||||
secrets:
|
||||
CI_GITHUB_REPOSITORY: ${{ secrets.CI_GITHUB_REPOSITORY }}
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
AWS_DEFAULT_REGION: ${{ secrets.AWS_DEFAULT_REGION }}
|
||||
BUCKET_NAME: ${{ secrets.BUCKET_NAME }}
|
||||
PUBLIC_BUCKET_NAME: ${{ secrets.PUBLIC_BUCKET_NAME }}
|
||||
SSH_PUB_KEY: ${{ secrets.SSH_PUB_KEY }}
|
||||
SSH_PUB_KEY_PATH: ${{ secrets.SSH_PUB_KEY_PATH }}
|
||||
|
||||
call-aws-ami:
|
||||
needs: call-aws-enclave
|
||||
uses: ./.github/workflows/build-ami.yml
|
||||
secrets:
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
AWS_DEFAULT_REGION: ${{ secrets.AWS_DEFAULT_REGION }}
|
||||
BUCKET_NAME: ${{ secrets.BUCKET_NAME }}
|
79
.github/workflows/build-coreos-debug.yml
vendored
Normal file
79
.github/workflows/build-coreos-debug.yml
vendored
Normal file
@ -0,0 +1,79 @@
|
||||
name: Build and Upload CoreOS debug image
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
on:
|
||||
workflow_dispatch:
|
||||
jobs:
|
||||
build-enclave:
|
||||
name: "Build CoreOS debug image using customized COSA"
|
||||
runs-on: [self-hosted, linux, nested-virt]
|
||||
permissions:
|
||||
contents: read
|
||||
packages: read
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
env:
|
||||
working-directory: ${{ github.workspace }}/images/fcos
|
||||
SHELL: /bin/bash
|
||||
GOPATH: /home/github-actions-runner-user/go
|
||||
GOCACHE: /home/github-actions-runner-user/.cache/go-build
|
||||
GOMODCACHE: /home/github-actions-runner-user/.cache/go-mod
|
||||
steps:
|
||||
- name: Checkout
|
||||
id: checkout
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
submodules: recursive
|
||||
token: ${{ secrets.CI_GITHUB_REPOSITORY }}
|
||||
|
||||
- name: Log in to the Container registry
|
||||
id: docker-login
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: "Install azure CLI"
|
||||
run: |
|
||||
# use pip since azure cli repository is not working as expected
|
||||
# https://github.com/Azure/azure-cli/issues/21532
|
||||
# curl -sL https://aka.ms/InstallAzureCLIDeb | sudo bash
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y python3 python3-pip
|
||||
sudo pip install azure-cli
|
||||
wget -q https://aka.ms/downloadazcopy-v10-linux -O azcopy.tar.gz
|
||||
tar --strip-components 1 -xf azcopy.tar.gz
|
||||
rm azcopy.tar.gz
|
||||
echo "$(pwd)" >> $GITHUB_PATH
|
||||
|
||||
- uses: azure/login@v1
|
||||
with:
|
||||
creds: ${{ secrets.AZURE_CREDENTIALS }}
|
||||
|
||||
- name: Setup Go environment
|
||||
uses: actions/setup-go@v2.2.0
|
||||
with:
|
||||
go-version: "1.18"
|
||||
|
||||
- name: "Compile debugd"
|
||||
run: GOCACHE=/home/github-actions-runner-user/.cache/go-build GOPATH=/home/github-actions-runner-user/go GOPRIVATE=github.com/edgelesssys GOMODCACHE=/home/github-actions-runner-user/.cache/go-mod go build -o constellation-debugd debugd.go
|
||||
working-directory: ${{ github.workspace }}/debugd/debugd/cmd/debugd
|
||||
|
||||
- name: "Store GH token to be mounted by cosa"
|
||||
run: echo "machine github.com login api password ${{ secrets.CI_GITHUB_REPOSITORY }}" > /tmp/.netrc
|
||||
|
||||
- name: "Set image timestamp"
|
||||
run: |
|
||||
TIMESTAMP=$(date +%s)
|
||||
echo "TIMESTAMP=${TIMESTAMP}" >> $GITHUB_ENV
|
||||
echo "IMAGE_TIMESTAMP=constellation-coreos-debugd-${TIMESTAMP}" >> $GITHUB_ENV
|
||||
echo "IMAGE_VERSION=0.0.${TIMESTAMP}" >> $GITHUB_ENV
|
||||
|
||||
- name: "Build and Upload"
|
||||
run: >
|
||||
make -j$(nproc) CONTAINER_ENGINE=docker NETRC=/tmp/.netrc GCP_IMAGE_NAME="${{ env.IMAGE_TIMESTAMP }}" AZURE_IMAGE_NAME="${{ env.IMAGE_TIMESTAMP }}"
|
||||
AZURE_IMAGE_DEFINITION="constellation-coreos-debugd" AZURE_IMAGE_VERSION="${{env.IMAGE_VERSION }}" DOWNLOAD_COORDINATOR=n COORDINATOR_BINARY="${{ github.workspace }}/debugd/debugd/cmd/debugd/constellation-debugd"
|
||||
image-gcp image-azure upload-gcp upload-azure
|
||||
working-directory: ${{ env.working-directory }}
|
99
.github/workflows/build-coreos.yml
vendored
Normal file
99
.github/workflows/build-coreos.yml
vendored
Normal file
@ -0,0 +1,99 @@
|
||||
name: Build and Upload CoreOS
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
coordinator-name:
|
||||
description: Coordinator name
|
||||
required: true
|
||||
type: string
|
||||
|
||||
workflow_call:
|
||||
inputs:
|
||||
coordinator-name:
|
||||
required: true
|
||||
type: string
|
||||
|
||||
secrets:
|
||||
CI_GITHUB_REPOSITORY:
|
||||
required: true
|
||||
AWS_ACCESS_KEY_ID:
|
||||
required: true
|
||||
AWS_SECRET_ACCESS_KEY:
|
||||
required: true
|
||||
AWS_DEFAULT_REGION:
|
||||
required: true
|
||||
BUCKET_NAME:
|
||||
required: true
|
||||
PUBLIC_BUCKET_NAME:
|
||||
required: true
|
||||
SSH_PUB_KEY:
|
||||
required: true
|
||||
SSH_PUB_KEY_PATH:
|
||||
required: true
|
||||
AZURE_CREDENTIALS:
|
||||
required: true
|
||||
|
||||
jobs:
|
||||
build-enclave:
|
||||
name: "Build CoreOS using customized COSA"
|
||||
runs-on: [self-hosted, linux, nested-virt]
|
||||
permissions:
|
||||
contents: read
|
||||
packages: read
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
env:
|
||||
working-directory: ${{ github.workspace }}/images/fcos
|
||||
SHELL: /bin/bash
|
||||
steps:
|
||||
- name: Checkout
|
||||
id: checkout
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
submodules: recursive
|
||||
token: ${{ secrets.CI_GITHUB_REPOSITORY }}
|
||||
|
||||
- name: Log in to the Container registry
|
||||
id: docker-login
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: "Install azure CLI"
|
||||
run: |
|
||||
# use pip since azure cli repository is not working as expected
|
||||
# https://github.com/Azure/azure-cli/issues/21532
|
||||
# curl -sL https://aka.ms/InstallAzureCLIDeb | sudo bash
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y python3 python3-pip
|
||||
sudo pip install azure-cli
|
||||
wget -q https://aka.ms/downloadazcopy-v10-linux -O azcopy.tar.gz
|
||||
tar --strip-components 1 -xf azcopy.tar.gz
|
||||
rm azcopy.tar.gz
|
||||
echo "$(pwd)" >> $GITHUB_PATH
|
||||
|
||||
- uses: azure/login@v1
|
||||
with:
|
||||
creds: ${{ secrets.AZURE_CREDENTIALS }}
|
||||
|
||||
- name: "Store GH token to be mounted by cosa"
|
||||
run: echo "machine github.com login api password ${{ secrets.CI_GITHUB_REPOSITORY }}" > /tmp/.netrc
|
||||
|
||||
- name: "Set image timestamp"
|
||||
run: |
|
||||
TIMESTAMP=$(date +%s)
|
||||
echo "TIMESTAMP=${TIMESTAMP}" >> $GITHUB_ENV
|
||||
echo "IMAGE_TIMESTAMP=constellation-coreos-${TIMESTAMP}" >> $GITHUB_ENV
|
||||
echo "IMAGE_VERSION=0.0.${TIMESTAMP}" >> $GITHUB_ENV
|
||||
|
||||
- name: "Build and Upload"
|
||||
run: >
|
||||
make -j$(nproc) CONTAINER_ENGINE=docker NETRC=/tmp/.netrc GCP_IMAGE_NAME="${{ env.IMAGE_TIMESTAMP }}" AZURE_IMAGE_NAME="${{ env.IMAGE_TIMESTAMP }}"
|
||||
AZURE_IMAGE_DEFINITION="constellation-coreos" AZURE_IMAGE_VERSION="${{env.IMAGE_VERSION }}" COORDINATOR_URL="https://${{ secrets.PUBLIC_BUCKET_NAME }}.s3.us-east-2.amazonaws.com/coordinator/${{ inputs.coordinator-name }}"
|
||||
image-gcp image-azure upload-gcp upload-azure
|
||||
working-directory: ${{ env.working-directory }}
|
76
.github/workflows/build-enclave.yml
vendored
Normal file
76
.github/workflows/build-enclave.yml
vendored
Normal file
@ -0,0 +1,76 @@
|
||||
name: Build and Upload the Enclave Image File
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
coordinator-name:
|
||||
description: Coordinator name
|
||||
required: true
|
||||
type: string
|
||||
|
||||
workflow_call:
|
||||
inputs:
|
||||
coordinator-name:
|
||||
required: true
|
||||
type: string
|
||||
|
||||
secrets:
|
||||
CI_GITHUB_REPOSITORY:
|
||||
required: true
|
||||
AWS_ACCESS_KEY_ID:
|
||||
required: true
|
||||
AWS_SECRET_ACCESS_KEY:
|
||||
required: true
|
||||
AWS_DEFAULT_REGION:
|
||||
required: true
|
||||
BUCKET_NAME:
|
||||
required: true
|
||||
PUBLIC_BUCKET_NAME:
|
||||
required: true
|
||||
SSH_PUB_KEY:
|
||||
required: true
|
||||
SSH_PUB_KEY_PATH:
|
||||
required: true
|
||||
|
||||
|
||||
jobs:
|
||||
build-enclave:
|
||||
name: "Build the Enclave"
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
AWS_DEFAULT_REGION: ${{ secrets.AWS_DEFAULT_REGION }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
id: checkout
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
submodules: recursive
|
||||
token: ${{ secrets.CI_GITHUB_REPOSITORY }}
|
||||
|
||||
- name: Install AWS CLI
|
||||
id: prepare
|
||||
run: sudo apt-get update && sudo apt-get -y install awscli
|
||||
|
||||
- name: Download bzImage, init and nsm.ko to AWS S3 Bucket
|
||||
id: download-artifacts
|
||||
run: aws s3 cp s3://${{ secrets.BUCKET_NAME }}/blobs/ ${{ github.workspace }}/images/aws/enclave/userland/dependencies/blobs/ --recursive
|
||||
|
||||
- name: Download Coordinator
|
||||
id: download-coordinator
|
||||
run: aws s3 cp s3://${{ secrets.PUBLIC_BUCKET_NAME }}/coordinator/${{ inputs.coordinator-name }} ${{ github.workspace }}/images/aws/enclave/userland/build/coordinator
|
||||
|
||||
- name: Write ssh public key to file
|
||||
run: echo $SSH_PUB_KEY >> ${{ env.SSH_PUB_KEY_PATH }} && chmod 644 ${{ env.SSH_PUB_KEY_PATH }}
|
||||
env:
|
||||
SSH_PUB_KEY: ${{ secrets.SSH_PUB_KEY }}
|
||||
SSH_PUB_KEY_PATH: ~/authorized_keys
|
||||
|
||||
- name: Build the eif file
|
||||
run: make -j$(nproc) SSH_DIR=~/ -C ${{ github.workspace }}/images/aws/enclave/
|
||||
|
||||
- name: Upload eif file to AWS S3 Bucket
|
||||
id: upload
|
||||
run: aws s3 cp ${{ github.workspace }}/images/aws/enclave/userland/build/ s3://${{ secrets.BUCKET_NAME }}/eif/ --recursive --exclude "*" --include "*.eif" --quiet
|
||||
|
||||
|
36
.github/workflows/build-kernel.yml
vendored
Normal file
36
.github/workflows/build-kernel.yml
vendored
Normal file
@ -0,0 +1,36 @@
|
||||
name: Build the Kernel
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- 'kernel/**'
|
||||
workflow_dispatch:
|
||||
jobs:
|
||||
compile-and-upload-kernel:
|
||||
name: "Compile and upload the Kernel"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install build dependencies
|
||||
id: install
|
||||
run: sudo apt-get update && sudo apt-get install -y git build-essential fakeroot libncurses5-dev libssl-dev ccache bison flex libelf-dev dwarves
|
||||
|
||||
- name: Checkout
|
||||
id: checkout
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Compile using make
|
||||
id: compile
|
||||
run: make -C ${{ github.workspace }}/images/aws/kernel/
|
||||
|
||||
- name: Install AWS CLI
|
||||
id: prepare
|
||||
run: sudo apt-get -y install awscli
|
||||
|
||||
- name: Upload bzImage, init and nsm.ko to AWS S3 Bucket
|
||||
id: upload
|
||||
run: aws s3 cp ${{ github.workspace }}/images/aws/kernel/build/blobs/ s3://${{ secrets.BUCKET_NAME }}/blobs/ --recursive --quiet
|
||||
env:
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
AWS_DEFAULT_REGION: ${{ secrets.AWS_DEFAULT_REGION }}
|
47
.github/workflows/build-patched-gvisor-proxy.yml
vendored
Normal file
47
.github/workflows/build-patched-gvisor-proxy.yml
vendored
Normal file
@ -0,0 +1,47 @@
|
||||
name: Patch gvisor-tap-vsock and Upload to S3
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: "gvisor version"
|
||||
required: true
|
||||
default: 0.3.0
|
||||
jobs:
|
||||
build:
|
||||
name: "Build"
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
AWS_DEFAULT_REGION: ${{ secrets.AWS_DEFAULT_REGION }}
|
||||
working-directory: ec2
|
||||
steps:
|
||||
- name: Checkout
|
||||
id: checkout
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Prepare Download
|
||||
id: prepare
|
||||
run: sudo apt-get update && sudo apt-get -y install wget tar make
|
||||
|
||||
- name: Download and unpack sources
|
||||
id: unpack
|
||||
run: wget -c https://github.com/containers/gvisor-tap-vsock/archive/refs/tags/v${{ github.event.inputs.version }}.tar.gz -O - | tar xz
|
||||
working-directory: ${{ github.workspace }}
|
||||
|
||||
- name: Install go
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: go1.17.6
|
||||
|
||||
- name: Patch source code
|
||||
run: patch --ignore-whitespace ${{ github.workspace }}/gvisor-tap-vsock-${{ github.event.inputs.version }}/pkg/services/forwarder/tcp.go < ${{ github.workspace }}/images/aws/ec2/patches/remove_link_local.patch
|
||||
working-directory: ${{ env.working-directory }}
|
||||
|
||||
- name: Build gvisor
|
||||
id: build
|
||||
run: make -C ${{ github.workspace }}/gvisor-tap-vsock-${{ github.event.inputs.version }}/
|
||||
|
||||
- name: Upload gvproxy
|
||||
id: upload_gvproxy
|
||||
run: aws s3 cp ${{ github.workspace }}/gvisor-tap-vsock-${{ github.event.inputs.version }}/bin/gvproxy s3://${{ secrets.BUCKET_NAME }}/gvproxy/gvproxy --quiet
|
22
.github/workflows/test-integration-etcdStore.yml
vendored
Normal file
22
.github/workflows/test-integration-etcdStore.yml
vendored
Normal file
@ -0,0 +1,22 @@
|
||||
name: Etcd Integration Test
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
integration-test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Setup Go environment
|
||||
uses: actions/setup-go@v2.1.4
|
||||
with:
|
||||
go-version: "1.18"
|
||||
|
||||
- name: Test Constellation etcd integration
|
||||
run: go test -v --race -cover -count=3 -tags integration
|
||||
working-directory: coordinator/store
|
23
.github/workflows/test-integration.yml
vendored
Normal file
23
.github/workflows/test-integration.yml
vendored
Normal file
@ -0,0 +1,23 @@
|
||||
name: Integration Test
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
integration-test:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
GOPRIVATE: github.com/edgelesssys/*
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Setup Go environment
|
||||
uses: actions/setup-go@v2.1.4
|
||||
with:
|
||||
go-version: "1.18"
|
||||
|
||||
- name: Run Integration Test
|
||||
run: DEBUG=true go test -v -tags integration ./test/
|
23
.github/workflows/test-lint.yml
vendored
Normal file
23
.github/workflows/test-lint.yml
vendored
Normal file
@ -0,0 +1,23 @@
|
||||
name: Golangci-lint
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
# Allow read access to pull request. Use with `only-new-issues` option.
|
||||
pull-requests: read
|
||||
|
||||
jobs:
|
||||
golangci:
|
||||
name: lint
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
GOPRIVATE: github.com/edgelesssys/*
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v2
|
||||
with:
|
||||
only-new-issues: true
|
18
.github/workflows/test-shellcheck.yml
vendored
Normal file
18
.github/workflows/test-shellcheck.yml
vendored
Normal file
@ -0,0 +1,18 @@
|
||||
name: Shellcheck
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
shellcheck:
|
||||
name: Shellcheck
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Run ShellCheck
|
||||
uses: ludeeus/action-shellcheck@master
|
||||
with:
|
||||
severity: error
|
||||
ignore_names: merge_config.sh
|
27
.github/workflows/test-unittest.yml
vendored
Normal file
27
.github/workflows/test-unittest.yml
vendored
Normal file
@ -0,0 +1,27 @@
|
||||
name: Unit Tests
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
GOPRIVATE: github.com/edgelesssys/*
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: 1.18
|
||||
|
||||
- name: Install Dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y libcryptsetup-dev
|
||||
|
||||
- name: Test
|
||||
run: go test -race -count=3 ./...
|
38
.gitignore
vendored
Normal file
38
.gitignore
vendored
Normal file
@ -0,0 +1,38 @@
|
||||
# Binaries for programs and plugins
|
||||
*.exe
|
||||
*.exe~
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
|
||||
# Test binary, built with `go test -c`
|
||||
*.test
|
||||
|
||||
# Output of the go coverage tool, specifically when used with LiteIDE
|
||||
*.out
|
||||
|
||||
# Dependency directories (remove the comment below to include it)
|
||||
# vendor/
|
||||
|
||||
build
|
||||
admin.conf
|
||||
coordinatorConfig.json
|
||||
coordinator-*
|
||||
|
||||
# VS Code configuration folder
|
||||
.vscode
|
||||
# Debug and testing files
|
||||
debug/
|
||||
|
||||
# Images
|
||||
images/aws/kernel/build/*
|
||||
images/aws/kernel/sed*
|
||||
images/aws/enclave/userland/build/*
|
||||
images/aws/enclave/userland/privatekey
|
||||
images/aws/enclave/userland/publickey
|
||||
images/aws/enclave/.build-*
|
||||
images/*.ign
|
||||
images/fcos/build/*
|
||||
images/fcos/dependencies/coordinator
|
||||
images/fcos/images/*
|
||||
images/fcos/cosa.lock
|
49
.golangci.yml
Normal file
49
.golangci.yml
Normal file
@ -0,0 +1,49 @@
|
||||
run:
|
||||
timeout: 5m
|
||||
|
||||
output:
|
||||
format: tab
|
||||
sort-results: true
|
||||
build-tags:
|
||||
- integration
|
||||
- aws
|
||||
- gcp
|
||||
|
||||
linters:
|
||||
enable:
|
||||
# Default linters
|
||||
- deadcode
|
||||
- errcheck
|
||||
- gosimple
|
||||
- govet
|
||||
- ineffassign
|
||||
- staticcheck
|
||||
- structcheck
|
||||
- typecheck
|
||||
- unused
|
||||
- varcheck
|
||||
# Additional linters
|
||||
- bodyclose
|
||||
- errname
|
||||
- exportloopref
|
||||
- ifshort
|
||||
- godot
|
||||
- gofmt
|
||||
- gofumpt
|
||||
- misspell
|
||||
- noctx
|
||||
- tenv
|
||||
- unconvert
|
||||
- unparam
|
||||
|
||||
issues:
|
||||
max-issues-per-linter: 0
|
||||
max-same-issues: 20
|
||||
|
||||
linters-settings:
|
||||
errcheck:
|
||||
# List of functions to exclude from checking, where each entry is a single function to exclude.
|
||||
# See https://github.com/kisielk/errcheck#excluding-functions for details.
|
||||
exclude-functions:
|
||||
- (*go.uber.org/zap.Logger).Sync
|
||||
- (*google.golang.org/grpc.Server).Serve
|
1025
3rdparty/aws-nitro-enclaves-ffi/Cargo.lock
generated
vendored
Normal file
1025
3rdparty/aws-nitro-enclaves-ffi/Cargo.lock
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
17
3rdparty/aws-nitro-enclaves-ffi/Cargo.toml
vendored
Normal file
17
3rdparty/aws-nitro-enclaves-ffi/Cargo.toml
vendored
Normal file
@ -0,0 +1,17 @@
|
||||
[package]
|
||||
name = "aws-nitro-enclaves-ffi"
|
||||
version = "0.1.0"
|
||||
edition = "2018"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
nsm-lib = { git = "https://github.com/aws/aws-nitro-enclaves-nsm-api", rev = "4f468c467583bbd55429935c4f09448dd43f48a0" }
|
||||
aws-nitro-enclaves-attestation-ffi = { git = "https://github.com/ppmag/aws-nitro-enclaves-attestation", rev = "83ca87233298c302973a5bdbbb394c36cd7eb6e6" }
|
||||
|
||||
[lib]
|
||||
name = "nitro"
|
||||
crate-type = ["staticlib"]
|
||||
|
||||
[profile.release]
|
||||
lto = true
|
2
3rdparty/aws-nitro-enclaves-ffi/src/lib.rs
vendored
Normal file
2
3rdparty/aws-nitro-enclaves-ffi/src/lib.rs
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
pub use nitroattest::*;
|
||||
pub use nsm::*;
|
67
CMakeLists.txt
Normal file
67
CMakeLists.txt
Normal file
@ -0,0 +1,67 @@
|
||||
cmake_minimum_required(VERSION 3.11)
|
||||
project(coordinator LANGUAGES C VERSION 0.1.0)
|
||||
|
||||
enable_testing()
|
||||
option(COORDINATOR_STATIC_MUSL "use musl and compile coordinator statically")
|
||||
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
set(CMAKE_BUILD_TYPE Debug)
|
||||
endif()
|
||||
if(CMAKE_BUILD_TYPE STREQUAL Debug)
|
||||
set(CARGOTARGET debug)
|
||||
else()
|
||||
set(CARGOTARGET release)
|
||||
set(CARGOFLAGS --release)
|
||||
endif()
|
||||
|
||||
if(COORDINATOR_STATIC_MUSL)
|
||||
set(RUST_STATICLIB_LDFLAGS -static ${RUST_STATICLIB_LDFLAGS})
|
||||
set(RUSTTARGETTRIPLE x86_64-unknown-linux-musl)
|
||||
set(CARGOFLAGS ${CARGOFLAGS} "--target=${RUSTTARGETTRIPLE}")
|
||||
set(CARGOTARGET ${RUSTTARGETTRIPLE}/${CARGOTARGET})
|
||||
else()
|
||||
set(RUST_STATICLIB_LDFLAGS -ldl -lm -lrt ${RUST_STATICLIB_LDFLAGS})
|
||||
endif()
|
||||
|
||||
set(NITRO_CFLAGS '-I${CMAKE_BINARY_DIR}/nitro/${CARGOTARGET} -I${CMAKE_BINARY_DIR}/nitro/${CARGOTARGET}/headers')
|
||||
set(NITRO_LDFLAGS '${CMAKE_BINARY_DIR}/nitro/${CARGOTARGET}/libnitro.a ${RUST_STATICLIB_LDFLAGS}')
|
||||
|
||||
#
|
||||
# coordinator
|
||||
#
|
||||
|
||||
add_custom_target(nitro
|
||||
CARGO_TARGET_DIR=${CMAKE_BINARY_DIR}/nitro cargo build ${CARGOFLAGS}
|
||||
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/3rdparty/aws-nitro-enclaves-ffi)
|
||||
|
||||
add_custom_target(coordinator ALL
|
||||
${CMAKE_COMMAND} -E env CGO_CFLAGS=${NITRO_CFLAGS}
|
||||
${CMAKE_COMMAND} -E env CGO_LDFLAGS=${NITRO_LDFLAGS}
|
||||
go build -o ${CMAKE_BINARY_DIR} -tags=aws,gcp -buildvcs=false -ldflags "-buildid='' -X main.version=${PROJECT_VERSION}"
|
||||
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/coordinator/cmd/coordinator)
|
||||
|
||||
add_dependencies(coordinator nitro)
|
||||
|
||||
#
|
||||
# cli
|
||||
#
|
||||
|
||||
add_custom_target(cli ALL
|
||||
${CMAKE_COMMAND} -E env CGO_CFLAGS=${NITRO_CFLAGS}
|
||||
${CMAKE_COMMAND} -E env CGO_LDFLAGS=${NITRO_LDFLAGS}
|
||||
go build -o ${CMAKE_BINARY_DIR}/constellation -buildvcs=false -tags=aws,gcp -ldflags "-buildid='' -X github.com/edgelesssys/constellation/cli/defaults.Version=${PROJECT_VERSION}"
|
||||
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/cli)
|
||||
|
||||
add_dependencies(cli nitro)
|
||||
|
||||
#
|
||||
# testing / debugging
|
||||
#
|
||||
|
||||
add_custom_target(debug_coordinator ALL
|
||||
go build -o ${CMAKE_BINARY_DIR}/debug_coordinator -buildvcs=false -ldflags "-buildid='' -X main.version=${PROJECT_VERSION}"
|
||||
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/coordinator/cmd/coordinator)
|
||||
|
||||
add_test(NAME unittest COMMAND go test -race -count=3 ./... WORKING_DIRECTORY ${CMAKE_SOURCE_DIR})
|
||||
add_test(NAME integrationtest COMMAND go test -v -tags integration ./test/ WORKING_DIRECTORY ${CMAKE_SOURCE_DIR})
|
||||
add_test(NAME etcd-unittest COMMAND go test -v --race -cover -count=3 -tags integration WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/coordinator/store/)
|
64
CONTRIBUTING.md
Normal file
64
CONTRIBUTING.md
Normal file
@ -0,0 +1,64 @@
|
||||
## Testing
|
||||
|
||||
Run all unit tests locally with
|
||||
|
||||
```sh
|
||||
cd build
|
||||
cmake ..
|
||||
ctest
|
||||
```
|
||||
|
||||
### E2E Test
|
||||
|
||||
Requirement: Kernel WireGuard, Docker
|
||||
```sh
|
||||
docker build -f Dockerfile.e2e -t constellation-e2e .
|
||||
```
|
||||
For the AWS test run
|
||||
```sh
|
||||
docker run -it --cap-add=NET_ADMIN --env GITHUB_TOKEN="$(cat ~/.netrc)" --env BRANCH="main" --env aws_access_key_id=XXX --env aws_secret_access_key=XXX constellation-e2e /initiateAWS.sh
|
||||
```
|
||||
For the gcp test run
|
||||
```sh
|
||||
docker run -it --cap-add=NET_ADMIN --env GITHUB_TOKEN="$(cat ~/.netrc)" --env BRANCH="main" --env GCLOUD_CREDENTIALS="$(cat ./constellation-keyfile.json)" constellation-e2e /initiategcloud.sh
|
||||
```
|
||||
|
||||
## Linting
|
||||
|
||||
This projects uses [golangci-lint](https://golangci-lint.run/) for linting.
|
||||
You can [install golangci-lint](https://golangci-lint.run/usage/install/#linux-and-windows) locally,
|
||||
but there is also a CI action to ensure compliance.
|
||||
|
||||
To locally run all configured linters, execute
|
||||
|
||||
```
|
||||
golangci-lint run ./...
|
||||
```
|
||||
|
||||
It is also recommended to use golangci-lint (and [gofumpt](https://github.com/mvdan/gofumpt) as formatter) in your IDE, by adding the recommended VS Code Settings or by [configuring it yourself](https://golangci-lint.run/usage/integrations/#editor-integration)
|
||||
|
||||
|
||||
## Recommended VS Code Settings
|
||||
|
||||
The following can be added to your personal `settings.json`, but it is recommended to add it to
|
||||
the `<REPOSITORY>/.vscode/settings.json` repo, so the settings will only affect this repository.
|
||||
|
||||
```jsonc
|
||||
// Use gofumpt as formatter.
|
||||
"gopls": {
|
||||
"formatting.gofumpt": true,
|
||||
},
|
||||
// Use golangci-lint as linter. Make sure you've installed it.
|
||||
"go.lintTool":"golangci-lint",
|
||||
"go.lintFlags": ["--fast"],
|
||||
// You can easily show Go test coverage by running a package test.
|
||||
"go.coverageOptions": "showUncoveredCodeOnly",
|
||||
// Executing unit tests with race detection.
|
||||
// You can add preferences like "-v" or "-count=1"
|
||||
"go.testFlags": ["-race"],
|
||||
// Enable language features for files with build tags.
|
||||
// Attention! This leads to integration test being executed when
|
||||
// running a package test within a package containing integration
|
||||
// tests.
|
||||
"go.buildTags": "integration",
|
||||
```
|
37
Dockerfile.build
Normal file
37
Dockerfile.build
Normal file
@ -0,0 +1,37 @@
|
||||
FROM ubuntu@sha256:7cc0576c7c0ec2384de5cbf245f41567e922aab1b075f3e8ad565f508032df17 as build
|
||||
|
||||
ENV DEBIAN_FRONTEND="noninteractive"
|
||||
RUN apt-get update && apt-get install cmake iproute2 iputils-ping wget curl git jq libssl-dev musl-tools=1.1.24-1 -y
|
||||
|
||||
# Install Go
|
||||
ARG GO_VER=1.18
|
||||
RUN wget https://go.dev/dl/go${GO_VER}.linux-amd64.tar.gz
|
||||
RUN tar -C /usr/local -xzf go${GO_VER}.linux-amd64.tar.gz && rm go${GO_VER}.linux-amd64.tar.gz
|
||||
ENV PATH ${PATH}:/usr/local/go/bin
|
||||
|
||||
# Install Rust
|
||||
ARG RUST_VER=1.58.0
|
||||
RUN curl https://sh.rustup.rs -sSf | sh -s -- -y
|
||||
ENV PATH /root/.cargo/bin:${PATH}
|
||||
RUN rustup install ${RUST_VER}
|
||||
RUN rustup override set ${RUST_VER}
|
||||
RUN rustup target add x86_64-unknown-linux-musl
|
||||
|
||||
# Download go dependencies
|
||||
WORKDIR /constellation/
|
||||
COPY go.mod ./
|
||||
COPY go.sum ./
|
||||
RUN go mod download all
|
||||
|
||||
# Copy Repo
|
||||
COPY . /constellation
|
||||
|
||||
# Build
|
||||
RUN mkdir -p /constellation/build
|
||||
WORKDIR /constellation/build
|
||||
RUN cmake -DCMAKE_BUILD_TYPE=Release -DCOORDINATOR_STATIC_MUSL=ON .. && make coordinator
|
||||
|
||||
RUN mv coordinator coordinator-$(sha512sum coordinator | cut -d " " -f 1)
|
||||
|
||||
FROM scratch AS export
|
||||
COPY --from=build /constellation/build/coordinator-* /
|
37
Dockerfile.e2e
Normal file
37
Dockerfile.e2e
Normal file
@ -0,0 +1,37 @@
|
||||
FROM ubuntu:20.04
|
||||
|
||||
ENV DEBIAN_FRONTEND="noninteractive"
|
||||
RUN apt-get update && apt-get install cmake iproute2 iputils-ping wget curl git libssl-dev -y
|
||||
|
||||
# Install kubectl
|
||||
RUN curl -fsSLo /usr/local/bin/kubectl https://dl.k8s.io/release/v1.23.0/bin/linux/amd64/kubectl && chmod +x /usr/local/bin/kubectl
|
||||
|
||||
# Install Go
|
||||
RUN wget https://go.dev/dl/go1.18.linux-amd64.tar.gz
|
||||
RUN tar -C /usr/local -xzf go1.18.linux-amd64.tar.gz && rm go1.18.linux-amd64.tar.gz
|
||||
ENV PATH ${PATH}:/usr/local/go/bin
|
||||
|
||||
# Install wireguard-tools
|
||||
RUN git clone -b v1.0.20210914 --depth=1 https://git.zx2c4.com/wireguard-tools && make -C wireguard-tools/src -j`nproc` && make -C wireguard-tools/src install
|
||||
|
||||
# Install Rust
|
||||
RUN curl https://sh.rustup.rs -sSf | sh -s -- -y
|
||||
ENV PATH /root/.cargo/bin:${PATH}
|
||||
|
||||
# Setup CLI
|
||||
RUN wg genkey | (umask 0077 && tee /privatekey) | wg pubkey > /publickey
|
||||
RUN mkdir -p /root/.config/constellation && touch /root/.config/constellation/config.json
|
||||
|
||||
# Setup AWS config
|
||||
RUN mkdir -p /root/.aws && echo "[default]\nregion = us-east-2" > /root/.aws/config && echo "[default]" >> /root/.aws/credentials
|
||||
|
||||
# Setup gcloud config
|
||||
RUN mkdir -p /root/.config/gcloud
|
||||
|
||||
# Use local constellation state
|
||||
# COPY . /constellation
|
||||
# WORKDIR /constellation
|
||||
# RUN mkdir build && cd build && cmake .. && make -j`nproc` cli
|
||||
|
||||
COPY ./test/ /
|
||||
RUN chmod +x /initiateAWS.sh && chmod +x /initiategcloud.sh
|
327
README.md
Normal file
327
README.md
Normal file
@ -0,0 +1,327 @@
|
||||
# constellation-coordinator
|
||||
|
||||
## Prerequisites
|
||||
* Go 1.18
|
||||
|
||||
### Ubuntu 20.04
|
||||
```sh
|
||||
sudo apt install build-essential cmake libssl-dev
|
||||
curl https://sh.rustup.rs -sSf | sh
|
||||
```
|
||||
|
||||
### Amazon Linux
|
||||
```sh
|
||||
sudo yum install cmake3 gcc make
|
||||
curl https://sh.rustup.rs -sSf | sh
|
||||
```
|
||||
|
||||
## Build
|
||||
```sh
|
||||
mkdir build
|
||||
cd build
|
||||
cmake ..
|
||||
make -j`nproc`
|
||||
```
|
||||
|
||||
## CMake build options:
|
||||
|
||||
### Release build
|
||||
|
||||
This options leaves out debug symbols and turns on more compiler optimizations.
|
||||
|
||||
```sh
|
||||
cmake -DCMAKE_BUILD_TYPE=Release ..
|
||||
```
|
||||
|
||||
### Static build (coordinator as static binary, no dependencies on libc or other libraries)
|
||||
|
||||
Install the musl-toolchain
|
||||
|
||||
Ubuntu / Debian:
|
||||
```sh
|
||||
sudo apt install -y musl-tools
|
||||
rustup target add x86_64-unknown-linux-musl
|
||||
```
|
||||
|
||||
From source (Amazon-Linux):
|
||||
```sh
|
||||
wget https://musl.libc.org/releases/musl-1.2.2.tar.gz
|
||||
tar xfz musl-1.2.2.tar.gz
|
||||
cd musl-1.2.2
|
||||
./configure
|
||||
make -j `nproc`
|
||||
sudo make install
|
||||
rustup target add x86_64-unknown-linux-musl
|
||||
```
|
||||
Add `musl-gcc` to your PATH:
|
||||
```sh
|
||||
export PATH=$PATH:/usr/loca/musl/bin/
|
||||
```
|
||||
|
||||
Compile the coordinator
|
||||
```sh
|
||||
cmake -DCOORDINATOR_STATIC_MUSL=ON ..
|
||||
```
|
||||
|
||||
## Cloud credentials
|
||||
|
||||
Using the CLI or debug-CLI requires the user to make authorized API calls to the AWS or GCP API.
|
||||
|
||||
### Google Cloud Platform (GCP)
|
||||
|
||||
If you are running from within a Google VM, and the VM is allowed to access the necessary APIs, no further configuration is needed.
|
||||
|
||||
Otherwise you have a couple options:
|
||||
|
||||
1. Use the `gcloud` CLI tool
|
||||
|
||||
```shell
|
||||
gcloud auth application-default login
|
||||
```
|
||||
This will ask you to log into your Google account, and then create your credentials.
|
||||
The Constellation CLI will automatically load these credentials when needed.
|
||||
|
||||
2. Set up a service account and pass the credentials manually
|
||||
|
||||
Follow [Google's guide](https://cloud.google.com/docs/authentication/production#manually) for setting up your credentials.
|
||||
|
||||
### Amazon Web Services (AWS)
|
||||
|
||||
To use the CLI with an Constellation cluster on AWS configure the following files:
|
||||
|
||||
|
||||
```bash
|
||||
$ cat ~/.aws/credentials
|
||||
[default]
|
||||
aws_access_key_id = XXXXX
|
||||
aws_secret_access_key = XXXXX
|
||||
```
|
||||
|
||||
```bash
|
||||
$ cat ~/.aws/config
|
||||
[default]
|
||||
region = us-east-2
|
||||
```
|
||||
|
||||
### Azure
|
||||
|
||||
To use the CLI with an Constellation cluster on Azure execute:
|
||||
```bash
|
||||
az login
|
||||
```
|
||||
|
||||
### Deploying a locally compiled coordinator binary
|
||||
|
||||
By default, `constellation create ...` will spawn cloud provider instances with a pre-baked coordinator binary.
|
||||
For testing, you can use the constellation debug daemon (debugd) to upload your local coordinator binary to running instances and to obtain SSH access.
|
||||
See this introduction on how to install and setup `cdbg`: https://github.com/edgelesssys/constellation/debugd/#readme
|
||||
# constellation-debugd
|
||||
|
||||
## Prerequisites
|
||||
|
||||
* Go 1.18
|
||||
|
||||
## Build
|
||||
|
||||
```
|
||||
git clone https://github.com/edgelesssys/constellation/debugd
|
||||
cd constellation-debugd
|
||||
go build -o constellation-debugd debugd/cmd/debugd.go
|
||||
go build -o constellation-cdbg cdbg/cdbg.go
|
||||
```
|
||||
|
||||
## Install cdbg
|
||||
|
||||
```
|
||||
go install github.com/edgelesssys/constellation/debugd/cdbg@latest
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
With `cdbg` installed in your path:
|
||||
|
||||
1. Run `constellation --dev-config /path/to/dev-config create […]` while specifying a cloud-provider image with the debugd already included. See [Configuration](#configuration) for a dev-config with a custom image and firewall rules to allow incoming connection on the debugd default port 4000.
|
||||
2. Run `cdbg deploy --dev-config /path/to/dev-config`
|
||||
3. Run `constellation init […]` as usual
|
||||
|
||||
|
||||
|
||||
### GCP image
|
||||
|
||||
For GCP, run the following command to get a list of all constellation images, sorted by their creation date:
|
||||
```
|
||||
gcloud compute images list --filter="name~'constellation-.+'" --sort-by=~creationTimestamp
|
||||
```
|
||||
Choose the newest debugd image with the naming scheme `constellation-coreos-debugd-<timestamp>`.
|
||||
|
||||
### Azure Image
|
||||
|
||||
For Azure, run the following command to get a list of all constellation debugd images, sorted by their creation date:
|
||||
```
|
||||
az sig image-version list --resource-group constellation-images --gallery-name Constellation --gallery-image-definition constellation-coreos-debugd --query "sort_by([], &publishingProfile.publishedDate)[].id" -o table
|
||||
```
|
||||
Choose the newest debugd image and copy the full URI.
|
||||
|
||||
## Configuration
|
||||
|
||||
You should first locate the newest debugd image for your cloud provider ([GCP](#gcp-image), [Azure](#azure-image)).
|
||||
|
||||
This tool uses the dev-config file from `constellation-coordinator` and extends it with more fields.
|
||||
See this example on what the possible settings are and how to setup the constellation cli to use a cloud-provider image and firewall rules with support for debugd:
|
||||
```json
|
||||
{
|
||||
"cdbg":{
|
||||
"authorized_keys":[
|
||||
{
|
||||
"user":"my-username",
|
||||
"pubkey":"ssh-rsa AAAAB…LJuM="
|
||||
}
|
||||
],
|
||||
"coordinator_path":"/path/to/coordinator",
|
||||
"systemd_units":[
|
||||
{
|
||||
"name":"some-custom.service",
|
||||
"contents":"[Unit]\nDescription=…"
|
||||
}
|
||||
]
|
||||
},
|
||||
"provider": {
|
||||
"gcpconfig": {
|
||||
"image": "constellation-coreos-debugd-TIMESTAMP",
|
||||
"firewallinput": {
|
||||
"Ingress": [
|
||||
{
|
||||
"Name": "coordinator",
|
||||
"Description": "Coordinator default port",
|
||||
"Protocol": "tcp",
|
||||
"Port": 9000
|
||||
},
|
||||
{
|
||||
"Name": "wireguard",
|
||||
"Description": "WireGuard default port",
|
||||
"Protocol": "udp",
|
||||
"Port": 51820
|
||||
},
|
||||
{
|
||||
"Name": "ssh",
|
||||
"Description": "SSH",
|
||||
"Protocol": "tcp",
|
||||
"Port": 22
|
||||
},
|
||||
{
|
||||
"Name": "debugd",
|
||||
"Description": "debugd default port",
|
||||
"Protocol": "tcp",
|
||||
"Port": 4000
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"azureconfig": {
|
||||
"image": "/subscriptions/0d202bbb-4fa7-4af8-8125-58c269a05435/resourceGroups/CONSTELLATION-IMAGES/providers/Microsoft.Compute/galleries/Constellation/images/constellation-coreos-debugd/versions/0.0.TIMESTAMP",
|
||||
"networksecuritygroupinput": {
|
||||
"Ingress": [
|
||||
{
|
||||
"Name": "coordinator",
|
||||
"Description": "Coordinator default port",
|
||||
"Protocol": "tcp",
|
||||
"IPRange": "0.0.0.0/0",
|
||||
"Port": 9000
|
||||
},
|
||||
{
|
||||
"Name": "wireguard",
|
||||
"Description": "WireGuard default port",
|
||||
"Protocol": "udp",
|
||||
"IPRange": "0.0.0.0/0",
|
||||
"Port": 51820
|
||||
},
|
||||
{
|
||||
"Name": "ssh",
|
||||
"Description": "SSH",
|
||||
"Protocol": "tcp",
|
||||
"IPRange": "0.0.0.0/0",
|
||||
"Port": 22
|
||||
},
|
||||
{
|
||||
"Name": "debugd",
|
||||
"Description": "debugd default port",
|
||||
"Protocol": "tcp",
|
||||
"IPRange": "0.0.0.0/0",
|
||||
"Port": 4000
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
# constellation-kms-client
|
||||
|
||||
This library provides an interface for the key management services used with constellation.
|
||||
It's intendet for the Constellation CSI Plugins and the CLI.
|
||||
|
||||
## KMS
|
||||
|
||||
The Cloud KMS is where we store our key encryption key (KEK).
|
||||
It should be initiated by the CLI and provided with a key release policy.
|
||||
The CSP Plugin can request to encrypt data encryption keys (DEK) with the DEK to safely store them on persistent memory.
|
||||
The [kms](pkg/kms) package interacts with the Cloud KMS APIs.
|
||||
Currently planned are KMS are:
|
||||
|
||||
* AWS KMS
|
||||
* GCP CKM
|
||||
* Azure Key Vault
|
||||
|
||||
|
||||
## Storage
|
||||
|
||||
Storage is where the CSI Plugin stores the encrypted DEKs.
|
||||
Currently planned are:
|
||||
|
||||
* AWS S3, SSP
|
||||
* GCP GCS
|
||||
* Azure Blob
|
||||
# constellation-images
|
||||
# constellation-mount-utils
|
||||
Wrapper for https://github.com/kubernetes/mount-utils
|
||||
|
||||
|
||||
## Dependencies
|
||||
|
||||
This package uses the C library [`libcryptsetup`](https://gitlab.com/cryptsetup/cryptsetup/) for device mapping.
|
||||
|
||||
To install the required dependencies on Ubuntu run:
|
||||
```shell
|
||||
sudo apt install libcryptsetup-dev
|
||||
```
|
||||
|
||||
To install or upgrade `go.mod` dependencies from private repositories run:
|
||||
```
|
||||
GOPRIVATE=github.com/edgelesssys/constellation-coordinator go get github.com/edgelesssys/constellation-coordinator
|
||||
GOPRIVATE=github.com/edgelesssys/constellation-kms-client go get github.com/edgelesssys/constellation-kms-client
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
A small test programm is available in `test/main.go`.
|
||||
To build the programm run:
|
||||
```shell
|
||||
go build -o test/crypt ./test/
|
||||
```
|
||||
|
||||
Create a new crypt device for `/dev/sdX` and map it to `/dev/mapper/volume01`:
|
||||
```shell
|
||||
sudo test/crypt -source /dev/sdX -target volume01 -v 4
|
||||
```
|
||||
|
||||
You can now interact with the mapped volume as if it was an unformatted device:
|
||||
```shell
|
||||
sudo mkfs.ext4 /dev/mapper/volume01
|
||||
sudo mount /dev/mapper/volume01 /mnt/volume01
|
||||
```
|
||||
|
||||
Close the mapped volume:
|
||||
```shell
|
||||
sudo umount /mnt/volume01
|
||||
sudo test/crypt -c -target volume01 -v 4
|
||||
```
|
28
cli/README.md
Normal file
28
cli/README.md
Normal file
@ -0,0 +1,28 @@
|
||||
# CLI to spawn a confidential kubernetes cluster
|
||||
|
||||
## Usage
|
||||
|
||||
0. (optional) replace the responsible in `cli/cmd/defaults.go` with yourself.
|
||||
1. Build the CLI and authenticate with <AWS/Azure/GCP> according to the [README.md](https://github.com/edgelesssys/constellation-coordinator/blob/main/README.md#cloud-credentials).
|
||||
2. Execute `constellation create <aws/azure/gcp> 2 <4xlarge|n2d-standard-2>`.
|
||||
3. Execute `wg genkey | tee privatekey | wg pubkey > publickey` to generate a WireGuard keypair.
|
||||
4. Execute `constellation init --publickey publickey`. Since the CLI waits for all nodes to be ready, this step can take up to 5 minutes.
|
||||
5. Use the output from `constellation init` and the wireguard template below to create `/etc/wireguard/wg0.conf`, then execute `wg-quick up wg0`.
|
||||
6. Execute `export KUBECONFIG=<path/to/admin.conf>`.
|
||||
7. Use `kubectl get nodes` to inspect your cluster.
|
||||
8. Execute `constellation terminate` to terminate your Constellation.
|
||||
|
||||
```bash
|
||||
[Interface]
|
||||
Address = <address from the init output>
|
||||
PrivateKey = <your base64 encoded private key>
|
||||
ListenPort = 51820
|
||||
|
||||
[Peer]
|
||||
PublicKey = <public key from the init output>
|
||||
AllowedIPs = 10.118.0.1/32 # IP set on the peer's wg interface
|
||||
Endpoint = <public IPv4 address from the activated coordinator>:51820 # address where the peer listens on
|
||||
PersistentKeepalive = 10
|
||||
```
|
||||
|
||||
Note: Skip the manual configuration of WireGuard by executing Step 2 as root. Then, replace steps 4 and 5 with `sudo constellation init --privatekey <path/to/your/privatekey>`. This will automatically configure a new WireGuard interface named wg0 with the coordinator as peer.
|
203
cli/azure/client/activedirectory.go
Normal file
203
cli/azure/client/activedirectory.go
Normal file
@ -0,0 +1,203 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
|
||||
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
|
||||
"github.com/Azure/go-autorest/autorest"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
"github.com/Azure/go-autorest/autorest/date"
|
||||
"github.com/Azure/go-autorest/autorest/to"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
adAppCredentialValidity = time.Hour * 24 * 365 * 5 // ~5 years
|
||||
adReplicationLagCheckInterval = time.Second * 5 // 5 seconds
|
||||
adReplicationLagCheckMaxRetries = int((15 * time.Minute) / adReplicationLagCheckInterval) // wait for up to 15 minutes for AD replication
|
||||
ownerRoleDefinitionID = "8e3af657-a8ff-443c-a75c-2fe8c4bcb635" // https://docs.microsoft.com/en-us/azure/role-based-access-control/built-in-roles#owner
|
||||
virtualMachineContributorRoleDefinitionID = "9980e02c-c2be-4d73-94e8-173b1dc7cf3c" // https://docs.microsoft.com/en-us/azure/role-based-access-control/built-in-roles#virtual-machine-contributor
|
||||
)
|
||||
|
||||
// CreateServicePrincipal creates an Azure AD app with a service principal, gives it "Owner" role on the resource group and creates new credentials.
|
||||
func (c *Client) CreateServicePrincipal(ctx context.Context) (string, error) {
|
||||
createAppRes, err := c.createADApplication(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
c.adAppObjectID = createAppRes.ObjectID
|
||||
servicePrincipalObjectID, err := c.createAppServicePrincipal(ctx, createAppRes.AppID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := c.assignResourceGroupRole(ctx, servicePrincipalObjectID, ownerRoleDefinitionID); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
clientSecret, err := c.updateAppCredentials(ctx, createAppRes.ObjectID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return ApplicationCredentials{
|
||||
ClientID: createAppRes.AppID,
|
||||
ClientSecret: clientSecret,
|
||||
}.ConvertToCloudServiceAccountURI(), nil
|
||||
}
|
||||
|
||||
// TerminateServicePrincipal terminates an Azure AD app together with the service principal.
|
||||
func (c *Client) TerminateServicePrincipal(ctx context.Context) error {
|
||||
if c.adAppObjectID == "" {
|
||||
return nil
|
||||
}
|
||||
if _, err := c.applicationsAPI.Delete(ctx, c.adAppObjectID); err != nil {
|
||||
return err
|
||||
}
|
||||
c.adAppObjectID = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
// createADApplication creates a new azure AD app.
|
||||
func (c *Client) createADApplication(ctx context.Context) (createADApplicationOutput, error) {
|
||||
createParameters := graphrbac.ApplicationCreateParameters{
|
||||
AvailableToOtherTenants: to.BoolPtr(false),
|
||||
DisplayName: to.StringPtr("constellation-app-" + c.name + "-" + c.uid),
|
||||
}
|
||||
app, err := c.applicationsAPI.Create(ctx, createParameters)
|
||||
if err != nil {
|
||||
return createADApplicationOutput{}, err
|
||||
}
|
||||
if app.AppID == nil || app.ObjectID == nil {
|
||||
return createADApplicationOutput{}, errors.New("creating AD application did not result in valid app id and object id")
|
||||
}
|
||||
return createADApplicationOutput{
|
||||
AppID: *app.AppID,
|
||||
ObjectID: *app.ObjectID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// createAppServicePrincipal creates a new service principal for an azure AD app.
|
||||
func (c *Client) createAppServicePrincipal(ctx context.Context, appID string) (string, error) {
|
||||
createParameters := graphrbac.ServicePrincipalCreateParameters{
|
||||
AppID: &appID,
|
||||
AccountEnabled: to.BoolPtr(true),
|
||||
}
|
||||
servicePrincipal, err := c.servicePrincipalsAPI.Create(ctx, createParameters)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if servicePrincipal.ObjectID == nil {
|
||||
return "", errors.New("creating AD service principal did not result in a valid object id")
|
||||
}
|
||||
return *servicePrincipal.ObjectID, nil
|
||||
}
|
||||
|
||||
// updateAppCredentials sets app client-secret for authentication.
|
||||
func (c *Client) updateAppCredentials(ctx context.Context, objectID string) (string, error) {
|
||||
keyID := uuid.New().String()
|
||||
clientSecret, err := generateClientSecret()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generating client secret failed: %w", err)
|
||||
}
|
||||
updateParameters := graphrbac.PasswordCredentialsUpdateParameters{
|
||||
Value: &[]graphrbac.PasswordCredential{
|
||||
{
|
||||
StartDate: &date.Time{Time: time.Now()},
|
||||
EndDate: &date.Time{Time: time.Now().Add(adAppCredentialValidity)},
|
||||
Value: to.StringPtr(clientSecret),
|
||||
KeyID: to.StringPtr(keyID),
|
||||
},
|
||||
},
|
||||
}
|
||||
_, err = c.applicationsAPI.UpdatePasswordCredentials(ctx, objectID, updateParameters)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return clientSecret, nil
|
||||
}
|
||||
|
||||
// assignResourceGroupRole assigns the service principal a role at resource group scope.
|
||||
func (c *Client) assignResourceGroupRole(ctx context.Context, principalID, roleDefinitionID string) error {
|
||||
resourceGroup, err := c.resourceGroupAPI.Get(ctx, c.resourceGroup, nil)
|
||||
if err != nil || resourceGroup.ID == nil {
|
||||
return fmt.Errorf("unable to retrieve resource group id for group %v: %w", c.resourceGroup, err)
|
||||
}
|
||||
roleAssignmentID := uuid.New().String()
|
||||
createParameters := authorization.RoleAssignmentCreateParameters{
|
||||
Properties: &authorization.RoleAssignmentProperties{
|
||||
PrincipalID: to.StringPtr(principalID),
|
||||
RoleDefinitionID: to.StringPtr(fmt.Sprintf("/subscriptions/%s/providers/Microsoft.Authorization/roleDefinitions/%s", c.subscriptionID, roleDefinitionID)),
|
||||
},
|
||||
}
|
||||
|
||||
// due to an azure AD replication lag, retry role assignment if principal does not exist yet
|
||||
// reference: https://docs.microsoft.com/en-us/azure/role-based-access-control/role-assignments-rest#new-service-principal
|
||||
// proper fix: use API version 2018-09-01-preview or later
|
||||
// azure go sdk currently uses version 2015-07-01: https://github.com/Azure/azure-sdk-for-go/blob/v62.0.0/services/authorization/mgmt/2015-07-01/authorization/roleassignments.go#L95
|
||||
// the newer version "armauthorization.RoleAssignmentsClient" is currently broken: https://github.com/Azure/azure-sdk-for-go/issues/17071
|
||||
for i := 0; i < c.adReplicationLagCheckMaxRetries; i++ {
|
||||
_, err = c.roleAssignmentsAPI.Create(ctx, *resourceGroup.ID, roleAssignmentID, createParameters)
|
||||
var detailedErr autorest.DetailedError
|
||||
var ok bool
|
||||
if detailedErr, ok = err.(autorest.DetailedError); !ok {
|
||||
return err
|
||||
}
|
||||
var requestErr *azure.RequestError
|
||||
if requestErr, ok = detailedErr.Original.(*azure.RequestError); !ok || requestErr.ServiceError == nil {
|
||||
return err
|
||||
}
|
||||
if requestErr.ServiceError.Code != "PrincipalNotFound" {
|
||||
return err
|
||||
}
|
||||
time.Sleep(c.adReplicationLagCheckInterval)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// ApplicationCredentials is a set of Azure AD application credentials.
|
||||
// It is the equivalent of a service account key in other cloud providers.
|
||||
type ApplicationCredentials struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
}
|
||||
|
||||
// ConvertToCloudServiceAccountURI converts the ApplicationCredentials into a cloud service account URI.
|
||||
func (c ApplicationCredentials) ConvertToCloudServiceAccountURI() string {
|
||||
query := url.Values{}
|
||||
query.Add("client_id", c.ClientID)
|
||||
query.Add("client_secret", c.ClientSecret)
|
||||
uri := url.URL{
|
||||
Scheme: "serviceaccount",
|
||||
Host: "azure",
|
||||
RawQuery: query.Encode(),
|
||||
}
|
||||
return uri.String()
|
||||
}
|
||||
|
||||
type createADApplicationOutput struct {
|
||||
AppID string
|
||||
ObjectID string
|
||||
}
|
||||
|
||||
func generateClientSecret() (string, error) {
|
||||
letters := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
|
||||
|
||||
pwLen := 64
|
||||
pw := make([]byte, 0, pwLen)
|
||||
for i := 0; i < pwLen; i++ {
|
||||
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
pw = append(pw, letters[n.Int64()])
|
||||
}
|
||||
return string(pw), nil
|
||||
}
|
380
cli/azure/client/activedirectory_test.go
Normal file
380
cli/azure/client/activedirectory_test.go
Normal file
@ -0,0 +1,380 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
|
||||
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
|
||||
"github.com/Azure/go-autorest/autorest"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
func TestCreateServicePrincipal(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
testCases := map[string]struct {
|
||||
applicationsAPI applicationsAPI
|
||||
servicePrincipalsAPI servicePrincipalsAPI
|
||||
roleAssignmentsAPI roleAssignmentsAPI
|
||||
resourceGroupAPI resourceGroupAPI
|
||||
errExpected bool
|
||||
}{
|
||||
"successful create": {
|
||||
applicationsAPI: stubApplicationsAPI{},
|
||||
servicePrincipalsAPI: stubServicePrincipalsAPI{},
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
||||
resourceGroupAPI: stubResourceGroupAPI{
|
||||
getResourceGroup: armresources.ResourceGroup{
|
||||
ID: to.StringPtr("resource-group-id"),
|
||||
},
|
||||
},
|
||||
},
|
||||
"failed app create": {
|
||||
applicationsAPI: stubApplicationsAPI{
|
||||
createErr: someErr,
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"failed service principal create": {
|
||||
applicationsAPI: stubApplicationsAPI{},
|
||||
servicePrincipalsAPI: stubServicePrincipalsAPI{
|
||||
createErr: someErr,
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"failed role assignment": {
|
||||
applicationsAPI: stubApplicationsAPI{},
|
||||
servicePrincipalsAPI: stubServicePrincipalsAPI{},
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
|
||||
createErrors: []error{someErr},
|
||||
},
|
||||
resourceGroupAPI: stubResourceGroupAPI{
|
||||
getResourceGroup: armresources.ResourceGroup{
|
||||
ID: to.StringPtr("resource-group-id"),
|
||||
},
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"failed update creds": {
|
||||
applicationsAPI: stubApplicationsAPI{
|
||||
updateCredentialsErr: someErr,
|
||||
},
|
||||
servicePrincipalsAPI: stubServicePrincipalsAPI{},
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
||||
resourceGroupAPI: stubResourceGroupAPI{
|
||||
getResourceGroup: armresources.ResourceGroup{
|
||||
ID: to.StringPtr("resource-group-id"),
|
||||
},
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
client := Client{
|
||||
name: "name",
|
||||
uid: "uid",
|
||||
resourceGroup: "resource-group",
|
||||
applicationsAPI: tc.applicationsAPI,
|
||||
servicePrincipalsAPI: tc.servicePrincipalsAPI,
|
||||
roleAssignmentsAPI: tc.roleAssignmentsAPI,
|
||||
resourceGroupAPI: tc.resourceGroupAPI,
|
||||
adReplicationLagCheckMaxRetries: 2,
|
||||
}
|
||||
|
||||
_, err := client.CreateServicePrincipal(ctx)
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.NoError(err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTerminateServicePrincipal(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
testCases := map[string]struct {
|
||||
appObjectID string
|
||||
applicationsAPI applicationsAPI
|
||||
errExpected bool
|
||||
}{
|
||||
"successful terminate": {
|
||||
appObjectID: "object-id",
|
||||
applicationsAPI: stubApplicationsAPI{},
|
||||
},
|
||||
"nothing to terminate": {
|
||||
applicationsAPI: stubApplicationsAPI{},
|
||||
},
|
||||
"failed delete": {
|
||||
appObjectID: "object-id",
|
||||
applicationsAPI: stubApplicationsAPI{
|
||||
deleteErr: someErr,
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
client := Client{
|
||||
name: "name",
|
||||
uid: "uid",
|
||||
resourceGroup: "resource-group",
|
||||
adAppObjectID: tc.appObjectID,
|
||||
applicationsAPI: tc.applicationsAPI,
|
||||
}
|
||||
|
||||
err := client.TerminateServicePrincipal(ctx)
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.NoError(err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateADApplication(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
testCases := map[string]struct {
|
||||
applicationsAPI applicationsAPI
|
||||
errExpected bool
|
||||
}{
|
||||
"successful create": {
|
||||
applicationsAPI: stubApplicationsAPI{},
|
||||
},
|
||||
"failed app create": {
|
||||
applicationsAPI: stubApplicationsAPI{
|
||||
createErr: someErr,
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"app create returns invalid appid": {
|
||||
applicationsAPI: stubApplicationsAPI{
|
||||
createApplication: &graphrbac.Application{
|
||||
ObjectID: proto.String("00000000-0000-0000-0000-000000000001"),
|
||||
},
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"app create returns invalid objectid": {
|
||||
applicationsAPI: stubApplicationsAPI{
|
||||
createApplication: &graphrbac.Application{
|
||||
AppID: proto.String("00000000-0000-0000-0000-000000000000"),
|
||||
},
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
client := Client{
|
||||
name: "name",
|
||||
uid: "uid",
|
||||
applicationsAPI: tc.applicationsAPI,
|
||||
}
|
||||
|
||||
appCredentials, err := client.createADApplication(ctx)
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.NoError(err)
|
||||
assert.NotNil(appCredentials)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateAppServicePrincipal(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
testCases := map[string]struct {
|
||||
servicePrincipalsAPI servicePrincipalsAPI
|
||||
errExpected bool
|
||||
}{
|
||||
"successful create": {
|
||||
servicePrincipalsAPI: stubServicePrincipalsAPI{},
|
||||
},
|
||||
"failed service principal create": {
|
||||
servicePrincipalsAPI: stubServicePrincipalsAPI{
|
||||
createErr: someErr,
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"service principal create returns invalid objectid": {
|
||||
servicePrincipalsAPI: stubServicePrincipalsAPI{
|
||||
createServicePrincipal: &graphrbac.ServicePrincipal{},
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
client := Client{
|
||||
name: "name",
|
||||
uid: "uid",
|
||||
servicePrincipalsAPI: tc.servicePrincipalsAPI,
|
||||
}
|
||||
|
||||
_, err := client.createAppServicePrincipal(ctx, "app-id")
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.NoError(err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssignOwnerOfResourceGroup(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
testCases := map[string]struct {
|
||||
roleAssignmentsAPI roleAssignmentsAPI
|
||||
resourceGroupAPI resourceGroupAPI
|
||||
errExpected bool
|
||||
}{
|
||||
"successful assign": {
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
||||
resourceGroupAPI: stubResourceGroupAPI{
|
||||
getResourceGroup: armresources.ResourceGroup{
|
||||
ID: to.StringPtr("resource-group-id"),
|
||||
},
|
||||
},
|
||||
},
|
||||
"failed role assignment": {
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
|
||||
createErrors: []error{someErr},
|
||||
},
|
||||
resourceGroupAPI: stubResourceGroupAPI{
|
||||
getResourceGroup: armresources.ResourceGroup{
|
||||
ID: to.StringPtr("resource-group-id"),
|
||||
},
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"failed resource group get": {
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
||||
resourceGroupAPI: stubResourceGroupAPI{
|
||||
getErr: someErr,
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"resource group get returns invalid id": {
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
||||
resourceGroupAPI: stubResourceGroupAPI{
|
||||
getResourceGroup: armresources.ResourceGroup{},
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"create returns PrincipalNotFound the first time": {
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
|
||||
createErrors: []error{
|
||||
autorest.DetailedError{Original: &azure.RequestError{
|
||||
ServiceError: &azure.ServiceError{
|
||||
Code: "PrincipalNotFound",
|
||||
},
|
||||
}},
|
||||
nil,
|
||||
},
|
||||
},
|
||||
resourceGroupAPI: stubResourceGroupAPI{
|
||||
getResourceGroup: armresources.ResourceGroup{
|
||||
ID: to.StringPtr("resource-group-id"),
|
||||
},
|
||||
},
|
||||
},
|
||||
"create does not return request error": {
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
|
||||
createErrors: []error{autorest.DetailedError{Original: someErr}},
|
||||
},
|
||||
resourceGroupAPI: stubResourceGroupAPI{
|
||||
getResourceGroup: armresources.ResourceGroup{
|
||||
ID: to.StringPtr("resource-group-id"),
|
||||
},
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"create service error code is unknown": {
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
|
||||
createErrors: []error{
|
||||
autorest.DetailedError{Original: &azure.RequestError{
|
||||
ServiceError: &azure.ServiceError{
|
||||
Code: "some-unknown-error-code",
|
||||
},
|
||||
}},
|
||||
nil,
|
||||
},
|
||||
},
|
||||
resourceGroupAPI: stubResourceGroupAPI{
|
||||
getResourceGroup: armresources.ResourceGroup{
|
||||
ID: to.StringPtr("resource-group-id"),
|
||||
},
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
client := Client{
|
||||
name: "name",
|
||||
uid: "uid",
|
||||
resourceGroup: "resource-group",
|
||||
roleAssignmentsAPI: tc.roleAssignmentsAPI,
|
||||
resourceGroupAPI: tc.resourceGroupAPI,
|
||||
adReplicationLagCheckMaxRetries: 2,
|
||||
}
|
||||
|
||||
err := client.assignResourceGroupRole(ctx, "principal-id", "role-definition-id")
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.NoError(err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertToCloudServiceAccountURI(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
key := ApplicationCredentials{
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
}
|
||||
|
||||
cloudServiceAccountURI := key.ConvertToCloudServiceAccountURI()
|
||||
uri, err := url.Parse(cloudServiceAccountURI)
|
||||
require.NoError(err)
|
||||
query := uri.Query()
|
||||
assert.Equal("serviceaccount", uri.Scheme)
|
||||
assert.Equal("azure", uri.Host)
|
||||
assert.Equal(url.Values{
|
||||
"client_id": []string{"client-id"},
|
||||
"client_secret": []string{"client-secret"},
|
||||
}, query)
|
||||
}
|
129
cli/azure/client/api.go
Normal file
129
cli/azure/client/api.go
Normal file
@ -0,0 +1,129 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
|
||||
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
|
||||
"github.com/Azure/go-autorest/autorest"
|
||||
)
|
||||
|
||||
type virtualNetworksCreateOrUpdatePollerResponse interface {
|
||||
PollUntilDone(ctx context.Context, freq time.Duration) (armnetwork.VirtualNetworksClientCreateOrUpdateResponse, error)
|
||||
}
|
||||
|
||||
type networksAPI interface {
|
||||
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
|
||||
virtualNetworkName string, parameters armnetwork.VirtualNetwork,
|
||||
options *armnetwork.VirtualNetworksClientBeginCreateOrUpdateOptions) (
|
||||
virtualNetworksCreateOrUpdatePollerResponse, error)
|
||||
}
|
||||
|
||||
type networkSecurityGroupsCreateOrUpdatePollerResponse interface {
|
||||
PollUntilDone(ctx context.Context, freq time.Duration) (armnetwork.SecurityGroupsClientCreateOrUpdateResponse, error)
|
||||
}
|
||||
|
||||
type networkSecurityGroupsAPI interface {
|
||||
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
|
||||
networkSecurityGroupName string, parameters armnetwork.SecurityGroup,
|
||||
options *armnetwork.SecurityGroupsClientBeginCreateOrUpdateOptions) (
|
||||
networkSecurityGroupsCreateOrUpdatePollerResponse, error)
|
||||
}
|
||||
|
||||
type virtualMachineScaleSetsCreateOrUpdatePollerResponse interface {
|
||||
PollUntilDone(ctx context.Context, freq time.Duration) (armcompute.VirtualMachineScaleSetsClientCreateOrUpdateResponse, error)
|
||||
}
|
||||
|
||||
type scaleSetsAPI interface {
|
||||
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
|
||||
vmScaleSetName string, parameters armcompute.VirtualMachineScaleSet,
|
||||
options *armcompute.VirtualMachineScaleSetsClientBeginCreateOrUpdateOptions) (
|
||||
virtualMachineScaleSetsCreateOrUpdatePollerResponse, error)
|
||||
}
|
||||
|
||||
type publicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager interface {
|
||||
NextPage(ctx context.Context) bool
|
||||
PageResponse() armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResponse
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
type publicIPAddressesClientCreateOrUpdatePollerResponse interface {
|
||||
PollUntilDone(ctx context.Context, freq time.Duration) (armnetwork.PublicIPAddressesClientCreateOrUpdateResponse, error)
|
||||
}
|
||||
|
||||
type publicIPAddressesAPI interface {
|
||||
ListVirtualMachineScaleSetVMPublicIPAddresses(resourceGroupName string,
|
||||
virtualMachineScaleSetName string, virtualmachineIndex string,
|
||||
networkInterfaceName string, ipConfigurationName string,
|
||||
options *armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesOptions,
|
||||
) publicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, publicIPAddressName string,
|
||||
parameters armnetwork.PublicIPAddress, options *armnetwork.PublicIPAddressesClientBeginCreateOrUpdateOptions) (
|
||||
publicIPAddressesClientCreateOrUpdatePollerResponse, error)
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
Get(ctx context.Context, resourceGroupName string, publicIPAddressName string, options *armnetwork.PublicIPAddressesClientGetOptions) (
|
||||
armnetwork.PublicIPAddressesClientGetResponse, error)
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
type interfacesClientCreateOrUpdatePollerResponse interface {
|
||||
PollUntilDone(ctx context.Context, freq time.Duration) (armnetwork.InterfacesClientCreateOrUpdateResponse, error)
|
||||
}
|
||||
|
||||
type networkInterfacesAPI interface {
|
||||
GetVirtualMachineScaleSetNetworkInterface(ctx context.Context, resourceGroupName string,
|
||||
virtualMachineScaleSetName string, virtualmachineIndex string, networkInterfaceName string,
|
||||
options *armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceOptions,
|
||||
) (armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse, error)
|
||||
// TODO: deprecate as soon as scale sets are available
|
||||
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, networkInterfaceName string,
|
||||
parameters armnetwork.Interface, options *armnetwork.InterfacesClientBeginCreateOrUpdateOptions) (
|
||||
interfacesClientCreateOrUpdatePollerResponse, error)
|
||||
}
|
||||
|
||||
type resourceGroupsDeletePollerResponse interface {
|
||||
PollUntilDone(ctx context.Context, freq time.Duration) (armresources.ResourceGroupsClientDeleteResponse, error)
|
||||
}
|
||||
|
||||
type resourceGroupAPI interface {
|
||||
CreateOrUpdate(ctx context.Context, resourceGroupName string,
|
||||
parameters armresources.ResourceGroup,
|
||||
options *armresources.ResourceGroupsClientCreateOrUpdateOptions) (
|
||||
armresources.ResourceGroupsClientCreateOrUpdateResponse, error)
|
||||
BeginDelete(ctx context.Context, resourceGroupName string,
|
||||
options *armresources.ResourceGroupsClientBeginDeleteOptions) (
|
||||
resourceGroupsDeletePollerResponse, error)
|
||||
Get(ctx context.Context, resourceGroupName string, options *armresources.ResourceGroupsClientGetOptions) (armresources.ResourceGroupsClientGetResponse, error)
|
||||
}
|
||||
|
||||
type applicationsAPI interface {
|
||||
Create(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error)
|
||||
Delete(ctx context.Context, applicationObjectID string) (autorest.Response, error)
|
||||
UpdatePasswordCredentials(ctx context.Context, objectID string, parameters graphrbac.PasswordCredentialsUpdateParameters) (autorest.Response, error)
|
||||
}
|
||||
|
||||
type servicePrincipalsAPI interface {
|
||||
Create(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error)
|
||||
}
|
||||
|
||||
// the newer version "armauthorization.RoleAssignmentsClient" is currently broken: https://github.com/Azure/azure-sdk-for-go/issues/17071
|
||||
// TODO: switch to "armauthorization.RoleAssignmentsClient" if issue is resolved.
|
||||
type roleAssignmentsAPI interface {
|
||||
Create(ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error)
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
type virtualMachinesClientCreateOrUpdatePollerResponse interface {
|
||||
PollUntilDone(ctx context.Context, freq time.Duration) (armcompute.VirtualMachinesClientCreateOrUpdateResponse, error)
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
type virtualMachinesAPI interface {
|
||||
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, vmName string, parameters armcompute.VirtualMachine,
|
||||
options *armcompute.VirtualMachinesClientBeginCreateOrUpdateOptions) (virtualMachinesClientCreateOrUpdatePollerResponse, error)
|
||||
}
|
388
cli/azure/client/api_test.go
Normal file
388
cli/azure/client/api_test.go
Normal file
@ -0,0 +1,388 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
|
||||
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
|
||||
"github.com/Azure/go-autorest/autorest"
|
||||
)
|
||||
|
||||
type stubNetworksAPI struct {
|
||||
createErr error
|
||||
stubResponse stubVirtualNetworksCreateOrUpdatePollerResponse
|
||||
}
|
||||
|
||||
type stubVirtualNetworksCreateOrUpdatePollerResponse struct {
|
||||
armnetwork.VirtualNetworksClientCreateOrUpdatePollerResponse
|
||||
pollerErr error
|
||||
}
|
||||
|
||||
func (r stubVirtualNetworksCreateOrUpdatePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration,
|
||||
) (armnetwork.VirtualNetworksClientCreateOrUpdateResponse, error) {
|
||||
return armnetwork.VirtualNetworksClientCreateOrUpdateResponse{
|
||||
VirtualNetworksClientCreateOrUpdateResult: armnetwork.VirtualNetworksClientCreateOrUpdateResult{
|
||||
VirtualNetwork: armnetwork.VirtualNetwork{
|
||||
Properties: &armnetwork.VirtualNetworkPropertiesFormat{
|
||||
Subnets: []*armnetwork.Subnet{
|
||||
{
|
||||
ID: to.StringPtr("virtual-network-subnet-id"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, r.pollerErr
|
||||
}
|
||||
|
||||
func (a stubNetworksAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
|
||||
virtualNetworkName string, parameters armnetwork.VirtualNetwork,
|
||||
options *armnetwork.VirtualNetworksClientBeginCreateOrUpdateOptions) (
|
||||
virtualNetworksCreateOrUpdatePollerResponse, error,
|
||||
) {
|
||||
return a.stubResponse, a.createErr
|
||||
}
|
||||
|
||||
type stubNetworkSecurityGroupsCreateOrUpdatePollerResponse struct {
|
||||
armnetwork.SecurityGroupsClientCreateOrUpdatePollerResponse
|
||||
pollerErr error
|
||||
}
|
||||
|
||||
func (r stubNetworkSecurityGroupsCreateOrUpdatePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration,
|
||||
) (armnetwork.SecurityGroupsClientCreateOrUpdateResponse, error) {
|
||||
return armnetwork.SecurityGroupsClientCreateOrUpdateResponse{
|
||||
SecurityGroupsClientCreateOrUpdateResult: armnetwork.SecurityGroupsClientCreateOrUpdateResult{
|
||||
SecurityGroup: armnetwork.SecurityGroup{
|
||||
ID: to.StringPtr("network-security-group-id"),
|
||||
},
|
||||
},
|
||||
}, r.pollerErr
|
||||
}
|
||||
|
||||
type stubNetworkSecurityGroupsAPI struct {
|
||||
createErr error
|
||||
stubPoller stubNetworkSecurityGroupsCreateOrUpdatePollerResponse
|
||||
}
|
||||
|
||||
func (a stubNetworkSecurityGroupsAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
|
||||
networkSecurityGroupName string, parameters armnetwork.SecurityGroup,
|
||||
options *armnetwork.SecurityGroupsClientBeginCreateOrUpdateOptions) (
|
||||
networkSecurityGroupsCreateOrUpdatePollerResponse, error,
|
||||
) {
|
||||
return a.stubPoller, a.createErr
|
||||
}
|
||||
|
||||
type stubResourceGroupAPI struct {
|
||||
terminateErr error
|
||||
createErr error
|
||||
getErr error
|
||||
getResourceGroup armresources.ResourceGroup
|
||||
stubResponse stubResourceGroupsDeletePollerResponse
|
||||
}
|
||||
|
||||
func (a stubResourceGroupAPI) CreateOrUpdate(ctx context.Context, resourceGroupName string,
|
||||
parameters armresources.ResourceGroup,
|
||||
options *armresources.ResourceGroupsClientCreateOrUpdateOptions) (
|
||||
armresources.ResourceGroupsClientCreateOrUpdateResponse, error,
|
||||
) {
|
||||
return armresources.ResourceGroupsClientCreateOrUpdateResponse{}, a.createErr
|
||||
}
|
||||
|
||||
func (a stubResourceGroupAPI) Get(ctx context.Context, resourceGroupName string, options *armresources.ResourceGroupsClientGetOptions) (armresources.ResourceGroupsClientGetResponse, error) {
|
||||
return armresources.ResourceGroupsClientGetResponse{
|
||||
ResourceGroupsClientGetResult: armresources.ResourceGroupsClientGetResult{
|
||||
ResourceGroup: a.getResourceGroup,
|
||||
},
|
||||
}, a.getErr
|
||||
}
|
||||
|
||||
type stubResourceGroupsDeletePollerResponse struct {
|
||||
armresources.ResourceGroupsClientDeletePollerResponse
|
||||
pollerErr error
|
||||
}
|
||||
|
||||
func (r stubResourceGroupsDeletePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration) (
|
||||
armresources.ResourceGroupsClientDeleteResponse, error,
|
||||
) {
|
||||
return armresources.ResourceGroupsClientDeleteResponse{}, r.pollerErr
|
||||
}
|
||||
|
||||
func (a stubResourceGroupAPI) BeginDelete(ctx context.Context, resourceGroupName string,
|
||||
options *armresources.ResourceGroupsClientBeginDeleteOptions) (
|
||||
resourceGroupsDeletePollerResponse, error,
|
||||
) {
|
||||
return a.stubResponse, a.terminateErr
|
||||
}
|
||||
|
||||
type stubScaleSetsAPI struct {
|
||||
createErr error
|
||||
stubResponse stubVirtualMachineScaleSetsCreateOrUpdatePollerResponse
|
||||
}
|
||||
|
||||
type stubVirtualMachineScaleSetsCreateOrUpdatePollerResponse struct {
|
||||
pollResponse armcompute.VirtualMachineScaleSetsClientCreateOrUpdateResponse
|
||||
pollErr error
|
||||
}
|
||||
|
||||
func (r stubVirtualMachineScaleSetsCreateOrUpdatePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration) (
|
||||
armcompute.VirtualMachineScaleSetsClientCreateOrUpdateResponse, error,
|
||||
) {
|
||||
return r.pollResponse, r.pollErr
|
||||
}
|
||||
|
||||
func (a stubScaleSetsAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
|
||||
vmScaleSetName string, parameters armcompute.VirtualMachineScaleSet,
|
||||
options *armcompute.VirtualMachineScaleSetsClientBeginCreateOrUpdateOptions) (
|
||||
virtualMachineScaleSetsCreateOrUpdatePollerResponse, error,
|
||||
) {
|
||||
return a.stubResponse, a.createErr
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
type stubPublicIPAddressesAPI struct {
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
createErr error
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
getErr error
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
stubCreateResponse stubPublicIPAddressesClientCreateOrUpdatePollerResponse
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
type stubPublicIPAddressesClientCreateOrUpdatePollerResponse struct {
|
||||
armnetwork.PublicIPAddressesClientCreateOrUpdatePollerResponse
|
||||
pollErr error
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
func (r stubPublicIPAddressesClientCreateOrUpdatePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration) (
|
||||
armnetwork.PublicIPAddressesClientCreateOrUpdateResponse, error,
|
||||
) {
|
||||
return armnetwork.PublicIPAddressesClientCreateOrUpdateResponse{
|
||||
PublicIPAddressesClientCreateOrUpdateResult: armnetwork.PublicIPAddressesClientCreateOrUpdateResult{
|
||||
PublicIPAddress: armnetwork.PublicIPAddress{
|
||||
ID: to.StringPtr("pubIP-id"),
|
||||
},
|
||||
},
|
||||
}, r.pollErr
|
||||
}
|
||||
|
||||
type stubPublicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager struct {
|
||||
pagesCounter int
|
||||
PagesMax int
|
||||
}
|
||||
|
||||
func (p *stubPublicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager) NextPage(ctx context.Context) bool {
|
||||
p.pagesCounter++
|
||||
return p.pagesCounter <= p.PagesMax
|
||||
}
|
||||
|
||||
func (p *stubPublicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager) PageResponse() armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResponse {
|
||||
return armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResponse{
|
||||
PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResult: armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResult{
|
||||
PublicIPAddressListResult: armnetwork.PublicIPAddressListResult{
|
||||
Value: []*armnetwork.PublicIPAddress{
|
||||
{
|
||||
Properties: &armnetwork.PublicIPAddressPropertiesFormat{
|
||||
IPAddress: to.StringPtr("192.0.2.1"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (a stubPublicIPAddressesAPI) ListVirtualMachineScaleSetVMPublicIPAddresses(resourceGroupName string,
|
||||
virtualMachineScaleSetName string, virtualmachineIndex string,
|
||||
networkInterfaceName string, ipConfigurationName string,
|
||||
options *armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesOptions,
|
||||
) publicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager {
|
||||
return &stubPublicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager{pagesCounter: 0, PagesMax: 1}
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
func (a stubPublicIPAddressesAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, publicIPAddressName string,
|
||||
parameters armnetwork.PublicIPAddress, options *armnetwork.PublicIPAddressesClientBeginCreateOrUpdateOptions) (
|
||||
publicIPAddressesClientCreateOrUpdatePollerResponse, error,
|
||||
) {
|
||||
return a.stubCreateResponse, a.createErr
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
func (a stubPublicIPAddressesAPI) Get(ctx context.Context, resourceGroupName string, publicIPAddressName string, options *armnetwork.PublicIPAddressesClientGetOptions) (
|
||||
armnetwork.PublicIPAddressesClientGetResponse, error,
|
||||
) {
|
||||
return armnetwork.PublicIPAddressesClientGetResponse{
|
||||
PublicIPAddressesClientGetResult: armnetwork.PublicIPAddressesClientGetResult{
|
||||
PublicIPAddress: armnetwork.PublicIPAddress{
|
||||
Properties: &armnetwork.PublicIPAddressPropertiesFormat{
|
||||
IPAddress: to.StringPtr("192.0.2.1"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}, a.getErr
|
||||
}
|
||||
|
||||
type stubNetworkInterfacesAPI struct {
|
||||
getErr error
|
||||
// TODO: deprecate as soon as scale sets are available
|
||||
createErr error
|
||||
// TODO: deprecate as soon as scale sets are available
|
||||
stubResp stubInterfacesClientCreateOrUpdatePollerResponse
|
||||
}
|
||||
|
||||
func (a stubNetworkInterfacesAPI) GetVirtualMachineScaleSetNetworkInterface(ctx context.Context, resourceGroupName string,
|
||||
virtualMachineScaleSetName string, virtualmachineIndex string, networkInterfaceName string,
|
||||
options *armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceOptions,
|
||||
) (armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse, error) {
|
||||
if a.getErr != nil {
|
||||
return armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse{}, a.getErr
|
||||
}
|
||||
return armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse{
|
||||
InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResult: armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResult{
|
||||
Interface: armnetwork.Interface{
|
||||
Properties: &armnetwork.InterfacePropertiesFormat{
|
||||
IPConfigurations: []*armnetwork.InterfaceIPConfiguration{
|
||||
{
|
||||
Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{
|
||||
PrivateIPAddress: to.StringPtr("192.0.2.1"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
type stubInterfacesClientCreateOrUpdatePollerResponse struct {
|
||||
pollErr error
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
func (r stubInterfacesClientCreateOrUpdatePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration) (
|
||||
armnetwork.InterfacesClientCreateOrUpdateResponse, error,
|
||||
) {
|
||||
return armnetwork.InterfacesClientCreateOrUpdateResponse{
|
||||
InterfacesClientCreateOrUpdateResult: armnetwork.InterfacesClientCreateOrUpdateResult{
|
||||
Interface: armnetwork.Interface{
|
||||
ID: to.StringPtr("interface-id"),
|
||||
Properties: &armnetwork.InterfacePropertiesFormat{
|
||||
IPConfigurations: []*armnetwork.InterfaceIPConfiguration{
|
||||
{
|
||||
Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{
|
||||
PrivateIPAddress: to.StringPtr("192.0.2.1"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, r.pollErr
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
func (a stubNetworkInterfacesAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, networkInterfaceName string,
|
||||
parameters armnetwork.Interface, options *armnetwork.InterfacesClientBeginCreateOrUpdateOptions) (
|
||||
interfacesClientCreateOrUpdatePollerResponse, error,
|
||||
) {
|
||||
return a.stubResp, a.createErr
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
type stubVirtualMachinesAPI struct {
|
||||
stubResponse stubVirtualMachinesClientCreateOrUpdatePollerResponse
|
||||
createErr error
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
type stubVirtualMachinesClientCreateOrUpdatePollerResponse struct {
|
||||
pollResponse armcompute.VirtualMachinesClientCreateOrUpdateResponse
|
||||
pollErr error
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
func (r stubVirtualMachinesClientCreateOrUpdatePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration) (
|
||||
armcompute.VirtualMachinesClientCreateOrUpdateResponse, error,
|
||||
) {
|
||||
return r.pollResponse, r.pollErr
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
func (a stubVirtualMachinesAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, vmName string, parameters armcompute.VirtualMachine,
|
||||
options *armcompute.VirtualMachinesClientBeginCreateOrUpdateOptions,
|
||||
) (virtualMachinesClientCreateOrUpdatePollerResponse, error) {
|
||||
return a.stubResponse, a.createErr
|
||||
}
|
||||
|
||||
type stubApplicationsAPI struct {
|
||||
createErr error
|
||||
deleteErr error
|
||||
updateCredentialsErr error
|
||||
createApplication *graphrbac.Application
|
||||
}
|
||||
|
||||
func (a stubApplicationsAPI) Create(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error) {
|
||||
if a.createErr != nil {
|
||||
return graphrbac.Application{}, a.createErr
|
||||
}
|
||||
if a.createApplication != nil {
|
||||
return *a.createApplication, nil
|
||||
}
|
||||
return graphrbac.Application{
|
||||
AppID: to.StringPtr("00000000-0000-0000-0000-000000000000"),
|
||||
ObjectID: to.StringPtr("00000000-0000-0000-0000-000000000001"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a stubApplicationsAPI) Delete(ctx context.Context, applicationObjectID string) (autorest.Response, error) {
|
||||
if a.deleteErr != nil {
|
||||
return autorest.Response{}, a.deleteErr
|
||||
}
|
||||
return autorest.Response{}, nil
|
||||
}
|
||||
|
||||
func (a stubApplicationsAPI) UpdatePasswordCredentials(ctx context.Context, objectID string, parameters graphrbac.PasswordCredentialsUpdateParameters) (autorest.Response, error) {
|
||||
if a.updateCredentialsErr != nil {
|
||||
return autorest.Response{}, a.updateCredentialsErr
|
||||
}
|
||||
return autorest.Response{}, nil
|
||||
}
|
||||
|
||||
type stubServicePrincipalsAPI struct {
|
||||
createErr error
|
||||
createServicePrincipal *graphrbac.ServicePrincipal
|
||||
}
|
||||
|
||||
func (a stubServicePrincipalsAPI) Create(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) {
|
||||
if a.createErr != nil {
|
||||
return graphrbac.ServicePrincipal{}, a.createErr
|
||||
}
|
||||
if a.createServicePrincipal != nil {
|
||||
return *a.createServicePrincipal, nil
|
||||
}
|
||||
return graphrbac.ServicePrincipal{
|
||||
AppID: to.StringPtr("00000000-0000-0000-0000-000000000000"),
|
||||
ObjectID: to.StringPtr("00000000-0000-0000-0000-000000000002"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type stubRoleAssignmentsAPI struct {
|
||||
createCounter int
|
||||
createErrors []error
|
||||
}
|
||||
|
||||
func (a *stubRoleAssignmentsAPI) Create(ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) {
|
||||
a.createCounter += 1
|
||||
if len(a.createErrors) == 0 {
|
||||
return authorization.RoleAssignment{}, nil
|
||||
}
|
||||
return authorization.RoleAssignment{}, a.createErrors[(a.createCounter-1)%len(a.createErrors)]
|
||||
}
|
136
cli/azure/client/azurewrappers.go
Normal file
136
cli/azure/client/azurewrappers.go
Normal file
@ -0,0 +1,136 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
|
||||
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
|
||||
"github.com/Azure/go-autorest/autorest"
|
||||
)
|
||||
|
||||
type networksClient struct {
|
||||
*armnetwork.VirtualNetworksClient
|
||||
}
|
||||
|
||||
func (c *networksClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
|
||||
virtualNetworkName string, parameters armnetwork.VirtualNetwork,
|
||||
options *armnetwork.VirtualNetworksClientBeginCreateOrUpdateOptions) (
|
||||
virtualNetworksCreateOrUpdatePollerResponse, error,
|
||||
) {
|
||||
return c.VirtualNetworksClient.BeginCreateOrUpdate(ctx, resourceGroupName, virtualNetworkName, parameters, options)
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
type networkInterfacesClient struct {
|
||||
*armnetwork.InterfacesClient
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
func (c *networkInterfacesClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, networkInterfaceName string,
|
||||
parameters armnetwork.Interface, options *armnetwork.InterfacesClientBeginCreateOrUpdateOptions,
|
||||
) (interfacesClientCreateOrUpdatePollerResponse, error) {
|
||||
return c.InterfacesClient.BeginCreateOrUpdate(ctx, resourceGroupName, networkInterfaceName, parameters, options)
|
||||
}
|
||||
|
||||
type networkSecurityGroupsClient struct {
|
||||
*armnetwork.SecurityGroupsClient
|
||||
}
|
||||
|
||||
func (c *networkSecurityGroupsClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
|
||||
networkSecurityGroupName string, parameters armnetwork.SecurityGroup,
|
||||
options *armnetwork.SecurityGroupsClientBeginCreateOrUpdateOptions) (
|
||||
networkSecurityGroupsCreateOrUpdatePollerResponse, error,
|
||||
) {
|
||||
return c.SecurityGroupsClient.BeginCreateOrUpdate(ctx, resourceGroupName, networkSecurityGroupName, parameters, options)
|
||||
}
|
||||
|
||||
type publicIPAddressesClient struct {
|
||||
*armnetwork.PublicIPAddressesClient
|
||||
}
|
||||
|
||||
func (c *publicIPAddressesClient) ListVirtualMachineScaleSetVMPublicIPAddresses(resourceGroupName string,
|
||||
virtualMachineScaleSetName string, virtualmachineIndex string,
|
||||
networkInterfaceName string, ipConfigurationName string,
|
||||
options *armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesOptions,
|
||||
) publicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager {
|
||||
return c.PublicIPAddressesClient.ListVirtualMachineScaleSetVMPublicIPAddresses(resourceGroupName, virtualMachineScaleSetName,
|
||||
virtualmachineIndex, networkInterfaceName, ipConfigurationName, options)
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
func (c *publicIPAddressesClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, publicIPAddressName string,
|
||||
parameters armnetwork.PublicIPAddress, options *armnetwork.PublicIPAddressesClientBeginCreateOrUpdateOptions) (
|
||||
publicIPAddressesClientCreateOrUpdatePollerResponse, error,
|
||||
) {
|
||||
return c.PublicIPAddressesClient.BeginCreateOrUpdate(ctx, resourceGroupName, publicIPAddressName, parameters, options)
|
||||
}
|
||||
|
||||
type virtualMachineScaleSetsClient struct {
|
||||
*armcompute.VirtualMachineScaleSetsClient
|
||||
}
|
||||
|
||||
func (c *virtualMachineScaleSetsClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
|
||||
vmScaleSetName string, parameters armcompute.VirtualMachineScaleSet,
|
||||
options *armcompute.VirtualMachineScaleSetsClientBeginCreateOrUpdateOptions) (
|
||||
virtualMachineScaleSetsCreateOrUpdatePollerResponse, error,
|
||||
) {
|
||||
return c.VirtualMachineScaleSetsClient.BeginCreateOrUpdate(ctx, resourceGroupName, vmScaleSetName, parameters, options)
|
||||
}
|
||||
|
||||
type resourceGroupsClient struct {
|
||||
*armresources.ResourceGroupsClient
|
||||
}
|
||||
|
||||
func (c *resourceGroupsClient) BeginDelete(ctx context.Context, resourceGroupName string,
|
||||
options *armresources.ResourceGroupsClientBeginDeleteOptions) (
|
||||
resourceGroupsDeletePollerResponse, error,
|
||||
) {
|
||||
return c.ResourceGroupsClient.BeginDelete(ctx, resourceGroupName, options)
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
type virtualMachinesClient struct {
|
||||
*armcompute.VirtualMachinesClient
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
func (c *virtualMachinesClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, vmName string, parameters armcompute.VirtualMachine,
|
||||
options *armcompute.VirtualMachinesClientBeginCreateOrUpdateOptions,
|
||||
) (virtualMachinesClientCreateOrUpdatePollerResponse, error) {
|
||||
return c.VirtualMachinesClient.BeginCreateOrUpdate(ctx, resourceGroupName, vmName, parameters, options)
|
||||
}
|
||||
|
||||
type applicationsClient struct {
|
||||
*graphrbac.ApplicationsClient
|
||||
}
|
||||
|
||||
func (c *applicationsClient) Create(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error) {
|
||||
return c.ApplicationsClient.Create(ctx, parameters)
|
||||
}
|
||||
|
||||
func (c *applicationsClient) Delete(ctx context.Context, applicationObjectID string) (autorest.Response, error) {
|
||||
return c.ApplicationsClient.Delete(ctx, applicationObjectID)
|
||||
}
|
||||
|
||||
func (c *applicationsClient) UpdatePasswordCredentials(ctx context.Context, objectID string, parameters graphrbac.PasswordCredentialsUpdateParameters) (autorest.Response, error) {
|
||||
return c.ApplicationsClient.UpdatePasswordCredentials(ctx, objectID, parameters)
|
||||
}
|
||||
|
||||
type servicePrincipalsClient struct {
|
||||
*graphrbac.ServicePrincipalsClient
|
||||
}
|
||||
|
||||
func (c *servicePrincipalsClient) Create(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) {
|
||||
return c.ServicePrincipalsClient.Create(ctx, parameters)
|
||||
}
|
||||
|
||||
type roleAssignmentsClient struct {
|
||||
*authorization.RoleAssignmentsClient
|
||||
}
|
||||
|
||||
func (c *roleAssignmentsClient) Create(ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) {
|
||||
return c.RoleAssignmentsClient.Create(ctx, scope, roleAssignmentName, parameters)
|
||||
}
|
277
cli/azure/client/client.go
Normal file
277
cli/azure/client/client.go
Normal file
@ -0,0 +1,277 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
|
||||
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
|
||||
"github.com/Azure/go-autorest/autorest"
|
||||
"github.com/Azure/go-autorest/autorest/azure/auth"
|
||||
"github.com/edgelesssys/constellation/cli/azure"
|
||||
"github.com/edgelesssys/constellation/cli/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
)
|
||||
|
||||
const (
|
||||
graphAPIResource = "https://graph.windows.net"
|
||||
managementAPIResource = "https://management.azure.com"
|
||||
)
|
||||
|
||||
// Client is a client for Azure.
|
||||
type Client struct {
|
||||
networksAPI
|
||||
networkSecurityGroupsAPI
|
||||
resourceGroupAPI
|
||||
scaleSetsAPI
|
||||
publicIPAddressesAPI
|
||||
networkInterfacesAPI
|
||||
virtualMachinesAPI
|
||||
applicationsAPI
|
||||
servicePrincipalsAPI
|
||||
roleAssignmentsAPI
|
||||
|
||||
adReplicationLagCheckInterval time.Duration
|
||||
adReplicationLagCheckMaxRetries int
|
||||
|
||||
nodes azure.Instances
|
||||
coordinators azure.Instances
|
||||
|
||||
name string
|
||||
uid string
|
||||
resourceGroup string
|
||||
location string
|
||||
subscriptionID string
|
||||
tenantID string
|
||||
subnetID string
|
||||
coordinatorsScaleSet string
|
||||
nodesScaleSet string
|
||||
networkSecurityGroup string
|
||||
adAppObjectID string
|
||||
}
|
||||
|
||||
// NewFromDefault creates a client with initialized clients.
|
||||
func NewFromDefault(subscriptionID, tenantID string) (*Client, error) {
|
||||
cred, err := azidentity.NewDefaultAzureCredential(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
graphAuthorizer, err := getAuthorizer(graphAPIResource)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
managementAuthorizer, err := getAuthorizer(managementAPIResource)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
netAPI := armnetwork.NewVirtualNetworksClient(subscriptionID, cred, nil)
|
||||
netSecGrpAPI := armnetwork.NewSecurityGroupsClient(subscriptionID, cred, nil)
|
||||
resGroupAPI := armresources.NewResourceGroupsClient(subscriptionID, cred, nil)
|
||||
scaleSetAPI := armcompute.NewVirtualMachineScaleSetsClient(subscriptionID, cred, nil)
|
||||
publicIPAddressesAPI := armnetwork.NewPublicIPAddressesClient(subscriptionID, cred, nil)
|
||||
networkInterfacesAPI := armnetwork.NewInterfacesClient(subscriptionID, cred, nil)
|
||||
virtualMachinesAPI := armcompute.NewVirtualMachinesClient(subscriptionID, cred, nil)
|
||||
applicationsAPI := graphrbac.NewApplicationsClient(tenantID)
|
||||
applicationsAPI.Authorizer = graphAuthorizer
|
||||
servicePrincipalsAPI := graphrbac.NewServicePrincipalsClient(tenantID)
|
||||
servicePrincipalsAPI.Authorizer = graphAuthorizer
|
||||
roleAssignmentsAPI := authorization.NewRoleAssignmentsClient(subscriptionID)
|
||||
roleAssignmentsAPI.Authorizer = managementAuthorizer
|
||||
|
||||
return &Client{
|
||||
networksAPI: &networksClient{netAPI},
|
||||
networkSecurityGroupsAPI: &networkSecurityGroupsClient{netSecGrpAPI},
|
||||
resourceGroupAPI: &resourceGroupsClient{resGroupAPI},
|
||||
scaleSetsAPI: &virtualMachineScaleSetsClient{scaleSetAPI},
|
||||
publicIPAddressesAPI: &publicIPAddressesClient{publicIPAddressesAPI},
|
||||
networkInterfacesAPI: &networkInterfacesClient{networkInterfacesAPI},
|
||||
applicationsAPI: &applicationsClient{&applicationsAPI},
|
||||
servicePrincipalsAPI: &servicePrincipalsClient{&servicePrincipalsAPI},
|
||||
roleAssignmentsAPI: &roleAssignmentsClient{&roleAssignmentsAPI},
|
||||
virtualMachinesAPI: &virtualMachinesClient{virtualMachinesAPI},
|
||||
subscriptionID: subscriptionID,
|
||||
tenantID: tenantID,
|
||||
nodes: azure.Instances{},
|
||||
coordinators: azure.Instances{},
|
||||
adReplicationLagCheckInterval: adReplicationLagCheckInterval,
|
||||
adReplicationLagCheckMaxRetries: adReplicationLagCheckMaxRetries,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewInitialized creates and initializes client by setting the subscriptionID, location and name
|
||||
// of the Constellation.
|
||||
func NewInitialized(subscriptionID, tenantID, name, location string) (*Client, error) {
|
||||
client, err := NewFromDefault(subscriptionID, tenantID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = client.init(location, name)
|
||||
return client, err
|
||||
}
|
||||
|
||||
// init initializes the client.
|
||||
func (c *Client) init(location, name string) error {
|
||||
c.location = location
|
||||
c.name = name
|
||||
uid, err := c.generateUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.uid = uid
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetState returns the state of the client as ConstellationState.
|
||||
func (c *Client) GetState() (state.ConstellationState, error) {
|
||||
var stat state.ConstellationState
|
||||
stat.CloudProvider = cloudprovider.Azure.String()
|
||||
if len(c.resourceGroup) == 0 {
|
||||
return state.ConstellationState{}, errors.New("client has no resource group")
|
||||
}
|
||||
stat.AzureResourceGroup = c.resourceGroup
|
||||
if c.name == "" {
|
||||
return state.ConstellationState{}, errors.New("client has no name")
|
||||
}
|
||||
stat.Name = c.name
|
||||
if len(c.uid) == 0 {
|
||||
return state.ConstellationState{}, errors.New("client has no uid")
|
||||
}
|
||||
stat.UID = c.uid
|
||||
if len(c.location) == 0 {
|
||||
return state.ConstellationState{}, errors.New("client has no location")
|
||||
}
|
||||
stat.AzureLocation = c.location
|
||||
if len(c.subscriptionID) == 0 {
|
||||
return state.ConstellationState{}, errors.New("client has no subscription")
|
||||
}
|
||||
stat.AzureSubscription = c.subscriptionID
|
||||
if len(c.tenantID) == 0 {
|
||||
return state.ConstellationState{}, errors.New("client has no tenant")
|
||||
}
|
||||
stat.AzureTenant = c.tenantID
|
||||
if len(c.subnetID) == 0 {
|
||||
return state.ConstellationState{}, errors.New("client has no subnet")
|
||||
}
|
||||
stat.AzureSubnet = c.subnetID
|
||||
if len(c.networkSecurityGroup) == 0 {
|
||||
return state.ConstellationState{}, errors.New("client has no network security group")
|
||||
}
|
||||
stat.AzureNetworkSecurityGroup = c.networkSecurityGroup
|
||||
// TODO: un-deprecate as soon as scale sets are available
|
||||
// if len(c.nodesScaleSet) == 0 {
|
||||
// return state.ConstellationState{}, errors.New("client has no nodes scale set")
|
||||
// }
|
||||
// stat.AzureNodesScaleSet = c.nodesScaleSet
|
||||
// if len(c.coordinatorsScaleSet) == 0 {
|
||||
// return state.ConstellationState{}, errors.New("client has no coordinators scale set")
|
||||
// }
|
||||
// stat.AzureCoordinatorsScaleSet = c.coordinatorsScaleSet
|
||||
if len(c.nodes) == 0 {
|
||||
return state.ConstellationState{}, errors.New("client has no nodes")
|
||||
}
|
||||
stat.AzureNodes = c.nodes
|
||||
if len(c.coordinators) == 0 {
|
||||
return state.ConstellationState{}, errors.New("client has no coordinators")
|
||||
}
|
||||
stat.AzureCoordinators = c.coordinators
|
||||
// AD App Object ID does not have to be set at all times
|
||||
stat.AzureADAppObjectID = c.adAppObjectID
|
||||
|
||||
return stat, nil
|
||||
}
|
||||
|
||||
// SetState sets the state of the client to the handed ConstellationState.
|
||||
func (c *Client) SetState(stat state.ConstellationState) error {
|
||||
if stat.CloudProvider != cloudprovider.Azure.String() {
|
||||
return errors.New("state is not azure state")
|
||||
}
|
||||
if len(stat.AzureResourceGroup) == 0 {
|
||||
return errors.New("state has no resource group")
|
||||
}
|
||||
c.resourceGroup = stat.AzureResourceGroup
|
||||
if stat.Name == "" {
|
||||
return errors.New("state has no name")
|
||||
}
|
||||
c.name = stat.Name
|
||||
if len(stat.UID) == 0 {
|
||||
return errors.New("state has no uuid")
|
||||
}
|
||||
c.uid = stat.UID
|
||||
if len(stat.AzureLocation) == 0 {
|
||||
return errors.New("state has no location")
|
||||
}
|
||||
c.location = stat.AzureLocation
|
||||
if len(stat.AzureSubscription) == 0 {
|
||||
return errors.New("state has no subscription")
|
||||
}
|
||||
c.subscriptionID = stat.AzureSubscription
|
||||
if len(stat.AzureTenant) == 0 {
|
||||
return errors.New("state has no tenant")
|
||||
}
|
||||
c.tenantID = stat.AzureTenant
|
||||
if len(stat.AzureSubnet) == 0 {
|
||||
return errors.New("state has no subnet")
|
||||
}
|
||||
c.subnetID = stat.AzureSubnet
|
||||
if len(stat.AzureNetworkSecurityGroup) == 0 {
|
||||
return errors.New("state has no subnet")
|
||||
}
|
||||
c.networkSecurityGroup = stat.AzureNetworkSecurityGroup
|
||||
// TODO: un-deprecate as soon as scale sets are available
|
||||
//if len(stat.AzureNodesScaleSet) == 0 {
|
||||
// return errors.New("state has no nodes scale set")
|
||||
//}
|
||||
//c.nodesScaleSet = stat.AzureNodesScaleSet
|
||||
//if len(stat.AzureCoordinatorsScaleSet) == 0 {
|
||||
// return errors.New("state has no nodes scale set")
|
||||
//}
|
||||
//c.coordinatorsScaleSet = stat.AzureCoordinatorsScaleSet
|
||||
if len(stat.AzureNodes) == 0 {
|
||||
return errors.New("state has no coordinator scale set")
|
||||
}
|
||||
c.nodes = stat.AzureNodes
|
||||
if len(stat.AzureCoordinators) == 0 {
|
||||
return errors.New("state has no coordinators")
|
||||
}
|
||||
c.coordinators = stat.AzureCoordinators
|
||||
// AD App Object ID does not have to be set at all times
|
||||
c.adAppObjectID = stat.AzureADAppObjectID
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) generateUID() (string, error) {
|
||||
letters := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
|
||||
|
||||
const uidLen = 5
|
||||
uid := make([]byte, uidLen)
|
||||
for i := 0; i < uidLen; i++ {
|
||||
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
uid[i] = letters[n.Int64()]
|
||||
}
|
||||
return string(uid), nil
|
||||
}
|
||||
|
||||
// getAuthorizer creates an autorest.Authorizer for different Azure AD APIs using either environment variables or azure cli credentials.
|
||||
func getAuthorizer(resource string) (autorest.Authorizer, error) {
|
||||
authorizer, cliErr := auth.NewAuthorizerFromCLIWithResource(resource)
|
||||
if cliErr == nil {
|
||||
return authorizer, nil
|
||||
}
|
||||
authorizer, envErr := auth.NewAuthorizerFromEnvironmentWithResource(resource)
|
||||
if envErr == nil {
|
||||
return authorizer, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unable to create authorizer from env or cli: %v %v", envErr, cliErr)
|
||||
}
|
486
cli/azure/client/client_test.go
Normal file
486
cli/azure/client/client_test.go
Normal file
@ -0,0 +1,486 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/azure"
|
||||
"github.com/edgelesssys/constellation/cli/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSetGetState(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
state state.ConstellationState
|
||||
errExpected bool
|
||||
}{
|
||||
"valid state": {
|
||||
state: state.ConstellationState{
|
||||
CloudProvider: cloudprovider.Azure.String(),
|
||||
AzureNodes: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "ip1",
|
||||
PrivateIP: "ip2",
|
||||
},
|
||||
},
|
||||
AzureCoordinators: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "ip3",
|
||||
PrivateIP: "ip4",
|
||||
},
|
||||
},
|
||||
Name: "name",
|
||||
UID: "uid",
|
||||
AzureResourceGroup: "resource-group",
|
||||
AzureLocation: "location",
|
||||
AzureSubscription: "subscription",
|
||||
AzureTenant: "tenant",
|
||||
AzureSubnet: "azure-subnet",
|
||||
AzureNetworkSecurityGroup: "network-security-group",
|
||||
// TODO: un-deprecate as soon as scale sets are available
|
||||
// AzureNodesScaleSet: "node-scale-set",
|
||||
// AzureCoordinatorsScaleSet: "coordinator-scale-set",
|
||||
},
|
||||
},
|
||||
"missing nodes": {
|
||||
state: state.ConstellationState{
|
||||
CloudProvider: cloudprovider.Azure.String(),
|
||||
AzureCoordinators: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "ip3",
|
||||
PrivateIP: "ip4",
|
||||
},
|
||||
},
|
||||
Name: "name",
|
||||
UID: "uid",
|
||||
AzureResourceGroup: "resource-group",
|
||||
AzureLocation: "location",
|
||||
AzureSubscription: "subscription",
|
||||
AzureTenant: "tenant",
|
||||
AzureSubnet: "azure-subnet",
|
||||
AzureNetworkSecurityGroup: "network-security-group",
|
||||
AzureNodesScaleSet: "node-scale-set",
|
||||
AzureCoordinatorsScaleSet: "coordinator-scale-set",
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"missing coordinator": {
|
||||
state: state.ConstellationState{
|
||||
CloudProvider: cloudprovider.Azure.String(),
|
||||
AzureNodes: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "ip1",
|
||||
PrivateIP: "ip2",
|
||||
},
|
||||
},
|
||||
Name: "name",
|
||||
UID: "uid",
|
||||
AzureResourceGroup: "resource-group",
|
||||
AzureLocation: "location",
|
||||
AzureSubscription: "subscription",
|
||||
AzureTenant: "tenant",
|
||||
AzureSubnet: "azure-subnet",
|
||||
AzureNetworkSecurityGroup: "network-security-group",
|
||||
AzureNodesScaleSet: "node-scale-set",
|
||||
AzureCoordinatorsScaleSet: "coordinator-scale-set",
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"missing name": {
|
||||
state: state.ConstellationState{
|
||||
CloudProvider: cloudprovider.Azure.String(),
|
||||
AzureNodes: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "ip1",
|
||||
PrivateIP: "ip2",
|
||||
},
|
||||
},
|
||||
AzureCoordinators: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "ip3",
|
||||
PrivateIP: "ip4",
|
||||
},
|
||||
},
|
||||
UID: "uid",
|
||||
AzureResourceGroup: "resource-group",
|
||||
AzureLocation: "location",
|
||||
AzureSubscription: "subscription",
|
||||
AzureTenant: "tenant",
|
||||
AzureSubnet: "azure-subnet",
|
||||
AzureNetworkSecurityGroup: "network-security-group",
|
||||
AzureNodesScaleSet: "node-scale-set",
|
||||
AzureCoordinatorsScaleSet: "coordinator-scale-set",
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"missing uid": {
|
||||
state: state.ConstellationState{
|
||||
CloudProvider: cloudprovider.Azure.String(),
|
||||
AzureNodes: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "ip1",
|
||||
PrivateIP: "ip2",
|
||||
},
|
||||
},
|
||||
AzureCoordinators: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "ip3",
|
||||
PrivateIP: "ip4",
|
||||
},
|
||||
},
|
||||
Name: "name",
|
||||
AzureResourceGroup: "resource-group",
|
||||
AzureLocation: "location",
|
||||
AzureSubscription: "subscription",
|
||||
AzureTenant: "tenant",
|
||||
AzureSubnet: "azure-subnet",
|
||||
AzureNetworkSecurityGroup: "network-security-group",
|
||||
AzureNodesScaleSet: "node-scale-set",
|
||||
AzureCoordinatorsScaleSet: "coordinator-scale-set",
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"missing resource group": {
|
||||
state: state.ConstellationState{
|
||||
CloudProvider: cloudprovider.Azure.String(),
|
||||
AzureNodes: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "ip1",
|
||||
PrivateIP: "ip2",
|
||||
},
|
||||
},
|
||||
AzureCoordinators: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "ip3",
|
||||
PrivateIP: "ip4",
|
||||
},
|
||||
},
|
||||
Name: "name",
|
||||
UID: "uid",
|
||||
AzureLocation: "location",
|
||||
AzureSubscription: "subscription",
|
||||
AzureTenant: "tenant",
|
||||
AzureSubnet: "azure-subnet",
|
||||
AzureNetworkSecurityGroup: "network-security-group",
|
||||
AzureNodesScaleSet: "node-scale-set",
|
||||
AzureCoordinatorsScaleSet: "coordinator-scale-set",
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"missing location": {
|
||||
state: state.ConstellationState{
|
||||
CloudProvider: cloudprovider.Azure.String(),
|
||||
AzureNodes: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "ip1",
|
||||
PrivateIP: "ip2",
|
||||
},
|
||||
},
|
||||
AzureCoordinators: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "ip3",
|
||||
PrivateIP: "ip4",
|
||||
},
|
||||
},
|
||||
Name: "name",
|
||||
UID: "uid",
|
||||
AzureResourceGroup: "resource-group",
|
||||
AzureSubscription: "subscription",
|
||||
AzureTenant: "tenant",
|
||||
AzureSubnet: "azure-subnet",
|
||||
AzureNetworkSecurityGroup: "network-security-group",
|
||||
AzureNodesScaleSet: "node-scale-set",
|
||||
AzureCoordinatorsScaleSet: "coordinator-scale-set",
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"missing subscription": {
|
||||
state: state.ConstellationState{
|
||||
CloudProvider: cloudprovider.Azure.String(),
|
||||
AzureNodes: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "ip1",
|
||||
PrivateIP: "ip2",
|
||||
},
|
||||
},
|
||||
AzureCoordinators: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "ip3",
|
||||
PrivateIP: "ip4",
|
||||
},
|
||||
},
|
||||
Name: "name",
|
||||
UID: "uid",
|
||||
AzureResourceGroup: "resource-group",
|
||||
AzureTenant: "tenant",
|
||||
AzureLocation: "location",
|
||||
AzureSubnet: "azure-subnet",
|
||||
AzureNetworkSecurityGroup: "network-security-group",
|
||||
AzureNodesScaleSet: "node-scale-set",
|
||||
AzureCoordinatorsScaleSet: "coordinator-scale-set",
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"missing tenant": {
|
||||
state: state.ConstellationState{
|
||||
CloudProvider: cloudprovider.Azure.String(),
|
||||
AzureNodes: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "ip1",
|
||||
PrivateIP: "ip2",
|
||||
},
|
||||
},
|
||||
AzureCoordinators: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "ip3",
|
||||
PrivateIP: "ip4",
|
||||
},
|
||||
},
|
||||
Name: "name",
|
||||
UID: "uid",
|
||||
AzureResourceGroup: "resource-group",
|
||||
AzureSubscription: "subscription",
|
||||
AzureLocation: "location",
|
||||
AzureSubnet: "azure-subnet",
|
||||
AzureNetworkSecurityGroup: "network-security-group",
|
||||
AzureNodesScaleSet: "node-scale-set",
|
||||
AzureCoordinatorsScaleSet: "coordinator-scale-set",
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"missing subnet": {
|
||||
state: state.ConstellationState{
|
||||
CloudProvider: cloudprovider.Azure.String(),
|
||||
AzureNodes: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "ip1",
|
||||
PrivateIP: "ip2",
|
||||
},
|
||||
},
|
||||
AzureCoordinators: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "ip3",
|
||||
PrivateIP: "ip4",
|
||||
},
|
||||
},
|
||||
Name: "name",
|
||||
UID: "uid",
|
||||
AzureResourceGroup: "resource-group",
|
||||
AzureLocation: "location",
|
||||
AzureSubscription: "subscription",
|
||||
AzureTenant: "tenant",
|
||||
AzureNetworkSecurityGroup: "network-security-group",
|
||||
AzureNodesScaleSet: "node-scale-set",
|
||||
AzureCoordinatorsScaleSet: "coordinator-scale-set",
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"missing network security group": {
|
||||
state: state.ConstellationState{
|
||||
CloudProvider: cloudprovider.Azure.String(),
|
||||
AzureNodes: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "ip1",
|
||||
PrivateIP: "ip2",
|
||||
},
|
||||
},
|
||||
AzureCoordinators: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "ip3",
|
||||
PrivateIP: "ip4",
|
||||
},
|
||||
},
|
||||
Name: "name",
|
||||
UID: "uid",
|
||||
AzureResourceGroup: "resource-group",
|
||||
AzureLocation: "location",
|
||||
AzureSubscription: "subscription",
|
||||
AzureTenant: "tenant",
|
||||
AzureSubnet: "azure-subnet",
|
||||
AzureNodesScaleSet: "node-scale-set",
|
||||
AzureCoordinatorsScaleSet: "coordinator-scale-set",
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
// TODO: un-deprecate as soon as scale sets are available
|
||||
// "missing node scale set": {
|
||||
// state: state.ConstellationState{
|
||||
// CloudProvider: cloudprovider.Azure.String(),
|
||||
// AzureNodes: azure.Instances{
|
||||
// "0": {
|
||||
// PublicIP: "ip1",
|
||||
// PrivateIP: "ip2",
|
||||
// },
|
||||
// },
|
||||
// AzureCoordinators: azure.Instances{
|
||||
// "0": {
|
||||
// PublicIP: "ip3",
|
||||
// PrivateIP: "ip4",
|
||||
// },
|
||||
// },
|
||||
// Name: "name",
|
||||
// UID: "uid",
|
||||
// AzureResourceGroup: "resource-group",
|
||||
// AzureLocation: "location",
|
||||
// AzureSubscription: "subscription",
|
||||
// AzureTenant: "tenant",
|
||||
// AzureSubnet: "azure-subnet",
|
||||
// AzureNetworkSecurityGroup: "network-security-group",
|
||||
// AzureCoordinatorsScaleSet: "coordinator-scale-set",
|
||||
// },
|
||||
// errExpected: true,
|
||||
// },
|
||||
// "missing coordinator scale set": {
|
||||
// state: state.ConstellationState{
|
||||
// CloudProvider: cloudprovider.Azure.String(),
|
||||
// AzureNodes: azure.Instances{
|
||||
// "0": {
|
||||
// PublicIP: "ip1",
|
||||
// PrivateIP: "ip2",
|
||||
// },
|
||||
// },
|
||||
// AzureCoordinators: azure.Instances{
|
||||
// "0": {
|
||||
// PublicIP: "ip3",
|
||||
// PrivateIP: "ip4",
|
||||
// },
|
||||
// },
|
||||
// Name: "name",
|
||||
// UID: "uid",
|
||||
// AzureResourceGroup: "resource-group",
|
||||
// AzureLocation: "location",
|
||||
// AzureSubscription: "subscription",
|
||||
// AzureTenant: "tenant",
|
||||
// AzureSubnet: "azure-subnet",
|
||||
// AzureNetworkSecurityGroup: "network-security-group",
|
||||
// AzureNodesScaleSet: "node-scale-set",
|
||||
// },
|
||||
// errExpected: true,
|
||||
// },
|
||||
}
|
||||
|
||||
t.Run("SetState", func(t *testing.T) {
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
client := Client{}
|
||||
if tc.errExpected {
|
||||
assert.Error(client.SetState(tc.state))
|
||||
} else {
|
||||
assert.NoError(client.SetState(tc.state))
|
||||
assert.Equal(tc.state.AzureNodes, client.nodes)
|
||||
assert.Equal(tc.state.AzureCoordinators, client.coordinators)
|
||||
assert.Equal(tc.state.Name, client.name)
|
||||
assert.Equal(tc.state.UID, client.uid)
|
||||
assert.Equal(tc.state.AzureResourceGroup, client.resourceGroup)
|
||||
assert.Equal(tc.state.AzureLocation, client.location)
|
||||
assert.Equal(tc.state.AzureSubscription, client.subscriptionID)
|
||||
assert.Equal(tc.state.AzureTenant, client.tenantID)
|
||||
assert.Equal(tc.state.AzureSubnet, client.subnetID)
|
||||
assert.Equal(tc.state.AzureNetworkSecurityGroup, client.networkSecurityGroup)
|
||||
assert.Equal(tc.state.AzureNodesScaleSet, client.nodesScaleSet)
|
||||
assert.Equal(tc.state.AzureCoordinatorsScaleSet, client.coordinatorsScaleSet)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetState", func(t *testing.T) {
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
client := Client{
|
||||
nodes: tc.state.AzureNodes,
|
||||
coordinators: tc.state.AzureCoordinators,
|
||||
name: tc.state.Name,
|
||||
uid: tc.state.UID,
|
||||
resourceGroup: tc.state.AzureResourceGroup,
|
||||
location: tc.state.AzureLocation,
|
||||
subscriptionID: tc.state.AzureSubscription,
|
||||
tenantID: tc.state.AzureTenant,
|
||||
subnetID: tc.state.AzureSubnet,
|
||||
networkSecurityGroup: tc.state.AzureNetworkSecurityGroup,
|
||||
nodesScaleSet: tc.state.AzureNodesScaleSet,
|
||||
coordinatorsScaleSet: tc.state.AzureCoordinatorsScaleSet,
|
||||
}
|
||||
if tc.errExpected {
|
||||
_, err := client.GetState()
|
||||
assert.Error(err)
|
||||
} else {
|
||||
state, err := client.GetState()
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.state, state)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetStateCloudProvider(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
client := Client{}
|
||||
stateMissingCloudProvider := state.ConstellationState{
|
||||
AzureNodes: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "ip1",
|
||||
PrivateIP: "ip2",
|
||||
},
|
||||
},
|
||||
AzureCoordinators: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "ip3",
|
||||
PrivateIP: "ip4",
|
||||
},
|
||||
},
|
||||
Name: "name",
|
||||
UID: "uid",
|
||||
AzureResourceGroup: "resource-group",
|
||||
AzureLocation: "location",
|
||||
AzureSubscription: "subscription",
|
||||
AzureSubnet: "azure-subnet",
|
||||
AzureNetworkSecurityGroup: "network-security-group",
|
||||
AzureNodesScaleSet: "node-scale-set",
|
||||
AzureCoordinatorsScaleSet: "coordinator-scale-set",
|
||||
}
|
||||
assert.Error(client.SetState(stateMissingCloudProvider))
|
||||
stateIncorrectCloudProvider := state.ConstellationState{
|
||||
CloudProvider: "incorrect",
|
||||
AzureNodes: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "ip1",
|
||||
PrivateIP: "ip2",
|
||||
},
|
||||
},
|
||||
AzureCoordinators: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "ip3",
|
||||
PrivateIP: "ip4",
|
||||
},
|
||||
},
|
||||
Name: "name",
|
||||
UID: "uid",
|
||||
AzureResourceGroup: "resource-group",
|
||||
AzureLocation: "location",
|
||||
AzureSubscription: "subscription",
|
||||
AzureSubnet: "azure-subnet",
|
||||
AzureNetworkSecurityGroup: "network-security-group",
|
||||
AzureNodesScaleSet: "node-scale-set",
|
||||
AzureCoordinatorsScaleSet: "coordinator-scale-set",
|
||||
}
|
||||
assert.Error(client.SetState(stateIncorrectCloudProvider))
|
||||
}
|
||||
|
||||
func TestInit(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
client := Client{}
|
||||
require.NoError(client.init("location", "name"))
|
||||
assert.Equal("location", client.location)
|
||||
assert.Equal("name", client.name)
|
||||
assert.NotEmpty(client.uid)
|
||||
}
|
281
cli/azure/client/compute.go
Normal file
281
cli/azure/client/compute.go
Normal file
@ -0,0 +1,281 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
|
||||
"github.com/edgelesssys/constellation/cli/azure"
|
||||
)
|
||||
|
||||
func (c *Client) CreateInstances(ctx context.Context, input CreateInstancesInput) error {
|
||||
// Create nodes scale set
|
||||
createNodesInput := CreateScaleSetInput{
|
||||
Name: "constellation-scale-set-nodes-" + c.uid,
|
||||
NamePrefix: c.name + "-worker-" + c.uid + "-",
|
||||
Count: input.Count - 1,
|
||||
InstanceType: input.InstanceType,
|
||||
Image: input.Image,
|
||||
UserAssingedIdentity: input.UserAssingedIdentity,
|
||||
}
|
||||
|
||||
if err := c.createScaleSet(ctx, createNodesInput); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.nodesScaleSet = createNodesInput.Name
|
||||
|
||||
// Create coordinator scale set
|
||||
createCoordinatorsInput := CreateScaleSetInput{
|
||||
Name: "constellation-scale-set-coordinators-" + c.uid,
|
||||
NamePrefix: c.name + "-control-plane-" + c.uid + "-",
|
||||
Count: 1,
|
||||
InstanceType: input.InstanceType,
|
||||
Image: input.Image,
|
||||
UserAssingedIdentity: input.UserAssingedIdentity,
|
||||
}
|
||||
|
||||
if err := c.createScaleSet(ctx, createCoordinatorsInput); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get nodes IPs
|
||||
instances, err := c.getInstanceIPs(ctx, createNodesInput.Name, createNodesInput.Count)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.nodes = instances
|
||||
|
||||
// Get coordinators IPs
|
||||
c.coordinatorsScaleSet = createCoordinatorsInput.Name
|
||||
instances, err = c.getInstanceIPs(ctx, createCoordinatorsInput.Name, createCoordinatorsInput.Count)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.coordinators = instances
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateInstancesInput is the input for a CreateInstances operation.
|
||||
type CreateInstancesInput struct {
|
||||
Count int
|
||||
InstanceType string
|
||||
Image string
|
||||
UserAssingedIdentity string
|
||||
}
|
||||
|
||||
// CreateInstancesVMs creates instances based on standalone VMs.
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
func (c *Client) CreateInstancesVMs(ctx context.Context, input CreateInstancesInput) error {
|
||||
pw, err := azure.GeneratePassword()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
vm := azure.VMInstance{
|
||||
Name: c.name + "-control-plane-" + c.uid,
|
||||
Username: "constell",
|
||||
Password: pw,
|
||||
Location: c.location,
|
||||
InstanceType: input.InstanceType,
|
||||
Image: input.Image,
|
||||
}
|
||||
instance, err := c.createInstanceVM(ctx, vm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.coordinators = azure.Instances{"0": instance}
|
||||
|
||||
for i := 0; i < input.Count-1; i++ {
|
||||
vm := azure.VMInstance{
|
||||
Name: c.name + "-node-" + strconv.Itoa(i) + c.uid,
|
||||
Username: "constell",
|
||||
Password: pw,
|
||||
Location: c.location,
|
||||
InstanceType: input.InstanceType,
|
||||
Image: input.Image,
|
||||
}
|
||||
instance, err := c.createInstanceVM(ctx, vm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.nodes[strconv.Itoa(i)] = instance
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createInstanceVM creates a single VM with a public IP address
|
||||
// and a network interface.
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
func (c *Client) createInstanceVM(ctx context.Context, input azure.VMInstance) (azure.Instance, error) {
|
||||
pubIPName := input.Name + "-pubIP"
|
||||
pubIPID, err := c.createPublicIPAddress(ctx, pubIPName)
|
||||
if err != nil {
|
||||
return azure.Instance{}, err
|
||||
}
|
||||
|
||||
nicName := input.Name + "-NIC"
|
||||
privIP, nicID, err := c.createNIC(ctx, nicName, pubIPID)
|
||||
if err != nil {
|
||||
return azure.Instance{}, err
|
||||
}
|
||||
|
||||
input.NIC = nicID
|
||||
|
||||
poller, err := c.virtualMachinesAPI.BeginCreateOrUpdate(ctx, c.resourceGroup, input.Name, input.Azure(), nil)
|
||||
if err != nil {
|
||||
return azure.Instance{}, err
|
||||
}
|
||||
|
||||
vm, err := poller.PollUntilDone(ctx, 30*time.Second)
|
||||
if err != nil {
|
||||
return azure.Instance{}, err
|
||||
}
|
||||
|
||||
if vm.Identity == nil || vm.Identity.PrincipalID == nil {
|
||||
return azure.Instance{}, errors.New("virtual machine was created without system managed identity")
|
||||
}
|
||||
|
||||
if err := c.assignResourceGroupRole(ctx, *vm.Identity.PrincipalID, virtualMachineContributorRoleDefinitionID); err != nil {
|
||||
return azure.Instance{}, err
|
||||
}
|
||||
|
||||
res, err := c.publicIPAddressesAPI.Get(ctx, c.resourceGroup, pubIPName, nil)
|
||||
if err != nil {
|
||||
return azure.Instance{}, err
|
||||
}
|
||||
|
||||
return azure.Instance{PublicIP: *res.PublicIPAddressesClientGetResult.PublicIPAddress.Properties.IPAddress, PrivateIP: privIP}, nil
|
||||
}
|
||||
|
||||
func (c *Client) createScaleSet(ctx context.Context, input CreateScaleSetInput) error {
|
||||
// TODO: Generating a random password to be able
|
||||
// to create the scale set. This is a temporary fix.
|
||||
// We need to think about azure access at some point.
|
||||
pw, err := azure.GeneratePassword()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
scaleSet := azure.ScaleSet{
|
||||
Name: input.Name,
|
||||
NamePrefix: input.NamePrefix,
|
||||
Location: c.location,
|
||||
InstanceType: input.InstanceType,
|
||||
Count: int64(input.Count),
|
||||
Username: "constellation",
|
||||
SubnetID: c.subnetID,
|
||||
NetworkSecurityGroup: c.networkSecurityGroup,
|
||||
Image: input.Image,
|
||||
Password: pw,
|
||||
UserAssignedIdentity: input.UserAssingedIdentity,
|
||||
}.Azure()
|
||||
|
||||
poller, err := c.scaleSetsAPI.BeginCreateOrUpdate(
|
||||
ctx, c.resourceGroup, input.Name,
|
||||
scaleSet,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = poller.PollUntilDone(ctx, 30*time.Second)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) getInstanceIPs(ctx context.Context, scaleSet string, count int) (azure.Instances, error) {
|
||||
instances := azure.Instances{}
|
||||
for i := 0; i < count; i++ {
|
||||
// get public ip address
|
||||
var publicIPAddress string
|
||||
pager := c.publicIPAddressesAPI.ListVirtualMachineScaleSetVMPublicIPAddresses(
|
||||
c.resourceGroup, scaleSet, strconv.Itoa(i), scaleSet, scaleSet, nil)
|
||||
|
||||
// We always need one pager.NextPage, since calling
|
||||
// pager.PageResponse() directly return no result.
|
||||
// We expect to get one page with one entry for each VM.
|
||||
for pager.NextPage(ctx) {
|
||||
for _, v := range pager.PageResponse().Value {
|
||||
if v.Properties != nil && v.Properties.IPAddress != nil {
|
||||
publicIPAddress = *v.Properties.IPAddress
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// get private ip address
|
||||
var privateIPAddress string
|
||||
res, err := c.networkInterfacesAPI.GetVirtualMachineScaleSetNetworkInterface(
|
||||
ctx, c.resourceGroup, scaleSet, strconv.Itoa(i), scaleSet, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
configs := res.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResult.Interface.Properties.IPConfigurations
|
||||
for _, config := range configs {
|
||||
privateIPAddress = *config.Properties.PrivateIPAddress
|
||||
break
|
||||
}
|
||||
|
||||
instance := azure.Instance{
|
||||
PrivateIP: privateIPAddress,
|
||||
PublicIP: publicIPAddress,
|
||||
}
|
||||
instances[strconv.Itoa(i)] = instance
|
||||
}
|
||||
return instances, nil
|
||||
}
|
||||
|
||||
// CreateScaleSetInput is the input for a CreateScaleSet operation.
|
||||
type CreateScaleSetInput struct {
|
||||
Name string
|
||||
NamePrefix string
|
||||
Count int
|
||||
InstanceType string
|
||||
Image string
|
||||
UserAssingedIdentity string
|
||||
}
|
||||
|
||||
// CreateResourceGroup creates a resource group.
|
||||
func (c *Client) CreateResourceGroup(ctx context.Context) error {
|
||||
_, err := c.resourceGroupAPI.CreateOrUpdate(ctx, c.name+"-"+c.uid,
|
||||
armresources.ResourceGroup{
|
||||
Location: &c.location,
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.resourceGroup = c.name + "-" + c.uid
|
||||
return nil
|
||||
}
|
||||
|
||||
// TerminateResourceGroup terminates a resource group.
|
||||
func (c *Client) TerminateResourceGroup(ctx context.Context) error {
|
||||
if c.resourceGroup == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
poller, err := c.resourceGroupAPI.BeginDelete(ctx, c.resourceGroup, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = poller.PollUntilDone(ctx, 30*time.Second); err != nil {
|
||||
return err
|
||||
}
|
||||
c.nodes = nil
|
||||
c.coordinators = nil
|
||||
c.resourceGroup = ""
|
||||
c.subnetID = ""
|
||||
c.networkSecurityGroup = ""
|
||||
c.nodesScaleSet = ""
|
||||
c.coordinatorsScaleSet = ""
|
||||
return nil
|
||||
}
|
374
cli/azure/client/compute_test.go
Normal file
374
cli/azure/client/compute_test.go
Normal file
@ -0,0 +1,374 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
|
||||
"github.com/edgelesssys/constellation/cli/azure"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCreateResourceGroup(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
testCases := map[string]struct {
|
||||
resourceGroupAPI resourceGroupAPI
|
||||
errExpected bool
|
||||
}{
|
||||
"successful create": {
|
||||
resourceGroupAPI: stubResourceGroupAPI{},
|
||||
},
|
||||
"failed create": {
|
||||
resourceGroupAPI: stubResourceGroupAPI{createErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
client := Client{
|
||||
location: "location",
|
||||
name: "name",
|
||||
uid: "uid",
|
||||
resourceGroupAPI: tc.resourceGroupAPI,
|
||||
nodes: make(azure.Instances),
|
||||
coordinators: make(azure.Instances),
|
||||
}
|
||||
|
||||
if tc.errExpected {
|
||||
assert.Error(client.CreateResourceGroup(ctx))
|
||||
} else {
|
||||
assert.NoError(client.CreateResourceGroup(ctx))
|
||||
assert.Equal(client.name+"-"+client.uid, client.resourceGroup)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTerminateResourceGroup(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
clientWithResourceGroup := Client{
|
||||
resourceGroup: "name",
|
||||
location: "location",
|
||||
name: "name",
|
||||
uid: "uid",
|
||||
subnetID: "subnet",
|
||||
nodesScaleSet: "node-scale-set",
|
||||
coordinatorsScaleSet: "coordinator-scale-set",
|
||||
nodes: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1",
|
||||
},
|
||||
},
|
||||
coordinators: azure.Instances{
|
||||
"0": {
|
||||
PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
testCases := map[string]struct {
|
||||
resourceGroup string
|
||||
resourceGroupAPI resourceGroupAPI
|
||||
client Client
|
||||
errExpected bool
|
||||
}{
|
||||
"successful terminate": {
|
||||
resourceGroupAPI: stubResourceGroupAPI{},
|
||||
client: clientWithResourceGroup,
|
||||
},
|
||||
"no resource group to terminate": {
|
||||
resourceGroupAPI: stubResourceGroupAPI{},
|
||||
client: Client{},
|
||||
resourceGroup: "",
|
||||
},
|
||||
"failed terminate": {
|
||||
resourceGroupAPI: stubResourceGroupAPI{terminateErr: someErr},
|
||||
client: clientWithResourceGroup,
|
||||
errExpected: true,
|
||||
},
|
||||
"failed to poll terminate response": {
|
||||
resourceGroupAPI: stubResourceGroupAPI{stubResponse: stubResourceGroupsDeletePollerResponse{pollerErr: someErr}},
|
||||
client: clientWithResourceGroup,
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
tc.client.resourceGroupAPI = tc.resourceGroupAPI
|
||||
ctx := context.Background()
|
||||
|
||||
if tc.errExpected {
|
||||
assert.Error(tc.client.TerminateResourceGroup(ctx))
|
||||
return
|
||||
}
|
||||
assert.NoError(tc.client.TerminateResourceGroup(ctx))
|
||||
assert.Empty(tc.client.resourceGroup)
|
||||
assert.Empty(tc.client.subnetID)
|
||||
assert.Empty(tc.client.nodes)
|
||||
assert.Empty(tc.client.coordinators)
|
||||
assert.Empty(tc.client.nodesScaleSet)
|
||||
assert.Empty(tc.client.coordinatorsScaleSet)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateInstances(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
testCases := map[string]struct {
|
||||
publicIPAddressesAPI publicIPAddressesAPI
|
||||
networkInterfacesAPI networkInterfacesAPI
|
||||
scaleSetsAPI scaleSetsAPI
|
||||
resourceGroupAPI resourceGroupAPI
|
||||
roleAssignmentsAPI roleAssignmentsAPI
|
||||
createInstancesInput CreateInstancesInput
|
||||
errExpected bool
|
||||
}{
|
||||
"successful create": {
|
||||
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
|
||||
networkInterfacesAPI: stubNetworkInterfacesAPI{},
|
||||
scaleSetsAPI: stubScaleSetsAPI{
|
||||
stubResponse: stubVirtualMachineScaleSetsCreateOrUpdatePollerResponse{
|
||||
pollResponse: armcompute.VirtualMachineScaleSetsClientCreateOrUpdateResponse{
|
||||
VirtualMachineScaleSetsClientCreateOrUpdateResult: armcompute.VirtualMachineScaleSetsClientCreateOrUpdateResult{
|
||||
VirtualMachineScaleSet: armcompute.VirtualMachineScaleSet{Identity: &armcompute.VirtualMachineScaleSetIdentity{PrincipalID: to.StringPtr("principal-id")}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
resourceGroupAPI: newSuccessfulResourceGroupStub(),
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
||||
createInstancesInput: CreateInstancesInput{
|
||||
Count: 3,
|
||||
InstanceType: "type",
|
||||
Image: "image",
|
||||
UserAssingedIdentity: "identity",
|
||||
},
|
||||
},
|
||||
"error when creating scale set": {
|
||||
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
|
||||
networkInterfacesAPI: stubNetworkInterfacesAPI{},
|
||||
scaleSetsAPI: stubScaleSetsAPI{createErr: someErr},
|
||||
resourceGroupAPI: newSuccessfulResourceGroupStub(),
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
||||
createInstancesInput: CreateInstancesInput{
|
||||
Count: 3,
|
||||
InstanceType: "type",
|
||||
Image: "image",
|
||||
UserAssingedIdentity: "identity",
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"error when polling create scale set response": {
|
||||
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
|
||||
networkInterfacesAPI: stubNetworkInterfacesAPI{},
|
||||
scaleSetsAPI: stubScaleSetsAPI{stubResponse: stubVirtualMachineScaleSetsCreateOrUpdatePollerResponse{pollErr: someErr}},
|
||||
resourceGroupAPI: newSuccessfulResourceGroupStub(),
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
||||
createInstancesInput: CreateInstancesInput{
|
||||
Count: 3,
|
||||
InstanceType: "type",
|
||||
Image: "image",
|
||||
UserAssingedIdentity: "identity",
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"error when retrieving private IPs": {
|
||||
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
|
||||
networkInterfacesAPI: stubNetworkInterfacesAPI{getErr: someErr},
|
||||
scaleSetsAPI: stubScaleSetsAPI{},
|
||||
resourceGroupAPI: newSuccessfulResourceGroupStub(),
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
||||
createInstancesInput: CreateInstancesInput{
|
||||
Count: 3,
|
||||
InstanceType: "type",
|
||||
Image: "image",
|
||||
UserAssingedIdentity: "identity",
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
client := Client{
|
||||
location: "location",
|
||||
name: "name",
|
||||
uid: "uid",
|
||||
resourceGroup: "name",
|
||||
publicIPAddressesAPI: tc.publicIPAddressesAPI,
|
||||
networkInterfacesAPI: tc.networkInterfacesAPI,
|
||||
scaleSetsAPI: tc.scaleSetsAPI,
|
||||
resourceGroupAPI: tc.resourceGroupAPI,
|
||||
roleAssignmentsAPI: tc.roleAssignmentsAPI,
|
||||
nodes: make(azure.Instances),
|
||||
coordinators: make(azure.Instances),
|
||||
}
|
||||
|
||||
if tc.errExpected {
|
||||
assert.Error(client.CreateInstances(ctx, tc.createInstancesInput))
|
||||
} else {
|
||||
assert.NoError(client.CreateInstances(ctx, tc.createInstancesInput))
|
||||
assert.Equal(1, len(client.coordinators))
|
||||
assert.Equal(tc.createInstancesInput.Count-1, len(client.nodes))
|
||||
assert.NotEmpty(client.nodes["0"].PrivateIP)
|
||||
assert.NotEmpty(client.nodes["0"].PublicIP)
|
||||
assert.NotEmpty(client.coordinators["0"].PrivateIP)
|
||||
assert.NotEmpty(client.coordinators["0"].PublicIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
func TestCreateInstancesVMs(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
testCases := map[string]struct {
|
||||
publicIPAddressesAPI publicIPAddressesAPI
|
||||
networkInterfacesAPI networkInterfacesAPI
|
||||
virtualMachinesAPI virtualMachinesAPI
|
||||
resourceGroupAPI resourceGroupAPI
|
||||
roleAssignmentsAPI roleAssignmentsAPI
|
||||
createInstancesInput CreateInstancesInput
|
||||
errExpected bool
|
||||
}{
|
||||
"successful create": {
|
||||
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
|
||||
networkInterfacesAPI: stubNetworkInterfacesAPI{},
|
||||
virtualMachinesAPI: stubVirtualMachinesAPI{
|
||||
stubResponse: stubVirtualMachinesClientCreateOrUpdatePollerResponse{
|
||||
pollResponse: armcompute.VirtualMachinesClientCreateOrUpdateResponse{VirtualMachinesClientCreateOrUpdateResult: armcompute.VirtualMachinesClientCreateOrUpdateResult{
|
||||
VirtualMachine: armcompute.VirtualMachine{
|
||||
Identity: &armcompute.VirtualMachineIdentity{PrincipalID: to.StringPtr("principal-id")},
|
||||
},
|
||||
}},
|
||||
},
|
||||
},
|
||||
resourceGroupAPI: newSuccessfulResourceGroupStub(),
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
||||
createInstancesInput: CreateInstancesInput{
|
||||
Count: 3,
|
||||
InstanceType: "type",
|
||||
Image: "image",
|
||||
},
|
||||
},
|
||||
"error when creating scale set": {
|
||||
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
|
||||
networkInterfacesAPI: stubNetworkInterfacesAPI{},
|
||||
virtualMachinesAPI: stubVirtualMachinesAPI{createErr: someErr},
|
||||
resourceGroupAPI: newSuccessfulResourceGroupStub(),
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
||||
createInstancesInput: CreateInstancesInput{
|
||||
Count: 3,
|
||||
InstanceType: "type",
|
||||
Image: "image",
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"error when polling create scale set response": {
|
||||
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
|
||||
networkInterfacesAPI: stubNetworkInterfacesAPI{},
|
||||
virtualMachinesAPI: stubVirtualMachinesAPI{stubResponse: stubVirtualMachinesClientCreateOrUpdatePollerResponse{pollErr: someErr}},
|
||||
resourceGroupAPI: newSuccessfulResourceGroupStub(),
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
||||
createInstancesInput: CreateInstancesInput{
|
||||
Count: 3,
|
||||
InstanceType: "type",
|
||||
Image: "image",
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"error when creating NIC": {
|
||||
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
|
||||
networkInterfacesAPI: stubNetworkInterfacesAPI{createErr: someErr},
|
||||
virtualMachinesAPI: stubVirtualMachinesAPI{},
|
||||
resourceGroupAPI: newSuccessfulResourceGroupStub(),
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
||||
createInstancesInput: CreateInstancesInput{
|
||||
Count: 3,
|
||||
InstanceType: "type",
|
||||
Image: "image",
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"error when creating public IP": {
|
||||
publicIPAddressesAPI: stubPublicIPAddressesAPI{createErr: someErr},
|
||||
networkInterfacesAPI: stubNetworkInterfacesAPI{},
|
||||
virtualMachinesAPI: stubVirtualMachinesAPI{},
|
||||
resourceGroupAPI: newSuccessfulResourceGroupStub(),
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
||||
createInstancesInput: CreateInstancesInput{
|
||||
Count: 3,
|
||||
InstanceType: "type",
|
||||
Image: "image",
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"error when retrieving public IP": {
|
||||
publicIPAddressesAPI: stubPublicIPAddressesAPI{getErr: someErr},
|
||||
networkInterfacesAPI: stubNetworkInterfacesAPI{},
|
||||
virtualMachinesAPI: stubVirtualMachinesAPI{},
|
||||
resourceGroupAPI: newSuccessfulResourceGroupStub(),
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
||||
createInstancesInput: CreateInstancesInput{
|
||||
Count: 3,
|
||||
InstanceType: "type",
|
||||
Image: "image",
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
client := Client{
|
||||
location: "location",
|
||||
name: "name",
|
||||
uid: "uid",
|
||||
resourceGroup: "name",
|
||||
publicIPAddressesAPI: tc.publicIPAddressesAPI,
|
||||
networkInterfacesAPI: tc.networkInterfacesAPI,
|
||||
virtualMachinesAPI: tc.virtualMachinesAPI,
|
||||
resourceGroupAPI: tc.resourceGroupAPI,
|
||||
roleAssignmentsAPI: tc.roleAssignmentsAPI,
|
||||
nodes: make(azure.Instances),
|
||||
coordinators: make(azure.Instances),
|
||||
}
|
||||
|
||||
if tc.errExpected {
|
||||
assert.Error(client.CreateInstancesVMs(ctx, tc.createInstancesInput))
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(client.CreateInstancesVMs(ctx, tc.createInstancesInput))
|
||||
assert.Equal(1, len(client.coordinators))
|
||||
assert.Equal(tc.createInstancesInput.Count-1, len(client.nodes))
|
||||
assert.NotEmpty(client.nodes["0"].PrivateIP)
|
||||
assert.NotEmpty(client.nodes["0"].PublicIP)
|
||||
assert.NotEmpty(client.coordinators["0"].PrivateIP)
|
||||
assert.NotEmpty(client.coordinators["0"].PublicIP)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newSuccessfulResourceGroupStub() *stubResourceGroupAPI {
|
||||
return &stubResourceGroupAPI{
|
||||
getResourceGroup: armresources.ResourceGroup{
|
||||
ID: to.StringPtr("resource-group-id"),
|
||||
},
|
||||
}
|
||||
}
|
166
cli/azure/client/network.go
Normal file
166
cli/azure/client/network.go
Normal file
@ -0,0 +1,166 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
|
||||
"github.com/edgelesssys/constellation/cli/cloud/cloudtypes"
|
||||
)
|
||||
|
||||
type createNetworkInput struct {
|
||||
name string
|
||||
location string
|
||||
addressSpace string
|
||||
}
|
||||
|
||||
// CreateVirtualNetwork creates a virtual network.
|
||||
func (c *Client) CreateVirtualNetwork(ctx context.Context) error {
|
||||
createNetworkInput := createNetworkInput{
|
||||
name: "constellation-" + c.uid,
|
||||
location: c.location,
|
||||
addressSpace: "172.20.0.0/16",
|
||||
}
|
||||
|
||||
poller, err := c.networksAPI.BeginCreateOrUpdate(
|
||||
ctx, c.resourceGroup, createNetworkInput.name,
|
||||
armnetwork.VirtualNetwork{
|
||||
Name: to.StringPtr(createNetworkInput.name), // this is supposed to be read-only
|
||||
Location: to.StringPtr(createNetworkInput.location),
|
||||
Properties: &armnetwork.VirtualNetworkPropertiesFormat{
|
||||
AddressSpace: &armnetwork.AddressSpace{
|
||||
AddressPrefixes: []*string{
|
||||
to.StringPtr(createNetworkInput.addressSpace),
|
||||
},
|
||||
},
|
||||
Subnets: []*armnetwork.Subnet{
|
||||
{
|
||||
Name: to.StringPtr("default"),
|
||||
Properties: &armnetwork.SubnetPropertiesFormat{
|
||||
AddressPrefix: to.StringPtr(createNetworkInput.addressSpace),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := poller.PollUntilDone(ctx, 30*time.Second)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.subnetID = *resp.VirtualNetworksClientCreateOrUpdateResult.VirtualNetwork.Properties.Subnets[0].ID
|
||||
return nil
|
||||
}
|
||||
|
||||
type createNetworkSecurityGroupInput struct {
|
||||
name string
|
||||
location string
|
||||
rules []*armnetwork.SecurityRule
|
||||
}
|
||||
|
||||
// CreateSecurityGroup creates a security group containing firewall rules.
|
||||
func (c *Client) CreateSecurityGroup(ctx context.Context, input NetworkSecurityGroupInput) error {
|
||||
rules := input.Ingress.Azure()
|
||||
|
||||
createNetworkSecurityGroupInput := createNetworkSecurityGroupInput{
|
||||
name: "constellation-security-group-" + c.uid,
|
||||
location: c.location,
|
||||
rules: rules,
|
||||
}
|
||||
|
||||
poller, err := c.networkSecurityGroupsAPI.BeginCreateOrUpdate(
|
||||
ctx, c.resourceGroup, createNetworkSecurityGroupInput.name,
|
||||
armnetwork.SecurityGroup{
|
||||
Name: to.StringPtr(createNetworkSecurityGroupInput.name),
|
||||
Location: to.StringPtr(createNetworkSecurityGroupInput.location),
|
||||
Properties: &armnetwork.SecurityGroupPropertiesFormat{
|
||||
SecurityRules: createNetworkSecurityGroupInput.rules,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pollerResp, err := poller.PollUntilDone(ctx, 30*time.Second)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.networkSecurityGroup = *pollerResp.SecurityGroupsClientCreateOrUpdateResult.SecurityGroup.ID
|
||||
return nil
|
||||
}
|
||||
|
||||
// createNIC creates a network interface that references a public IP address.
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
func (c *Client) createNIC(ctx context.Context, name, publicIPAddressID string) (ip string, id string, err error) {
|
||||
poller, err := c.networkInterfacesAPI.BeginCreateOrUpdate(
|
||||
ctx, c.resourceGroup, name,
|
||||
armnetwork.Interface{
|
||||
Location: to.StringPtr(c.location),
|
||||
Properties: &armnetwork.InterfacePropertiesFormat{
|
||||
NetworkSecurityGroup: &armnetwork.SecurityGroup{
|
||||
ID: to.StringPtr(c.networkSecurityGroup),
|
||||
},
|
||||
IPConfigurations: []*armnetwork.InterfaceIPConfiguration{
|
||||
{
|
||||
Name: to.StringPtr(name),
|
||||
Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{
|
||||
Subnet: &armnetwork.Subnet{
|
||||
ID: to.StringPtr(c.subnetID),
|
||||
},
|
||||
PublicIPAddress: &armnetwork.PublicIPAddress{
|
||||
ID: to.StringPtr(publicIPAddressID),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
pollerResp, err := poller.PollUntilDone(ctx, 30*time.Second)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
netInterface := pollerResp.InterfacesClientCreateOrUpdateResult.Interface
|
||||
|
||||
return *netInterface.Properties.IPConfigurations[0].Properties.PrivateIPAddress,
|
||||
*netInterface.ID,
|
||||
nil
|
||||
}
|
||||
|
||||
// createPublicIPAddress creates a public IP address.
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
func (c *Client) createPublicIPAddress(ctx context.Context, name string) (string, error) {
|
||||
poller, err := c.publicIPAddressesAPI.BeginCreateOrUpdate(
|
||||
ctx, c.resourceGroup, name,
|
||||
armnetwork.PublicIPAddress{
|
||||
Location: to.StringPtr(c.location),
|
||||
},
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
pollerResp, err := poller.PollUntilDone(ctx, 30*time.Second)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return *pollerResp.PublicIPAddressesClientCreateOrUpdateResult.PublicIPAddress.ID, nil
|
||||
}
|
||||
|
||||
// NetworkSecurityGroupInput defines firewall rules to be set.
|
||||
type NetworkSecurityGroupInput struct {
|
||||
Ingress cloudtypes.Firewall
|
||||
Egress cloudtypes.Firewall
|
||||
}
|
220
cli/azure/client/network_test.go
Normal file
220
cli/azure/client/network_test.go
Normal file
@ -0,0 +1,220 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/azure"
|
||||
"github.com/edgelesssys/constellation/cli/cloud/cloudtypes"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCreateVirtualNetwork(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
testCases := map[string]struct {
|
||||
networksAPI networksAPI
|
||||
errExpected bool
|
||||
}{
|
||||
"successful create": {
|
||||
networksAPI: stubNetworksAPI{},
|
||||
},
|
||||
"failed to get response from successful create": {
|
||||
networksAPI: stubNetworksAPI{stubResponse: stubVirtualNetworksCreateOrUpdatePollerResponse{pollerErr: someErr}},
|
||||
errExpected: true,
|
||||
},
|
||||
"failed create": {
|
||||
networksAPI: stubNetworksAPI{createErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
ctx := context.Background()
|
||||
client := Client{
|
||||
resourceGroup: "resource-group",
|
||||
location: "location",
|
||||
name: "name",
|
||||
uid: "uid",
|
||||
networksAPI: tc.networksAPI,
|
||||
nodes: make(azure.Instances),
|
||||
coordinators: make(azure.Instances),
|
||||
}
|
||||
|
||||
if tc.errExpected {
|
||||
assert.Error(client.CreateVirtualNetwork(ctx))
|
||||
} else {
|
||||
assert.NoError(client.CreateVirtualNetwork(ctx))
|
||||
assert.NotEmpty(client.subnetID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSecurityGroup(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
testNetworkSecurityGroupInput := NetworkSecurityGroupInput{
|
||||
Ingress: cloudtypes.Firewall{
|
||||
{
|
||||
Name: "test-1",
|
||||
Description: "test-1 description",
|
||||
Protocol: "tcp",
|
||||
IPRange: "192.0.2.0/24",
|
||||
Port: 9000,
|
||||
},
|
||||
{
|
||||
Name: "test-2",
|
||||
Description: "test-2 description",
|
||||
Protocol: "udp",
|
||||
IPRange: "192.0.2.0/24",
|
||||
Port: 51820,
|
||||
},
|
||||
},
|
||||
Egress: cloudtypes.Firewall{},
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
networkSecurityGroupsAPI networkSecurityGroupsAPI
|
||||
errExpected bool
|
||||
}{
|
||||
"successful create": {
|
||||
networkSecurityGroupsAPI: stubNetworkSecurityGroupsAPI{},
|
||||
},
|
||||
"failed to get response from successful create": {
|
||||
networkSecurityGroupsAPI: stubNetworkSecurityGroupsAPI{stubPoller: stubNetworkSecurityGroupsCreateOrUpdatePollerResponse{pollerErr: someErr}},
|
||||
errExpected: true,
|
||||
},
|
||||
"failed create": {
|
||||
networkSecurityGroupsAPI: stubNetworkSecurityGroupsAPI{createErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
ctx := context.Background()
|
||||
client := Client{
|
||||
resourceGroup: "resource-group",
|
||||
location: "location",
|
||||
name: "name",
|
||||
uid: "uid",
|
||||
nodes: make(azure.Instances),
|
||||
coordinators: make(azure.Instances),
|
||||
networkSecurityGroupsAPI: tc.networkSecurityGroupsAPI,
|
||||
}
|
||||
|
||||
if tc.errExpected {
|
||||
assert.Error(client.CreateSecurityGroup(ctx, testNetworkSecurityGroupInput))
|
||||
} else {
|
||||
assert.NoError(client.CreateSecurityGroup(ctx, testNetworkSecurityGroupInput))
|
||||
assert.Equal("network-security-group-id", client.networkSecurityGroup)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
func TestCreateNIC(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
|
||||
testCases := map[string]struct {
|
||||
networkInterfacesAPI networkInterfacesAPI
|
||||
name string
|
||||
publicIPAddressID string
|
||||
errExpected bool
|
||||
}{
|
||||
"successful create": {
|
||||
networkInterfacesAPI: stubNetworkInterfacesAPI{},
|
||||
name: "nic-name",
|
||||
publicIPAddressID: "pubIP-id",
|
||||
},
|
||||
"failed to get response from successful create": {
|
||||
networkInterfacesAPI: stubNetworkInterfacesAPI{stubResp: stubInterfacesClientCreateOrUpdatePollerResponse{pollErr: someErr}},
|
||||
errExpected: true,
|
||||
},
|
||||
"failed create": {
|
||||
networkInterfacesAPI: stubNetworkInterfacesAPI{createErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
ctx := context.Background()
|
||||
client := Client{
|
||||
resourceGroup: "resource-group",
|
||||
location: "location",
|
||||
name: "name",
|
||||
uid: "uid",
|
||||
nodes: make(azure.Instances),
|
||||
coordinators: make(azure.Instances),
|
||||
networkInterfacesAPI: tc.networkInterfacesAPI,
|
||||
}
|
||||
|
||||
ip, id, err := client.createNIC(ctx, tc.name, tc.publicIPAddressID)
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.NotEmpty(ip)
|
||||
assert.NotEmpty(id)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
func TestCreatePublicIPAddress(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
|
||||
testCases := map[string]struct {
|
||||
publicIPAddressesAPI publicIPAddressesAPI
|
||||
name string
|
||||
errExpected bool
|
||||
}{
|
||||
"successful create": {
|
||||
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
|
||||
name: "nic-name",
|
||||
},
|
||||
"failed to get response from successful create": {
|
||||
publicIPAddressesAPI: stubPublicIPAddressesAPI{stubCreateResponse: stubPublicIPAddressesClientCreateOrUpdatePollerResponse{pollErr: someErr}},
|
||||
errExpected: true,
|
||||
},
|
||||
"failed create": {
|
||||
publicIPAddressesAPI: stubPublicIPAddressesAPI{createErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
ctx := context.Background()
|
||||
client := Client{
|
||||
resourceGroup: "resource-group",
|
||||
location: "location",
|
||||
name: "name",
|
||||
uid: "uid",
|
||||
nodes: make(azure.Instances),
|
||||
coordinators: make(azure.Instances),
|
||||
publicIPAddressesAPI: tc.publicIPAddressesAPI,
|
||||
}
|
||||
|
||||
id, err := client.createPublicIPAddress(ctx, tc.name)
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.NotEmpty(id)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
135
cli/azure/instances.go
Normal file
135
cli/azure/instances.go
Normal file
@ -0,0 +1,135 @@
|
||||
package azure
|
||||
|
||||
// copy of ec2/instances.go
|
||||
|
||||
// TODO(katexochen): refactor into mulitcloud package.
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
|
||||
)
|
||||
|
||||
// Instance is a azure instance.
|
||||
type Instance struct {
|
||||
PublicIP string
|
||||
PrivateIP string
|
||||
}
|
||||
|
||||
// Instances is a map of azure Instances. The ID of an instance is used as key.
|
||||
type Instances map[string]Instance
|
||||
|
||||
// IDs returns the IDs of all instances of the Constellation.
|
||||
func (i Instances) IDs() []string {
|
||||
var ids []string
|
||||
for id := range i {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// PublicIPs returns the public IPs of all the instances of the Constellation.
|
||||
func (i Instances) PublicIPs() []string {
|
||||
var ips []string
|
||||
for _, instance := range i {
|
||||
ips = append(ips, instance.PublicIP)
|
||||
}
|
||||
return ips
|
||||
}
|
||||
|
||||
// PrivateIPs returns the private IPs of all the instances of the Constellation.
|
||||
func (i Instances) PrivateIPs() []string {
|
||||
var ips []string
|
||||
for _, instance := range i {
|
||||
ips = append(ips, instance.PrivateIP)
|
||||
}
|
||||
return ips
|
||||
}
|
||||
|
||||
// GetOne return anyone instance out of the instances and its ID.
|
||||
func (i Instances) GetOne() (string, Instance, error) {
|
||||
for id, instance := range i {
|
||||
return id, instance, nil
|
||||
}
|
||||
return "", Instance{}, errors.New("map is empty")
|
||||
}
|
||||
|
||||
// GetOthers returns all instances but the one with the handed ID.
|
||||
func (i Instances) GetOthers(id string) Instances {
|
||||
others := make(Instances)
|
||||
for key, instance := range i {
|
||||
if key != id {
|
||||
others[key] = instance
|
||||
}
|
||||
}
|
||||
return others
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
type VMInstance struct {
|
||||
Name string
|
||||
Location string
|
||||
InstanceType string
|
||||
Username string
|
||||
Password string
|
||||
NIC string
|
||||
Image string
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
func (i VMInstance) Azure() armcompute.VirtualMachine {
|
||||
return armcompute.VirtualMachine{
|
||||
Name: to.StringPtr(i.Name),
|
||||
Location: to.StringPtr(i.Location),
|
||||
Properties: &armcompute.VirtualMachineProperties{
|
||||
HardwareProfile: &armcompute.HardwareProfile{
|
||||
VMSize: (*armcompute.VirtualMachineSizeTypes)(to.StringPtr(i.InstanceType)),
|
||||
},
|
||||
OSProfile: &armcompute.OSProfile{
|
||||
ComputerName: to.StringPtr(i.Name),
|
||||
AdminPassword: to.StringPtr(i.Password),
|
||||
AdminUsername: to.StringPtr(i.Username),
|
||||
},
|
||||
SecurityProfile: &armcompute.SecurityProfile{
|
||||
UefiSettings: &armcompute.UefiSettings{
|
||||
SecureBootEnabled: to.BoolPtr(true),
|
||||
VTpmEnabled: to.BoolPtr(true),
|
||||
},
|
||||
SecurityType: armcompute.SecurityTypesConfidentialVM.ToPtr(),
|
||||
},
|
||||
NetworkProfile: &armcompute.NetworkProfile{
|
||||
NetworkInterfaces: []*armcompute.NetworkInterfaceReference{
|
||||
{
|
||||
ID: to.StringPtr(i.NIC),
|
||||
},
|
||||
},
|
||||
},
|
||||
StorageProfile: &armcompute.StorageProfile{
|
||||
OSDisk: &armcompute.OSDisk{
|
||||
CreateOption: armcompute.DiskCreateOptionTypesFromImage.ToPtr(),
|
||||
ManagedDisk: &armcompute.ManagedDiskParameters{
|
||||
StorageAccountType: armcompute.StorageAccountTypesPremiumLRS.ToPtr(),
|
||||
SecurityProfile: &armcompute.VMDiskSecurityProfile{
|
||||
SecurityEncryptionType: armcompute.SecurityEncryptionTypesVMGuestStateOnly.ToPtr(),
|
||||
},
|
||||
},
|
||||
},
|
||||
ImageReference: &armcompute.ImageReference{
|
||||
Publisher: to.StringPtr("0001-com-ubuntu-confidential-vm-focal"),
|
||||
Offer: to.StringPtr("canonical"),
|
||||
SKU: to.StringPtr("20_04-lts-gen2"),
|
||||
Version: to.StringPtr("latest"),
|
||||
},
|
||||
},
|
||||
DiagnosticsProfile: &armcompute.DiagnosticsProfile{
|
||||
BootDiagnostics: &armcompute.BootDiagnostics{
|
||||
Enabled: to.BoolPtr(true),
|
||||
},
|
||||
},
|
||||
},
|
||||
Identity: &armcompute.VirtualMachineIdentity{
|
||||
Type: armcompute.ResourceIdentityTypeSystemAssigned.ToPtr(),
|
||||
},
|
||||
}
|
||||
}
|
71
cli/azure/instances_test.go
Normal file
71
cli/azure/instances_test.go
Normal file
@ -0,0 +1,71 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIDs(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
testState := testInstances()
|
||||
expectedIDs := []string{"id-9", "id-10", "id-11", "id-12"}
|
||||
assert.ElementsMatch(expectedIDs, testState.IDs())
|
||||
}
|
||||
|
||||
func TestPublicIPs(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
testState := testInstances()
|
||||
expectedIPs := []string{"192.0.2.1", "192.0.2.3", "192.0.2.5", "192.0.2.7"}
|
||||
assert.ElementsMatch(expectedIPs, testState.PublicIPs())
|
||||
}
|
||||
|
||||
func TestPrivateIPs(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
testState := testInstances()
|
||||
expectedIPs := []string{"192.0.2.2", "192.0.2.4", "192.0.2.6", "192.0.2.8"}
|
||||
assert.ElementsMatch(expectedIPs, testState.PrivateIPs())
|
||||
}
|
||||
|
||||
func TestGetOne(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
testState := testInstances()
|
||||
id, instance, err := testState.GetOne()
|
||||
assert.NoError(err)
|
||||
assert.Contains(testState, id)
|
||||
assert.Equal(testState[id], instance)
|
||||
}
|
||||
|
||||
func TestGetOthers(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
testCases := testInstances().IDs()
|
||||
|
||||
for _, id := range testCases {
|
||||
others := testInstances().GetOthers(id)
|
||||
assert.NotContains(others, id)
|
||||
expectedInstances := testInstances()
|
||||
delete(expectedInstances, id)
|
||||
assert.ElementsMatch(others.IDs(), expectedInstances.IDs())
|
||||
}
|
||||
}
|
||||
|
||||
func testInstances() Instances {
|
||||
return Instances{
|
||||
"id-9": {
|
||||
PublicIP: "192.0.2.1",
|
||||
PrivateIP: "192.0.2.2",
|
||||
},
|
||||
"id-10": {
|
||||
PublicIP: "192.0.2.3",
|
||||
PrivateIP: "192.0.2.4",
|
||||
},
|
||||
"id-11": {
|
||||
PublicIP: "192.0.2.5",
|
||||
PrivateIP: "192.0.2.6",
|
||||
},
|
||||
"id-12": {
|
||||
PublicIP: "192.0.2.7",
|
||||
PrivateIP: "192.0.2.8",
|
||||
},
|
||||
}
|
||||
}
|
13
cli/azure/instancetypes.go
Normal file
13
cli/azure/instancetypes.go
Normal file
@ -0,0 +1,13 @@
|
||||
package azure
|
||||
|
||||
import "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
|
||||
|
||||
// InstanceTypes are valid Azure instance types.
|
||||
// Normally, this would be string(armcompute.VirtualMachineSizeTypesStandardD4SV3),
|
||||
// but currently needed instances are not in SDK.
|
||||
var InstanceTypes = []string{
|
||||
string(armcompute.VirtualMachineSizeTypesStandardD4SV3),
|
||||
"Standard_DC2as_v5",
|
||||
"Standard_DC4as_v5",
|
||||
"Standard_DC8as_v5",
|
||||
}
|
123
cli/azure/scaleset.go
Normal file
123
cli/azure/scaleset.go
Normal file
@ -0,0 +1,123 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"math/big"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
|
||||
)
|
||||
|
||||
// ScaleSet defines a Azure scale set.
|
||||
type ScaleSet struct {
|
||||
Name string
|
||||
NamePrefix string
|
||||
Location string
|
||||
InstanceType string
|
||||
Count int64
|
||||
Username string
|
||||
SubnetID string
|
||||
NetworkSecurityGroup string
|
||||
Password string
|
||||
Image string
|
||||
UserAssignedIdentity string
|
||||
}
|
||||
|
||||
// Azure returns the Azure representation of ScaleSet.
|
||||
func (s ScaleSet) Azure() armcompute.VirtualMachineScaleSet {
|
||||
return armcompute.VirtualMachineScaleSet{
|
||||
Name: to.StringPtr(s.Name),
|
||||
Location: to.StringPtr(s.Location),
|
||||
SKU: &armcompute.SKU{
|
||||
Name: to.StringPtr(s.InstanceType),
|
||||
Capacity: to.Int64Ptr(s.Count),
|
||||
},
|
||||
Properties: &armcompute.VirtualMachineScaleSetProperties{
|
||||
Overprovision: to.BoolPtr(false),
|
||||
UpgradePolicy: &armcompute.UpgradePolicy{
|
||||
Mode: armcompute.UpgradeModeManual.ToPtr(),
|
||||
AutomaticOSUpgradePolicy: &armcompute.AutomaticOSUpgradePolicy{
|
||||
EnableAutomaticOSUpgrade: to.BoolPtr(false),
|
||||
DisableAutomaticRollback: to.BoolPtr(false),
|
||||
},
|
||||
},
|
||||
VirtualMachineProfile: &armcompute.VirtualMachineScaleSetVMProfile{
|
||||
OSProfile: &armcompute.VirtualMachineScaleSetOSProfile{
|
||||
ComputerNamePrefix: to.StringPtr(s.NamePrefix),
|
||||
AdminUsername: to.StringPtr(s.Username),
|
||||
AdminPassword: to.StringPtr(s.Password),
|
||||
LinuxConfiguration: &armcompute.LinuxConfiguration{},
|
||||
},
|
||||
StorageProfile: &armcompute.VirtualMachineScaleSetStorageProfile{
|
||||
ImageReference: &armcompute.ImageReference{
|
||||
ID: to.StringPtr(s.Image),
|
||||
},
|
||||
},
|
||||
NetworkProfile: &armcompute.VirtualMachineScaleSetNetworkProfile{
|
||||
NetworkInterfaceConfigurations: []*armcompute.VirtualMachineScaleSetNetworkConfiguration{
|
||||
{
|
||||
Name: to.StringPtr(s.Name),
|
||||
Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{
|
||||
Primary: to.BoolPtr(true),
|
||||
EnableIPForwarding: to.BoolPtr(true),
|
||||
IPConfigurations: []*armcompute.VirtualMachineScaleSetIPConfiguration{
|
||||
{
|
||||
Name: to.StringPtr(s.Name),
|
||||
Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{
|
||||
Subnet: &armcompute.APIEntityReference{
|
||||
ID: to.StringPtr(s.SubnetID),
|
||||
},
|
||||
PublicIPAddressConfiguration: &armcompute.VirtualMachineScaleSetPublicIPAddressConfiguration{
|
||||
Name: to.StringPtr(s.Name),
|
||||
Properties: &armcompute.VirtualMachineScaleSetPublicIPAddressConfigurationProperties{
|
||||
IdleTimeoutInMinutes: to.Int32Ptr(15), // default per https://docs.microsoft.com/en-us/azure/virtual-machine-scale-sets/virtual-machine-scale-sets-networking#creating-a-scale-set-with-public-ip-per-virtual-machine
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
NetworkSecurityGroup: &armcompute.SubResource{
|
||||
ID: to.StringPtr(s.NetworkSecurityGroup),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
SecurityProfile: &armcompute.SecurityProfile{
|
||||
SecurityType: armcompute.SecurityTypesTrustedLaunch.ToPtr(),
|
||||
UefiSettings: &armcompute.UefiSettings{VTpmEnabled: to.BoolPtr(true)},
|
||||
},
|
||||
DiagnosticsProfile: &armcompute.DiagnosticsProfile{
|
||||
BootDiagnostics: &armcompute.BootDiagnostics{
|
||||
Enabled: to.BoolPtr(true),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Identity: &armcompute.VirtualMachineScaleSetIdentity{
|
||||
Type: armcompute.ResourceIdentityTypeUserAssigned.ToPtr(),
|
||||
UserAssignedIdentities: map[string]*armcompute.VirtualMachineScaleSetIdentityUserAssignedIdentitiesValue{
|
||||
s.UserAssignedIdentity: {},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GeneratePassword is a helper function to generate a random password
|
||||
// for Azure's scale set.
|
||||
func GeneratePassword() (string, error) {
|
||||
letters := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
|
||||
|
||||
pwLen := 16
|
||||
pw := make([]byte, 0, pwLen)
|
||||
for i := 0; i < pwLen; i++ {
|
||||
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
pw = append(pw, letters[n.Int64()])
|
||||
}
|
||||
// bypass password rules
|
||||
pw = append(pw, []byte("Aa1!")...)
|
||||
return string(pw), nil
|
||||
}
|
111
cli/azure/scaleset_test.go
Normal file
111
cli/azure/scaleset_test.go
Normal file
@ -0,0 +1,111 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFirewallPermissions(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
scaleSet := ScaleSet{
|
||||
Name: "name",
|
||||
NamePrefix: "constellation-",
|
||||
Location: "UK South",
|
||||
InstanceType: "Standard_D2s_v3",
|
||||
Count: 3,
|
||||
Username: "constellation",
|
||||
SubnetID: "subnet-id",
|
||||
NetworkSecurityGroup: "network-security-group",
|
||||
Password: "password",
|
||||
Image: "image",
|
||||
UserAssignedIdentity: "user-identity",
|
||||
}
|
||||
|
||||
scaleSetAzure := scaleSet.Azure()
|
||||
|
||||
require.NotNil(scaleSetAzure.Name)
|
||||
assert.Equal(scaleSet.Name, *scaleSetAzure.Name)
|
||||
require.NotNil(scaleSetAzure.Location)
|
||||
assert.Equal(scaleSet.Location, *scaleSetAzure.Location)
|
||||
|
||||
require.NotNil(scaleSetAzure.SKU)
|
||||
require.NotNil(scaleSetAzure.SKU.Name)
|
||||
assert.Equal(scaleSet.InstanceType, *scaleSetAzure.SKU.Name)
|
||||
|
||||
require.NotNil(scaleSetAzure.SKU.Capacity)
|
||||
assert.Equal(scaleSet.Count, *scaleSetAzure.SKU.Capacity)
|
||||
|
||||
require.NotNil(scaleSetAzure.Properties)
|
||||
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile)
|
||||
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.OSProfile)
|
||||
|
||||
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.ComputerNamePrefix)
|
||||
assert.Equal(scaleSet.NamePrefix, *scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.ComputerNamePrefix)
|
||||
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.AdminUsername)
|
||||
assert.Equal(scaleSet.Username, *scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.AdminUsername)
|
||||
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.AdminPassword)
|
||||
assert.Equal(scaleSet.Password, *scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.AdminPassword)
|
||||
|
||||
// Verify image
|
||||
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.StorageProfile)
|
||||
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.StorageProfile.ImageReference)
|
||||
|
||||
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.StorageProfile.ImageReference.ID)
|
||||
assert.Equal(scaleSet.Image, *scaleSetAzure.Properties.VirtualMachineProfile.StorageProfile.ImageReference.ID)
|
||||
|
||||
// Verify network
|
||||
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.NetworkProfile)
|
||||
require.Len(scaleSetAzure.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations, 1)
|
||||
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations[0])
|
||||
|
||||
networkConfig := scaleSetAzure.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations[0]
|
||||
|
||||
require.NotNil(networkConfig.Name)
|
||||
assert.Equal(scaleSet.Name, *networkConfig.Name)
|
||||
|
||||
require.NotNil(networkConfig.Properties)
|
||||
require.Len(networkConfig.Properties.IPConfigurations, 1)
|
||||
require.NotNil(networkConfig.Properties.IPConfigurations[0])
|
||||
|
||||
ipConfig := networkConfig.Properties.IPConfigurations[0]
|
||||
|
||||
require.NotNil(ipConfig.Name)
|
||||
assert.Equal(scaleSet.Name, *ipConfig.Name)
|
||||
|
||||
require.NotNil(ipConfig.Properties)
|
||||
require.NotNil(ipConfig.Properties.Subnet)
|
||||
|
||||
require.NotNil(ipConfig.Properties.Subnet.ID)
|
||||
assert.Equal(scaleSet.SubnetID, *ipConfig.Properties.Subnet.ID)
|
||||
|
||||
require.NotNil(networkConfig.Properties.NetworkSecurityGroup)
|
||||
assert.Equal(scaleSet.NetworkSecurityGroup, *networkConfig.Properties.NetworkSecurityGroup.ID)
|
||||
|
||||
// Verify vTPM
|
||||
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile)
|
||||
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile.SecurityType)
|
||||
assert.Equal(armcompute.SecurityTypesTrustedLaunch, *scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile.SecurityType)
|
||||
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile.UefiSettings)
|
||||
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile.UefiSettings.VTpmEnabled)
|
||||
assert.True(*scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile.UefiSettings.VTpmEnabled)
|
||||
|
||||
// Verify UserAssignedIdentity
|
||||
require.NotNil(scaleSetAzure.Identity)
|
||||
require.NotNil(scaleSetAzure.Identity.Type)
|
||||
assert.Equal(armcompute.ResourceIdentityTypeUserAssigned, *scaleSetAzure.Identity.Type)
|
||||
require.Len(scaleSetAzure.Identity.UserAssignedIdentities, 1)
|
||||
assert.Contains(scaleSetAzure.Identity.UserAssignedIdentities, scaleSet.UserAssignedIdentity)
|
||||
}
|
||||
|
||||
func TestGeneratePassword(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
pw, err := GeneratePassword()
|
||||
require.NoError(err)
|
||||
assert.Len(pw, 20)
|
||||
}
|
88
cli/cloud/cloudtypes/firewall.go
Normal file
88
cli/cloud/cloudtypes/firewall.go
Normal file
@ -0,0 +1,88 @@
|
||||
package cloudtypes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
|
||||
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
|
||||
computepb "google.golang.org/genproto/googleapis/cloud/compute/v1"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
type FirewallRule struct {
|
||||
Name string
|
||||
Description string
|
||||
Protocol string
|
||||
IPRange string
|
||||
Port int
|
||||
}
|
||||
|
||||
type Firewall []FirewallRule
|
||||
|
||||
func (f Firewall) GCP() []*computepb.Firewall {
|
||||
var fw []*computepb.Firewall
|
||||
for _, rule := range f {
|
||||
var destRange []string = nil
|
||||
if rule.IPRange != "" {
|
||||
destRange = append(destRange, rule.IPRange)
|
||||
}
|
||||
|
||||
fw = append(fw, &computepb.Firewall{
|
||||
Allowed: []*computepb.Allowed{
|
||||
{
|
||||
IPProtocol: proto.String(rule.Protocol),
|
||||
Ports: []string{fmt.Sprint(rule.Port)},
|
||||
},
|
||||
},
|
||||
Description: proto.String(rule.Description),
|
||||
DestinationRanges: destRange,
|
||||
Name: proto.String(rule.Name),
|
||||
})
|
||||
}
|
||||
return fw
|
||||
}
|
||||
|
||||
func (f Firewall) Azure() []*armnetwork.SecurityRule {
|
||||
var fw []*armnetwork.SecurityRule
|
||||
for i, rule := range f {
|
||||
// format string according to armnetwork.SecurityRuleProtocol specification
|
||||
protocol := strings.Title(strings.ToLower(rule.Protocol))
|
||||
|
||||
fw = append(fw, &armnetwork.SecurityRule{
|
||||
Name: proto.String(rule.Name),
|
||||
Properties: &armnetwork.SecurityRulePropertiesFormat{
|
||||
Description: proto.String(rule.Description),
|
||||
Protocol: (*armnetwork.SecurityRuleProtocol)(proto.String(protocol)),
|
||||
SourceAddressPrefix: proto.String(rule.IPRange),
|
||||
SourcePortRange: proto.String("*"),
|
||||
DestinationAddressPrefix: proto.String(rule.IPRange),
|
||||
DestinationPortRange: proto.String(strconv.Itoa(rule.Port)),
|
||||
Access: armnetwork.SecurityRuleAccessAllow.ToPtr(),
|
||||
Direction: armnetwork.SecurityRuleDirectionInbound.ToPtr(),
|
||||
// Each security role needs a unique priority
|
||||
Priority: proto.Int32(int32(100 * (i + 1))),
|
||||
},
|
||||
})
|
||||
}
|
||||
return fw
|
||||
}
|
||||
|
||||
func (f Firewall) AWS() []ec2types.IpPermission {
|
||||
var fw []ec2types.IpPermission
|
||||
for _, rule := range f {
|
||||
fw = append(fw, ec2types.IpPermission{
|
||||
FromPort: proto.Int32(int32(rule.Port)),
|
||||
ToPort: proto.Int32(int32(rule.Port)),
|
||||
IpProtocol: proto.String(rule.Protocol),
|
||||
IpRanges: []ec2types.IpRange{
|
||||
{
|
||||
CidrIp: proto.String(rule.IPRange),
|
||||
Description: proto.String(rule.Description),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
return fw
|
||||
}
|
188
cli/cloud/cloudtypes/firewall_test.go
Normal file
188
cli/cloud/cloudtypes/firewall_test.go
Normal file
@ -0,0 +1,188 @@
|
||||
package cloudtypes
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
|
||||
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
func TestFirewallGCP(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
testFw := Firewall{
|
||||
{
|
||||
Name: "test-1",
|
||||
Description: "This is the Test-1 Permission",
|
||||
Protocol: "tcp",
|
||||
IPRange: "",
|
||||
Port: 9000,
|
||||
},
|
||||
{
|
||||
Name: "test-2",
|
||||
Description: "This is the Test-2 Permission",
|
||||
Protocol: "udp",
|
||||
IPRange: "",
|
||||
Port: 51820,
|
||||
},
|
||||
}
|
||||
|
||||
firewalls := testFw.GCP()
|
||||
assert.Equal(2, len(firewalls))
|
||||
|
||||
// Check permissions
|
||||
for i := 0; i < len(testFw); i++ {
|
||||
firewall1 := firewalls[i]
|
||||
actualPermission1 := firewall1.Allowed[0]
|
||||
|
||||
actualPort, err := strconv.Atoi(actualPermission1.GetPorts()[0])
|
||||
require.NoError(err)
|
||||
assert.Equal(testFw[i].Port, actualPort)
|
||||
assert.Equal(testFw[i].Protocol, actualPermission1.GetIPProtocol())
|
||||
|
||||
assert.Equal(testFw[i].Name, firewall1.GetName())
|
||||
assert.Equal(testFw[i].Description, firewall1.GetDescription())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirewallAzure(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
input := Firewall{
|
||||
{
|
||||
Name: "perm1",
|
||||
Description: "perm1 description",
|
||||
Protocol: "TCP",
|
||||
IPRange: "192.0.2.0/24",
|
||||
Port: 22,
|
||||
},
|
||||
{
|
||||
Name: "perm2",
|
||||
Description: "perm2 description",
|
||||
Protocol: "udp",
|
||||
IPRange: "192.0.2.0/24",
|
||||
Port: 4433,
|
||||
},
|
||||
{
|
||||
Name: "perm3",
|
||||
Description: "perm3 description",
|
||||
Protocol: "tcp",
|
||||
IPRange: "192.0.2.0/24",
|
||||
Port: 4433,
|
||||
},
|
||||
}
|
||||
expectedOutput := []*armnetwork.SecurityRule{
|
||||
{
|
||||
Name: proto.String("perm1"),
|
||||
Properties: &armnetwork.SecurityRulePropertiesFormat{
|
||||
Description: proto.String("perm1 description"),
|
||||
Protocol: armnetwork.SecurityRuleProtocolTCP.ToPtr(),
|
||||
SourceAddressPrefix: proto.String("192.0.2.0/24"),
|
||||
SourcePortRange: proto.String("*"),
|
||||
DestinationAddressPrefix: proto.String("192.0.2.0/24"),
|
||||
DestinationPortRange: proto.String("22"),
|
||||
Access: armnetwork.SecurityRuleAccessAllow.ToPtr(),
|
||||
Direction: armnetwork.SecurityRuleDirectionInbound.ToPtr(),
|
||||
Priority: proto.Int32(100),
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: proto.String("perm2"),
|
||||
Properties: &armnetwork.SecurityRulePropertiesFormat{
|
||||
Description: proto.String("perm2 description"),
|
||||
Protocol: armnetwork.SecurityRuleProtocolUDP.ToPtr(),
|
||||
SourceAddressPrefix: proto.String("192.0.2.0/24"),
|
||||
SourcePortRange: proto.String("*"),
|
||||
DestinationAddressPrefix: proto.String("192.0.2.0/24"),
|
||||
DestinationPortRange: proto.String("4433"),
|
||||
Access: armnetwork.SecurityRuleAccessAllow.ToPtr(),
|
||||
Direction: armnetwork.SecurityRuleDirectionInbound.ToPtr(),
|
||||
Priority: proto.Int32(200),
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: proto.String("perm3"),
|
||||
Properties: &armnetwork.SecurityRulePropertiesFormat{
|
||||
Description: proto.String("perm3 description"),
|
||||
Protocol: armnetwork.SecurityRuleProtocolTCP.ToPtr(),
|
||||
SourceAddressPrefix: proto.String("192.0.2.0/24"),
|
||||
SourcePortRange: proto.String("*"),
|
||||
DestinationAddressPrefix: proto.String("192.0.2.0/24"),
|
||||
DestinationPortRange: proto.String("4433"),
|
||||
Access: armnetwork.SecurityRuleAccessAllow.ToPtr(),
|
||||
Direction: armnetwork.SecurityRuleDirectionInbound.ToPtr(),
|
||||
Priority: proto.Int32(300),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out := input.Azure()
|
||||
assert.Equal(expectedOutput, out)
|
||||
}
|
||||
|
||||
func TestIPPermissonsToAWS(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
input := Firewall{
|
||||
{
|
||||
Description: "perm1",
|
||||
Protocol: "TCP",
|
||||
IPRange: "192.0.2.0/24",
|
||||
Port: 22,
|
||||
},
|
||||
{
|
||||
Description: "perm2",
|
||||
Protocol: "UDP",
|
||||
IPRange: "192.0.2.0/24",
|
||||
Port: 4433,
|
||||
},
|
||||
{
|
||||
Description: "perm3",
|
||||
Protocol: "TCP",
|
||||
IPRange: "192.0.2.0/24",
|
||||
Port: 4433,
|
||||
},
|
||||
}
|
||||
expectedOutput := []ec2types.IpPermission{
|
||||
{
|
||||
FromPort: proto.Int32(int32(22)),
|
||||
ToPort: proto.Int32(int32(22)),
|
||||
IpProtocol: proto.String("TCP"),
|
||||
IpRanges: []ec2types.IpRange{
|
||||
{
|
||||
CidrIp: proto.String("192.0.2.0/24"),
|
||||
Description: proto.String("perm1"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FromPort: proto.Int32(int32(4433)),
|
||||
ToPort: proto.Int32(int32(4433)),
|
||||
IpProtocol: proto.String("UDP"),
|
||||
IpRanges: []ec2types.IpRange{
|
||||
{
|
||||
CidrIp: proto.String("192.0.2.0/24"),
|
||||
Description: proto.String("perm2"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FromPort: proto.Int32(int32(4433)),
|
||||
ToPort: proto.Int32(int32(4433)),
|
||||
IpProtocol: proto.String("TCP"),
|
||||
IpRanges: []ec2types.IpRange{
|
||||
{
|
||||
CidrIp: proto.String("192.0.2.0/24"),
|
||||
Description: proto.String("perm3"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out := input.AWS()
|
||||
assert.Equal(expectedOutput, out)
|
||||
}
|
13
cli/cloudprovider/cloudprovider.go
Normal file
13
cli/cloudprovider/cloudprovider.go
Normal file
@ -0,0 +1,13 @@
|
||||
package cloudprovider
|
||||
|
||||
//go:generate stringer -type=CloudProvider
|
||||
|
||||
// CloudProvider is cloud provider used by the CLI.
|
||||
type CloudProvider uint32
|
||||
|
||||
const (
|
||||
Unknown CloudProvider = iota
|
||||
AWS
|
||||
Azure
|
||||
GCP
|
||||
)
|
26
cli/cloudprovider/cloudprovider_string.go
Normal file
26
cli/cloudprovider/cloudprovider_string.go
Normal file
@ -0,0 +1,26 @@
|
||||
// Code generated by "stringer -type=CloudProvider"; DO NOT EDIT.
|
||||
|
||||
package cloudprovider
|
||||
|
||||
import "strconv"
|
||||
|
||||
func _() {
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
var x [1]struct{}
|
||||
_ = x[Unknown-0]
|
||||
_ = x[AWS-1]
|
||||
_ = x[Azure-2]
|
||||
_ = x[GCP-3]
|
||||
}
|
||||
|
||||
const _CloudProvider_name = "UnknownAWSAzureGCP"
|
||||
|
||||
var _CloudProvider_index = [...]uint8{0, 7, 10, 15, 18}
|
||||
|
||||
func (i CloudProvider) String() string {
|
||||
if i >= CloudProvider(len(_CloudProvider_index)-1) {
|
||||
return "CloudProvider(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
return _CloudProvider_name[_CloudProvider_index[i]:_CloudProvider_index[i+1]]
|
||||
}
|
22
cli/cmd/azureclient.go
Normal file
22
cli/cmd/azureclient.go
Normal file
@ -0,0 +1,22 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/azure/client"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
)
|
||||
|
||||
type azureclient interface {
|
||||
GetState() (state.ConstellationState, error)
|
||||
SetState(state.ConstellationState) error
|
||||
CreateResourceGroup(ctx context.Context) error
|
||||
CreateVirtualNetwork(ctx context.Context) error
|
||||
CreateSecurityGroup(ctx context.Context, input client.NetworkSecurityGroupInput) error
|
||||
CreateInstances(ctx context.Context, input client.CreateInstancesInput) error
|
||||
// TODO: deprecate as soon as scale sets are available
|
||||
CreateInstancesVMs(ctx context.Context, input client.CreateInstancesInput) error
|
||||
CreateServicePrincipal(ctx context.Context) (string, error)
|
||||
TerminateResourceGroup(ctx context.Context) error
|
||||
TerminateServicePrincipal(ctx context.Context) error
|
||||
}
|
194
cli/cmd/azureclient_test.go
Normal file
194
cli/cmd/azureclient_test.go
Normal file
@ -0,0 +1,194 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/azure"
|
||||
"github.com/edgelesssys/constellation/cli/azure/client"
|
||||
"github.com/edgelesssys/constellation/cli/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
)
|
||||
|
||||
type fakeAzureClient struct {
|
||||
nodes azure.Instances
|
||||
coordinators azure.Instances
|
||||
|
||||
resourceGroup string
|
||||
name string
|
||||
uid string
|
||||
location string
|
||||
subscriptionID string
|
||||
tenantID string
|
||||
subnetID string
|
||||
coordinatorsScaleSet string
|
||||
nodesScaleSet string
|
||||
networkSecurityGroup string
|
||||
adAppObjectID string
|
||||
}
|
||||
|
||||
func (c *fakeAzureClient) GetState() (state.ConstellationState, error) {
|
||||
stat := state.ConstellationState{
|
||||
CloudProvider: cloudprovider.Azure.String(),
|
||||
AzureNodes: c.nodes,
|
||||
AzureCoordinators: c.coordinators,
|
||||
Name: c.name,
|
||||
UID: c.uid,
|
||||
AzureResourceGroup: c.resourceGroup,
|
||||
AzureLocation: c.location,
|
||||
AzureSubscription: c.subscriptionID,
|
||||
AzureTenant: c.tenantID,
|
||||
AzureSubnet: c.subnetID,
|
||||
AzureNetworkSecurityGroup: c.networkSecurityGroup,
|
||||
AzureNodesScaleSet: c.nodesScaleSet,
|
||||
AzureCoordinatorsScaleSet: c.coordinatorsScaleSet,
|
||||
AzureADAppObjectID: c.adAppObjectID,
|
||||
}
|
||||
return stat, nil
|
||||
}
|
||||
|
||||
func (c *fakeAzureClient) SetState(stat state.ConstellationState) error {
|
||||
c.nodes = stat.AzureNodes
|
||||
c.coordinators = stat.AzureCoordinators
|
||||
c.name = stat.Name
|
||||
c.uid = stat.UID
|
||||
c.resourceGroup = stat.AzureResourceGroup
|
||||
c.location = stat.AzureLocation
|
||||
c.subscriptionID = stat.AzureSubscription
|
||||
c.tenantID = stat.AzureTenant
|
||||
c.subnetID = stat.AzureSubnet
|
||||
c.networkSecurityGroup = stat.AzureNetworkSecurityGroup
|
||||
c.nodesScaleSet = stat.AzureNodesScaleSet
|
||||
c.coordinatorsScaleSet = stat.AzureCoordinatorsScaleSet
|
||||
c.adAppObjectID = stat.AzureADAppObjectID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeAzureClient) CreateResourceGroup(ctx context.Context) error {
|
||||
c.resourceGroup = "resource-group"
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeAzureClient) CreateVirtualNetwork(ctx context.Context) error {
|
||||
c.subnetID = "subnet"
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeAzureClient) CreateSecurityGroup(ctx context.Context, input client.NetworkSecurityGroupInput) error {
|
||||
c.networkSecurityGroup = "network-security-group"
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeAzureClient) CreateInstances(ctx context.Context, input client.CreateInstancesInput) error {
|
||||
c.coordinatorsScaleSet = "coordinators-scale-set"
|
||||
c.nodesScaleSet = "nodes-scale-set"
|
||||
c.nodes = make(azure.Instances)
|
||||
for i := 0; i < input.Count-1; i++ {
|
||||
id := strconv.Itoa(i)
|
||||
c.nodes[id] = azure.Instance{PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1"}
|
||||
}
|
||||
c.coordinators = make(azure.Instances)
|
||||
c.coordinators["0"] = azure.Instance{PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1"}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
func (c *fakeAzureClient) CreateInstancesVMs(ctx context.Context, input client.CreateInstancesInput) error {
|
||||
c.nodes = make(azure.Instances)
|
||||
for i := 0; i < input.Count-1; i++ {
|
||||
id := strconv.Itoa(i)
|
||||
c.nodes[id] = azure.Instance{PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1"}
|
||||
}
|
||||
c.coordinators = make(azure.Instances)
|
||||
c.coordinators["0"] = azure.Instance{PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1"}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeAzureClient) CreateServicePrincipal(ctx context.Context) (string, error) {
|
||||
c.adAppObjectID = "00000000-0000-0000-0000-000000000001"
|
||||
return client.ApplicationCredentials{
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
}.ConvertToCloudServiceAccountURI(), nil
|
||||
}
|
||||
|
||||
func (c *fakeAzureClient) TerminateResourceGroup(ctx context.Context) error {
|
||||
if c.resourceGroup == "" {
|
||||
return nil
|
||||
}
|
||||
c.nodes = nil
|
||||
c.coordinators = nil
|
||||
c.resourceGroup = ""
|
||||
c.subnetID = ""
|
||||
c.networkSecurityGroup = ""
|
||||
c.nodesScaleSet = ""
|
||||
c.coordinatorsScaleSet = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeAzureClient) TerminateServicePrincipal(ctx context.Context) error {
|
||||
if c.adAppObjectID == "" {
|
||||
return nil
|
||||
}
|
||||
c.adAppObjectID = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
type stubAzureClient struct {
|
||||
terminateResourceGroupCalled bool
|
||||
|
||||
getStateErr error
|
||||
setStateErr error
|
||||
createResourceGroupErr error
|
||||
createVirtualNetworkErr error
|
||||
createSecurityGroupErr error
|
||||
createInstancesErr error
|
||||
createServicePrincipalErr error
|
||||
terminateResourceGroupErr error
|
||||
terminateServicePrincipalErr error
|
||||
}
|
||||
|
||||
func (c *stubAzureClient) GetState() (state.ConstellationState, error) {
|
||||
return state.ConstellationState{}, c.getStateErr
|
||||
}
|
||||
|
||||
func (c *stubAzureClient) SetState(state.ConstellationState) error {
|
||||
return c.setStateErr
|
||||
}
|
||||
|
||||
func (c *stubAzureClient) CreateResourceGroup(ctx context.Context) error {
|
||||
return c.createResourceGroupErr
|
||||
}
|
||||
|
||||
func (c *stubAzureClient) CreateVirtualNetwork(ctx context.Context) error {
|
||||
return c.createVirtualNetworkErr
|
||||
}
|
||||
|
||||
func (c *stubAzureClient) CreateSecurityGroup(ctx context.Context, input client.NetworkSecurityGroupInput) error {
|
||||
return c.createSecurityGroupErr
|
||||
}
|
||||
|
||||
func (c *stubAzureClient) CreateInstances(ctx context.Context, input client.CreateInstancesInput) error {
|
||||
return c.createInstancesErr
|
||||
}
|
||||
|
||||
// TODO: deprecate as soon as scale sets are available.
|
||||
func (c *stubAzureClient) CreateInstancesVMs(ctx context.Context, input client.CreateInstancesInput) error {
|
||||
return c.createInstancesErr
|
||||
}
|
||||
|
||||
func (c *stubAzureClient) CreateServicePrincipal(ctx context.Context) (string, error) {
|
||||
return client.ApplicationCredentials{
|
||||
ClientID: "00000000-0000-0000-0000-000000000000",
|
||||
ClientSecret: "secret",
|
||||
}.ConvertToCloudServiceAccountURI(), c.createServicePrincipalErr
|
||||
}
|
||||
|
||||
func (c *stubAzureClient) TerminateResourceGroup(ctx context.Context) error {
|
||||
c.terminateResourceGroupCalled = true
|
||||
return c.terminateResourceGroupErr
|
||||
}
|
||||
|
||||
func (c *stubAzureClient) TerminateServicePrincipal(ctx context.Context) error {
|
||||
return c.terminateServicePrincipalErr
|
||||
}
|
13
cli/cmd/constants.go
Normal file
13
cli/cmd/constants.go
Normal file
@ -0,0 +1,13 @@
|
||||
package cmd
|
||||
|
||||
// wireguardKeyLength is the length of a WireGuard key in byte.
|
||||
const wireguardKeyLength = 32
|
||||
|
||||
// masterSecretLengthDefault is the default length in bytes for CLI generated master secrets.
|
||||
const masterSecretLengthDefault = 32
|
||||
|
||||
// masterSecretLengthMin is the minimal length in bytes for user provided master secrets.
|
||||
const masterSecretLengthMin = 16
|
||||
|
||||
// constellationNameLength is the maximum length of a Constellation's name.
|
||||
const constellationNameLength = 37
|
41
cli/cmd/create.go
Normal file
41
cli/cmd/create.go
Normal file
@ -0,0 +1,41 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/file"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newCreateCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "create aws|gcp|azure",
|
||||
Short: "Create instances on a cloud platform for your Constellation.",
|
||||
Long: "Create instances on a cloud platform for your Constellation.",
|
||||
}
|
||||
cmd.PersistentFlags().String("name", "constell", "Set this flag to create the Constellation with the specified name.")
|
||||
cmd.PersistentFlags().BoolP("yes", "y", false, "Set this flag to create the Constellation without further confirmation.")
|
||||
|
||||
cmd.AddCommand(newCreateAWSCmd())
|
||||
cmd.AddCommand(newCreateGCPCmd())
|
||||
cmd.AddCommand(newCreateAzureCmd())
|
||||
return cmd
|
||||
}
|
||||
|
||||
// checkDirClean checks if files of a previous Constellation are left in the current working dir.
|
||||
func checkDirClean(fileHandler file.Handler, config *config.Config) error {
|
||||
if _, err := fileHandler.Stat(*config.StatePath); !errors.Is(err, fs.ErrNotExist) {
|
||||
return fmt.Errorf("file '%s' already exists in working directory, run 'constellation terminate' before creating a new one", *config.StatePath)
|
||||
}
|
||||
if _, err := fileHandler.Stat(*config.AdminConfPath); !errors.Is(err, fs.ErrNotExist) {
|
||||
return fmt.Errorf("file '%s' already exists in working directory, run 'constellation terminate' before creating a new one", *config.AdminConfPath)
|
||||
}
|
||||
if _, err := fileHandler.Stat(*config.MasterSecretPath); !errors.Is(err, fs.ErrNotExist) {
|
||||
return fmt.Errorf("file '%s' already exists in working directory, clean it up first", *config.MasterSecretPath)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
138
cli/cmd/create_aws.go
Normal file
138
cli/cmd/create_aws.go
Normal file
@ -0,0 +1,138 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/ec2"
|
||||
"github.com/edgelesssys/constellation/cli/ec2/client"
|
||||
"github.com/edgelesssys/constellation/cli/file"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
)
|
||||
|
||||
func newCreateAWSCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "aws NUMBER SIZE",
|
||||
Short: "Create a Constellation of NUMBER nodes of SIZE on AWS.",
|
||||
Long: "Create a Constellation of NUMBER nodes of SIZE on AWS.",
|
||||
Example: "aws 4 2xlarge",
|
||||
Args: cobra.MatchAll(
|
||||
cobra.ExactArgs(2),
|
||||
isIntGreaterArg(0, 1),
|
||||
isEC2InstanceType(1),
|
||||
),
|
||||
ValidArgsFunction: createAWSCompletion,
|
||||
RunE: runCreateAWS,
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
// runCreateAWS runs the create command.
|
||||
func runCreateAWS(cmd *cobra.Command, args []string) error {
|
||||
count, _ := strconv.Atoi(args[0]) // err already checked in args validation
|
||||
size := strings.ToLower(args[1])
|
||||
|
||||
name, err := cmd.Flags().GetString("name")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
devConfigName, err := cmd.Flags().GetString("dev-config")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fileHandler := file.NewHandler(afero.NewOsFs())
|
||||
config, err := config.FromFile(fileHandler, devConfigName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := client.NewFromDefault(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return createAWS(cmd, client, fileHandler, config, size, name, count)
|
||||
}
|
||||
|
||||
// createAWS uses the given client to create 'count' instances of 'size'.
|
||||
// After the instances are running, they are tagged with the default tags.
|
||||
// On success, the state of the client is saved to the state file.
|
||||
func createAWS(cmd *cobra.Command, cl ec2client, fileHandler file.Handler, config *config.Config, size, name string, count int) (retErr error) {
|
||||
if err := checkDirClean(fileHandler, config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
const maxLength = 255
|
||||
if len(name) > maxLength {
|
||||
return fmt.Errorf("name for constellation too long, maximum length is %d: %s", maxLength, name)
|
||||
}
|
||||
ec2Tags := append([]ec2.Tag{}, *config.Provider.EC2.Tags...)
|
||||
ec2Tags = append(ec2Tags, ec2.Tag{Key: "Name", Value: name})
|
||||
|
||||
ok, err := cmd.Flags().GetBool("yes")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
// Ask user to confirm action.
|
||||
cmd.Printf("The following Constellation will be created:\n")
|
||||
cmd.Printf("%d nodes of size %s will be created.\n", count, size)
|
||||
ok, err := askToConfirm(cmd, "Do you want to create this Constellation?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
cmd.Println("The creation of the Constellation was aborted.")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
defer rollbackOnError(context.Background(), cmd.OutOrStdout(), &retErr, &rollbackerAWS{client: cl})
|
||||
if err := cl.CreateSecurityGroup(cmd.Context(), *config.Provider.EC2.SecurityGroupInput); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
createInput := client.CreateInput{
|
||||
ImageId: *config.Provider.EC2.Image,
|
||||
InstanceType: size,
|
||||
Count: count,
|
||||
Tags: ec2Tags,
|
||||
}
|
||||
if err := cl.CreateInstances(cmd.Context(), createInput); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stat, err := cl.GetState()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := fileHandler.WriteJSON(*config.StatePath, stat, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Println("Your Constellation was created successfully.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// createAWSCompletion handels the completion of CLI arguments. It is frequently called
|
||||
// while the user types arguments of the command to suggest completion.
|
||||
func createAWSCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
switch len(args) {
|
||||
case 0:
|
||||
return []string{}, cobra.ShellCompDirectiveNoFileComp
|
||||
case 1:
|
||||
return []string{
|
||||
"4xlarge",
|
||||
"8xlarge",
|
||||
"12xlarge",
|
||||
"16xlarge",
|
||||
"24xlarge",
|
||||
}, cobra.ShellCompDirectiveDefault
|
||||
default:
|
||||
return []string{}, cobra.ShellCompDirectiveError
|
||||
}
|
||||
}
|
205
cli/cmd/create_aws_test.go
Normal file
205
cli/cmd/create_aws_test.go
Normal file
@ -0,0 +1,205 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/cli/ec2"
|
||||
"github.com/edgelesssys/constellation/cli/file"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCreateAWSCmdArgumentValidation(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
args []string
|
||||
expectErr bool
|
||||
}{
|
||||
"valid size 4XL": {[]string{"5", "4xlarge"}, false},
|
||||
"valid size 8XL": {[]string{"4", "8xlarge"}, false},
|
||||
"valid size 12XL": {[]string{"3", "12xlarge"}, false},
|
||||
"valid size 16XL": {[]string{"2", "16xlarge"}, false},
|
||||
"valid size 24XL": {[]string{"2", "24xlarge"}, false},
|
||||
"valid short 12XL": {[]string{"4", "12xl"}, false},
|
||||
"valid short 24XL": {[]string{"2", "24xl"}, false},
|
||||
"valid capitalized": {[]string{"3", "24XlARge"}, false},
|
||||
"valid short capitalized": {[]string{"4", "16XL"}, false},
|
||||
"invalid to many arguments": {[]string{"2", "4xl", "2xl"}, true},
|
||||
"invalid to many arguments 2": {[]string{"2", "4xl", "2"}, true},
|
||||
"invalidOnlyOneInstance": {[]string{"1", "4xl"}, true},
|
||||
"invalid first is no int": {[]string{"xl", "4xl"}, true},
|
||||
"invalid second is no size": {[]string{"2", "2"}, true},
|
||||
"invalid wrong order": {[]string{"4xl", "2"}, true},
|
||||
}
|
||||
|
||||
cmd := newCreateAWSCmd()
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
err := cmd.ValidateArgs(tc.args)
|
||||
if tc.expectErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateAWS(t *testing.T) {
|
||||
testState := state.ConstellationState{
|
||||
CloudProvider: cloudprovider.AWS.String(),
|
||||
EC2Instances: ec2.Instances{
|
||||
"id-0": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.2",
|
||||
},
|
||||
"id-1": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.2",
|
||||
},
|
||||
"id-2": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.2",
|
||||
},
|
||||
},
|
||||
EC2SecurityGroup: "sg-test",
|
||||
}
|
||||
someErr := errors.New("failed")
|
||||
config := config.Default()
|
||||
|
||||
testCases := map[string]struct {
|
||||
existingState *state.ConstellationState
|
||||
client ec2client
|
||||
interactive bool
|
||||
interactiveStdin string
|
||||
stateExpected state.ConstellationState
|
||||
errExpected bool
|
||||
}{
|
||||
"create some instances": {
|
||||
client: &fakeEc2Client{},
|
||||
stateExpected: testState,
|
||||
errExpected: false,
|
||||
},
|
||||
"state already exists": {
|
||||
existingState: &testState,
|
||||
client: &fakeEc2Client{},
|
||||
errExpected: true,
|
||||
},
|
||||
"create some instances interactive": {
|
||||
client: &fakeEc2Client{},
|
||||
interactive: true,
|
||||
interactiveStdin: "y\n",
|
||||
stateExpected: testState,
|
||||
errExpected: false,
|
||||
},
|
||||
"fail CreateSecurityGroup": {
|
||||
client: &stubEc2Client{createSecurityGroupErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
"fail CreateInstances": {
|
||||
client: &stubEc2Client{createInstancesErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
"fail GetState": {
|
||||
client: &stubEc2Client{getStateErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
"error on rollback": {
|
||||
client: &stubEc2Client{createInstancesErr: someErr, deleteSecurityGroupErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
cmd := newCreateAWSCmd()
|
||||
cmd.Flags().BoolP("yes", "y", false, "")
|
||||
out := bytes.NewBufferString("")
|
||||
cmd.SetOut(out)
|
||||
errOut := bytes.NewBufferString("")
|
||||
cmd.SetErr(errOut)
|
||||
in := bytes.NewBufferString(tc.interactiveStdin)
|
||||
cmd.SetIn(in)
|
||||
|
||||
if !tc.interactive {
|
||||
require.NoError(cmd.Flags().Set("yes", "true")) // disable interactivity
|
||||
}
|
||||
fs := afero.NewMemMapFs()
|
||||
fileHandler := file.NewHandler(fs)
|
||||
if tc.existingState != nil {
|
||||
require.NoError(fileHandler.WriteJSON(*config.StatePath, *tc.existingState, false))
|
||||
}
|
||||
|
||||
err := createAWS(cmd, tc.client, fileHandler, config, "xlarge", "name", 3)
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
if stubClient, ok := tc.client.(*stubEc2Client); ok {
|
||||
// Should have made a rollback on error.
|
||||
assert.True(stubClient.terminateInstancesCalled)
|
||||
assert.True(stubClient.deleteSecurityGroupCalled)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
var stat state.ConstellationState
|
||||
err := fileHandler.ReadJSON(*config.StatePath, &stat)
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.stateExpected, stat)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateAWSCompletion(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
args []string
|
||||
toComplete string
|
||||
resultExpected []string
|
||||
shellCDExpected cobra.ShellCompDirective
|
||||
}{
|
||||
"first arg": {
|
||||
args: []string{},
|
||||
toComplete: "21",
|
||||
resultExpected: []string{},
|
||||
shellCDExpected: cobra.ShellCompDirectiveNoFileComp,
|
||||
},
|
||||
"second arg": {
|
||||
args: []string{"23"},
|
||||
toComplete: "4xl",
|
||||
resultExpected: []string{
|
||||
"4xlarge",
|
||||
"8xlarge",
|
||||
"12xlarge",
|
||||
"16xlarge",
|
||||
"24xlarge",
|
||||
},
|
||||
shellCDExpected: cobra.ShellCompDirectiveDefault,
|
||||
},
|
||||
"third arg": {
|
||||
args: []string{"23", "4xlarge"},
|
||||
toComplete: "xl",
|
||||
resultExpected: []string{},
|
||||
shellCDExpected: cobra.ShellCompDirectiveError,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
result, shellCD := createAWSCompletion(cmd, tc.args, tc.toComplete)
|
||||
assert.Equal(tc.resultExpected, result)
|
||||
assert.Equal(tc.shellCDExpected, shellCD)
|
||||
})
|
||||
}
|
||||
}
|
137
cli/cmd/create_azure.go
Normal file
137
cli/cmd/create_azure.go
Normal file
@ -0,0 +1,137 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/azure"
|
||||
"github.com/edgelesssys/constellation/cli/azure/client"
|
||||
"github.com/edgelesssys/constellation/cli/file"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newCreateAzureCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "azure",
|
||||
Short: "Create a Constellation of NUMBER nodes of SIZE on Azure.",
|
||||
Long: "Create a Constellation of NUMBER nodes of SIZE on Azure.",
|
||||
Args: cobra.MatchAll(
|
||||
cobra.ExactArgs(2),
|
||||
isIntGreaterArg(0, 1),
|
||||
isAzureInstanceType(1),
|
||||
),
|
||||
ValidArgsFunction: createAzureCompletion,
|
||||
RunE: runCreateAzure,
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
// runCreateAzure runs the create command.
|
||||
func runCreateAzure(cmd *cobra.Command, args []string) error {
|
||||
count, _ := strconv.Atoi(args[0]) // err already checked in args validation
|
||||
size := strings.ToLower(args[1])
|
||||
subscriptionID := "0d202bbb-4fa7-4af8-8125-58c269a05435" // TODO: This will be user input
|
||||
tenantID := "adb650a8-5da3-4b15-b4b0-3daf65ff7626" // TODO: This will be user input
|
||||
location := "North Europe" // TODO: This will be user input
|
||||
|
||||
name, err := cmd.Flags().GetString("name")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(name) > constellationNameLength {
|
||||
return fmt.Errorf("name for constellation too long, maximum length is %d got %d: %s", constellationNameLength, len(name), name)
|
||||
}
|
||||
|
||||
client, err := client.NewInitialized(
|
||||
subscriptionID,
|
||||
tenantID,
|
||||
name,
|
||||
location,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
devConfigName, err := cmd.Flags().GetString("dev-config")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fileHandler := file.NewHandler(afero.NewOsFs())
|
||||
config, err := config.FromFile(fileHandler, devConfigName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return createAzure(cmd, client, fileHandler, config, size, count)
|
||||
}
|
||||
|
||||
func createAzure(cmd *cobra.Command, cl azureclient, fileHandler file.Handler, config *config.Config, size string, count int) (retErr error) {
|
||||
if err := checkDirClean(fileHandler, config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ok, err := cmd.Flags().GetBool("yes")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
// Ask user to confirm action.
|
||||
cmd.Printf("The following Constellation will be created:\n")
|
||||
cmd.Printf("%d nodes of size %s will be created.\n", count, size)
|
||||
ok, err := askToConfirm(cmd, "Do you want to create this Constellation?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
cmd.Println("The creation of the Constellation was aborted.")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Create all azure resources
|
||||
defer rollbackOnError(context.Background(), cmd.OutOrStdout(), &retErr, &rollbackerAzure{client: cl})
|
||||
if err := cl.CreateResourceGroup(cmd.Context()); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := cl.CreateVirtualNetwork(cmd.Context()); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := cl.CreateSecurityGroup(cmd.Context(), *config.Provider.Azure.NetworkSecurityGroupInput); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := cl.CreateInstances(cmd.Context(), client.CreateInstancesInput{
|
||||
Count: count,
|
||||
InstanceType: size,
|
||||
Image: *config.Provider.Azure.Image,
|
||||
UserAssingedIdentity: *config.Provider.Azure.UserAssignedIdentity,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stat, err := cl.GetState()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := fileHandler.WriteJSON(*config.StatePath, stat, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Println("Your Constellation was created successfully.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// createAzureCompletion handels the completion of CLI arguments. It is frequently called
|
||||
// while the user types arguments of the command to suggest completion.
|
||||
func createAzureCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
switch len(args) {
|
||||
case 0:
|
||||
return []string{}, cobra.ShellCompDirectiveNoFileComp
|
||||
case 1:
|
||||
return azure.InstanceTypes, cobra.ShellCompDirectiveDefault
|
||||
default:
|
||||
return []string{}, cobra.ShellCompDirectiveError
|
||||
}
|
||||
}
|
206
cli/cmd/create_azure_test.go
Normal file
206
cli/cmd/create_azure_test.go
Normal file
@ -0,0 +1,206 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/azure"
|
||||
"github.com/edgelesssys/constellation/cli/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/cli/file"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCreateAzureCmdArgumentValidation(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
args []string
|
||||
expectErr bool
|
||||
}{
|
||||
"valid create 1": {[]string{"3", "Standard_DC2as_v5"}, false},
|
||||
"valid create 2": {[]string{"7", "Standard_DC4as_v5"}, false},
|
||||
"valid create 3": {[]string{"2", "Standard_DC8as_v5"}, false},
|
||||
"invalid to many arguments": {[]string{"2", "Standard_DC2as_v5", "Standard_DC2as_v5"}, true},
|
||||
"invalid to many arguments 2": {[]string{"2", "Standard_DC2as_v5", "2"}, true},
|
||||
"invalidOnlyOneInstance": {[]string{"1", "Standard_DC2as_v5"}, true},
|
||||
"invalid first is no int": {[]string{"Standard_DC2as_v5", "Standard_DC2as_v5"}, true},
|
||||
"invalid second is no size": {[]string{"2", "2"}, true},
|
||||
"invalid wrong order": {[]string{"Standard_DC2as_v5", "2"}, true},
|
||||
}
|
||||
|
||||
cmd := newCreateAzureCmd()
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
err := cmd.ValidateArgs(tc.args)
|
||||
if tc.expectErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateAzure(t *testing.T) {
|
||||
testState := state.ConstellationState{
|
||||
CloudProvider: cloudprovider.Azure.String(),
|
||||
AzureNodes: azure.Instances{
|
||||
"0": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
"1": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
},
|
||||
AzureCoordinators: azure.Instances{
|
||||
"0": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
},
|
||||
AzureResourceGroup: "resource-group",
|
||||
AzureSubnet: "subnet",
|
||||
AzureNetworkSecurityGroup: "network-security-group",
|
||||
AzureNodesScaleSet: "nodes-scale-set",
|
||||
AzureCoordinatorsScaleSet: "coordinators-scale-set",
|
||||
}
|
||||
someErr := errors.New("failed")
|
||||
config := config.Default()
|
||||
|
||||
testCases := map[string]struct {
|
||||
existingState *state.ConstellationState
|
||||
client azureclient
|
||||
interactive bool
|
||||
interactiveStdin string
|
||||
stateExpected state.ConstellationState
|
||||
errExpected bool
|
||||
}{
|
||||
"create some instances": {
|
||||
client: &fakeAzureClient{},
|
||||
stateExpected: testState,
|
||||
},
|
||||
"state already exists": {
|
||||
existingState: &testState,
|
||||
client: &fakeAzureClient{},
|
||||
errExpected: true,
|
||||
},
|
||||
"create some instances interactive": {
|
||||
client: &fakeAzureClient{},
|
||||
interactive: true,
|
||||
interactiveStdin: "y\n",
|
||||
stateExpected: testState,
|
||||
errExpected: false,
|
||||
},
|
||||
"fail getState": {
|
||||
client: &stubAzureClient{getStateErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
"fail createVirtualNetwork": {
|
||||
client: &stubAzureClient{createVirtualNetworkErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
"fail createSecurityGroup": {
|
||||
client: &stubAzureClient{createSecurityGroupErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
"fail createInstances": {
|
||||
client: &stubAzureClient{createInstancesErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
"fail createResourceGroup": {
|
||||
client: &stubAzureClient{createResourceGroupErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
"error on rollback": {
|
||||
client: &stubAzureClient{createInstancesErr: someErr, terminateResourceGroupErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
cmd := newCreateAzureCmd()
|
||||
cmd.Flags().BoolP("yes", "y", false, "")
|
||||
out := bytes.NewBufferString("")
|
||||
cmd.SetOut(out)
|
||||
errOut := bytes.NewBufferString("")
|
||||
cmd.SetErr(errOut)
|
||||
in := bytes.NewBufferString(tc.interactiveStdin)
|
||||
cmd.SetIn(in)
|
||||
if !tc.interactive {
|
||||
require.NoError(cmd.Flags().Set("yes", "true")) // disable interactivity
|
||||
}
|
||||
|
||||
fs := afero.NewMemMapFs()
|
||||
fileHandler := file.NewHandler(fs)
|
||||
if tc.existingState != nil {
|
||||
require.NoError(fileHandler.WriteJSON(*config.StatePath, *tc.existingState, false))
|
||||
}
|
||||
|
||||
err := createAzure(cmd, tc.client, fileHandler, config, "Standard_D2s_v3", 3)
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
if stubClient, ok := tc.client.(*stubAzureClient); ok {
|
||||
// Should have made a rollback on error.
|
||||
assert.True(stubClient.terminateResourceGroupCalled)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
var state state.ConstellationState
|
||||
err := fileHandler.ReadJSON(*config.StatePath, &state)
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.stateExpected, state)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateAzureCompletion(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
args []string
|
||||
toComplete string
|
||||
resultExpected []string
|
||||
shellCDExpected cobra.ShellCompDirective
|
||||
}{
|
||||
"first arg": {
|
||||
args: []string{},
|
||||
toComplete: "21",
|
||||
resultExpected: []string{},
|
||||
shellCDExpected: cobra.ShellCompDirectiveNoFileComp,
|
||||
},
|
||||
"second arg": {
|
||||
args: []string{"23"},
|
||||
toComplete: "Standard_D",
|
||||
resultExpected: azure.InstanceTypes,
|
||||
shellCDExpected: cobra.ShellCompDirectiveDefault,
|
||||
},
|
||||
"third arg": {
|
||||
args: []string{"23", "Standard_D2s_v3"},
|
||||
toComplete: "Standard_D",
|
||||
resultExpected: []string{},
|
||||
shellCDExpected: cobra.ShellCompDirectiveError,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
result, shellCD := createAzureCompletion(cmd, tc.args, tc.toComplete)
|
||||
assert.Equal(tc.resultExpected, result)
|
||||
assert.Equal(tc.shellCDExpected, shellCD)
|
||||
})
|
||||
}
|
||||
}
|
132
cli/cmd/create_gcp.go
Normal file
132
cli/cmd/create_gcp.go
Normal file
@ -0,0 +1,132 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/file"
|
||||
"github.com/edgelesssys/constellation/cli/gcp"
|
||||
"github.com/edgelesssys/constellation/cli/gcp/client"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newCreateGCPCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "gcp",
|
||||
Short: "Create a Constellation of NUMBER nodes of SIZE on Google Cloud Platform.",
|
||||
Long: "Create a Constellation of NUMBER nodes of SIZE on Google Cloud Platform.",
|
||||
Args: cobra.MatchAll(
|
||||
cobra.ExactArgs(2),
|
||||
isIntGreaterArg(0, 1),
|
||||
isGCPInstanceType(1),
|
||||
),
|
||||
ValidArgsFunction: createGCPCompletion,
|
||||
RunE: runCreateGCP,
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
// runCreateGCP runs the create command.
|
||||
func runCreateGCP(cmd *cobra.Command, args []string) error {
|
||||
count, _ := strconv.Atoi(args[0]) // err already checked in args validation
|
||||
size := strings.ToLower(args[1])
|
||||
project := "constellation-331613" // TODO: This will be user input
|
||||
zone := "us-central1-c" // TODO: This will be user input
|
||||
region := "us-central1" // TODO: This will be user input
|
||||
|
||||
name, err := cmd.Flags().GetString("name")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(name) > constellationNameLength {
|
||||
return fmt.Errorf("name for constellation too long, maximum length is %d got %d: %s", constellationNameLength, len(name), name)
|
||||
}
|
||||
|
||||
client, err := client.NewInitialized(cmd.Context(), project, zone, region, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
devConfigName, err := cmd.Flags().GetString("dev-config")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fileHandler := file.NewHandler(afero.NewOsFs())
|
||||
config, err := config.FromFile(fileHandler, devConfigName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return createGCP(cmd, client, fileHandler, config, size, count)
|
||||
}
|
||||
|
||||
func createGCP(cmd *cobra.Command, cl gcpclient, fileHandler file.Handler, config *config.Config, size string, count int) (retErr error) {
|
||||
if err := checkDirClean(fileHandler, config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
createInput := client.CreateInstancesInput{
|
||||
Count: count,
|
||||
ImageId: *config.Provider.GCP.Image,
|
||||
InstanceType: size,
|
||||
KubeEnv: gcp.KubeEnv,
|
||||
DisableCVM: *config.Provider.GCP.DisableCVM,
|
||||
}
|
||||
|
||||
ok, err := cmd.Flags().GetBool("yes")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
// Ask user to confirm action.
|
||||
cmd.Printf("The following Constellation will be created:\n")
|
||||
cmd.Printf("%d nodes of size %s will be created.\n", count, size)
|
||||
ok, err := askToConfirm(cmd, "Do you want to create this Constellation?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
cmd.Println("The creation of the Constellation was aborted.")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Create all gcp resources
|
||||
defer rollbackOnError(context.Background(), cmd.OutOrStdout(), &retErr, &rollbackerGCP{client: cl})
|
||||
if err := cl.CreateVPCs(cmd.Context(), *config.Provider.GCP.VPCsInput); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := cl.CreateFirewall(cmd.Context(), *config.Provider.GCP.FirewallInput); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := cl.CreateInstances(cmd.Context(), createInput); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stat, err := cl.GetState()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := fileHandler.WriteJSON(*config.StatePath, stat, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Println("Your Constellation was created successfully.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// createGCPCompletion handels the completion of CLI arguments. It is frequently called
|
||||
// while the user types arguments of the command to suggest completion.
|
||||
func createGCPCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
switch len(args) {
|
||||
case 0:
|
||||
return []string{}, cobra.ShellCompDirectiveNoFileComp
|
||||
case 1:
|
||||
return gcp.InstanceTypes, cobra.ShellCompDirectiveDefault
|
||||
default:
|
||||
return []string{}, cobra.ShellCompDirectiveError
|
||||
}
|
||||
}
|
206
cli/cmd/create_gcp_test.go
Normal file
206
cli/cmd/create_gcp_test.go
Normal file
@ -0,0 +1,206 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/cli/file"
|
||||
"github.com/edgelesssys/constellation/cli/gcp"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCreateGCPCmdArgumentValidation(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
args []string
|
||||
expectErr bool
|
||||
}{
|
||||
"valid create 1": {[]string{"3", "n2d-standard-2"}, false},
|
||||
"valid create 2": {[]string{"7", "n2d-standard-16"}, false},
|
||||
"valid create 3": {[]string{"2", "n2d-standard-96"}, false},
|
||||
"invalid to many arguments": {[]string{"2", "n2d-standard-2", "n2d-standard-2"}, true},
|
||||
"invalid to many arguments 2": {[]string{"2", "n2d-standard-2", "2"}, true},
|
||||
"invalidOnlyOneInstance": {[]string{"1", "n2d-standard-2"}, true},
|
||||
"invalid first is no int": {[]string{"n2d-standard-2", "n2d-standard-2"}, true},
|
||||
"invalid second is no size": {[]string{"2", "2"}, true},
|
||||
"invalid wrong order": {[]string{"n2d-standard-2", "2"}, true},
|
||||
}
|
||||
|
||||
cmd := newCreateGCPCmd()
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
err := cmd.ValidateArgs(tc.args)
|
||||
if tc.expectErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateGCP(t *testing.T) {
|
||||
testState := state.ConstellationState{
|
||||
CloudProvider: cloudprovider.GCP.String(),
|
||||
GCPNodes: gcp.Instances{
|
||||
"id-0": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
"id-1": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
},
|
||||
GCPCoordinators: gcp.Instances{
|
||||
"id-c": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
},
|
||||
GCPNodeInstanceGroup: "nodes-group",
|
||||
GCPCoordinatorInstanceGroup: "coordinator-group",
|
||||
GCPNodeInstanceTemplate: "node-template",
|
||||
GCPCoordinatorInstanceTemplate: "coordinator-template",
|
||||
GCPNetwork: "network",
|
||||
GCPSubnetwork: "subnetwork",
|
||||
GCPFirewalls: []string{"coordinator", "wireguard", "ssh"},
|
||||
}
|
||||
someErr := errors.New("failed")
|
||||
config := config.Default()
|
||||
|
||||
testCases := map[string]struct {
|
||||
existingState *state.ConstellationState
|
||||
client gcpclient
|
||||
interactive bool
|
||||
interactiveStdin string
|
||||
stateExpected state.ConstellationState
|
||||
errExpected bool
|
||||
}{
|
||||
"create some instances": {
|
||||
client: &fakeGcpClient{},
|
||||
stateExpected: testState,
|
||||
},
|
||||
"state already exists": {
|
||||
existingState: &testState,
|
||||
client: &fakeGcpClient{},
|
||||
errExpected: true,
|
||||
},
|
||||
"create some instances interactive": {
|
||||
client: &fakeGcpClient{},
|
||||
interactive: true,
|
||||
interactiveStdin: "y\n",
|
||||
stateExpected: testState,
|
||||
errExpected: false,
|
||||
},
|
||||
"fail getState": {
|
||||
client: &stubGcpClient{getStateErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
"fail createVPCs": {
|
||||
client: &stubGcpClient{createVPCsErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
"fail createFirewall": {
|
||||
client: &stubGcpClient{createFirewallErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
"fail createInstances": {
|
||||
client: &stubGcpClient{createInstancesErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
"error on rollback": {
|
||||
client: &stubGcpClient{createInstancesErr: someErr, terminateVPCsErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
cmd := newCreateGCPCmd()
|
||||
cmd.Flags().BoolP("yes", "y", false, "")
|
||||
out := bytes.NewBufferString("")
|
||||
cmd.SetOut(out)
|
||||
errOut := bytes.NewBufferString("")
|
||||
cmd.SetErr(errOut)
|
||||
in := bytes.NewBufferString(tc.interactiveStdin)
|
||||
cmd.SetIn(in)
|
||||
if !tc.interactive {
|
||||
require.NoError(cmd.Flags().Set("yes", "true")) // disable interactivity
|
||||
}
|
||||
|
||||
fs := afero.NewMemMapFs()
|
||||
fileHandler := file.NewHandler(fs)
|
||||
if tc.existingState != nil {
|
||||
require.NoError(fileHandler.WriteJSON(*config.StatePath, *tc.existingState, false))
|
||||
}
|
||||
|
||||
err := createGCP(cmd, tc.client, fileHandler, config, "n2d-standard-2", 3)
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
if stubClient, ok := tc.client.(*stubGcpClient); ok {
|
||||
// Should have made a rollback on error.
|
||||
assert.True(stubClient.terminateFirewallCalled)
|
||||
assert.True(stubClient.terminateInstancesCalled)
|
||||
assert.True(stubClient.terminateVPCsCalled)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
var stat state.ConstellationState
|
||||
err := fileHandler.ReadJSON(*config.StatePath, &stat)
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.stateExpected, stat)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateGCPCompletion(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
args []string
|
||||
toComplete string
|
||||
resultExpected []string
|
||||
shellCDExpected cobra.ShellCompDirective
|
||||
}{
|
||||
"first arg": {
|
||||
args: []string{},
|
||||
toComplete: "21",
|
||||
resultExpected: []string{},
|
||||
shellCDExpected: cobra.ShellCompDirectiveNoFileComp,
|
||||
},
|
||||
"second arg": {
|
||||
args: []string{"23"},
|
||||
toComplete: "n2d-stan",
|
||||
resultExpected: gcp.InstanceTypes,
|
||||
shellCDExpected: cobra.ShellCompDirectiveDefault,
|
||||
},
|
||||
"third arg": {
|
||||
args: []string{"23", "n2d-standard-2"},
|
||||
toComplete: "n2d-stan",
|
||||
resultExpected: []string{},
|
||||
shellCDExpected: cobra.ShellCompDirectiveError,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
result, shellCD := createGCPCompletion(cmd, tc.args, tc.toComplete)
|
||||
assert.Equal(tc.resultExpected, result)
|
||||
assert.Equal(tc.shellCDExpected, shellCD)
|
||||
})
|
||||
}
|
||||
}
|
64
cli/cmd/create_test.go
Normal file
64
cli/cmd/create_test.go
Normal file
@ -0,0 +1,64 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/file"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCheckDirClean(t *testing.T) {
|
||||
config := config.Default()
|
||||
|
||||
testCases := map[string]struct {
|
||||
fileHandler file.Handler
|
||||
existingFiles []string
|
||||
wantErr bool
|
||||
}{
|
||||
"no file exists": {
|
||||
fileHandler: file.NewHandler(afero.NewMemMapFs()),
|
||||
},
|
||||
"adminconf exists": {
|
||||
fileHandler: file.NewHandler(afero.NewMemMapFs()),
|
||||
existingFiles: []string{*config.AdminConfPath},
|
||||
wantErr: true,
|
||||
},
|
||||
"master secret exists": {
|
||||
fileHandler: file.NewHandler(afero.NewMemMapFs()),
|
||||
existingFiles: []string{*config.MasterSecretPath},
|
||||
wantErr: true,
|
||||
},
|
||||
"state file exists": {
|
||||
fileHandler: file.NewHandler(afero.NewMemMapFs()),
|
||||
existingFiles: []string{*config.StatePath},
|
||||
wantErr: true,
|
||||
},
|
||||
"multiple exist": {
|
||||
fileHandler: file.NewHandler(afero.NewMemMapFs()),
|
||||
existingFiles: []string{*config.AdminConfPath, *config.MasterSecretPath, *config.StatePath},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
for _, f := range tc.existingFiles {
|
||||
require.NoError(tc.fileHandler.Write(f, []byte{1, 2, 3}, false))
|
||||
}
|
||||
|
||||
err := checkDirClean(tc.fileHandler, config)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
17
cli/cmd/ec2client.go
Normal file
17
cli/cmd/ec2client.go
Normal file
@ -0,0 +1,17 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/ec2/client"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
)
|
||||
|
||||
type ec2client interface {
|
||||
GetState() (state.ConstellationState, error)
|
||||
SetState(stat state.ConstellationState) error
|
||||
CreateInstances(ctx context.Context, input client.CreateInput) error
|
||||
TerminateInstances(ctx context.Context) error
|
||||
CreateSecurityGroup(ctx context.Context, input client.SecurityGroupInput) error
|
||||
DeleteSecurityGroup(ctx context.Context) error
|
||||
}
|
139
cli/cmd/ec2client_test.go
Normal file
139
cli/cmd/ec2client_test.go
Normal file
@ -0,0 +1,139 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strconv"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/cli/ec2"
|
||||
"github.com/edgelesssys/constellation/cli/ec2/client"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
)
|
||||
|
||||
type fakeEc2Client struct {
|
||||
instances ec2.Instances
|
||||
securityGroup string
|
||||
ec2state []fakeEc2Instance
|
||||
}
|
||||
|
||||
func (c *fakeEc2Client) GetState() (state.ConstellationState, error) {
|
||||
if len(c.instances) == 0 {
|
||||
return state.ConstellationState{}, errors.New("client has no instances")
|
||||
}
|
||||
stat := state.ConstellationState{
|
||||
CloudProvider: cloudprovider.AWS.String(),
|
||||
EC2Instances: c.instances,
|
||||
EC2SecurityGroup: c.securityGroup,
|
||||
}
|
||||
for id, instance := range c.instances {
|
||||
instance.PrivateIP = "192.0.2.1"
|
||||
instance.PublicIP = "192.0.2.2"
|
||||
c.instances[id] = instance
|
||||
}
|
||||
return stat, nil
|
||||
}
|
||||
|
||||
func (c *fakeEc2Client) SetState(stat state.ConstellationState) error {
|
||||
if len(stat.EC2Instances) == 0 {
|
||||
return errors.New("state has no instances")
|
||||
}
|
||||
c.instances = stat.EC2Instances
|
||||
c.securityGroup = stat.EC2SecurityGroup
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeEc2Client) CreateInstances(_ context.Context, input client.CreateInput) error {
|
||||
if c.securityGroup == "" {
|
||||
return errors.New("client has no security group")
|
||||
}
|
||||
if c.instances == nil {
|
||||
c.instances = make(ec2.Instances)
|
||||
}
|
||||
for i := 0; i < input.Count; i++ {
|
||||
id := "id-" + strconv.Itoa(len(c.ec2state))
|
||||
c.ec2state = append(c.ec2state, fakeEc2Instance{
|
||||
state: running,
|
||||
instanceID: id,
|
||||
securityGroup: c.securityGroup,
|
||||
tags: input.Tags,
|
||||
})
|
||||
c.instances[id] = ec2.Instance{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeEc2Client) TerminateInstances(_ context.Context) error {
|
||||
if len(c.instances) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, instance := range c.ec2state {
|
||||
instance.state = terminated
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeEc2Client) CreateSecurityGroup(_ context.Context, input client.SecurityGroupInput) error {
|
||||
if c.securityGroup != "" {
|
||||
return errors.New("client already has a security group")
|
||||
}
|
||||
c.securityGroup = "sg-test"
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeEc2Client) DeleteSecurityGroup(_ context.Context) error {
|
||||
c.securityGroup = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
type ec2InstanceState int
|
||||
|
||||
const (
|
||||
running = iota
|
||||
terminated
|
||||
)
|
||||
|
||||
type fakeEc2Instance struct {
|
||||
state ec2InstanceState
|
||||
instanceID string
|
||||
tags ec2.Tags
|
||||
securityGroup string
|
||||
}
|
||||
|
||||
type stubEc2Client struct {
|
||||
terminateInstancesCalled bool
|
||||
deleteSecurityGroupCalled bool
|
||||
|
||||
getStateErr error
|
||||
setStateErr error
|
||||
createInstancesErr error
|
||||
terminateInstancesErr error
|
||||
createSecurityGroupErr error
|
||||
deleteSecurityGroupErr error
|
||||
}
|
||||
|
||||
func (c *stubEc2Client) GetState() (state.ConstellationState, error) {
|
||||
return state.ConstellationState{}, c.getStateErr
|
||||
}
|
||||
|
||||
func (c *stubEc2Client) SetState(stat state.ConstellationState) error {
|
||||
return c.setStateErr
|
||||
}
|
||||
|
||||
func (c *stubEc2Client) CreateInstances(_ context.Context, input client.CreateInput) error {
|
||||
return c.createInstancesErr
|
||||
}
|
||||
|
||||
func (c *stubEc2Client) TerminateInstances(_ context.Context) error {
|
||||
c.terminateInstancesCalled = true
|
||||
return c.terminateInstancesErr
|
||||
}
|
||||
|
||||
func (c *stubEc2Client) CreateSecurityGroup(_ context.Context, input client.SecurityGroupInput) error {
|
||||
return c.createSecurityGroupErr
|
||||
}
|
||||
|
||||
func (c *stubEc2Client) DeleteSecurityGroup(_ context.Context) error {
|
||||
c.deleteSecurityGroupCalled = true
|
||||
return c.deleteSecurityGroupErr
|
||||
}
|
22
cli/cmd/gcpclient.go
Normal file
22
cli/cmd/gcpclient.go
Normal file
@ -0,0 +1,22 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/gcp/client"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
)
|
||||
|
||||
type gcpclient interface {
|
||||
GetState() (state.ConstellationState, error)
|
||||
SetState(state.ConstellationState) error
|
||||
CreateVPCs(ctx context.Context, input client.VPCsInput) error
|
||||
CreateFirewall(ctx context.Context, input client.FirewallInput) error
|
||||
CreateInstances(ctx context.Context, input client.CreateInstancesInput) error
|
||||
CreateServiceAccount(ctx context.Context, input client.ServiceAccountInput) (string, error)
|
||||
TerminateFirewall(ctx context.Context) error
|
||||
TerminateVPCs(context.Context) error
|
||||
TerminateInstances(context.Context) error
|
||||
TerminateServiceAccount(ctx context.Context) error
|
||||
Close() error
|
||||
}
|
219
cli/cmd/gcpclient_test.go
Normal file
219
cli/cmd/gcpclient_test.go
Normal file
@ -0,0 +1,219 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strconv"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/cli/gcp"
|
||||
"github.com/edgelesssys/constellation/cli/gcp/client"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
)
|
||||
|
||||
type fakeGcpClient struct {
|
||||
nodes gcp.Instances
|
||||
coordinators gcp.Instances
|
||||
|
||||
nodesInstanceGroup string
|
||||
coordinatorInstanceGroup string
|
||||
coordinatorTemplate string
|
||||
nodeTemplate string
|
||||
network string
|
||||
subnetwork string
|
||||
firewalls []string
|
||||
project string
|
||||
uid string
|
||||
name string
|
||||
zone string
|
||||
serviceAccount string
|
||||
}
|
||||
|
||||
func (c *fakeGcpClient) GetState() (state.ConstellationState, error) {
|
||||
stat := state.ConstellationState{
|
||||
CloudProvider: cloudprovider.GCP.String(),
|
||||
GCPNodes: c.nodes,
|
||||
GCPCoordinators: c.coordinators,
|
||||
GCPNodeInstanceGroup: c.nodesInstanceGroup,
|
||||
GCPCoordinatorInstanceGroup: c.coordinatorInstanceGroup,
|
||||
GCPNodeInstanceTemplate: c.nodeTemplate,
|
||||
GCPCoordinatorInstanceTemplate: c.coordinatorTemplate,
|
||||
GCPNetwork: c.network,
|
||||
GCPSubnetwork: c.subnetwork,
|
||||
GCPFirewalls: c.firewalls,
|
||||
GCPProject: c.project,
|
||||
Name: c.name,
|
||||
UID: c.uid,
|
||||
GCPZone: c.zone,
|
||||
GCPServiceAccount: c.serviceAccount,
|
||||
}
|
||||
return stat, nil
|
||||
}
|
||||
|
||||
func (c *fakeGcpClient) SetState(stat state.ConstellationState) error {
|
||||
c.nodes = stat.GCPNodes
|
||||
c.coordinators = stat.GCPCoordinators
|
||||
c.nodesInstanceGroup = stat.GCPNodeInstanceGroup
|
||||
c.coordinatorInstanceGroup = stat.GCPCoordinatorInstanceGroup
|
||||
c.nodeTemplate = stat.GCPNodeInstanceTemplate
|
||||
c.coordinatorTemplate = stat.GCPCoordinatorInstanceTemplate
|
||||
c.network = stat.GCPNetwork
|
||||
c.subnetwork = stat.GCPSubnetwork
|
||||
c.firewalls = stat.GCPFirewalls
|
||||
c.project = stat.GCPProject
|
||||
c.name = stat.Name
|
||||
c.uid = stat.UID
|
||||
c.zone = stat.GCPZone
|
||||
c.serviceAccount = stat.GCPServiceAccount
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeGcpClient) CreateVPCs(ctx context.Context, input client.VPCsInput) error {
|
||||
c.network = "network"
|
||||
c.subnetwork = "subnetwork"
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeGcpClient) CreateFirewall(ctx context.Context, input client.FirewallInput) error {
|
||||
if c.network == "" {
|
||||
return errors.New("client has not network")
|
||||
}
|
||||
var firewalls []string
|
||||
for _, rule := range input.Ingress {
|
||||
firewalls = append(firewalls, rule.Name)
|
||||
}
|
||||
c.firewalls = firewalls
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeGcpClient) CreateInstances(ctx context.Context, input client.CreateInstancesInput) error {
|
||||
c.coordinatorInstanceGroup = "coordinator-group"
|
||||
c.nodesInstanceGroup = "nodes-group"
|
||||
c.nodeTemplate = "node-template"
|
||||
c.coordinatorTemplate = "coordinator-template"
|
||||
c.nodes = make(gcp.Instances)
|
||||
for i := 0; i < input.Count-1; i++ {
|
||||
id := "id-" + strconv.Itoa(len(c.nodes))
|
||||
c.nodes[id] = gcp.Instance{PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1"}
|
||||
}
|
||||
c.coordinators = make(gcp.Instances)
|
||||
c.coordinators["id-c"] = gcp.Instance{PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1"}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeGcpClient) CreateServiceAccount(ctx context.Context, input client.ServiceAccountInput) (string, error) {
|
||||
c.serviceAccount = "service-account@" + c.project + ".iam.gserviceaccount.com"
|
||||
return client.ServiceAccountKey{
|
||||
Type: "service_account",
|
||||
ProjectID: c.project,
|
||||
PrivateKeyID: "key-id",
|
||||
PrivateKey: "-----BEGIN PRIVATE KEY-----\nprivate-key\n-----END PRIVATE KEY-----\n",
|
||||
ClientEmail: c.serviceAccount,
|
||||
ClientID: "client-id",
|
||||
AuthURI: "https://accounts.google.com/o/oauth2/auth",
|
||||
TokenURI: "https://accounts.google.com/o/oauth2/token",
|
||||
AuthProviderX509CertURL: "https://www.googleapis.com/oauth2/v1/certs",
|
||||
ClientX509CertURL: "https://www.googleapis.com/robot/v1/metadata/x509/service-account-email",
|
||||
}.ConvertToCloudServiceAccountURI(), nil
|
||||
}
|
||||
|
||||
func (c *fakeGcpClient) TerminateFirewall(ctx context.Context) error {
|
||||
if len(c.firewalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
c.firewalls = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeGcpClient) TerminateVPCs(context.Context) error {
|
||||
if len(c.firewalls) != 0 {
|
||||
return errors.New("client has firewalls, which must be deleted first")
|
||||
}
|
||||
c.network = ""
|
||||
c.subnetwork = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeGcpClient) TerminateInstances(context.Context) error {
|
||||
c.nodeTemplate = ""
|
||||
c.coordinatorTemplate = ""
|
||||
c.nodesInstanceGroup = ""
|
||||
c.coordinatorInstanceGroup = ""
|
||||
c.nodes = nil
|
||||
c.coordinators = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeGcpClient) TerminateServiceAccount(context.Context) error {
|
||||
c.serviceAccount = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeGcpClient) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type stubGcpClient struct {
|
||||
terminateFirewallCalled bool
|
||||
terminateInstancesCalled bool
|
||||
terminateVPCsCalled bool
|
||||
|
||||
getStateErr error
|
||||
setStateErr error
|
||||
createVPCsErr error
|
||||
createFirewallErr error
|
||||
createInstancesErr error
|
||||
createServiceAccountErr error
|
||||
terminateFirewallErr error
|
||||
terminateVPCsErr error
|
||||
terminateInstancesErr error
|
||||
terminateServiceAccountErr error
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func (c *stubGcpClient) GetState() (state.ConstellationState, error) {
|
||||
return state.ConstellationState{}, c.getStateErr
|
||||
}
|
||||
|
||||
func (c *stubGcpClient) SetState(state.ConstellationState) error {
|
||||
return c.setStateErr
|
||||
}
|
||||
|
||||
func (c *stubGcpClient) CreateVPCs(ctx context.Context, input client.VPCsInput) error {
|
||||
return c.createVPCsErr
|
||||
}
|
||||
|
||||
func (c *stubGcpClient) CreateFirewall(ctx context.Context, input client.FirewallInput) error {
|
||||
return c.createFirewallErr
|
||||
}
|
||||
|
||||
func (c *stubGcpClient) CreateInstances(ctx context.Context, input client.CreateInstancesInput) error {
|
||||
return c.createInstancesErr
|
||||
}
|
||||
|
||||
func (c *stubGcpClient) CreateServiceAccount(ctx context.Context, input client.ServiceAccountInput) (string, error) {
|
||||
return client.ServiceAccountKey{}.ConvertToCloudServiceAccountURI(), c.createServiceAccountErr
|
||||
}
|
||||
|
||||
func (c *stubGcpClient) TerminateFirewall(ctx context.Context) error {
|
||||
c.terminateFirewallCalled = true
|
||||
return c.terminateFirewallErr
|
||||
}
|
||||
|
||||
func (c *stubGcpClient) TerminateVPCs(context.Context) error {
|
||||
c.terminateVPCsCalled = true
|
||||
return c.terminateVPCsErr
|
||||
}
|
||||
|
||||
func (c *stubGcpClient) TerminateInstances(context.Context) error {
|
||||
c.terminateInstancesCalled = true
|
||||
return c.terminateInstancesErr
|
||||
}
|
||||
|
||||
func (c *stubGcpClient) TerminateServiceAccount(context.Context) error {
|
||||
return c.terminateServiceAccountErr
|
||||
}
|
||||
|
||||
func (c *stubGcpClient) Close() error {
|
||||
return c.closeErr
|
||||
}
|
484
cli/cmd/init.go
Normal file
484
cli/cmd/init.go
Normal file
@ -0,0 +1,484 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/file"
|
||||
"github.com/edgelesssys/constellation/cli/gcp"
|
||||
"github.com/edgelesssys/constellation/cli/proto"
|
||||
"github.com/edgelesssys/constellation/cli/status"
|
||||
"github.com/edgelesssys/constellation/cli/vpn"
|
||||
coordinatorstate "github.com/edgelesssys/constellation/coordinator/state"
|
||||
"github.com/edgelesssys/constellation/coordinator/util"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
)
|
||||
|
||||
func newInitCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "init",
|
||||
Short: "Initialize the Constellation. Start your confidential Kubernetes cluster.",
|
||||
Long: "Initialize the Constellation. Start your confidential Kubernetes cluster.",
|
||||
ValidArgsFunction: initCompletion,
|
||||
Args: cobra.ExactArgs(0),
|
||||
RunE: runInitialize,
|
||||
}
|
||||
cmd.Flags().String("privatekey", "", "path to your private key.")
|
||||
cmd.Flags().String("publickey", "", "path to your public key.")
|
||||
cmd.Flags().String("master-secret", "", "path to base64 encoded master secret.")
|
||||
cmd.Flags().Bool("autoscale", false, "enable kubernetes cluster-autoscaler")
|
||||
return cmd
|
||||
}
|
||||
|
||||
// runInitialize runs the initialize command.
|
||||
func runInitialize(cmd *cobra.Command, args []string) error {
|
||||
fileHandler := file.NewHandler(afero.NewOsFs())
|
||||
devConfigName, err := cmd.Flags().GetString("dev-config")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config, err := config.FromFile(fileHandler, devConfigName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
protoClient := proto.NewClient(*config.Provider.GCP.PCRs)
|
||||
defer protoClient.Close()
|
||||
vpnClient, err := vpn.NewWithDefaults()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// We have to parse the context separately, since cmd.Context()
|
||||
// returns nil during the tests otherwise.
|
||||
return initialize(cmd.Context(), cmd, protoClient, vpnClient, serviceAccountClient{}, fileHandler, config, status.NewWaiter(*config.Provider.GCP.PCRs))
|
||||
}
|
||||
|
||||
// initialize initializes a Constellation. Coordinator instances are activated as Coordinators and will
|
||||
// themself activate the other peers as nodes.
|
||||
func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, vpnCl vpnConfigurer, serviceAccountCr serviceAccountCreator,
|
||||
fileHandler file.Handler, config *config.Config, waiter statusWaiter,
|
||||
) error {
|
||||
flagArgs, err := evalFlagArgs(cmd, fileHandler, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var stat state.ConstellationState
|
||||
err = fileHandler.ReadJSON(*config.StatePath, &stat)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return fmt.Errorf("nothing to initialize: %w", err)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch stat.CloudProvider {
|
||||
case "GCP":
|
||||
if err := warnAboutPCRs(cmd, *config.Provider.GCP.PCRs, true); err != nil {
|
||||
return err
|
||||
}
|
||||
case "Azure":
|
||||
if err := warnAboutPCRs(cmd, *config.Provider.Azure.PCRs, true); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
serviceAccount, stat, err := serviceAccountCr.createServiceAccount(ctx, stat, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := fileHandler.WriteJSON(*config.StatePath, stat, true); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
coordinators, nodes, err := getScalingGroupsFromConfig(stat, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
endpoints := ipsToEndpoints(append(coordinators.PublicIPs(), nodes.PublicIPs()...), *config.CoordinatorPort)
|
||||
if err := waiter.WaitForAll(ctx, coordinatorstate.AcceptingInit, endpoints); err != nil {
|
||||
return fmt.Errorf("failed to wait for peer status: %w", err)
|
||||
}
|
||||
|
||||
var autoscalingNodeGroups []string
|
||||
if flagArgs.autoscale {
|
||||
autoscalingNodeGroups = append(autoscalingNodeGroups, nodes.GroupID)
|
||||
}
|
||||
|
||||
input := activationInput{
|
||||
coordinatorPubIP: coordinators.PublicIPs()[0],
|
||||
pubKey: flagArgs.userPubKey,
|
||||
masterSecret: flagArgs.masterSecret,
|
||||
nodePrivIPs: nodes.PrivateIPs(),
|
||||
autoscalingNodeGroups: autoscalingNodeGroups,
|
||||
cloudServiceAccountURI: serviceAccount,
|
||||
}
|
||||
result, err := activate(ctx, cmd, protCl, input, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = result.writeOutput(cmd.OutOrStdout(), fileHandler, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if flagArgs.autoconfigureWG {
|
||||
if err := configureVpn(vpnCl, result.clientVpnIP, result.coordinatorPubKey, result.coordinatorPubIP, flagArgs.userPrivKey); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput, config *config.Config) (activationResult, error) {
|
||||
if err := client.Connect(input.coordinatorPubIP, *config.CoordinatorPort); err != nil {
|
||||
return activationResult{}, err
|
||||
}
|
||||
|
||||
respCl, err := client.Activate(ctx, input.pubKey, input.masterSecret, ipsToEndpoints(input.nodePrivIPs, *config.CoordinatorPort), input.autoscalingNodeGroups, input.cloudServiceAccountURI)
|
||||
if err != nil {
|
||||
return activationResult{}, err
|
||||
}
|
||||
|
||||
if err := respCl.WriteLogStream(cmd.OutOrStdout()); err != nil {
|
||||
return activationResult{}, err
|
||||
}
|
||||
|
||||
clientVpnIp, err := respCl.GetClientVpnIp()
|
||||
if err != nil {
|
||||
return activationResult{}, err
|
||||
}
|
||||
coordinatorPubKey, err := respCl.GetCoordinatorVpnKey()
|
||||
if err != nil {
|
||||
return activationResult{}, err
|
||||
}
|
||||
kubeconfig, err := respCl.GetKubeconfig()
|
||||
if err != nil {
|
||||
return activationResult{}, err
|
||||
}
|
||||
ownerID, err := respCl.GetOwnerID()
|
||||
if err != nil {
|
||||
return activationResult{}, err
|
||||
}
|
||||
clusterID, err := respCl.GetClusterID()
|
||||
if err != nil {
|
||||
return activationResult{}, err
|
||||
}
|
||||
|
||||
return activationResult{
|
||||
clientVpnIP: clientVpnIp,
|
||||
coordinatorPubKey: coordinatorPubKey,
|
||||
coordinatorPubIP: input.coordinatorPubIP,
|
||||
kubeconfig: kubeconfig,
|
||||
ownerID: ownerID,
|
||||
clusterID: clusterID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type activationInput struct {
|
||||
coordinatorPubIP string
|
||||
pubKey []byte
|
||||
masterSecret []byte
|
||||
nodePrivIPs []string
|
||||
autoscalingNodeGroups []string
|
||||
cloudServiceAccountURI string
|
||||
}
|
||||
|
||||
type activationResult struct {
|
||||
clientVpnIP string
|
||||
coordinatorPubKey string
|
||||
coordinatorPubIP string
|
||||
kubeconfig string
|
||||
ownerID string
|
||||
clusterID string
|
||||
}
|
||||
|
||||
func (res activationResult) writeOutput(w io.Writer, fileHandler file.Handler, config *config.Config) error {
|
||||
fmt.Fprintln(w, "Your Constellation was successfully initialized.")
|
||||
fmt.Fprintf(w, "Your WireGuard IP is %s\n", res.clientVpnIP)
|
||||
fmt.Fprintf(w, "The Coordinator's public IP is %s\n", res.coordinatorPubIP)
|
||||
fmt.Fprintf(w, "The Coordinator's public key is %s\n", res.coordinatorPubKey)
|
||||
fmt.Fprintf(w, "The Constellation's owner identifier is %s\n", res.ownerID)
|
||||
fmt.Fprintf(w, "The Constellation's unique identifier is %s\n", res.clusterID)
|
||||
if err := fileHandler.Write(*config.AdminConfPath, []byte(res.kubeconfig), false); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Fprintf(w, "Your Constellation Kubernetes configuration was successfully written to %s\n", *config.AdminConfPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// evalFlagArgs gets the flag values and does preprocessing of these values like
|
||||
// reading the content from file path flags and deriving other values from flag combinations.
|
||||
func evalFlagArgs(cmd *cobra.Command, fileHandler file.Handler, config *config.Config) (flagArgs, error) {
|
||||
userPrivKeyPath, err := cmd.Flags().GetString("privatekey")
|
||||
if err != nil {
|
||||
return flagArgs{}, err
|
||||
}
|
||||
userPublicKeyPath, err := cmd.Flags().GetString("publickey")
|
||||
if err != nil {
|
||||
return flagArgs{}, err
|
||||
}
|
||||
userPrivKey, userPubKey, err := readVpnKey(fileHandler, userPrivKeyPath, userPublicKeyPath)
|
||||
if err != nil {
|
||||
return flagArgs{}, err
|
||||
}
|
||||
masterSecretPath, err := cmd.Flags().GetString("master-secret")
|
||||
if err != nil {
|
||||
return flagArgs{}, err
|
||||
}
|
||||
masterSecret, err := readOrGeneratedMasterSecret(cmd.OutOrStdout(), fileHandler, masterSecretPath, config)
|
||||
if err != nil {
|
||||
return flagArgs{}, err
|
||||
}
|
||||
autoscale, err := cmd.Flags().GetBool("autoscale")
|
||||
if err != nil {
|
||||
return flagArgs{}, err
|
||||
}
|
||||
|
||||
return flagArgs{
|
||||
userPrivKey: userPrivKey,
|
||||
userPubKey: userPubKey,
|
||||
autoconfigureWG: userPrivKeyPath != "",
|
||||
autoscale: autoscale,
|
||||
masterSecret: masterSecret,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// flagArgs are the resulting values of flag preprocessing.
|
||||
type flagArgs struct {
|
||||
userPrivKey []byte
|
||||
userPubKey []byte
|
||||
masterSecret []byte
|
||||
autoconfigureWG bool
|
||||
autoscale bool
|
||||
}
|
||||
|
||||
func readVpnKey(fileHandler file.Handler, privKeyPath, publicKeyPath string) (privKey, pubKey []byte, err error) {
|
||||
if privKeyPath != "" {
|
||||
privKey, err = fileHandler.Read(privKeyPath)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
privKeyParsed, err := wgtypes.ParseKey(string(privKey))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
pubKey = []byte(privKeyParsed.PublicKey().String())
|
||||
} else if publicKeyPath != "" {
|
||||
pubKey, err = fileHandler.Read(publicKeyPath)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := checkBase64WGKey(pubKey); err != nil {
|
||||
return nil, nil, fmt.Errorf("wireguard public key is invalid: %w", err)
|
||||
}
|
||||
} else {
|
||||
return nil, nil, errors.New("neither path to public nor private key provided")
|
||||
}
|
||||
return privKey, pubKey, nil
|
||||
}
|
||||
|
||||
func configureVpn(vpnCl vpnConfigurer, clientVpnIp, coordinatorPubKey, coordinatorPublicIp string, privKey []byte) error {
|
||||
err := vpnCl.Configure(clientVpnIp, coordinatorPubKey, coordinatorPublicIp, string(privKey))
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not configure WireGuard automatically: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ipsToEndpoints(ips []string, port string) []string {
|
||||
var endpoints []string
|
||||
for _, ip := range ips {
|
||||
endpoints = append(endpoints, net.JoinHostPort(ip, port))
|
||||
}
|
||||
return endpoints
|
||||
}
|
||||
|
||||
func checkBase64WGKey(b []byte) error {
|
||||
keyStr, err := base64.StdEncoding.DecodeString(string(b))
|
||||
if err != nil {
|
||||
return errors.New("key can't be decoded")
|
||||
}
|
||||
if length := len(keyStr); length != wireguardKeyLength {
|
||||
return fmt.Errorf("key has invalid length %d", length)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// readOrGeneratedMasterSecret reads a base64 encoded master secret from file or generates a new 32 byte secret.
|
||||
func readOrGeneratedMasterSecret(w io.Writer, fileHandler file.Handler, filename string, config *config.Config) ([]byte, error) {
|
||||
if filename != "" {
|
||||
// Try to read the base64 secret from file
|
||||
encodedSecret, err := fileHandler.Read(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(string(encodedSecret))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(decoded) < masterSecretLengthMin {
|
||||
return nil, errors.New("provided master secret is smaller than the required minimum of 16 Bytes")
|
||||
}
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// No file given, generate a new secret, and save it to disk
|
||||
masterSecret, err := util.GenerateRandomBytes(masterSecretLengthDefault)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := fileHandler.Write(*config.MasterSecretPath, []byte(base64.StdEncoding.EncodeToString(masterSecret)), false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Fprintf(w, "Your Constellation master secret was successfully written to ./%s\n", *config.MasterSecretPath)
|
||||
return masterSecret, nil
|
||||
}
|
||||
|
||||
func getScalingGroupsFromConfig(stat state.ConstellationState, config *config.Config) (coordinators, nodes ScalingGroup, err error) {
|
||||
switch {
|
||||
case len(stat.EC2Instances) != 0:
|
||||
return getAWSInstances(stat)
|
||||
case len(stat.GCPCoordinators) != 0:
|
||||
return getGCPInstances(stat, config)
|
||||
case len(stat.AzureCoordinators) != 0:
|
||||
return getAzureInstances(stat)
|
||||
default:
|
||||
return ScalingGroup{}, ScalingGroup{}, errors.New("no instances to init")
|
||||
}
|
||||
}
|
||||
|
||||
func getAWSInstances(stat state.ConstellationState) (coordinators, nodes ScalingGroup, err error) {
|
||||
coordinatorID, coordinator, err := stat.EC2Instances.GetOne()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// GroupID of coordinators is empty, since they currently do not scale.
|
||||
coordinators = ScalingGroup{Instances: Instances{Instance(coordinator)}, GroupID: ""}
|
||||
|
||||
nodeMap := stat.EC2Instances.GetOthers(coordinatorID)
|
||||
if len(nodeMap) == 0 {
|
||||
return ScalingGroup{}, ScalingGroup{}, errors.New("no nodes available, can't create Constellation with one instance")
|
||||
}
|
||||
|
||||
var nodeInstances Instances
|
||||
for _, node := range nodeMap {
|
||||
nodeInstances = append(nodeInstances, Instance(node))
|
||||
}
|
||||
|
||||
// TODO: make min / max configurable and abstract autoscaling for different cloud providers
|
||||
// TODO: GroupID of nodes is empty, since they currently do not scale.
|
||||
nodes = ScalingGroup{Instances: nodeInstances, GroupID: ""}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func getGCPInstances(stat state.ConstellationState, config *config.Config) (coordinators, nodes ScalingGroup, err error) {
|
||||
_, coordinator, err := stat.GCPCoordinators.GetOne()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// GroupID of coordinators is empty, since they currently do not scale.
|
||||
coordinators = ScalingGroup{Instances: Instances{Instance(coordinator)}, GroupID: ""}
|
||||
|
||||
nodeMap := stat.GCPNodes
|
||||
if len(nodeMap) == 0 {
|
||||
return ScalingGroup{}, ScalingGroup{}, errors.New("no nodes available, can't create Constellation with one instance")
|
||||
}
|
||||
|
||||
var nodeInstances Instances
|
||||
for _, node := range nodeMap {
|
||||
nodeInstances = append(nodeInstances, Instance(node))
|
||||
}
|
||||
|
||||
// TODO: make min / max configurable and abstract autoscaling for different cloud providers
|
||||
nodes = ScalingGroup{
|
||||
Instances: nodeInstances,
|
||||
GroupID: gcp.AutoscalingNodeGroup(stat.GCPProject, stat.GCPZone, stat.GCPNodeInstanceGroup, *config.AutoscalingNodeGroupsMin, *config.AutoscalingNodeGroupsMax),
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func getAzureInstances(stat state.ConstellationState) (coordinators, nodes ScalingGroup, err error) {
|
||||
_, coordinator, err := stat.AzureCoordinators.GetOne()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// GroupID of coordinators is empty, since they currently do not scale.
|
||||
coordinators = ScalingGroup{Instances: Instances{Instance(coordinator)}, GroupID: ""}
|
||||
|
||||
nodeMap := stat.AzureNodes
|
||||
if len(nodeMap) == 0 {
|
||||
return ScalingGroup{}, ScalingGroup{}, errors.New("no nodes available, can't create Constellation with one instance")
|
||||
}
|
||||
|
||||
var nodeInstances Instances
|
||||
for _, node := range nodeMap {
|
||||
nodeInstances = append(nodeInstances, Instance(node))
|
||||
}
|
||||
|
||||
// TODO: make min / max configurable and abstract autoscaling for different cloud providers
|
||||
nodes = ScalingGroup{
|
||||
Instances: nodeInstances,
|
||||
GroupID: "",
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// initCompletion handels the completion of CLI arguments. It is frequently called
|
||||
// while the user types arguments of the command to suggest completion.
|
||||
func initCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
if len(args) != 0 {
|
||||
return []string{}, cobra.ShellCompDirectiveError
|
||||
}
|
||||
return []string{}, cobra.ShellCompDirectiveDefault
|
||||
}
|
||||
|
||||
//
|
||||
// TODO: Code below is target of multicloud refactoring.
|
||||
//
|
||||
|
||||
// Instance is a cloud instance.
|
||||
type Instance struct {
|
||||
PublicIP string
|
||||
PrivateIP string
|
||||
}
|
||||
|
||||
type Instances []Instance
|
||||
|
||||
type ScalingGroup struct {
|
||||
Instances
|
||||
GroupID string
|
||||
}
|
||||
|
||||
// PublicIPs returns the public IPs of all the instances.
|
||||
func (i Instances) PublicIPs() []string {
|
||||
var ips []string
|
||||
for _, instance := range i {
|
||||
ips = append(ips, instance.PublicIP)
|
||||
}
|
||||
return ips
|
||||
}
|
||||
|
||||
// PrivateIPs returns the private IPs of all the instances of the Constellation.
|
||||
func (i Instances) PrivateIPs() []string {
|
||||
var ips []string
|
||||
for _, instance := range i {
|
||||
ips = append(ips, instance.PrivateIP)
|
||||
}
|
||||
return ips
|
||||
}
|
686
cli/cmd/init_test.go
Normal file
686
cli/cmd/init_test.go
Normal file
@ -0,0 +1,686 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/azure"
|
||||
"github.com/edgelesssys/constellation/cli/ec2"
|
||||
"github.com/edgelesssys/constellation/cli/file"
|
||||
"github.com/edgelesssys/constellation/cli/gcp"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInitArgumentValidation(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
cmd := newInitCmd()
|
||||
assert.NoError(cmd.ValidateArgs(nil))
|
||||
assert.Error(cmd.ValidateArgs([]string{"something"}))
|
||||
assert.Error(cmd.ValidateArgs([]string{"sth", "sth"}))
|
||||
}
|
||||
|
||||
func TestInitialize(t *testing.T) {
|
||||
testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))
|
||||
config := config.Default()
|
||||
testEc2State := state.ConstellationState{
|
||||
CloudProvider: "AWS",
|
||||
EC2Instances: ec2.Instances{
|
||||
"id-0": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.2",
|
||||
},
|
||||
"id-1": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.2",
|
||||
},
|
||||
"id-2": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.2",
|
||||
},
|
||||
},
|
||||
EC2SecurityGroup: "sg-test",
|
||||
}
|
||||
testGcpState := state.ConstellationState{
|
||||
GCPNodes: gcp.Instances{
|
||||
"id-0": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
"id-1": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
},
|
||||
GCPCoordinators: gcp.Instances{
|
||||
"id-c": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
testAzureState := state.ConstellationState{
|
||||
CloudProvider: "Azure",
|
||||
AzureNodes: azure.Instances{
|
||||
"id-0": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
"id-1": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
},
|
||||
AzureCoordinators: azure.Instances{
|
||||
"id-c": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
},
|
||||
AzureResourceGroup: "test",
|
||||
}
|
||||
testActivationResps := []fakeActivationRespMessage{
|
||||
{log: "testlog1"},
|
||||
{log: "testlog2"},
|
||||
{
|
||||
kubeconfig: "kubeconfig",
|
||||
clientVpnIp: "vpnIp",
|
||||
coordinatorVpnKey: "coordKey",
|
||||
ownerID: "ownerID",
|
||||
clusterID: "clusterID",
|
||||
},
|
||||
{log: "testlog3"},
|
||||
}
|
||||
someErr := errors.New("failed")
|
||||
|
||||
testCases := map[string]struct {
|
||||
existingState state.ConstellationState
|
||||
client protoClient
|
||||
serviceAccountCreator stubServiceAccountCreator
|
||||
waiter statusWaiter
|
||||
pubKey string
|
||||
errExpected bool
|
||||
}{
|
||||
"initialize some ec2 instances": {
|
||||
existingState: testEc2State,
|
||||
client: &fakeProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: stubStatusWaiter{},
|
||||
pubKey: testKey,
|
||||
},
|
||||
"initialize some gcp instances": {
|
||||
existingState: testGcpState,
|
||||
client: &fakeProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: stubStatusWaiter{},
|
||||
pubKey: testKey,
|
||||
},
|
||||
"initialize some azure instances": {
|
||||
existingState: testAzureState,
|
||||
client: &fakeProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: stubStatusWaiter{},
|
||||
pubKey: testKey,
|
||||
},
|
||||
"no state exists": {
|
||||
existingState: state.ConstellationState{},
|
||||
client: &stubProtoClient{},
|
||||
waiter: stubStatusWaiter{},
|
||||
pubKey: testKey,
|
||||
errExpected: true,
|
||||
},
|
||||
"no instances to pick one": {
|
||||
existingState: state.ConstellationState{
|
||||
EC2Instances: ec2.Instances{},
|
||||
EC2SecurityGroup: "sg-test",
|
||||
},
|
||||
client: &stubProtoClient{},
|
||||
waiter: stubStatusWaiter{},
|
||||
pubKey: testKey,
|
||||
errExpected: true,
|
||||
},
|
||||
"only one instance": {
|
||||
existingState: state.ConstellationState{
|
||||
EC2Instances: ec2.Instances{"id-1": {}},
|
||||
EC2SecurityGroup: "sg-test",
|
||||
},
|
||||
client: &stubProtoClient{},
|
||||
waiter: stubStatusWaiter{},
|
||||
pubKey: testKey,
|
||||
errExpected: true,
|
||||
},
|
||||
"public key to short": {
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{},
|
||||
waiter: stubStatusWaiter{},
|
||||
pubKey: base64.StdEncoding.EncodeToString([]byte("tooShortKey")),
|
||||
errExpected: true,
|
||||
},
|
||||
"public key to long": {
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{},
|
||||
waiter: stubStatusWaiter{},
|
||||
pubKey: base64.StdEncoding.EncodeToString([]byte("thisWireguardKeyIsToLongAndHasTooManyBytes")),
|
||||
errExpected: true,
|
||||
},
|
||||
"public key not base64": {
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{},
|
||||
waiter: stubStatusWaiter{},
|
||||
pubKey: "this is not base64 encoded",
|
||||
errExpected: true,
|
||||
},
|
||||
"fail Connect": {
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{connectErr: someErr},
|
||||
waiter: stubStatusWaiter{},
|
||||
pubKey: testKey,
|
||||
errExpected: true,
|
||||
},
|
||||
"fail Activate": {
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{activateErr: someErr},
|
||||
waiter: stubStatusWaiter{},
|
||||
pubKey: testKey,
|
||||
errExpected: true,
|
||||
},
|
||||
"fail respClient WriteLogStream": {
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{respClient: &stubActivationRespClient{writeLogStreamErr: someErr}},
|
||||
waiter: stubStatusWaiter{},
|
||||
pubKey: testKey,
|
||||
errExpected: true,
|
||||
},
|
||||
"fail respClient getKubeconfig": {
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{respClient: &stubActivationRespClient{getKubeconfigErr: someErr}},
|
||||
waiter: stubStatusWaiter{},
|
||||
pubKey: testKey,
|
||||
errExpected: true,
|
||||
},
|
||||
"fail respClient getCoordinatorVpnKey": {
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{respClient: &stubActivationRespClient{getCoordinatorVpnKeyErr: someErr}},
|
||||
waiter: stubStatusWaiter{},
|
||||
pubKey: testKey,
|
||||
errExpected: true,
|
||||
},
|
||||
"fail respClient getClientVpnIp": {
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{respClient: &stubActivationRespClient{getClientVpnIpErr: someErr}},
|
||||
waiter: stubStatusWaiter{},
|
||||
pubKey: testKey,
|
||||
errExpected: true,
|
||||
},
|
||||
"fail respClient getOwnerID": {
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{respClient: &stubActivationRespClient{getOwnerIDErr: someErr}},
|
||||
waiter: stubStatusWaiter{},
|
||||
pubKey: testKey,
|
||||
errExpected: true,
|
||||
},
|
||||
"fail respClient getClusterID": {
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{respClient: &stubActivationRespClient{getClusterIDErr: someErr}},
|
||||
waiter: stubStatusWaiter{},
|
||||
pubKey: testKey,
|
||||
errExpected: true,
|
||||
},
|
||||
"fail to wait for required status": {
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{},
|
||||
waiter: stubStatusWaiter{waitForAllErr: someErr},
|
||||
pubKey: testKey,
|
||||
errExpected: true,
|
||||
},
|
||||
"fail to create service account": {
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{},
|
||||
serviceAccountCreator: stubServiceAccountCreator{
|
||||
createErr: someErr,
|
||||
},
|
||||
waiter: stubStatusWaiter{},
|
||||
pubKey: testKey,
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
cmd := newInitCmd()
|
||||
var out bytes.Buffer
|
||||
cmd.SetOut(&out)
|
||||
var errOut bytes.Buffer
|
||||
cmd.SetErr(&errOut)
|
||||
fs := afero.NewMemMapFs()
|
||||
fileHandler := file.NewHandler(fs)
|
||||
require.NoError(fileHandler.WriteJSON(*config.StatePath, tc.existingState, false))
|
||||
|
||||
// Write key file to filesystem and set path in flag.
|
||||
require.NoError(afero.Afero{Fs: fs}.WriteFile("pubKPath", []byte(tc.pubKey), 0o600))
|
||||
require.NoError(cmd.Flags().Set("publickey", "pubKPath"))
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := initialize(ctx, cmd, tc.client, &dummyVPNConfigurer{}, &tc.serviceAccountCreator, fileHandler, config, tc.waiter)
|
||||
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
require.NoError(err)
|
||||
assert.Contains(out.String(), "vpnIp")
|
||||
assert.Contains(out.String(), "coordKey")
|
||||
assert.Contains(out.String(), "ownerID")
|
||||
assert.Contains(out.String(), "clusterID")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureVPN(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
key := []byte(base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting")))
|
||||
ip := "192.0.2.1"
|
||||
someErr := errors.New("failed")
|
||||
|
||||
configurer := stubVPNConfigurer{}
|
||||
assert.NoError(configureVpn(&configurer, ip, string(key), ip, key))
|
||||
assert.True(configurer.configured)
|
||||
|
||||
configurer = stubVPNConfigurer{configureErr: someErr}
|
||||
assert.Error(configureVpn(&configurer, ip, string(key), ip, key))
|
||||
}
|
||||
|
||||
func TestWriteOutput(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
result := activationResult{
|
||||
clientVpnIP: "foo-qq",
|
||||
coordinatorPubKey: "bar-qq",
|
||||
coordinatorPubIP: "baz-qq",
|
||||
kubeconfig: "foo-bar-baz-qq",
|
||||
}
|
||||
var out bytes.Buffer
|
||||
testFs := afero.NewMemMapFs()
|
||||
fileHandler := file.NewHandler(testFs)
|
||||
config := config.Default()
|
||||
|
||||
err := result.writeOutput(&out, fileHandler, config)
|
||||
assert.NoError(err)
|
||||
assert.Contains(out.String(), result.clientVpnIP)
|
||||
assert.Contains(out.String(), result.coordinatorPubIP)
|
||||
assert.Contains(out.String(), result.coordinatorPubKey)
|
||||
|
||||
afs := afero.Afero{Fs: testFs}
|
||||
adminConf, err := afs.ReadFile(*config.AdminConfPath)
|
||||
assert.NoError(err)
|
||||
assert.Equal(result.kubeconfig, string(adminConf))
|
||||
}
|
||||
|
||||
func TestIpsToEndpoints(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
ips := []string{"192.0.2.1", "192.0.2.2", "192.0.2.3"}
|
||||
port := "8080"
|
||||
endpoints := ipsToEndpoints(ips, port)
|
||||
assert.Equal([]string{"192.0.2.1:8080", "192.0.2.2:8080", "192.0.2.3:8080"}, endpoints)
|
||||
}
|
||||
|
||||
func TestCheckBase64WGKEy(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
key := []byte(base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting")))
|
||||
assert.NoError(checkBase64WGKey(key))
|
||||
key = []byte(base64.StdEncoding.EncodeToString([]byte("shortKey")))
|
||||
assert.Error(checkBase64WGKey(key))
|
||||
key = []byte(base64.StdEncoding.EncodeToString([]byte("looooooooooongKeyWithMoreThan32Bytes")))
|
||||
assert.Error(checkBase64WGKey(key))
|
||||
key = []byte("noBase 64")
|
||||
assert.Error(checkBase64WGKey(key))
|
||||
}
|
||||
|
||||
func TestInitCompletion(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
args []string
|
||||
toComplete string
|
||||
resultExpected []string
|
||||
shellCDExpected cobra.ShellCompDirective
|
||||
}{
|
||||
"first arg": {
|
||||
args: []string{},
|
||||
toComplete: "hello",
|
||||
resultExpected: []string{},
|
||||
shellCDExpected: cobra.ShellCompDirectiveDefault,
|
||||
},
|
||||
"secnod arg": {
|
||||
args: []string{"23"},
|
||||
toComplete: "/test/h",
|
||||
resultExpected: []string{},
|
||||
shellCDExpected: cobra.ShellCompDirectiveError,
|
||||
},
|
||||
"third arg": {
|
||||
args: []string{"./file", "sth"},
|
||||
toComplete: "./file",
|
||||
resultExpected: []string{},
|
||||
shellCDExpected: cobra.ShellCompDirectiveError,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
result, shellCD := initCompletion(cmd, tc.args, tc.toComplete)
|
||||
assert.Equal(tc.resultExpected, result)
|
||||
assert.Equal(tc.shellCDExpected, shellCD)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadVpnKey(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
testKeyA := []byte(base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting")))
|
||||
testKeyB := []byte(base64.StdEncoding.EncodeToString([]byte("anotherWireGuardKeyForTheTesting")))
|
||||
fileHandler := file.NewHandler(afero.NewMemMapFs())
|
||||
require.NoError(fileHandler.Write("testKeyA", testKeyA, false))
|
||||
require.NoError(fileHandler.Write("testKeyB", testKeyB, false))
|
||||
|
||||
// provide privK
|
||||
privK, _, err := readVpnKey(fileHandler, "testKeyA", "")
|
||||
assert.NoError(err)
|
||||
assert.Equal(testKeyA, privK)
|
||||
|
||||
// provide pubK
|
||||
_, pubK, err := readVpnKey(fileHandler, "", "testKeyA")
|
||||
assert.NoError(err)
|
||||
assert.Equal(testKeyA, pubK)
|
||||
|
||||
// provide both, privK should be used, pubK ignored
|
||||
privK, pubK, err = readVpnKey(fileHandler, "testKeyA", "testKeyB")
|
||||
assert.NoError(err)
|
||||
assert.Equal(testKeyA, privK)
|
||||
assert.NotEqual(testKeyB, pubK)
|
||||
|
||||
// no path provided
|
||||
_, _, err = readVpnKey(fileHandler, "", "")
|
||||
assert.Error(err)
|
||||
}
|
||||
|
||||
func TestReadOrGeneratedMasterSecret(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
filename string
|
||||
filecontent string
|
||||
createFile bool
|
||||
fs func() afero.Fs
|
||||
errExpected bool
|
||||
}{
|
||||
"file with secret exists": {
|
||||
filename: "someSecret",
|
||||
filecontent: base64.StdEncoding.EncodeToString([]byte("ConstellationSecret")),
|
||||
createFile: true,
|
||||
fs: afero.NewMemMapFs,
|
||||
errExpected: false,
|
||||
},
|
||||
"no file given": {
|
||||
filename: "",
|
||||
filecontent: "",
|
||||
fs: afero.NewMemMapFs,
|
||||
errExpected: false,
|
||||
},
|
||||
"file does not exist": {
|
||||
filename: "nonExistingSecret",
|
||||
filecontent: "",
|
||||
createFile: false,
|
||||
fs: afero.NewMemMapFs,
|
||||
errExpected: true,
|
||||
},
|
||||
"file is empty": {
|
||||
filename: "emptySecret",
|
||||
filecontent: "",
|
||||
createFile: true,
|
||||
fs: afero.NewMemMapFs,
|
||||
errExpected: true,
|
||||
},
|
||||
"secret too short": {
|
||||
filename: "shortSecret",
|
||||
filecontent: base64.StdEncoding.EncodeToString([]byte("short")),
|
||||
createFile: true,
|
||||
fs: afero.NewMemMapFs,
|
||||
errExpected: true,
|
||||
},
|
||||
"secret not encoded": {
|
||||
filename: "unencodedSecret",
|
||||
filecontent: "Constellation",
|
||||
createFile: true,
|
||||
fs: afero.NewMemMapFs,
|
||||
errExpected: true,
|
||||
},
|
||||
"file not writeable": {
|
||||
filename: "",
|
||||
filecontent: "",
|
||||
createFile: false,
|
||||
fs: func() afero.Fs { return afero.NewReadOnlyFs(afero.NewMemMapFs()) },
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
fileHandler := file.NewHandler(tc.fs())
|
||||
config := config.Default()
|
||||
|
||||
if tc.createFile {
|
||||
require.NoError(fileHandler.Write(tc.filename, []byte(tc.filecontent), false))
|
||||
}
|
||||
|
||||
var out bytes.Buffer
|
||||
secret, err := readOrGeneratedMasterSecret(&out, fileHandler, tc.filename, config)
|
||||
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
|
||||
if tc.filename == "" {
|
||||
require.Contains(out.String(), *config.MasterSecretPath)
|
||||
filename := strings.Split(out.String(), "./")
|
||||
tc.filename = strings.Trim(filename[1], "\n")
|
||||
}
|
||||
|
||||
content, err := fileHandler.Read(tc.filename)
|
||||
require.NoError(err)
|
||||
assert.Equal(content, []byte(base64.StdEncoding.EncodeToString(secret)))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoscaleFlag(t *testing.T) {
|
||||
testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))
|
||||
config := config.Default()
|
||||
testEc2State := state.ConstellationState{
|
||||
EC2Instances: ec2.Instances{
|
||||
"id-0": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.2",
|
||||
},
|
||||
"id-1": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.2",
|
||||
},
|
||||
"id-2": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.2",
|
||||
},
|
||||
},
|
||||
EC2SecurityGroup: "sg-test",
|
||||
}
|
||||
testGcpState := state.ConstellationState{
|
||||
GCPNodes: gcp.Instances{
|
||||
"id-0": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
"id-1": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
},
|
||||
GCPCoordinators: gcp.Instances{
|
||||
"id-c": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
testAzureState := state.ConstellationState{
|
||||
AzureNodes: azure.Instances{
|
||||
"id-0": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
"id-1": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
},
|
||||
AzureCoordinators: azure.Instances{
|
||||
"id-c": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
},
|
||||
AzureResourceGroup: "test",
|
||||
}
|
||||
testActivationResps := []fakeActivationRespMessage{
|
||||
{log: "testlog1"},
|
||||
{log: "testlog2"},
|
||||
{
|
||||
kubeconfig: "kubeconfig",
|
||||
clientVpnIp: "vpnIp",
|
||||
coordinatorVpnKey: "coordKey",
|
||||
ownerID: "ownerID",
|
||||
clusterID: "clusterID",
|
||||
},
|
||||
{log: "testlog3"},
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
autoscaleFlag bool
|
||||
existingState state.ConstellationState
|
||||
client *stubProtoClient
|
||||
serviceAccountCreator stubServiceAccountCreator
|
||||
waiter statusWaiter
|
||||
pubKey string
|
||||
}{
|
||||
"initialize some ec2 instances without autoscale flag": {
|
||||
autoscaleFlag: false,
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: stubStatusWaiter{},
|
||||
pubKey: testKey,
|
||||
},
|
||||
"initialize some gcp instances without autoscale flag": {
|
||||
autoscaleFlag: false,
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: stubStatusWaiter{},
|
||||
pubKey: testKey,
|
||||
},
|
||||
"initialize some azure instances without autoscale flag": {
|
||||
autoscaleFlag: false,
|
||||
existingState: testAzureState,
|
||||
client: &stubProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: stubStatusWaiter{},
|
||||
pubKey: testKey,
|
||||
},
|
||||
"initialize some ec2 instances with autoscale flag": {
|
||||
autoscaleFlag: true,
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: stubStatusWaiter{},
|
||||
pubKey: testKey,
|
||||
},
|
||||
"initialize some gcp instances with autoscale flag": {
|
||||
autoscaleFlag: true,
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: stubStatusWaiter{},
|
||||
pubKey: testKey,
|
||||
},
|
||||
"initialize some azure instances with autoscale flag": {
|
||||
autoscaleFlag: true,
|
||||
existingState: testAzureState,
|
||||
client: &stubProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: stubStatusWaiter{},
|
||||
pubKey: testKey,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
cmd := newInitCmd()
|
||||
var out bytes.Buffer
|
||||
cmd.SetOut(&out)
|
||||
var errOut bytes.Buffer
|
||||
cmd.SetErr(&errOut)
|
||||
fs := afero.NewMemMapFs()
|
||||
fileHandler := file.NewHandler(fs)
|
||||
require.NoError(fileHandler.WriteJSON(*config.StatePath, tc.existingState, false))
|
||||
|
||||
// Write key file to filesystem and set path in flag.
|
||||
require.NoError(afero.Afero{Fs: fs}.WriteFile("pubKPath", []byte(tc.pubKey), 0o600))
|
||||
require.NoError(cmd.Flags().Set("publickey", "pubKPath"))
|
||||
|
||||
require.NoError(cmd.Flags().Set("autoscale", strconv.FormatBool(tc.autoscaleFlag)))
|
||||
ctx := context.Background()
|
||||
|
||||
require.NoError(initialize(ctx, cmd, tc.client, &dummyVPNConfigurer{}, &tc.serviceAccountCreator, fileHandler, config, tc.waiter))
|
||||
if tc.autoscaleFlag {
|
||||
assert.Len(tc.client.activateAutoscalingNodeGroups, 1)
|
||||
} else {
|
||||
assert.Len(tc.client.activateAutoscalingNodeGroups, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
13
cli/cmd/protoclient.go
Normal file
13
cli/cmd/protoclient.go
Normal file
@ -0,0 +1,13 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/proto"
|
||||
)
|
||||
|
||||
type protoClient interface {
|
||||
Connect(ip string, port string) error
|
||||
Close() error
|
||||
Activate(ctx context.Context, userPublicKey, masterSecret []byte, endpoints, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error)
|
||||
}
|
188
cli/cmd/protoclient_test.go
Normal file
188
cli/cmd/protoclient_test.go
Normal file
@ -0,0 +1,188 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/proto"
|
||||
)
|
||||
|
||||
type stubProtoClient struct {
|
||||
conn bool
|
||||
respClient proto.ActivationResponseClient
|
||||
connectErr error
|
||||
closeErr error
|
||||
activateErr error
|
||||
|
||||
activateUserPublicKey []byte
|
||||
activateMasterSecret []byte
|
||||
activateEndpoints []string
|
||||
activateAutoscalingNodeGroups []string
|
||||
cloudServiceAccountURI string
|
||||
}
|
||||
|
||||
func (c *stubProtoClient) Connect(ip string, port string) error {
|
||||
c.conn = true
|
||||
return c.connectErr
|
||||
}
|
||||
|
||||
func (c *stubProtoClient) Close() error {
|
||||
c.conn = false
|
||||
return c.closeErr
|
||||
}
|
||||
|
||||
func (c *stubProtoClient) Activate(ctx context.Context, userPublicKey, masterSecret []byte, endpoints []string, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error) {
|
||||
c.activateUserPublicKey = userPublicKey
|
||||
c.activateMasterSecret = masterSecret
|
||||
c.activateEndpoints = endpoints
|
||||
c.activateAutoscalingNodeGroups = autoscalingNodeGroups
|
||||
c.cloudServiceAccountURI = cloudServiceAccountURI
|
||||
|
||||
return c.respClient, c.activateErr
|
||||
}
|
||||
|
||||
type stubActivationRespClient struct {
|
||||
nextLogErr *error
|
||||
getKubeconfigErr error
|
||||
getCoordinatorVpnKeyErr error
|
||||
getClientVpnIpErr error
|
||||
getOwnerIDErr error
|
||||
getClusterIDErr error
|
||||
writeLogStreamErr error
|
||||
}
|
||||
|
||||
func (s *stubActivationRespClient) NextLog() (string, error) {
|
||||
if s.nextLogErr == nil {
|
||||
return "", io.EOF
|
||||
}
|
||||
return "", *s.nextLogErr
|
||||
}
|
||||
|
||||
func (s *stubActivationRespClient) WriteLogStream(io.Writer) error {
|
||||
return s.writeLogStreamErr
|
||||
}
|
||||
|
||||
func (s *stubActivationRespClient) GetKubeconfig() (string, error) {
|
||||
return "", s.getKubeconfigErr
|
||||
}
|
||||
|
||||
func (s *stubActivationRespClient) GetCoordinatorVpnKey() (string, error) {
|
||||
return "", s.getCoordinatorVpnKeyErr
|
||||
}
|
||||
|
||||
func (s *stubActivationRespClient) GetClientVpnIp() (string, error) {
|
||||
return "", s.getClientVpnIpErr
|
||||
}
|
||||
|
||||
func (s *stubActivationRespClient) GetOwnerID() (string, error) {
|
||||
return "", s.getOwnerIDErr
|
||||
}
|
||||
|
||||
func (s *stubActivationRespClient) GetClusterID() (string, error) {
|
||||
return "", s.getClusterIDErr
|
||||
}
|
||||
|
||||
type fakeProtoClient struct {
|
||||
conn bool
|
||||
respClient proto.ActivationResponseClient
|
||||
}
|
||||
|
||||
func (c *fakeProtoClient) Connect(ip string, port string) error {
|
||||
c.conn = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeProtoClient) Close() error {
|
||||
c.conn = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeProtoClient) Activate(ctx context.Context, userPublicKey, masterSecret []byte, endpoints []string, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error) {
|
||||
if !c.conn {
|
||||
return nil, errors.New("client is not connected")
|
||||
}
|
||||
return c.respClient, nil
|
||||
}
|
||||
|
||||
type fakeActivationRespClient struct {
|
||||
responses []fakeActivationRespMessage
|
||||
kubeconfig string
|
||||
coordinatorVpnKey string
|
||||
clientVpnIp string
|
||||
ownerID string
|
||||
clusterID string
|
||||
}
|
||||
|
||||
func (c *fakeActivationRespClient) NextLog() (string, error) {
|
||||
for len(c.responses) > 0 {
|
||||
resp := c.responses[0]
|
||||
c.responses = c.responses[1:]
|
||||
if len(resp.log) > 0 {
|
||||
return resp.log, nil
|
||||
}
|
||||
c.kubeconfig = resp.kubeconfig
|
||||
c.coordinatorVpnKey = resp.coordinatorVpnKey
|
||||
c.clientVpnIp = resp.clientVpnIp
|
||||
c.ownerID = resp.ownerID
|
||||
c.clusterID = resp.clusterID
|
||||
}
|
||||
return "", io.EOF
|
||||
}
|
||||
|
||||
func (c *fakeActivationRespClient) WriteLogStream(w io.Writer) error {
|
||||
log, err := c.NextLog()
|
||||
for err == nil {
|
||||
fmt.Fprint(w, log)
|
||||
log, err = c.NextLog()
|
||||
}
|
||||
if !errors.Is(err, io.EOF) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeActivationRespClient) GetKubeconfig() (string, error) {
|
||||
if c.kubeconfig == "" {
|
||||
return "", errors.New("kubeconfig is empty")
|
||||
}
|
||||
return c.kubeconfig, nil
|
||||
}
|
||||
|
||||
func (c *fakeActivationRespClient) GetCoordinatorVpnKey() (string, error) {
|
||||
if c.coordinatorVpnKey == "" {
|
||||
return "", errors.New("coordinator public VPN key is empty")
|
||||
}
|
||||
return c.coordinatorVpnKey, nil
|
||||
}
|
||||
|
||||
func (c *fakeActivationRespClient) GetClientVpnIp() (string, error) {
|
||||
if c.clientVpnIp == "" {
|
||||
return "", errors.New("client VPN IP is empty")
|
||||
}
|
||||
return c.clientVpnIp, nil
|
||||
}
|
||||
|
||||
func (c *fakeActivationRespClient) GetOwnerID() (string, error) {
|
||||
if c.ownerID == "" {
|
||||
return "", errors.New("init secret is empty")
|
||||
}
|
||||
return c.ownerID, nil
|
||||
}
|
||||
|
||||
func (c *fakeActivationRespClient) GetClusterID() (string, error) {
|
||||
if c.clusterID == "" {
|
||||
return "", errors.New("cluster identifier is empty")
|
||||
}
|
||||
return c.clusterID, nil
|
||||
}
|
||||
|
||||
type fakeActivationRespMessage struct {
|
||||
log string
|
||||
kubeconfig string
|
||||
coordinatorVpnKey string
|
||||
clientVpnIp string
|
||||
ownerID string
|
||||
clusterID string
|
||||
}
|
60
cli/cmd/rollback.go
Normal file
60
cli/cmd/rollback.go
Normal file
@ -0,0 +1,60 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"go.uber.org/multierr"
|
||||
)
|
||||
|
||||
// rollbacker does a rollback.
|
||||
type rollbacker interface {
|
||||
rollback(ctx context.Context) error
|
||||
}
|
||||
|
||||
// rollbackOnError calls rollback on the rollbacker if the handed error is not nil,
|
||||
// and writes logs to the writer w.
|
||||
func rollbackOnError(ctx context.Context, w io.Writer, onErr *error, roll rollbacker) {
|
||||
if *onErr == nil {
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "An error occurred: %s\n", *onErr)
|
||||
fmt.Fprintln(w, "Attempting to roll back.")
|
||||
if err := roll.rollback(ctx); err != nil {
|
||||
*onErr = multierr.Append(*onErr, fmt.Errorf("on rollback: %w", err)) // TODO: print the error, or retrun it?
|
||||
return
|
||||
}
|
||||
fmt.Fprintln(w, "Rollback succeeded.")
|
||||
}
|
||||
|
||||
type rollbackerGCP struct {
|
||||
client gcpclient
|
||||
}
|
||||
|
||||
func (r *rollbackerGCP) rollback(ctx context.Context) error {
|
||||
var err error
|
||||
err = multierr.Append(err, r.client.TerminateInstances(ctx))
|
||||
err = multierr.Append(err, r.client.TerminateFirewall(ctx))
|
||||
err = multierr.Append(err, r.client.TerminateVPCs(ctx))
|
||||
return err
|
||||
}
|
||||
|
||||
type rollbackerAzure struct {
|
||||
client azureclient
|
||||
}
|
||||
|
||||
func (r *rollbackerAzure) rollback(ctx context.Context) error {
|
||||
return r.client.TerminateResourceGroup(ctx)
|
||||
}
|
||||
|
||||
type rollbackerAWS struct {
|
||||
client ec2client
|
||||
}
|
||||
|
||||
func (r *rollbackerAWS) rollback(ctx context.Context) error {
|
||||
var err error
|
||||
err = multierr.Append(err, r.client.TerminateInstances(ctx))
|
||||
err = multierr.Append(err, r.client.DeleteSecurityGroup(ctx))
|
||||
return err
|
||||
}
|
64
cli/cmd/root.go
Normal file
64
cli/cmd/root.go
Normal file
@ -0,0 +1,64 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "constellation",
|
||||
Short: "Set up your Constellation cluster.",
|
||||
Long: "Set up your Constellation cluster.",
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
// Execute starts the CLI.
|
||||
func Execute() error {
|
||||
ctx, cancel := signalContext(context.Background(), os.Interrupt)
|
||||
defer cancel()
|
||||
return rootCmd.ExecuteContext(ctx)
|
||||
}
|
||||
|
||||
// signalContext returns a context that is canceled on the handed signal.
|
||||
// The signal isn't watched after its first occurrence. Call the cancel
|
||||
// function to ensure the internal goroutine is stopped and the signal isn't
|
||||
// watched any longer.
|
||||
func signalContext(ctx context.Context, sig os.Signal) (context.Context, context.CancelFunc) {
|
||||
sigCtx, stop := signal.NotifyContext(ctx, sig)
|
||||
done := make(chan struct{}, 1)
|
||||
stopDone := make(chan struct{}, 1)
|
||||
|
||||
go func() {
|
||||
defer func() { stopDone <- struct{}{} }()
|
||||
defer stop()
|
||||
select {
|
||||
case <-sigCtx.Done():
|
||||
fmt.Println(" Signal caught. Press ctrl+c again to terminate the program immediately.")
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
|
||||
cancelFunc := func() {
|
||||
done <- struct{}{}
|
||||
<-stopDone
|
||||
}
|
||||
|
||||
return sigCtx, cancelFunc
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Set output of cmd.Print to stdout.
|
||||
rootCmd.SetOut(os.Stdout)
|
||||
// Disable --no-description flag for completion command.
|
||||
rootCmd.CompletionOptions.DisableNoDescFlag = true
|
||||
rootCmd.PersistentFlags().String("dev-config", "", "Set this flag to create the Constellation using settings from a development config.")
|
||||
rootCmd.AddCommand(newVersionCmd())
|
||||
rootCmd.AddCommand(newCreateCmd())
|
||||
rootCmd.AddCommand(newInitCmd())
|
||||
rootCmd.AddCommand(newTerminateCmd())
|
||||
rootCmd.AddCommand(newVerifyCmd())
|
||||
}
|
95
cli/cmd/serviceaccountcreator.go
Normal file
95
cli/cmd/serviceaccountcreator.go
Normal file
@ -0,0 +1,95 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
azurecl "github.com/edgelesssys/constellation/cli/azure/client"
|
||||
"github.com/edgelesssys/constellation/cli/cloudprovider"
|
||||
ec2cl "github.com/edgelesssys/constellation/cli/ec2/client"
|
||||
gcpcl "github.com/edgelesssys/constellation/cli/gcp/client"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
)
|
||||
|
||||
type serviceAccountCreator interface {
|
||||
createServiceAccount(ctx context.Context, stat state.ConstellationState, config *config.Config) (string, state.ConstellationState, error)
|
||||
}
|
||||
|
||||
type serviceAccountClient struct{}
|
||||
|
||||
// createServiceAccount creates a new cloud provider service account with access to the created resources.
|
||||
func (c serviceAccountClient) createServiceAccount(ctx context.Context, stat state.ConstellationState, config *config.Config) (string, state.ConstellationState, error) {
|
||||
switch stat.CloudProvider {
|
||||
case cloudprovider.AWS.String():
|
||||
// TODO: implement
|
||||
ec2client, err := ec2cl.NewFromDefault(ctx)
|
||||
if err != nil {
|
||||
return "", state.ConstellationState{}, err
|
||||
}
|
||||
return c.createServiceAccountEC2(ctx, ec2client, stat, config)
|
||||
case cloudprovider.GCP.String():
|
||||
gcpclient, err := gcpcl.NewFromDefault(ctx)
|
||||
if err != nil {
|
||||
return "", state.ConstellationState{}, err
|
||||
}
|
||||
serviceAccount, stat, err := c.createServiceAccountGCP(ctx, gcpclient, stat, config)
|
||||
if err != nil {
|
||||
return "", state.ConstellationState{}, err
|
||||
}
|
||||
return serviceAccount, stat, gcpclient.Close()
|
||||
case cloudprovider.Azure.String():
|
||||
azureclient, err := azurecl.NewFromDefault(stat.AzureSubscription, stat.AzureTenant)
|
||||
if err != nil {
|
||||
return "", state.ConstellationState{}, err
|
||||
}
|
||||
return c.createServiceAccountAzure(ctx, azureclient, stat)
|
||||
}
|
||||
|
||||
return "", state.ConstellationState{}, fmt.Errorf("unknown cloud provider %v", stat.CloudProvider)
|
||||
}
|
||||
|
||||
func (c serviceAccountClient) createServiceAccountAzure(ctx context.Context, cl azureclient, stat state.ConstellationState) (string, state.ConstellationState, error) {
|
||||
if err := cl.SetState(stat); err != nil {
|
||||
return "", state.ConstellationState{}, fmt.Errorf("failed to set state while creating service account: %w", err)
|
||||
}
|
||||
serviceAccount, err := cl.CreateServicePrincipal(ctx)
|
||||
if err != nil {
|
||||
return "", state.ConstellationState{}, fmt.Errorf("failed to create service account: %w", err)
|
||||
}
|
||||
|
||||
stat, err = cl.GetState()
|
||||
if err != nil {
|
||||
return "", state.ConstellationState{}, fmt.Errorf("failed to get state after creating service account: %w", err)
|
||||
}
|
||||
return serviceAccount, stat, nil
|
||||
}
|
||||
|
||||
func (c serviceAccountClient) createServiceAccountGCP(ctx context.Context, cl gcpclient, stat state.ConstellationState, config *config.Config) (string, state.ConstellationState, error) {
|
||||
if err := cl.SetState(stat); err != nil {
|
||||
return "", state.ConstellationState{}, fmt.Errorf("failed to set state while creating service account: %w", err)
|
||||
}
|
||||
|
||||
input := gcpcl.ServiceAccountInput{
|
||||
Roles: *config.Provider.GCP.ServiceAccountRoles,
|
||||
}
|
||||
serviceAccount, err := cl.CreateServiceAccount(ctx, input)
|
||||
if err != nil {
|
||||
return "", state.ConstellationState{}, fmt.Errorf("failed to create service account: %w", err)
|
||||
}
|
||||
|
||||
stat, err = cl.GetState()
|
||||
if err != nil {
|
||||
return "", state.ConstellationState{}, fmt.Errorf("failed to get state after creating service account: %w", err)
|
||||
}
|
||||
return serviceAccount, stat, nil
|
||||
}
|
||||
|
||||
//nolint:unparam
|
||||
func (c serviceAccountClient) createServiceAccountEC2(ctx context.Context, cl ec2client, stat state.ConstellationState, config *config.Config) (string, state.ConstellationState, error) {
|
||||
// TODO: implement
|
||||
if err := cl.SetState(stat); err != nil {
|
||||
return "", state.ConstellationState{}, fmt.Errorf("failed to set state while creating service account: %w", err)
|
||||
}
|
||||
return "", stat, nil
|
||||
}
|
136
cli/cmd/serviceaccountcreator_test.go
Normal file
136
cli/cmd/serviceaccountcreator_test.go
Normal file
@ -0,0 +1,136 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/cli/gcp"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCreateServiceAccountAzure(t *testing.T) {
|
||||
testState := state.ConstellationState{
|
||||
CloudProvider: cloudprovider.Azure.String(),
|
||||
}
|
||||
someErr := errors.New("failed")
|
||||
|
||||
testCases := map[string]struct {
|
||||
existingState state.ConstellationState
|
||||
client azureclient
|
||||
errExpected bool
|
||||
}{
|
||||
"create service account works": {
|
||||
existingState: testState,
|
||||
client: &fakeAzureClient{},
|
||||
},
|
||||
"fail setState": {
|
||||
existingState: testState,
|
||||
client: &stubAzureClient{setStateErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
"fail create": {
|
||||
existingState: testState,
|
||||
client: &stubAzureClient{createServicePrincipalErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
client := serviceAccountClient{}
|
||||
serviceAccount, _, err := client.createServiceAccountAzure(context.Background(), tc.client, tc.existingState)
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.NotNil(serviceAccount)
|
||||
stat, err := tc.client.GetState()
|
||||
assert.NoError(err)
|
||||
assert.Equal(state.ConstellationState{
|
||||
CloudProvider: cloudprovider.Azure.String(),
|
||||
AzureADAppObjectID: "00000000-0000-0000-0000-000000000001",
|
||||
}, stat)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateServiceAccountGCP(t *testing.T) {
|
||||
testState := state.ConstellationState{
|
||||
GCPProject: "project",
|
||||
GCPNodes: gcp.Instances{},
|
||||
GCPCoordinators: gcp.Instances{},
|
||||
GCPNodeInstanceGroup: "nodes-group",
|
||||
GCPCoordinatorInstanceGroup: "coordinator-group",
|
||||
GCPNodeInstanceTemplate: "template",
|
||||
GCPCoordinatorInstanceTemplate: "template",
|
||||
GCPNetwork: "network",
|
||||
GCPFirewalls: []string{},
|
||||
}
|
||||
config := config.Default()
|
||||
someErr := errors.New("failed")
|
||||
|
||||
testCases := map[string]struct {
|
||||
existingState state.ConstellationState
|
||||
client gcpclient
|
||||
errExpected bool
|
||||
}{
|
||||
"create service account works": {
|
||||
existingState: testState,
|
||||
client: &fakeGcpClient{},
|
||||
},
|
||||
"fail setState": {
|
||||
existingState: testState,
|
||||
client: &stubGcpClient{setStateErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
"fail create": {
|
||||
existingState: testState,
|
||||
client: &stubGcpClient{createServiceAccountErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
client := serviceAccountClient{}
|
||||
serviceAccount, _, err := client.createServiceAccountGCP(context.Background(), tc.client, tc.existingState, config)
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.NotNil(serviceAccount)
|
||||
stat, err := tc.client.GetState()
|
||||
assert.NoError(err)
|
||||
assert.Equal(state.ConstellationState{
|
||||
CloudProvider: cloudprovider.GCP.String(),
|
||||
GCPProject: "project",
|
||||
GCPNodes: gcp.Instances{},
|
||||
GCPCoordinators: gcp.Instances{},
|
||||
GCPNodeInstanceGroup: "nodes-group",
|
||||
GCPCoordinatorInstanceGroup: "coordinator-group",
|
||||
GCPNodeInstanceTemplate: "template",
|
||||
GCPCoordinatorInstanceTemplate: "template",
|
||||
GCPNetwork: "network",
|
||||
GCPFirewalls: []string{},
|
||||
GCPServiceAccount: "service-account@project.iam.gserviceaccount.com",
|
||||
}, stat)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type stubServiceAccountCreator struct {
|
||||
cloudServiceAccountURI string
|
||||
createErr error
|
||||
}
|
||||
|
||||
func (c *stubServiceAccountCreator) createServiceAccount(ctx context.Context, stat state.ConstellationState, config *config.Config) (string, state.ConstellationState, error) {
|
||||
return c.cloudServiceAccountURI, stat, c.createErr
|
||||
}
|
11
cli/cmd/statuswaiter.go
Normal file
11
cli/cmd/statuswaiter.go
Normal file
@ -0,0 +1,11 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/state"
|
||||
)
|
||||
|
||||
type statusWaiter interface {
|
||||
WaitForAll(ctx context.Context, status state.State, endpoints []string) error
|
||||
}
|
15
cli/cmd/statuswaiter_test.go
Normal file
15
cli/cmd/statuswaiter_test.go
Normal file
@ -0,0 +1,15 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/state"
|
||||
)
|
||||
|
||||
type stubStatusWaiter struct {
|
||||
waitForAllErr error
|
||||
}
|
||||
|
||||
func (w stubStatusWaiter) WaitForAll(ctx context.Context, status state.State, endpoints []string) error {
|
||||
return w.waitForAllErr
|
||||
}
|
131
cli/cmd/terminate.go
Normal file
131
cli/cmd/terminate.go
Normal file
@ -0,0 +1,131 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
azure "github.com/edgelesssys/constellation/cli/azure/client"
|
||||
ec2 "github.com/edgelesssys/constellation/cli/ec2/client"
|
||||
"github.com/edgelesssys/constellation/cli/file"
|
||||
gcp "github.com/edgelesssys/constellation/cli/gcp/client"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
)
|
||||
|
||||
func newTerminateCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "terminate",
|
||||
Short: "Terminate an existing Constellation.",
|
||||
Long: "Terminate an existing Constellation. The Constellation can't be started again, and all persistent storage will be lost.",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: runTerminate,
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
// runTerminate runs the terminate command.
|
||||
func runTerminate(cmd *cobra.Command, args []string) error {
|
||||
fileHandler := file.NewHandler(afero.NewOsFs())
|
||||
devConfigName, err := cmd.Flags().GetString("dev-config")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config, err := config.FromFile(fileHandler, devConfigName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return terminate(cmd, fileHandler, config)
|
||||
}
|
||||
|
||||
func terminate(cmd *cobra.Command, fileHandler file.Handler, config *config.Config) error {
|
||||
var stat state.ConstellationState
|
||||
if err := fileHandler.ReadJSON(*config.StatePath, &stat); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(stat.EC2Instances) != 0 || stat.EC2SecurityGroup != "" {
|
||||
ec2client, err := ec2.NewFromDefault(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := terminateEC2(cmd, ec2client, stat); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// TODO: improve check, also look for other resources that might need to be terminated
|
||||
if len(stat.GCPNodes) != 0 {
|
||||
gcpclient, err := gcp.NewFromDefault(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := terminateGCP(cmd, gcpclient, stat); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(stat.AzureResourceGroup) != 0 {
|
||||
azureclient, err := azure.NewFromDefault(stat.AzureSubscription, stat.AzureTenant)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := terminateAzure(cmd, azureclient, stat); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
cmd.Println("Your Constellation was terminated successfully.")
|
||||
|
||||
if err := fileHandler.Remove(*config.StatePath); err != nil {
|
||||
return fmt.Errorf("failed to remove file '%s', please remove manually", *config.StatePath)
|
||||
}
|
||||
|
||||
if err := fileHandler.Remove(*config.AdminConfPath); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
return fmt.Errorf("failed to remove file '%s', please remove manually", *config.AdminConfPath)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func terminateAzure(cmd *cobra.Command, cl azureclient, stat state.ConstellationState) error {
|
||||
if err := cl.SetState(stat); err != nil {
|
||||
return fmt.Errorf("failed to terminate the Constellation: %w", err)
|
||||
}
|
||||
|
||||
if err := cl.TerminateServicePrincipal(cmd.Context()); err != nil {
|
||||
return err
|
||||
}
|
||||
return cl.TerminateResourceGroup(cmd.Context())
|
||||
}
|
||||
|
||||
func terminateGCP(cmd *cobra.Command, cl gcpclient, stat state.ConstellationState) error {
|
||||
if err := cl.SetState(stat); err != nil {
|
||||
return fmt.Errorf("failed to terminate the Constellation: %w", err)
|
||||
}
|
||||
|
||||
if err := cl.TerminateInstances(cmd.Context()); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := cl.TerminateFirewall(cmd.Context()); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := cl.TerminateVPCs(cmd.Context()); err != nil {
|
||||
return err
|
||||
}
|
||||
return cl.TerminateServiceAccount(cmd.Context())
|
||||
}
|
||||
|
||||
// terminateEC2 and remove the existing Constellation form the state file.
|
||||
func terminateEC2(cmd *cobra.Command, cl ec2client, stat state.ConstellationState) error {
|
||||
if err := cl.SetState(stat); err != nil {
|
||||
return fmt.Errorf("failed to terminate the Constellation: %w", err)
|
||||
}
|
||||
|
||||
if err := cl.TerminateInstances(cmd.Context()); err != nil {
|
||||
return fmt.Errorf("failed to terminate the Constellation: %w", err)
|
||||
}
|
||||
|
||||
return cl.DeleteSecurityGroup(cmd.Context())
|
||||
}
|
288
cli/cmd/terminate_test.go
Normal file
288
cli/cmd/terminate_test.go
Normal file
@ -0,0 +1,288 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/azure"
|
||||
"github.com/edgelesssys/constellation/cli/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/cli/ec2"
|
||||
"github.com/edgelesssys/constellation/cli/gcp"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTerminateCmdArgumentValidation(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
args []string
|
||||
expectErr bool
|
||||
}{
|
||||
"no args": {[]string{}, false},
|
||||
"some args": {[]string{"hello", "test"}, true},
|
||||
"some other args": {[]string{"12", "2"}, true},
|
||||
}
|
||||
|
||||
cmd := newTerminateCmd()
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
err := cmd.ValidateArgs(tc.args)
|
||||
if tc.expectErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTerminateEC2(t *testing.T) {
|
||||
testState := state.ConstellationState{
|
||||
CloudProvider: cloudprovider.AWS.String(),
|
||||
EC2Instances: ec2.Instances{
|
||||
"id-0": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
"id-1": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
"id-3": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
},
|
||||
EC2SecurityGroup: "sg-test",
|
||||
}
|
||||
someErr := errors.New("failed")
|
||||
|
||||
testCases := map[string]struct {
|
||||
existingState state.ConstellationState
|
||||
client ec2client
|
||||
errExpected bool
|
||||
}{
|
||||
"terminate existing instances": {
|
||||
existingState: testState,
|
||||
client: &fakeEc2Client{},
|
||||
errExpected: false,
|
||||
},
|
||||
"state without instances": {
|
||||
existingState: state.ConstellationState{
|
||||
CloudProvider: cloudprovider.AWS.String(),
|
||||
EC2Instances: ec2.Instances{},
|
||||
},
|
||||
client: &fakeEc2Client{},
|
||||
errExpected: true,
|
||||
},
|
||||
"fail TerminateInstances": {
|
||||
existingState: testState,
|
||||
client: &stubEc2Client{terminateInstancesErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
"fail DeleteSecurityGroup": {
|
||||
existingState: testState,
|
||||
client: &stubEc2Client{deleteSecurityGroupErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
cmd := newTerminateCmd()
|
||||
out := bytes.NewBufferString("")
|
||||
cmd.SetOut(out)
|
||||
errOut := bytes.NewBufferString("")
|
||||
cmd.SetErr(errOut)
|
||||
|
||||
err := terminateEC2(cmd, tc.client, tc.existingState)
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTerminateGCP(t *testing.T) {
|
||||
testState := state.ConstellationState{
|
||||
GCPNodes: gcp.Instances{
|
||||
"id-0": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
"id-1": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
},
|
||||
GCPCoordinators: gcp.Instances{
|
||||
"id-c": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
},
|
||||
GCPNodeInstanceGroup: "nodes-group",
|
||||
GCPCoordinatorInstanceGroup: "coordinator-group",
|
||||
GCPNodeInstanceTemplate: "template",
|
||||
GCPCoordinatorInstanceTemplate: "template",
|
||||
GCPNetwork: "network",
|
||||
GCPFirewalls: []string{"coordinator", "wireguard", "ssh"},
|
||||
}
|
||||
someErr := errors.New("failed")
|
||||
|
||||
testCases := map[string]struct {
|
||||
existingState state.ConstellationState
|
||||
client gcpclient
|
||||
errExpected bool
|
||||
}{
|
||||
"terminate existing instances": {
|
||||
existingState: testState,
|
||||
client: &fakeGcpClient{},
|
||||
},
|
||||
"state without instances": {
|
||||
existingState: state.ConstellationState{EC2Instances: ec2.Instances{}},
|
||||
client: &fakeGcpClient{},
|
||||
},
|
||||
"state not found": {
|
||||
existingState: testState,
|
||||
client: &fakeGcpClient{},
|
||||
},
|
||||
"fail setState": {
|
||||
existingState: testState,
|
||||
client: &stubGcpClient{setStateErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
"fail terminateFirewall": {
|
||||
existingState: testState,
|
||||
client: &stubGcpClient{terminateFirewallErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
"fail terminateVPC": {
|
||||
existingState: testState,
|
||||
client: &stubGcpClient{terminateVPCsErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
"fail terminateInstances": {
|
||||
existingState: testState,
|
||||
client: &stubGcpClient{terminateInstancesErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
"fail terminateServiceAccount": {
|
||||
existingState: testState,
|
||||
client: &stubGcpClient{terminateServiceAccountErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
cmd := newTerminateCmd()
|
||||
out := bytes.NewBufferString("")
|
||||
cmd.SetOut(out)
|
||||
errOut := bytes.NewBufferString("")
|
||||
cmd.SetErr(errOut)
|
||||
|
||||
err := terminateGCP(cmd, tc.client, tc.existingState)
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
stat, err := tc.client.GetState()
|
||||
assert.NoError(err)
|
||||
assert.Equal(state.ConstellationState{
|
||||
CloudProvider: cloudprovider.GCP.String(),
|
||||
}, stat)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTerminateAzure(t *testing.T) {
|
||||
testState := state.ConstellationState{
|
||||
CloudProvider: cloudprovider.Azure.String(),
|
||||
AzureNodes: azure.Instances{
|
||||
"id-0": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
"id-1": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
},
|
||||
AzureCoordinators: azure.Instances{
|
||||
"id-c": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
},
|
||||
AzureResourceGroup: "test",
|
||||
}
|
||||
someErr := errors.New("failed")
|
||||
|
||||
testCases := map[string]struct {
|
||||
existingState state.ConstellationState
|
||||
client azureclient
|
||||
errExpected bool
|
||||
}{
|
||||
"terminate existing instances": {
|
||||
existingState: testState,
|
||||
client: &fakeAzureClient{},
|
||||
},
|
||||
"state resource group": {
|
||||
existingState: state.ConstellationState{AzureResourceGroup: ""},
|
||||
client: &fakeAzureClient{},
|
||||
},
|
||||
"state not found": {
|
||||
existingState: testState,
|
||||
client: &fakeAzureClient{},
|
||||
},
|
||||
"fail setState": {
|
||||
existingState: testState,
|
||||
client: &stubAzureClient{setStateErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
"fail resource group termination": {
|
||||
existingState: testState,
|
||||
client: &stubAzureClient{terminateResourceGroupErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
"fail service principal termination": {
|
||||
existingState: testState,
|
||||
client: &stubAzureClient{terminateServicePrincipalErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
cmd := newTerminateCmd()
|
||||
out := bytes.NewBufferString("")
|
||||
cmd.SetOut(out)
|
||||
errOut := bytes.NewBufferString("")
|
||||
cmd.SetErr(errOut)
|
||||
|
||||
err := terminateAzure(cmd, tc.client, tc.existingState)
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
stat, err := tc.client.GetState()
|
||||
assert.NoError(err)
|
||||
assert.Equal(state.ConstellationState{
|
||||
CloudProvider: cloudprovider.Azure.String(),
|
||||
}, stat)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
80
cli/cmd/userinteraction.go
Normal file
80
cli/cmd/userinteraction.go
Normal file
@ -0,0 +1,80 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidInput is an error where user entered invalid input.
|
||||
ErrInvalidInput = errors.New("user made invalid input")
|
||||
warningStr = "Warning: not verifying the Constellation's %s measurements\n"
|
||||
)
|
||||
|
||||
// askToConfirm asks user to confirm an action.
|
||||
// The user will be asked the handed question and can answer with
|
||||
// yes or no.
|
||||
func askToConfirm(cmd *cobra.Command, question string) (bool, error) {
|
||||
reader := bufio.NewReader(cmd.InOrStdin())
|
||||
cmd.Printf("%s [y/n]: ", question)
|
||||
for i := 0; i < 3; i++ {
|
||||
resp, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
resp = strings.ToLower(strings.TrimSpace(resp))
|
||||
if resp == "n" || resp == "no" {
|
||||
return false, nil
|
||||
}
|
||||
if resp == "y" || resp == "yes" {
|
||||
return true, nil
|
||||
}
|
||||
cmd.Printf("Type 'y' or 'yes' to confirm, or abort action with 'n' or 'no': ")
|
||||
}
|
||||
return false, ErrInvalidInput
|
||||
}
|
||||
|
||||
// warnAboutPCRs displays warnings if specifc PCR values are not verified.
|
||||
//
|
||||
// PCR allocation inspired by https://link.springer.com/chapter/10.1007/978-1-4302-6584-9_12#Tab1
|
||||
func warnAboutPCRs(cmd *cobra.Command, pcrs map[uint32][]byte, checkInit bool) error {
|
||||
for k, v := range pcrs {
|
||||
if len(v) != 32 {
|
||||
return fmt.Errorf("bad config: PCR[%d]: expected length: %d, but got: %d", k, 32, len(v))
|
||||
}
|
||||
}
|
||||
|
||||
if pcrs[0] == nil || pcrs[1] == nil {
|
||||
cmd.PrintErrf(warningStr, "BIOS")
|
||||
}
|
||||
|
||||
if pcrs[2] == nil || pcrs[3] == nil {
|
||||
cmd.PrintErrf(warningStr, "OPROM")
|
||||
}
|
||||
|
||||
if pcrs[4] == nil || pcrs[5] == nil {
|
||||
cmd.PrintErrf(warningStr, "MBR")
|
||||
}
|
||||
|
||||
// GRUB measures kernel command line and initrd into pcrs 8 and 9
|
||||
if pcrs[8] == nil {
|
||||
cmd.PrintErrf(warningStr, "kernel command line")
|
||||
}
|
||||
if pcrs[9] == nil {
|
||||
cmd.PrintErrf(warningStr, "initrd")
|
||||
}
|
||||
|
||||
// Only warn about initialization PCRs if necessary
|
||||
if checkInit {
|
||||
if pcrs[uint32(vtpm.PCRIndexOwnerID)] == nil || pcrs[uint32(vtpm.PCRIndexClusterID)] == nil {
|
||||
cmd.PrintErrf(warningStr, "initialization status")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
249
cli/cmd/userinteraction_test.go
Normal file
249
cli/cmd/userinteraction_test.go
Normal file
@ -0,0 +1,249 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAskToConfirm(t *testing.T) {
|
||||
// errAborted is an error where the user aborted the action.
|
||||
errAborted := errors.New("user aborted")
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "test",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ok, err := askToConfirm(cmd, "777")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return errAborted
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
input string
|
||||
expectedErr error
|
||||
}{
|
||||
"user confirms": {"y\n", nil},
|
||||
"user confirms long": {"yes\n", nil},
|
||||
"user disagrees": {"n\n", errAborted},
|
||||
"user disagrees long": {"no\n", errAborted},
|
||||
"user is first unsure, but agrees": {"what?\ny\n", nil},
|
||||
"user is first unsure, but disagrees": {"wait.\nn\n", errAborted},
|
||||
"repeated invalid input": {"h\nb\nq\n", ErrInvalidInput},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
out := bytes.NewBufferString("")
|
||||
cmd.SetOut(out)
|
||||
errOut := bytes.NewBufferString("")
|
||||
cmd.SetErr(errOut)
|
||||
in := bytes.NewBufferString(tc.input)
|
||||
cmd.SetIn(in)
|
||||
|
||||
err := cmd.Execute()
|
||||
assert.ErrorIs(err, tc.expectedErr)
|
||||
|
||||
output, err := io.ReadAll(out)
|
||||
assert.NoError(err)
|
||||
assert.Contains(string(output), "777")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWarnAboutPCRs(t *testing.T) {
|
||||
zero := []byte("00000000000000000000000000000000")
|
||||
|
||||
testCases := map[string]struct {
|
||||
pcrs map[uint32][]byte
|
||||
dontWarnInit bool
|
||||
expectedWarnings []string
|
||||
errExpected bool
|
||||
}{
|
||||
"no warnings": {
|
||||
pcrs: map[uint32][]byte{
|
||||
0: zero,
|
||||
1: zero,
|
||||
2: zero,
|
||||
3: zero,
|
||||
4: zero,
|
||||
5: zero,
|
||||
6: zero,
|
||||
7: zero,
|
||||
8: zero,
|
||||
9: zero,
|
||||
10: zero,
|
||||
11: zero,
|
||||
12: zero,
|
||||
},
|
||||
},
|
||||
"no warnings for missing non critical values": {
|
||||
pcrs: map[uint32][]byte{
|
||||
0: zero,
|
||||
1: zero,
|
||||
2: zero,
|
||||
3: zero,
|
||||
4: zero,
|
||||
5: zero,
|
||||
8: zero,
|
||||
9: zero,
|
||||
11: zero,
|
||||
12: zero,
|
||||
},
|
||||
},
|
||||
"warn for BIOS": {
|
||||
pcrs: map[uint32][]byte{
|
||||
0: zero,
|
||||
2: zero,
|
||||
3: zero,
|
||||
4: zero,
|
||||
5: zero,
|
||||
8: zero,
|
||||
9: zero,
|
||||
11: zero,
|
||||
12: zero,
|
||||
},
|
||||
expectedWarnings: []string{"BIOS"},
|
||||
},
|
||||
"warn for OPROM": {
|
||||
pcrs: map[uint32][]byte{
|
||||
0: zero,
|
||||
1: zero,
|
||||
3: zero,
|
||||
4: zero,
|
||||
5: zero,
|
||||
8: zero,
|
||||
9: zero,
|
||||
11: zero,
|
||||
12: zero,
|
||||
},
|
||||
expectedWarnings: []string{"OPROM"},
|
||||
},
|
||||
"warn for MBR": {
|
||||
pcrs: map[uint32][]byte{
|
||||
0: zero,
|
||||
1: zero,
|
||||
2: zero,
|
||||
3: zero,
|
||||
5: zero,
|
||||
8: zero,
|
||||
9: zero,
|
||||
11: zero,
|
||||
12: zero,
|
||||
},
|
||||
expectedWarnings: []string{"MBR"},
|
||||
},
|
||||
"warn for kernel": {
|
||||
pcrs: map[uint32][]byte{
|
||||
0: zero,
|
||||
1: zero,
|
||||
2: zero,
|
||||
3: zero,
|
||||
4: zero,
|
||||
5: zero,
|
||||
9: zero,
|
||||
11: zero,
|
||||
12: zero,
|
||||
},
|
||||
expectedWarnings: []string{"kernel"},
|
||||
},
|
||||
"warn for initrd": {
|
||||
pcrs: map[uint32][]byte{
|
||||
0: zero,
|
||||
1: zero,
|
||||
2: zero,
|
||||
3: zero,
|
||||
4: zero,
|
||||
5: zero,
|
||||
8: zero,
|
||||
11: zero,
|
||||
12: zero,
|
||||
},
|
||||
expectedWarnings: []string{"initrd"},
|
||||
},
|
||||
"warn for initialization": {
|
||||
pcrs: map[uint32][]byte{
|
||||
0: zero,
|
||||
1: zero,
|
||||
2: zero,
|
||||
3: zero,
|
||||
4: zero,
|
||||
5: zero,
|
||||
8: zero,
|
||||
9: zero,
|
||||
11: zero,
|
||||
},
|
||||
dontWarnInit: false,
|
||||
expectedWarnings: []string{"initialization"},
|
||||
},
|
||||
"don't warn for initialization": {
|
||||
pcrs: map[uint32][]byte{
|
||||
0: zero,
|
||||
1: zero,
|
||||
2: zero,
|
||||
3: zero,
|
||||
4: zero,
|
||||
5: zero,
|
||||
8: zero,
|
||||
9: zero,
|
||||
11: zero,
|
||||
},
|
||||
dontWarnInit: true,
|
||||
},
|
||||
"multi warning": {
|
||||
pcrs: map[uint32][]byte{},
|
||||
expectedWarnings: []string{
|
||||
"BIOS",
|
||||
"OPROM",
|
||||
"MBR",
|
||||
"initialization",
|
||||
"initrd",
|
||||
"kernel",
|
||||
},
|
||||
},
|
||||
"bad config": {
|
||||
pcrs: map[uint32][]byte{
|
||||
0: []byte("000"),
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
cmd := newInitCmd()
|
||||
var out bytes.Buffer
|
||||
cmd.SetOut(&out)
|
||||
var errOut bytes.Buffer
|
||||
cmd.SetErr(&errOut)
|
||||
|
||||
err := warnAboutPCRs(cmd, tc.pcrs, !tc.dontWarnInit)
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
if len(tc.expectedWarnings) == 0 {
|
||||
assert.Empty(errOut.String())
|
||||
} else {
|
||||
for _, warning := range tc.expectedWarnings {
|
||||
assert.Contains(errOut.String(), warning)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
70
cli/cmd/validargs.go
Normal file
70
cli/cmd/validargs.go
Normal file
@ -0,0 +1,70 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/azure"
|
||||
"github.com/edgelesssys/constellation/cli/ec2"
|
||||
"github.com/edgelesssys/constellation/cli/gcp"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// isIntArg checks if argument at position arg is an integer.
|
||||
func isIntArg(arg int) cobra.PositionalArgs {
|
||||
return func(cmd *cobra.Command, args []string) error {
|
||||
if _, err := strconv.Atoi(args[arg]); err != nil {
|
||||
return fmt.Errorf("argument %d must be an integer", arg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// isIntGreaterArg checks if argument at position arg is and integer and greater i.
|
||||
func isIntGreaterArg(arg int, i int) cobra.PositionalArgs {
|
||||
return cobra.MatchAll(isIntArg(arg), func(cmd *cobra.Command, args []string) error {
|
||||
if v, _ := strconv.Atoi(args[arg]); v <= i {
|
||||
return fmt.Errorf("argument %d must be greater %d, but it's %d", arg, i, v)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// isIntGreaterZeroArg checks if argument at position arg is a positive non zero integer.
|
||||
func isIntGreaterZeroArg(arg int) cobra.PositionalArgs {
|
||||
return cobra.MatchAll(isIntGreaterArg(arg, 0))
|
||||
}
|
||||
|
||||
// isEC2InstanceType checks if argument at position arg is a key in m.
|
||||
// The argument will always be converted to lower case letters.
|
||||
func isEC2InstanceType(arg int) cobra.PositionalArgs {
|
||||
return func(cmd *cobra.Command, args []string) error {
|
||||
if _, ok := ec2.InstanceTypes[strings.ToLower(args[arg])]; !ok {
|
||||
return fmt.Errorf("'%s' isn't an AWS EC2 instance type", args[arg])
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func isGCPInstanceType(arg int) cobra.PositionalArgs {
|
||||
return func(cmd *cobra.Command, args []string) error {
|
||||
for _, instanceType := range gcp.InstanceTypes {
|
||||
if args[arg] == instanceType {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("argument %s isn't a valid GCP instance type", args[arg])
|
||||
}
|
||||
}
|
||||
|
||||
func isAzureInstanceType(arg int) cobra.PositionalArgs {
|
||||
return func(cmd *cobra.Command, args []string) error {
|
||||
for _, instanceType := range azure.InstanceTypes {
|
||||
if args[arg] == instanceType {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("argument %s isn't a valid Azure instance type", args[arg])
|
||||
}
|
||||
}
|
197
cli/cmd/validargs_test.go
Normal file
197
cli/cmd/validargs_test.go
Normal file
@ -0,0 +1,197 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsIntArg(t *testing.T) {
|
||||
testCmd := &cobra.Command{
|
||||
Use: "test",
|
||||
Args: isIntArg(0),
|
||||
Run: func(cmd *cobra.Command, args []string) {},
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
args []string
|
||||
expectErr bool
|
||||
}{
|
||||
"valid int 1": {[]string{"1"}, false},
|
||||
"valid int 2": {[]string{"42"}, false},
|
||||
"valid int 3": {[]string{"987987498"}, false},
|
||||
"valid int and other args": {[]string{"3", "hello"}, false},
|
||||
"valid int and other args 2": {[]string{"3", "4"}, false},
|
||||
"invalid 1": {[]string{"hello world"}, true},
|
||||
"invalid 2": {[]string{"98798d749f8"}, true},
|
||||
"invalid 3": {[]string{"three"}, true},
|
||||
"invalid 4": {[]string{"0.3"}, true},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
err := testCmd.ValidateArgs(tc.args)
|
||||
if tc.expectErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsIntGreaterArg(t *testing.T) {
|
||||
testCmd := &cobra.Command{
|
||||
Use: "test",
|
||||
Args: isIntGreaterArg(0, 12),
|
||||
Run: func(cmd *cobra.Command, args []string) {},
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
args []string
|
||||
expectErr bool
|
||||
}{
|
||||
"valid int 1": {[]string{"13"}, false},
|
||||
"valid int 2": {[]string{"42"}, false},
|
||||
"valid int 3": {[]string{"987987498"}, false},
|
||||
"invalid int 1": {[]string{"1"}, true},
|
||||
"invalid int and other args": {[]string{"3", "hello"}, true},
|
||||
"invalid int and other args 2": {[]string{"-14", "4"}, true},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
err := testCmd.ValidateArgs(tc.args)
|
||||
if tc.expectErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsIntGreaterZeroArg(t *testing.T) {
|
||||
testCmd := &cobra.Command{
|
||||
Use: "test",
|
||||
Args: isIntGreaterZeroArg(0),
|
||||
Run: func(cmd *cobra.Command, args []string) {},
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
args []string
|
||||
expectErr bool
|
||||
}{
|
||||
"valid int 1": {[]string{"13"}, false},
|
||||
"valid int 2": {[]string{"42"}, false},
|
||||
"valid int 3": {[]string{"987987498"}, false},
|
||||
"invalid": {[]string{"0"}, true},
|
||||
"invalid int 1": {[]string{"-42", "hello"}, true},
|
||||
"invalid int and other args": {[]string{"-9487239847", "4"}, true},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
err := testCmd.ValidateArgs(tc.args)
|
||||
if tc.expectErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsEC2InstanceType(t *testing.T) {
|
||||
testCmd := &cobra.Command{
|
||||
Use: "test",
|
||||
Args: isEC2InstanceType(0),
|
||||
Run: func(cmd *cobra.Command, args []string) {},
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
args []string
|
||||
expectErr bool
|
||||
}{
|
||||
"is instance type 1": {[]string{"4xl"}, false},
|
||||
"is instance type 2": {[]string{"12xlarge", "something else"}, false},
|
||||
"isn't instance type 1": {[]string{"notanInstanceType"}, true},
|
||||
"isn't instance type 2": {[]string{"Hello World!"}, true},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
err := testCmd.ValidateArgs(tc.args)
|
||||
if tc.expectErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsGCPInstanceType(t *testing.T) {
|
||||
testCmd := &cobra.Command{
|
||||
Use: "test",
|
||||
Args: isGCPInstanceType(0),
|
||||
Run: func(cmd *cobra.Command, args []string) {},
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
args []string
|
||||
expectErr bool
|
||||
}{
|
||||
"is instance type 1": {[]string{"n2d-standard-4"}, false},
|
||||
"is instance type 2": {[]string{"n2d-standard-16", "something else"}, false},
|
||||
"isn't instance type 1": {[]string{"notanInstanceType"}, true},
|
||||
"isn't instance type 2": {[]string{"Hello World!"}, true},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
err := testCmd.ValidateArgs(tc.args)
|
||||
if tc.expectErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAzureInstanceType(t *testing.T) {
|
||||
testCmd := &cobra.Command{
|
||||
Use: "test",
|
||||
Args: isAzureInstanceType(0),
|
||||
Run: func(cmd *cobra.Command, args []string) {},
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
args []string
|
||||
expectErr bool
|
||||
}{
|
||||
"is instance type 1": {[]string{"Standard_DC2as_v5"}, false},
|
||||
"is instance type 2": {[]string{"Standard_DC8as_v5", "something else"}, false},
|
||||
"isn't instance type 1": {[]string{"notanInstanceType"}, true},
|
||||
"isn't instance type 2": {[]string{"Hello World!"}, true},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
err := testCmd.ValidateArgs(tc.args)
|
||||
if tc.expectErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
141
cli/cmd/verify.go
Normal file
141
cli/cmd/verify.go
Normal file
@ -0,0 +1,141 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/status"
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
|
||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
rpcStatus "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func newVerifyCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "verify azure|gcp",
|
||||
Short: "Verify the confidential properties of your Constellation.",
|
||||
Long: "Verify the confidential properties of your Constellation.",
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().String("owner-id", "", "verify the Constellation using the owner identity derived from the master secret.")
|
||||
cmd.PersistentFlags().String("unique-id", "", "verify the Constellation using the unique cluster identity.")
|
||||
|
||||
cmd.AddCommand(newVerifyGCPCmd())
|
||||
cmd.AddCommand(newVerifyAzureCmd())
|
||||
cmd.AddCommand(newVerifyGCPNonCVMCmd())
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func runVerify(cmd *cobra.Command, args []string, pcrs map[uint32][]byte, validator atls.Validator) error {
|
||||
if err := warnAboutPCRs(cmd, pcrs, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
verifier := verifier{
|
||||
newConn: newVerifiedConn,
|
||||
newClient: pubproto.NewAPIClient,
|
||||
}
|
||||
return verify(cmd.Context(), cmd.OutOrStdout(), net.JoinHostPort(args[0], args[1]), []atls.Validator{validator}, verifier)
|
||||
}
|
||||
|
||||
func verify(ctx context.Context, w io.Writer, target string, validators []atls.Validator, verifier verifier) error {
|
||||
conn, err := verifier.newConn(ctx, target, validators)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := verifier.newClient(conn)
|
||||
|
||||
if _, err := client.GetState(ctx, &pubproto.GetStateRequest{}); err != nil {
|
||||
if err, ok := rpcStatus.FromError(err); ok {
|
||||
return fmt.Errorf("unable to verify Constellation cluster: %s", err.Message())
|
||||
}
|
||||
return err
|
||||
}
|
||||
fmt.Fprintln(w, "OK")
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareValidator parses parameters and updates the PCR map.
|
||||
func prepareValidator(cmd *cobra.Command, pcrs map[uint32][]byte) error {
|
||||
ownerID, err := cmd.Flags().GetString("owner-id")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
clusterID, err := cmd.Flags().GetString("unique-id")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ownerID == "" && clusterID == "" {
|
||||
return errors.New("neither owner identity nor unique identity provided to verify the Constellation")
|
||||
}
|
||||
|
||||
return updatePCRMap(pcrs, ownerID, clusterID)
|
||||
}
|
||||
|
||||
func updatePCRMap(pcrs map[uint32][]byte, ownerID, clusterID string) error {
|
||||
if err := addOrSkipPCR(pcrs, uint32(vtpm.PCRIndexOwnerID), ownerID); err != nil {
|
||||
return err
|
||||
}
|
||||
return addOrSkipPCR(pcrs, uint32(vtpm.PCRIndexClusterID), clusterID)
|
||||
}
|
||||
|
||||
// addOrSkipPCR adds a new entry to the map, or removes the key if the input is an empty string.
|
||||
//
|
||||
// When adding, the input is first decoded from base64.
|
||||
// We then calculate the expected PCR by hashing the input using SHA256,
|
||||
// appending expected PCR for initialization, and then hashing once more.
|
||||
func addOrSkipPCR(toAdd map[uint32][]byte, pcrIndex uint32, encoded string) error {
|
||||
if encoded == "" {
|
||||
delete(toAdd, pcrIndex)
|
||||
return nil
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(encoded)
|
||||
if err != nil {
|
||||
return fmt.Errorf("input [%s] is not base64 encoded: %w", encoded, err)
|
||||
}
|
||||
// new_pcr_value := hash(old_pcr_value || data_to_extend)
|
||||
// Since we use the TPM2_PCR_Event call to extend the PCR, data_to_extend is the hash of our input
|
||||
hashedInput := sha256.Sum256(decoded)
|
||||
expectedPcr := sha256.Sum256(append(toAdd[pcrIndex], hashedInput[:]...))
|
||||
toAdd[pcrIndex] = expectedPcr[:]
|
||||
return nil
|
||||
}
|
||||
|
||||
type verifier struct {
|
||||
newConn func(context.Context, string, []atls.Validator) (status.ClientConn, error)
|
||||
newClient func(cc grpc.ClientConnInterface) pubproto.APIClient
|
||||
}
|
||||
|
||||
// newVerifiedConn creates a grpc over aTLS connection to the target, using the provided PCR values to verify the server.
|
||||
func newVerifiedConn(ctx context.Context, target string, validators []atls.Validator) (status.ClientConn, error) {
|
||||
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return grpc.DialContext(
|
||||
ctx, target, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
|
||||
)
|
||||
}
|
||||
|
||||
// verifyCompletion handels the completion of CLI arguments. It is frequently called
|
||||
// while the user types arguments of the command to suggest completion.
|
||||
func verifyCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
switch len(args) {
|
||||
case 0, 1:
|
||||
return []string{}, cobra.ShellCompDirectiveNoFileComp
|
||||
default:
|
||||
return []string{}, cobra.ShellCompDirectiveError
|
||||
}
|
||||
}
|
51
cli/cmd/verify_azure.go
Normal file
51
cli/cmd/verify_azure.go
Normal file
@ -0,0 +1,51 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/edgelesssys/constellation/cli/file"
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/azure"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newVerifyAzureCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "azure IP PORT",
|
||||
Short: "Verify the confidential properties of your Constellation on Azure.",
|
||||
Long: "Verify the confidential properties of your Constellation on Azure.",
|
||||
Args: cobra.ExactArgs(2),
|
||||
ValidArgsFunction: verifyCompletion,
|
||||
RunE: runVerifyAzure,
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func runVerifyAzure(cmd *cobra.Command, args []string) error {
|
||||
fileHandler := file.NewHandler(afero.NewOsFs())
|
||||
devConfigName, err := cmd.Flags().GetString("dev-config")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config, err := config.FromFile(fileHandler, devConfigName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
validators, err := getAzureValidator(cmd, *config.Provider.GCP.PCRs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return runVerify(cmd, args, *config.Provider.GCP.PCRs, validators)
|
||||
}
|
||||
|
||||
// getAzureValidator returns an Azure validator.
|
||||
func getAzureValidator(cmd *cobra.Command, pcrs map[uint32][]byte) (atls.Validator, error) {
|
||||
if err := prepareValidator(cmd, pcrs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return azure.NewValidator(pcrs), nil
|
||||
}
|
66
cli/cmd/verify_azure_test.go
Normal file
66
cli/cmd/verify_azure_test.go
Normal file
@ -0,0 +1,66 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetAzureValidator(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
ownerID string
|
||||
clusterID string
|
||||
errExpected bool
|
||||
}{
|
||||
"no input": {
|
||||
ownerID: "",
|
||||
clusterID: "",
|
||||
errExpected: true,
|
||||
},
|
||||
"unencoded secret ID": {
|
||||
ownerID: "owner-id",
|
||||
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
|
||||
errExpected: true,
|
||||
},
|
||||
"unencoded cluster ID": {
|
||||
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
|
||||
clusterID: "unique-id",
|
||||
errExpected: true,
|
||||
},
|
||||
"correct input": {
|
||||
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
|
||||
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
|
||||
errExpected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
cmd := newVerifyAzureCmd()
|
||||
cmd.Flags().String("owner-id", "", "")
|
||||
cmd.Flags().String("unique-id", "", "")
|
||||
require.NoError(cmd.Flags().Set("owner-id", tc.ownerID))
|
||||
require.NoError(cmd.Flags().Set("unique-id", tc.clusterID))
|
||||
var out bytes.Buffer
|
||||
cmd.SetOut(&out)
|
||||
var errOut bytes.Buffer
|
||||
cmd.SetErr(&errOut)
|
||||
|
||||
_, err := getAzureValidator(cmd, map[uint32][]byte{
|
||||
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
|
||||
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"),
|
||||
})
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
51
cli/cmd/verify_gcp.go
Normal file
51
cli/cmd/verify_gcp.go
Normal file
@ -0,0 +1,51 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/edgelesssys/constellation/cli/file"
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newVerifyGCPCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "gcp IP PORT",
|
||||
Short: "Verify the confidential properties of your Constellation on Google Cloud Platform.",
|
||||
Long: "Verify the confidential properties of your Constellation on Google Cloud Platform.",
|
||||
Args: cobra.ExactArgs(2),
|
||||
ValidArgsFunction: verifyCompletion,
|
||||
RunE: runVerifyGCP,
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func runVerifyGCP(cmd *cobra.Command, args []string) error {
|
||||
fileHandler := file.NewHandler(afero.NewOsFs())
|
||||
devConfigName, err := cmd.Flags().GetString("dev-config")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config, err := config.FromFile(fileHandler, devConfigName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
validators, err := getGCPValidator(cmd, *config.Provider.GCP.PCRs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return runVerify(cmd, args, *config.Provider.GCP.PCRs, validators)
|
||||
}
|
||||
|
||||
// getValidators returns a GCP validator.
|
||||
func getGCPValidator(cmd *cobra.Command, pcrs map[uint32][]byte) (atls.Validator, error) {
|
||||
if err := prepareValidator(cmd, pcrs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return gcp.NewValidator(pcrs), nil
|
||||
}
|
40
cli/cmd/verify_gcp_noncvm.go
Normal file
40
cli/cmd/verify_gcp_noncvm.go
Normal file
@ -0,0 +1,40 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// TODO: Remove this command once we no longer use non cvms.
|
||||
func newVerifyGCPNonCVMCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "gcp-non-cvm IP PORT",
|
||||
Short: "Verify the TPM attestation of your shielded VM Constellation on Google Cloud Platform.",
|
||||
Long: "Verify the TPM attestation of your shielded VM Constellation on Google Cloud Platform.",
|
||||
Args: cobra.ExactArgs(2),
|
||||
ValidArgsFunction: verifyCompletion,
|
||||
RunE: runVerifyGCPNonCVM,
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func runVerifyGCPNonCVM(cmd *cobra.Command, args []string) error {
|
||||
pcrs := map[uint32][]byte{}
|
||||
validator, err := getGCPNonCVMValidator(cmd, pcrs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return runVerify(cmd, args, pcrs, validator)
|
||||
}
|
||||
|
||||
// getGCPNonCVMValidator returns a GCP validator for regular shielded VMs.
|
||||
func getGCPNonCVMValidator(cmd *cobra.Command, pcrs map[uint32][]byte) (atls.Validator, error) {
|
||||
if err := prepareValidator(cmd, pcrs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return gcp.NewNonCVMValidator(pcrs), nil
|
||||
}
|
63
cli/cmd/verify_gcp_noncvm_test.go
Normal file
63
cli/cmd/verify_gcp_noncvm_test.go
Normal file
@ -0,0 +1,63 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetGCPNonCVMValidator(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
ownerID string
|
||||
clusterID string
|
||||
errExpected bool
|
||||
}{
|
||||
"no input": {
|
||||
ownerID: "",
|
||||
clusterID: "",
|
||||
errExpected: true,
|
||||
},
|
||||
"unencoded secret ID": {
|
||||
ownerID: "owner-id",
|
||||
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
|
||||
errExpected: true,
|
||||
},
|
||||
"unencoded cluster ID": {
|
||||
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
|
||||
clusterID: "unique-id",
|
||||
errExpected: true,
|
||||
},
|
||||
"correct input": {
|
||||
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
|
||||
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
|
||||
errExpected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
cmd := newVerifyGCPNonCVMCmd()
|
||||
cmd.Flags().String("owner-id", "", "")
|
||||
cmd.Flags().String("unique-id", "", "")
|
||||
require.NoError(cmd.Flags().Set("owner-id", tc.ownerID))
|
||||
require.NoError(cmd.Flags().Set("unique-id", tc.clusterID))
|
||||
var out bytes.Buffer
|
||||
cmd.SetOut(&out)
|
||||
var errOut bytes.Buffer
|
||||
cmd.SetErr(&errOut)
|
||||
|
||||
_, err := getGCPNonCVMValidator(cmd, map[uint32][]byte{})
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
66
cli/cmd/verify_gcp_test.go
Normal file
66
cli/cmd/verify_gcp_test.go
Normal file
@ -0,0 +1,66 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetGCPValidator(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
ownerID string
|
||||
clusterID string
|
||||
errExpected bool
|
||||
}{
|
||||
"no input": {
|
||||
ownerID: "",
|
||||
clusterID: "",
|
||||
errExpected: true,
|
||||
},
|
||||
"unencoded secret ID": {
|
||||
ownerID: "owner-id",
|
||||
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
|
||||
errExpected: true,
|
||||
},
|
||||
"unencoded cluster ID": {
|
||||
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
|
||||
clusterID: "unique-id",
|
||||
errExpected: true,
|
||||
},
|
||||
"correct input": {
|
||||
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
|
||||
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
|
||||
errExpected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
cmd := newVerifyGCPCmd()
|
||||
cmd.Flags().String("owner-id", "", "")
|
||||
cmd.Flags().String("unique-id", "", "")
|
||||
require.NoError(cmd.Flags().Set("owner-id", tc.ownerID))
|
||||
require.NoError(cmd.Flags().Set("unique-id", tc.clusterID))
|
||||
var out bytes.Buffer
|
||||
cmd.SetOut(&out)
|
||||
var errOut bytes.Buffer
|
||||
cmd.SetErr(&errOut)
|
||||
|
||||
_, err := getGCPValidator(cmd, map[uint32][]byte{
|
||||
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
|
||||
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"),
|
||||
})
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
325
cli/cmd/verify_test.go
Normal file
325
cli/cmd/verify_test.go
Normal file
@ -0,0 +1,325 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/status"
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
|
||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||
"github.com/edgelesssys/constellation/coordinator/state"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
rpcStatus "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestVerify(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
connErr error
|
||||
checkErr error
|
||||
state state.State
|
||||
errExpected bool
|
||||
}{
|
||||
"connection error": {
|
||||
connErr: errors.New("connection error"),
|
||||
checkErr: nil,
|
||||
state: 0,
|
||||
errExpected: true,
|
||||
},
|
||||
"check error": {
|
||||
connErr: nil,
|
||||
checkErr: errors.New("check error"),
|
||||
state: 0,
|
||||
errExpected: true,
|
||||
},
|
||||
"check error, rpc status": {
|
||||
connErr: nil,
|
||||
checkErr: rpcStatus.Error(codes.Unavailable, "check error"),
|
||||
state: 0,
|
||||
errExpected: true,
|
||||
},
|
||||
"verify on worker node": {
|
||||
connErr: nil,
|
||||
checkErr: nil,
|
||||
state: state.IsNode,
|
||||
errExpected: false,
|
||||
},
|
||||
"verify on master node": {
|
||||
connErr: nil,
|
||||
checkErr: nil,
|
||||
state: state.ActivatingNodes,
|
||||
errExpected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
ctx := context.Background()
|
||||
var out bytes.Buffer
|
||||
|
||||
verifier := verifier{
|
||||
newConn: stubNewConnFunc(tc.connErr),
|
||||
newClient: stubNewClientFunc(&stubPeerStatusClient{
|
||||
state: tc.state,
|
||||
checkErr: tc.checkErr,
|
||||
}),
|
||||
}
|
||||
|
||||
pcrs := map[uint32][]byte{
|
||||
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
|
||||
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"),
|
||||
}
|
||||
err := verify(ctx, &out, "", []atls.Validator{gcp.NewValidator(pcrs)}, verifier)
|
||||
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.Contains(out.String(), "OK")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func stubNewConnFunc(errStub error) func(ctx context.Context, target string, validators []atls.Validator) (status.ClientConn, error) {
|
||||
return func(ctx context.Context, target string, validators []atls.Validator) (status.ClientConn, error) {
|
||||
return &stubClientConn{}, errStub
|
||||
}
|
||||
}
|
||||
|
||||
type stubClientConn struct{}
|
||||
|
||||
func (c *stubClientConn) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...grpc.CallOption) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stubClientConn) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *stubClientConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func stubNewClientFunc(stubClient pubproto.APIClient) func(cc grpc.ClientConnInterface) pubproto.APIClient {
|
||||
return func(cc grpc.ClientConnInterface) pubproto.APIClient {
|
||||
return stubClient
|
||||
}
|
||||
}
|
||||
|
||||
type stubPeerStatusClient struct {
|
||||
state state.State
|
||||
checkErr error
|
||||
pubproto.APIClient
|
||||
}
|
||||
|
||||
func (c *stubPeerStatusClient) GetState(ctx context.Context, in *pubproto.GetStateRequest, opts ...grpc.CallOption) (*pubproto.GetStateResponse, error) {
|
||||
resp := &pubproto.GetStateResponse{State: uint32(c.state)}
|
||||
return resp, c.checkErr
|
||||
}
|
||||
|
||||
func TestPrepareValidator(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
ownerID string
|
||||
clusterID string
|
||||
errExpected bool
|
||||
}{
|
||||
"no input": {
|
||||
ownerID: "",
|
||||
clusterID: "",
|
||||
errExpected: true,
|
||||
},
|
||||
"unencoded secret ID": {
|
||||
ownerID: "owner-id",
|
||||
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
|
||||
errExpected: true,
|
||||
},
|
||||
"unencoded cluster ID": {
|
||||
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
|
||||
clusterID: "unique-id",
|
||||
errExpected: true,
|
||||
},
|
||||
"correct input": {
|
||||
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
|
||||
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
|
||||
errExpected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
cmd := newVerifyCmd()
|
||||
cmd.Flags().String("owner-id", "", "")
|
||||
cmd.Flags().String("unique-id", "", "")
|
||||
require.NoError(cmd.Flags().Set("owner-id", tc.ownerID))
|
||||
require.NoError(cmd.Flags().Set("unique-id", tc.clusterID))
|
||||
var out bytes.Buffer
|
||||
cmd.SetOut(&out)
|
||||
var errOut bytes.Buffer
|
||||
cmd.SetErr(&errOut)
|
||||
|
||||
pcrs := map[uint32][]byte{
|
||||
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
|
||||
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"),
|
||||
}
|
||||
|
||||
err := prepareValidator(cmd, pcrs)
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
if tc.clusterID != "" {
|
||||
assert.Len(pcrs[uint32(vtpm.PCRIndexClusterID)], 32)
|
||||
} else {
|
||||
assert.Nil(pcrs[uint32(vtpm.PCRIndexClusterID)])
|
||||
}
|
||||
if tc.ownerID != "" {
|
||||
assert.Len(pcrs[uint32(vtpm.PCRIndexOwnerID)], 32)
|
||||
} else {
|
||||
assert.Nil(pcrs[uint32(vtpm.PCRIndexOwnerID)])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddOrSkipPcr(t *testing.T) {
|
||||
emptyMap := map[uint32][]byte{}
|
||||
defaultMap := map[uint32][]byte{
|
||||
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
|
||||
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"),
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
pcrMap map[uint32][]byte
|
||||
pcrIndex uint32
|
||||
encoded string
|
||||
expectedEntries int
|
||||
errExpected bool
|
||||
}{
|
||||
"empty input, empty map": {
|
||||
pcrMap: emptyMap,
|
||||
pcrIndex: 10,
|
||||
encoded: "",
|
||||
expectedEntries: 0,
|
||||
errExpected: false,
|
||||
},
|
||||
"empty input, default map": {
|
||||
pcrMap: defaultMap,
|
||||
pcrIndex: 10,
|
||||
encoded: "",
|
||||
expectedEntries: len(defaultMap),
|
||||
errExpected: false,
|
||||
},
|
||||
"correct input, empty map": {
|
||||
pcrMap: emptyMap,
|
||||
pcrIndex: 10,
|
||||
encoded: base64.StdEncoding.EncodeToString([]byte("Constellation")),
|
||||
expectedEntries: 1,
|
||||
errExpected: false,
|
||||
},
|
||||
"correct input, default map": {
|
||||
pcrMap: defaultMap,
|
||||
pcrIndex: 10,
|
||||
encoded: base64.StdEncoding.EncodeToString([]byte("Constellation")),
|
||||
expectedEntries: len(defaultMap) + 1,
|
||||
errExpected: false,
|
||||
},
|
||||
"unencoded input, empty map": {
|
||||
pcrMap: emptyMap,
|
||||
pcrIndex: 10,
|
||||
encoded: "Constellation",
|
||||
expectedEntries: 0,
|
||||
errExpected: true,
|
||||
},
|
||||
"unencoded input, default map": {
|
||||
pcrMap: defaultMap,
|
||||
pcrIndex: 10,
|
||||
encoded: "Constellation",
|
||||
expectedEntries: len(defaultMap),
|
||||
errExpected: true,
|
||||
},
|
||||
"empty input at occupied index": {
|
||||
pcrMap: defaultMap,
|
||||
pcrIndex: 0,
|
||||
encoded: "",
|
||||
expectedEntries: len(defaultMap) - 1,
|
||||
errExpected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
res := make(map[uint32][]byte)
|
||||
for k, v := range tc.pcrMap {
|
||||
res[k] = v
|
||||
}
|
||||
|
||||
err := addOrSkipPCR(res, tc.pcrIndex, tc.encoded)
|
||||
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
assert.Len(res, tc.expectedEntries)
|
||||
for _, v := range res {
|
||||
assert.Len(v, 32)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyCompletion(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
args []string
|
||||
toComplete string
|
||||
resultExpected []string
|
||||
shellCDExpected cobra.ShellCompDirective
|
||||
}{
|
||||
"first arg": {
|
||||
args: []string{},
|
||||
toComplete: "192.0.2.1",
|
||||
resultExpected: []string{},
|
||||
shellCDExpected: cobra.ShellCompDirectiveNoFileComp,
|
||||
},
|
||||
"second arg": {
|
||||
args: []string{"192.0.2.1"},
|
||||
toComplete: "443",
|
||||
resultExpected: []string{},
|
||||
shellCDExpected: cobra.ShellCompDirectiveNoFileComp,
|
||||
},
|
||||
"third arg": {
|
||||
args: []string{"192.0.2.1", "443"},
|
||||
toComplete: "./file",
|
||||
resultExpected: []string{},
|
||||
shellCDExpected: cobra.ShellCompDirectiveError,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
result, shellCD := verifyCompletion(cmd, tc.args, tc.toComplete)
|
||||
assert.Equal(tc.resultExpected, result)
|
||||
assert.Equal(tc.shellCDExpected, shellCD)
|
||||
})
|
||||
}
|
||||
}
|
19
cli/cmd/version.go
Normal file
19
cli/cmd/version.go
Normal file
@ -0,0 +1,19 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newVersionCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "version",
|
||||
Short: "Display version of this CLI",
|
||||
Long: `Display version of this CLI`,
|
||||
Args: cobra.NoArgs,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
cmd.Printf("CLI Version: v%s \n", config.Version)
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
25
cli/cmd/version_test.go
Normal file
25
cli/cmd/version_test.go
Normal file
@ -0,0 +1,25 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestVersionCmd(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
cmd := newVersionCmd()
|
||||
b := bytes.NewBufferString("")
|
||||
cmd.SetOut(b)
|
||||
|
||||
err := cmd.Execute()
|
||||
assert.NoError(err)
|
||||
|
||||
s, err := io.ReadAll(b)
|
||||
assert.NoError(err)
|
||||
assert.Contains(string(s), config.Version)
|
||||
}
|
5
cli/cmd/vpnconfigurer.go
Normal file
5
cli/cmd/vpnconfigurer.go
Normal file
@ -0,0 +1,5 @@
|
||||
package cmd
|
||||
|
||||
type vpnConfigurer interface {
|
||||
Configure(clientVpnIp string, coordinatorPubKey string, coordinatorPubIP string, clientPrivKey string) error
|
||||
}
|
17
cli/cmd/vpnconfigurer_test.go
Normal file
17
cli/cmd/vpnconfigurer_test.go
Normal file
@ -0,0 +1,17 @@
|
||||
package cmd
|
||||
|
||||
type stubVPNConfigurer struct {
|
||||
configured bool
|
||||
configureErr error
|
||||
}
|
||||
|
||||
func (c *stubVPNConfigurer) Configure(clientVpnIp, coordinatorPubKey, coordinatorPubIP, clientPrivKey string) error {
|
||||
c.configured = true
|
||||
return c.configureErr
|
||||
}
|
||||
|
||||
type dummyVPNConfigurer struct{}
|
||||
|
||||
func (c *dummyVPNConfigurer) Configure(clientVpnIp, coordinatorPubKey, coordinatorPubIP, clientPrivKey string) error {
|
||||
panic("dummy doesn't implement this function")
|
||||
}
|
42
cli/ec2/client/api.go
Normal file
42
cli/ec2/client/api.go
Normal file
@ -0,0 +1,42 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/service/ec2"
|
||||
)
|
||||
|
||||
// api collects used functions of AWS' ec2.Client as interfaces to enable testing.
|
||||
type api interface {
|
||||
ec2.DescribeInstancesAPIClient
|
||||
|
||||
// Instances
|
||||
RunInstances(ctx context.Context,
|
||||
params *ec2.RunInstancesInput,
|
||||
optFns ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error)
|
||||
|
||||
TerminateInstances(ctx context.Context,
|
||||
params *ec2.TerminateInstancesInput,
|
||||
optFns ...func(*ec2.Options)) (*ec2.TerminateInstancesOutput, error)
|
||||
|
||||
CreateTags(ctx context.Context,
|
||||
params *ec2.CreateTagsInput,
|
||||
optFns ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error)
|
||||
|
||||
// SecurityGroup
|
||||
CreateSecurityGroup(ctx context.Context,
|
||||
params *ec2.CreateSecurityGroupInput,
|
||||
optFns ...func(*ec2.Options)) (*ec2.CreateSecurityGroupOutput, error)
|
||||
|
||||
DeleteSecurityGroup(ctx context.Context,
|
||||
params *ec2.DeleteSecurityGroupInput,
|
||||
optFns ...func(*ec2.Options)) (*ec2.DeleteSecurityGroupOutput, error)
|
||||
|
||||
AuthorizeSecurityGroupIngress(ctx context.Context,
|
||||
params *ec2.AuthorizeSecurityGroupIngressInput,
|
||||
optFns ...func(*ec2.Options)) (*ec2.AuthorizeSecurityGroupIngressOutput, error)
|
||||
|
||||
AuthorizeSecurityGroupEgress(ctx context.Context,
|
||||
params *ec2.AuthorizeSecurityGroupEgressInput,
|
||||
optFns ...func(*ec2.Options)) (*ec2.AuthorizeSecurityGroupEgressOutput, error)
|
||||
}
|
137
cli/ec2/client/api_test.go
Normal file
137
cli/ec2/client/api_test.go
Normal file
@ -0,0 +1,137 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/service/ec2"
|
||||
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
|
||||
"github.com/aws/smithy-go"
|
||||
)
|
||||
|
||||
// stubAPI is a stub ec2 api for testing.
|
||||
type stubAPI struct {
|
||||
instances []types.Instance
|
||||
securityGroup types.SecurityGroup
|
||||
|
||||
describeInstancesErr error
|
||||
runInstancesErr error
|
||||
runInstancesDryRunErr *error
|
||||
terminateInstancesErr error
|
||||
terminateInstancesDryRunErr *error
|
||||
createTagsErr error
|
||||
createSecurityGroupErr error
|
||||
createSecurityGroupDryRunErr *error
|
||||
deleteSecurityGroupErr error
|
||||
deleteSecurityGroupDryRunErr *error
|
||||
authorizeSecurityGroupIngressErr error
|
||||
authorizeSecurityGroupIngressDryRunErr *error
|
||||
authorizeSecurityGroupEgressErr error
|
||||
authorizeSecurityGroupEgressDryRunErr *error
|
||||
}
|
||||
|
||||
func (a stubAPI) DescribeInstances(ctx context.Context,
|
||||
params *ec2.DescribeInstancesInput,
|
||||
optFns ...func(*ec2.Options),
|
||||
) (*ec2.DescribeInstancesOutput, error) {
|
||||
return &ec2.DescribeInstancesOutput{
|
||||
Reservations: []types.Reservation{
|
||||
{Instances: a.instances},
|
||||
},
|
||||
}, a.describeInstancesErr
|
||||
}
|
||||
|
||||
func (a stubAPI) RunInstances(ctx context.Context,
|
||||
params *ec2.RunInstancesInput,
|
||||
optFns ...func(*ec2.Options),
|
||||
) (*ec2.RunInstancesOutput, error) {
|
||||
if err := getDryRunErr(params.DryRun, a.runInstancesDryRunErr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ec2.RunInstancesOutput{Instances: a.instances}, a.runInstancesErr
|
||||
}
|
||||
|
||||
func (a stubAPI) CreateTags(ctx context.Context,
|
||||
params *ec2.CreateTagsInput,
|
||||
optFns ...func(*ec2.Options),
|
||||
) (*ec2.CreateTagsOutput, error) {
|
||||
return nil, a.createTagsErr
|
||||
}
|
||||
|
||||
func (a stubAPI) TerminateInstances(ctx context.Context,
|
||||
params *ec2.TerminateInstancesInput,
|
||||
optFns ...func(*ec2.Options),
|
||||
) (*ec2.TerminateInstancesOutput, error) {
|
||||
if err := getDryRunErr(params.DryRun, a.terminateInstancesDryRunErr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, a.terminateInstancesErr
|
||||
}
|
||||
|
||||
func (a stubAPI) CreateSecurityGroup(ctx context.Context,
|
||||
params *ec2.CreateSecurityGroupInput,
|
||||
optFns ...func(*ec2.Options),
|
||||
) (*ec2.CreateSecurityGroupOutput, error) {
|
||||
if err := getDryRunErr(params.DryRun, a.createSecurityGroupDryRunErr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ec2.CreateSecurityGroupOutput{
|
||||
GroupId: a.securityGroup.GroupId,
|
||||
}, a.createSecurityGroupErr
|
||||
}
|
||||
|
||||
func (a stubAPI) DeleteSecurityGroup(ctx context.Context,
|
||||
params *ec2.DeleteSecurityGroupInput,
|
||||
optFns ...func(*ec2.Options),
|
||||
) (*ec2.DeleteSecurityGroupOutput, error) {
|
||||
if err := getDryRunErr(params.DryRun, a.deleteSecurityGroupDryRunErr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, a.deleteSecurityGroupErr
|
||||
}
|
||||
|
||||
func (a stubAPI) AuthorizeSecurityGroupIngress(ctx context.Context,
|
||||
params *ec2.AuthorizeSecurityGroupIngressInput,
|
||||
optFns ...func(*ec2.Options),
|
||||
) (*ec2.AuthorizeSecurityGroupIngressOutput, error) {
|
||||
if err := getDryRunErr(params.DryRun, a.authorizeSecurityGroupIngressDryRunErr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, a.authorizeSecurityGroupIngressErr
|
||||
}
|
||||
|
||||
func (a stubAPI) AuthorizeSecurityGroupEgress(ctx context.Context,
|
||||
params *ec2.AuthorizeSecurityGroupEgressInput,
|
||||
optFns ...func(*ec2.Options),
|
||||
) (*ec2.AuthorizeSecurityGroupEgressOutput, error) {
|
||||
if err := getDryRunErr(params.DryRun, a.authorizeSecurityGroupEgressDryRunErr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, a.authorizeSecurityGroupEgressErr
|
||||
}
|
||||
|
||||
func getDryRunErr(dryRun *bool, stubErr *error) error {
|
||||
if dryRun == nil || !*dryRun {
|
||||
return nil
|
||||
}
|
||||
if stubErr != nil {
|
||||
return *stubErr
|
||||
}
|
||||
return &smithy.GenericAPIError{Code: "DryRunOperation"}
|
||||
}
|
||||
|
||||
var stateRunning = types.InstanceState{
|
||||
Code: aws.Int32(int32(16)),
|
||||
Name: types.InstanceStateNameRunning,
|
||||
}
|
||||
|
||||
var stateTerminated = types.InstanceState{
|
||||
Code: aws.Int32(48),
|
||||
Name: types.InstanceStateNameTerminated,
|
||||
}
|
71
cli/ec2/client/client.go
Normal file
71
cli/ec2/client/client.go
Normal file
@ -0,0 +1,71 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
awsconfig "github.com/aws/aws-sdk-go-v2/config"
|
||||
awsec2 "github.com/aws/aws-sdk-go-v2/service/ec2"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/cli/ec2"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
)
|
||||
|
||||
// Client for the AWS EC2 API.
|
||||
type Client struct {
|
||||
api api
|
||||
instances ec2.Instances
|
||||
securityGroup string
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func newClient(api api) (*Client, error) {
|
||||
return &Client{
|
||||
api: api,
|
||||
instances: make(map[string]ec2.Instance),
|
||||
timeout: 2 * time.Minute,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewFromDefault creates a Client from the default config.
|
||||
func NewFromDefault(ctx context.Context) (*Client, error) {
|
||||
cfg, err := awsconfig.LoadDefaultConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newClient(awsec2.NewFromConfig(cfg))
|
||||
}
|
||||
|
||||
// GetState returns the current configuration of the Constellation,
|
||||
// which can be stored and used through later CLI commands.
|
||||
func (c *Client) GetState() (state.ConstellationState, error) {
|
||||
if len(c.instances) == 0 {
|
||||
return state.ConstellationState{}, errors.New("client has no instances")
|
||||
}
|
||||
if c.securityGroup == "" {
|
||||
return state.ConstellationState{}, errors.New("client has no security group")
|
||||
}
|
||||
return state.ConstellationState{
|
||||
CloudProvider: cloudprovider.AWS.String(),
|
||||
EC2Instances: c.instances,
|
||||
EC2SecurityGroup: c.securityGroup,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetState sets a Client to an existing configuration.
|
||||
func (c *Client) SetState(stat state.ConstellationState) error {
|
||||
if stat.CloudProvider != cloudprovider.AWS.String() {
|
||||
return errors.New("state is not aws state")
|
||||
}
|
||||
if len(stat.EC2Instances) == 0 {
|
||||
return errors.New("state has no instances")
|
||||
}
|
||||
if stat.EC2SecurityGroup == "" {
|
||||
return errors.New("state has no security group")
|
||||
}
|
||||
c.instances = stat.EC2Instances
|
||||
c.securityGroup = stat.EC2SecurityGroup
|
||||
return nil
|
||||
}
|
120
cli/ec2/client/client_test.go
Normal file
120
cli/ec2/client/client_test.go
Normal file
@ -0,0 +1,120 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/ec2"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetState(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
client Client
|
||||
wantState state.ConstellationState
|
||||
wantErr bool
|
||||
}{
|
||||
"successful get": {
|
||||
client: Client{
|
||||
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
|
||||
securityGroup: "sg",
|
||||
},
|
||||
wantState: state.ConstellationState{
|
||||
CloudProvider: "AWS",
|
||||
EC2Instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
|
||||
EC2SecurityGroup: "sg",
|
||||
},
|
||||
},
|
||||
"client without security group": {
|
||||
client: Client{
|
||||
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"client without instances": {
|
||||
client: Client{
|
||||
securityGroup: "sg",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
stat, err := tc.client.GetState()
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.wantState, stat)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetState(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
state state.ConstellationState
|
||||
wantInstances ec2.Instances
|
||||
wantSecurityGroup string
|
||||
wantErr bool
|
||||
}{
|
||||
"successful set": {
|
||||
state: state.ConstellationState{
|
||||
CloudProvider: "AWS",
|
||||
EC2SecurityGroup: "sg-test",
|
||||
EC2Instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
|
||||
},
|
||||
wantInstances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
|
||||
wantSecurityGroup: "sg-test",
|
||||
},
|
||||
"state without cloudprovider": {
|
||||
state: state.ConstellationState{
|
||||
EC2SecurityGroup: "sg-test",
|
||||
EC2Instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"state with incorrect cloudprovider": {
|
||||
state: state.ConstellationState{
|
||||
CloudProvider: "incorrect",
|
||||
EC2SecurityGroup: "sg-test",
|
||||
EC2Instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"state without instances": {
|
||||
state: state.ConstellationState{
|
||||
CloudProvider: "AWS",
|
||||
EC2SecurityGroup: "sg-test",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"state without security group": {
|
||||
state: state.ConstellationState{
|
||||
CloudProvider: "AWS",
|
||||
EC2Instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
client := &Client{}
|
||||
|
||||
err := client.SetState(tc.state)
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.wantInstances, client.instances)
|
||||
assert.Equal(tc.wantSecurityGroup, client.securityGroup)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
199
cli/ec2/client/instances.go
Normal file
199
cli/ec2/client/instances.go
Normal file
@ -0,0 +1,199 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
awsec2 "github.com/aws/aws-sdk-go-v2/service/ec2"
|
||||
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
|
||||
"github.com/edgelesssys/constellation/cli/ec2"
|
||||
)
|
||||
|
||||
// CreateInstances creates the instances defined in input.
|
||||
//
|
||||
// An existing security group is needed to create instances.
|
||||
func (c *Client) CreateInstances(ctx context.Context, input CreateInput) error {
|
||||
if c.securityGroup == "" {
|
||||
return errors.New("no security group set")
|
||||
}
|
||||
input.securityGroupIds = []string{c.securityGroup}
|
||||
|
||||
if err := c.createDryRun(ctx, input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := c.api.RunInstances(ctx, input.AWS())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create instances: %w", err)
|
||||
}
|
||||
|
||||
for _, instance := range resp.Instances {
|
||||
id := instance.InstanceId
|
||||
if id == nil {
|
||||
return errors.New("instanceId is nil pointer")
|
||||
}
|
||||
c.instances[*id] = ec2.Instance{}
|
||||
}
|
||||
|
||||
if err := c.waitStateRunning(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.tagInstances(ctx, input.Tags); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.getInstanceIPs(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TerminateInstances terminates all instances of a Client.
|
||||
func (c *Client) TerminateInstances(ctx context.Context) error {
|
||||
if len(c.instances) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
input := &awsec2.TerminateInstancesInput{
|
||||
InstanceIds: c.instances.IDs(),
|
||||
}
|
||||
if err := c.terminateDryRun(ctx, *input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := c.api.TerminateInstances(ctx, input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.waitStateTerminated(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
c.instances = ec2.Instances{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// waitStateRunning waits until all the client's instances reached the running state.
|
||||
//
|
||||
// A set of instances is also considered to be running if at least one of the
|
||||
// instances' state is 'running' and the other instances have a nil state.
|
||||
func (c *Client) waitStateRunning(ctx context.Context) error {
|
||||
if len(c.instances) == 0 {
|
||||
return errors.New("client has no instances")
|
||||
}
|
||||
describeInput := &awsec2.DescribeInstancesInput{
|
||||
InstanceIds: c.instances.IDs(),
|
||||
}
|
||||
waiter := awsec2.NewInstanceRunningWaiter(c.api)
|
||||
return waiter.Wait(ctx, describeInput, c.timeout)
|
||||
}
|
||||
|
||||
// waitStateTerminated waits until all the client's instances reached the terminated state.
|
||||
//
|
||||
// A set of instances is also considered to be terminated if at least one of the
|
||||
// instances' state is 'terminated' and the other instances have a nil state.
|
||||
func (c *Client) waitStateTerminated(ctx context.Context) error {
|
||||
if len(c.instances) == 0 {
|
||||
return errors.New("client has no instances")
|
||||
}
|
||||
|
||||
describeInput := &awsec2.DescribeInstancesInput{
|
||||
InstanceIds: c.instances.IDs(),
|
||||
}
|
||||
waiter := awsec2.NewInstanceTerminatedWaiter(c.api)
|
||||
return waiter.Wait(ctx, describeInput, c.timeout)
|
||||
}
|
||||
|
||||
// tagInstances tags all instances of a client with a given set of tags.
|
||||
func (c *Client) tagInstances(ctx context.Context, tags ec2.Tags) error {
|
||||
if len(c.instances) == 0 {
|
||||
return errors.New("client has no instances")
|
||||
}
|
||||
|
||||
tagInput := &awsec2.CreateTagsInput{
|
||||
Resources: c.instances.IDs(),
|
||||
Tags: tags.AWS(),
|
||||
}
|
||||
if _, err := c.api.CreateTags(ctx, tagInput); err != nil {
|
||||
return fmt.Errorf("failed to tag instances: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// createDryRun checks if user has the privilege to create the instances
|
||||
// which were defined in input.
|
||||
func (c *Client) createDryRun(ctx context.Context, input CreateInput) error {
|
||||
runInput := input.AWS()
|
||||
runInput.DryRun = aws.Bool(true)
|
||||
_, err := c.api.RunInstances(ctx, runInput)
|
||||
return checkDryRunError(err)
|
||||
}
|
||||
|
||||
// terminateDryRun checks if user has the privilege to terminate the instances
|
||||
// which were defined in input.
|
||||
func (c *Client) terminateDryRun(ctx context.Context, input awsec2.TerminateInstancesInput) error {
|
||||
input.DryRun = aws.Bool(true)
|
||||
_, err := c.api.TerminateInstances(ctx, &input)
|
||||
return checkDryRunError(err)
|
||||
}
|
||||
|
||||
// getInstanceIPs queries the private and public IP addresses
|
||||
// and adds the information to each instance.
|
||||
//
|
||||
// The instances must be in 'running' state.
|
||||
func (c *Client) getInstanceIPs(ctx context.Context) error {
|
||||
describeInput := &awsec2.DescribeInstancesInput{
|
||||
InstanceIds: c.instances.IDs(),
|
||||
}
|
||||
paginator := awsec2.NewDescribeInstancesPaginator(c.api, describeInput)
|
||||
for paginator.HasMorePages() {
|
||||
output, err := paginator.NextPage(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, reservation := range output.Reservations {
|
||||
for _, instanceDescription := range reservation.Instances {
|
||||
if instanceDescription.InstanceId == nil {
|
||||
return errors.New("instanceId is nil pointer")
|
||||
}
|
||||
if instanceDescription.PublicIpAddress == nil {
|
||||
return errors.New("publicIpAddress is nil pointer")
|
||||
}
|
||||
if instanceDescription.PrivateIpAddress == nil {
|
||||
return errors.New("privateIpAddress is nil pointer")
|
||||
}
|
||||
instance, ok := c.instances[*instanceDescription.InstanceId]
|
||||
if !ok {
|
||||
return errors.New("got an instance description to an unknown instanceId")
|
||||
}
|
||||
instance.PublicIP = *instanceDescription.PublicIpAddress
|
||||
instance.PrivateIP = *instanceDescription.PrivateIpAddress
|
||||
c.instances[*instanceDescription.InstanceId] = instance
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateInput defines the propertis of the instances to create.
|
||||
type CreateInput struct {
|
||||
ImageId string
|
||||
InstanceType string
|
||||
Count int
|
||||
Tags ec2.Tags
|
||||
securityGroupIds []string
|
||||
}
|
||||
|
||||
// AWS creates a AWS ec2.RunInstancesInput from an CreateInput.
|
||||
func (ci *CreateInput) AWS() *awsec2.RunInstancesInput {
|
||||
return &awsec2.RunInstancesInput{
|
||||
ImageId: aws.String(ci.ImageId),
|
||||
InstanceType: ec2.InstanceTypes[ci.InstanceType],
|
||||
MaxCount: aws.Int32(int32(ci.Count)),
|
||||
MinCount: aws.Int32(int32(ci.Count)),
|
||||
EnclaveOptions: &types.EnclaveOptionsRequest{Enabled: aws.Bool(true)},
|
||||
SecurityGroupIds: ci.securityGroupIds,
|
||||
}
|
||||
}
|
493
cli/ec2/client/instances_test.go
Normal file
493
cli/ec2/client/instances_test.go
Normal file
@ -0,0 +1,493 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
awsec2 "github.com/aws/aws-sdk-go-v2/service/ec2"
|
||||
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
|
||||
"github.com/edgelesssys/constellation/cli/ec2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCreateInstances(t *testing.T) {
|
||||
testInstances := []types.Instance{
|
||||
{
|
||||
InstanceId: aws.String("id-1"),
|
||||
PublicIpAddress: aws.String("192.0.2.1"),
|
||||
PrivateIpAddress: aws.String("192.0.2.2"),
|
||||
State: &stateRunning,
|
||||
},
|
||||
{
|
||||
InstanceId: aws.String("id-2"),
|
||||
PublicIpAddress: aws.String("192.0.2.3"),
|
||||
PrivateIpAddress: aws.String("192.0.2.4"),
|
||||
State: &stateRunning,
|
||||
},
|
||||
{
|
||||
InstanceId: aws.String("id-3"),
|
||||
PublicIpAddress: aws.String("192.0.2.5"),
|
||||
PrivateIpAddress: aws.String("192.0.2.6"),
|
||||
State: &stateRunning,
|
||||
},
|
||||
}
|
||||
someErr := errors.New("failed")
|
||||
var noErr error
|
||||
|
||||
testCases := map[string]struct {
|
||||
api stubAPI
|
||||
instances ec2.Instances
|
||||
securityGroup string
|
||||
errExpected bool
|
||||
wantInstances ec2.Instances
|
||||
}{
|
||||
"create": {
|
||||
api: stubAPI{instances: testInstances},
|
||||
securityGroup: "sg-test",
|
||||
wantInstances: ec2.Instances{
|
||||
"id-1": {PublicIP: "192.0.2.1", PrivateIP: "192.0.2.2"},
|
||||
"id-2": {PublicIP: "192.0.2.3", PrivateIP: "192.0.2.4"},
|
||||
"id-3": {PublicIP: "192.0.2.5", PrivateIP: "192.0.2.6"},
|
||||
},
|
||||
},
|
||||
"client already has instances": {
|
||||
api: stubAPI{instances: testInstances},
|
||||
instances: ec2.Instances{"id-4": {}, "id-5": {}},
|
||||
securityGroup: "sg-test",
|
||||
wantInstances: ec2.Instances{
|
||||
"id-1": {PublicIP: "192.0.2.1", PrivateIP: "192.0.2.2"},
|
||||
"id-2": {PublicIP: "192.0.2.3", PrivateIP: "192.0.2.4"},
|
||||
"id-3": {PublicIP: "192.0.2.5", PrivateIP: "192.0.2.6"},
|
||||
"id-4": {},
|
||||
"id-5": {},
|
||||
},
|
||||
},
|
||||
"client already has same instance id": {
|
||||
api: stubAPI{instances: testInstances},
|
||||
instances: ec2.Instances{"id-1": {}, "id-4": {}, "id-5": {}},
|
||||
securityGroup: "sg-test",
|
||||
errExpected: false,
|
||||
wantInstances: ec2.Instances{
|
||||
"id-1": {PublicIP: "192.0.2.1", PrivateIP: "192.0.2.2"},
|
||||
"id-2": {PublicIP: "192.0.2.3", PrivateIP: "192.0.2.4"},
|
||||
"id-3": {PublicIP: "192.0.2.5", PrivateIP: "192.0.2.6"},
|
||||
"id-4": {},
|
||||
"id-5": {},
|
||||
},
|
||||
},
|
||||
"client has no security group": {
|
||||
api: stubAPI{},
|
||||
errExpected: true,
|
||||
},
|
||||
"run API error": {
|
||||
api: stubAPI{runInstancesErr: someErr},
|
||||
securityGroup: "sg-test",
|
||||
errExpected: true,
|
||||
},
|
||||
"runDryRun API error": {
|
||||
api: stubAPI{runInstancesDryRunErr: &someErr},
|
||||
securityGroup: "sg-test",
|
||||
errExpected: true,
|
||||
},
|
||||
"runDryRun missing expected API error": {
|
||||
api: stubAPI{runInstancesDryRunErr: &noErr},
|
||||
securityGroup: "sg-test",
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
client := &Client{
|
||||
api: tc.api,
|
||||
instances: tc.instances,
|
||||
timeout: time.Millisecond,
|
||||
securityGroup: tc.securityGroup,
|
||||
}
|
||||
if client.instances == nil {
|
||||
client.instances = make(map[string]ec2.Instance)
|
||||
}
|
||||
input := CreateInput{
|
||||
ImageId: "test-image",
|
||||
InstanceType: "",
|
||||
Count: 13,
|
||||
}
|
||||
|
||||
err := client.CreateInstances(context.Background(), input)
|
||||
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.ElementsMatch(tc.wantInstances.IDs(), client.instances.IDs())
|
||||
assert.ElementsMatch(tc.wantInstances.PublicIPs(), client.instances.PublicIPs())
|
||||
assert.ElementsMatch(tc.wantInstances.PrivateIPs(), client.instances.PrivateIPs())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTerminateInstances(t *testing.T) {
|
||||
testAWSInstances := []types.Instance{
|
||||
{InstanceId: aws.String("id-1"), State: &stateTerminated},
|
||||
{InstanceId: aws.String("id-2"), State: &stateTerminated},
|
||||
{InstanceId: aws.String("id-3"), State: &stateTerminated},
|
||||
}
|
||||
someErr := errors.New("failed")
|
||||
var noErr error
|
||||
|
||||
testCases := map[string]struct {
|
||||
api stubAPI
|
||||
instances ec2.Instances
|
||||
errExpected bool
|
||||
}{
|
||||
"client with instances": {
|
||||
api: stubAPI{instances: testAWSInstances},
|
||||
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
|
||||
errExpected: false,
|
||||
},
|
||||
"client no instances set": {
|
||||
api: stubAPI{},
|
||||
},
|
||||
"terminate API error": {
|
||||
api: stubAPI{terminateInstancesErr: someErr},
|
||||
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
|
||||
errExpected: true,
|
||||
},
|
||||
"terminateDryRun API error": {
|
||||
api: stubAPI{terminateInstancesDryRunErr: &someErr},
|
||||
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
|
||||
errExpected: true,
|
||||
},
|
||||
"terminateDryRun miss expected API error": {
|
||||
api: stubAPI{terminateInstancesDryRunErr: &noErr},
|
||||
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
client := &Client{
|
||||
api: tc.api,
|
||||
instances: tc.instances,
|
||||
timeout: time.Millisecond,
|
||||
}
|
||||
if client.instances == nil {
|
||||
client.instances = make(map[string]ec2.Instance)
|
||||
}
|
||||
|
||||
err := client.TerminateInstances(context.Background())
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.Empty(client.instances)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWaitStateRunning(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
api api
|
||||
instances ec2.Instances
|
||||
errExpected bool
|
||||
}{
|
||||
"instances are running": {
|
||||
api: stubAPI{instances: []types.Instance{
|
||||
{
|
||||
InstanceId: aws.String("id-1"),
|
||||
State: &stateRunning,
|
||||
},
|
||||
{
|
||||
InstanceId: aws.String("id-2"),
|
||||
State: &stateRunning,
|
||||
},
|
||||
{
|
||||
InstanceId: aws.String("id-3"),
|
||||
State: &stateRunning,
|
||||
},
|
||||
}},
|
||||
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
|
||||
errExpected: false,
|
||||
},
|
||||
"one instance running, rest nil": {
|
||||
api: stubAPI{instances: []types.Instance{
|
||||
{
|
||||
InstanceId: aws.String("id-1"),
|
||||
State: &stateRunning,
|
||||
},
|
||||
{InstanceId: aws.String("id-2")},
|
||||
{InstanceId: aws.String("id-3")},
|
||||
}},
|
||||
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
|
||||
errExpected: false,
|
||||
},
|
||||
"one instance terminated, rest nil": {
|
||||
api: stubAPI{instances: []types.Instance{
|
||||
{
|
||||
InstanceId: aws.String("id-1"),
|
||||
State: &stateTerminated,
|
||||
},
|
||||
{InstanceId: aws.String("id-2")},
|
||||
{InstanceId: aws.String("id-3")},
|
||||
}},
|
||||
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
|
||||
errExpected: true,
|
||||
},
|
||||
"instances with different state": {
|
||||
api: stubAPI{instances: []types.Instance{
|
||||
{
|
||||
InstanceId: aws.String("id-1"),
|
||||
State: &stateTerminated,
|
||||
},
|
||||
{
|
||||
InstanceId: aws.String("id-2"),
|
||||
State: &stateRunning,
|
||||
},
|
||||
{InstanceId: aws.String("id-3")},
|
||||
}},
|
||||
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
|
||||
errExpected: true,
|
||||
},
|
||||
"all instances have nil state": {
|
||||
api: stubAPI{instances: []types.Instance{
|
||||
{InstanceId: aws.String("id-1")},
|
||||
{InstanceId: aws.String("id-2")},
|
||||
{InstanceId: aws.String("id-3")},
|
||||
}},
|
||||
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
|
||||
errExpected: true,
|
||||
},
|
||||
"client has no instances": {
|
||||
api: &stubAPI{},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
client := &Client{
|
||||
api: tc.api,
|
||||
instances: tc.instances,
|
||||
timeout: time.Millisecond,
|
||||
}
|
||||
if client.instances == nil {
|
||||
client.instances = make(map[string]ec2.Instance)
|
||||
}
|
||||
|
||||
err := client.waitStateRunning(context.Background())
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWaitStateTerminated(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
api api
|
||||
instances ec2.Instances
|
||||
errExpected bool
|
||||
}{
|
||||
"instances are terminated": {
|
||||
api: stubAPI{instances: []types.Instance{
|
||||
{
|
||||
InstanceId: aws.String("id-1"),
|
||||
State: &stateTerminated,
|
||||
},
|
||||
{
|
||||
InstanceId: aws.String("id-2"),
|
||||
State: &stateTerminated,
|
||||
},
|
||||
{
|
||||
InstanceId: aws.String("id-3"),
|
||||
State: &stateTerminated,
|
||||
},
|
||||
}},
|
||||
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
|
||||
errExpected: false,
|
||||
},
|
||||
"one instance terminated, rest nil": {
|
||||
api: stubAPI{instances: []types.Instance{
|
||||
{
|
||||
InstanceId: aws.String("id-1"),
|
||||
State: &stateTerminated,
|
||||
},
|
||||
{InstanceId: aws.String("id-2")},
|
||||
{InstanceId: aws.String("id-3")},
|
||||
}},
|
||||
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
|
||||
errExpected: false,
|
||||
},
|
||||
"one instance running, rest nil": {
|
||||
api: stubAPI{instances: []types.Instance{
|
||||
{
|
||||
InstanceId: aws.String("id-1"),
|
||||
State: &stateRunning,
|
||||
},
|
||||
{InstanceId: aws.String("id-2")},
|
||||
{InstanceId: aws.String("id-3")},
|
||||
}},
|
||||
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
|
||||
errExpected: true,
|
||||
},
|
||||
"instances with different state": {
|
||||
api: stubAPI{instances: []types.Instance{
|
||||
{
|
||||
InstanceId: aws.String("id-1"),
|
||||
State: &stateTerminated,
|
||||
},
|
||||
{
|
||||
InstanceId: aws.String("id-2"),
|
||||
State: &stateRunning,
|
||||
},
|
||||
{InstanceId: aws.String("id-3")},
|
||||
}},
|
||||
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
|
||||
errExpected: true,
|
||||
},
|
||||
"all instances have nil state": {
|
||||
api: stubAPI{instances: []types.Instance{
|
||||
{InstanceId: aws.String("id-1")},
|
||||
{InstanceId: aws.String("id-2")},
|
||||
{InstanceId: aws.String("id-3")},
|
||||
}},
|
||||
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
|
||||
errExpected: true,
|
||||
},
|
||||
"client has no instances": {
|
||||
api: &stubAPI{},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
client := &Client{
|
||||
api: tc.api,
|
||||
instances: tc.instances,
|
||||
timeout: time.Millisecond,
|
||||
}
|
||||
if client.instances == nil {
|
||||
client.instances = make(map[string]ec2.Instance)
|
||||
}
|
||||
|
||||
err := client.waitStateTerminated(context.Background())
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagInstances(t *testing.T) {
|
||||
testTags := ec2.Tags{
|
||||
{Key: "Name", Value: "Test"},
|
||||
{Key: "Foo", Value: "Bar"},
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
api stubAPI
|
||||
instances ec2.Instances
|
||||
errExpected bool
|
||||
}{
|
||||
"tag": {
|
||||
api: stubAPI{},
|
||||
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
|
||||
errExpected: false,
|
||||
},
|
||||
"client without instances": {
|
||||
api: stubAPI{createTagsErr: errors.New("failed")},
|
||||
errExpected: true,
|
||||
},
|
||||
"tag API error": {
|
||||
api: stubAPI{createTagsErr: errors.New("failed")},
|
||||
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
client := &Client{
|
||||
api: tc.api,
|
||||
instances: tc.instances,
|
||||
timeout: time.Millisecond,
|
||||
}
|
||||
if client.instances == nil {
|
||||
client.instances = make(map[string]ec2.Instance)
|
||||
}
|
||||
|
||||
err := client.tagInstances(context.Background(), testTags)
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEc2RunInstanceInput(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
testCases := []struct {
|
||||
in CreateInput
|
||||
outExpected awsec2.RunInstancesInput
|
||||
}{
|
||||
{
|
||||
in: CreateInput{
|
||||
ImageId: "test-image",
|
||||
InstanceType: "4xlarge",
|
||||
Count: 13,
|
||||
securityGroupIds: []string{"test-sec-group"},
|
||||
},
|
||||
outExpected: awsec2.RunInstancesInput{
|
||||
ImageId: aws.String("test-image"),
|
||||
InstanceType: types.InstanceTypeC5a4xlarge,
|
||||
MinCount: aws.Int32(int32(13)),
|
||||
MaxCount: aws.Int32(int32(13)),
|
||||
EnclaveOptions: &types.EnclaveOptionsRequest{Enabled: aws.Bool(true)},
|
||||
SecurityGroupIds: []string{"test-sec-group"},
|
||||
},
|
||||
},
|
||||
{
|
||||
in: CreateInput{
|
||||
ImageId: "test-image-2",
|
||||
InstanceType: "12xlarge",
|
||||
Count: 2,
|
||||
securityGroupIds: []string{"test-sec-group-2"},
|
||||
},
|
||||
outExpected: awsec2.RunInstancesInput{
|
||||
ImageId: aws.String("test-image-2"),
|
||||
InstanceType: types.InstanceTypeC5a12xlarge,
|
||||
MinCount: aws.Int32(int32(2)),
|
||||
MaxCount: aws.Int32(int32(2)),
|
||||
EnclaveOptions: &types.EnclaveOptionsRequest{Enabled: aws.Bool(true)},
|
||||
SecurityGroupIds: []string{"test-sec-group-2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
out := tc.in.AWS()
|
||||
assert.Equal(tc.outExpected, *out)
|
||||
}
|
||||
}
|
136
cli/ec2/client/securitygroups.go
Normal file
136
cli/ec2/client/securitygroups.go
Normal file
@ -0,0 +1,136 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
awsec2 "github.com/aws/aws-sdk-go-v2/service/ec2"
|
||||
"github.com/edgelesssys/constellation/cli/cloud/cloudtypes"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// CreateSecurityGroup creates a AWS security group with the handed properties.
|
||||
func (c *Client) CreateSecurityGroup(ctx context.Context, input SecurityGroupInput) error {
|
||||
if c.securityGroup != "" {
|
||||
return errors.New("client already has a security group")
|
||||
}
|
||||
|
||||
id := uuid.New()
|
||||
createInput := &awsec2.CreateSecurityGroupInput{
|
||||
Description: aws.String("Security group of Constellation. This group was generated through the Constellation CLI."),
|
||||
GroupName: aws.String("Constellation-" + id.String()),
|
||||
DryRun: aws.Bool(true),
|
||||
}
|
||||
|
||||
// DryRun
|
||||
_, err := c.api.CreateSecurityGroup(ctx, createInput)
|
||||
if err := checkDryRunError(err); err != nil {
|
||||
return err
|
||||
}
|
||||
createInput.DryRun = aws.Bool(false)
|
||||
|
||||
// Create
|
||||
out, err := c.api.CreateSecurityGroup(ctx, createInput)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if out.GroupId == nil {
|
||||
return errors.New("security group creation didn't return an id")
|
||||
}
|
||||
c.securityGroup = *out.GroupId
|
||||
|
||||
// Authorize.
|
||||
return c.authorizeSecurityGroup(ctx, input)
|
||||
}
|
||||
|
||||
// DeleteSecurityGroup deletes the security group of the client.
|
||||
// The deletion is idempotent, no error is returned if the client has
|
||||
// an empty securityGroupID.
|
||||
func (c *Client) DeleteSecurityGroup(ctx context.Context) error {
|
||||
if c.securityGroup == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
input := &awsec2.DeleteSecurityGroupInput{
|
||||
GroupId: aws.String(c.securityGroup),
|
||||
DryRun: aws.Bool(true),
|
||||
}
|
||||
|
||||
// DryRun
|
||||
_, err := c.api.DeleteSecurityGroup(ctx, input)
|
||||
if err := checkDryRunError(err); err != nil {
|
||||
return err
|
||||
}
|
||||
input.DryRun = aws.Bool(false)
|
||||
|
||||
// Delete
|
||||
if _, err := c.api.DeleteSecurityGroup(ctx, input); err != nil {
|
||||
return err
|
||||
}
|
||||
c.securityGroup = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) authorizeSecurityGroup(ctx context.Context, input SecurityGroupInput) error {
|
||||
if c.securityGroup == "" {
|
||||
return errors.New("client hasn't got a security group id")
|
||||
}
|
||||
|
||||
if err := c.authorizeSecurityGroupIngress(ctx, input.Inbound); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.authorizeSecurityGroupEgress(ctx, input.Outbound)
|
||||
}
|
||||
|
||||
func (c *Client) authorizeSecurityGroupIngress(ctx context.Context, perms cloudtypes.Firewall) error {
|
||||
if len(perms) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
authInput := &awsec2.AuthorizeSecurityGroupIngressInput{
|
||||
GroupId: aws.String(c.securityGroup),
|
||||
IpPermissions: perms.AWS(),
|
||||
DryRun: aws.Bool(true),
|
||||
}
|
||||
|
||||
// DryRun
|
||||
_, err := c.api.AuthorizeSecurityGroupIngress(ctx, authInput)
|
||||
if err := checkDryRunError(err); err != nil {
|
||||
return err
|
||||
}
|
||||
authInput.DryRun = aws.Bool(false)
|
||||
|
||||
// Authorize
|
||||
_, err = c.api.AuthorizeSecurityGroupIngress(ctx, authInput)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) authorizeSecurityGroupEgress(ctx context.Context, perms cloudtypes.Firewall) error {
|
||||
if len(perms) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
authInput := &awsec2.AuthorizeSecurityGroupEgressInput{
|
||||
GroupId: aws.String(c.securityGroup),
|
||||
IpPermissions: perms.AWS(),
|
||||
DryRun: aws.Bool(true),
|
||||
}
|
||||
|
||||
// DryRun
|
||||
_, err := c.api.AuthorizeSecurityGroupEgress(ctx, authInput)
|
||||
if err := checkDryRunError(err); err != nil {
|
||||
return err
|
||||
}
|
||||
authInput.DryRun = aws.Bool(false)
|
||||
|
||||
// Authorize
|
||||
_, err = c.api.AuthorizeSecurityGroupEgress(ctx, authInput)
|
||||
return err
|
||||
}
|
||||
|
||||
type SecurityGroupInput struct {
|
||||
Inbound cloudtypes.Firewall
|
||||
Outbound cloudtypes.Firewall
|
||||
}
|
269
cli/ec2/client/securitygroups_test.go
Normal file
269
cli/ec2/client/securitygroups_test.go
Normal file
@ -0,0 +1,269 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
|
||||
"github.com/edgelesssys/constellation/cli/cloud/cloudtypes"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCreateSecurityGroup(t *testing.T) {
|
||||
testInput := SecurityGroupInput{
|
||||
Inbound: cloudtypes.Firewall{
|
||||
{
|
||||
Description: "perm1",
|
||||
Protocol: "TCP",
|
||||
IPRange: "192.0.2.0/24",
|
||||
Port: 22,
|
||||
},
|
||||
{
|
||||
Description: "perm2",
|
||||
Protocol: "UDP",
|
||||
IPRange: "192.0.2.0/24",
|
||||
Port: 4433,
|
||||
},
|
||||
},
|
||||
Outbound: cloudtypes.Firewall{
|
||||
{
|
||||
Description: "perm3",
|
||||
Protocol: "TCP",
|
||||
IPRange: "192.0.2.0/24",
|
||||
Port: 4040,
|
||||
},
|
||||
},
|
||||
}
|
||||
someErr := errors.New("failed")
|
||||
var noErr error
|
||||
|
||||
testCases := map[string]struct {
|
||||
api stubAPI
|
||||
securityGroup string
|
||||
input SecurityGroupInput
|
||||
errExpected bool
|
||||
securityGroupExpected string
|
||||
}{
|
||||
"create security group": {
|
||||
api: stubAPI{securityGroup: types.SecurityGroup{GroupId: aws.String("sg-test")}},
|
||||
input: testInput,
|
||||
securityGroupExpected: "sg-test",
|
||||
},
|
||||
"create security group without permissions": {
|
||||
api: stubAPI{securityGroup: types.SecurityGroup{GroupId: aws.String("sg-test")}},
|
||||
input: SecurityGroupInput{},
|
||||
securityGroupExpected: "sg-test",
|
||||
},
|
||||
"client already has security group": {
|
||||
api: stubAPI{},
|
||||
securityGroup: "sg-test",
|
||||
input: testInput,
|
||||
errExpected: true,
|
||||
},
|
||||
"create returns nil security group ID": {
|
||||
api: stubAPI{securityGroup: types.SecurityGroup{GroupId: nil}},
|
||||
input: testInput,
|
||||
errExpected: true,
|
||||
},
|
||||
"create API error": {
|
||||
api: stubAPI{createSecurityGroupErr: someErr},
|
||||
input: testInput,
|
||||
errExpected: true,
|
||||
},
|
||||
"create DryRun API error": {
|
||||
api: stubAPI{createSecurityGroupDryRunErr: &someErr},
|
||||
input: testInput,
|
||||
errExpected: true,
|
||||
},
|
||||
"create DryRun missing expected error": {
|
||||
api: stubAPI{createSecurityGroupDryRunErr: &noErr},
|
||||
input: testInput,
|
||||
errExpected: true,
|
||||
},
|
||||
"authorize error": {
|
||||
api: stubAPI{
|
||||
securityGroup: types.SecurityGroup{GroupId: aws.String("sg-test")},
|
||||
authorizeSecurityGroupIngressErr: someErr,
|
||||
},
|
||||
input: testInput,
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
client, err := newClient(tc.api)
|
||||
require.NoError(err)
|
||||
client.securityGroup = tc.securityGroup
|
||||
|
||||
err = client.CreateSecurityGroup(context.Background(), tc.input)
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.securityGroupExpected, client.securityGroup)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteSecurityGroup(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
var noErr error
|
||||
|
||||
testCases := map[string]struct {
|
||||
api stubAPI
|
||||
securityGroup string
|
||||
errExpected bool
|
||||
}{
|
||||
"delete security group": {
|
||||
api: stubAPI{},
|
||||
securityGroup: "sg-test",
|
||||
},
|
||||
"client without security group": {
|
||||
api: stubAPI{},
|
||||
},
|
||||
"delete API error": {
|
||||
api: stubAPI{deleteSecurityGroupErr: someErr},
|
||||
securityGroup: "sg-test",
|
||||
errExpected: true,
|
||||
},
|
||||
"delete DryRun API error": {
|
||||
api: stubAPI{deleteSecurityGroupDryRunErr: &someErr},
|
||||
securityGroup: "sg-test",
|
||||
errExpected: true,
|
||||
},
|
||||
"delete DryRun missing expected error": {
|
||||
api: stubAPI{deleteSecurityGroupDryRunErr: &noErr},
|
||||
securityGroup: "sg-test",
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
client, err := newClient(tc.api)
|
||||
require.NoError(err)
|
||||
client.securityGroup = tc.securityGroup
|
||||
|
||||
err = client.DeleteSecurityGroup(context.Background())
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.Empty(client.securityGroup)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizeSecurityGroup(t *testing.T) {
|
||||
testInput := SecurityGroupInput{
|
||||
Inbound: cloudtypes.Firewall{
|
||||
{
|
||||
Description: "perm1",
|
||||
Protocol: "TCP",
|
||||
IPRange: " 192.0.2.0/24",
|
||||
Port: 22,
|
||||
},
|
||||
{
|
||||
Description: "perm2",
|
||||
Protocol: "UDP",
|
||||
IPRange: "192.0.2.0/24",
|
||||
Port: 4433,
|
||||
},
|
||||
},
|
||||
Outbound: cloudtypes.Firewall{
|
||||
{
|
||||
Description: "perm3",
|
||||
Protocol: "TCP",
|
||||
IPRange: "192.0.2.0/24",
|
||||
Port: 4040,
|
||||
},
|
||||
},
|
||||
}
|
||||
someErr := errors.New("failed")
|
||||
var noErr error
|
||||
|
||||
testCases := map[string]struct {
|
||||
api stubAPI
|
||||
securityGroup string
|
||||
input SecurityGroupInput
|
||||
errExpected bool
|
||||
}{
|
||||
"authorize": {
|
||||
api: stubAPI{},
|
||||
securityGroup: "sg-test",
|
||||
input: testInput,
|
||||
errExpected: false,
|
||||
},
|
||||
"client without security group": {
|
||||
api: stubAPI{},
|
||||
input: testInput,
|
||||
errExpected: true,
|
||||
},
|
||||
"authorizeIngress API error": {
|
||||
api: stubAPI{authorizeSecurityGroupIngressErr: someErr},
|
||||
securityGroup: "sg-test",
|
||||
input: testInput,
|
||||
errExpected: true,
|
||||
},
|
||||
"authorizeIngress DryRun API error": {
|
||||
api: stubAPI{authorizeSecurityGroupIngressDryRunErr: &someErr},
|
||||
securityGroup: "sg-test",
|
||||
input: testInput,
|
||||
errExpected: true,
|
||||
},
|
||||
"authorizeIngress DryRun missing expected error": {
|
||||
api: stubAPI{authorizeSecurityGroupIngressDryRunErr: &noErr},
|
||||
securityGroup: "sg-test",
|
||||
input: testInput,
|
||||
errExpected: true,
|
||||
},
|
||||
"authorizeEgress API error": {
|
||||
api: stubAPI{authorizeSecurityGroupEgressErr: someErr},
|
||||
securityGroup: "sg-test",
|
||||
input: testInput,
|
||||
errExpected: true,
|
||||
},
|
||||
"authorizeEgress DryRun API error": {
|
||||
api: stubAPI{authorizeSecurityGroupEgressDryRunErr: &someErr},
|
||||
securityGroup: "sg-test",
|
||||
input: testInput,
|
||||
errExpected: true,
|
||||
},
|
||||
"authorizeEgress DryRun missing expected error": {
|
||||
api: stubAPI{authorizeSecurityGroupEgressDryRunErr: &noErr},
|
||||
securityGroup: "sg-test",
|
||||
input: testInput,
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
client, err := newClient(tc.api)
|
||||
require.NoError(err)
|
||||
client.securityGroup = tc.securityGroup
|
||||
|
||||
err = client.authorizeSecurityGroup(context.Background(), tc.input)
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
21
cli/ec2/client/util.go
Normal file
21
cli/ec2/client/util.go
Normal file
@ -0,0 +1,21 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/aws/smithy-go"
|
||||
)
|
||||
|
||||
// checkDryRunError error checks if an error is a DryRun error.
|
||||
// If the error is nil, an error is returned, since a DryRun error
|
||||
// is the expected result of a DryRun operation.
|
||||
func checkDryRunError(err error) error {
|
||||
var apiErr smithy.APIError
|
||||
if errors.As(err, &apiErr) && apiErr.ErrorCode() == "DryRunOperation" {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return errors.New("expected APIError: DryRunOperation, but got no error at all")
|
||||
}
|
22
cli/ec2/client/util_test.go
Normal file
22
cli/ec2/client/util_test.go
Normal file
@ -0,0 +1,22 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/smithy-go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCheckDryRunError(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
someErr := errors.New("failed")
|
||||
assert.ErrorIs(checkDryRunError(someErr), someErr)
|
||||
|
||||
dryRunErr := smithy.GenericAPIError{Code: "DryRunOperation"}
|
||||
assert.NoError(checkDryRunError(&dryRunErr))
|
||||
|
||||
var nilErr error
|
||||
assert.Error(checkDryRunError(nilErr))
|
||||
}
|
58
cli/ec2/instances.go
Normal file
58
cli/ec2/instances.go
Normal file
@ -0,0 +1,58 @@
|
||||
package ec2
|
||||
|
||||
import "errors"
|
||||
|
||||
// Instance is an ec2 instance.
|
||||
type Instance struct {
|
||||
PublicIP string
|
||||
PrivateIP string
|
||||
}
|
||||
|
||||
// Instances is a map of ec2 Instances. The ID of an instance is used as key.
|
||||
type Instances map[string]Instance
|
||||
|
||||
// IDs returns the IDs of all instances of the Constellation.
|
||||
func (i Instances) IDs() []string {
|
||||
var ids []string
|
||||
for id := range i {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// PublicIPs returns the public IPs of all the instances of the Constellation.
|
||||
func (i Instances) PublicIPs() []string {
|
||||
var ips []string
|
||||
for _, instance := range i {
|
||||
ips = append(ips, instance.PublicIP)
|
||||
}
|
||||
return ips
|
||||
}
|
||||
|
||||
// PrivateIPs returns the private IPs of all the instances of the Constellation.
|
||||
func (i Instances) PrivateIPs() []string {
|
||||
var ips []string
|
||||
for _, instance := range i {
|
||||
ips = append(ips, instance.PrivateIP)
|
||||
}
|
||||
return ips
|
||||
}
|
||||
|
||||
// GetOne return anyone instance out of the instances and its ID.
|
||||
func (i Instances) GetOne() (string, Instance, error) {
|
||||
for id, instance := range i {
|
||||
return id, instance, nil
|
||||
}
|
||||
return "", Instance{}, errors.New("map is empty")
|
||||
}
|
||||
|
||||
// GetOthers returns all instances but the one with the handed ID.
|
||||
func (i Instances) GetOthers(id string) Instances {
|
||||
others := make(Instances)
|
||||
for key, instance := range i {
|
||||
if key != id {
|
||||
others[key] = instance
|
||||
}
|
||||
}
|
||||
return others
|
||||
}
|
71
cli/ec2/instances_test.go
Normal file
71
cli/ec2/instances_test.go
Normal file
@ -0,0 +1,71 @@
|
||||
package ec2
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIDs(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
testState := testInstances()
|
||||
expectedIDs := []string{"id-9", "id-10", "id-11", "id-12"}
|
||||
assert.ElementsMatch(expectedIDs, testState.IDs())
|
||||
}
|
||||
|
||||
func TestPublicIPs(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
testState := testInstances()
|
||||
expectedIPs := []string{"192.0.2.1", "192.0.2.3", "192.0.2.5", "192.0.2.7"}
|
||||
assert.ElementsMatch(expectedIPs, testState.PublicIPs())
|
||||
}
|
||||
|
||||
func TestPrivateIPs(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
testState := testInstances()
|
||||
expectedIPs := []string{"192.0.2.2", "192.0.2.4", "192.0.2.6", "192.0.2.8"}
|
||||
assert.ElementsMatch(expectedIPs, testState.PrivateIPs())
|
||||
}
|
||||
|
||||
func TestGetOne(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
testState := testInstances()
|
||||
id, instance, err := testState.GetOne()
|
||||
assert.NoError(err)
|
||||
assert.Contains(testState, id)
|
||||
assert.Equal(testState[id], instance)
|
||||
}
|
||||
|
||||
func TestGetOthers(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
testCases := testInstances().IDs()
|
||||
|
||||
for _, id := range testCases {
|
||||
others := testInstances().GetOthers(id)
|
||||
assert.NotContains(others, id)
|
||||
expectedInstances := testInstances()
|
||||
delete(expectedInstances, id)
|
||||
assert.ElementsMatch(others.IDs(), expectedInstances.IDs())
|
||||
}
|
||||
}
|
||||
|
||||
func testInstances() Instances {
|
||||
return Instances{
|
||||
"id-9": {
|
||||
PublicIP: "192.0.2.1",
|
||||
PrivateIP: "192.0.2.2",
|
||||
},
|
||||
"id-10": {
|
||||
PublicIP: "192.0.2.3",
|
||||
PrivateIP: "192.0.2.4",
|
||||
},
|
||||
"id-11": {
|
||||
PublicIP: "192.0.2.5",
|
||||
PrivateIP: "192.0.2.6",
|
||||
},
|
||||
"id-12": {
|
||||
PublicIP: "192.0.2.7",
|
||||
PrivateIP: "192.0.2.8",
|
||||
},
|
||||
}
|
||||
}
|
18
cli/ec2/instancetypes.go
Normal file
18
cli/ec2/instancetypes.go
Normal file
@ -0,0 +1,18 @@
|
||||
package ec2
|
||||
|
||||
import "github.com/aws/aws-sdk-go-v2/service/ec2/types"
|
||||
|
||||
// InstanceTypes defines possible values for the SIZE positional argument.
|
||||
var InstanceTypes = map[string]types.InstanceType{
|
||||
"4xlarge": types.InstanceTypeC5a4xlarge,
|
||||
"8xlarge": types.InstanceTypeC5a8xlarge,
|
||||
"12xlarge": types.InstanceTypeC5a12xlarge,
|
||||
"16xlarge": types.InstanceTypeC5a16xlarge,
|
||||
"24xlarge": types.InstanceTypeC5a24xlarge,
|
||||
// shorthands
|
||||
"4xl": types.InstanceTypeC5a4xlarge,
|
||||
"8xl": types.InstanceTypeC5a8xlarge,
|
||||
"12xl": types.InstanceTypeC5a12xlarge,
|
||||
"16xl": types.InstanceTypeC5a16xlarge,
|
||||
"24xl": types.InstanceTypeC5a24xlarge,
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user