Co-authored-by: Malte Poll <mp@edgeless.systems>
Co-authored-by: katexochen <katexochen@users.noreply.github.com>
Co-authored-by: Daniel Weiße <dw@edgeless.systems>
Co-authored-by: Thomas Tendyck <tt@edgeless.systems>
Co-authored-by: Benedict Schlueter <bs@edgeless.systems>
Co-authored-by: leongross <leon.gross@rub.de>
Co-authored-by: Moritz Eckert <m1gh7ym0@gmail.com>
This commit is contained in:
Leonard Cohnen 2022-03-22 16:03:15 +01:00
commit 2d8fcd9bf4
362 changed files with 50980 additions and 0 deletions

31
.dockerignore Normal file
View File

@ -0,0 +1,31 @@
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
# Test binary, built with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
#ignore build files
/build
admin.conf
coordinatorConfig.json
coordinator-*
/debugd
/images
# Dockerfiles
Dockerfile
Dockerfile.*
# GitHub actions
.github
# VS Code configuration folder
.vscode

55
.github/workflows/build-ami.yml vendored Normal file
View File

@ -0,0 +1,55 @@
name: Build the AMI Template
on:
workflow_dispatch:
workflow_call:
secrets:
AWS_ACCESS_KEY_ID:
required: true
AWS_SECRET_ACCESS_KEY:
required: true
AWS_DEFAULT_REGION:
required: true
BUCKET_NAME:
required: true
jobs:
build-enclave:
name: "Build the AMI"
runs-on: ubuntu-latest
env:
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_DEFAULT_REGION: ${{ secrets.AWS_DEFAULT_REGION }}
working-directory: images/aws/ec2
steps:
- name: Checkout
id: checkout
uses: actions/checkout@v2
- name: Install AWS CLI
id: prepare
run: sudo apt-get update && sudo apt-get -y install awscli
- name: Download eif
id: download_eif
run: aws s3 cp s3://${{ secrets.BUCKET_NAME }}/eif/ ${{ github.workspace }}/${{ env.working-directory }}/ --recursive --quiet
- name: Download gvproxy
id: download_gvproxy
run: aws s3 cp s3://${{ secrets.BUCKET_NAME }}/gvproxy/gvproxy ${{ github.workspace }}/${{ env.working-directory }}/ --quiet
- name: Install build dependencies
run: sudo apt-get -y install packer
- name: Init packer
run: packer init .
working-directory: ${{ env.working-directory }}
- name: Validate packer
run: packer validate -syntax-only .
working-directory: ${{ env.working-directory }}
- name: Build packer
run: packer build -color=false .
working-directory: ${{ env.working-directory }}

107
.github/workflows/build-coordinator.yml vendored Normal file
View File

@ -0,0 +1,107 @@
name: Build and Upload the Coordinator
on:
workflow_dispatch:
push:
branches:
- main
jobs:
build-coordinator:
name: "Build the Coordinator"
runs-on: ubuntu-latest
env:
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_DEFAULT_REGION: ${{ secrets.AWS_DEFAULT_REGION }}
outputs:
coordinator-name: ${{ steps.copy.outputs.coordinator-name }}
steps:
- name: Checkout
id: checkout
uses: actions/checkout@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
- name: Cache Docker layers
uses: actions/cache@v2
with:
path: /tmp/.buildx-cache
key: ${{ runner.os }}-buildx-${{ github.sha }}
restore-keys: |
${{ runner.os }}-buildx-
- name: Install Dependencies
id: prepare
run: sudo apt-get update && sudo apt-get -y install awscli
- name: Build the Coordinator
uses: docker/build-push-action@v2
with:
context: .
file: Dockerfile.build
outputs: .
push: false
cache-from: type=local,src=/tmp/.buildx-cache
cache-to: type=local,dest=/tmp/.buildx-cache-new,mode=max
# This ugly bit is necessary if you don't want your cache to grow forever
# till it hits GitHub's limit of 5GB.
# Temp fix
# https://github.com/docker/build-push-action/issues/252
# https://github.com/moby/buildkit/issues/1896
- name: Move cache
run: |
rm -rf /tmp/.buildx-cache
mv /tmp/.buildx-cache-new /tmp/.buildx-cache
- name: Copy Coordinator to S3 if not exists
id: copy
run: >
aws s3api head-object --bucket ${{ secrets.PUBLIC_BUCKET_NAME }} --key coordinator/$(ls | grep "coordinator-")
|| (
echo "::set-output name=coordinator-name::$(ls | grep "coordinator-")"
&& aws s3 cp ${{ github.workspace }}/ s3://${{ secrets.PUBLIC_BUCKET_NAME }}/coordinator/ --exclude "*" --include "coordinator-*" --include "constellation" --recursive --quiet)
shell: bash {0}
call-coreos:
needs: build-coordinator
if: startsWith(needs.build-coordinator.outputs.coordinator-name, 'coordinator-')
uses: ./.github/workflows/build-coreos.yml
with:
coordinator-name: ${{ needs.build-coordinator.outputs.coordinator-name }}
secrets:
CI_GITHUB_REPOSITORY: ${{ secrets.CI_GITHUB_REPOSITORY }}
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_DEFAULT_REGION: ${{ secrets.AWS_DEFAULT_REGION }}
BUCKET_NAME: ${{ secrets.BUCKET_NAME }}
PUBLIC_BUCKET_NAME: ${{ secrets.PUBLIC_BUCKET_NAME }}
SSH_PUB_KEY: ${{ secrets.SSH_PUB_KEY }}
SSH_PUB_KEY_PATH: ${{ secrets.SSH_PUB_KEY_PATH }}
AZURE_CREDENTIALS: ${{ secrets.AZURE_CREDENTIALS }}
call-aws-enclave:
needs: build-coordinator
if: startsWith(needs.build-coordinator.outputs.coordinator-name, 'coordinator-')
uses: ./.github/workflows/build-enclave.yml
with:
coordinator-name: ${{ needs.build-coordinator.outputs.coordinator-name }}
secrets:
CI_GITHUB_REPOSITORY: ${{ secrets.CI_GITHUB_REPOSITORY }}
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_DEFAULT_REGION: ${{ secrets.AWS_DEFAULT_REGION }}
BUCKET_NAME: ${{ secrets.BUCKET_NAME }}
PUBLIC_BUCKET_NAME: ${{ secrets.PUBLIC_BUCKET_NAME }}
SSH_PUB_KEY: ${{ secrets.SSH_PUB_KEY }}
SSH_PUB_KEY_PATH: ${{ secrets.SSH_PUB_KEY_PATH }}
call-aws-ami:
needs: call-aws-enclave
uses: ./.github/workflows/build-ami.yml
secrets:
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_DEFAULT_REGION: ${{ secrets.AWS_DEFAULT_REGION }}
BUCKET_NAME: ${{ secrets.BUCKET_NAME }}

View File

@ -0,0 +1,79 @@
name: Build and Upload CoreOS debug image
env:
REGISTRY: ghcr.io
on:
workflow_dispatch:
jobs:
build-enclave:
name: "Build CoreOS debug image using customized COSA"
runs-on: [self-hosted, linux, nested-virt]
permissions:
contents: read
packages: read
defaults:
run:
shell: bash
env:
working-directory: ${{ github.workspace }}/images/fcos
SHELL: /bin/bash
GOPATH: /home/github-actions-runner-user/go
GOCACHE: /home/github-actions-runner-user/.cache/go-build
GOMODCACHE: /home/github-actions-runner-user/.cache/go-mod
steps:
- name: Checkout
id: checkout
uses: actions/checkout@v2
with:
submodules: recursive
token: ${{ secrets.CI_GITHUB_REPOSITORY }}
- name: Log in to the Container registry
id: docker-login
uses: docker/login-action@v1
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: "Install azure CLI"
run: |
# use pip since azure cli repository is not working as expected
# https://github.com/Azure/azure-cli/issues/21532
# curl -sL https://aka.ms/InstallAzureCLIDeb | sudo bash
sudo apt-get update
sudo apt-get install -y python3 python3-pip
sudo pip install azure-cli
wget -q https://aka.ms/downloadazcopy-v10-linux -O azcopy.tar.gz
tar --strip-components 1 -xf azcopy.tar.gz
rm azcopy.tar.gz
echo "$(pwd)" >> $GITHUB_PATH
- uses: azure/login@v1
with:
creds: ${{ secrets.AZURE_CREDENTIALS }}
- name: Setup Go environment
uses: actions/setup-go@v2.2.0
with:
go-version: "1.18"
- name: "Compile debugd"
run: GOCACHE=/home/github-actions-runner-user/.cache/go-build GOPATH=/home/github-actions-runner-user/go GOPRIVATE=github.com/edgelesssys GOMODCACHE=/home/github-actions-runner-user/.cache/go-mod go build -o constellation-debugd debugd.go
working-directory: ${{ github.workspace }}/debugd/debugd/cmd/debugd
- name: "Store GH token to be mounted by cosa"
run: echo "machine github.com login api password ${{ secrets.CI_GITHUB_REPOSITORY }}" > /tmp/.netrc
- name: "Set image timestamp"
run: |
TIMESTAMP=$(date +%s)
echo "TIMESTAMP=${TIMESTAMP}" >> $GITHUB_ENV
echo "IMAGE_TIMESTAMP=constellation-coreos-debugd-${TIMESTAMP}" >> $GITHUB_ENV
echo "IMAGE_VERSION=0.0.${TIMESTAMP}" >> $GITHUB_ENV
- name: "Build and Upload"
run: >
make -j$(nproc) CONTAINER_ENGINE=docker NETRC=/tmp/.netrc GCP_IMAGE_NAME="${{ env.IMAGE_TIMESTAMP }}" AZURE_IMAGE_NAME="${{ env.IMAGE_TIMESTAMP }}"
AZURE_IMAGE_DEFINITION="constellation-coreos-debugd" AZURE_IMAGE_VERSION="${{env.IMAGE_VERSION }}" DOWNLOAD_COORDINATOR=n COORDINATOR_BINARY="${{ github.workspace }}/debugd/debugd/cmd/debugd/constellation-debugd"
image-gcp image-azure upload-gcp upload-azure
working-directory: ${{ env.working-directory }}

99
.github/workflows/build-coreos.yml vendored Normal file
View File

@ -0,0 +1,99 @@
name: Build and Upload CoreOS
env:
REGISTRY: ghcr.io
on:
workflow_dispatch:
inputs:
coordinator-name:
description: Coordinator name
required: true
type: string
workflow_call:
inputs:
coordinator-name:
required: true
type: string
secrets:
CI_GITHUB_REPOSITORY:
required: true
AWS_ACCESS_KEY_ID:
required: true
AWS_SECRET_ACCESS_KEY:
required: true
AWS_DEFAULT_REGION:
required: true
BUCKET_NAME:
required: true
PUBLIC_BUCKET_NAME:
required: true
SSH_PUB_KEY:
required: true
SSH_PUB_KEY_PATH:
required: true
AZURE_CREDENTIALS:
required: true
jobs:
build-enclave:
name: "Build CoreOS using customized COSA"
runs-on: [self-hosted, linux, nested-virt]
permissions:
contents: read
packages: read
defaults:
run:
shell: bash
env:
working-directory: ${{ github.workspace }}/images/fcos
SHELL: /bin/bash
steps:
- name: Checkout
id: checkout
uses: actions/checkout@v2
with:
submodules: recursive
token: ${{ secrets.CI_GITHUB_REPOSITORY }}
- name: Log in to the Container registry
id: docker-login
uses: docker/login-action@v1
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: "Install azure CLI"
run: |
# use pip since azure cli repository is not working as expected
# https://github.com/Azure/azure-cli/issues/21532
# curl -sL https://aka.ms/InstallAzureCLIDeb | sudo bash
sudo apt-get update
sudo apt-get install -y python3 python3-pip
sudo pip install azure-cli
wget -q https://aka.ms/downloadazcopy-v10-linux -O azcopy.tar.gz
tar --strip-components 1 -xf azcopy.tar.gz
rm azcopy.tar.gz
echo "$(pwd)" >> $GITHUB_PATH
- uses: azure/login@v1
with:
creds: ${{ secrets.AZURE_CREDENTIALS }}
- name: "Store GH token to be mounted by cosa"
run: echo "machine github.com login api password ${{ secrets.CI_GITHUB_REPOSITORY }}" > /tmp/.netrc
- name: "Set image timestamp"
run: |
TIMESTAMP=$(date +%s)
echo "TIMESTAMP=${TIMESTAMP}" >> $GITHUB_ENV
echo "IMAGE_TIMESTAMP=constellation-coreos-${TIMESTAMP}" >> $GITHUB_ENV
echo "IMAGE_VERSION=0.0.${TIMESTAMP}" >> $GITHUB_ENV
- name: "Build and Upload"
run: >
make -j$(nproc) CONTAINER_ENGINE=docker NETRC=/tmp/.netrc GCP_IMAGE_NAME="${{ env.IMAGE_TIMESTAMP }}" AZURE_IMAGE_NAME="${{ env.IMAGE_TIMESTAMP }}"
AZURE_IMAGE_DEFINITION="constellation-coreos" AZURE_IMAGE_VERSION="${{env.IMAGE_VERSION }}" COORDINATOR_URL="https://${{ secrets.PUBLIC_BUCKET_NAME }}.s3.us-east-2.amazonaws.com/coordinator/${{ inputs.coordinator-name }}"
image-gcp image-azure upload-gcp upload-azure
working-directory: ${{ env.working-directory }}

76
.github/workflows/build-enclave.yml vendored Normal file
View File

@ -0,0 +1,76 @@
name: Build and Upload the Enclave Image File
on:
workflow_dispatch:
inputs:
coordinator-name:
description: Coordinator name
required: true
type: string
workflow_call:
inputs:
coordinator-name:
required: true
type: string
secrets:
CI_GITHUB_REPOSITORY:
required: true
AWS_ACCESS_KEY_ID:
required: true
AWS_SECRET_ACCESS_KEY:
required: true
AWS_DEFAULT_REGION:
required: true
BUCKET_NAME:
required: true
PUBLIC_BUCKET_NAME:
required: true
SSH_PUB_KEY:
required: true
SSH_PUB_KEY_PATH:
required: true
jobs:
build-enclave:
name: "Build the Enclave"
runs-on: ubuntu-latest
env:
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_DEFAULT_REGION: ${{ secrets.AWS_DEFAULT_REGION }}
steps:
- name: Checkout
id: checkout
uses: actions/checkout@v2
with:
submodules: recursive
token: ${{ secrets.CI_GITHUB_REPOSITORY }}
- name: Install AWS CLI
id: prepare
run: sudo apt-get update && sudo apt-get -y install awscli
- name: Download bzImage, init and nsm.ko to AWS S3 Bucket
id: download-artifacts
run: aws s3 cp s3://${{ secrets.BUCKET_NAME }}/blobs/ ${{ github.workspace }}/images/aws/enclave/userland/dependencies/blobs/ --recursive
- name: Download Coordinator
id: download-coordinator
run: aws s3 cp s3://${{ secrets.PUBLIC_BUCKET_NAME }}/coordinator/${{ inputs.coordinator-name }} ${{ github.workspace }}/images/aws/enclave/userland/build/coordinator
- name: Write ssh public key to file
run: echo $SSH_PUB_KEY >> ${{ env.SSH_PUB_KEY_PATH }} && chmod 644 ${{ env.SSH_PUB_KEY_PATH }}
env:
SSH_PUB_KEY: ${{ secrets.SSH_PUB_KEY }}
SSH_PUB_KEY_PATH: ~/authorized_keys
- name: Build the eif file
run: make -j$(nproc) SSH_DIR=~/ -C ${{ github.workspace }}/images/aws/enclave/
- name: Upload eif file to AWS S3 Bucket
id: upload
run: aws s3 cp ${{ github.workspace }}/images/aws/enclave/userland/build/ s3://${{ secrets.BUCKET_NAME }}/eif/ --recursive --exclude "*" --include "*.eif" --quiet

36
.github/workflows/build-kernel.yml vendored Normal file
View File

@ -0,0 +1,36 @@
name: Build the Kernel
on:
push:
branches:
- main
paths:
- 'kernel/**'
workflow_dispatch:
jobs:
compile-and-upload-kernel:
name: "Compile and upload the Kernel"
runs-on: ubuntu-latest
steps:
- name: Install build dependencies
id: install
run: sudo apt-get update && sudo apt-get install -y git build-essential fakeroot libncurses5-dev libssl-dev ccache bison flex libelf-dev dwarves
- name: Checkout
id: checkout
uses: actions/checkout@v2
- name: Compile using make
id: compile
run: make -C ${{ github.workspace }}/images/aws/kernel/
- name: Install AWS CLI
id: prepare
run: sudo apt-get -y install awscli
- name: Upload bzImage, init and nsm.ko to AWS S3 Bucket
id: upload
run: aws s3 cp ${{ github.workspace }}/images/aws/kernel/build/blobs/ s3://${{ secrets.BUCKET_NAME }}/blobs/ --recursive --quiet
env:
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_DEFAULT_REGION: ${{ secrets.AWS_DEFAULT_REGION }}

View File

@ -0,0 +1,47 @@
name: Patch gvisor-tap-vsock and Upload to S3
on:
workflow_dispatch:
inputs:
version:
description: "gvisor version"
required: true
default: 0.3.0
jobs:
build:
name: "Build"
runs-on: ubuntu-latest
env:
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_DEFAULT_REGION: ${{ secrets.AWS_DEFAULT_REGION }}
working-directory: ec2
steps:
- name: Checkout
id: checkout
uses: actions/checkout@v2
- name: Prepare Download
id: prepare
run: sudo apt-get update && sudo apt-get -y install wget tar make
- name: Download and unpack sources
id: unpack
run: wget -c https://github.com/containers/gvisor-tap-vsock/archive/refs/tags/v${{ github.event.inputs.version }}.tar.gz -O - | tar xz
working-directory: ${{ github.workspace }}
- name: Install go
uses: actions/setup-go@v2
with:
go-version: go1.17.6
- name: Patch source code
run: patch --ignore-whitespace ${{ github.workspace }}/gvisor-tap-vsock-${{ github.event.inputs.version }}/pkg/services/forwarder/tcp.go < ${{ github.workspace }}/images/aws/ec2/patches/remove_link_local.patch
working-directory: ${{ env.working-directory }}
- name: Build gvisor
id: build
run: make -C ${{ github.workspace }}/gvisor-tap-vsock-${{ github.event.inputs.version }}/
- name: Upload gvproxy
id: upload_gvproxy
run: aws s3 cp ${{ github.workspace }}/gvisor-tap-vsock-${{ github.event.inputs.version }}/bin/gvproxy s3://${{ secrets.BUCKET_NAME }}/gvproxy/gvproxy --quiet

View File

@ -0,0 +1,22 @@
name: Etcd Integration Test
on:
workflow_dispatch:
push:
branches:
- main
pull_request:
jobs:
integration-test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Setup Go environment
uses: actions/setup-go@v2.1.4
with:
go-version: "1.18"
- name: Test Constellation etcd integration
run: go test -v --race -cover -count=3 -tags integration
working-directory: coordinator/store

23
.github/workflows/test-integration.yml vendored Normal file
View File

@ -0,0 +1,23 @@
name: Integration Test
on:
workflow_dispatch:
push:
branches:
- main
pull_request:
jobs:
integration-test:
runs-on: ubuntu-latest
env:
GOPRIVATE: github.com/edgelesssys/*
steps:
- uses: actions/checkout@v2
- name: Setup Go environment
uses: actions/setup-go@v2.1.4
with:
go-version: "1.18"
- name: Run Integration Test
run: DEBUG=true go test -v -tags integration ./test/

23
.github/workflows/test-lint.yml vendored Normal file
View File

@ -0,0 +1,23 @@
name: Golangci-lint
on:
pull_request:
permissions:
contents: read
# Allow read access to pull request. Use with `only-new-issues` option.
pull-requests: read
jobs:
golangci:
name: lint
runs-on: ubuntu-latest
env:
GOPRIVATE: github.com/edgelesssys/*
steps:
- uses: actions/checkout@v2
- name: golangci-lint
uses: golangci/golangci-lint-action@v2
with:
only-new-issues: true

18
.github/workflows/test-shellcheck.yml vendored Normal file
View File

@ -0,0 +1,18 @@
name: Shellcheck
on:
push:
branches:
- main
pull_request:
jobs:
shellcheck:
name: Shellcheck
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Run ShellCheck
uses: ludeeus/action-shellcheck@master
with:
severity: error
ignore_names: merge_config.sh

27
.github/workflows/test-unittest.yml vendored Normal file
View File

@ -0,0 +1,27 @@
name: Unit Tests
on:
workflow_dispatch:
push:
branches:
- main
pull_request:
jobs:
test:
runs-on: ubuntu-latest
env:
GOPRIVATE: github.com/edgelesssys/*
steps:
- uses: actions/checkout@v2
- name: Set up Go
uses: actions/setup-go@v2
with:
go-version: 1.18
- name: Install Dependencies
run: sudo apt-get update && sudo apt-get install -y libcryptsetup-dev
- name: Test
run: go test -race -count=3 ./...

38
.gitignore vendored Normal file
View File

@ -0,0 +1,38 @@
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
# Test binary, built with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# Dependency directories (remove the comment below to include it)
# vendor/
build
admin.conf
coordinatorConfig.json
coordinator-*
# VS Code configuration folder
.vscode
# Debug and testing files
debug/
# Images
images/aws/kernel/build/*
images/aws/kernel/sed*
images/aws/enclave/userland/build/*
images/aws/enclave/userland/privatekey
images/aws/enclave/userland/publickey
images/aws/enclave/.build-*
images/*.ign
images/fcos/build/*
images/fcos/dependencies/coordinator
images/fcos/images/*
images/fcos/cosa.lock

49
.golangci.yml Normal file
View File

@ -0,0 +1,49 @@
run:
timeout: 5m
output:
format: tab
sort-results: true
build-tags:
- integration
- aws
- gcp
linters:
enable:
# Default linters
- deadcode
- errcheck
- gosimple
- govet
- ineffassign
- staticcheck
- structcheck
- typecheck
- unused
- varcheck
# Additional linters
- bodyclose
- errname
- exportloopref
- ifshort
- godot
- gofmt
- gofumpt
- misspell
- noctx
- tenv
- unconvert
- unparam
issues:
max-issues-per-linter: 0
max-same-issues: 20
linters-settings:
errcheck:
# List of functions to exclude from checking, where each entry is a single function to exclude.
# See https://github.com/kisielk/errcheck#excluding-functions for details.
exclude-functions:
- (*go.uber.org/zap.Logger).Sync
- (*google.golang.org/grpc.Server).Serve

1025
3rdparty/aws-nitro-enclaves-ffi/Cargo.lock generated vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,17 @@
[package]
name = "aws-nitro-enclaves-ffi"
version = "0.1.0"
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
nsm-lib = { git = "https://github.com/aws/aws-nitro-enclaves-nsm-api", rev = "4f468c467583bbd55429935c4f09448dd43f48a0" }
aws-nitro-enclaves-attestation-ffi = { git = "https://github.com/ppmag/aws-nitro-enclaves-attestation", rev = "83ca87233298c302973a5bdbbb394c36cd7eb6e6" }
[lib]
name = "nitro"
crate-type = ["staticlib"]
[profile.release]
lto = true

View File

@ -0,0 +1,2 @@
pub use nitroattest::*;
pub use nsm::*;

67
CMakeLists.txt Normal file
View File

@ -0,0 +1,67 @@
cmake_minimum_required(VERSION 3.11)
project(coordinator LANGUAGES C VERSION 0.1.0)
enable_testing()
option(COORDINATOR_STATIC_MUSL "use musl and compile coordinator statically")
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Debug)
endif()
if(CMAKE_BUILD_TYPE STREQUAL Debug)
set(CARGOTARGET debug)
else()
set(CARGOTARGET release)
set(CARGOFLAGS --release)
endif()
if(COORDINATOR_STATIC_MUSL)
set(RUST_STATICLIB_LDFLAGS -static ${RUST_STATICLIB_LDFLAGS})
set(RUSTTARGETTRIPLE x86_64-unknown-linux-musl)
set(CARGOFLAGS ${CARGOFLAGS} "--target=${RUSTTARGETTRIPLE}")
set(CARGOTARGET ${RUSTTARGETTRIPLE}/${CARGOTARGET})
else()
set(RUST_STATICLIB_LDFLAGS -ldl -lm -lrt ${RUST_STATICLIB_LDFLAGS})
endif()
set(NITRO_CFLAGS '-I${CMAKE_BINARY_DIR}/nitro/${CARGOTARGET} -I${CMAKE_BINARY_DIR}/nitro/${CARGOTARGET}/headers')
set(NITRO_LDFLAGS '${CMAKE_BINARY_DIR}/nitro/${CARGOTARGET}/libnitro.a ${RUST_STATICLIB_LDFLAGS}')
#
# coordinator
#
add_custom_target(nitro
CARGO_TARGET_DIR=${CMAKE_BINARY_DIR}/nitro cargo build ${CARGOFLAGS}
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/3rdparty/aws-nitro-enclaves-ffi)
add_custom_target(coordinator ALL
${CMAKE_COMMAND} -E env CGO_CFLAGS=${NITRO_CFLAGS}
${CMAKE_COMMAND} -E env CGO_LDFLAGS=${NITRO_LDFLAGS}
go build -o ${CMAKE_BINARY_DIR} -tags=aws,gcp -buildvcs=false -ldflags "-buildid='' -X main.version=${PROJECT_VERSION}"
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/coordinator/cmd/coordinator)
add_dependencies(coordinator nitro)
#
# cli
#
add_custom_target(cli ALL
${CMAKE_COMMAND} -E env CGO_CFLAGS=${NITRO_CFLAGS}
${CMAKE_COMMAND} -E env CGO_LDFLAGS=${NITRO_LDFLAGS}
go build -o ${CMAKE_BINARY_DIR}/constellation -buildvcs=false -tags=aws,gcp -ldflags "-buildid='' -X github.com/edgelesssys/constellation/cli/defaults.Version=${PROJECT_VERSION}"
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/cli)
add_dependencies(cli nitro)
#
# testing / debugging
#
add_custom_target(debug_coordinator ALL
go build -o ${CMAKE_BINARY_DIR}/debug_coordinator -buildvcs=false -ldflags "-buildid='' -X main.version=${PROJECT_VERSION}"
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/coordinator/cmd/coordinator)
add_test(NAME unittest COMMAND go test -race -count=3 ./... WORKING_DIRECTORY ${CMAKE_SOURCE_DIR})
add_test(NAME integrationtest COMMAND go test -v -tags integration ./test/ WORKING_DIRECTORY ${CMAKE_SOURCE_DIR})
add_test(NAME etcd-unittest COMMAND go test -v --race -cover -count=3 -tags integration WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/coordinator/store/)

64
CONTRIBUTING.md Normal file
View File

@ -0,0 +1,64 @@
## Testing
Run all unit tests locally with
```sh
cd build
cmake ..
ctest
```
### E2E Test
Requirement: Kernel WireGuard, Docker
```sh
docker build -f Dockerfile.e2e -t constellation-e2e .
```
For the AWS test run
```sh
docker run -it --cap-add=NET_ADMIN --env GITHUB_TOKEN="$(cat ~/.netrc)" --env BRANCH="main" --env aws_access_key_id=XXX --env aws_secret_access_key=XXX constellation-e2e /initiateAWS.sh
```
For the gcp test run
```sh
docker run -it --cap-add=NET_ADMIN --env GITHUB_TOKEN="$(cat ~/.netrc)" --env BRANCH="main" --env GCLOUD_CREDENTIALS="$(cat ./constellation-keyfile.json)" constellation-e2e /initiategcloud.sh
```
## Linting
This projects uses [golangci-lint](https://golangci-lint.run/) for linting.
You can [install golangci-lint](https://golangci-lint.run/usage/install/#linux-and-windows) locally,
but there is also a CI action to ensure compliance.
To locally run all configured linters, execute
```
golangci-lint run ./...
```
It is also recommended to use golangci-lint (and [gofumpt](https://github.com/mvdan/gofumpt) as formatter) in your IDE, by adding the recommended VS Code Settings or by [configuring it yourself](https://golangci-lint.run/usage/integrations/#editor-integration)
## Recommended VS Code Settings
The following can be added to your personal `settings.json`, but it is recommended to add it to
the `<REPOSITORY>/.vscode/settings.json` repo, so the settings will only affect this repository.
```jsonc
// Use gofumpt as formatter.
"gopls": {
"formatting.gofumpt": true,
},
// Use golangci-lint as linter. Make sure you've installed it.
"go.lintTool":"golangci-lint",
"go.lintFlags": ["--fast"],
// You can easily show Go test coverage by running a package test.
"go.coverageOptions": "showUncoveredCodeOnly",
// Executing unit tests with race detection.
// You can add preferences like "-v" or "-count=1"
"go.testFlags": ["-race"],
// Enable language features for files with build tags.
// Attention! This leads to integration test being executed when
// running a package test within a package containing integration
// tests.
"go.buildTags": "integration",
```

37
Dockerfile.build Normal file
View File

@ -0,0 +1,37 @@
FROM ubuntu@sha256:7cc0576c7c0ec2384de5cbf245f41567e922aab1b075f3e8ad565f508032df17 as build
ENV DEBIAN_FRONTEND="noninteractive"
RUN apt-get update && apt-get install cmake iproute2 iputils-ping wget curl git jq libssl-dev musl-tools=1.1.24-1 -y
# Install Go
ARG GO_VER=1.18
RUN wget https://go.dev/dl/go${GO_VER}.linux-amd64.tar.gz
RUN tar -C /usr/local -xzf go${GO_VER}.linux-amd64.tar.gz && rm go${GO_VER}.linux-amd64.tar.gz
ENV PATH ${PATH}:/usr/local/go/bin
# Install Rust
ARG RUST_VER=1.58.0
RUN curl https://sh.rustup.rs -sSf | sh -s -- -y
ENV PATH /root/.cargo/bin:${PATH}
RUN rustup install ${RUST_VER}
RUN rustup override set ${RUST_VER}
RUN rustup target add x86_64-unknown-linux-musl
# Download go dependencies
WORKDIR /constellation/
COPY go.mod ./
COPY go.sum ./
RUN go mod download all
# Copy Repo
COPY . /constellation
# Build
RUN mkdir -p /constellation/build
WORKDIR /constellation/build
RUN cmake -DCMAKE_BUILD_TYPE=Release -DCOORDINATOR_STATIC_MUSL=ON .. && make coordinator
RUN mv coordinator coordinator-$(sha512sum coordinator | cut -d " " -f 1)
FROM scratch AS export
COPY --from=build /constellation/build/coordinator-* /

37
Dockerfile.e2e Normal file
View File

@ -0,0 +1,37 @@
FROM ubuntu:20.04
ENV DEBIAN_FRONTEND="noninteractive"
RUN apt-get update && apt-get install cmake iproute2 iputils-ping wget curl git libssl-dev -y
# Install kubectl
RUN curl -fsSLo /usr/local/bin/kubectl https://dl.k8s.io/release/v1.23.0/bin/linux/amd64/kubectl && chmod +x /usr/local/bin/kubectl
# Install Go
RUN wget https://go.dev/dl/go1.18.linux-amd64.tar.gz
RUN tar -C /usr/local -xzf go1.18.linux-amd64.tar.gz && rm go1.18.linux-amd64.tar.gz
ENV PATH ${PATH}:/usr/local/go/bin
# Install wireguard-tools
RUN git clone -b v1.0.20210914 --depth=1 https://git.zx2c4.com/wireguard-tools && make -C wireguard-tools/src -j`nproc` && make -C wireguard-tools/src install
# Install Rust
RUN curl https://sh.rustup.rs -sSf | sh -s -- -y
ENV PATH /root/.cargo/bin:${PATH}
# Setup CLI
RUN wg genkey | (umask 0077 && tee /privatekey) | wg pubkey > /publickey
RUN mkdir -p /root/.config/constellation && touch /root/.config/constellation/config.json
# Setup AWS config
RUN mkdir -p /root/.aws && echo "[default]\nregion = us-east-2" > /root/.aws/config && echo "[default]" >> /root/.aws/credentials
# Setup gcloud config
RUN mkdir -p /root/.config/gcloud
# Use local constellation state
# COPY . /constellation
# WORKDIR /constellation
# RUN mkdir build && cd build && cmake .. && make -j`nproc` cli
COPY ./test/ /
RUN chmod +x /initiateAWS.sh && chmod +x /initiategcloud.sh

327
README.md Normal file
View File

@ -0,0 +1,327 @@
# constellation-coordinator
## Prerequisites
* Go 1.18
### Ubuntu 20.04
```sh
sudo apt install build-essential cmake libssl-dev
curl https://sh.rustup.rs -sSf | sh
```
### Amazon Linux
```sh
sudo yum install cmake3 gcc make
curl https://sh.rustup.rs -sSf | sh
```
## Build
```sh
mkdir build
cd build
cmake ..
make -j`nproc`
```
## CMake build options:
### Release build
This options leaves out debug symbols and turns on more compiler optimizations.
```sh
cmake -DCMAKE_BUILD_TYPE=Release ..
```
### Static build (coordinator as static binary, no dependencies on libc or other libraries)
Install the musl-toolchain
Ubuntu / Debian:
```sh
sudo apt install -y musl-tools
rustup target add x86_64-unknown-linux-musl
```
From source (Amazon-Linux):
```sh
wget https://musl.libc.org/releases/musl-1.2.2.tar.gz
tar xfz musl-1.2.2.tar.gz
cd musl-1.2.2
./configure
make -j `nproc`
sudo make install
rustup target add x86_64-unknown-linux-musl
```
Add `musl-gcc` to your PATH:
```sh
export PATH=$PATH:/usr/loca/musl/bin/
```
Compile the coordinator
```sh
cmake -DCOORDINATOR_STATIC_MUSL=ON ..
```
## Cloud credentials
Using the CLI or debug-CLI requires the user to make authorized API calls to the AWS or GCP API.
### Google Cloud Platform (GCP)
If you are running from within a Google VM, and the VM is allowed to access the necessary APIs, no further configuration is needed.
Otherwise you have a couple options:
1. Use the `gcloud` CLI tool
```shell
gcloud auth application-default login
```
This will ask you to log into your Google account, and then create your credentials.
The Constellation CLI will automatically load these credentials when needed.
2. Set up a service account and pass the credentials manually
Follow [Google's guide](https://cloud.google.com/docs/authentication/production#manually) for setting up your credentials.
### Amazon Web Services (AWS)
To use the CLI with an Constellation cluster on AWS configure the following files:
```bash
$ cat ~/.aws/credentials
[default]
aws_access_key_id = XXXXX
aws_secret_access_key = XXXXX
```
```bash
$ cat ~/.aws/config
[default]
region = us-east-2
```
### Azure
To use the CLI with an Constellation cluster on Azure execute:
```bash
az login
```
### Deploying a locally compiled coordinator binary
By default, `constellation create ...` will spawn cloud provider instances with a pre-baked coordinator binary.
For testing, you can use the constellation debug daemon (debugd) to upload your local coordinator binary to running instances and to obtain SSH access.
See this introduction on how to install and setup `cdbg`: https://github.com/edgelesssys/constellation/debugd/#readme
# constellation-debugd
## Prerequisites
* Go 1.18
## Build
```
git clone https://github.com/edgelesssys/constellation/debugd
cd constellation-debugd
go build -o constellation-debugd debugd/cmd/debugd.go
go build -o constellation-cdbg cdbg/cdbg.go
```
## Install cdbg
```
go install github.com/edgelesssys/constellation/debugd/cdbg@latest
```
## Usage
With `cdbg` installed in your path:
1. Run `constellation --dev-config /path/to/dev-config create […]` while specifying a cloud-provider image with the debugd already included. See [Configuration](#configuration) for a dev-config with a custom image and firewall rules to allow incoming connection on the debugd default port 4000.
2. Run `cdbg deploy --dev-config /path/to/dev-config`
3. Run `constellation init […]` as usual
### GCP image
For GCP, run the following command to get a list of all constellation images, sorted by their creation date:
```
gcloud compute images list --filter="name~'constellation-.+'" --sort-by=~creationTimestamp
```
Choose the newest debugd image with the naming scheme `constellation-coreos-debugd-<timestamp>`.
### Azure Image
For Azure, run the following command to get a list of all constellation debugd images, sorted by their creation date:
```
az sig image-version list --resource-group constellation-images --gallery-name Constellation --gallery-image-definition constellation-coreos-debugd --query "sort_by([], &publishingProfile.publishedDate)[].id" -o table
```
Choose the newest debugd image and copy the full URI.
## Configuration
You should first locate the newest debugd image for your cloud provider ([GCP](#gcp-image), [Azure](#azure-image)).
This tool uses the dev-config file from `constellation-coordinator` and extends it with more fields.
See this example on what the possible settings are and how to setup the constellation cli to use a cloud-provider image and firewall rules with support for debugd:
```json
{
"cdbg":{
"authorized_keys":[
{
"user":"my-username",
"pubkey":"ssh-rsa AAAAB…LJuM="
}
],
"coordinator_path":"/path/to/coordinator",
"systemd_units":[
{
"name":"some-custom.service",
"contents":"[Unit]\nDescription=…"
}
]
},
"provider": {
"gcpconfig": {
"image": "constellation-coreos-debugd-TIMESTAMP",
"firewallinput": {
"Ingress": [
{
"Name": "coordinator",
"Description": "Coordinator default port",
"Protocol": "tcp",
"Port": 9000
},
{
"Name": "wireguard",
"Description": "WireGuard default port",
"Protocol": "udp",
"Port": 51820
},
{
"Name": "ssh",
"Description": "SSH",
"Protocol": "tcp",
"Port": 22
},
{
"Name": "debugd",
"Description": "debugd default port",
"Protocol": "tcp",
"Port": 4000
}
]
}
},
"azureconfig": {
"image": "/subscriptions/0d202bbb-4fa7-4af8-8125-58c269a05435/resourceGroups/CONSTELLATION-IMAGES/providers/Microsoft.Compute/galleries/Constellation/images/constellation-coreos-debugd/versions/0.0.TIMESTAMP",
"networksecuritygroupinput": {
"Ingress": [
{
"Name": "coordinator",
"Description": "Coordinator default port",
"Protocol": "tcp",
"IPRange": "0.0.0.0/0",
"Port": 9000
},
{
"Name": "wireguard",
"Description": "WireGuard default port",
"Protocol": "udp",
"IPRange": "0.0.0.0/0",
"Port": 51820
},
{
"Name": "ssh",
"Description": "SSH",
"Protocol": "tcp",
"IPRange": "0.0.0.0/0",
"Port": 22
},
{
"Name": "debugd",
"Description": "debugd default port",
"Protocol": "tcp",
"IPRange": "0.0.0.0/0",
"Port": 4000
}
]
}
}
}
}
```
# constellation-kms-client
This library provides an interface for the key management services used with constellation.
It's intendet for the Constellation CSI Plugins and the CLI.
## KMS
The Cloud KMS is where we store our key encryption key (KEK).
It should be initiated by the CLI and provided with a key release policy.
The CSP Plugin can request to encrypt data encryption keys (DEK) with the DEK to safely store them on persistent memory.
The [kms](pkg/kms) package interacts with the Cloud KMS APIs.
Currently planned are KMS are:
* AWS KMS
* GCP CKM
* Azure Key Vault
## Storage
Storage is where the CSI Plugin stores the encrypted DEKs.
Currently planned are:
* AWS S3, SSP
* GCP GCS
* Azure Blob
# constellation-images
# constellation-mount-utils
Wrapper for https://github.com/kubernetes/mount-utils
## Dependencies
This package uses the C library [`libcryptsetup`](https://gitlab.com/cryptsetup/cryptsetup/) for device mapping.
To install the required dependencies on Ubuntu run:
```shell
sudo apt install libcryptsetup-dev
```
To install or upgrade `go.mod` dependencies from private repositories run:
```
GOPRIVATE=github.com/edgelesssys/constellation-coordinator go get github.com/edgelesssys/constellation-coordinator
GOPRIVATE=github.com/edgelesssys/constellation-kms-client go get github.com/edgelesssys/constellation-kms-client
```
## Testing
A small test programm is available in `test/main.go`.
To build the programm run:
```shell
go build -o test/crypt ./test/
```
Create a new crypt device for `/dev/sdX` and map it to `/dev/mapper/volume01`:
```shell
sudo test/crypt -source /dev/sdX -target volume01 -v 4
```
You can now interact with the mapped volume as if it was an unformatted device:
```shell
sudo mkfs.ext4 /dev/mapper/volume01
sudo mount /dev/mapper/volume01 /mnt/volume01
```
Close the mapped volume:
```shell
sudo umount /mnt/volume01
sudo test/crypt -c -target volume01 -v 4
```

28
cli/README.md Normal file
View File

@ -0,0 +1,28 @@
# CLI to spawn a confidential kubernetes cluster
## Usage
0. (optional) replace the responsible in `cli/cmd/defaults.go` with yourself.
1. Build the CLI and authenticate with <AWS/Azure/GCP> according to the [README.md](https://github.com/edgelesssys/constellation-coordinator/blob/main/README.md#cloud-credentials).
2. Execute `constellation create <aws/azure/gcp> 2 <4xlarge|n2d-standard-2>`.
3. Execute `wg genkey | tee privatekey | wg pubkey > publickey` to generate a WireGuard keypair.
4. Execute `constellation init --publickey publickey`. Since the CLI waits for all nodes to be ready, this step can take up to 5 minutes.
5. Use the output from `constellation init` and the wireguard template below to create `/etc/wireguard/wg0.conf`, then execute `wg-quick up wg0`.
6. Execute `export KUBECONFIG=<path/to/admin.conf>`.
7. Use `kubectl get nodes` to inspect your cluster.
8. Execute `constellation terminate` to terminate your Constellation.
```bash
[Interface]
Address = <address from the init output>
PrivateKey = <your base64 encoded private key>
ListenPort = 51820
[Peer]
PublicKey = <public key from the init output>
AllowedIPs = 10.118.0.1/32 # IP set on the peer's wg interface
Endpoint = <public IPv4 address from the activated coordinator>:51820 # address where the peer listens on
PersistentKeepalive = 10
```
Note: Skip the manual configuration of WireGuard by executing Step 2 as root. Then, replace steps 4 and 5 with `sudo constellation init --privatekey <path/to/your/privatekey>`. This will automatically configure a new WireGuard interface named wg0 with the coordinator as peer.

View File

@ -0,0 +1,203 @@
package client
import (
"context"
"crypto/rand"
"errors"
"fmt"
"math/big"
"net/url"
"time"
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/Azure/go-autorest/autorest/date"
"github.com/Azure/go-autorest/autorest/to"
"github.com/google/uuid"
)
const (
adAppCredentialValidity = time.Hour * 24 * 365 * 5 // ~5 years
adReplicationLagCheckInterval = time.Second * 5 // 5 seconds
adReplicationLagCheckMaxRetries = int((15 * time.Minute) / adReplicationLagCheckInterval) // wait for up to 15 minutes for AD replication
ownerRoleDefinitionID = "8e3af657-a8ff-443c-a75c-2fe8c4bcb635" // https://docs.microsoft.com/en-us/azure/role-based-access-control/built-in-roles#owner
virtualMachineContributorRoleDefinitionID = "9980e02c-c2be-4d73-94e8-173b1dc7cf3c" // https://docs.microsoft.com/en-us/azure/role-based-access-control/built-in-roles#virtual-machine-contributor
)
// CreateServicePrincipal creates an Azure AD app with a service principal, gives it "Owner" role on the resource group and creates new credentials.
func (c *Client) CreateServicePrincipal(ctx context.Context) (string, error) {
createAppRes, err := c.createADApplication(ctx)
if err != nil {
return "", err
}
c.adAppObjectID = createAppRes.ObjectID
servicePrincipalObjectID, err := c.createAppServicePrincipal(ctx, createAppRes.AppID)
if err != nil {
return "", err
}
if err := c.assignResourceGroupRole(ctx, servicePrincipalObjectID, ownerRoleDefinitionID); err != nil {
return "", err
}
clientSecret, err := c.updateAppCredentials(ctx, createAppRes.ObjectID)
if err != nil {
return "", err
}
return ApplicationCredentials{
ClientID: createAppRes.AppID,
ClientSecret: clientSecret,
}.ConvertToCloudServiceAccountURI(), nil
}
// TerminateServicePrincipal terminates an Azure AD app together with the service principal.
func (c *Client) TerminateServicePrincipal(ctx context.Context) error {
if c.adAppObjectID == "" {
return nil
}
if _, err := c.applicationsAPI.Delete(ctx, c.adAppObjectID); err != nil {
return err
}
c.adAppObjectID = ""
return nil
}
// createADApplication creates a new azure AD app.
func (c *Client) createADApplication(ctx context.Context) (createADApplicationOutput, error) {
createParameters := graphrbac.ApplicationCreateParameters{
AvailableToOtherTenants: to.BoolPtr(false),
DisplayName: to.StringPtr("constellation-app-" + c.name + "-" + c.uid),
}
app, err := c.applicationsAPI.Create(ctx, createParameters)
if err != nil {
return createADApplicationOutput{}, err
}
if app.AppID == nil || app.ObjectID == nil {
return createADApplicationOutput{}, errors.New("creating AD application did not result in valid app id and object id")
}
return createADApplicationOutput{
AppID: *app.AppID,
ObjectID: *app.ObjectID,
}, nil
}
// createAppServicePrincipal creates a new service principal for an azure AD app.
func (c *Client) createAppServicePrincipal(ctx context.Context, appID string) (string, error) {
createParameters := graphrbac.ServicePrincipalCreateParameters{
AppID: &appID,
AccountEnabled: to.BoolPtr(true),
}
servicePrincipal, err := c.servicePrincipalsAPI.Create(ctx, createParameters)
if err != nil {
return "", err
}
if servicePrincipal.ObjectID == nil {
return "", errors.New("creating AD service principal did not result in a valid object id")
}
return *servicePrincipal.ObjectID, nil
}
// updateAppCredentials sets app client-secret for authentication.
func (c *Client) updateAppCredentials(ctx context.Context, objectID string) (string, error) {
keyID := uuid.New().String()
clientSecret, err := generateClientSecret()
if err != nil {
return "", fmt.Errorf("generating client secret failed: %w", err)
}
updateParameters := graphrbac.PasswordCredentialsUpdateParameters{
Value: &[]graphrbac.PasswordCredential{
{
StartDate: &date.Time{Time: time.Now()},
EndDate: &date.Time{Time: time.Now().Add(adAppCredentialValidity)},
Value: to.StringPtr(clientSecret),
KeyID: to.StringPtr(keyID),
},
},
}
_, err = c.applicationsAPI.UpdatePasswordCredentials(ctx, objectID, updateParameters)
if err != nil {
return "", err
}
return clientSecret, nil
}
// assignResourceGroupRole assigns the service principal a role at resource group scope.
func (c *Client) assignResourceGroupRole(ctx context.Context, principalID, roleDefinitionID string) error {
resourceGroup, err := c.resourceGroupAPI.Get(ctx, c.resourceGroup, nil)
if err != nil || resourceGroup.ID == nil {
return fmt.Errorf("unable to retrieve resource group id for group %v: %w", c.resourceGroup, err)
}
roleAssignmentID := uuid.New().String()
createParameters := authorization.RoleAssignmentCreateParameters{
Properties: &authorization.RoleAssignmentProperties{
PrincipalID: to.StringPtr(principalID),
RoleDefinitionID: to.StringPtr(fmt.Sprintf("/subscriptions/%s/providers/Microsoft.Authorization/roleDefinitions/%s", c.subscriptionID, roleDefinitionID)),
},
}
// due to an azure AD replication lag, retry role assignment if principal does not exist yet
// reference: https://docs.microsoft.com/en-us/azure/role-based-access-control/role-assignments-rest#new-service-principal
// proper fix: use API version 2018-09-01-preview or later
// azure go sdk currently uses version 2015-07-01: https://github.com/Azure/azure-sdk-for-go/blob/v62.0.0/services/authorization/mgmt/2015-07-01/authorization/roleassignments.go#L95
// the newer version "armauthorization.RoleAssignmentsClient" is currently broken: https://github.com/Azure/azure-sdk-for-go/issues/17071
for i := 0; i < c.adReplicationLagCheckMaxRetries; i++ {
_, err = c.roleAssignmentsAPI.Create(ctx, *resourceGroup.ID, roleAssignmentID, createParameters)
var detailedErr autorest.DetailedError
var ok bool
if detailedErr, ok = err.(autorest.DetailedError); !ok {
return err
}
var requestErr *azure.RequestError
if requestErr, ok = detailedErr.Original.(*azure.RequestError); !ok || requestErr.ServiceError == nil {
return err
}
if requestErr.ServiceError.Code != "PrincipalNotFound" {
return err
}
time.Sleep(c.adReplicationLagCheckInterval)
}
return err
}
// ApplicationCredentials is a set of Azure AD application credentials.
// It is the equivalent of a service account key in other cloud providers.
type ApplicationCredentials struct {
ClientID string
ClientSecret string
}
// ConvertToCloudServiceAccountURI converts the ApplicationCredentials into a cloud service account URI.
func (c ApplicationCredentials) ConvertToCloudServiceAccountURI() string {
query := url.Values{}
query.Add("client_id", c.ClientID)
query.Add("client_secret", c.ClientSecret)
uri := url.URL{
Scheme: "serviceaccount",
Host: "azure",
RawQuery: query.Encode(),
}
return uri.String()
}
type createADApplicationOutput struct {
AppID string
ObjectID string
}
func generateClientSecret() (string, error) {
letters := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
pwLen := 64
pw := make([]byte, 0, pwLen)
for i := 0; i < pwLen; i++ {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
if err != nil {
return "", err
}
pw = append(pw, letters[n.Int64()])
}
return string(pw), nil
}

View File

@ -0,0 +1,380 @@
package client
import (
"context"
"errors"
"net/url"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
)
func TestCreateServicePrincipal(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
applicationsAPI applicationsAPI
servicePrincipalsAPI servicePrincipalsAPI
roleAssignmentsAPI roleAssignmentsAPI
resourceGroupAPI resourceGroupAPI
errExpected bool
}{
"successful create": {
applicationsAPI: stubApplicationsAPI{},
servicePrincipalsAPI: stubServicePrincipalsAPI{},
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
},
},
"failed app create": {
applicationsAPI: stubApplicationsAPI{
createErr: someErr,
},
errExpected: true,
},
"failed service principal create": {
applicationsAPI: stubApplicationsAPI{},
servicePrincipalsAPI: stubServicePrincipalsAPI{
createErr: someErr,
},
errExpected: true,
},
"failed role assignment": {
applicationsAPI: stubApplicationsAPI{},
servicePrincipalsAPI: stubServicePrincipalsAPI{},
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{someErr},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
},
errExpected: true,
},
"failed update creds": {
applicationsAPI: stubApplicationsAPI{
updateCredentialsErr: someErr,
},
servicePrincipalsAPI: stubServicePrincipalsAPI{},
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
resourceGroup: "resource-group",
applicationsAPI: tc.applicationsAPI,
servicePrincipalsAPI: tc.servicePrincipalsAPI,
roleAssignmentsAPI: tc.roleAssignmentsAPI,
resourceGroupAPI: tc.resourceGroupAPI,
adReplicationLagCheckMaxRetries: 2,
}
_, err := client.CreateServicePrincipal(ctx)
if tc.errExpected {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}
func TestTerminateServicePrincipal(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
appObjectID string
applicationsAPI applicationsAPI
errExpected bool
}{
"successful terminate": {
appObjectID: "object-id",
applicationsAPI: stubApplicationsAPI{},
},
"nothing to terminate": {
applicationsAPI: stubApplicationsAPI{},
},
"failed delete": {
appObjectID: "object-id",
applicationsAPI: stubApplicationsAPI{
deleteErr: someErr,
},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
resourceGroup: "resource-group",
adAppObjectID: tc.appObjectID,
applicationsAPI: tc.applicationsAPI,
}
err := client.TerminateServicePrincipal(ctx)
if tc.errExpected {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}
func TestCreateADApplication(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
applicationsAPI applicationsAPI
errExpected bool
}{
"successful create": {
applicationsAPI: stubApplicationsAPI{},
},
"failed app create": {
applicationsAPI: stubApplicationsAPI{
createErr: someErr,
},
errExpected: true,
},
"app create returns invalid appid": {
applicationsAPI: stubApplicationsAPI{
createApplication: &graphrbac.Application{
ObjectID: proto.String("00000000-0000-0000-0000-000000000001"),
},
},
errExpected: true,
},
"app create returns invalid objectid": {
applicationsAPI: stubApplicationsAPI{
createApplication: &graphrbac.Application{
AppID: proto.String("00000000-0000-0000-0000-000000000000"),
},
},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
applicationsAPI: tc.applicationsAPI,
}
appCredentials, err := client.createADApplication(ctx)
if tc.errExpected {
assert.Error(err)
return
}
assert.NoError(err)
assert.NotNil(appCredentials)
})
}
}
func TestCreateAppServicePrincipal(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
servicePrincipalsAPI servicePrincipalsAPI
errExpected bool
}{
"successful create": {
servicePrincipalsAPI: stubServicePrincipalsAPI{},
},
"failed service principal create": {
servicePrincipalsAPI: stubServicePrincipalsAPI{
createErr: someErr,
},
errExpected: true,
},
"service principal create returns invalid objectid": {
servicePrincipalsAPI: stubServicePrincipalsAPI{
createServicePrincipal: &graphrbac.ServicePrincipal{},
},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
servicePrincipalsAPI: tc.servicePrincipalsAPI,
}
_, err := client.createAppServicePrincipal(ctx, "app-id")
if tc.errExpected {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}
func TestAssignOwnerOfResourceGroup(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
roleAssignmentsAPI roleAssignmentsAPI
resourceGroupAPI resourceGroupAPI
errExpected bool
}{
"successful assign": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
},
},
"failed role assignment": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{someErr},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
},
errExpected: true,
},
"failed resource group get": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getErr: someErr,
},
errExpected: true,
},
"resource group get returns invalid id": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{},
},
errExpected: true,
},
"create returns PrincipalNotFound the first time": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{
autorest.DetailedError{Original: &azure.RequestError{
ServiceError: &azure.ServiceError{
Code: "PrincipalNotFound",
},
}},
nil,
},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
},
},
"create does not return request error": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{autorest.DetailedError{Original: someErr}},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
},
errExpected: true,
},
"create service error code is unknown": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{
autorest.DetailedError{Original: &azure.RequestError{
ServiceError: &azure.ServiceError{
Code: "some-unknown-error-code",
},
}},
nil,
},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
resourceGroup: "resource-group",
roleAssignmentsAPI: tc.roleAssignmentsAPI,
resourceGroupAPI: tc.resourceGroupAPI,
adReplicationLagCheckMaxRetries: 2,
}
err := client.assignResourceGroupRole(ctx, "principal-id", "role-definition-id")
if tc.errExpected {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}
func TestConvertToCloudServiceAccountURI(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
key := ApplicationCredentials{
ClientID: "client-id",
ClientSecret: "client-secret",
}
cloudServiceAccountURI := key.ConvertToCloudServiceAccountURI()
uri, err := url.Parse(cloudServiceAccountURI)
require.NoError(err)
query := uri.Query()
assert.Equal("serviceaccount", uri.Scheme)
assert.Equal("azure", uri.Host)
assert.Equal(url.Values{
"client_id": []string{"client-id"},
"client_secret": []string{"client-secret"},
}, query)
}

129
cli/azure/client/api.go Normal file
View File

@ -0,0 +1,129 @@
package client
import (
"context"
"time"
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
)
type virtualNetworksCreateOrUpdatePollerResponse interface {
PollUntilDone(ctx context.Context, freq time.Duration) (armnetwork.VirtualNetworksClientCreateOrUpdateResponse, error)
}
type networksAPI interface {
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
virtualNetworkName string, parameters armnetwork.VirtualNetwork,
options *armnetwork.VirtualNetworksClientBeginCreateOrUpdateOptions) (
virtualNetworksCreateOrUpdatePollerResponse, error)
}
type networkSecurityGroupsCreateOrUpdatePollerResponse interface {
PollUntilDone(ctx context.Context, freq time.Duration) (armnetwork.SecurityGroupsClientCreateOrUpdateResponse, error)
}
type networkSecurityGroupsAPI interface {
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
networkSecurityGroupName string, parameters armnetwork.SecurityGroup,
options *armnetwork.SecurityGroupsClientBeginCreateOrUpdateOptions) (
networkSecurityGroupsCreateOrUpdatePollerResponse, error)
}
type virtualMachineScaleSetsCreateOrUpdatePollerResponse interface {
PollUntilDone(ctx context.Context, freq time.Duration) (armcompute.VirtualMachineScaleSetsClientCreateOrUpdateResponse, error)
}
type scaleSetsAPI interface {
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
vmScaleSetName string, parameters armcompute.VirtualMachineScaleSet,
options *armcompute.VirtualMachineScaleSetsClientBeginCreateOrUpdateOptions) (
virtualMachineScaleSetsCreateOrUpdatePollerResponse, error)
}
type publicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager interface {
NextPage(ctx context.Context) bool
PageResponse() armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResponse
}
// TODO: deprecate as soon as scale sets are available.
type publicIPAddressesClientCreateOrUpdatePollerResponse interface {
PollUntilDone(ctx context.Context, freq time.Duration) (armnetwork.PublicIPAddressesClientCreateOrUpdateResponse, error)
}
type publicIPAddressesAPI interface {
ListVirtualMachineScaleSetVMPublicIPAddresses(resourceGroupName string,
virtualMachineScaleSetName string, virtualmachineIndex string,
networkInterfaceName string, ipConfigurationName string,
options *armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesOptions,
) publicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager
// TODO: deprecate as soon as scale sets are available.
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, publicIPAddressName string,
parameters armnetwork.PublicIPAddress, options *armnetwork.PublicIPAddressesClientBeginCreateOrUpdateOptions) (
publicIPAddressesClientCreateOrUpdatePollerResponse, error)
// TODO: deprecate as soon as scale sets are available.
Get(ctx context.Context, resourceGroupName string, publicIPAddressName string, options *armnetwork.PublicIPAddressesClientGetOptions) (
armnetwork.PublicIPAddressesClientGetResponse, error)
}
// TODO: deprecate as soon as scale sets are available.
type interfacesClientCreateOrUpdatePollerResponse interface {
PollUntilDone(ctx context.Context, freq time.Duration) (armnetwork.InterfacesClientCreateOrUpdateResponse, error)
}
type networkInterfacesAPI interface {
GetVirtualMachineScaleSetNetworkInterface(ctx context.Context, resourceGroupName string,
virtualMachineScaleSetName string, virtualmachineIndex string, networkInterfaceName string,
options *armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceOptions,
) (armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse, error)
// TODO: deprecate as soon as scale sets are available
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, networkInterfaceName string,
parameters armnetwork.Interface, options *armnetwork.InterfacesClientBeginCreateOrUpdateOptions) (
interfacesClientCreateOrUpdatePollerResponse, error)
}
type resourceGroupsDeletePollerResponse interface {
PollUntilDone(ctx context.Context, freq time.Duration) (armresources.ResourceGroupsClientDeleteResponse, error)
}
type resourceGroupAPI interface {
CreateOrUpdate(ctx context.Context, resourceGroupName string,
parameters armresources.ResourceGroup,
options *armresources.ResourceGroupsClientCreateOrUpdateOptions) (
armresources.ResourceGroupsClientCreateOrUpdateResponse, error)
BeginDelete(ctx context.Context, resourceGroupName string,
options *armresources.ResourceGroupsClientBeginDeleteOptions) (
resourceGroupsDeletePollerResponse, error)
Get(ctx context.Context, resourceGroupName string, options *armresources.ResourceGroupsClientGetOptions) (armresources.ResourceGroupsClientGetResponse, error)
}
type applicationsAPI interface {
Create(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error)
Delete(ctx context.Context, applicationObjectID string) (autorest.Response, error)
UpdatePasswordCredentials(ctx context.Context, objectID string, parameters graphrbac.PasswordCredentialsUpdateParameters) (autorest.Response, error)
}
type servicePrincipalsAPI interface {
Create(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error)
}
// the newer version "armauthorization.RoleAssignmentsClient" is currently broken: https://github.com/Azure/azure-sdk-for-go/issues/17071
// TODO: switch to "armauthorization.RoleAssignmentsClient" if issue is resolved.
type roleAssignmentsAPI interface {
Create(ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error)
}
// TODO: deprecate as soon as scale sets are available.
type virtualMachinesClientCreateOrUpdatePollerResponse interface {
PollUntilDone(ctx context.Context, freq time.Duration) (armcompute.VirtualMachinesClientCreateOrUpdateResponse, error)
}
// TODO: deprecate as soon as scale sets are available.
type virtualMachinesAPI interface {
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, vmName string, parameters armcompute.VirtualMachine,
options *armcompute.VirtualMachinesClientBeginCreateOrUpdateOptions) (virtualMachinesClientCreateOrUpdatePollerResponse, error)
}

View File

@ -0,0 +1,388 @@
package client
import (
"context"
"time"
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
)
type stubNetworksAPI struct {
createErr error
stubResponse stubVirtualNetworksCreateOrUpdatePollerResponse
}
type stubVirtualNetworksCreateOrUpdatePollerResponse struct {
armnetwork.VirtualNetworksClientCreateOrUpdatePollerResponse
pollerErr error
}
func (r stubVirtualNetworksCreateOrUpdatePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration,
) (armnetwork.VirtualNetworksClientCreateOrUpdateResponse, error) {
return armnetwork.VirtualNetworksClientCreateOrUpdateResponse{
VirtualNetworksClientCreateOrUpdateResult: armnetwork.VirtualNetworksClientCreateOrUpdateResult{
VirtualNetwork: armnetwork.VirtualNetwork{
Properties: &armnetwork.VirtualNetworkPropertiesFormat{
Subnets: []*armnetwork.Subnet{
{
ID: to.StringPtr("virtual-network-subnet-id"),
},
},
},
},
},
}, r.pollerErr
}
func (a stubNetworksAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
virtualNetworkName string, parameters armnetwork.VirtualNetwork,
options *armnetwork.VirtualNetworksClientBeginCreateOrUpdateOptions) (
virtualNetworksCreateOrUpdatePollerResponse, error,
) {
return a.stubResponse, a.createErr
}
type stubNetworkSecurityGroupsCreateOrUpdatePollerResponse struct {
armnetwork.SecurityGroupsClientCreateOrUpdatePollerResponse
pollerErr error
}
func (r stubNetworkSecurityGroupsCreateOrUpdatePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration,
) (armnetwork.SecurityGroupsClientCreateOrUpdateResponse, error) {
return armnetwork.SecurityGroupsClientCreateOrUpdateResponse{
SecurityGroupsClientCreateOrUpdateResult: armnetwork.SecurityGroupsClientCreateOrUpdateResult{
SecurityGroup: armnetwork.SecurityGroup{
ID: to.StringPtr("network-security-group-id"),
},
},
}, r.pollerErr
}
type stubNetworkSecurityGroupsAPI struct {
createErr error
stubPoller stubNetworkSecurityGroupsCreateOrUpdatePollerResponse
}
func (a stubNetworkSecurityGroupsAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
networkSecurityGroupName string, parameters armnetwork.SecurityGroup,
options *armnetwork.SecurityGroupsClientBeginCreateOrUpdateOptions) (
networkSecurityGroupsCreateOrUpdatePollerResponse, error,
) {
return a.stubPoller, a.createErr
}
type stubResourceGroupAPI struct {
terminateErr error
createErr error
getErr error
getResourceGroup armresources.ResourceGroup
stubResponse stubResourceGroupsDeletePollerResponse
}
func (a stubResourceGroupAPI) CreateOrUpdate(ctx context.Context, resourceGroupName string,
parameters armresources.ResourceGroup,
options *armresources.ResourceGroupsClientCreateOrUpdateOptions) (
armresources.ResourceGroupsClientCreateOrUpdateResponse, error,
) {
return armresources.ResourceGroupsClientCreateOrUpdateResponse{}, a.createErr
}
func (a stubResourceGroupAPI) Get(ctx context.Context, resourceGroupName string, options *armresources.ResourceGroupsClientGetOptions) (armresources.ResourceGroupsClientGetResponse, error) {
return armresources.ResourceGroupsClientGetResponse{
ResourceGroupsClientGetResult: armresources.ResourceGroupsClientGetResult{
ResourceGroup: a.getResourceGroup,
},
}, a.getErr
}
type stubResourceGroupsDeletePollerResponse struct {
armresources.ResourceGroupsClientDeletePollerResponse
pollerErr error
}
func (r stubResourceGroupsDeletePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration) (
armresources.ResourceGroupsClientDeleteResponse, error,
) {
return armresources.ResourceGroupsClientDeleteResponse{}, r.pollerErr
}
func (a stubResourceGroupAPI) BeginDelete(ctx context.Context, resourceGroupName string,
options *armresources.ResourceGroupsClientBeginDeleteOptions) (
resourceGroupsDeletePollerResponse, error,
) {
return a.stubResponse, a.terminateErr
}
type stubScaleSetsAPI struct {
createErr error
stubResponse stubVirtualMachineScaleSetsCreateOrUpdatePollerResponse
}
type stubVirtualMachineScaleSetsCreateOrUpdatePollerResponse struct {
pollResponse armcompute.VirtualMachineScaleSetsClientCreateOrUpdateResponse
pollErr error
}
func (r stubVirtualMachineScaleSetsCreateOrUpdatePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration) (
armcompute.VirtualMachineScaleSetsClientCreateOrUpdateResponse, error,
) {
return r.pollResponse, r.pollErr
}
func (a stubScaleSetsAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
vmScaleSetName string, parameters armcompute.VirtualMachineScaleSet,
options *armcompute.VirtualMachineScaleSetsClientBeginCreateOrUpdateOptions) (
virtualMachineScaleSetsCreateOrUpdatePollerResponse, error,
) {
return a.stubResponse, a.createErr
}
// TODO: deprecate as soon as scale sets are available.
type stubPublicIPAddressesAPI struct {
// TODO: deprecate as soon as scale sets are available.
createErr error
// TODO: deprecate as soon as scale sets are available.
getErr error
// TODO: deprecate as soon as scale sets are available.
stubCreateResponse stubPublicIPAddressesClientCreateOrUpdatePollerResponse
}
// TODO: deprecate as soon as scale sets are available.
type stubPublicIPAddressesClientCreateOrUpdatePollerResponse struct {
armnetwork.PublicIPAddressesClientCreateOrUpdatePollerResponse
pollErr error
}
// TODO: deprecate as soon as scale sets are available.
func (r stubPublicIPAddressesClientCreateOrUpdatePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration) (
armnetwork.PublicIPAddressesClientCreateOrUpdateResponse, error,
) {
return armnetwork.PublicIPAddressesClientCreateOrUpdateResponse{
PublicIPAddressesClientCreateOrUpdateResult: armnetwork.PublicIPAddressesClientCreateOrUpdateResult{
PublicIPAddress: armnetwork.PublicIPAddress{
ID: to.StringPtr("pubIP-id"),
},
},
}, r.pollErr
}
type stubPublicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager struct {
pagesCounter int
PagesMax int
}
func (p *stubPublicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager) NextPage(ctx context.Context) bool {
p.pagesCounter++
return p.pagesCounter <= p.PagesMax
}
func (p *stubPublicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager) PageResponse() armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResponse {
return armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResponse{
PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResult: armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResult{
PublicIPAddressListResult: armnetwork.PublicIPAddressListResult{
Value: []*armnetwork.PublicIPAddress{
{
Properties: &armnetwork.PublicIPAddressPropertiesFormat{
IPAddress: to.StringPtr("192.0.2.1"),
},
},
},
},
},
}
}
func (a stubPublicIPAddressesAPI) ListVirtualMachineScaleSetVMPublicIPAddresses(resourceGroupName string,
virtualMachineScaleSetName string, virtualmachineIndex string,
networkInterfaceName string, ipConfigurationName string,
options *armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesOptions,
) publicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager {
return &stubPublicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager{pagesCounter: 0, PagesMax: 1}
}
// TODO: deprecate as soon as scale sets are available.
func (a stubPublicIPAddressesAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, publicIPAddressName string,
parameters armnetwork.PublicIPAddress, options *armnetwork.PublicIPAddressesClientBeginCreateOrUpdateOptions) (
publicIPAddressesClientCreateOrUpdatePollerResponse, error,
) {
return a.stubCreateResponse, a.createErr
}
// TODO: deprecate as soon as scale sets are available.
func (a stubPublicIPAddressesAPI) Get(ctx context.Context, resourceGroupName string, publicIPAddressName string, options *armnetwork.PublicIPAddressesClientGetOptions) (
armnetwork.PublicIPAddressesClientGetResponse, error,
) {
return armnetwork.PublicIPAddressesClientGetResponse{
PublicIPAddressesClientGetResult: armnetwork.PublicIPAddressesClientGetResult{
PublicIPAddress: armnetwork.PublicIPAddress{
Properties: &armnetwork.PublicIPAddressPropertiesFormat{
IPAddress: to.StringPtr("192.0.2.1"),
},
},
},
}, a.getErr
}
type stubNetworkInterfacesAPI struct {
getErr error
// TODO: deprecate as soon as scale sets are available
createErr error
// TODO: deprecate as soon as scale sets are available
stubResp stubInterfacesClientCreateOrUpdatePollerResponse
}
func (a stubNetworkInterfacesAPI) GetVirtualMachineScaleSetNetworkInterface(ctx context.Context, resourceGroupName string,
virtualMachineScaleSetName string, virtualmachineIndex string, networkInterfaceName string,
options *armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceOptions,
) (armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse, error) {
if a.getErr != nil {
return armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse{}, a.getErr
}
return armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse{
InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResult: armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResult{
Interface: armnetwork.Interface{
Properties: &armnetwork.InterfacePropertiesFormat{
IPConfigurations: []*armnetwork.InterfaceIPConfiguration{
{
Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{
PrivateIPAddress: to.StringPtr("192.0.2.1"),
},
},
},
},
},
},
}, nil
}
// TODO: deprecate as soon as scale sets are available.
type stubInterfacesClientCreateOrUpdatePollerResponse struct {
pollErr error
}
// TODO: deprecate as soon as scale sets are available.
func (r stubInterfacesClientCreateOrUpdatePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration) (
armnetwork.InterfacesClientCreateOrUpdateResponse, error,
) {
return armnetwork.InterfacesClientCreateOrUpdateResponse{
InterfacesClientCreateOrUpdateResult: armnetwork.InterfacesClientCreateOrUpdateResult{
Interface: armnetwork.Interface{
ID: to.StringPtr("interface-id"),
Properties: &armnetwork.InterfacePropertiesFormat{
IPConfigurations: []*armnetwork.InterfaceIPConfiguration{
{
Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{
PrivateIPAddress: to.StringPtr("192.0.2.1"),
},
},
},
},
},
},
}, r.pollErr
}
// TODO: deprecate as soon as scale sets are available.
func (a stubNetworkInterfacesAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, networkInterfaceName string,
parameters armnetwork.Interface, options *armnetwork.InterfacesClientBeginCreateOrUpdateOptions) (
interfacesClientCreateOrUpdatePollerResponse, error,
) {
return a.stubResp, a.createErr
}
// TODO: deprecate as soon as scale sets are available.
type stubVirtualMachinesAPI struct {
stubResponse stubVirtualMachinesClientCreateOrUpdatePollerResponse
createErr error
}
// TODO: deprecate as soon as scale sets are available.
type stubVirtualMachinesClientCreateOrUpdatePollerResponse struct {
pollResponse armcompute.VirtualMachinesClientCreateOrUpdateResponse
pollErr error
}
// TODO: deprecate as soon as scale sets are available.
func (r stubVirtualMachinesClientCreateOrUpdatePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration) (
armcompute.VirtualMachinesClientCreateOrUpdateResponse, error,
) {
return r.pollResponse, r.pollErr
}
// TODO: deprecate as soon as scale sets are available.
func (a stubVirtualMachinesAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, vmName string, parameters armcompute.VirtualMachine,
options *armcompute.VirtualMachinesClientBeginCreateOrUpdateOptions,
) (virtualMachinesClientCreateOrUpdatePollerResponse, error) {
return a.stubResponse, a.createErr
}
type stubApplicationsAPI struct {
createErr error
deleteErr error
updateCredentialsErr error
createApplication *graphrbac.Application
}
func (a stubApplicationsAPI) Create(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error) {
if a.createErr != nil {
return graphrbac.Application{}, a.createErr
}
if a.createApplication != nil {
return *a.createApplication, nil
}
return graphrbac.Application{
AppID: to.StringPtr("00000000-0000-0000-0000-000000000000"),
ObjectID: to.StringPtr("00000000-0000-0000-0000-000000000001"),
}, nil
}
func (a stubApplicationsAPI) Delete(ctx context.Context, applicationObjectID string) (autorest.Response, error) {
if a.deleteErr != nil {
return autorest.Response{}, a.deleteErr
}
return autorest.Response{}, nil
}
func (a stubApplicationsAPI) UpdatePasswordCredentials(ctx context.Context, objectID string, parameters graphrbac.PasswordCredentialsUpdateParameters) (autorest.Response, error) {
if a.updateCredentialsErr != nil {
return autorest.Response{}, a.updateCredentialsErr
}
return autorest.Response{}, nil
}
type stubServicePrincipalsAPI struct {
createErr error
createServicePrincipal *graphrbac.ServicePrincipal
}
func (a stubServicePrincipalsAPI) Create(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) {
if a.createErr != nil {
return graphrbac.ServicePrincipal{}, a.createErr
}
if a.createServicePrincipal != nil {
return *a.createServicePrincipal, nil
}
return graphrbac.ServicePrincipal{
AppID: to.StringPtr("00000000-0000-0000-0000-000000000000"),
ObjectID: to.StringPtr("00000000-0000-0000-0000-000000000002"),
}, nil
}
type stubRoleAssignmentsAPI struct {
createCounter int
createErrors []error
}
func (a *stubRoleAssignmentsAPI) Create(ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) {
a.createCounter += 1
if len(a.createErrors) == 0 {
return authorization.RoleAssignment{}, nil
}
return authorization.RoleAssignment{}, a.createErrors[(a.createCounter-1)%len(a.createErrors)]
}

View File

@ -0,0 +1,136 @@
package client
import (
"context"
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
)
type networksClient struct {
*armnetwork.VirtualNetworksClient
}
func (c *networksClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
virtualNetworkName string, parameters armnetwork.VirtualNetwork,
options *armnetwork.VirtualNetworksClientBeginCreateOrUpdateOptions) (
virtualNetworksCreateOrUpdatePollerResponse, error,
) {
return c.VirtualNetworksClient.BeginCreateOrUpdate(ctx, resourceGroupName, virtualNetworkName, parameters, options)
}
// TODO: deprecate as soon as scale sets are available.
type networkInterfacesClient struct {
*armnetwork.InterfacesClient
}
// TODO: deprecate as soon as scale sets are available.
func (c *networkInterfacesClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, networkInterfaceName string,
parameters armnetwork.Interface, options *armnetwork.InterfacesClientBeginCreateOrUpdateOptions,
) (interfacesClientCreateOrUpdatePollerResponse, error) {
return c.InterfacesClient.BeginCreateOrUpdate(ctx, resourceGroupName, networkInterfaceName, parameters, options)
}
type networkSecurityGroupsClient struct {
*armnetwork.SecurityGroupsClient
}
func (c *networkSecurityGroupsClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
networkSecurityGroupName string, parameters armnetwork.SecurityGroup,
options *armnetwork.SecurityGroupsClientBeginCreateOrUpdateOptions) (
networkSecurityGroupsCreateOrUpdatePollerResponse, error,
) {
return c.SecurityGroupsClient.BeginCreateOrUpdate(ctx, resourceGroupName, networkSecurityGroupName, parameters, options)
}
type publicIPAddressesClient struct {
*armnetwork.PublicIPAddressesClient
}
func (c *publicIPAddressesClient) ListVirtualMachineScaleSetVMPublicIPAddresses(resourceGroupName string,
virtualMachineScaleSetName string, virtualmachineIndex string,
networkInterfaceName string, ipConfigurationName string,
options *armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesOptions,
) publicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager {
return c.PublicIPAddressesClient.ListVirtualMachineScaleSetVMPublicIPAddresses(resourceGroupName, virtualMachineScaleSetName,
virtualmachineIndex, networkInterfaceName, ipConfigurationName, options)
}
// TODO: deprecate as soon as scale sets are available.
func (c *publicIPAddressesClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, publicIPAddressName string,
parameters armnetwork.PublicIPAddress, options *armnetwork.PublicIPAddressesClientBeginCreateOrUpdateOptions) (
publicIPAddressesClientCreateOrUpdatePollerResponse, error,
) {
return c.PublicIPAddressesClient.BeginCreateOrUpdate(ctx, resourceGroupName, publicIPAddressName, parameters, options)
}
type virtualMachineScaleSetsClient struct {
*armcompute.VirtualMachineScaleSetsClient
}
func (c *virtualMachineScaleSetsClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
vmScaleSetName string, parameters armcompute.VirtualMachineScaleSet,
options *armcompute.VirtualMachineScaleSetsClientBeginCreateOrUpdateOptions) (
virtualMachineScaleSetsCreateOrUpdatePollerResponse, error,
) {
return c.VirtualMachineScaleSetsClient.BeginCreateOrUpdate(ctx, resourceGroupName, vmScaleSetName, parameters, options)
}
type resourceGroupsClient struct {
*armresources.ResourceGroupsClient
}
func (c *resourceGroupsClient) BeginDelete(ctx context.Context, resourceGroupName string,
options *armresources.ResourceGroupsClientBeginDeleteOptions) (
resourceGroupsDeletePollerResponse, error,
) {
return c.ResourceGroupsClient.BeginDelete(ctx, resourceGroupName, options)
}
// TODO: deprecate as soon as scale sets are available.
type virtualMachinesClient struct {
*armcompute.VirtualMachinesClient
}
// TODO: deprecate as soon as scale sets are available.
func (c *virtualMachinesClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, vmName string, parameters armcompute.VirtualMachine,
options *armcompute.VirtualMachinesClientBeginCreateOrUpdateOptions,
) (virtualMachinesClientCreateOrUpdatePollerResponse, error) {
return c.VirtualMachinesClient.BeginCreateOrUpdate(ctx, resourceGroupName, vmName, parameters, options)
}
type applicationsClient struct {
*graphrbac.ApplicationsClient
}
func (c *applicationsClient) Create(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error) {
return c.ApplicationsClient.Create(ctx, parameters)
}
func (c *applicationsClient) Delete(ctx context.Context, applicationObjectID string) (autorest.Response, error) {
return c.ApplicationsClient.Delete(ctx, applicationObjectID)
}
func (c *applicationsClient) UpdatePasswordCredentials(ctx context.Context, objectID string, parameters graphrbac.PasswordCredentialsUpdateParameters) (autorest.Response, error) {
return c.ApplicationsClient.UpdatePasswordCredentials(ctx, objectID, parameters)
}
type servicePrincipalsClient struct {
*graphrbac.ServicePrincipalsClient
}
func (c *servicePrincipalsClient) Create(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) {
return c.ServicePrincipalsClient.Create(ctx, parameters)
}
type roleAssignmentsClient struct {
*authorization.RoleAssignmentsClient
}
func (c *roleAssignmentsClient) Create(ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) {
return c.RoleAssignmentsClient.Create(ctx, scope, roleAssignmentName, parameters)
}

277
cli/azure/client/client.go Normal file
View File

@ -0,0 +1,277 @@
package client
import (
"crypto/rand"
"errors"
"fmt"
"math/big"
"time"
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure/auth"
"github.com/edgelesssys/constellation/cli/azure"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/internal/state"
)
const (
graphAPIResource = "https://graph.windows.net"
managementAPIResource = "https://management.azure.com"
)
// Client is a client for Azure.
type Client struct {
networksAPI
networkSecurityGroupsAPI
resourceGroupAPI
scaleSetsAPI
publicIPAddressesAPI
networkInterfacesAPI
virtualMachinesAPI
applicationsAPI
servicePrincipalsAPI
roleAssignmentsAPI
adReplicationLagCheckInterval time.Duration
adReplicationLagCheckMaxRetries int
nodes azure.Instances
coordinators azure.Instances
name string
uid string
resourceGroup string
location string
subscriptionID string
tenantID string
subnetID string
coordinatorsScaleSet string
nodesScaleSet string
networkSecurityGroup string
adAppObjectID string
}
// NewFromDefault creates a client with initialized clients.
func NewFromDefault(subscriptionID, tenantID string) (*Client, error) {
cred, err := azidentity.NewDefaultAzureCredential(nil)
if err != nil {
return nil, err
}
graphAuthorizer, err := getAuthorizer(graphAPIResource)
if err != nil {
return nil, err
}
managementAuthorizer, err := getAuthorizer(managementAPIResource)
if err != nil {
return nil, err
}
netAPI := armnetwork.NewVirtualNetworksClient(subscriptionID, cred, nil)
netSecGrpAPI := armnetwork.NewSecurityGroupsClient(subscriptionID, cred, nil)
resGroupAPI := armresources.NewResourceGroupsClient(subscriptionID, cred, nil)
scaleSetAPI := armcompute.NewVirtualMachineScaleSetsClient(subscriptionID, cred, nil)
publicIPAddressesAPI := armnetwork.NewPublicIPAddressesClient(subscriptionID, cred, nil)
networkInterfacesAPI := armnetwork.NewInterfacesClient(subscriptionID, cred, nil)
virtualMachinesAPI := armcompute.NewVirtualMachinesClient(subscriptionID, cred, nil)
applicationsAPI := graphrbac.NewApplicationsClient(tenantID)
applicationsAPI.Authorizer = graphAuthorizer
servicePrincipalsAPI := graphrbac.NewServicePrincipalsClient(tenantID)
servicePrincipalsAPI.Authorizer = graphAuthorizer
roleAssignmentsAPI := authorization.NewRoleAssignmentsClient(subscriptionID)
roleAssignmentsAPI.Authorizer = managementAuthorizer
return &Client{
networksAPI: &networksClient{netAPI},
networkSecurityGroupsAPI: &networkSecurityGroupsClient{netSecGrpAPI},
resourceGroupAPI: &resourceGroupsClient{resGroupAPI},
scaleSetsAPI: &virtualMachineScaleSetsClient{scaleSetAPI},
publicIPAddressesAPI: &publicIPAddressesClient{publicIPAddressesAPI},
networkInterfacesAPI: &networkInterfacesClient{networkInterfacesAPI},
applicationsAPI: &applicationsClient{&applicationsAPI},
servicePrincipalsAPI: &servicePrincipalsClient{&servicePrincipalsAPI},
roleAssignmentsAPI: &roleAssignmentsClient{&roleAssignmentsAPI},
virtualMachinesAPI: &virtualMachinesClient{virtualMachinesAPI},
subscriptionID: subscriptionID,
tenantID: tenantID,
nodes: azure.Instances{},
coordinators: azure.Instances{},
adReplicationLagCheckInterval: adReplicationLagCheckInterval,
adReplicationLagCheckMaxRetries: adReplicationLagCheckMaxRetries,
}, nil
}
// NewInitialized creates and initializes client by setting the subscriptionID, location and name
// of the Constellation.
func NewInitialized(subscriptionID, tenantID, name, location string) (*Client, error) {
client, err := NewFromDefault(subscriptionID, tenantID)
if err != nil {
return nil, err
}
err = client.init(location, name)
return client, err
}
// init initializes the client.
func (c *Client) init(location, name string) error {
c.location = location
c.name = name
uid, err := c.generateUID()
if err != nil {
return err
}
c.uid = uid
return nil
}
// GetState returns the state of the client as ConstellationState.
func (c *Client) GetState() (state.ConstellationState, error) {
var stat state.ConstellationState
stat.CloudProvider = cloudprovider.Azure.String()
if len(c.resourceGroup) == 0 {
return state.ConstellationState{}, errors.New("client has no resource group")
}
stat.AzureResourceGroup = c.resourceGroup
if c.name == "" {
return state.ConstellationState{}, errors.New("client has no name")
}
stat.Name = c.name
if len(c.uid) == 0 {
return state.ConstellationState{}, errors.New("client has no uid")
}
stat.UID = c.uid
if len(c.location) == 0 {
return state.ConstellationState{}, errors.New("client has no location")
}
stat.AzureLocation = c.location
if len(c.subscriptionID) == 0 {
return state.ConstellationState{}, errors.New("client has no subscription")
}
stat.AzureSubscription = c.subscriptionID
if len(c.tenantID) == 0 {
return state.ConstellationState{}, errors.New("client has no tenant")
}
stat.AzureTenant = c.tenantID
if len(c.subnetID) == 0 {
return state.ConstellationState{}, errors.New("client has no subnet")
}
stat.AzureSubnet = c.subnetID
if len(c.networkSecurityGroup) == 0 {
return state.ConstellationState{}, errors.New("client has no network security group")
}
stat.AzureNetworkSecurityGroup = c.networkSecurityGroup
// TODO: un-deprecate as soon as scale sets are available
// if len(c.nodesScaleSet) == 0 {
// return state.ConstellationState{}, errors.New("client has no nodes scale set")
// }
// stat.AzureNodesScaleSet = c.nodesScaleSet
// if len(c.coordinatorsScaleSet) == 0 {
// return state.ConstellationState{}, errors.New("client has no coordinators scale set")
// }
// stat.AzureCoordinatorsScaleSet = c.coordinatorsScaleSet
if len(c.nodes) == 0 {
return state.ConstellationState{}, errors.New("client has no nodes")
}
stat.AzureNodes = c.nodes
if len(c.coordinators) == 0 {
return state.ConstellationState{}, errors.New("client has no coordinators")
}
stat.AzureCoordinators = c.coordinators
// AD App Object ID does not have to be set at all times
stat.AzureADAppObjectID = c.adAppObjectID
return stat, nil
}
// SetState sets the state of the client to the handed ConstellationState.
func (c *Client) SetState(stat state.ConstellationState) error {
if stat.CloudProvider != cloudprovider.Azure.String() {
return errors.New("state is not azure state")
}
if len(stat.AzureResourceGroup) == 0 {
return errors.New("state has no resource group")
}
c.resourceGroup = stat.AzureResourceGroup
if stat.Name == "" {
return errors.New("state has no name")
}
c.name = stat.Name
if len(stat.UID) == 0 {
return errors.New("state has no uuid")
}
c.uid = stat.UID
if len(stat.AzureLocation) == 0 {
return errors.New("state has no location")
}
c.location = stat.AzureLocation
if len(stat.AzureSubscription) == 0 {
return errors.New("state has no subscription")
}
c.subscriptionID = stat.AzureSubscription
if len(stat.AzureTenant) == 0 {
return errors.New("state has no tenant")
}
c.tenantID = stat.AzureTenant
if len(stat.AzureSubnet) == 0 {
return errors.New("state has no subnet")
}
c.subnetID = stat.AzureSubnet
if len(stat.AzureNetworkSecurityGroup) == 0 {
return errors.New("state has no subnet")
}
c.networkSecurityGroup = stat.AzureNetworkSecurityGroup
// TODO: un-deprecate as soon as scale sets are available
//if len(stat.AzureNodesScaleSet) == 0 {
// return errors.New("state has no nodes scale set")
//}
//c.nodesScaleSet = stat.AzureNodesScaleSet
//if len(stat.AzureCoordinatorsScaleSet) == 0 {
// return errors.New("state has no nodes scale set")
//}
//c.coordinatorsScaleSet = stat.AzureCoordinatorsScaleSet
if len(stat.AzureNodes) == 0 {
return errors.New("state has no coordinator scale set")
}
c.nodes = stat.AzureNodes
if len(stat.AzureCoordinators) == 0 {
return errors.New("state has no coordinators")
}
c.coordinators = stat.AzureCoordinators
// AD App Object ID does not have to be set at all times
c.adAppObjectID = stat.AzureADAppObjectID
return nil
}
func (c *Client) generateUID() (string, error) {
letters := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
const uidLen = 5
uid := make([]byte, uidLen)
for i := 0; i < uidLen; i++ {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
if err != nil {
return "", err
}
uid[i] = letters[n.Int64()]
}
return string(uid), nil
}
// getAuthorizer creates an autorest.Authorizer for different Azure AD APIs using either environment variables or azure cli credentials.
func getAuthorizer(resource string) (autorest.Authorizer, error) {
authorizer, cliErr := auth.NewAuthorizerFromCLIWithResource(resource)
if cliErr == nil {
return authorizer, nil
}
authorizer, envErr := auth.NewAuthorizerFromEnvironmentWithResource(resource)
if envErr == nil {
return authorizer, nil
}
return nil, fmt.Errorf("unable to create authorizer from env or cli: %v %v", envErr, cliErr)
}

View File

@ -0,0 +1,486 @@
package client
import (
"testing"
"github.com/edgelesssys/constellation/cli/azure"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/internal/state"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSetGetState(t *testing.T) {
testCases := map[string]struct {
state state.ConstellationState
errExpected bool
}{
"valid state": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: azure.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: azure.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
// TODO: un-deprecate as soon as scale sets are available
// AzureNodesScaleSet: "node-scale-set",
// AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
},
"missing nodes": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureCoordinators: azure.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
errExpected: true,
},
"missing coordinator": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: azure.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
errExpected: true,
},
"missing name": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: azure.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: azure.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
errExpected: true,
},
"missing uid": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: azure.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: azure.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
errExpected: true,
},
"missing resource group": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: azure.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: azure.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
errExpected: true,
},
"missing location": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: azure.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: azure.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
errExpected: true,
},
"missing subscription": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: azure.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: azure.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureTenant: "tenant",
AzureLocation: "location",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
errExpected: true,
},
"missing tenant": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: azure.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: azure.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureSubscription: "subscription",
AzureLocation: "location",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
errExpected: true,
},
"missing subnet": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: azure.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: azure.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
errExpected: true,
},
"missing network security group": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: azure.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: azure.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
errExpected: true,
},
// TODO: un-deprecate as soon as scale sets are available
// "missing node scale set": {
// state: state.ConstellationState{
// CloudProvider: cloudprovider.Azure.String(),
// AzureNodes: azure.Instances{
// "0": {
// PublicIP: "ip1",
// PrivateIP: "ip2",
// },
// },
// AzureCoordinators: azure.Instances{
// "0": {
// PublicIP: "ip3",
// PrivateIP: "ip4",
// },
// },
// Name: "name",
// UID: "uid",
// AzureResourceGroup: "resource-group",
// AzureLocation: "location",
// AzureSubscription: "subscription",
// AzureTenant: "tenant",
// AzureSubnet: "azure-subnet",
// AzureNetworkSecurityGroup: "network-security-group",
// AzureCoordinatorsScaleSet: "coordinator-scale-set",
// },
// errExpected: true,
// },
// "missing coordinator scale set": {
// state: state.ConstellationState{
// CloudProvider: cloudprovider.Azure.String(),
// AzureNodes: azure.Instances{
// "0": {
// PublicIP: "ip1",
// PrivateIP: "ip2",
// },
// },
// AzureCoordinators: azure.Instances{
// "0": {
// PublicIP: "ip3",
// PrivateIP: "ip4",
// },
// },
// Name: "name",
// UID: "uid",
// AzureResourceGroup: "resource-group",
// AzureLocation: "location",
// AzureSubscription: "subscription",
// AzureTenant: "tenant",
// AzureSubnet: "azure-subnet",
// AzureNetworkSecurityGroup: "network-security-group",
// AzureNodesScaleSet: "node-scale-set",
// },
// errExpected: true,
// },
}
t.Run("SetState", func(t *testing.T) {
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := Client{}
if tc.errExpected {
assert.Error(client.SetState(tc.state))
} else {
assert.NoError(client.SetState(tc.state))
assert.Equal(tc.state.AzureNodes, client.nodes)
assert.Equal(tc.state.AzureCoordinators, client.coordinators)
assert.Equal(tc.state.Name, client.name)
assert.Equal(tc.state.UID, client.uid)
assert.Equal(tc.state.AzureResourceGroup, client.resourceGroup)
assert.Equal(tc.state.AzureLocation, client.location)
assert.Equal(tc.state.AzureSubscription, client.subscriptionID)
assert.Equal(tc.state.AzureTenant, client.tenantID)
assert.Equal(tc.state.AzureSubnet, client.subnetID)
assert.Equal(tc.state.AzureNetworkSecurityGroup, client.networkSecurityGroup)
assert.Equal(tc.state.AzureNodesScaleSet, client.nodesScaleSet)
assert.Equal(tc.state.AzureCoordinatorsScaleSet, client.coordinatorsScaleSet)
}
})
}
})
t.Run("GetState", func(t *testing.T) {
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := Client{
nodes: tc.state.AzureNodes,
coordinators: tc.state.AzureCoordinators,
name: tc.state.Name,
uid: tc.state.UID,
resourceGroup: tc.state.AzureResourceGroup,
location: tc.state.AzureLocation,
subscriptionID: tc.state.AzureSubscription,
tenantID: tc.state.AzureTenant,
subnetID: tc.state.AzureSubnet,
networkSecurityGroup: tc.state.AzureNetworkSecurityGroup,
nodesScaleSet: tc.state.AzureNodesScaleSet,
coordinatorsScaleSet: tc.state.AzureCoordinatorsScaleSet,
}
if tc.errExpected {
_, err := client.GetState()
assert.Error(err)
} else {
state, err := client.GetState()
assert.NoError(err)
assert.Equal(tc.state, state)
}
})
}
})
}
func TestSetStateCloudProvider(t *testing.T) {
assert := assert.New(t)
client := Client{}
stateMissingCloudProvider := state.ConstellationState{
AzureNodes: azure.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: azure.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
}
assert.Error(client.SetState(stateMissingCloudProvider))
stateIncorrectCloudProvider := state.ConstellationState{
CloudProvider: "incorrect",
AzureNodes: azure.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: azure.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
}
assert.Error(client.SetState(stateIncorrectCloudProvider))
}
func TestInit(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
client := Client{}
require.NoError(client.init("location", "name"))
assert.Equal("location", client.location)
assert.Equal("name", client.name)
assert.NotEmpty(client.uid)
}

281
cli/azure/client/compute.go Normal file
View File

@ -0,0 +1,281 @@
package client
import (
"context"
"errors"
"strconv"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/edgelesssys/constellation/cli/azure"
)
func (c *Client) CreateInstances(ctx context.Context, input CreateInstancesInput) error {
// Create nodes scale set
createNodesInput := CreateScaleSetInput{
Name: "constellation-scale-set-nodes-" + c.uid,
NamePrefix: c.name + "-worker-" + c.uid + "-",
Count: input.Count - 1,
InstanceType: input.InstanceType,
Image: input.Image,
UserAssingedIdentity: input.UserAssingedIdentity,
}
if err := c.createScaleSet(ctx, createNodesInput); err != nil {
return err
}
c.nodesScaleSet = createNodesInput.Name
// Create coordinator scale set
createCoordinatorsInput := CreateScaleSetInput{
Name: "constellation-scale-set-coordinators-" + c.uid,
NamePrefix: c.name + "-control-plane-" + c.uid + "-",
Count: 1,
InstanceType: input.InstanceType,
Image: input.Image,
UserAssingedIdentity: input.UserAssingedIdentity,
}
if err := c.createScaleSet(ctx, createCoordinatorsInput); err != nil {
return err
}
// Get nodes IPs
instances, err := c.getInstanceIPs(ctx, createNodesInput.Name, createNodesInput.Count)
if err != nil {
return err
}
c.nodes = instances
// Get coordinators IPs
c.coordinatorsScaleSet = createCoordinatorsInput.Name
instances, err = c.getInstanceIPs(ctx, createCoordinatorsInput.Name, createCoordinatorsInput.Count)
if err != nil {
return err
}
c.coordinators = instances
return nil
}
// CreateInstancesInput is the input for a CreateInstances operation.
type CreateInstancesInput struct {
Count int
InstanceType string
Image string
UserAssingedIdentity string
}
// CreateInstancesVMs creates instances based on standalone VMs.
// TODO: deprecate as soon as scale sets are available.
func (c *Client) CreateInstancesVMs(ctx context.Context, input CreateInstancesInput) error {
pw, err := azure.GeneratePassword()
if err != nil {
return err
}
vm := azure.VMInstance{
Name: c.name + "-control-plane-" + c.uid,
Username: "constell",
Password: pw,
Location: c.location,
InstanceType: input.InstanceType,
Image: input.Image,
}
instance, err := c.createInstanceVM(ctx, vm)
if err != nil {
return err
}
c.coordinators = azure.Instances{"0": instance}
for i := 0; i < input.Count-1; i++ {
vm := azure.VMInstance{
Name: c.name + "-node-" + strconv.Itoa(i) + c.uid,
Username: "constell",
Password: pw,
Location: c.location,
InstanceType: input.InstanceType,
Image: input.Image,
}
instance, err := c.createInstanceVM(ctx, vm)
if err != nil {
return err
}
c.nodes[strconv.Itoa(i)] = instance
}
return nil
}
// createInstanceVM creates a single VM with a public IP address
// and a network interface.
// TODO: deprecate as soon as scale sets are available.
func (c *Client) createInstanceVM(ctx context.Context, input azure.VMInstance) (azure.Instance, error) {
pubIPName := input.Name + "-pubIP"
pubIPID, err := c.createPublicIPAddress(ctx, pubIPName)
if err != nil {
return azure.Instance{}, err
}
nicName := input.Name + "-NIC"
privIP, nicID, err := c.createNIC(ctx, nicName, pubIPID)
if err != nil {
return azure.Instance{}, err
}
input.NIC = nicID
poller, err := c.virtualMachinesAPI.BeginCreateOrUpdate(ctx, c.resourceGroup, input.Name, input.Azure(), nil)
if err != nil {
return azure.Instance{}, err
}
vm, err := poller.PollUntilDone(ctx, 30*time.Second)
if err != nil {
return azure.Instance{}, err
}
if vm.Identity == nil || vm.Identity.PrincipalID == nil {
return azure.Instance{}, errors.New("virtual machine was created without system managed identity")
}
if err := c.assignResourceGroupRole(ctx, *vm.Identity.PrincipalID, virtualMachineContributorRoleDefinitionID); err != nil {
return azure.Instance{}, err
}
res, err := c.publicIPAddressesAPI.Get(ctx, c.resourceGroup, pubIPName, nil)
if err != nil {
return azure.Instance{}, err
}
return azure.Instance{PublicIP: *res.PublicIPAddressesClientGetResult.PublicIPAddress.Properties.IPAddress, PrivateIP: privIP}, nil
}
func (c *Client) createScaleSet(ctx context.Context, input CreateScaleSetInput) error {
// TODO: Generating a random password to be able
// to create the scale set. This is a temporary fix.
// We need to think about azure access at some point.
pw, err := azure.GeneratePassword()
if err != nil {
return err
}
scaleSet := azure.ScaleSet{
Name: input.Name,
NamePrefix: input.NamePrefix,
Location: c.location,
InstanceType: input.InstanceType,
Count: int64(input.Count),
Username: "constellation",
SubnetID: c.subnetID,
NetworkSecurityGroup: c.networkSecurityGroup,
Image: input.Image,
Password: pw,
UserAssignedIdentity: input.UserAssingedIdentity,
}.Azure()
poller, err := c.scaleSetsAPI.BeginCreateOrUpdate(
ctx, c.resourceGroup, input.Name,
scaleSet,
nil,
)
if err != nil {
return err
}
_, err = poller.PollUntilDone(ctx, 30*time.Second)
if err != nil {
return err
}
return nil
}
func (c *Client) getInstanceIPs(ctx context.Context, scaleSet string, count int) (azure.Instances, error) {
instances := azure.Instances{}
for i := 0; i < count; i++ {
// get public ip address
var publicIPAddress string
pager := c.publicIPAddressesAPI.ListVirtualMachineScaleSetVMPublicIPAddresses(
c.resourceGroup, scaleSet, strconv.Itoa(i), scaleSet, scaleSet, nil)
// We always need one pager.NextPage, since calling
// pager.PageResponse() directly return no result.
// We expect to get one page with one entry for each VM.
for pager.NextPage(ctx) {
for _, v := range pager.PageResponse().Value {
if v.Properties != nil && v.Properties.IPAddress != nil {
publicIPAddress = *v.Properties.IPAddress
break
}
}
}
// get private ip address
var privateIPAddress string
res, err := c.networkInterfacesAPI.GetVirtualMachineScaleSetNetworkInterface(
ctx, c.resourceGroup, scaleSet, strconv.Itoa(i), scaleSet, nil)
if err != nil {
return nil, err
}
configs := res.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResult.Interface.Properties.IPConfigurations
for _, config := range configs {
privateIPAddress = *config.Properties.PrivateIPAddress
break
}
instance := azure.Instance{
PrivateIP: privateIPAddress,
PublicIP: publicIPAddress,
}
instances[strconv.Itoa(i)] = instance
}
return instances, nil
}
// CreateScaleSetInput is the input for a CreateScaleSet operation.
type CreateScaleSetInput struct {
Name string
NamePrefix string
Count int
InstanceType string
Image string
UserAssingedIdentity string
}
// CreateResourceGroup creates a resource group.
func (c *Client) CreateResourceGroup(ctx context.Context) error {
_, err := c.resourceGroupAPI.CreateOrUpdate(ctx, c.name+"-"+c.uid,
armresources.ResourceGroup{
Location: &c.location,
}, nil)
if err != nil {
return err
}
c.resourceGroup = c.name + "-" + c.uid
return nil
}
// TerminateResourceGroup terminates a resource group.
func (c *Client) TerminateResourceGroup(ctx context.Context) error {
if c.resourceGroup == "" {
return nil
}
poller, err := c.resourceGroupAPI.BeginDelete(ctx, c.resourceGroup, nil)
if err != nil {
return err
}
if _, err = poller.PollUntilDone(ctx, 30*time.Second); err != nil {
return err
}
c.nodes = nil
c.coordinators = nil
c.resourceGroup = ""
c.subnetID = ""
c.networkSecurityGroup = ""
c.nodesScaleSet = ""
c.coordinatorsScaleSet = ""
return nil
}

View File

@ -0,0 +1,374 @@
package client
import (
"context"
"errors"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/edgelesssys/constellation/cli/azure"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCreateResourceGroup(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
resourceGroupAPI resourceGroupAPI
errExpected bool
}{
"successful create": {
resourceGroupAPI: stubResourceGroupAPI{},
},
"failed create": {
resourceGroupAPI: stubResourceGroupAPI{createErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
location: "location",
name: "name",
uid: "uid",
resourceGroupAPI: tc.resourceGroupAPI,
nodes: make(azure.Instances),
coordinators: make(azure.Instances),
}
if tc.errExpected {
assert.Error(client.CreateResourceGroup(ctx))
} else {
assert.NoError(client.CreateResourceGroup(ctx))
assert.Equal(client.name+"-"+client.uid, client.resourceGroup)
}
})
}
}
func TestTerminateResourceGroup(t *testing.T) {
someErr := errors.New("failed")
clientWithResourceGroup := Client{
resourceGroup: "name",
location: "location",
name: "name",
uid: "uid",
subnetID: "subnet",
nodesScaleSet: "node-scale-set",
coordinatorsScaleSet: "coordinator-scale-set",
nodes: azure.Instances{
"0": {
PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1",
},
},
coordinators: azure.Instances{
"0": {
PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1",
},
},
}
testCases := map[string]struct {
resourceGroup string
resourceGroupAPI resourceGroupAPI
client Client
errExpected bool
}{
"successful terminate": {
resourceGroupAPI: stubResourceGroupAPI{},
client: clientWithResourceGroup,
},
"no resource group to terminate": {
resourceGroupAPI: stubResourceGroupAPI{},
client: Client{},
resourceGroup: "",
},
"failed terminate": {
resourceGroupAPI: stubResourceGroupAPI{terminateErr: someErr},
client: clientWithResourceGroup,
errExpected: true,
},
"failed to poll terminate response": {
resourceGroupAPI: stubResourceGroupAPI{stubResponse: stubResourceGroupsDeletePollerResponse{pollerErr: someErr}},
client: clientWithResourceGroup,
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
tc.client.resourceGroupAPI = tc.resourceGroupAPI
ctx := context.Background()
if tc.errExpected {
assert.Error(tc.client.TerminateResourceGroup(ctx))
return
}
assert.NoError(tc.client.TerminateResourceGroup(ctx))
assert.Empty(tc.client.resourceGroup)
assert.Empty(tc.client.subnetID)
assert.Empty(tc.client.nodes)
assert.Empty(tc.client.coordinators)
assert.Empty(tc.client.nodesScaleSet)
assert.Empty(tc.client.coordinatorsScaleSet)
})
}
}
func TestCreateInstances(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
publicIPAddressesAPI publicIPAddressesAPI
networkInterfacesAPI networkInterfacesAPI
scaleSetsAPI scaleSetsAPI
resourceGroupAPI resourceGroupAPI
roleAssignmentsAPI roleAssignmentsAPI
createInstancesInput CreateInstancesInput
errExpected bool
}{
"successful create": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
scaleSetsAPI: stubScaleSetsAPI{
stubResponse: stubVirtualMachineScaleSetsCreateOrUpdatePollerResponse{
pollResponse: armcompute.VirtualMachineScaleSetsClientCreateOrUpdateResponse{
VirtualMachineScaleSetsClientCreateOrUpdateResult: armcompute.VirtualMachineScaleSetsClientCreateOrUpdateResult{
VirtualMachineScaleSet: armcompute.VirtualMachineScaleSet{Identity: &armcompute.VirtualMachineScaleSetIdentity{PrincipalID: to.StringPtr("principal-id")}},
},
},
},
},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
Count: 3,
InstanceType: "type",
Image: "image",
UserAssingedIdentity: "identity",
},
},
"error when creating scale set": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
scaleSetsAPI: stubScaleSetsAPI{createErr: someErr},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
Count: 3,
InstanceType: "type",
Image: "image",
UserAssingedIdentity: "identity",
},
errExpected: true,
},
"error when polling create scale set response": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
scaleSetsAPI: stubScaleSetsAPI{stubResponse: stubVirtualMachineScaleSetsCreateOrUpdatePollerResponse{pollErr: someErr}},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
Count: 3,
InstanceType: "type",
Image: "image",
UserAssingedIdentity: "identity",
},
errExpected: true,
},
"error when retrieving private IPs": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{getErr: someErr},
scaleSetsAPI: stubScaleSetsAPI{},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
Count: 3,
InstanceType: "type",
Image: "image",
UserAssingedIdentity: "identity",
},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
location: "location",
name: "name",
uid: "uid",
resourceGroup: "name",
publicIPAddressesAPI: tc.publicIPAddressesAPI,
networkInterfacesAPI: tc.networkInterfacesAPI,
scaleSetsAPI: tc.scaleSetsAPI,
resourceGroupAPI: tc.resourceGroupAPI,
roleAssignmentsAPI: tc.roleAssignmentsAPI,
nodes: make(azure.Instances),
coordinators: make(azure.Instances),
}
if tc.errExpected {
assert.Error(client.CreateInstances(ctx, tc.createInstancesInput))
} else {
assert.NoError(client.CreateInstances(ctx, tc.createInstancesInput))
assert.Equal(1, len(client.coordinators))
assert.Equal(tc.createInstancesInput.Count-1, len(client.nodes))
assert.NotEmpty(client.nodes["0"].PrivateIP)
assert.NotEmpty(client.nodes["0"].PublicIP)
assert.NotEmpty(client.coordinators["0"].PrivateIP)
assert.NotEmpty(client.coordinators["0"].PublicIP)
}
})
}
}
// TODO: deprecate as soon as scale sets are available.
func TestCreateInstancesVMs(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
publicIPAddressesAPI publicIPAddressesAPI
networkInterfacesAPI networkInterfacesAPI
virtualMachinesAPI virtualMachinesAPI
resourceGroupAPI resourceGroupAPI
roleAssignmentsAPI roleAssignmentsAPI
createInstancesInput CreateInstancesInput
errExpected bool
}{
"successful create": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
virtualMachinesAPI: stubVirtualMachinesAPI{
stubResponse: stubVirtualMachinesClientCreateOrUpdatePollerResponse{
pollResponse: armcompute.VirtualMachinesClientCreateOrUpdateResponse{VirtualMachinesClientCreateOrUpdateResult: armcompute.VirtualMachinesClientCreateOrUpdateResult{
VirtualMachine: armcompute.VirtualMachine{
Identity: &armcompute.VirtualMachineIdentity{PrincipalID: to.StringPtr("principal-id")},
},
}},
},
},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
Count: 3,
InstanceType: "type",
Image: "image",
},
},
"error when creating scale set": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
virtualMachinesAPI: stubVirtualMachinesAPI{createErr: someErr},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
Count: 3,
InstanceType: "type",
Image: "image",
},
errExpected: true,
},
"error when polling create scale set response": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
virtualMachinesAPI: stubVirtualMachinesAPI{stubResponse: stubVirtualMachinesClientCreateOrUpdatePollerResponse{pollErr: someErr}},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
Count: 3,
InstanceType: "type",
Image: "image",
},
errExpected: true,
},
"error when creating NIC": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{createErr: someErr},
virtualMachinesAPI: stubVirtualMachinesAPI{},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
Count: 3,
InstanceType: "type",
Image: "image",
},
errExpected: true,
},
"error when creating public IP": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{createErr: someErr},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
virtualMachinesAPI: stubVirtualMachinesAPI{},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
Count: 3,
InstanceType: "type",
Image: "image",
},
errExpected: true,
},
"error when retrieving public IP": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{getErr: someErr},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
virtualMachinesAPI: stubVirtualMachinesAPI{},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
Count: 3,
InstanceType: "type",
Image: "image",
},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
ctx := context.Background()
client := Client{
location: "location",
name: "name",
uid: "uid",
resourceGroup: "name",
publicIPAddressesAPI: tc.publicIPAddressesAPI,
networkInterfacesAPI: tc.networkInterfacesAPI,
virtualMachinesAPI: tc.virtualMachinesAPI,
resourceGroupAPI: tc.resourceGroupAPI,
roleAssignmentsAPI: tc.roleAssignmentsAPI,
nodes: make(azure.Instances),
coordinators: make(azure.Instances),
}
if tc.errExpected {
assert.Error(client.CreateInstancesVMs(ctx, tc.createInstancesInput))
return
}
require.NoError(client.CreateInstancesVMs(ctx, tc.createInstancesInput))
assert.Equal(1, len(client.coordinators))
assert.Equal(tc.createInstancesInput.Count-1, len(client.nodes))
assert.NotEmpty(client.nodes["0"].PrivateIP)
assert.NotEmpty(client.nodes["0"].PublicIP)
assert.NotEmpty(client.coordinators["0"].PrivateIP)
assert.NotEmpty(client.coordinators["0"].PublicIP)
})
}
}
func newSuccessfulResourceGroupStub() *stubResourceGroupAPI {
return &stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
}
}

166
cli/azure/client/network.go Normal file
View File

@ -0,0 +1,166 @@
package client
import (
"context"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/edgelesssys/constellation/cli/cloud/cloudtypes"
)
type createNetworkInput struct {
name string
location string
addressSpace string
}
// CreateVirtualNetwork creates a virtual network.
func (c *Client) CreateVirtualNetwork(ctx context.Context) error {
createNetworkInput := createNetworkInput{
name: "constellation-" + c.uid,
location: c.location,
addressSpace: "172.20.0.0/16",
}
poller, err := c.networksAPI.BeginCreateOrUpdate(
ctx, c.resourceGroup, createNetworkInput.name,
armnetwork.VirtualNetwork{
Name: to.StringPtr(createNetworkInput.name), // this is supposed to be read-only
Location: to.StringPtr(createNetworkInput.location),
Properties: &armnetwork.VirtualNetworkPropertiesFormat{
AddressSpace: &armnetwork.AddressSpace{
AddressPrefixes: []*string{
to.StringPtr(createNetworkInput.addressSpace),
},
},
Subnets: []*armnetwork.Subnet{
{
Name: to.StringPtr("default"),
Properties: &armnetwork.SubnetPropertiesFormat{
AddressPrefix: to.StringPtr(createNetworkInput.addressSpace),
},
},
},
},
},
nil,
)
if err != nil {
return err
}
resp, err := poller.PollUntilDone(ctx, 30*time.Second)
if err != nil {
return err
}
c.subnetID = *resp.VirtualNetworksClientCreateOrUpdateResult.VirtualNetwork.Properties.Subnets[0].ID
return nil
}
type createNetworkSecurityGroupInput struct {
name string
location string
rules []*armnetwork.SecurityRule
}
// CreateSecurityGroup creates a security group containing firewall rules.
func (c *Client) CreateSecurityGroup(ctx context.Context, input NetworkSecurityGroupInput) error {
rules := input.Ingress.Azure()
createNetworkSecurityGroupInput := createNetworkSecurityGroupInput{
name: "constellation-security-group-" + c.uid,
location: c.location,
rules: rules,
}
poller, err := c.networkSecurityGroupsAPI.BeginCreateOrUpdate(
ctx, c.resourceGroup, createNetworkSecurityGroupInput.name,
armnetwork.SecurityGroup{
Name: to.StringPtr(createNetworkSecurityGroupInput.name),
Location: to.StringPtr(createNetworkSecurityGroupInput.location),
Properties: &armnetwork.SecurityGroupPropertiesFormat{
SecurityRules: createNetworkSecurityGroupInput.rules,
},
},
nil,
)
if err != nil {
return err
}
pollerResp, err := poller.PollUntilDone(ctx, 30*time.Second)
if err != nil {
return err
}
c.networkSecurityGroup = *pollerResp.SecurityGroupsClientCreateOrUpdateResult.SecurityGroup.ID
return nil
}
// createNIC creates a network interface that references a public IP address.
// TODO: deprecate as soon as scale sets are available.
func (c *Client) createNIC(ctx context.Context, name, publicIPAddressID string) (ip string, id string, err error) {
poller, err := c.networkInterfacesAPI.BeginCreateOrUpdate(
ctx, c.resourceGroup, name,
armnetwork.Interface{
Location: to.StringPtr(c.location),
Properties: &armnetwork.InterfacePropertiesFormat{
NetworkSecurityGroup: &armnetwork.SecurityGroup{
ID: to.StringPtr(c.networkSecurityGroup),
},
IPConfigurations: []*armnetwork.InterfaceIPConfiguration{
{
Name: to.StringPtr(name),
Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{
Subnet: &armnetwork.Subnet{
ID: to.StringPtr(c.subnetID),
},
PublicIPAddress: &armnetwork.PublicIPAddress{
ID: to.StringPtr(publicIPAddressID),
},
},
},
},
},
},
nil,
)
if err != nil {
return "", "", err
}
pollerResp, err := poller.PollUntilDone(ctx, 30*time.Second)
if err != nil {
return "", "", err
}
netInterface := pollerResp.InterfacesClientCreateOrUpdateResult.Interface
return *netInterface.Properties.IPConfigurations[0].Properties.PrivateIPAddress,
*netInterface.ID,
nil
}
// createPublicIPAddress creates a public IP address.
// TODO: deprecate as soon as scale sets are available.
func (c *Client) createPublicIPAddress(ctx context.Context, name string) (string, error) {
poller, err := c.publicIPAddressesAPI.BeginCreateOrUpdate(
ctx, c.resourceGroup, name,
armnetwork.PublicIPAddress{
Location: to.StringPtr(c.location),
},
nil,
)
if err != nil {
return "", err
}
pollerResp, err := poller.PollUntilDone(ctx, 30*time.Second)
if err != nil {
return "", err
}
return *pollerResp.PublicIPAddressesClientCreateOrUpdateResult.PublicIPAddress.ID, nil
}
// NetworkSecurityGroupInput defines firewall rules to be set.
type NetworkSecurityGroupInput struct {
Ingress cloudtypes.Firewall
Egress cloudtypes.Firewall
}

View File

@ -0,0 +1,220 @@
package client
import (
"context"
"errors"
"testing"
"github.com/edgelesssys/constellation/cli/azure"
"github.com/edgelesssys/constellation/cli/cloud/cloudtypes"
"github.com/stretchr/testify/assert"
)
func TestCreateVirtualNetwork(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
networksAPI networksAPI
errExpected bool
}{
"successful create": {
networksAPI: stubNetworksAPI{},
},
"failed to get response from successful create": {
networksAPI: stubNetworksAPI{stubResponse: stubVirtualNetworksCreateOrUpdatePollerResponse{pollerErr: someErr}},
errExpected: true,
},
"failed create": {
networksAPI: stubNetworksAPI{createErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
resourceGroup: "resource-group",
location: "location",
name: "name",
uid: "uid",
networksAPI: tc.networksAPI,
nodes: make(azure.Instances),
coordinators: make(azure.Instances),
}
if tc.errExpected {
assert.Error(client.CreateVirtualNetwork(ctx))
} else {
assert.NoError(client.CreateVirtualNetwork(ctx))
assert.NotEmpty(client.subnetID)
}
})
}
}
func TestCreateSecurityGroup(t *testing.T) {
someErr := errors.New("failed")
testNetworkSecurityGroupInput := NetworkSecurityGroupInput{
Ingress: cloudtypes.Firewall{
{
Name: "test-1",
Description: "test-1 description",
Protocol: "tcp",
IPRange: "192.0.2.0/24",
Port: 9000,
},
{
Name: "test-2",
Description: "test-2 description",
Protocol: "udp",
IPRange: "192.0.2.0/24",
Port: 51820,
},
},
Egress: cloudtypes.Firewall{},
}
testCases := map[string]struct {
networkSecurityGroupsAPI networkSecurityGroupsAPI
errExpected bool
}{
"successful create": {
networkSecurityGroupsAPI: stubNetworkSecurityGroupsAPI{},
},
"failed to get response from successful create": {
networkSecurityGroupsAPI: stubNetworkSecurityGroupsAPI{stubPoller: stubNetworkSecurityGroupsCreateOrUpdatePollerResponse{pollerErr: someErr}},
errExpected: true,
},
"failed create": {
networkSecurityGroupsAPI: stubNetworkSecurityGroupsAPI{createErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
resourceGroup: "resource-group",
location: "location",
name: "name",
uid: "uid",
nodes: make(azure.Instances),
coordinators: make(azure.Instances),
networkSecurityGroupsAPI: tc.networkSecurityGroupsAPI,
}
if tc.errExpected {
assert.Error(client.CreateSecurityGroup(ctx, testNetworkSecurityGroupInput))
} else {
assert.NoError(client.CreateSecurityGroup(ctx, testNetworkSecurityGroupInput))
assert.Equal("network-security-group-id", client.networkSecurityGroup)
}
})
}
}
// TODO: deprecate as soon as scale sets are available.
func TestCreateNIC(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
networkInterfacesAPI networkInterfacesAPI
name string
publicIPAddressID string
errExpected bool
}{
"successful create": {
networkInterfacesAPI: stubNetworkInterfacesAPI{},
name: "nic-name",
publicIPAddressID: "pubIP-id",
},
"failed to get response from successful create": {
networkInterfacesAPI: stubNetworkInterfacesAPI{stubResp: stubInterfacesClientCreateOrUpdatePollerResponse{pollErr: someErr}},
errExpected: true,
},
"failed create": {
networkInterfacesAPI: stubNetworkInterfacesAPI{createErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
resourceGroup: "resource-group",
location: "location",
name: "name",
uid: "uid",
nodes: make(azure.Instances),
coordinators: make(azure.Instances),
networkInterfacesAPI: tc.networkInterfacesAPI,
}
ip, id, err := client.createNIC(ctx, tc.name, tc.publicIPAddressID)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.NotEmpty(ip)
assert.NotEmpty(id)
}
})
}
}
// TODO: deprecate as soon as scale sets are available.
func TestCreatePublicIPAddress(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
publicIPAddressesAPI publicIPAddressesAPI
name string
errExpected bool
}{
"successful create": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
name: "nic-name",
},
"failed to get response from successful create": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{stubCreateResponse: stubPublicIPAddressesClientCreateOrUpdatePollerResponse{pollErr: someErr}},
errExpected: true,
},
"failed create": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{createErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
resourceGroup: "resource-group",
location: "location",
name: "name",
uid: "uid",
nodes: make(azure.Instances),
coordinators: make(azure.Instances),
publicIPAddressesAPI: tc.publicIPAddressesAPI,
}
id, err := client.createPublicIPAddress(ctx, tc.name)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.NotEmpty(id)
}
})
}
}

135
cli/azure/instances.go Normal file
View File

@ -0,0 +1,135 @@
package azure
// copy of ec2/instances.go
// TODO(katexochen): refactor into mulitcloud package.
import (
"errors"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
)
// Instance is a azure instance.
type Instance struct {
PublicIP string
PrivateIP string
}
// Instances is a map of azure Instances. The ID of an instance is used as key.
type Instances map[string]Instance
// IDs returns the IDs of all instances of the Constellation.
func (i Instances) IDs() []string {
var ids []string
for id := range i {
ids = append(ids, id)
}
return ids
}
// PublicIPs returns the public IPs of all the instances of the Constellation.
func (i Instances) PublicIPs() []string {
var ips []string
for _, instance := range i {
ips = append(ips, instance.PublicIP)
}
return ips
}
// PrivateIPs returns the private IPs of all the instances of the Constellation.
func (i Instances) PrivateIPs() []string {
var ips []string
for _, instance := range i {
ips = append(ips, instance.PrivateIP)
}
return ips
}
// GetOne return anyone instance out of the instances and its ID.
func (i Instances) GetOne() (string, Instance, error) {
for id, instance := range i {
return id, instance, nil
}
return "", Instance{}, errors.New("map is empty")
}
// GetOthers returns all instances but the one with the handed ID.
func (i Instances) GetOthers(id string) Instances {
others := make(Instances)
for key, instance := range i {
if key != id {
others[key] = instance
}
}
return others
}
// TODO: deprecate as soon as scale sets are available.
type VMInstance struct {
Name string
Location string
InstanceType string
Username string
Password string
NIC string
Image string
}
// TODO: deprecate as soon as scale sets are available.
func (i VMInstance) Azure() armcompute.VirtualMachine {
return armcompute.VirtualMachine{
Name: to.StringPtr(i.Name),
Location: to.StringPtr(i.Location),
Properties: &armcompute.VirtualMachineProperties{
HardwareProfile: &armcompute.HardwareProfile{
VMSize: (*armcompute.VirtualMachineSizeTypes)(to.StringPtr(i.InstanceType)),
},
OSProfile: &armcompute.OSProfile{
ComputerName: to.StringPtr(i.Name),
AdminPassword: to.StringPtr(i.Password),
AdminUsername: to.StringPtr(i.Username),
},
SecurityProfile: &armcompute.SecurityProfile{
UefiSettings: &armcompute.UefiSettings{
SecureBootEnabled: to.BoolPtr(true),
VTpmEnabled: to.BoolPtr(true),
},
SecurityType: armcompute.SecurityTypesConfidentialVM.ToPtr(),
},
NetworkProfile: &armcompute.NetworkProfile{
NetworkInterfaces: []*armcompute.NetworkInterfaceReference{
{
ID: to.StringPtr(i.NIC),
},
},
},
StorageProfile: &armcompute.StorageProfile{
OSDisk: &armcompute.OSDisk{
CreateOption: armcompute.DiskCreateOptionTypesFromImage.ToPtr(),
ManagedDisk: &armcompute.ManagedDiskParameters{
StorageAccountType: armcompute.StorageAccountTypesPremiumLRS.ToPtr(),
SecurityProfile: &armcompute.VMDiskSecurityProfile{
SecurityEncryptionType: armcompute.SecurityEncryptionTypesVMGuestStateOnly.ToPtr(),
},
},
},
ImageReference: &armcompute.ImageReference{
Publisher: to.StringPtr("0001-com-ubuntu-confidential-vm-focal"),
Offer: to.StringPtr("canonical"),
SKU: to.StringPtr("20_04-lts-gen2"),
Version: to.StringPtr("latest"),
},
},
DiagnosticsProfile: &armcompute.DiagnosticsProfile{
BootDiagnostics: &armcompute.BootDiagnostics{
Enabled: to.BoolPtr(true),
},
},
},
Identity: &armcompute.VirtualMachineIdentity{
Type: armcompute.ResourceIdentityTypeSystemAssigned.ToPtr(),
},
}
}

View File

@ -0,0 +1,71 @@
package azure
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestIDs(t *testing.T) {
assert := assert.New(t)
testState := testInstances()
expectedIDs := []string{"id-9", "id-10", "id-11", "id-12"}
assert.ElementsMatch(expectedIDs, testState.IDs())
}
func TestPublicIPs(t *testing.T) {
assert := assert.New(t)
testState := testInstances()
expectedIPs := []string{"192.0.2.1", "192.0.2.3", "192.0.2.5", "192.0.2.7"}
assert.ElementsMatch(expectedIPs, testState.PublicIPs())
}
func TestPrivateIPs(t *testing.T) {
assert := assert.New(t)
testState := testInstances()
expectedIPs := []string{"192.0.2.2", "192.0.2.4", "192.0.2.6", "192.0.2.8"}
assert.ElementsMatch(expectedIPs, testState.PrivateIPs())
}
func TestGetOne(t *testing.T) {
assert := assert.New(t)
testState := testInstances()
id, instance, err := testState.GetOne()
assert.NoError(err)
assert.Contains(testState, id)
assert.Equal(testState[id], instance)
}
func TestGetOthers(t *testing.T) {
assert := assert.New(t)
testCases := testInstances().IDs()
for _, id := range testCases {
others := testInstances().GetOthers(id)
assert.NotContains(others, id)
expectedInstances := testInstances()
delete(expectedInstances, id)
assert.ElementsMatch(others.IDs(), expectedInstances.IDs())
}
}
func testInstances() Instances {
return Instances{
"id-9": {
PublicIP: "192.0.2.1",
PrivateIP: "192.0.2.2",
},
"id-10": {
PublicIP: "192.0.2.3",
PrivateIP: "192.0.2.4",
},
"id-11": {
PublicIP: "192.0.2.5",
PrivateIP: "192.0.2.6",
},
"id-12": {
PublicIP: "192.0.2.7",
PrivateIP: "192.0.2.8",
},
}
}

View File

@ -0,0 +1,13 @@
package azure
import "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
// InstanceTypes are valid Azure instance types.
// Normally, this would be string(armcompute.VirtualMachineSizeTypesStandardD4SV3),
// but currently needed instances are not in SDK.
var InstanceTypes = []string{
string(armcompute.VirtualMachineSizeTypesStandardD4SV3),
"Standard_DC2as_v5",
"Standard_DC4as_v5",
"Standard_DC8as_v5",
}

123
cli/azure/scaleset.go Normal file
View File

@ -0,0 +1,123 @@
package azure
import (
"crypto/rand"
"math/big"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
)
// ScaleSet defines a Azure scale set.
type ScaleSet struct {
Name string
NamePrefix string
Location string
InstanceType string
Count int64
Username string
SubnetID string
NetworkSecurityGroup string
Password string
Image string
UserAssignedIdentity string
}
// Azure returns the Azure representation of ScaleSet.
func (s ScaleSet) Azure() armcompute.VirtualMachineScaleSet {
return armcompute.VirtualMachineScaleSet{
Name: to.StringPtr(s.Name),
Location: to.StringPtr(s.Location),
SKU: &armcompute.SKU{
Name: to.StringPtr(s.InstanceType),
Capacity: to.Int64Ptr(s.Count),
},
Properties: &armcompute.VirtualMachineScaleSetProperties{
Overprovision: to.BoolPtr(false),
UpgradePolicy: &armcompute.UpgradePolicy{
Mode: armcompute.UpgradeModeManual.ToPtr(),
AutomaticOSUpgradePolicy: &armcompute.AutomaticOSUpgradePolicy{
EnableAutomaticOSUpgrade: to.BoolPtr(false),
DisableAutomaticRollback: to.BoolPtr(false),
},
},
VirtualMachineProfile: &armcompute.VirtualMachineScaleSetVMProfile{
OSProfile: &armcompute.VirtualMachineScaleSetOSProfile{
ComputerNamePrefix: to.StringPtr(s.NamePrefix),
AdminUsername: to.StringPtr(s.Username),
AdminPassword: to.StringPtr(s.Password),
LinuxConfiguration: &armcompute.LinuxConfiguration{},
},
StorageProfile: &armcompute.VirtualMachineScaleSetStorageProfile{
ImageReference: &armcompute.ImageReference{
ID: to.StringPtr(s.Image),
},
},
NetworkProfile: &armcompute.VirtualMachineScaleSetNetworkProfile{
NetworkInterfaceConfigurations: []*armcompute.VirtualMachineScaleSetNetworkConfiguration{
{
Name: to.StringPtr(s.Name),
Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{
Primary: to.BoolPtr(true),
EnableIPForwarding: to.BoolPtr(true),
IPConfigurations: []*armcompute.VirtualMachineScaleSetIPConfiguration{
{
Name: to.StringPtr(s.Name),
Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{
Subnet: &armcompute.APIEntityReference{
ID: to.StringPtr(s.SubnetID),
},
PublicIPAddressConfiguration: &armcompute.VirtualMachineScaleSetPublicIPAddressConfiguration{
Name: to.StringPtr(s.Name),
Properties: &armcompute.VirtualMachineScaleSetPublicIPAddressConfigurationProperties{
IdleTimeoutInMinutes: to.Int32Ptr(15), // default per https://docs.microsoft.com/en-us/azure/virtual-machine-scale-sets/virtual-machine-scale-sets-networking#creating-a-scale-set-with-public-ip-per-virtual-machine
},
},
},
},
},
NetworkSecurityGroup: &armcompute.SubResource{
ID: to.StringPtr(s.NetworkSecurityGroup),
},
},
},
},
},
SecurityProfile: &armcompute.SecurityProfile{
SecurityType: armcompute.SecurityTypesTrustedLaunch.ToPtr(),
UefiSettings: &armcompute.UefiSettings{VTpmEnabled: to.BoolPtr(true)},
},
DiagnosticsProfile: &armcompute.DiagnosticsProfile{
BootDiagnostics: &armcompute.BootDiagnostics{
Enabled: to.BoolPtr(true),
},
},
},
},
Identity: &armcompute.VirtualMachineScaleSetIdentity{
Type: armcompute.ResourceIdentityTypeUserAssigned.ToPtr(),
UserAssignedIdentities: map[string]*armcompute.VirtualMachineScaleSetIdentityUserAssignedIdentitiesValue{
s.UserAssignedIdentity: {},
},
},
}
}
// GeneratePassword is a helper function to generate a random password
// for Azure's scale set.
func GeneratePassword() (string, error) {
letters := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
pwLen := 16
pw := make([]byte, 0, pwLen)
for i := 0; i < pwLen; i++ {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
if err != nil {
return "", err
}
pw = append(pw, letters[n.Int64()])
}
// bypass password rules
pw = append(pw, []byte("Aa1!")...)
return string(pw), nil
}

111
cli/azure/scaleset_test.go Normal file
View File

@ -0,0 +1,111 @@
package azure
import (
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFirewallPermissions(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
scaleSet := ScaleSet{
Name: "name",
NamePrefix: "constellation-",
Location: "UK South",
InstanceType: "Standard_D2s_v3",
Count: 3,
Username: "constellation",
SubnetID: "subnet-id",
NetworkSecurityGroup: "network-security-group",
Password: "password",
Image: "image",
UserAssignedIdentity: "user-identity",
}
scaleSetAzure := scaleSet.Azure()
require.NotNil(scaleSetAzure.Name)
assert.Equal(scaleSet.Name, *scaleSetAzure.Name)
require.NotNil(scaleSetAzure.Location)
assert.Equal(scaleSet.Location, *scaleSetAzure.Location)
require.NotNil(scaleSetAzure.SKU)
require.NotNil(scaleSetAzure.SKU.Name)
assert.Equal(scaleSet.InstanceType, *scaleSetAzure.SKU.Name)
require.NotNil(scaleSetAzure.SKU.Capacity)
assert.Equal(scaleSet.Count, *scaleSetAzure.SKU.Capacity)
require.NotNil(scaleSetAzure.Properties)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.OSProfile)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.ComputerNamePrefix)
assert.Equal(scaleSet.NamePrefix, *scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.ComputerNamePrefix)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.AdminUsername)
assert.Equal(scaleSet.Username, *scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.AdminUsername)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.AdminPassword)
assert.Equal(scaleSet.Password, *scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.AdminPassword)
// Verify image
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.StorageProfile)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.StorageProfile.ImageReference)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.StorageProfile.ImageReference.ID)
assert.Equal(scaleSet.Image, *scaleSetAzure.Properties.VirtualMachineProfile.StorageProfile.ImageReference.ID)
// Verify network
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.NetworkProfile)
require.Len(scaleSetAzure.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations, 1)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations[0])
networkConfig := scaleSetAzure.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations[0]
require.NotNil(networkConfig.Name)
assert.Equal(scaleSet.Name, *networkConfig.Name)
require.NotNil(networkConfig.Properties)
require.Len(networkConfig.Properties.IPConfigurations, 1)
require.NotNil(networkConfig.Properties.IPConfigurations[0])
ipConfig := networkConfig.Properties.IPConfigurations[0]
require.NotNil(ipConfig.Name)
assert.Equal(scaleSet.Name, *ipConfig.Name)
require.NotNil(ipConfig.Properties)
require.NotNil(ipConfig.Properties.Subnet)
require.NotNil(ipConfig.Properties.Subnet.ID)
assert.Equal(scaleSet.SubnetID, *ipConfig.Properties.Subnet.ID)
require.NotNil(networkConfig.Properties.NetworkSecurityGroup)
assert.Equal(scaleSet.NetworkSecurityGroup, *networkConfig.Properties.NetworkSecurityGroup.ID)
// Verify vTPM
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile.SecurityType)
assert.Equal(armcompute.SecurityTypesTrustedLaunch, *scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile.SecurityType)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile.UefiSettings)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile.UefiSettings.VTpmEnabled)
assert.True(*scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile.UefiSettings.VTpmEnabled)
// Verify UserAssignedIdentity
require.NotNil(scaleSetAzure.Identity)
require.NotNil(scaleSetAzure.Identity.Type)
assert.Equal(armcompute.ResourceIdentityTypeUserAssigned, *scaleSetAzure.Identity.Type)
require.Len(scaleSetAzure.Identity.UserAssignedIdentities, 1)
assert.Contains(scaleSetAzure.Identity.UserAssignedIdentities, scaleSet.UserAssignedIdentity)
}
func TestGeneratePassword(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
pw, err := GeneratePassword()
require.NoError(err)
assert.Len(pw, 20)
}

View File

@ -0,0 +1,88 @@
package cloudtypes
import (
"fmt"
"strconv"
"strings"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
computepb "google.golang.org/genproto/googleapis/cloud/compute/v1"
"google.golang.org/protobuf/proto"
)
type FirewallRule struct {
Name string
Description string
Protocol string
IPRange string
Port int
}
type Firewall []FirewallRule
func (f Firewall) GCP() []*computepb.Firewall {
var fw []*computepb.Firewall
for _, rule := range f {
var destRange []string = nil
if rule.IPRange != "" {
destRange = append(destRange, rule.IPRange)
}
fw = append(fw, &computepb.Firewall{
Allowed: []*computepb.Allowed{
{
IPProtocol: proto.String(rule.Protocol),
Ports: []string{fmt.Sprint(rule.Port)},
},
},
Description: proto.String(rule.Description),
DestinationRanges: destRange,
Name: proto.String(rule.Name),
})
}
return fw
}
func (f Firewall) Azure() []*armnetwork.SecurityRule {
var fw []*armnetwork.SecurityRule
for i, rule := range f {
// format string according to armnetwork.SecurityRuleProtocol specification
protocol := strings.Title(strings.ToLower(rule.Protocol))
fw = append(fw, &armnetwork.SecurityRule{
Name: proto.String(rule.Name),
Properties: &armnetwork.SecurityRulePropertiesFormat{
Description: proto.String(rule.Description),
Protocol: (*armnetwork.SecurityRuleProtocol)(proto.String(protocol)),
SourceAddressPrefix: proto.String(rule.IPRange),
SourcePortRange: proto.String("*"),
DestinationAddressPrefix: proto.String(rule.IPRange),
DestinationPortRange: proto.String(strconv.Itoa(rule.Port)),
Access: armnetwork.SecurityRuleAccessAllow.ToPtr(),
Direction: armnetwork.SecurityRuleDirectionInbound.ToPtr(),
// Each security role needs a unique priority
Priority: proto.Int32(int32(100 * (i + 1))),
},
})
}
return fw
}
func (f Firewall) AWS() []ec2types.IpPermission {
var fw []ec2types.IpPermission
for _, rule := range f {
fw = append(fw, ec2types.IpPermission{
FromPort: proto.Int32(int32(rule.Port)),
ToPort: proto.Int32(int32(rule.Port)),
IpProtocol: proto.String(rule.Protocol),
IpRanges: []ec2types.IpRange{
{
CidrIp: proto.String(rule.IPRange),
Description: proto.String(rule.Description),
},
},
})
}
return fw
}

View File

@ -0,0 +1,188 @@
package cloudtypes
import (
"strconv"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
)
func TestFirewallGCP(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
testFw := Firewall{
{
Name: "test-1",
Description: "This is the Test-1 Permission",
Protocol: "tcp",
IPRange: "",
Port: 9000,
},
{
Name: "test-2",
Description: "This is the Test-2 Permission",
Protocol: "udp",
IPRange: "",
Port: 51820,
},
}
firewalls := testFw.GCP()
assert.Equal(2, len(firewalls))
// Check permissions
for i := 0; i < len(testFw); i++ {
firewall1 := firewalls[i]
actualPermission1 := firewall1.Allowed[0]
actualPort, err := strconv.Atoi(actualPermission1.GetPorts()[0])
require.NoError(err)
assert.Equal(testFw[i].Port, actualPort)
assert.Equal(testFw[i].Protocol, actualPermission1.GetIPProtocol())
assert.Equal(testFw[i].Name, firewall1.GetName())
assert.Equal(testFw[i].Description, firewall1.GetDescription())
}
}
func TestFirewallAzure(t *testing.T) {
assert := assert.New(t)
input := Firewall{
{
Name: "perm1",
Description: "perm1 description",
Protocol: "TCP",
IPRange: "192.0.2.0/24",
Port: 22,
},
{
Name: "perm2",
Description: "perm2 description",
Protocol: "udp",
IPRange: "192.0.2.0/24",
Port: 4433,
},
{
Name: "perm3",
Description: "perm3 description",
Protocol: "tcp",
IPRange: "192.0.2.0/24",
Port: 4433,
},
}
expectedOutput := []*armnetwork.SecurityRule{
{
Name: proto.String("perm1"),
Properties: &armnetwork.SecurityRulePropertiesFormat{
Description: proto.String("perm1 description"),
Protocol: armnetwork.SecurityRuleProtocolTCP.ToPtr(),
SourceAddressPrefix: proto.String("192.0.2.0/24"),
SourcePortRange: proto.String("*"),
DestinationAddressPrefix: proto.String("192.0.2.0/24"),
DestinationPortRange: proto.String("22"),
Access: armnetwork.SecurityRuleAccessAllow.ToPtr(),
Direction: armnetwork.SecurityRuleDirectionInbound.ToPtr(),
Priority: proto.Int32(100),
},
},
{
Name: proto.String("perm2"),
Properties: &armnetwork.SecurityRulePropertiesFormat{
Description: proto.String("perm2 description"),
Protocol: armnetwork.SecurityRuleProtocolUDP.ToPtr(),
SourceAddressPrefix: proto.String("192.0.2.0/24"),
SourcePortRange: proto.String("*"),
DestinationAddressPrefix: proto.String("192.0.2.0/24"),
DestinationPortRange: proto.String("4433"),
Access: armnetwork.SecurityRuleAccessAllow.ToPtr(),
Direction: armnetwork.SecurityRuleDirectionInbound.ToPtr(),
Priority: proto.Int32(200),
},
},
{
Name: proto.String("perm3"),
Properties: &armnetwork.SecurityRulePropertiesFormat{
Description: proto.String("perm3 description"),
Protocol: armnetwork.SecurityRuleProtocolTCP.ToPtr(),
SourceAddressPrefix: proto.String("192.0.2.0/24"),
SourcePortRange: proto.String("*"),
DestinationAddressPrefix: proto.String("192.0.2.0/24"),
DestinationPortRange: proto.String("4433"),
Access: armnetwork.SecurityRuleAccessAllow.ToPtr(),
Direction: armnetwork.SecurityRuleDirectionInbound.ToPtr(),
Priority: proto.Int32(300),
},
},
}
out := input.Azure()
assert.Equal(expectedOutput, out)
}
func TestIPPermissonsToAWS(t *testing.T) {
assert := assert.New(t)
input := Firewall{
{
Description: "perm1",
Protocol: "TCP",
IPRange: "192.0.2.0/24",
Port: 22,
},
{
Description: "perm2",
Protocol: "UDP",
IPRange: "192.0.2.0/24",
Port: 4433,
},
{
Description: "perm3",
Protocol: "TCP",
IPRange: "192.0.2.0/24",
Port: 4433,
},
}
expectedOutput := []ec2types.IpPermission{
{
FromPort: proto.Int32(int32(22)),
ToPort: proto.Int32(int32(22)),
IpProtocol: proto.String("TCP"),
IpRanges: []ec2types.IpRange{
{
CidrIp: proto.String("192.0.2.0/24"),
Description: proto.String("perm1"),
},
},
},
{
FromPort: proto.Int32(int32(4433)),
ToPort: proto.Int32(int32(4433)),
IpProtocol: proto.String("UDP"),
IpRanges: []ec2types.IpRange{
{
CidrIp: proto.String("192.0.2.0/24"),
Description: proto.String("perm2"),
},
},
},
{
FromPort: proto.Int32(int32(4433)),
ToPort: proto.Int32(int32(4433)),
IpProtocol: proto.String("TCP"),
IpRanges: []ec2types.IpRange{
{
CidrIp: proto.String("192.0.2.0/24"),
Description: proto.String("perm3"),
},
},
},
}
out := input.AWS()
assert.Equal(expectedOutput, out)
}

View File

@ -0,0 +1,13 @@
package cloudprovider
//go:generate stringer -type=CloudProvider
// CloudProvider is cloud provider used by the CLI.
type CloudProvider uint32
const (
Unknown CloudProvider = iota
AWS
Azure
GCP
)

View File

@ -0,0 +1,26 @@
// Code generated by "stringer -type=CloudProvider"; DO NOT EDIT.
package cloudprovider
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[Unknown-0]
_ = x[AWS-1]
_ = x[Azure-2]
_ = x[GCP-3]
}
const _CloudProvider_name = "UnknownAWSAzureGCP"
var _CloudProvider_index = [...]uint8{0, 7, 10, 15, 18}
func (i CloudProvider) String() string {
if i >= CloudProvider(len(_CloudProvider_index)-1) {
return "CloudProvider(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _CloudProvider_name[_CloudProvider_index[i]:_CloudProvider_index[i+1]]
}

22
cli/cmd/azureclient.go Normal file
View File

@ -0,0 +1,22 @@
package cmd
import (
"context"
"github.com/edgelesssys/constellation/cli/azure/client"
"github.com/edgelesssys/constellation/internal/state"
)
type azureclient interface {
GetState() (state.ConstellationState, error)
SetState(state.ConstellationState) error
CreateResourceGroup(ctx context.Context) error
CreateVirtualNetwork(ctx context.Context) error
CreateSecurityGroup(ctx context.Context, input client.NetworkSecurityGroupInput) error
CreateInstances(ctx context.Context, input client.CreateInstancesInput) error
// TODO: deprecate as soon as scale sets are available
CreateInstancesVMs(ctx context.Context, input client.CreateInstancesInput) error
CreateServicePrincipal(ctx context.Context) (string, error)
TerminateResourceGroup(ctx context.Context) error
TerminateServicePrincipal(ctx context.Context) error
}

194
cli/cmd/azureclient_test.go Normal file
View File

@ -0,0 +1,194 @@
package cmd
import (
"context"
"strconv"
"github.com/edgelesssys/constellation/cli/azure"
"github.com/edgelesssys/constellation/cli/azure/client"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/internal/state"
)
type fakeAzureClient struct {
nodes azure.Instances
coordinators azure.Instances
resourceGroup string
name string
uid string
location string
subscriptionID string
tenantID string
subnetID string
coordinatorsScaleSet string
nodesScaleSet string
networkSecurityGroup string
adAppObjectID string
}
func (c *fakeAzureClient) GetState() (state.ConstellationState, error) {
stat := state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: c.nodes,
AzureCoordinators: c.coordinators,
Name: c.name,
UID: c.uid,
AzureResourceGroup: c.resourceGroup,
AzureLocation: c.location,
AzureSubscription: c.subscriptionID,
AzureTenant: c.tenantID,
AzureSubnet: c.subnetID,
AzureNetworkSecurityGroup: c.networkSecurityGroup,
AzureNodesScaleSet: c.nodesScaleSet,
AzureCoordinatorsScaleSet: c.coordinatorsScaleSet,
AzureADAppObjectID: c.adAppObjectID,
}
return stat, nil
}
func (c *fakeAzureClient) SetState(stat state.ConstellationState) error {
c.nodes = stat.AzureNodes
c.coordinators = stat.AzureCoordinators
c.name = stat.Name
c.uid = stat.UID
c.resourceGroup = stat.AzureResourceGroup
c.location = stat.AzureLocation
c.subscriptionID = stat.AzureSubscription
c.tenantID = stat.AzureTenant
c.subnetID = stat.AzureSubnet
c.networkSecurityGroup = stat.AzureNetworkSecurityGroup
c.nodesScaleSet = stat.AzureNodesScaleSet
c.coordinatorsScaleSet = stat.AzureCoordinatorsScaleSet
c.adAppObjectID = stat.AzureADAppObjectID
return nil
}
func (c *fakeAzureClient) CreateResourceGroup(ctx context.Context) error {
c.resourceGroup = "resource-group"
return nil
}
func (c *fakeAzureClient) CreateVirtualNetwork(ctx context.Context) error {
c.subnetID = "subnet"
return nil
}
func (c *fakeAzureClient) CreateSecurityGroup(ctx context.Context, input client.NetworkSecurityGroupInput) error {
c.networkSecurityGroup = "network-security-group"
return nil
}
func (c *fakeAzureClient) CreateInstances(ctx context.Context, input client.CreateInstancesInput) error {
c.coordinatorsScaleSet = "coordinators-scale-set"
c.nodesScaleSet = "nodes-scale-set"
c.nodes = make(azure.Instances)
for i := 0; i < input.Count-1; i++ {
id := strconv.Itoa(i)
c.nodes[id] = azure.Instance{PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1"}
}
c.coordinators = make(azure.Instances)
c.coordinators["0"] = azure.Instance{PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1"}
return nil
}
// TODO: deprecate as soon as scale sets are available.
func (c *fakeAzureClient) CreateInstancesVMs(ctx context.Context, input client.CreateInstancesInput) error {
c.nodes = make(azure.Instances)
for i := 0; i < input.Count-1; i++ {
id := strconv.Itoa(i)
c.nodes[id] = azure.Instance{PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1"}
}
c.coordinators = make(azure.Instances)
c.coordinators["0"] = azure.Instance{PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1"}
return nil
}
func (c *fakeAzureClient) CreateServicePrincipal(ctx context.Context) (string, error) {
c.adAppObjectID = "00000000-0000-0000-0000-000000000001"
return client.ApplicationCredentials{
ClientID: "client-id",
ClientSecret: "client-secret",
}.ConvertToCloudServiceAccountURI(), nil
}
func (c *fakeAzureClient) TerminateResourceGroup(ctx context.Context) error {
if c.resourceGroup == "" {
return nil
}
c.nodes = nil
c.coordinators = nil
c.resourceGroup = ""
c.subnetID = ""
c.networkSecurityGroup = ""
c.nodesScaleSet = ""
c.coordinatorsScaleSet = ""
return nil
}
func (c *fakeAzureClient) TerminateServicePrincipal(ctx context.Context) error {
if c.adAppObjectID == "" {
return nil
}
c.adAppObjectID = ""
return nil
}
type stubAzureClient struct {
terminateResourceGroupCalled bool
getStateErr error
setStateErr error
createResourceGroupErr error
createVirtualNetworkErr error
createSecurityGroupErr error
createInstancesErr error
createServicePrincipalErr error
terminateResourceGroupErr error
terminateServicePrincipalErr error
}
func (c *stubAzureClient) GetState() (state.ConstellationState, error) {
return state.ConstellationState{}, c.getStateErr
}
func (c *stubAzureClient) SetState(state.ConstellationState) error {
return c.setStateErr
}
func (c *stubAzureClient) CreateResourceGroup(ctx context.Context) error {
return c.createResourceGroupErr
}
func (c *stubAzureClient) CreateVirtualNetwork(ctx context.Context) error {
return c.createVirtualNetworkErr
}
func (c *stubAzureClient) CreateSecurityGroup(ctx context.Context, input client.NetworkSecurityGroupInput) error {
return c.createSecurityGroupErr
}
func (c *stubAzureClient) CreateInstances(ctx context.Context, input client.CreateInstancesInput) error {
return c.createInstancesErr
}
// TODO: deprecate as soon as scale sets are available.
func (c *stubAzureClient) CreateInstancesVMs(ctx context.Context, input client.CreateInstancesInput) error {
return c.createInstancesErr
}
func (c *stubAzureClient) CreateServicePrincipal(ctx context.Context) (string, error) {
return client.ApplicationCredentials{
ClientID: "00000000-0000-0000-0000-000000000000",
ClientSecret: "secret",
}.ConvertToCloudServiceAccountURI(), c.createServicePrincipalErr
}
func (c *stubAzureClient) TerminateResourceGroup(ctx context.Context) error {
c.terminateResourceGroupCalled = true
return c.terminateResourceGroupErr
}
func (c *stubAzureClient) TerminateServicePrincipal(ctx context.Context) error {
return c.terminateServicePrincipalErr
}

13
cli/cmd/constants.go Normal file
View File

@ -0,0 +1,13 @@
package cmd
// wireguardKeyLength is the length of a WireGuard key in byte.
const wireguardKeyLength = 32
// masterSecretLengthDefault is the default length in bytes for CLI generated master secrets.
const masterSecretLengthDefault = 32
// masterSecretLengthMin is the minimal length in bytes for user provided master secrets.
const masterSecretLengthMin = 16
// constellationNameLength is the maximum length of a Constellation's name.
const constellationNameLength = 37

41
cli/cmd/create.go Normal file
View File

@ -0,0 +1,41 @@
package cmd
import (
"errors"
"fmt"
"io/fs"
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/internal/config"
"github.com/spf13/cobra"
)
func newCreateCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "create aws|gcp|azure",
Short: "Create instances on a cloud platform for your Constellation.",
Long: "Create instances on a cloud platform for your Constellation.",
}
cmd.PersistentFlags().String("name", "constell", "Set this flag to create the Constellation with the specified name.")
cmd.PersistentFlags().BoolP("yes", "y", false, "Set this flag to create the Constellation without further confirmation.")
cmd.AddCommand(newCreateAWSCmd())
cmd.AddCommand(newCreateGCPCmd())
cmd.AddCommand(newCreateAzureCmd())
return cmd
}
// checkDirClean checks if files of a previous Constellation are left in the current working dir.
func checkDirClean(fileHandler file.Handler, config *config.Config) error {
if _, err := fileHandler.Stat(*config.StatePath); !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("file '%s' already exists in working directory, run 'constellation terminate' before creating a new one", *config.StatePath)
}
if _, err := fileHandler.Stat(*config.AdminConfPath); !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("file '%s' already exists in working directory, run 'constellation terminate' before creating a new one", *config.AdminConfPath)
}
if _, err := fileHandler.Stat(*config.MasterSecretPath); !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("file '%s' already exists in working directory, clean it up first", *config.MasterSecretPath)
}
return nil
}

138
cli/cmd/create_aws.go Normal file
View File

@ -0,0 +1,138 @@
package cmd
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/edgelesssys/constellation/cli/ec2"
"github.com/edgelesssys/constellation/cli/ec2/client"
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/internal/config"
)
func newCreateAWSCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "aws NUMBER SIZE",
Short: "Create a Constellation of NUMBER nodes of SIZE on AWS.",
Long: "Create a Constellation of NUMBER nodes of SIZE on AWS.",
Example: "aws 4 2xlarge",
Args: cobra.MatchAll(
cobra.ExactArgs(2),
isIntGreaterArg(0, 1),
isEC2InstanceType(1),
),
ValidArgsFunction: createAWSCompletion,
RunE: runCreateAWS,
}
return cmd
}
// runCreateAWS runs the create command.
func runCreateAWS(cmd *cobra.Command, args []string) error {
count, _ := strconv.Atoi(args[0]) // err already checked in args validation
size := strings.ToLower(args[1])
name, err := cmd.Flags().GetString("name")
if err != nil {
return err
}
devConfigName, err := cmd.Flags().GetString("dev-config")
if err != nil {
return err
}
fileHandler := file.NewHandler(afero.NewOsFs())
config, err := config.FromFile(fileHandler, devConfigName)
if err != nil {
return err
}
client, err := client.NewFromDefault(cmd.Context())
if err != nil {
return err
}
return createAWS(cmd, client, fileHandler, config, size, name, count)
}
// createAWS uses the given client to create 'count' instances of 'size'.
// After the instances are running, they are tagged with the default tags.
// On success, the state of the client is saved to the state file.
func createAWS(cmd *cobra.Command, cl ec2client, fileHandler file.Handler, config *config.Config, size, name string, count int) (retErr error) {
if err := checkDirClean(fileHandler, config); err != nil {
return err
}
const maxLength = 255
if len(name) > maxLength {
return fmt.Errorf("name for constellation too long, maximum length is %d: %s", maxLength, name)
}
ec2Tags := append([]ec2.Tag{}, *config.Provider.EC2.Tags...)
ec2Tags = append(ec2Tags, ec2.Tag{Key: "Name", Value: name})
ok, err := cmd.Flags().GetBool("yes")
if err != nil {
return err
}
if !ok {
// Ask user to confirm action.
cmd.Printf("The following Constellation will be created:\n")
cmd.Printf("%d nodes of size %s will be created.\n", count, size)
ok, err := askToConfirm(cmd, "Do you want to create this Constellation?")
if err != nil {
return err
}
if !ok {
cmd.Println("The creation of the Constellation was aborted.")
return nil
}
}
defer rollbackOnError(context.Background(), cmd.OutOrStdout(), &retErr, &rollbackerAWS{client: cl})
if err := cl.CreateSecurityGroup(cmd.Context(), *config.Provider.EC2.SecurityGroupInput); err != nil {
return err
}
createInput := client.CreateInput{
ImageId: *config.Provider.EC2.Image,
InstanceType: size,
Count: count,
Tags: ec2Tags,
}
if err := cl.CreateInstances(cmd.Context(), createInput); err != nil {
return err
}
stat, err := cl.GetState()
if err != nil {
return err
}
if err := fileHandler.WriteJSON(*config.StatePath, stat, false); err != nil {
return err
}
cmd.Println("Your Constellation was created successfully.")
return nil
}
// createAWSCompletion handels the completion of CLI arguments. It is frequently called
// while the user types arguments of the command to suggest completion.
func createAWSCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
switch len(args) {
case 0:
return []string{}, cobra.ShellCompDirectiveNoFileComp
case 1:
return []string{
"4xlarge",
"8xlarge",
"12xlarge",
"16xlarge",
"24xlarge",
}, cobra.ShellCompDirectiveDefault
default:
return []string{}, cobra.ShellCompDirectiveError
}
}

205
cli/cmd/create_aws_test.go Normal file
View File

@ -0,0 +1,205 @@
package cmd
import (
"bytes"
"errors"
"testing"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/cli/ec2"
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/state"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCreateAWSCmdArgumentValidation(t *testing.T) {
testCases := map[string]struct {
args []string
expectErr bool
}{
"valid size 4XL": {[]string{"5", "4xlarge"}, false},
"valid size 8XL": {[]string{"4", "8xlarge"}, false},
"valid size 12XL": {[]string{"3", "12xlarge"}, false},
"valid size 16XL": {[]string{"2", "16xlarge"}, false},
"valid size 24XL": {[]string{"2", "24xlarge"}, false},
"valid short 12XL": {[]string{"4", "12xl"}, false},
"valid short 24XL": {[]string{"2", "24xl"}, false},
"valid capitalized": {[]string{"3", "24XlARge"}, false},
"valid short capitalized": {[]string{"4", "16XL"}, false},
"invalid to many arguments": {[]string{"2", "4xl", "2xl"}, true},
"invalid to many arguments 2": {[]string{"2", "4xl", "2"}, true},
"invalidOnlyOneInstance": {[]string{"1", "4xl"}, true},
"invalid first is no int": {[]string{"xl", "4xl"}, true},
"invalid second is no size": {[]string{"2", "2"}, true},
"invalid wrong order": {[]string{"4xl", "2"}, true},
}
cmd := newCreateAWSCmd()
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := cmd.ValidateArgs(tc.args)
if tc.expectErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestCreateAWS(t *testing.T) {
testState := state.ConstellationState{
CloudProvider: cloudprovider.AWS.String(),
EC2Instances: ec2.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
"id-2": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
},
EC2SecurityGroup: "sg-test",
}
someErr := errors.New("failed")
config := config.Default()
testCases := map[string]struct {
existingState *state.ConstellationState
client ec2client
interactive bool
interactiveStdin string
stateExpected state.ConstellationState
errExpected bool
}{
"create some instances": {
client: &fakeEc2Client{},
stateExpected: testState,
errExpected: false,
},
"state already exists": {
existingState: &testState,
client: &fakeEc2Client{},
errExpected: true,
},
"create some instances interactive": {
client: &fakeEc2Client{},
interactive: true,
interactiveStdin: "y\n",
stateExpected: testState,
errExpected: false,
},
"fail CreateSecurityGroup": {
client: &stubEc2Client{createSecurityGroupErr: someErr},
errExpected: true,
},
"fail CreateInstances": {
client: &stubEc2Client{createInstancesErr: someErr},
errExpected: true,
},
"fail GetState": {
client: &stubEc2Client{getStateErr: someErr},
errExpected: true,
},
"error on rollback": {
client: &stubEc2Client{createInstancesErr: someErr, deleteSecurityGroupErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := newCreateAWSCmd()
cmd.Flags().BoolP("yes", "y", false, "")
out := bytes.NewBufferString("")
cmd.SetOut(out)
errOut := bytes.NewBufferString("")
cmd.SetErr(errOut)
in := bytes.NewBufferString(tc.interactiveStdin)
cmd.SetIn(in)
if !tc.interactive {
require.NoError(cmd.Flags().Set("yes", "true")) // disable interactivity
}
fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs)
if tc.existingState != nil {
require.NoError(fileHandler.WriteJSON(*config.StatePath, *tc.existingState, false))
}
err := createAWS(cmd, tc.client, fileHandler, config, "xlarge", "name", 3)
if tc.errExpected {
assert.Error(err)
if stubClient, ok := tc.client.(*stubEc2Client); ok {
// Should have made a rollback on error.
assert.True(stubClient.terminateInstancesCalled)
assert.True(stubClient.deleteSecurityGroupCalled)
}
} else {
assert.NoError(err)
var stat state.ConstellationState
err := fileHandler.ReadJSON(*config.StatePath, &stat)
assert.NoError(err)
assert.Equal(tc.stateExpected, stat)
}
})
}
}
func TestCreateAWSCompletion(t *testing.T) {
testCases := map[string]struct {
args []string
toComplete string
resultExpected []string
shellCDExpected cobra.ShellCompDirective
}{
"first arg": {
args: []string{},
toComplete: "21",
resultExpected: []string{},
shellCDExpected: cobra.ShellCompDirectiveNoFileComp,
},
"second arg": {
args: []string{"23"},
toComplete: "4xl",
resultExpected: []string{
"4xlarge",
"8xlarge",
"12xlarge",
"16xlarge",
"24xlarge",
},
shellCDExpected: cobra.ShellCompDirectiveDefault,
},
"third arg": {
args: []string{"23", "4xlarge"},
toComplete: "xl",
resultExpected: []string{},
shellCDExpected: cobra.ShellCompDirectiveError,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
cmd := &cobra.Command{}
result, shellCD := createAWSCompletion(cmd, tc.args, tc.toComplete)
assert.Equal(tc.resultExpected, result)
assert.Equal(tc.shellCDExpected, shellCD)
})
}
}

137
cli/cmd/create_azure.go Normal file
View File

@ -0,0 +1,137 @@
package cmd
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/edgelesssys/constellation/cli/azure"
"github.com/edgelesssys/constellation/cli/azure/client"
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/internal/config"
"github.com/spf13/afero"
"github.com/spf13/cobra"
)
func newCreateAzureCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "azure",
Short: "Create a Constellation of NUMBER nodes of SIZE on Azure.",
Long: "Create a Constellation of NUMBER nodes of SIZE on Azure.",
Args: cobra.MatchAll(
cobra.ExactArgs(2),
isIntGreaterArg(0, 1),
isAzureInstanceType(1),
),
ValidArgsFunction: createAzureCompletion,
RunE: runCreateAzure,
}
return cmd
}
// runCreateAzure runs the create command.
func runCreateAzure(cmd *cobra.Command, args []string) error {
count, _ := strconv.Atoi(args[0]) // err already checked in args validation
size := strings.ToLower(args[1])
subscriptionID := "0d202bbb-4fa7-4af8-8125-58c269a05435" // TODO: This will be user input
tenantID := "adb650a8-5da3-4b15-b4b0-3daf65ff7626" // TODO: This will be user input
location := "North Europe" // TODO: This will be user input
name, err := cmd.Flags().GetString("name")
if err != nil {
return err
}
if len(name) > constellationNameLength {
return fmt.Errorf("name for constellation too long, maximum length is %d got %d: %s", constellationNameLength, len(name), name)
}
client, err := client.NewInitialized(
subscriptionID,
tenantID,
name,
location,
)
if err != nil {
return err
}
devConfigName, err := cmd.Flags().GetString("dev-config")
if err != nil {
return err
}
fileHandler := file.NewHandler(afero.NewOsFs())
config, err := config.FromFile(fileHandler, devConfigName)
if err != nil {
return err
}
return createAzure(cmd, client, fileHandler, config, size, count)
}
func createAzure(cmd *cobra.Command, cl azureclient, fileHandler file.Handler, config *config.Config, size string, count int) (retErr error) {
if err := checkDirClean(fileHandler, config); err != nil {
return err
}
ok, err := cmd.Flags().GetBool("yes")
if err != nil {
return err
}
if !ok {
// Ask user to confirm action.
cmd.Printf("The following Constellation will be created:\n")
cmd.Printf("%d nodes of size %s will be created.\n", count, size)
ok, err := askToConfirm(cmd, "Do you want to create this Constellation?")
if err != nil {
return err
}
if !ok {
cmd.Println("The creation of the Constellation was aborted.")
return nil
}
}
// Create all azure resources
defer rollbackOnError(context.Background(), cmd.OutOrStdout(), &retErr, &rollbackerAzure{client: cl})
if err := cl.CreateResourceGroup(cmd.Context()); err != nil {
return err
}
if err := cl.CreateVirtualNetwork(cmd.Context()); err != nil {
return err
}
if err := cl.CreateSecurityGroup(cmd.Context(), *config.Provider.Azure.NetworkSecurityGroupInput); err != nil {
return err
}
if err := cl.CreateInstances(cmd.Context(), client.CreateInstancesInput{
Count: count,
InstanceType: size,
Image: *config.Provider.Azure.Image,
UserAssingedIdentity: *config.Provider.Azure.UserAssignedIdentity,
}); err != nil {
return err
}
stat, err := cl.GetState()
if err != nil {
return err
}
if err := fileHandler.WriteJSON(*config.StatePath, stat, false); err != nil {
return err
}
cmd.Println("Your Constellation was created successfully.")
return nil
}
// createAzureCompletion handels the completion of CLI arguments. It is frequently called
// while the user types arguments of the command to suggest completion.
func createAzureCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
switch len(args) {
case 0:
return []string{}, cobra.ShellCompDirectiveNoFileComp
case 1:
return azure.InstanceTypes, cobra.ShellCompDirectiveDefault
default:
return []string{}, cobra.ShellCompDirectiveError
}
}

View File

@ -0,0 +1,206 @@
package cmd
import (
"bytes"
"errors"
"testing"
"github.com/edgelesssys/constellation/cli/azure"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/state"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCreateAzureCmdArgumentValidation(t *testing.T) {
testCases := map[string]struct {
args []string
expectErr bool
}{
"valid create 1": {[]string{"3", "Standard_DC2as_v5"}, false},
"valid create 2": {[]string{"7", "Standard_DC4as_v5"}, false},
"valid create 3": {[]string{"2", "Standard_DC8as_v5"}, false},
"invalid to many arguments": {[]string{"2", "Standard_DC2as_v5", "Standard_DC2as_v5"}, true},
"invalid to many arguments 2": {[]string{"2", "Standard_DC2as_v5", "2"}, true},
"invalidOnlyOneInstance": {[]string{"1", "Standard_DC2as_v5"}, true},
"invalid first is no int": {[]string{"Standard_DC2as_v5", "Standard_DC2as_v5"}, true},
"invalid second is no size": {[]string{"2", "2"}, true},
"invalid wrong order": {[]string{"Standard_DC2as_v5", "2"}, true},
}
cmd := newCreateAzureCmd()
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := cmd.ValidateArgs(tc.args)
if tc.expectErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestCreateAzure(t *testing.T) {
testState := state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: azure.Instances{
"0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
AzureCoordinators: azure.Instances{
"0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
AzureResourceGroup: "resource-group",
AzureSubnet: "subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "nodes-scale-set",
AzureCoordinatorsScaleSet: "coordinators-scale-set",
}
someErr := errors.New("failed")
config := config.Default()
testCases := map[string]struct {
existingState *state.ConstellationState
client azureclient
interactive bool
interactiveStdin string
stateExpected state.ConstellationState
errExpected bool
}{
"create some instances": {
client: &fakeAzureClient{},
stateExpected: testState,
},
"state already exists": {
existingState: &testState,
client: &fakeAzureClient{},
errExpected: true,
},
"create some instances interactive": {
client: &fakeAzureClient{},
interactive: true,
interactiveStdin: "y\n",
stateExpected: testState,
errExpected: false,
},
"fail getState": {
client: &stubAzureClient{getStateErr: someErr},
errExpected: true,
},
"fail createVirtualNetwork": {
client: &stubAzureClient{createVirtualNetworkErr: someErr},
errExpected: true,
},
"fail createSecurityGroup": {
client: &stubAzureClient{createSecurityGroupErr: someErr},
errExpected: true,
},
"fail createInstances": {
client: &stubAzureClient{createInstancesErr: someErr},
errExpected: true,
},
"fail createResourceGroup": {
client: &stubAzureClient{createResourceGroupErr: someErr},
errExpected: true,
},
"error on rollback": {
client: &stubAzureClient{createInstancesErr: someErr, terminateResourceGroupErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := newCreateAzureCmd()
cmd.Flags().BoolP("yes", "y", false, "")
out := bytes.NewBufferString("")
cmd.SetOut(out)
errOut := bytes.NewBufferString("")
cmd.SetErr(errOut)
in := bytes.NewBufferString(tc.interactiveStdin)
cmd.SetIn(in)
if !tc.interactive {
require.NoError(cmd.Flags().Set("yes", "true")) // disable interactivity
}
fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs)
if tc.existingState != nil {
require.NoError(fileHandler.WriteJSON(*config.StatePath, *tc.existingState, false))
}
err := createAzure(cmd, tc.client, fileHandler, config, "Standard_D2s_v3", 3)
if tc.errExpected {
assert.Error(err)
if stubClient, ok := tc.client.(*stubAzureClient); ok {
// Should have made a rollback on error.
assert.True(stubClient.terminateResourceGroupCalled)
}
} else {
assert.NoError(err)
var state state.ConstellationState
err := fileHandler.ReadJSON(*config.StatePath, &state)
assert.NoError(err)
assert.Equal(tc.stateExpected, state)
}
})
}
}
func TestCreateAzureCompletion(t *testing.T) {
testCases := map[string]struct {
args []string
toComplete string
resultExpected []string
shellCDExpected cobra.ShellCompDirective
}{
"first arg": {
args: []string{},
toComplete: "21",
resultExpected: []string{},
shellCDExpected: cobra.ShellCompDirectiveNoFileComp,
},
"second arg": {
args: []string{"23"},
toComplete: "Standard_D",
resultExpected: azure.InstanceTypes,
shellCDExpected: cobra.ShellCompDirectiveDefault,
},
"third arg": {
args: []string{"23", "Standard_D2s_v3"},
toComplete: "Standard_D",
resultExpected: []string{},
shellCDExpected: cobra.ShellCompDirectiveError,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
cmd := &cobra.Command{}
result, shellCD := createAzureCompletion(cmd, tc.args, tc.toComplete)
assert.Equal(tc.resultExpected, result)
assert.Equal(tc.shellCDExpected, shellCD)
})
}
}

132
cli/cmd/create_gcp.go Normal file
View File

@ -0,0 +1,132 @@
package cmd
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/cli/gcp"
"github.com/edgelesssys/constellation/cli/gcp/client"
"github.com/edgelesssys/constellation/internal/config"
"github.com/spf13/afero"
"github.com/spf13/cobra"
)
func newCreateGCPCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "gcp",
Short: "Create a Constellation of NUMBER nodes of SIZE on Google Cloud Platform.",
Long: "Create a Constellation of NUMBER nodes of SIZE on Google Cloud Platform.",
Args: cobra.MatchAll(
cobra.ExactArgs(2),
isIntGreaterArg(0, 1),
isGCPInstanceType(1),
),
ValidArgsFunction: createGCPCompletion,
RunE: runCreateGCP,
}
return cmd
}
// runCreateGCP runs the create command.
func runCreateGCP(cmd *cobra.Command, args []string) error {
count, _ := strconv.Atoi(args[0]) // err already checked in args validation
size := strings.ToLower(args[1])
project := "constellation-331613" // TODO: This will be user input
zone := "us-central1-c" // TODO: This will be user input
region := "us-central1" // TODO: This will be user input
name, err := cmd.Flags().GetString("name")
if err != nil {
return err
}
if len(name) > constellationNameLength {
return fmt.Errorf("name for constellation too long, maximum length is %d got %d: %s", constellationNameLength, len(name), name)
}
client, err := client.NewInitialized(cmd.Context(), project, zone, region, name)
if err != nil {
return err
}
devConfigName, err := cmd.Flags().GetString("dev-config")
if err != nil {
return err
}
fileHandler := file.NewHandler(afero.NewOsFs())
config, err := config.FromFile(fileHandler, devConfigName)
if err != nil {
return err
}
return createGCP(cmd, client, fileHandler, config, size, count)
}
func createGCP(cmd *cobra.Command, cl gcpclient, fileHandler file.Handler, config *config.Config, size string, count int) (retErr error) {
if err := checkDirClean(fileHandler, config); err != nil {
return err
}
createInput := client.CreateInstancesInput{
Count: count,
ImageId: *config.Provider.GCP.Image,
InstanceType: size,
KubeEnv: gcp.KubeEnv,
DisableCVM: *config.Provider.GCP.DisableCVM,
}
ok, err := cmd.Flags().GetBool("yes")
if err != nil {
return err
}
if !ok {
// Ask user to confirm action.
cmd.Printf("The following Constellation will be created:\n")
cmd.Printf("%d nodes of size %s will be created.\n", count, size)
ok, err := askToConfirm(cmd, "Do you want to create this Constellation?")
if err != nil {
return err
}
if !ok {
cmd.Println("The creation of the Constellation was aborted.")
return nil
}
}
// Create all gcp resources
defer rollbackOnError(context.Background(), cmd.OutOrStdout(), &retErr, &rollbackerGCP{client: cl})
if err := cl.CreateVPCs(cmd.Context(), *config.Provider.GCP.VPCsInput); err != nil {
return err
}
if err := cl.CreateFirewall(cmd.Context(), *config.Provider.GCP.FirewallInput); err != nil {
return err
}
if err := cl.CreateInstances(cmd.Context(), createInput); err != nil {
return err
}
stat, err := cl.GetState()
if err != nil {
return err
}
if err := fileHandler.WriteJSON(*config.StatePath, stat, false); err != nil {
return err
}
cmd.Println("Your Constellation was created successfully.")
return nil
}
// createGCPCompletion handels the completion of CLI arguments. It is frequently called
// while the user types arguments of the command to suggest completion.
func createGCPCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
switch len(args) {
case 0:
return []string{}, cobra.ShellCompDirectiveNoFileComp
case 1:
return gcp.InstanceTypes, cobra.ShellCompDirectiveDefault
default:
return []string{}, cobra.ShellCompDirectiveError
}
}

206
cli/cmd/create_gcp_test.go Normal file
View File

@ -0,0 +1,206 @@
package cmd
import (
"bytes"
"errors"
"testing"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/cli/gcp"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/state"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCreateGCPCmdArgumentValidation(t *testing.T) {
testCases := map[string]struct {
args []string
expectErr bool
}{
"valid create 1": {[]string{"3", "n2d-standard-2"}, false},
"valid create 2": {[]string{"7", "n2d-standard-16"}, false},
"valid create 3": {[]string{"2", "n2d-standard-96"}, false},
"invalid to many arguments": {[]string{"2", "n2d-standard-2", "n2d-standard-2"}, true},
"invalid to many arguments 2": {[]string{"2", "n2d-standard-2", "2"}, true},
"invalidOnlyOneInstance": {[]string{"1", "n2d-standard-2"}, true},
"invalid first is no int": {[]string{"n2d-standard-2", "n2d-standard-2"}, true},
"invalid second is no size": {[]string{"2", "2"}, true},
"invalid wrong order": {[]string{"n2d-standard-2", "2"}, true},
}
cmd := newCreateGCPCmd()
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := cmd.ValidateArgs(tc.args)
if tc.expectErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestCreateGCP(t *testing.T) {
testState := state.ConstellationState{
CloudProvider: cloudprovider.GCP.String(),
GCPNodes: gcp.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
GCPCoordinators: gcp.Instances{
"id-c": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
GCPNodeInstanceGroup: "nodes-group",
GCPCoordinatorInstanceGroup: "coordinator-group",
GCPNodeInstanceTemplate: "node-template",
GCPCoordinatorInstanceTemplate: "coordinator-template",
GCPNetwork: "network",
GCPSubnetwork: "subnetwork",
GCPFirewalls: []string{"coordinator", "wireguard", "ssh"},
}
someErr := errors.New("failed")
config := config.Default()
testCases := map[string]struct {
existingState *state.ConstellationState
client gcpclient
interactive bool
interactiveStdin string
stateExpected state.ConstellationState
errExpected bool
}{
"create some instances": {
client: &fakeGcpClient{},
stateExpected: testState,
},
"state already exists": {
existingState: &testState,
client: &fakeGcpClient{},
errExpected: true,
},
"create some instances interactive": {
client: &fakeGcpClient{},
interactive: true,
interactiveStdin: "y\n",
stateExpected: testState,
errExpected: false,
},
"fail getState": {
client: &stubGcpClient{getStateErr: someErr},
errExpected: true,
},
"fail createVPCs": {
client: &stubGcpClient{createVPCsErr: someErr},
errExpected: true,
},
"fail createFirewall": {
client: &stubGcpClient{createFirewallErr: someErr},
errExpected: true,
},
"fail createInstances": {
client: &stubGcpClient{createInstancesErr: someErr},
errExpected: true,
},
"error on rollback": {
client: &stubGcpClient{createInstancesErr: someErr, terminateVPCsErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := newCreateGCPCmd()
cmd.Flags().BoolP("yes", "y", false, "")
out := bytes.NewBufferString("")
cmd.SetOut(out)
errOut := bytes.NewBufferString("")
cmd.SetErr(errOut)
in := bytes.NewBufferString(tc.interactiveStdin)
cmd.SetIn(in)
if !tc.interactive {
require.NoError(cmd.Flags().Set("yes", "true")) // disable interactivity
}
fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs)
if tc.existingState != nil {
require.NoError(fileHandler.WriteJSON(*config.StatePath, *tc.existingState, false))
}
err := createGCP(cmd, tc.client, fileHandler, config, "n2d-standard-2", 3)
if tc.errExpected {
assert.Error(err)
if stubClient, ok := tc.client.(*stubGcpClient); ok {
// Should have made a rollback on error.
assert.True(stubClient.terminateFirewallCalled)
assert.True(stubClient.terminateInstancesCalled)
assert.True(stubClient.terminateVPCsCalled)
}
} else {
assert.NoError(err)
var stat state.ConstellationState
err := fileHandler.ReadJSON(*config.StatePath, &stat)
assert.NoError(err)
assert.Equal(tc.stateExpected, stat)
}
})
}
}
func TestCreateGCPCompletion(t *testing.T) {
testCases := map[string]struct {
args []string
toComplete string
resultExpected []string
shellCDExpected cobra.ShellCompDirective
}{
"first arg": {
args: []string{},
toComplete: "21",
resultExpected: []string{},
shellCDExpected: cobra.ShellCompDirectiveNoFileComp,
},
"second arg": {
args: []string{"23"},
toComplete: "n2d-stan",
resultExpected: gcp.InstanceTypes,
shellCDExpected: cobra.ShellCompDirectiveDefault,
},
"third arg": {
args: []string{"23", "n2d-standard-2"},
toComplete: "n2d-stan",
resultExpected: []string{},
shellCDExpected: cobra.ShellCompDirectiveError,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
cmd := &cobra.Command{}
result, shellCD := createGCPCompletion(cmd, tc.args, tc.toComplete)
assert.Equal(tc.resultExpected, result)
assert.Equal(tc.shellCDExpected, shellCD)
})
}
}

64
cli/cmd/create_test.go Normal file
View File

@ -0,0 +1,64 @@
package cmd
import (
"testing"
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/internal/config"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCheckDirClean(t *testing.T) {
config := config.Default()
testCases := map[string]struct {
fileHandler file.Handler
existingFiles []string
wantErr bool
}{
"no file exists": {
fileHandler: file.NewHandler(afero.NewMemMapFs()),
},
"adminconf exists": {
fileHandler: file.NewHandler(afero.NewMemMapFs()),
existingFiles: []string{*config.AdminConfPath},
wantErr: true,
},
"master secret exists": {
fileHandler: file.NewHandler(afero.NewMemMapFs()),
existingFiles: []string{*config.MasterSecretPath},
wantErr: true,
},
"state file exists": {
fileHandler: file.NewHandler(afero.NewMemMapFs()),
existingFiles: []string{*config.StatePath},
wantErr: true,
},
"multiple exist": {
fileHandler: file.NewHandler(afero.NewMemMapFs()),
existingFiles: []string{*config.AdminConfPath, *config.MasterSecretPath, *config.StatePath},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
for _, f := range tc.existingFiles {
require.NoError(tc.fileHandler.Write(f, []byte{1, 2, 3}, false))
}
err := checkDirClean(tc.fileHandler, config)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}

17
cli/cmd/ec2client.go Normal file
View File

@ -0,0 +1,17 @@
package cmd
import (
"context"
"github.com/edgelesssys/constellation/cli/ec2/client"
"github.com/edgelesssys/constellation/internal/state"
)
type ec2client interface {
GetState() (state.ConstellationState, error)
SetState(stat state.ConstellationState) error
CreateInstances(ctx context.Context, input client.CreateInput) error
TerminateInstances(ctx context.Context) error
CreateSecurityGroup(ctx context.Context, input client.SecurityGroupInput) error
DeleteSecurityGroup(ctx context.Context) error
}

139
cli/cmd/ec2client_test.go Normal file
View File

@ -0,0 +1,139 @@
package cmd
import (
"context"
"errors"
"strconv"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/cli/ec2"
"github.com/edgelesssys/constellation/cli/ec2/client"
"github.com/edgelesssys/constellation/internal/state"
)
type fakeEc2Client struct {
instances ec2.Instances
securityGroup string
ec2state []fakeEc2Instance
}
func (c *fakeEc2Client) GetState() (state.ConstellationState, error) {
if len(c.instances) == 0 {
return state.ConstellationState{}, errors.New("client has no instances")
}
stat := state.ConstellationState{
CloudProvider: cloudprovider.AWS.String(),
EC2Instances: c.instances,
EC2SecurityGroup: c.securityGroup,
}
for id, instance := range c.instances {
instance.PrivateIP = "192.0.2.1"
instance.PublicIP = "192.0.2.2"
c.instances[id] = instance
}
return stat, nil
}
func (c *fakeEc2Client) SetState(stat state.ConstellationState) error {
if len(stat.EC2Instances) == 0 {
return errors.New("state has no instances")
}
c.instances = stat.EC2Instances
c.securityGroup = stat.EC2SecurityGroup
return nil
}
func (c *fakeEc2Client) CreateInstances(_ context.Context, input client.CreateInput) error {
if c.securityGroup == "" {
return errors.New("client has no security group")
}
if c.instances == nil {
c.instances = make(ec2.Instances)
}
for i := 0; i < input.Count; i++ {
id := "id-" + strconv.Itoa(len(c.ec2state))
c.ec2state = append(c.ec2state, fakeEc2Instance{
state: running,
instanceID: id,
securityGroup: c.securityGroup,
tags: input.Tags,
})
c.instances[id] = ec2.Instance{}
}
return nil
}
func (c *fakeEc2Client) TerminateInstances(_ context.Context) error {
if len(c.instances) == 0 {
return nil
}
for _, instance := range c.ec2state {
instance.state = terminated
}
return nil
}
func (c *fakeEc2Client) CreateSecurityGroup(_ context.Context, input client.SecurityGroupInput) error {
if c.securityGroup != "" {
return errors.New("client already has a security group")
}
c.securityGroup = "sg-test"
return nil
}
func (c *fakeEc2Client) DeleteSecurityGroup(_ context.Context) error {
c.securityGroup = ""
return nil
}
type ec2InstanceState int
const (
running = iota
terminated
)
type fakeEc2Instance struct {
state ec2InstanceState
instanceID string
tags ec2.Tags
securityGroup string
}
type stubEc2Client struct {
terminateInstancesCalled bool
deleteSecurityGroupCalled bool
getStateErr error
setStateErr error
createInstancesErr error
terminateInstancesErr error
createSecurityGroupErr error
deleteSecurityGroupErr error
}
func (c *stubEc2Client) GetState() (state.ConstellationState, error) {
return state.ConstellationState{}, c.getStateErr
}
func (c *stubEc2Client) SetState(stat state.ConstellationState) error {
return c.setStateErr
}
func (c *stubEc2Client) CreateInstances(_ context.Context, input client.CreateInput) error {
return c.createInstancesErr
}
func (c *stubEc2Client) TerminateInstances(_ context.Context) error {
c.terminateInstancesCalled = true
return c.terminateInstancesErr
}
func (c *stubEc2Client) CreateSecurityGroup(_ context.Context, input client.SecurityGroupInput) error {
return c.createSecurityGroupErr
}
func (c *stubEc2Client) DeleteSecurityGroup(_ context.Context) error {
c.deleteSecurityGroupCalled = true
return c.deleteSecurityGroupErr
}

22
cli/cmd/gcpclient.go Normal file
View File

@ -0,0 +1,22 @@
package cmd
import (
"context"
"github.com/edgelesssys/constellation/cli/gcp/client"
"github.com/edgelesssys/constellation/internal/state"
)
type gcpclient interface {
GetState() (state.ConstellationState, error)
SetState(state.ConstellationState) error
CreateVPCs(ctx context.Context, input client.VPCsInput) error
CreateFirewall(ctx context.Context, input client.FirewallInput) error
CreateInstances(ctx context.Context, input client.CreateInstancesInput) error
CreateServiceAccount(ctx context.Context, input client.ServiceAccountInput) (string, error)
TerminateFirewall(ctx context.Context) error
TerminateVPCs(context.Context) error
TerminateInstances(context.Context) error
TerminateServiceAccount(ctx context.Context) error
Close() error
}

219
cli/cmd/gcpclient_test.go Normal file
View File

@ -0,0 +1,219 @@
package cmd
import (
"context"
"errors"
"strconv"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/cli/gcp"
"github.com/edgelesssys/constellation/cli/gcp/client"
"github.com/edgelesssys/constellation/internal/state"
)
type fakeGcpClient struct {
nodes gcp.Instances
coordinators gcp.Instances
nodesInstanceGroup string
coordinatorInstanceGroup string
coordinatorTemplate string
nodeTemplate string
network string
subnetwork string
firewalls []string
project string
uid string
name string
zone string
serviceAccount string
}
func (c *fakeGcpClient) GetState() (state.ConstellationState, error) {
stat := state.ConstellationState{
CloudProvider: cloudprovider.GCP.String(),
GCPNodes: c.nodes,
GCPCoordinators: c.coordinators,
GCPNodeInstanceGroup: c.nodesInstanceGroup,
GCPCoordinatorInstanceGroup: c.coordinatorInstanceGroup,
GCPNodeInstanceTemplate: c.nodeTemplate,
GCPCoordinatorInstanceTemplate: c.coordinatorTemplate,
GCPNetwork: c.network,
GCPSubnetwork: c.subnetwork,
GCPFirewalls: c.firewalls,
GCPProject: c.project,
Name: c.name,
UID: c.uid,
GCPZone: c.zone,
GCPServiceAccount: c.serviceAccount,
}
return stat, nil
}
func (c *fakeGcpClient) SetState(stat state.ConstellationState) error {
c.nodes = stat.GCPNodes
c.coordinators = stat.GCPCoordinators
c.nodesInstanceGroup = stat.GCPNodeInstanceGroup
c.coordinatorInstanceGroup = stat.GCPCoordinatorInstanceGroup
c.nodeTemplate = stat.GCPNodeInstanceTemplate
c.coordinatorTemplate = stat.GCPCoordinatorInstanceTemplate
c.network = stat.GCPNetwork
c.subnetwork = stat.GCPSubnetwork
c.firewalls = stat.GCPFirewalls
c.project = stat.GCPProject
c.name = stat.Name
c.uid = stat.UID
c.zone = stat.GCPZone
c.serviceAccount = stat.GCPServiceAccount
return nil
}
func (c *fakeGcpClient) CreateVPCs(ctx context.Context, input client.VPCsInput) error {
c.network = "network"
c.subnetwork = "subnetwork"
return nil
}
func (c *fakeGcpClient) CreateFirewall(ctx context.Context, input client.FirewallInput) error {
if c.network == "" {
return errors.New("client has not network")
}
var firewalls []string
for _, rule := range input.Ingress {
firewalls = append(firewalls, rule.Name)
}
c.firewalls = firewalls
return nil
}
func (c *fakeGcpClient) CreateInstances(ctx context.Context, input client.CreateInstancesInput) error {
c.coordinatorInstanceGroup = "coordinator-group"
c.nodesInstanceGroup = "nodes-group"
c.nodeTemplate = "node-template"
c.coordinatorTemplate = "coordinator-template"
c.nodes = make(gcp.Instances)
for i := 0; i < input.Count-1; i++ {
id := "id-" + strconv.Itoa(len(c.nodes))
c.nodes[id] = gcp.Instance{PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1"}
}
c.coordinators = make(gcp.Instances)
c.coordinators["id-c"] = gcp.Instance{PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1"}
return nil
}
func (c *fakeGcpClient) CreateServiceAccount(ctx context.Context, input client.ServiceAccountInput) (string, error) {
c.serviceAccount = "service-account@" + c.project + ".iam.gserviceaccount.com"
return client.ServiceAccountKey{
Type: "service_account",
ProjectID: c.project,
PrivateKeyID: "key-id",
PrivateKey: "-----BEGIN PRIVATE KEY-----\nprivate-key\n-----END PRIVATE KEY-----\n",
ClientEmail: c.serviceAccount,
ClientID: "client-id",
AuthURI: "https://accounts.google.com/o/oauth2/auth",
TokenURI: "https://accounts.google.com/o/oauth2/token",
AuthProviderX509CertURL: "https://www.googleapis.com/oauth2/v1/certs",
ClientX509CertURL: "https://www.googleapis.com/robot/v1/metadata/x509/service-account-email",
}.ConvertToCloudServiceAccountURI(), nil
}
func (c *fakeGcpClient) TerminateFirewall(ctx context.Context) error {
if len(c.firewalls) == 0 {
return nil
}
c.firewalls = nil
return nil
}
func (c *fakeGcpClient) TerminateVPCs(context.Context) error {
if len(c.firewalls) != 0 {
return errors.New("client has firewalls, which must be deleted first")
}
c.network = ""
c.subnetwork = ""
return nil
}
func (c *fakeGcpClient) TerminateInstances(context.Context) error {
c.nodeTemplate = ""
c.coordinatorTemplate = ""
c.nodesInstanceGroup = ""
c.coordinatorInstanceGroup = ""
c.nodes = nil
c.coordinators = nil
return nil
}
func (c *fakeGcpClient) TerminateServiceAccount(context.Context) error {
c.serviceAccount = ""
return nil
}
func (c *fakeGcpClient) Close() error {
return nil
}
type stubGcpClient struct {
terminateFirewallCalled bool
terminateInstancesCalled bool
terminateVPCsCalled bool
getStateErr error
setStateErr error
createVPCsErr error
createFirewallErr error
createInstancesErr error
createServiceAccountErr error
terminateFirewallErr error
terminateVPCsErr error
terminateInstancesErr error
terminateServiceAccountErr error
closeErr error
}
func (c *stubGcpClient) GetState() (state.ConstellationState, error) {
return state.ConstellationState{}, c.getStateErr
}
func (c *stubGcpClient) SetState(state.ConstellationState) error {
return c.setStateErr
}
func (c *stubGcpClient) CreateVPCs(ctx context.Context, input client.VPCsInput) error {
return c.createVPCsErr
}
func (c *stubGcpClient) CreateFirewall(ctx context.Context, input client.FirewallInput) error {
return c.createFirewallErr
}
func (c *stubGcpClient) CreateInstances(ctx context.Context, input client.CreateInstancesInput) error {
return c.createInstancesErr
}
func (c *stubGcpClient) CreateServiceAccount(ctx context.Context, input client.ServiceAccountInput) (string, error) {
return client.ServiceAccountKey{}.ConvertToCloudServiceAccountURI(), c.createServiceAccountErr
}
func (c *stubGcpClient) TerminateFirewall(ctx context.Context) error {
c.terminateFirewallCalled = true
return c.terminateFirewallErr
}
func (c *stubGcpClient) TerminateVPCs(context.Context) error {
c.terminateVPCsCalled = true
return c.terminateVPCsErr
}
func (c *stubGcpClient) TerminateInstances(context.Context) error {
c.terminateInstancesCalled = true
return c.terminateInstancesErr
}
func (c *stubGcpClient) TerminateServiceAccount(context.Context) error {
return c.terminateServiceAccountErr
}
func (c *stubGcpClient) Close() error {
return c.closeErr
}

484
cli/cmd/init.go Normal file
View File

@ -0,0 +1,484 @@
package cmd
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"io/fs"
"net"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/cli/gcp"
"github.com/edgelesssys/constellation/cli/proto"
"github.com/edgelesssys/constellation/cli/status"
"github.com/edgelesssys/constellation/cli/vpn"
coordinatorstate "github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/coordinator/util"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/state"
)
func newInitCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "init",
Short: "Initialize the Constellation. Start your confidential Kubernetes cluster.",
Long: "Initialize the Constellation. Start your confidential Kubernetes cluster.",
ValidArgsFunction: initCompletion,
Args: cobra.ExactArgs(0),
RunE: runInitialize,
}
cmd.Flags().String("privatekey", "", "path to your private key.")
cmd.Flags().String("publickey", "", "path to your public key.")
cmd.Flags().String("master-secret", "", "path to base64 encoded master secret.")
cmd.Flags().Bool("autoscale", false, "enable kubernetes cluster-autoscaler")
return cmd
}
// runInitialize runs the initialize command.
func runInitialize(cmd *cobra.Command, args []string) error {
fileHandler := file.NewHandler(afero.NewOsFs())
devConfigName, err := cmd.Flags().GetString("dev-config")
if err != nil {
return err
}
config, err := config.FromFile(fileHandler, devConfigName)
if err != nil {
return err
}
protoClient := proto.NewClient(*config.Provider.GCP.PCRs)
defer protoClient.Close()
vpnClient, err := vpn.NewWithDefaults()
if err != nil {
return err
}
// We have to parse the context separately, since cmd.Context()
// returns nil during the tests otherwise.
return initialize(cmd.Context(), cmd, protoClient, vpnClient, serviceAccountClient{}, fileHandler, config, status.NewWaiter(*config.Provider.GCP.PCRs))
}
// initialize initializes a Constellation. Coordinator instances are activated as Coordinators and will
// themself activate the other peers as nodes.
func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, vpnCl vpnConfigurer, serviceAccountCr serviceAccountCreator,
fileHandler file.Handler, config *config.Config, waiter statusWaiter,
) error {
flagArgs, err := evalFlagArgs(cmd, fileHandler, config)
if err != nil {
return err
}
var stat state.ConstellationState
err = fileHandler.ReadJSON(*config.StatePath, &stat)
if errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("nothing to initialize: %w", err)
} else if err != nil {
return err
}
switch stat.CloudProvider {
case "GCP":
if err := warnAboutPCRs(cmd, *config.Provider.GCP.PCRs, true); err != nil {
return err
}
case "Azure":
if err := warnAboutPCRs(cmd, *config.Provider.Azure.PCRs, true); err != nil {
return err
}
}
serviceAccount, stat, err := serviceAccountCr.createServiceAccount(ctx, stat, config)
if err != nil {
return err
}
if err := fileHandler.WriteJSON(*config.StatePath, stat, true); err != nil {
return err
}
coordinators, nodes, err := getScalingGroupsFromConfig(stat, config)
if err != nil {
return err
}
endpoints := ipsToEndpoints(append(coordinators.PublicIPs(), nodes.PublicIPs()...), *config.CoordinatorPort)
if err := waiter.WaitForAll(ctx, coordinatorstate.AcceptingInit, endpoints); err != nil {
return fmt.Errorf("failed to wait for peer status: %w", err)
}
var autoscalingNodeGroups []string
if flagArgs.autoscale {
autoscalingNodeGroups = append(autoscalingNodeGroups, nodes.GroupID)
}
input := activationInput{
coordinatorPubIP: coordinators.PublicIPs()[0],
pubKey: flagArgs.userPubKey,
masterSecret: flagArgs.masterSecret,
nodePrivIPs: nodes.PrivateIPs(),
autoscalingNodeGroups: autoscalingNodeGroups,
cloudServiceAccountURI: serviceAccount,
}
result, err := activate(ctx, cmd, protCl, input, config)
if err != nil {
return err
}
err = result.writeOutput(cmd.OutOrStdout(), fileHandler, config)
if err != nil {
return err
}
if flagArgs.autoconfigureWG {
if err := configureVpn(vpnCl, result.clientVpnIP, result.coordinatorPubKey, result.coordinatorPubIP, flagArgs.userPrivKey); err != nil {
return err
}
}
return nil
}
func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput, config *config.Config) (activationResult, error) {
if err := client.Connect(input.coordinatorPubIP, *config.CoordinatorPort); err != nil {
return activationResult{}, err
}
respCl, err := client.Activate(ctx, input.pubKey, input.masterSecret, ipsToEndpoints(input.nodePrivIPs, *config.CoordinatorPort), input.autoscalingNodeGroups, input.cloudServiceAccountURI)
if err != nil {
return activationResult{}, err
}
if err := respCl.WriteLogStream(cmd.OutOrStdout()); err != nil {
return activationResult{}, err
}
clientVpnIp, err := respCl.GetClientVpnIp()
if err != nil {
return activationResult{}, err
}
coordinatorPubKey, err := respCl.GetCoordinatorVpnKey()
if err != nil {
return activationResult{}, err
}
kubeconfig, err := respCl.GetKubeconfig()
if err != nil {
return activationResult{}, err
}
ownerID, err := respCl.GetOwnerID()
if err != nil {
return activationResult{}, err
}
clusterID, err := respCl.GetClusterID()
if err != nil {
return activationResult{}, err
}
return activationResult{
clientVpnIP: clientVpnIp,
coordinatorPubKey: coordinatorPubKey,
coordinatorPubIP: input.coordinatorPubIP,
kubeconfig: kubeconfig,
ownerID: ownerID,
clusterID: clusterID,
}, nil
}
type activationInput struct {
coordinatorPubIP string
pubKey []byte
masterSecret []byte
nodePrivIPs []string
autoscalingNodeGroups []string
cloudServiceAccountURI string
}
type activationResult struct {
clientVpnIP string
coordinatorPubKey string
coordinatorPubIP string
kubeconfig string
ownerID string
clusterID string
}
func (res activationResult) writeOutput(w io.Writer, fileHandler file.Handler, config *config.Config) error {
fmt.Fprintln(w, "Your Constellation was successfully initialized.")
fmt.Fprintf(w, "Your WireGuard IP is %s\n", res.clientVpnIP)
fmt.Fprintf(w, "The Coordinator's public IP is %s\n", res.coordinatorPubIP)
fmt.Fprintf(w, "The Coordinator's public key is %s\n", res.coordinatorPubKey)
fmt.Fprintf(w, "The Constellation's owner identifier is %s\n", res.ownerID)
fmt.Fprintf(w, "The Constellation's unique identifier is %s\n", res.clusterID)
if err := fileHandler.Write(*config.AdminConfPath, []byte(res.kubeconfig), false); err != nil {
return err
}
fmt.Fprintf(w, "Your Constellation Kubernetes configuration was successfully written to %s\n", *config.AdminConfPath)
return nil
}
// evalFlagArgs gets the flag values and does preprocessing of these values like
// reading the content from file path flags and deriving other values from flag combinations.
func evalFlagArgs(cmd *cobra.Command, fileHandler file.Handler, config *config.Config) (flagArgs, error) {
userPrivKeyPath, err := cmd.Flags().GetString("privatekey")
if err != nil {
return flagArgs{}, err
}
userPublicKeyPath, err := cmd.Flags().GetString("publickey")
if err != nil {
return flagArgs{}, err
}
userPrivKey, userPubKey, err := readVpnKey(fileHandler, userPrivKeyPath, userPublicKeyPath)
if err != nil {
return flagArgs{}, err
}
masterSecretPath, err := cmd.Flags().GetString("master-secret")
if err != nil {
return flagArgs{}, err
}
masterSecret, err := readOrGeneratedMasterSecret(cmd.OutOrStdout(), fileHandler, masterSecretPath, config)
if err != nil {
return flagArgs{}, err
}
autoscale, err := cmd.Flags().GetBool("autoscale")
if err != nil {
return flagArgs{}, err
}
return flagArgs{
userPrivKey: userPrivKey,
userPubKey: userPubKey,
autoconfigureWG: userPrivKeyPath != "",
autoscale: autoscale,
masterSecret: masterSecret,
}, nil
}
// flagArgs are the resulting values of flag preprocessing.
type flagArgs struct {
userPrivKey []byte
userPubKey []byte
masterSecret []byte
autoconfigureWG bool
autoscale bool
}
func readVpnKey(fileHandler file.Handler, privKeyPath, publicKeyPath string) (privKey, pubKey []byte, err error) {
if privKeyPath != "" {
privKey, err = fileHandler.Read(privKeyPath)
if err != nil {
return nil, nil, err
}
privKeyParsed, err := wgtypes.ParseKey(string(privKey))
if err != nil {
return nil, nil, err
}
pubKey = []byte(privKeyParsed.PublicKey().String())
} else if publicKeyPath != "" {
pubKey, err = fileHandler.Read(publicKeyPath)
if err != nil {
return nil, nil, err
}
if err := checkBase64WGKey(pubKey); err != nil {
return nil, nil, fmt.Errorf("wireguard public key is invalid: %w", err)
}
} else {
return nil, nil, errors.New("neither path to public nor private key provided")
}
return privKey, pubKey, nil
}
func configureVpn(vpnCl vpnConfigurer, clientVpnIp, coordinatorPubKey, coordinatorPublicIp string, privKey []byte) error {
err := vpnCl.Configure(clientVpnIp, coordinatorPubKey, coordinatorPublicIp, string(privKey))
if err != nil {
return fmt.Errorf("could not configure WireGuard automatically: %w", err)
}
return nil
}
func ipsToEndpoints(ips []string, port string) []string {
var endpoints []string
for _, ip := range ips {
endpoints = append(endpoints, net.JoinHostPort(ip, port))
}
return endpoints
}
func checkBase64WGKey(b []byte) error {
keyStr, err := base64.StdEncoding.DecodeString(string(b))
if err != nil {
return errors.New("key can't be decoded")
}
if length := len(keyStr); length != wireguardKeyLength {
return fmt.Errorf("key has invalid length %d", length)
}
return nil
}
// readOrGeneratedMasterSecret reads a base64 encoded master secret from file or generates a new 32 byte secret.
func readOrGeneratedMasterSecret(w io.Writer, fileHandler file.Handler, filename string, config *config.Config) ([]byte, error) {
if filename != "" {
// Try to read the base64 secret from file
encodedSecret, err := fileHandler.Read(filename)
if err != nil {
return nil, err
}
decoded, err := base64.StdEncoding.DecodeString(string(encodedSecret))
if err != nil {
return nil, err
}
if len(decoded) < masterSecretLengthMin {
return nil, errors.New("provided master secret is smaller than the required minimum of 16 Bytes")
}
return decoded, nil
}
// No file given, generate a new secret, and save it to disk
masterSecret, err := util.GenerateRandomBytes(masterSecretLengthDefault)
if err != nil {
return nil, err
}
if err := fileHandler.Write(*config.MasterSecretPath, []byte(base64.StdEncoding.EncodeToString(masterSecret)), false); err != nil {
return nil, err
}
fmt.Fprintf(w, "Your Constellation master secret was successfully written to ./%s\n", *config.MasterSecretPath)
return masterSecret, nil
}
func getScalingGroupsFromConfig(stat state.ConstellationState, config *config.Config) (coordinators, nodes ScalingGroup, err error) {
switch {
case len(stat.EC2Instances) != 0:
return getAWSInstances(stat)
case len(stat.GCPCoordinators) != 0:
return getGCPInstances(stat, config)
case len(stat.AzureCoordinators) != 0:
return getAzureInstances(stat)
default:
return ScalingGroup{}, ScalingGroup{}, errors.New("no instances to init")
}
}
func getAWSInstances(stat state.ConstellationState) (coordinators, nodes ScalingGroup, err error) {
coordinatorID, coordinator, err := stat.EC2Instances.GetOne()
if err != nil {
return
}
// GroupID of coordinators is empty, since they currently do not scale.
coordinators = ScalingGroup{Instances: Instances{Instance(coordinator)}, GroupID: ""}
nodeMap := stat.EC2Instances.GetOthers(coordinatorID)
if len(nodeMap) == 0 {
return ScalingGroup{}, ScalingGroup{}, errors.New("no nodes available, can't create Constellation with one instance")
}
var nodeInstances Instances
for _, node := range nodeMap {
nodeInstances = append(nodeInstances, Instance(node))
}
// TODO: make min / max configurable and abstract autoscaling for different cloud providers
// TODO: GroupID of nodes is empty, since they currently do not scale.
nodes = ScalingGroup{Instances: nodeInstances, GroupID: ""}
return
}
func getGCPInstances(stat state.ConstellationState, config *config.Config) (coordinators, nodes ScalingGroup, err error) {
_, coordinator, err := stat.GCPCoordinators.GetOne()
if err != nil {
return
}
// GroupID of coordinators is empty, since they currently do not scale.
coordinators = ScalingGroup{Instances: Instances{Instance(coordinator)}, GroupID: ""}
nodeMap := stat.GCPNodes
if len(nodeMap) == 0 {
return ScalingGroup{}, ScalingGroup{}, errors.New("no nodes available, can't create Constellation with one instance")
}
var nodeInstances Instances
for _, node := range nodeMap {
nodeInstances = append(nodeInstances, Instance(node))
}
// TODO: make min / max configurable and abstract autoscaling for different cloud providers
nodes = ScalingGroup{
Instances: nodeInstances,
GroupID: gcp.AutoscalingNodeGroup(stat.GCPProject, stat.GCPZone, stat.GCPNodeInstanceGroup, *config.AutoscalingNodeGroupsMin, *config.AutoscalingNodeGroupsMax),
}
return
}
func getAzureInstances(stat state.ConstellationState) (coordinators, nodes ScalingGroup, err error) {
_, coordinator, err := stat.AzureCoordinators.GetOne()
if err != nil {
return
}
// GroupID of coordinators is empty, since they currently do not scale.
coordinators = ScalingGroup{Instances: Instances{Instance(coordinator)}, GroupID: ""}
nodeMap := stat.AzureNodes
if len(nodeMap) == 0 {
return ScalingGroup{}, ScalingGroup{}, errors.New("no nodes available, can't create Constellation with one instance")
}
var nodeInstances Instances
for _, node := range nodeMap {
nodeInstances = append(nodeInstances, Instance(node))
}
// TODO: make min / max configurable and abstract autoscaling for different cloud providers
nodes = ScalingGroup{
Instances: nodeInstances,
GroupID: "",
}
return
}
// initCompletion handels the completion of CLI arguments. It is frequently called
// while the user types arguments of the command to suggest completion.
func initCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
if len(args) != 0 {
return []string{}, cobra.ShellCompDirectiveError
}
return []string{}, cobra.ShellCompDirectiveDefault
}
//
// TODO: Code below is target of multicloud refactoring.
//
// Instance is a cloud instance.
type Instance struct {
PublicIP string
PrivateIP string
}
type Instances []Instance
type ScalingGroup struct {
Instances
GroupID string
}
// PublicIPs returns the public IPs of all the instances.
func (i Instances) PublicIPs() []string {
var ips []string
for _, instance := range i {
ips = append(ips, instance.PublicIP)
}
return ips
}
// PrivateIPs returns the private IPs of all the instances of the Constellation.
func (i Instances) PrivateIPs() []string {
var ips []string
for _, instance := range i {
ips = append(ips, instance.PrivateIP)
}
return ips
}

686
cli/cmd/init_test.go Normal file
View File

@ -0,0 +1,686 @@
package cmd
import (
"bytes"
"context"
"encoding/base64"
"errors"
"strconv"
"strings"
"testing"
"time"
"github.com/edgelesssys/constellation/cli/azure"
"github.com/edgelesssys/constellation/cli/ec2"
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/cli/gcp"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/state"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestInitArgumentValidation(t *testing.T) {
assert := assert.New(t)
cmd := newInitCmd()
assert.NoError(cmd.ValidateArgs(nil))
assert.Error(cmd.ValidateArgs([]string{"something"}))
assert.Error(cmd.ValidateArgs([]string{"sth", "sth"}))
}
func TestInitialize(t *testing.T) {
testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))
config := config.Default()
testEc2State := state.ConstellationState{
CloudProvider: "AWS",
EC2Instances: ec2.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
"id-2": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
},
EC2SecurityGroup: "sg-test",
}
testGcpState := state.ConstellationState{
GCPNodes: gcp.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
GCPCoordinators: gcp.Instances{
"id-c": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
}
testAzureState := state.ConstellationState{
CloudProvider: "Azure",
AzureNodes: azure.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
AzureCoordinators: azure.Instances{
"id-c": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
AzureResourceGroup: "test",
}
testActivationResps := []fakeActivationRespMessage{
{log: "testlog1"},
{log: "testlog2"},
{
kubeconfig: "kubeconfig",
clientVpnIp: "vpnIp",
coordinatorVpnKey: "coordKey",
ownerID: "ownerID",
clusterID: "clusterID",
},
{log: "testlog3"},
}
someErr := errors.New("failed")
testCases := map[string]struct {
existingState state.ConstellationState
client protoClient
serviceAccountCreator stubServiceAccountCreator
waiter statusWaiter
pubKey string
errExpected bool
}{
"initialize some ec2 instances": {
existingState: testEc2State,
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
pubKey: testKey,
},
"initialize some gcp instances": {
existingState: testGcpState,
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
pubKey: testKey,
},
"initialize some azure instances": {
existingState: testAzureState,
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
pubKey: testKey,
},
"no state exists": {
existingState: state.ConstellationState{},
client: &stubProtoClient{},
waiter: stubStatusWaiter{},
pubKey: testKey,
errExpected: true,
},
"no instances to pick one": {
existingState: state.ConstellationState{
EC2Instances: ec2.Instances{},
EC2SecurityGroup: "sg-test",
},
client: &stubProtoClient{},
waiter: stubStatusWaiter{},
pubKey: testKey,
errExpected: true,
},
"only one instance": {
existingState: state.ConstellationState{
EC2Instances: ec2.Instances{"id-1": {}},
EC2SecurityGroup: "sg-test",
},
client: &stubProtoClient{},
waiter: stubStatusWaiter{},
pubKey: testKey,
errExpected: true,
},
"public key to short": {
existingState: testEc2State,
client: &stubProtoClient{},
waiter: stubStatusWaiter{},
pubKey: base64.StdEncoding.EncodeToString([]byte("tooShortKey")),
errExpected: true,
},
"public key to long": {
existingState: testEc2State,
client: &stubProtoClient{},
waiter: stubStatusWaiter{},
pubKey: base64.StdEncoding.EncodeToString([]byte("thisWireguardKeyIsToLongAndHasTooManyBytes")),
errExpected: true,
},
"public key not base64": {
existingState: testEc2State,
client: &stubProtoClient{},
waiter: stubStatusWaiter{},
pubKey: "this is not base64 encoded",
errExpected: true,
},
"fail Connect": {
existingState: testEc2State,
client: &stubProtoClient{connectErr: someErr},
waiter: stubStatusWaiter{},
pubKey: testKey,
errExpected: true,
},
"fail Activate": {
existingState: testEc2State,
client: &stubProtoClient{activateErr: someErr},
waiter: stubStatusWaiter{},
pubKey: testKey,
errExpected: true,
},
"fail respClient WriteLogStream": {
existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{writeLogStreamErr: someErr}},
waiter: stubStatusWaiter{},
pubKey: testKey,
errExpected: true,
},
"fail respClient getKubeconfig": {
existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{getKubeconfigErr: someErr}},
waiter: stubStatusWaiter{},
pubKey: testKey,
errExpected: true,
},
"fail respClient getCoordinatorVpnKey": {
existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{getCoordinatorVpnKeyErr: someErr}},
waiter: stubStatusWaiter{},
pubKey: testKey,
errExpected: true,
},
"fail respClient getClientVpnIp": {
existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{getClientVpnIpErr: someErr}},
waiter: stubStatusWaiter{},
pubKey: testKey,
errExpected: true,
},
"fail respClient getOwnerID": {
existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{getOwnerIDErr: someErr}},
waiter: stubStatusWaiter{},
pubKey: testKey,
errExpected: true,
},
"fail respClient getClusterID": {
existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{getClusterIDErr: someErr}},
waiter: stubStatusWaiter{},
pubKey: testKey,
errExpected: true,
},
"fail to wait for required status": {
existingState: testGcpState,
client: &stubProtoClient{},
waiter: stubStatusWaiter{waitForAllErr: someErr},
pubKey: testKey,
errExpected: true,
},
"fail to create service account": {
existingState: testGcpState,
client: &stubProtoClient{},
serviceAccountCreator: stubServiceAccountCreator{
createErr: someErr,
},
waiter: stubStatusWaiter{},
pubKey: testKey,
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := newInitCmd()
var out bytes.Buffer
cmd.SetOut(&out)
var errOut bytes.Buffer
cmd.SetErr(&errOut)
fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs)
require.NoError(fileHandler.WriteJSON(*config.StatePath, tc.existingState, false))
// Write key file to filesystem and set path in flag.
require.NoError(afero.Afero{Fs: fs}.WriteFile("pubKPath", []byte(tc.pubKey), 0o600))
require.NoError(cmd.Flags().Set("publickey", "pubKPath"))
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
defer cancel()
err := initialize(ctx, cmd, tc.client, &dummyVPNConfigurer{}, &tc.serviceAccountCreator, fileHandler, config, tc.waiter)
if tc.errExpected {
assert.Error(err)
} else {
require.NoError(err)
assert.Contains(out.String(), "vpnIp")
assert.Contains(out.String(), "coordKey")
assert.Contains(out.String(), "ownerID")
assert.Contains(out.String(), "clusterID")
}
})
}
}
func TestConfigureVPN(t *testing.T) {
assert := assert.New(t)
key := []byte(base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting")))
ip := "192.0.2.1"
someErr := errors.New("failed")
configurer := stubVPNConfigurer{}
assert.NoError(configureVpn(&configurer, ip, string(key), ip, key))
assert.True(configurer.configured)
configurer = stubVPNConfigurer{configureErr: someErr}
assert.Error(configureVpn(&configurer, ip, string(key), ip, key))
}
func TestWriteOutput(t *testing.T) {
assert := assert.New(t)
result := activationResult{
clientVpnIP: "foo-qq",
coordinatorPubKey: "bar-qq",
coordinatorPubIP: "baz-qq",
kubeconfig: "foo-bar-baz-qq",
}
var out bytes.Buffer
testFs := afero.NewMemMapFs()
fileHandler := file.NewHandler(testFs)
config := config.Default()
err := result.writeOutput(&out, fileHandler, config)
assert.NoError(err)
assert.Contains(out.String(), result.clientVpnIP)
assert.Contains(out.String(), result.coordinatorPubIP)
assert.Contains(out.String(), result.coordinatorPubKey)
afs := afero.Afero{Fs: testFs}
adminConf, err := afs.ReadFile(*config.AdminConfPath)
assert.NoError(err)
assert.Equal(result.kubeconfig, string(adminConf))
}
func TestIpsToEndpoints(t *testing.T) {
assert := assert.New(t)
ips := []string{"192.0.2.1", "192.0.2.2", "192.0.2.3"}
port := "8080"
endpoints := ipsToEndpoints(ips, port)
assert.Equal([]string{"192.0.2.1:8080", "192.0.2.2:8080", "192.0.2.3:8080"}, endpoints)
}
func TestCheckBase64WGKEy(t *testing.T) {
assert := assert.New(t)
key := []byte(base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting")))
assert.NoError(checkBase64WGKey(key))
key = []byte(base64.StdEncoding.EncodeToString([]byte("shortKey")))
assert.Error(checkBase64WGKey(key))
key = []byte(base64.StdEncoding.EncodeToString([]byte("looooooooooongKeyWithMoreThan32Bytes")))
assert.Error(checkBase64WGKey(key))
key = []byte("noBase 64")
assert.Error(checkBase64WGKey(key))
}
func TestInitCompletion(t *testing.T) {
testCases := map[string]struct {
args []string
toComplete string
resultExpected []string
shellCDExpected cobra.ShellCompDirective
}{
"first arg": {
args: []string{},
toComplete: "hello",
resultExpected: []string{},
shellCDExpected: cobra.ShellCompDirectiveDefault,
},
"secnod arg": {
args: []string{"23"},
toComplete: "/test/h",
resultExpected: []string{},
shellCDExpected: cobra.ShellCompDirectiveError,
},
"third arg": {
args: []string{"./file", "sth"},
toComplete: "./file",
resultExpected: []string{},
shellCDExpected: cobra.ShellCompDirectiveError,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
cmd := &cobra.Command{}
result, shellCD := initCompletion(cmd, tc.args, tc.toComplete)
assert.Equal(tc.resultExpected, result)
assert.Equal(tc.shellCDExpected, shellCD)
})
}
}
func TestReadVpnKey(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
testKeyA := []byte(base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting")))
testKeyB := []byte(base64.StdEncoding.EncodeToString([]byte("anotherWireGuardKeyForTheTesting")))
fileHandler := file.NewHandler(afero.NewMemMapFs())
require.NoError(fileHandler.Write("testKeyA", testKeyA, false))
require.NoError(fileHandler.Write("testKeyB", testKeyB, false))
// provide privK
privK, _, err := readVpnKey(fileHandler, "testKeyA", "")
assert.NoError(err)
assert.Equal(testKeyA, privK)
// provide pubK
_, pubK, err := readVpnKey(fileHandler, "", "testKeyA")
assert.NoError(err)
assert.Equal(testKeyA, pubK)
// provide both, privK should be used, pubK ignored
privK, pubK, err = readVpnKey(fileHandler, "testKeyA", "testKeyB")
assert.NoError(err)
assert.Equal(testKeyA, privK)
assert.NotEqual(testKeyB, pubK)
// no path provided
_, _, err = readVpnKey(fileHandler, "", "")
assert.Error(err)
}
func TestReadOrGeneratedMasterSecret(t *testing.T) {
testCases := map[string]struct {
filename string
filecontent string
createFile bool
fs func() afero.Fs
errExpected bool
}{
"file with secret exists": {
filename: "someSecret",
filecontent: base64.StdEncoding.EncodeToString([]byte("ConstellationSecret")),
createFile: true,
fs: afero.NewMemMapFs,
errExpected: false,
},
"no file given": {
filename: "",
filecontent: "",
fs: afero.NewMemMapFs,
errExpected: false,
},
"file does not exist": {
filename: "nonExistingSecret",
filecontent: "",
createFile: false,
fs: afero.NewMemMapFs,
errExpected: true,
},
"file is empty": {
filename: "emptySecret",
filecontent: "",
createFile: true,
fs: afero.NewMemMapFs,
errExpected: true,
},
"secret too short": {
filename: "shortSecret",
filecontent: base64.StdEncoding.EncodeToString([]byte("short")),
createFile: true,
fs: afero.NewMemMapFs,
errExpected: true,
},
"secret not encoded": {
filename: "unencodedSecret",
filecontent: "Constellation",
createFile: true,
fs: afero.NewMemMapFs,
errExpected: true,
},
"file not writeable": {
filename: "",
filecontent: "",
createFile: false,
fs: func() afero.Fs { return afero.NewReadOnlyFs(afero.NewMemMapFs()) },
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fileHandler := file.NewHandler(tc.fs())
config := config.Default()
if tc.createFile {
require.NoError(fileHandler.Write(tc.filename, []byte(tc.filecontent), false))
}
var out bytes.Buffer
secret, err := readOrGeneratedMasterSecret(&out, fileHandler, tc.filename, config)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
if tc.filename == "" {
require.Contains(out.String(), *config.MasterSecretPath)
filename := strings.Split(out.String(), "./")
tc.filename = strings.Trim(filename[1], "\n")
}
content, err := fileHandler.Read(tc.filename)
require.NoError(err)
assert.Equal(content, []byte(base64.StdEncoding.EncodeToString(secret)))
}
})
}
}
func TestAutoscaleFlag(t *testing.T) {
testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))
config := config.Default()
testEc2State := state.ConstellationState{
EC2Instances: ec2.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
"id-2": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
},
EC2SecurityGroup: "sg-test",
}
testGcpState := state.ConstellationState{
GCPNodes: gcp.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
GCPCoordinators: gcp.Instances{
"id-c": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
}
testAzureState := state.ConstellationState{
AzureNodes: azure.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
AzureCoordinators: azure.Instances{
"id-c": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
AzureResourceGroup: "test",
}
testActivationResps := []fakeActivationRespMessage{
{log: "testlog1"},
{log: "testlog2"},
{
kubeconfig: "kubeconfig",
clientVpnIp: "vpnIp",
coordinatorVpnKey: "coordKey",
ownerID: "ownerID",
clusterID: "clusterID",
},
{log: "testlog3"},
}
testCases := map[string]struct {
autoscaleFlag bool
existingState state.ConstellationState
client *stubProtoClient
serviceAccountCreator stubServiceAccountCreator
waiter statusWaiter
pubKey string
}{
"initialize some ec2 instances without autoscale flag": {
autoscaleFlag: false,
existingState: testEc2State,
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
pubKey: testKey,
},
"initialize some gcp instances without autoscale flag": {
autoscaleFlag: false,
existingState: testGcpState,
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
pubKey: testKey,
},
"initialize some azure instances without autoscale flag": {
autoscaleFlag: false,
existingState: testAzureState,
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
pubKey: testKey,
},
"initialize some ec2 instances with autoscale flag": {
autoscaleFlag: true,
existingState: testEc2State,
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
pubKey: testKey,
},
"initialize some gcp instances with autoscale flag": {
autoscaleFlag: true,
existingState: testGcpState,
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
pubKey: testKey,
},
"initialize some azure instances with autoscale flag": {
autoscaleFlag: true,
existingState: testAzureState,
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
pubKey: testKey,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := newInitCmd()
var out bytes.Buffer
cmd.SetOut(&out)
var errOut bytes.Buffer
cmd.SetErr(&errOut)
fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs)
require.NoError(fileHandler.WriteJSON(*config.StatePath, tc.existingState, false))
// Write key file to filesystem and set path in flag.
require.NoError(afero.Afero{Fs: fs}.WriteFile("pubKPath", []byte(tc.pubKey), 0o600))
require.NoError(cmd.Flags().Set("publickey", "pubKPath"))
require.NoError(cmd.Flags().Set("autoscale", strconv.FormatBool(tc.autoscaleFlag)))
ctx := context.Background()
require.NoError(initialize(ctx, cmd, tc.client, &dummyVPNConfigurer{}, &tc.serviceAccountCreator, fileHandler, config, tc.waiter))
if tc.autoscaleFlag {
assert.Len(tc.client.activateAutoscalingNodeGroups, 1)
} else {
assert.Len(tc.client.activateAutoscalingNodeGroups, 0)
}
})
}
}

13
cli/cmd/protoclient.go Normal file
View File

@ -0,0 +1,13 @@
package cmd
import (
"context"
"github.com/edgelesssys/constellation/cli/proto"
)
type protoClient interface {
Connect(ip string, port string) error
Close() error
Activate(ctx context.Context, userPublicKey, masterSecret []byte, endpoints, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error)
}

188
cli/cmd/protoclient_test.go Normal file
View File

@ -0,0 +1,188 @@
package cmd
import (
"context"
"errors"
"fmt"
"io"
"github.com/edgelesssys/constellation/cli/proto"
)
type stubProtoClient struct {
conn bool
respClient proto.ActivationResponseClient
connectErr error
closeErr error
activateErr error
activateUserPublicKey []byte
activateMasterSecret []byte
activateEndpoints []string
activateAutoscalingNodeGroups []string
cloudServiceAccountURI string
}
func (c *stubProtoClient) Connect(ip string, port string) error {
c.conn = true
return c.connectErr
}
func (c *stubProtoClient) Close() error {
c.conn = false
return c.closeErr
}
func (c *stubProtoClient) Activate(ctx context.Context, userPublicKey, masterSecret []byte, endpoints []string, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error) {
c.activateUserPublicKey = userPublicKey
c.activateMasterSecret = masterSecret
c.activateEndpoints = endpoints
c.activateAutoscalingNodeGroups = autoscalingNodeGroups
c.cloudServiceAccountURI = cloudServiceAccountURI
return c.respClient, c.activateErr
}
type stubActivationRespClient struct {
nextLogErr *error
getKubeconfigErr error
getCoordinatorVpnKeyErr error
getClientVpnIpErr error
getOwnerIDErr error
getClusterIDErr error
writeLogStreamErr error
}
func (s *stubActivationRespClient) NextLog() (string, error) {
if s.nextLogErr == nil {
return "", io.EOF
}
return "", *s.nextLogErr
}
func (s *stubActivationRespClient) WriteLogStream(io.Writer) error {
return s.writeLogStreamErr
}
func (s *stubActivationRespClient) GetKubeconfig() (string, error) {
return "", s.getKubeconfigErr
}
func (s *stubActivationRespClient) GetCoordinatorVpnKey() (string, error) {
return "", s.getCoordinatorVpnKeyErr
}
func (s *stubActivationRespClient) GetClientVpnIp() (string, error) {
return "", s.getClientVpnIpErr
}
func (s *stubActivationRespClient) GetOwnerID() (string, error) {
return "", s.getOwnerIDErr
}
func (s *stubActivationRespClient) GetClusterID() (string, error) {
return "", s.getClusterIDErr
}
type fakeProtoClient struct {
conn bool
respClient proto.ActivationResponseClient
}
func (c *fakeProtoClient) Connect(ip string, port string) error {
c.conn = true
return nil
}
func (c *fakeProtoClient) Close() error {
c.conn = false
return nil
}
func (c *fakeProtoClient) Activate(ctx context.Context, userPublicKey, masterSecret []byte, endpoints []string, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error) {
if !c.conn {
return nil, errors.New("client is not connected")
}
return c.respClient, nil
}
type fakeActivationRespClient struct {
responses []fakeActivationRespMessage
kubeconfig string
coordinatorVpnKey string
clientVpnIp string
ownerID string
clusterID string
}
func (c *fakeActivationRespClient) NextLog() (string, error) {
for len(c.responses) > 0 {
resp := c.responses[0]
c.responses = c.responses[1:]
if len(resp.log) > 0 {
return resp.log, nil
}
c.kubeconfig = resp.kubeconfig
c.coordinatorVpnKey = resp.coordinatorVpnKey
c.clientVpnIp = resp.clientVpnIp
c.ownerID = resp.ownerID
c.clusterID = resp.clusterID
}
return "", io.EOF
}
func (c *fakeActivationRespClient) WriteLogStream(w io.Writer) error {
log, err := c.NextLog()
for err == nil {
fmt.Fprint(w, log)
log, err = c.NextLog()
}
if !errors.Is(err, io.EOF) {
return err
}
return nil
}
func (c *fakeActivationRespClient) GetKubeconfig() (string, error) {
if c.kubeconfig == "" {
return "", errors.New("kubeconfig is empty")
}
return c.kubeconfig, nil
}
func (c *fakeActivationRespClient) GetCoordinatorVpnKey() (string, error) {
if c.coordinatorVpnKey == "" {
return "", errors.New("coordinator public VPN key is empty")
}
return c.coordinatorVpnKey, nil
}
func (c *fakeActivationRespClient) GetClientVpnIp() (string, error) {
if c.clientVpnIp == "" {
return "", errors.New("client VPN IP is empty")
}
return c.clientVpnIp, nil
}
func (c *fakeActivationRespClient) GetOwnerID() (string, error) {
if c.ownerID == "" {
return "", errors.New("init secret is empty")
}
return c.ownerID, nil
}
func (c *fakeActivationRespClient) GetClusterID() (string, error) {
if c.clusterID == "" {
return "", errors.New("cluster identifier is empty")
}
return c.clusterID, nil
}
type fakeActivationRespMessage struct {
log string
kubeconfig string
coordinatorVpnKey string
clientVpnIp string
ownerID string
clusterID string
}

60
cli/cmd/rollback.go Normal file
View File

@ -0,0 +1,60 @@
package cmd
import (
"context"
"fmt"
"io"
"go.uber.org/multierr"
)
// rollbacker does a rollback.
type rollbacker interface {
rollback(ctx context.Context) error
}
// rollbackOnError calls rollback on the rollbacker if the handed error is not nil,
// and writes logs to the writer w.
func rollbackOnError(ctx context.Context, w io.Writer, onErr *error, roll rollbacker) {
if *onErr == nil {
return
}
fmt.Fprintf(w, "An error occurred: %s\n", *onErr)
fmt.Fprintln(w, "Attempting to roll back.")
if err := roll.rollback(ctx); err != nil {
*onErr = multierr.Append(*onErr, fmt.Errorf("on rollback: %w", err)) // TODO: print the error, or retrun it?
return
}
fmt.Fprintln(w, "Rollback succeeded.")
}
type rollbackerGCP struct {
client gcpclient
}
func (r *rollbackerGCP) rollback(ctx context.Context) error {
var err error
err = multierr.Append(err, r.client.TerminateInstances(ctx))
err = multierr.Append(err, r.client.TerminateFirewall(ctx))
err = multierr.Append(err, r.client.TerminateVPCs(ctx))
return err
}
type rollbackerAzure struct {
client azureclient
}
func (r *rollbackerAzure) rollback(ctx context.Context) error {
return r.client.TerminateResourceGroup(ctx)
}
type rollbackerAWS struct {
client ec2client
}
func (r *rollbackerAWS) rollback(ctx context.Context) error {
var err error
err = multierr.Append(err, r.client.TerminateInstances(ctx))
err = multierr.Append(err, r.client.DeleteSecurityGroup(ctx))
return err
}

64
cli/cmd/root.go Normal file
View File

@ -0,0 +1,64 @@
package cmd
import (
"context"
"fmt"
"os"
"os/signal"
"github.com/spf13/cobra"
)
var rootCmd = &cobra.Command{
Use: "constellation",
Short: "Set up your Constellation cluster.",
Long: "Set up your Constellation cluster.",
SilenceUsage: true,
}
// Execute starts the CLI.
func Execute() error {
ctx, cancel := signalContext(context.Background(), os.Interrupt)
defer cancel()
return rootCmd.ExecuteContext(ctx)
}
// signalContext returns a context that is canceled on the handed signal.
// The signal isn't watched after its first occurrence. Call the cancel
// function to ensure the internal goroutine is stopped and the signal isn't
// watched any longer.
func signalContext(ctx context.Context, sig os.Signal) (context.Context, context.CancelFunc) {
sigCtx, stop := signal.NotifyContext(ctx, sig)
done := make(chan struct{}, 1)
stopDone := make(chan struct{}, 1)
go func() {
defer func() { stopDone <- struct{}{} }()
defer stop()
select {
case <-sigCtx.Done():
fmt.Println(" Signal caught. Press ctrl+c again to terminate the program immediately.")
case <-done:
}
}()
cancelFunc := func() {
done <- struct{}{}
<-stopDone
}
return sigCtx, cancelFunc
}
func init() {
// Set output of cmd.Print to stdout.
rootCmd.SetOut(os.Stdout)
// Disable --no-description flag for completion command.
rootCmd.CompletionOptions.DisableNoDescFlag = true
rootCmd.PersistentFlags().String("dev-config", "", "Set this flag to create the Constellation using settings from a development config.")
rootCmd.AddCommand(newVersionCmd())
rootCmd.AddCommand(newCreateCmd())
rootCmd.AddCommand(newInitCmd())
rootCmd.AddCommand(newTerminateCmd())
rootCmd.AddCommand(newVerifyCmd())
}

View File

@ -0,0 +1,95 @@
package cmd
import (
"context"
"fmt"
azurecl "github.com/edgelesssys/constellation/cli/azure/client"
"github.com/edgelesssys/constellation/cli/cloudprovider"
ec2cl "github.com/edgelesssys/constellation/cli/ec2/client"
gcpcl "github.com/edgelesssys/constellation/cli/gcp/client"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/state"
)
type serviceAccountCreator interface {
createServiceAccount(ctx context.Context, stat state.ConstellationState, config *config.Config) (string, state.ConstellationState, error)
}
type serviceAccountClient struct{}
// createServiceAccount creates a new cloud provider service account with access to the created resources.
func (c serviceAccountClient) createServiceAccount(ctx context.Context, stat state.ConstellationState, config *config.Config) (string, state.ConstellationState, error) {
switch stat.CloudProvider {
case cloudprovider.AWS.String():
// TODO: implement
ec2client, err := ec2cl.NewFromDefault(ctx)
if err != nil {
return "", state.ConstellationState{}, err
}
return c.createServiceAccountEC2(ctx, ec2client, stat, config)
case cloudprovider.GCP.String():
gcpclient, err := gcpcl.NewFromDefault(ctx)
if err != nil {
return "", state.ConstellationState{}, err
}
serviceAccount, stat, err := c.createServiceAccountGCP(ctx, gcpclient, stat, config)
if err != nil {
return "", state.ConstellationState{}, err
}
return serviceAccount, stat, gcpclient.Close()
case cloudprovider.Azure.String():
azureclient, err := azurecl.NewFromDefault(stat.AzureSubscription, stat.AzureTenant)
if err != nil {
return "", state.ConstellationState{}, err
}
return c.createServiceAccountAzure(ctx, azureclient, stat)
}
return "", state.ConstellationState{}, fmt.Errorf("unknown cloud provider %v", stat.CloudProvider)
}
func (c serviceAccountClient) createServiceAccountAzure(ctx context.Context, cl azureclient, stat state.ConstellationState) (string, state.ConstellationState, error) {
if err := cl.SetState(stat); err != nil {
return "", state.ConstellationState{}, fmt.Errorf("failed to set state while creating service account: %w", err)
}
serviceAccount, err := cl.CreateServicePrincipal(ctx)
if err != nil {
return "", state.ConstellationState{}, fmt.Errorf("failed to create service account: %w", err)
}
stat, err = cl.GetState()
if err != nil {
return "", state.ConstellationState{}, fmt.Errorf("failed to get state after creating service account: %w", err)
}
return serviceAccount, stat, nil
}
func (c serviceAccountClient) createServiceAccountGCP(ctx context.Context, cl gcpclient, stat state.ConstellationState, config *config.Config) (string, state.ConstellationState, error) {
if err := cl.SetState(stat); err != nil {
return "", state.ConstellationState{}, fmt.Errorf("failed to set state while creating service account: %w", err)
}
input := gcpcl.ServiceAccountInput{
Roles: *config.Provider.GCP.ServiceAccountRoles,
}
serviceAccount, err := cl.CreateServiceAccount(ctx, input)
if err != nil {
return "", state.ConstellationState{}, fmt.Errorf("failed to create service account: %w", err)
}
stat, err = cl.GetState()
if err != nil {
return "", state.ConstellationState{}, fmt.Errorf("failed to get state after creating service account: %w", err)
}
return serviceAccount, stat, nil
}
//nolint:unparam
func (c serviceAccountClient) createServiceAccountEC2(ctx context.Context, cl ec2client, stat state.ConstellationState, config *config.Config) (string, state.ConstellationState, error) {
// TODO: implement
if err := cl.SetState(stat); err != nil {
return "", state.ConstellationState{}, fmt.Errorf("failed to set state while creating service account: %w", err)
}
return "", stat, nil
}

View File

@ -0,0 +1,136 @@
package cmd
import (
"context"
"errors"
"testing"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/cli/gcp"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/state"
"github.com/stretchr/testify/assert"
)
func TestCreateServiceAccountAzure(t *testing.T) {
testState := state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
}
someErr := errors.New("failed")
testCases := map[string]struct {
existingState state.ConstellationState
client azureclient
errExpected bool
}{
"create service account works": {
existingState: testState,
client: &fakeAzureClient{},
},
"fail setState": {
existingState: testState,
client: &stubAzureClient{setStateErr: someErr},
errExpected: true,
},
"fail create": {
existingState: testState,
client: &stubAzureClient{createServicePrincipalErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := serviceAccountClient{}
serviceAccount, _, err := client.createServiceAccountAzure(context.Background(), tc.client, tc.existingState)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.NotNil(serviceAccount)
stat, err := tc.client.GetState()
assert.NoError(err)
assert.Equal(state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureADAppObjectID: "00000000-0000-0000-0000-000000000001",
}, stat)
}
})
}
}
func TestCreateServiceAccountGCP(t *testing.T) {
testState := state.ConstellationState{
GCPProject: "project",
GCPNodes: gcp.Instances{},
GCPCoordinators: gcp.Instances{},
GCPNodeInstanceGroup: "nodes-group",
GCPCoordinatorInstanceGroup: "coordinator-group",
GCPNodeInstanceTemplate: "template",
GCPCoordinatorInstanceTemplate: "template",
GCPNetwork: "network",
GCPFirewalls: []string{},
}
config := config.Default()
someErr := errors.New("failed")
testCases := map[string]struct {
existingState state.ConstellationState
client gcpclient
errExpected bool
}{
"create service account works": {
existingState: testState,
client: &fakeGcpClient{},
},
"fail setState": {
existingState: testState,
client: &stubGcpClient{setStateErr: someErr},
errExpected: true,
},
"fail create": {
existingState: testState,
client: &stubGcpClient{createServiceAccountErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := serviceAccountClient{}
serviceAccount, _, err := client.createServiceAccountGCP(context.Background(), tc.client, tc.existingState, config)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.NotNil(serviceAccount)
stat, err := tc.client.GetState()
assert.NoError(err)
assert.Equal(state.ConstellationState{
CloudProvider: cloudprovider.GCP.String(),
GCPProject: "project",
GCPNodes: gcp.Instances{},
GCPCoordinators: gcp.Instances{},
GCPNodeInstanceGroup: "nodes-group",
GCPCoordinatorInstanceGroup: "coordinator-group",
GCPNodeInstanceTemplate: "template",
GCPCoordinatorInstanceTemplate: "template",
GCPNetwork: "network",
GCPFirewalls: []string{},
GCPServiceAccount: "service-account@project.iam.gserviceaccount.com",
}, stat)
}
})
}
}
type stubServiceAccountCreator struct {
cloudServiceAccountURI string
createErr error
}
func (c *stubServiceAccountCreator) createServiceAccount(ctx context.Context, stat state.ConstellationState, config *config.Config) (string, state.ConstellationState, error) {
return c.cloudServiceAccountURI, stat, c.createErr
}

11
cli/cmd/statuswaiter.go Normal file
View File

@ -0,0 +1,11 @@
package cmd
import (
"context"
"github.com/edgelesssys/constellation/coordinator/state"
)
type statusWaiter interface {
WaitForAll(ctx context.Context, status state.State, endpoints []string) error
}

View File

@ -0,0 +1,15 @@
package cmd
import (
"context"
"github.com/edgelesssys/constellation/coordinator/state"
)
type stubStatusWaiter struct {
waitForAllErr error
}
func (w stubStatusWaiter) WaitForAll(ctx context.Context, status state.State, endpoints []string) error {
return w.waitForAllErr
}

131
cli/cmd/terminate.go Normal file
View File

@ -0,0 +1,131 @@
package cmd
import (
"errors"
"fmt"
"io/fs"
"github.com/spf13/afero"
"github.com/spf13/cobra"
azure "github.com/edgelesssys/constellation/cli/azure/client"
ec2 "github.com/edgelesssys/constellation/cli/ec2/client"
"github.com/edgelesssys/constellation/cli/file"
gcp "github.com/edgelesssys/constellation/cli/gcp/client"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/state"
)
func newTerminateCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "terminate",
Short: "Terminate an existing Constellation.",
Long: "Terminate an existing Constellation. The Constellation can't be started again, and all persistent storage will be lost.",
Args: cobra.NoArgs,
RunE: runTerminate,
}
return cmd
}
// runTerminate runs the terminate command.
func runTerminate(cmd *cobra.Command, args []string) error {
fileHandler := file.NewHandler(afero.NewOsFs())
devConfigName, err := cmd.Flags().GetString("dev-config")
if err != nil {
return err
}
config, err := config.FromFile(fileHandler, devConfigName)
if err != nil {
return err
}
return terminate(cmd, fileHandler, config)
}
func terminate(cmd *cobra.Command, fileHandler file.Handler, config *config.Config) error {
var stat state.ConstellationState
if err := fileHandler.ReadJSON(*config.StatePath, &stat); err != nil {
return err
}
if len(stat.EC2Instances) != 0 || stat.EC2SecurityGroup != "" {
ec2client, err := ec2.NewFromDefault(cmd.Context())
if err != nil {
return err
}
if err := terminateEC2(cmd, ec2client, stat); err != nil {
return err
}
}
// TODO: improve check, also look for other resources that might need to be terminated
if len(stat.GCPNodes) != 0 {
gcpclient, err := gcp.NewFromDefault(cmd.Context())
if err != nil {
return err
}
if err := terminateGCP(cmd, gcpclient, stat); err != nil {
return err
}
}
if len(stat.AzureResourceGroup) != 0 {
azureclient, err := azure.NewFromDefault(stat.AzureSubscription, stat.AzureTenant)
if err != nil {
return err
}
if err := terminateAzure(cmd, azureclient, stat); err != nil {
return err
}
}
cmd.Println("Your Constellation was terminated successfully.")
if err := fileHandler.Remove(*config.StatePath); err != nil {
return fmt.Errorf("failed to remove file '%s', please remove manually", *config.StatePath)
}
if err := fileHandler.Remove(*config.AdminConfPath); err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("failed to remove file '%s', please remove manually", *config.AdminConfPath)
}
return nil
}
func terminateAzure(cmd *cobra.Command, cl azureclient, stat state.ConstellationState) error {
if err := cl.SetState(stat); err != nil {
return fmt.Errorf("failed to terminate the Constellation: %w", err)
}
if err := cl.TerminateServicePrincipal(cmd.Context()); err != nil {
return err
}
return cl.TerminateResourceGroup(cmd.Context())
}
func terminateGCP(cmd *cobra.Command, cl gcpclient, stat state.ConstellationState) error {
if err := cl.SetState(stat); err != nil {
return fmt.Errorf("failed to terminate the Constellation: %w", err)
}
if err := cl.TerminateInstances(cmd.Context()); err != nil {
return err
}
if err := cl.TerminateFirewall(cmd.Context()); err != nil {
return err
}
if err := cl.TerminateVPCs(cmd.Context()); err != nil {
return err
}
return cl.TerminateServiceAccount(cmd.Context())
}
// terminateEC2 and remove the existing Constellation form the state file.
func terminateEC2(cmd *cobra.Command, cl ec2client, stat state.ConstellationState) error {
if err := cl.SetState(stat); err != nil {
return fmt.Errorf("failed to terminate the Constellation: %w", err)
}
if err := cl.TerminateInstances(cmd.Context()); err != nil {
return fmt.Errorf("failed to terminate the Constellation: %w", err)
}
return cl.DeleteSecurityGroup(cmd.Context())
}

288
cli/cmd/terminate_test.go Normal file
View File

@ -0,0 +1,288 @@
package cmd
import (
"bytes"
"errors"
"testing"
"github.com/edgelesssys/constellation/cli/azure"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/cli/ec2"
"github.com/edgelesssys/constellation/cli/gcp"
"github.com/edgelesssys/constellation/internal/state"
"github.com/stretchr/testify/assert"
)
func TestTerminateCmdArgumentValidation(t *testing.T) {
testCases := map[string]struct {
args []string
expectErr bool
}{
"no args": {[]string{}, false},
"some args": {[]string{"hello", "test"}, true},
"some other args": {[]string{"12", "2"}, true},
}
cmd := newTerminateCmd()
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := cmd.ValidateArgs(tc.args)
if tc.expectErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestTerminateEC2(t *testing.T) {
testState := state.ConstellationState{
CloudProvider: cloudprovider.AWS.String(),
EC2Instances: ec2.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-3": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
EC2SecurityGroup: "sg-test",
}
someErr := errors.New("failed")
testCases := map[string]struct {
existingState state.ConstellationState
client ec2client
errExpected bool
}{
"terminate existing instances": {
existingState: testState,
client: &fakeEc2Client{},
errExpected: false,
},
"state without instances": {
existingState: state.ConstellationState{
CloudProvider: cloudprovider.AWS.String(),
EC2Instances: ec2.Instances{},
},
client: &fakeEc2Client{},
errExpected: true,
},
"fail TerminateInstances": {
existingState: testState,
client: &stubEc2Client{terminateInstancesErr: someErr},
errExpected: true,
},
"fail DeleteSecurityGroup": {
existingState: testState,
client: &stubEc2Client{deleteSecurityGroupErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
cmd := newTerminateCmd()
out := bytes.NewBufferString("")
cmd.SetOut(out)
errOut := bytes.NewBufferString("")
cmd.SetErr(errOut)
err := terminateEC2(cmd, tc.client, tc.existingState)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestTerminateGCP(t *testing.T) {
testState := state.ConstellationState{
GCPNodes: gcp.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
GCPCoordinators: gcp.Instances{
"id-c": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
GCPNodeInstanceGroup: "nodes-group",
GCPCoordinatorInstanceGroup: "coordinator-group",
GCPNodeInstanceTemplate: "template",
GCPCoordinatorInstanceTemplate: "template",
GCPNetwork: "network",
GCPFirewalls: []string{"coordinator", "wireguard", "ssh"},
}
someErr := errors.New("failed")
testCases := map[string]struct {
existingState state.ConstellationState
client gcpclient
errExpected bool
}{
"terminate existing instances": {
existingState: testState,
client: &fakeGcpClient{},
},
"state without instances": {
existingState: state.ConstellationState{EC2Instances: ec2.Instances{}},
client: &fakeGcpClient{},
},
"state not found": {
existingState: testState,
client: &fakeGcpClient{},
},
"fail setState": {
existingState: testState,
client: &stubGcpClient{setStateErr: someErr},
errExpected: true,
},
"fail terminateFirewall": {
existingState: testState,
client: &stubGcpClient{terminateFirewallErr: someErr},
errExpected: true,
},
"fail terminateVPC": {
existingState: testState,
client: &stubGcpClient{terminateVPCsErr: someErr},
errExpected: true,
},
"fail terminateInstances": {
existingState: testState,
client: &stubGcpClient{terminateInstancesErr: someErr},
errExpected: true,
},
"fail terminateServiceAccount": {
existingState: testState,
client: &stubGcpClient{terminateServiceAccountErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
cmd := newTerminateCmd()
out := bytes.NewBufferString("")
cmd.SetOut(out)
errOut := bytes.NewBufferString("")
cmd.SetErr(errOut)
err := terminateGCP(cmd, tc.client, tc.existingState)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
stat, err := tc.client.GetState()
assert.NoError(err)
assert.Equal(state.ConstellationState{
CloudProvider: cloudprovider.GCP.String(),
}, stat)
}
})
}
}
func TestTerminateAzure(t *testing.T) {
testState := state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: azure.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
AzureCoordinators: azure.Instances{
"id-c": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
AzureResourceGroup: "test",
}
someErr := errors.New("failed")
testCases := map[string]struct {
existingState state.ConstellationState
client azureclient
errExpected bool
}{
"terminate existing instances": {
existingState: testState,
client: &fakeAzureClient{},
},
"state resource group": {
existingState: state.ConstellationState{AzureResourceGroup: ""},
client: &fakeAzureClient{},
},
"state not found": {
existingState: testState,
client: &fakeAzureClient{},
},
"fail setState": {
existingState: testState,
client: &stubAzureClient{setStateErr: someErr},
errExpected: true,
},
"fail resource group termination": {
existingState: testState,
client: &stubAzureClient{terminateResourceGroupErr: someErr},
errExpected: true,
},
"fail service principal termination": {
existingState: testState,
client: &stubAzureClient{terminateServicePrincipalErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
cmd := newTerminateCmd()
out := bytes.NewBufferString("")
cmd.SetOut(out)
errOut := bytes.NewBufferString("")
cmd.SetErr(errOut)
err := terminateAzure(cmd, tc.client, tc.existingState)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
stat, err := tc.client.GetState()
assert.NoError(err)
assert.Equal(state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
}, stat)
}
})
}
}

View File

@ -0,0 +1,80 @@
package cmd
import (
"bufio"
"errors"
"fmt"
"strings"
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
"github.com/spf13/cobra"
)
var (
// ErrInvalidInput is an error where user entered invalid input.
ErrInvalidInput = errors.New("user made invalid input")
warningStr = "Warning: not verifying the Constellation's %s measurements\n"
)
// askToConfirm asks user to confirm an action.
// The user will be asked the handed question and can answer with
// yes or no.
func askToConfirm(cmd *cobra.Command, question string) (bool, error) {
reader := bufio.NewReader(cmd.InOrStdin())
cmd.Printf("%s [y/n]: ", question)
for i := 0; i < 3; i++ {
resp, err := reader.ReadString('\n')
if err != nil {
return false, err
}
resp = strings.ToLower(strings.TrimSpace(resp))
if resp == "n" || resp == "no" {
return false, nil
}
if resp == "y" || resp == "yes" {
return true, nil
}
cmd.Printf("Type 'y' or 'yes' to confirm, or abort action with 'n' or 'no': ")
}
return false, ErrInvalidInput
}
// warnAboutPCRs displays warnings if specifc PCR values are not verified.
//
// PCR allocation inspired by https://link.springer.com/chapter/10.1007/978-1-4302-6584-9_12#Tab1
func warnAboutPCRs(cmd *cobra.Command, pcrs map[uint32][]byte, checkInit bool) error {
for k, v := range pcrs {
if len(v) != 32 {
return fmt.Errorf("bad config: PCR[%d]: expected length: %d, but got: %d", k, 32, len(v))
}
}
if pcrs[0] == nil || pcrs[1] == nil {
cmd.PrintErrf(warningStr, "BIOS")
}
if pcrs[2] == nil || pcrs[3] == nil {
cmd.PrintErrf(warningStr, "OPROM")
}
if pcrs[4] == nil || pcrs[5] == nil {
cmd.PrintErrf(warningStr, "MBR")
}
// GRUB measures kernel command line and initrd into pcrs 8 and 9
if pcrs[8] == nil {
cmd.PrintErrf(warningStr, "kernel command line")
}
if pcrs[9] == nil {
cmd.PrintErrf(warningStr, "initrd")
}
// Only warn about initialization PCRs if necessary
if checkInit {
if pcrs[uint32(vtpm.PCRIndexOwnerID)] == nil || pcrs[uint32(vtpm.PCRIndexClusterID)] == nil {
cmd.PrintErrf(warningStr, "initialization status")
}
}
return nil
}

View File

@ -0,0 +1,249 @@
package cmd
import (
"bytes"
"errors"
"io"
"testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
)
func TestAskToConfirm(t *testing.T) {
// errAborted is an error where the user aborted the action.
errAborted := errors.New("user aborted")
cmd := &cobra.Command{
Use: "test",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error {
ok, err := askToConfirm(cmd, "777")
if err != nil {
return err
}
if !ok {
return errAborted
}
return nil
},
}
testCases := map[string]struct {
input string
expectedErr error
}{
"user confirms": {"y\n", nil},
"user confirms long": {"yes\n", nil},
"user disagrees": {"n\n", errAborted},
"user disagrees long": {"no\n", errAborted},
"user is first unsure, but agrees": {"what?\ny\n", nil},
"user is first unsure, but disagrees": {"wait.\nn\n", errAborted},
"repeated invalid input": {"h\nb\nq\n", ErrInvalidInput},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
out := bytes.NewBufferString("")
cmd.SetOut(out)
errOut := bytes.NewBufferString("")
cmd.SetErr(errOut)
in := bytes.NewBufferString(tc.input)
cmd.SetIn(in)
err := cmd.Execute()
assert.ErrorIs(err, tc.expectedErr)
output, err := io.ReadAll(out)
assert.NoError(err)
assert.Contains(string(output), "777")
})
}
}
func TestWarnAboutPCRs(t *testing.T) {
zero := []byte("00000000000000000000000000000000")
testCases := map[string]struct {
pcrs map[uint32][]byte
dontWarnInit bool
expectedWarnings []string
errExpected bool
}{
"no warnings": {
pcrs: map[uint32][]byte{
0: zero,
1: zero,
2: zero,
3: zero,
4: zero,
5: zero,
6: zero,
7: zero,
8: zero,
9: zero,
10: zero,
11: zero,
12: zero,
},
},
"no warnings for missing non critical values": {
pcrs: map[uint32][]byte{
0: zero,
1: zero,
2: zero,
3: zero,
4: zero,
5: zero,
8: zero,
9: zero,
11: zero,
12: zero,
},
},
"warn for BIOS": {
pcrs: map[uint32][]byte{
0: zero,
2: zero,
3: zero,
4: zero,
5: zero,
8: zero,
9: zero,
11: zero,
12: zero,
},
expectedWarnings: []string{"BIOS"},
},
"warn for OPROM": {
pcrs: map[uint32][]byte{
0: zero,
1: zero,
3: zero,
4: zero,
5: zero,
8: zero,
9: zero,
11: zero,
12: zero,
},
expectedWarnings: []string{"OPROM"},
},
"warn for MBR": {
pcrs: map[uint32][]byte{
0: zero,
1: zero,
2: zero,
3: zero,
5: zero,
8: zero,
9: zero,
11: zero,
12: zero,
},
expectedWarnings: []string{"MBR"},
},
"warn for kernel": {
pcrs: map[uint32][]byte{
0: zero,
1: zero,
2: zero,
3: zero,
4: zero,
5: zero,
9: zero,
11: zero,
12: zero,
},
expectedWarnings: []string{"kernel"},
},
"warn for initrd": {
pcrs: map[uint32][]byte{
0: zero,
1: zero,
2: zero,
3: zero,
4: zero,
5: zero,
8: zero,
11: zero,
12: zero,
},
expectedWarnings: []string{"initrd"},
},
"warn for initialization": {
pcrs: map[uint32][]byte{
0: zero,
1: zero,
2: zero,
3: zero,
4: zero,
5: zero,
8: zero,
9: zero,
11: zero,
},
dontWarnInit: false,
expectedWarnings: []string{"initialization"},
},
"don't warn for initialization": {
pcrs: map[uint32][]byte{
0: zero,
1: zero,
2: zero,
3: zero,
4: zero,
5: zero,
8: zero,
9: zero,
11: zero,
},
dontWarnInit: true,
},
"multi warning": {
pcrs: map[uint32][]byte{},
expectedWarnings: []string{
"BIOS",
"OPROM",
"MBR",
"initialization",
"initrd",
"kernel",
},
},
"bad config": {
pcrs: map[uint32][]byte{
0: []byte("000"),
},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
cmd := newInitCmd()
var out bytes.Buffer
cmd.SetOut(&out)
var errOut bytes.Buffer
cmd.SetErr(&errOut)
err := warnAboutPCRs(cmd, tc.pcrs, !tc.dontWarnInit)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
if len(tc.expectedWarnings) == 0 {
assert.Empty(errOut.String())
} else {
for _, warning := range tc.expectedWarnings {
assert.Contains(errOut.String(), warning)
}
}
}
})
}
}

70
cli/cmd/validargs.go Normal file
View File

@ -0,0 +1,70 @@
package cmd
import (
"fmt"
"strconv"
"strings"
"github.com/edgelesssys/constellation/cli/azure"
"github.com/edgelesssys/constellation/cli/ec2"
"github.com/edgelesssys/constellation/cli/gcp"
"github.com/spf13/cobra"
)
// isIntArg checks if argument at position arg is an integer.
func isIntArg(arg int) cobra.PositionalArgs {
return func(cmd *cobra.Command, args []string) error {
if _, err := strconv.Atoi(args[arg]); err != nil {
return fmt.Errorf("argument %d must be an integer", arg)
}
return nil
}
}
// isIntGreaterArg checks if argument at position arg is and integer and greater i.
func isIntGreaterArg(arg int, i int) cobra.PositionalArgs {
return cobra.MatchAll(isIntArg(arg), func(cmd *cobra.Command, args []string) error {
if v, _ := strconv.Atoi(args[arg]); v <= i {
return fmt.Errorf("argument %d must be greater %d, but it's %d", arg, i, v)
}
return nil
})
}
// isIntGreaterZeroArg checks if argument at position arg is a positive non zero integer.
func isIntGreaterZeroArg(arg int) cobra.PositionalArgs {
return cobra.MatchAll(isIntGreaterArg(arg, 0))
}
// isEC2InstanceType checks if argument at position arg is a key in m.
// The argument will always be converted to lower case letters.
func isEC2InstanceType(arg int) cobra.PositionalArgs {
return func(cmd *cobra.Command, args []string) error {
if _, ok := ec2.InstanceTypes[strings.ToLower(args[arg])]; !ok {
return fmt.Errorf("'%s' isn't an AWS EC2 instance type", args[arg])
}
return nil
}
}
func isGCPInstanceType(arg int) cobra.PositionalArgs {
return func(cmd *cobra.Command, args []string) error {
for _, instanceType := range gcp.InstanceTypes {
if args[arg] == instanceType {
return nil
}
}
return fmt.Errorf("argument %s isn't a valid GCP instance type", args[arg])
}
}
func isAzureInstanceType(arg int) cobra.PositionalArgs {
return func(cmd *cobra.Command, args []string) error {
for _, instanceType := range azure.InstanceTypes {
if args[arg] == instanceType {
return nil
}
}
return fmt.Errorf("argument %s isn't a valid Azure instance type", args[arg])
}
}

197
cli/cmd/validargs_test.go Normal file
View File

@ -0,0 +1,197 @@
package cmd
import (
"testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
)
func TestIsIntArg(t *testing.T) {
testCmd := &cobra.Command{
Use: "test",
Args: isIntArg(0),
Run: func(cmd *cobra.Command, args []string) {},
}
testCases := map[string]struct {
args []string
expectErr bool
}{
"valid int 1": {[]string{"1"}, false},
"valid int 2": {[]string{"42"}, false},
"valid int 3": {[]string{"987987498"}, false},
"valid int and other args": {[]string{"3", "hello"}, false},
"valid int and other args 2": {[]string{"3", "4"}, false},
"invalid 1": {[]string{"hello world"}, true},
"invalid 2": {[]string{"98798d749f8"}, true},
"invalid 3": {[]string{"three"}, true},
"invalid 4": {[]string{"0.3"}, true},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := testCmd.ValidateArgs(tc.args)
if tc.expectErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestIsIntGreaterArg(t *testing.T) {
testCmd := &cobra.Command{
Use: "test",
Args: isIntGreaterArg(0, 12),
Run: func(cmd *cobra.Command, args []string) {},
}
testCases := map[string]struct {
args []string
expectErr bool
}{
"valid int 1": {[]string{"13"}, false},
"valid int 2": {[]string{"42"}, false},
"valid int 3": {[]string{"987987498"}, false},
"invalid int 1": {[]string{"1"}, true},
"invalid int and other args": {[]string{"3", "hello"}, true},
"invalid int and other args 2": {[]string{"-14", "4"}, true},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := testCmd.ValidateArgs(tc.args)
if tc.expectErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestIsIntGreaterZeroArg(t *testing.T) {
testCmd := &cobra.Command{
Use: "test",
Args: isIntGreaterZeroArg(0),
Run: func(cmd *cobra.Command, args []string) {},
}
testCases := map[string]struct {
args []string
expectErr bool
}{
"valid int 1": {[]string{"13"}, false},
"valid int 2": {[]string{"42"}, false},
"valid int 3": {[]string{"987987498"}, false},
"invalid": {[]string{"0"}, true},
"invalid int 1": {[]string{"-42", "hello"}, true},
"invalid int and other args": {[]string{"-9487239847", "4"}, true},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := testCmd.ValidateArgs(tc.args)
if tc.expectErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestIsEC2InstanceType(t *testing.T) {
testCmd := &cobra.Command{
Use: "test",
Args: isEC2InstanceType(0),
Run: func(cmd *cobra.Command, args []string) {},
}
testCases := map[string]struct {
args []string
expectErr bool
}{
"is instance type 1": {[]string{"4xl"}, false},
"is instance type 2": {[]string{"12xlarge", "something else"}, false},
"isn't instance type 1": {[]string{"notanInstanceType"}, true},
"isn't instance type 2": {[]string{"Hello World!"}, true},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := testCmd.ValidateArgs(tc.args)
if tc.expectErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestIsGCPInstanceType(t *testing.T) {
testCmd := &cobra.Command{
Use: "test",
Args: isGCPInstanceType(0),
Run: func(cmd *cobra.Command, args []string) {},
}
testCases := map[string]struct {
args []string
expectErr bool
}{
"is instance type 1": {[]string{"n2d-standard-4"}, false},
"is instance type 2": {[]string{"n2d-standard-16", "something else"}, false},
"isn't instance type 1": {[]string{"notanInstanceType"}, true},
"isn't instance type 2": {[]string{"Hello World!"}, true},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := testCmd.ValidateArgs(tc.args)
if tc.expectErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestIsAzureInstanceType(t *testing.T) {
testCmd := &cobra.Command{
Use: "test",
Args: isAzureInstanceType(0),
Run: func(cmd *cobra.Command, args []string) {},
}
testCases := map[string]struct {
args []string
expectErr bool
}{
"is instance type 1": {[]string{"Standard_DC2as_v5"}, false},
"is instance type 2": {[]string{"Standard_DC8as_v5", "something else"}, false},
"isn't instance type 1": {[]string{"notanInstanceType"}, true},
"isn't instance type 2": {[]string{"Hello World!"}, true},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := testCmd.ValidateArgs(tc.args)
if tc.expectErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}

141
cli/cmd/verify.go Normal file
View File

@ -0,0 +1,141 @@
package cmd
import (
"context"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"github.com/edgelesssys/constellation/cli/status"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
rpcStatus "google.golang.org/grpc/status"
)
func newVerifyCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "verify azure|gcp",
Short: "Verify the confidential properties of your Constellation.",
Long: "Verify the confidential properties of your Constellation.",
}
cmd.PersistentFlags().String("owner-id", "", "verify the Constellation using the owner identity derived from the master secret.")
cmd.PersistentFlags().String("unique-id", "", "verify the Constellation using the unique cluster identity.")
cmd.AddCommand(newVerifyGCPCmd())
cmd.AddCommand(newVerifyAzureCmd())
cmd.AddCommand(newVerifyGCPNonCVMCmd())
return cmd
}
func runVerify(cmd *cobra.Command, args []string, pcrs map[uint32][]byte, validator atls.Validator) error {
if err := warnAboutPCRs(cmd, pcrs, false); err != nil {
return err
}
verifier := verifier{
newConn: newVerifiedConn,
newClient: pubproto.NewAPIClient,
}
return verify(cmd.Context(), cmd.OutOrStdout(), net.JoinHostPort(args[0], args[1]), []atls.Validator{validator}, verifier)
}
func verify(ctx context.Context, w io.Writer, target string, validators []atls.Validator, verifier verifier) error {
conn, err := verifier.newConn(ctx, target, validators)
if err != nil {
return err
}
defer conn.Close()
client := verifier.newClient(conn)
if _, err := client.GetState(ctx, &pubproto.GetStateRequest{}); err != nil {
if err, ok := rpcStatus.FromError(err); ok {
return fmt.Errorf("unable to verify Constellation cluster: %s", err.Message())
}
return err
}
fmt.Fprintln(w, "OK")
return nil
}
// prepareValidator parses parameters and updates the PCR map.
func prepareValidator(cmd *cobra.Command, pcrs map[uint32][]byte) error {
ownerID, err := cmd.Flags().GetString("owner-id")
if err != nil {
return err
}
clusterID, err := cmd.Flags().GetString("unique-id")
if err != nil {
return err
}
if ownerID == "" && clusterID == "" {
return errors.New("neither owner identity nor unique identity provided to verify the Constellation")
}
return updatePCRMap(pcrs, ownerID, clusterID)
}
func updatePCRMap(pcrs map[uint32][]byte, ownerID, clusterID string) error {
if err := addOrSkipPCR(pcrs, uint32(vtpm.PCRIndexOwnerID), ownerID); err != nil {
return err
}
return addOrSkipPCR(pcrs, uint32(vtpm.PCRIndexClusterID), clusterID)
}
// addOrSkipPCR adds a new entry to the map, or removes the key if the input is an empty string.
//
// When adding, the input is first decoded from base64.
// We then calculate the expected PCR by hashing the input using SHA256,
// appending expected PCR for initialization, and then hashing once more.
func addOrSkipPCR(toAdd map[uint32][]byte, pcrIndex uint32, encoded string) error {
if encoded == "" {
delete(toAdd, pcrIndex)
return nil
}
decoded, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
return fmt.Errorf("input [%s] is not base64 encoded: %w", encoded, err)
}
// new_pcr_value := hash(old_pcr_value || data_to_extend)
// Since we use the TPM2_PCR_Event call to extend the PCR, data_to_extend is the hash of our input
hashedInput := sha256.Sum256(decoded)
expectedPcr := sha256.Sum256(append(toAdd[pcrIndex], hashedInput[:]...))
toAdd[pcrIndex] = expectedPcr[:]
return nil
}
type verifier struct {
newConn func(context.Context, string, []atls.Validator) (status.ClientConn, error)
newClient func(cc grpc.ClientConnInterface) pubproto.APIClient
}
// newVerifiedConn creates a grpc over aTLS connection to the target, using the provided PCR values to verify the server.
func newVerifiedConn(ctx context.Context, target string, validators []atls.Validator) (status.ClientConn, error) {
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)
if err != nil {
return nil, err
}
return grpc.DialContext(
ctx, target, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
)
}
// verifyCompletion handels the completion of CLI arguments. It is frequently called
// while the user types arguments of the command to suggest completion.
func verifyCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
switch len(args) {
case 0, 1:
return []string{}, cobra.ShellCompDirectiveNoFileComp
default:
return []string{}, cobra.ShellCompDirectiveError
}
}

51
cli/cmd/verify_azure.go Normal file
View File

@ -0,0 +1,51 @@
package cmd
import (
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/attestation/azure"
"github.com/edgelesssys/constellation/internal/config"
"github.com/spf13/afero"
"github.com/spf13/cobra"
)
func newVerifyAzureCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "azure IP PORT",
Short: "Verify the confidential properties of your Constellation on Azure.",
Long: "Verify the confidential properties of your Constellation on Azure.",
Args: cobra.ExactArgs(2),
ValidArgsFunction: verifyCompletion,
RunE: runVerifyAzure,
}
return cmd
}
func runVerifyAzure(cmd *cobra.Command, args []string) error {
fileHandler := file.NewHandler(afero.NewOsFs())
devConfigName, err := cmd.Flags().GetString("dev-config")
if err != nil {
return err
}
config, err := config.FromFile(fileHandler, devConfigName)
if err != nil {
return err
}
validators, err := getAzureValidator(cmd, *config.Provider.GCP.PCRs)
if err != nil {
return err
}
return runVerify(cmd, args, *config.Provider.GCP.PCRs, validators)
}
// getAzureValidator returns an Azure validator.
func getAzureValidator(cmd *cobra.Command, pcrs map[uint32][]byte) (atls.Validator, error) {
if err := prepareValidator(cmd, pcrs); err != nil {
return nil, err
}
return azure.NewValidator(pcrs), nil
}

View File

@ -0,0 +1,66 @@
package cmd
import (
"bytes"
"encoding/base64"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetAzureValidator(t *testing.T) {
testCases := map[string]struct {
ownerID string
clusterID string
errExpected bool
}{
"no input": {
ownerID: "",
clusterID: "",
errExpected: true,
},
"unencoded secret ID": {
ownerID: "owner-id",
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
errExpected: true,
},
"unencoded cluster ID": {
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
clusterID: "unique-id",
errExpected: true,
},
"correct input": {
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
errExpected: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := newVerifyAzureCmd()
cmd.Flags().String("owner-id", "", "")
cmd.Flags().String("unique-id", "", "")
require.NoError(cmd.Flags().Set("owner-id", tc.ownerID))
require.NoError(cmd.Flags().Set("unique-id", tc.clusterID))
var out bytes.Buffer
cmd.SetOut(&out)
var errOut bytes.Buffer
cmd.SetErr(&errOut)
_, err := getAzureValidator(cmd, map[uint32][]byte{
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"),
})
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}

51
cli/cmd/verify_gcp.go Normal file
View File

@ -0,0 +1,51 @@
package cmd
import (
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
"github.com/edgelesssys/constellation/internal/config"
"github.com/spf13/afero"
"github.com/spf13/cobra"
)
func newVerifyGCPCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "gcp IP PORT",
Short: "Verify the confidential properties of your Constellation on Google Cloud Platform.",
Long: "Verify the confidential properties of your Constellation on Google Cloud Platform.",
Args: cobra.ExactArgs(2),
ValidArgsFunction: verifyCompletion,
RunE: runVerifyGCP,
}
return cmd
}
func runVerifyGCP(cmd *cobra.Command, args []string) error {
fileHandler := file.NewHandler(afero.NewOsFs())
devConfigName, err := cmd.Flags().GetString("dev-config")
if err != nil {
return err
}
config, err := config.FromFile(fileHandler, devConfigName)
if err != nil {
return err
}
validators, err := getGCPValidator(cmd, *config.Provider.GCP.PCRs)
if err != nil {
return err
}
return runVerify(cmd, args, *config.Provider.GCP.PCRs, validators)
}
// getValidators returns a GCP validator.
func getGCPValidator(cmd *cobra.Command, pcrs map[uint32][]byte) (atls.Validator, error) {
if err := prepareValidator(cmd, pcrs); err != nil {
return nil, err
}
return gcp.NewValidator(pcrs), nil
}

View File

@ -0,0 +1,40 @@
package cmd
import (
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
"github.com/spf13/cobra"
)
// TODO: Remove this command once we no longer use non cvms.
func newVerifyGCPNonCVMCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "gcp-non-cvm IP PORT",
Short: "Verify the TPM attestation of your shielded VM Constellation on Google Cloud Platform.",
Long: "Verify the TPM attestation of your shielded VM Constellation on Google Cloud Platform.",
Args: cobra.ExactArgs(2),
ValidArgsFunction: verifyCompletion,
RunE: runVerifyGCPNonCVM,
}
return cmd
}
func runVerifyGCPNonCVM(cmd *cobra.Command, args []string) error {
pcrs := map[uint32][]byte{}
validator, err := getGCPNonCVMValidator(cmd, pcrs)
if err != nil {
return err
}
return runVerify(cmd, args, pcrs, validator)
}
// getGCPNonCVMValidator returns a GCP validator for regular shielded VMs.
func getGCPNonCVMValidator(cmd *cobra.Command, pcrs map[uint32][]byte) (atls.Validator, error) {
if err := prepareValidator(cmd, pcrs); err != nil {
return nil, err
}
return gcp.NewNonCVMValidator(pcrs), nil
}

View File

@ -0,0 +1,63 @@
package cmd
import (
"bytes"
"encoding/base64"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetGCPNonCVMValidator(t *testing.T) {
testCases := map[string]struct {
ownerID string
clusterID string
errExpected bool
}{
"no input": {
ownerID: "",
clusterID: "",
errExpected: true,
},
"unencoded secret ID": {
ownerID: "owner-id",
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
errExpected: true,
},
"unencoded cluster ID": {
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
clusterID: "unique-id",
errExpected: true,
},
"correct input": {
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
errExpected: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := newVerifyGCPNonCVMCmd()
cmd.Flags().String("owner-id", "", "")
cmd.Flags().String("unique-id", "", "")
require.NoError(cmd.Flags().Set("owner-id", tc.ownerID))
require.NoError(cmd.Flags().Set("unique-id", tc.clusterID))
var out bytes.Buffer
cmd.SetOut(&out)
var errOut bytes.Buffer
cmd.SetErr(&errOut)
_, err := getGCPNonCVMValidator(cmd, map[uint32][]byte{})
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}

View File

@ -0,0 +1,66 @@
package cmd
import (
"bytes"
"encoding/base64"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetGCPValidator(t *testing.T) {
testCases := map[string]struct {
ownerID string
clusterID string
errExpected bool
}{
"no input": {
ownerID: "",
clusterID: "",
errExpected: true,
},
"unencoded secret ID": {
ownerID: "owner-id",
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
errExpected: true,
},
"unencoded cluster ID": {
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
clusterID: "unique-id",
errExpected: true,
},
"correct input": {
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
errExpected: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := newVerifyGCPCmd()
cmd.Flags().String("owner-id", "", "")
cmd.Flags().String("unique-id", "", "")
require.NoError(cmd.Flags().Set("owner-id", tc.ownerID))
require.NoError(cmd.Flags().Set("unique-id", tc.clusterID))
var out bytes.Buffer
cmd.SetOut(&out)
var errOut bytes.Buffer
cmd.SetErr(&errOut)
_, err := getGCPValidator(cmd, map[uint32][]byte{
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"),
})
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}

325
cli/cmd/verify_test.go Normal file
View File

@ -0,0 +1,325 @@
package cmd
import (
"bytes"
"context"
"encoding/base64"
"errors"
"testing"
"github.com/edgelesssys/constellation/cli/status"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/state"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
rpcStatus "google.golang.org/grpc/status"
)
func TestVerify(t *testing.T) {
testCases := map[string]struct {
connErr error
checkErr error
state state.State
errExpected bool
}{
"connection error": {
connErr: errors.New("connection error"),
checkErr: nil,
state: 0,
errExpected: true,
},
"check error": {
connErr: nil,
checkErr: errors.New("check error"),
state: 0,
errExpected: true,
},
"check error, rpc status": {
connErr: nil,
checkErr: rpcStatus.Error(codes.Unavailable, "check error"),
state: 0,
errExpected: true,
},
"verify on worker node": {
connErr: nil,
checkErr: nil,
state: state.IsNode,
errExpected: false,
},
"verify on master node": {
connErr: nil,
checkErr: nil,
state: state.ActivatingNodes,
errExpected: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
var out bytes.Buffer
verifier := verifier{
newConn: stubNewConnFunc(tc.connErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{
state: tc.state,
checkErr: tc.checkErr,
}),
}
pcrs := map[uint32][]byte{
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"),
}
err := verify(ctx, &out, "", []atls.Validator{gcp.NewValidator(pcrs)}, verifier)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.Contains(out.String(), "OK")
}
})
}
}
func stubNewConnFunc(errStub error) func(ctx context.Context, target string, validators []atls.Validator) (status.ClientConn, error) {
return func(ctx context.Context, target string, validators []atls.Validator) (status.ClientConn, error) {
return &stubClientConn{}, errStub
}
}
type stubClientConn struct{}
func (c *stubClientConn) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...grpc.CallOption) error {
return nil
}
func (c *stubClientConn) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
return nil, nil
}
func (c *stubClientConn) Close() error {
return nil
}
func stubNewClientFunc(stubClient pubproto.APIClient) func(cc grpc.ClientConnInterface) pubproto.APIClient {
return func(cc grpc.ClientConnInterface) pubproto.APIClient {
return stubClient
}
}
type stubPeerStatusClient struct {
state state.State
checkErr error
pubproto.APIClient
}
func (c *stubPeerStatusClient) GetState(ctx context.Context, in *pubproto.GetStateRequest, opts ...grpc.CallOption) (*pubproto.GetStateResponse, error) {
resp := &pubproto.GetStateResponse{State: uint32(c.state)}
return resp, c.checkErr
}
func TestPrepareValidator(t *testing.T) {
testCases := map[string]struct {
ownerID string
clusterID string
errExpected bool
}{
"no input": {
ownerID: "",
clusterID: "",
errExpected: true,
},
"unencoded secret ID": {
ownerID: "owner-id",
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
errExpected: true,
},
"unencoded cluster ID": {
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
clusterID: "unique-id",
errExpected: true,
},
"correct input": {
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
errExpected: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := newVerifyCmd()
cmd.Flags().String("owner-id", "", "")
cmd.Flags().String("unique-id", "", "")
require.NoError(cmd.Flags().Set("owner-id", tc.ownerID))
require.NoError(cmd.Flags().Set("unique-id", tc.clusterID))
var out bytes.Buffer
cmd.SetOut(&out)
var errOut bytes.Buffer
cmd.SetErr(&errOut)
pcrs := map[uint32][]byte{
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"),
}
err := prepareValidator(cmd, pcrs)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
if tc.clusterID != "" {
assert.Len(pcrs[uint32(vtpm.PCRIndexClusterID)], 32)
} else {
assert.Nil(pcrs[uint32(vtpm.PCRIndexClusterID)])
}
if tc.ownerID != "" {
assert.Len(pcrs[uint32(vtpm.PCRIndexOwnerID)], 32)
} else {
assert.Nil(pcrs[uint32(vtpm.PCRIndexOwnerID)])
}
}
})
}
}
func TestAddOrSkipPcr(t *testing.T) {
emptyMap := map[uint32][]byte{}
defaultMap := map[uint32][]byte{
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"),
}
testCases := map[string]struct {
pcrMap map[uint32][]byte
pcrIndex uint32
encoded string
expectedEntries int
errExpected bool
}{
"empty input, empty map": {
pcrMap: emptyMap,
pcrIndex: 10,
encoded: "",
expectedEntries: 0,
errExpected: false,
},
"empty input, default map": {
pcrMap: defaultMap,
pcrIndex: 10,
encoded: "",
expectedEntries: len(defaultMap),
errExpected: false,
},
"correct input, empty map": {
pcrMap: emptyMap,
pcrIndex: 10,
encoded: base64.StdEncoding.EncodeToString([]byte("Constellation")),
expectedEntries: 1,
errExpected: false,
},
"correct input, default map": {
pcrMap: defaultMap,
pcrIndex: 10,
encoded: base64.StdEncoding.EncodeToString([]byte("Constellation")),
expectedEntries: len(defaultMap) + 1,
errExpected: false,
},
"unencoded input, empty map": {
pcrMap: emptyMap,
pcrIndex: 10,
encoded: "Constellation",
expectedEntries: 0,
errExpected: true,
},
"unencoded input, default map": {
pcrMap: defaultMap,
pcrIndex: 10,
encoded: "Constellation",
expectedEntries: len(defaultMap),
errExpected: true,
},
"empty input at occupied index": {
pcrMap: defaultMap,
pcrIndex: 0,
encoded: "",
expectedEntries: len(defaultMap) - 1,
errExpected: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
res := make(map[uint32][]byte)
for k, v := range tc.pcrMap {
res[k] = v
}
err := addOrSkipPCR(res, tc.pcrIndex, tc.encoded)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
}
assert.Len(res, tc.expectedEntries)
for _, v := range res {
assert.Len(v, 32)
}
})
}
}
func TestVerifyCompletion(t *testing.T) {
testCases := map[string]struct {
args []string
toComplete string
resultExpected []string
shellCDExpected cobra.ShellCompDirective
}{
"first arg": {
args: []string{},
toComplete: "192.0.2.1",
resultExpected: []string{},
shellCDExpected: cobra.ShellCompDirectiveNoFileComp,
},
"second arg": {
args: []string{"192.0.2.1"},
toComplete: "443",
resultExpected: []string{},
shellCDExpected: cobra.ShellCompDirectiveNoFileComp,
},
"third arg": {
args: []string{"192.0.2.1", "443"},
toComplete: "./file",
resultExpected: []string{},
shellCDExpected: cobra.ShellCompDirectiveError,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
cmd := &cobra.Command{}
result, shellCD := verifyCompletion(cmd, tc.args, tc.toComplete)
assert.Equal(tc.resultExpected, result)
assert.Equal(tc.shellCDExpected, shellCD)
})
}
}

19
cli/cmd/version.go Normal file
View File

@ -0,0 +1,19 @@
package cmd
import (
"github.com/edgelesssys/constellation/internal/config"
"github.com/spf13/cobra"
)
func newVersionCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "version",
Short: "Display version of this CLI",
Long: `Display version of this CLI`,
Args: cobra.NoArgs,
Run: func(cmd *cobra.Command, args []string) {
cmd.Printf("CLI Version: v%s \n", config.Version)
},
}
return cmd
}

25
cli/cmd/version_test.go Normal file
View File

@ -0,0 +1,25 @@
package cmd
import (
"bytes"
"io"
"testing"
"github.com/edgelesssys/constellation/internal/config"
"github.com/stretchr/testify/assert"
)
func TestVersionCmd(t *testing.T) {
assert := assert.New(t)
cmd := newVersionCmd()
b := bytes.NewBufferString("")
cmd.SetOut(b)
err := cmd.Execute()
assert.NoError(err)
s, err := io.ReadAll(b)
assert.NoError(err)
assert.Contains(string(s), config.Version)
}

5
cli/cmd/vpnconfigurer.go Normal file
View File

@ -0,0 +1,5 @@
package cmd
type vpnConfigurer interface {
Configure(clientVpnIp string, coordinatorPubKey string, coordinatorPubIP string, clientPrivKey string) error
}

View File

@ -0,0 +1,17 @@
package cmd
type stubVPNConfigurer struct {
configured bool
configureErr error
}
func (c *stubVPNConfigurer) Configure(clientVpnIp, coordinatorPubKey, coordinatorPubIP, clientPrivKey string) error {
c.configured = true
return c.configureErr
}
type dummyVPNConfigurer struct{}
func (c *dummyVPNConfigurer) Configure(clientVpnIp, coordinatorPubKey, coordinatorPubIP, clientPrivKey string) error {
panic("dummy doesn't implement this function")
}

42
cli/ec2/client/api.go Normal file
View File

@ -0,0 +1,42 @@
package client
import (
"context"
"github.com/aws/aws-sdk-go-v2/service/ec2"
)
// api collects used functions of AWS' ec2.Client as interfaces to enable testing.
type api interface {
ec2.DescribeInstancesAPIClient
// Instances
RunInstances(ctx context.Context,
params *ec2.RunInstancesInput,
optFns ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error)
TerminateInstances(ctx context.Context,
params *ec2.TerminateInstancesInput,
optFns ...func(*ec2.Options)) (*ec2.TerminateInstancesOutput, error)
CreateTags(ctx context.Context,
params *ec2.CreateTagsInput,
optFns ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error)
// SecurityGroup
CreateSecurityGroup(ctx context.Context,
params *ec2.CreateSecurityGroupInput,
optFns ...func(*ec2.Options)) (*ec2.CreateSecurityGroupOutput, error)
DeleteSecurityGroup(ctx context.Context,
params *ec2.DeleteSecurityGroupInput,
optFns ...func(*ec2.Options)) (*ec2.DeleteSecurityGroupOutput, error)
AuthorizeSecurityGroupIngress(ctx context.Context,
params *ec2.AuthorizeSecurityGroupIngressInput,
optFns ...func(*ec2.Options)) (*ec2.AuthorizeSecurityGroupIngressOutput, error)
AuthorizeSecurityGroupEgress(ctx context.Context,
params *ec2.AuthorizeSecurityGroupEgressInput,
optFns ...func(*ec2.Options)) (*ec2.AuthorizeSecurityGroupEgressOutput, error)
}

137
cli/ec2/client/api_test.go Normal file
View File

@ -0,0 +1,137 @@
package client
import (
"context"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/aws/smithy-go"
)
// stubAPI is a stub ec2 api for testing.
type stubAPI struct {
instances []types.Instance
securityGroup types.SecurityGroup
describeInstancesErr error
runInstancesErr error
runInstancesDryRunErr *error
terminateInstancesErr error
terminateInstancesDryRunErr *error
createTagsErr error
createSecurityGroupErr error
createSecurityGroupDryRunErr *error
deleteSecurityGroupErr error
deleteSecurityGroupDryRunErr *error
authorizeSecurityGroupIngressErr error
authorizeSecurityGroupIngressDryRunErr *error
authorizeSecurityGroupEgressErr error
authorizeSecurityGroupEgressDryRunErr *error
}
func (a stubAPI) DescribeInstances(ctx context.Context,
params *ec2.DescribeInstancesInput,
optFns ...func(*ec2.Options),
) (*ec2.DescribeInstancesOutput, error) {
return &ec2.DescribeInstancesOutput{
Reservations: []types.Reservation{
{Instances: a.instances},
},
}, a.describeInstancesErr
}
func (a stubAPI) RunInstances(ctx context.Context,
params *ec2.RunInstancesInput,
optFns ...func(*ec2.Options),
) (*ec2.RunInstancesOutput, error) {
if err := getDryRunErr(params.DryRun, a.runInstancesDryRunErr); err != nil {
return nil, err
}
return &ec2.RunInstancesOutput{Instances: a.instances}, a.runInstancesErr
}
func (a stubAPI) CreateTags(ctx context.Context,
params *ec2.CreateTagsInput,
optFns ...func(*ec2.Options),
) (*ec2.CreateTagsOutput, error) {
return nil, a.createTagsErr
}
func (a stubAPI) TerminateInstances(ctx context.Context,
params *ec2.TerminateInstancesInput,
optFns ...func(*ec2.Options),
) (*ec2.TerminateInstancesOutput, error) {
if err := getDryRunErr(params.DryRun, a.terminateInstancesDryRunErr); err != nil {
return nil, err
}
return nil, a.terminateInstancesErr
}
func (a stubAPI) CreateSecurityGroup(ctx context.Context,
params *ec2.CreateSecurityGroupInput,
optFns ...func(*ec2.Options),
) (*ec2.CreateSecurityGroupOutput, error) {
if err := getDryRunErr(params.DryRun, a.createSecurityGroupDryRunErr); err != nil {
return nil, err
}
return &ec2.CreateSecurityGroupOutput{
GroupId: a.securityGroup.GroupId,
}, a.createSecurityGroupErr
}
func (a stubAPI) DeleteSecurityGroup(ctx context.Context,
params *ec2.DeleteSecurityGroupInput,
optFns ...func(*ec2.Options),
) (*ec2.DeleteSecurityGroupOutput, error) {
if err := getDryRunErr(params.DryRun, a.deleteSecurityGroupDryRunErr); err != nil {
return nil, err
}
return nil, a.deleteSecurityGroupErr
}
func (a stubAPI) AuthorizeSecurityGroupIngress(ctx context.Context,
params *ec2.AuthorizeSecurityGroupIngressInput,
optFns ...func(*ec2.Options),
) (*ec2.AuthorizeSecurityGroupIngressOutput, error) {
if err := getDryRunErr(params.DryRun, a.authorizeSecurityGroupIngressDryRunErr); err != nil {
return nil, err
}
return nil, a.authorizeSecurityGroupIngressErr
}
func (a stubAPI) AuthorizeSecurityGroupEgress(ctx context.Context,
params *ec2.AuthorizeSecurityGroupEgressInput,
optFns ...func(*ec2.Options),
) (*ec2.AuthorizeSecurityGroupEgressOutput, error) {
if err := getDryRunErr(params.DryRun, a.authorizeSecurityGroupEgressDryRunErr); err != nil {
return nil, err
}
return nil, a.authorizeSecurityGroupEgressErr
}
func getDryRunErr(dryRun *bool, stubErr *error) error {
if dryRun == nil || !*dryRun {
return nil
}
if stubErr != nil {
return *stubErr
}
return &smithy.GenericAPIError{Code: "DryRunOperation"}
}
var stateRunning = types.InstanceState{
Code: aws.Int32(int32(16)),
Name: types.InstanceStateNameRunning,
}
var stateTerminated = types.InstanceState{
Code: aws.Int32(48),
Name: types.InstanceStateNameTerminated,
}

71
cli/ec2/client/client.go Normal file
View File

@ -0,0 +1,71 @@
package client
import (
"context"
"errors"
"time"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
awsec2 "github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/cli/ec2"
"github.com/edgelesssys/constellation/internal/state"
)
// Client for the AWS EC2 API.
type Client struct {
api api
instances ec2.Instances
securityGroup string
timeout time.Duration
}
func newClient(api api) (*Client, error) {
return &Client{
api: api,
instances: make(map[string]ec2.Instance),
timeout: 2 * time.Minute,
}, nil
}
// NewFromDefault creates a Client from the default config.
func NewFromDefault(ctx context.Context) (*Client, error) {
cfg, err := awsconfig.LoadDefaultConfig(ctx)
if err != nil {
return nil, err
}
return newClient(awsec2.NewFromConfig(cfg))
}
// GetState returns the current configuration of the Constellation,
// which can be stored and used through later CLI commands.
func (c *Client) GetState() (state.ConstellationState, error) {
if len(c.instances) == 0 {
return state.ConstellationState{}, errors.New("client has no instances")
}
if c.securityGroup == "" {
return state.ConstellationState{}, errors.New("client has no security group")
}
return state.ConstellationState{
CloudProvider: cloudprovider.AWS.String(),
EC2Instances: c.instances,
EC2SecurityGroup: c.securityGroup,
}, nil
}
// SetState sets a Client to an existing configuration.
func (c *Client) SetState(stat state.ConstellationState) error {
if stat.CloudProvider != cloudprovider.AWS.String() {
return errors.New("state is not aws state")
}
if len(stat.EC2Instances) == 0 {
return errors.New("state has no instances")
}
if stat.EC2SecurityGroup == "" {
return errors.New("state has no security group")
}
c.instances = stat.EC2Instances
c.securityGroup = stat.EC2SecurityGroup
return nil
}

View File

@ -0,0 +1,120 @@
package client
import (
"testing"
"github.com/edgelesssys/constellation/cli/ec2"
"github.com/edgelesssys/constellation/internal/state"
"github.com/stretchr/testify/assert"
)
func TestGetState(t *testing.T) {
testCases := map[string]struct {
client Client
wantState state.ConstellationState
wantErr bool
}{
"successful get": {
client: Client{
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
securityGroup: "sg",
},
wantState: state.ConstellationState{
CloudProvider: "AWS",
EC2Instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
EC2SecurityGroup: "sg",
},
},
"client without security group": {
client: Client{
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
},
wantErr: true,
},
"client without instances": {
client: Client{
securityGroup: "sg",
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
stat, err := tc.client.GetState()
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.wantState, stat)
}
})
}
}
func TestSetState(t *testing.T) {
testCases := map[string]struct {
state state.ConstellationState
wantInstances ec2.Instances
wantSecurityGroup string
wantErr bool
}{
"successful set": {
state: state.ConstellationState{
CloudProvider: "AWS",
EC2SecurityGroup: "sg-test",
EC2Instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
},
wantInstances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
wantSecurityGroup: "sg-test",
},
"state without cloudprovider": {
state: state.ConstellationState{
EC2SecurityGroup: "sg-test",
EC2Instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
},
wantErr: true,
},
"state with incorrect cloudprovider": {
state: state.ConstellationState{
CloudProvider: "incorrect",
EC2SecurityGroup: "sg-test",
EC2Instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
},
wantErr: true,
},
"state without instances": {
state: state.ConstellationState{
CloudProvider: "AWS",
EC2SecurityGroup: "sg-test",
},
wantErr: true,
},
"state without security group": {
state: state.ConstellationState{
CloudProvider: "AWS",
EC2Instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := &Client{}
err := client.SetState(tc.state)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.wantInstances, client.instances)
assert.Equal(tc.wantSecurityGroup, client.securityGroup)
}
})
}
}

199
cli/ec2/client/instances.go Normal file
View File

@ -0,0 +1,199 @@
package client
import (
"context"
"errors"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
awsec2 "github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/edgelesssys/constellation/cli/ec2"
)
// CreateInstances creates the instances defined in input.
//
// An existing security group is needed to create instances.
func (c *Client) CreateInstances(ctx context.Context, input CreateInput) error {
if c.securityGroup == "" {
return errors.New("no security group set")
}
input.securityGroupIds = []string{c.securityGroup}
if err := c.createDryRun(ctx, input); err != nil {
return err
}
resp, err := c.api.RunInstances(ctx, input.AWS())
if err != nil {
return fmt.Errorf("failed to create instances: %w", err)
}
for _, instance := range resp.Instances {
id := instance.InstanceId
if id == nil {
return errors.New("instanceId is nil pointer")
}
c.instances[*id] = ec2.Instance{}
}
if err := c.waitStateRunning(ctx); err != nil {
return err
}
if err := c.tagInstances(ctx, input.Tags); err != nil {
return err
}
if err := c.getInstanceIPs(ctx); err != nil {
return err
}
return nil
}
// TerminateInstances terminates all instances of a Client.
func (c *Client) TerminateInstances(ctx context.Context) error {
if len(c.instances) == 0 {
return nil
}
input := &awsec2.TerminateInstancesInput{
InstanceIds: c.instances.IDs(),
}
if err := c.terminateDryRun(ctx, *input); err != nil {
return err
}
if _, err := c.api.TerminateInstances(ctx, input); err != nil {
return err
}
if err := c.waitStateTerminated(ctx); err != nil {
return err
}
c.instances = ec2.Instances{}
return nil
}
// waitStateRunning waits until all the client's instances reached the running state.
//
// A set of instances is also considered to be running if at least one of the
// instances' state is 'running' and the other instances have a nil state.
func (c *Client) waitStateRunning(ctx context.Context) error {
if len(c.instances) == 0 {
return errors.New("client has no instances")
}
describeInput := &awsec2.DescribeInstancesInput{
InstanceIds: c.instances.IDs(),
}
waiter := awsec2.NewInstanceRunningWaiter(c.api)
return waiter.Wait(ctx, describeInput, c.timeout)
}
// waitStateTerminated waits until all the client's instances reached the terminated state.
//
// A set of instances is also considered to be terminated if at least one of the
// instances' state is 'terminated' and the other instances have a nil state.
func (c *Client) waitStateTerminated(ctx context.Context) error {
if len(c.instances) == 0 {
return errors.New("client has no instances")
}
describeInput := &awsec2.DescribeInstancesInput{
InstanceIds: c.instances.IDs(),
}
waiter := awsec2.NewInstanceTerminatedWaiter(c.api)
return waiter.Wait(ctx, describeInput, c.timeout)
}
// tagInstances tags all instances of a client with a given set of tags.
func (c *Client) tagInstances(ctx context.Context, tags ec2.Tags) error {
if len(c.instances) == 0 {
return errors.New("client has no instances")
}
tagInput := &awsec2.CreateTagsInput{
Resources: c.instances.IDs(),
Tags: tags.AWS(),
}
if _, err := c.api.CreateTags(ctx, tagInput); err != nil {
return fmt.Errorf("failed to tag instances: %w", err)
}
return nil
}
// createDryRun checks if user has the privilege to create the instances
// which were defined in input.
func (c *Client) createDryRun(ctx context.Context, input CreateInput) error {
runInput := input.AWS()
runInput.DryRun = aws.Bool(true)
_, err := c.api.RunInstances(ctx, runInput)
return checkDryRunError(err)
}
// terminateDryRun checks if user has the privilege to terminate the instances
// which were defined in input.
func (c *Client) terminateDryRun(ctx context.Context, input awsec2.TerminateInstancesInput) error {
input.DryRun = aws.Bool(true)
_, err := c.api.TerminateInstances(ctx, &input)
return checkDryRunError(err)
}
// getInstanceIPs queries the private and public IP addresses
// and adds the information to each instance.
//
// The instances must be in 'running' state.
func (c *Client) getInstanceIPs(ctx context.Context) error {
describeInput := &awsec2.DescribeInstancesInput{
InstanceIds: c.instances.IDs(),
}
paginator := awsec2.NewDescribeInstancesPaginator(c.api, describeInput)
for paginator.HasMorePages() {
output, err := paginator.NextPage(ctx)
if err != nil {
return err
}
for _, reservation := range output.Reservations {
for _, instanceDescription := range reservation.Instances {
if instanceDescription.InstanceId == nil {
return errors.New("instanceId is nil pointer")
}
if instanceDescription.PublicIpAddress == nil {
return errors.New("publicIpAddress is nil pointer")
}
if instanceDescription.PrivateIpAddress == nil {
return errors.New("privateIpAddress is nil pointer")
}
instance, ok := c.instances[*instanceDescription.InstanceId]
if !ok {
return errors.New("got an instance description to an unknown instanceId")
}
instance.PublicIP = *instanceDescription.PublicIpAddress
instance.PrivateIP = *instanceDescription.PrivateIpAddress
c.instances[*instanceDescription.InstanceId] = instance
}
}
}
return nil
}
// CreateInput defines the propertis of the instances to create.
type CreateInput struct {
ImageId string
InstanceType string
Count int
Tags ec2.Tags
securityGroupIds []string
}
// AWS creates a AWS ec2.RunInstancesInput from an CreateInput.
func (ci *CreateInput) AWS() *awsec2.RunInstancesInput {
return &awsec2.RunInstancesInput{
ImageId: aws.String(ci.ImageId),
InstanceType: ec2.InstanceTypes[ci.InstanceType],
MaxCount: aws.Int32(int32(ci.Count)),
MinCount: aws.Int32(int32(ci.Count)),
EnclaveOptions: &types.EnclaveOptionsRequest{Enabled: aws.Bool(true)},
SecurityGroupIds: ci.securityGroupIds,
}
}

View File

@ -0,0 +1,493 @@
package client
import (
"context"
"errors"
"testing"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
awsec2 "github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/edgelesssys/constellation/cli/ec2"
"github.com/stretchr/testify/assert"
)
func TestCreateInstances(t *testing.T) {
testInstances := []types.Instance{
{
InstanceId: aws.String("id-1"),
PublicIpAddress: aws.String("192.0.2.1"),
PrivateIpAddress: aws.String("192.0.2.2"),
State: &stateRunning,
},
{
InstanceId: aws.String("id-2"),
PublicIpAddress: aws.String("192.0.2.3"),
PrivateIpAddress: aws.String("192.0.2.4"),
State: &stateRunning,
},
{
InstanceId: aws.String("id-3"),
PublicIpAddress: aws.String("192.0.2.5"),
PrivateIpAddress: aws.String("192.0.2.6"),
State: &stateRunning,
},
}
someErr := errors.New("failed")
var noErr error
testCases := map[string]struct {
api stubAPI
instances ec2.Instances
securityGroup string
errExpected bool
wantInstances ec2.Instances
}{
"create": {
api: stubAPI{instances: testInstances},
securityGroup: "sg-test",
wantInstances: ec2.Instances{
"id-1": {PublicIP: "192.0.2.1", PrivateIP: "192.0.2.2"},
"id-2": {PublicIP: "192.0.2.3", PrivateIP: "192.0.2.4"},
"id-3": {PublicIP: "192.0.2.5", PrivateIP: "192.0.2.6"},
},
},
"client already has instances": {
api: stubAPI{instances: testInstances},
instances: ec2.Instances{"id-4": {}, "id-5": {}},
securityGroup: "sg-test",
wantInstances: ec2.Instances{
"id-1": {PublicIP: "192.0.2.1", PrivateIP: "192.0.2.2"},
"id-2": {PublicIP: "192.0.2.3", PrivateIP: "192.0.2.4"},
"id-3": {PublicIP: "192.0.2.5", PrivateIP: "192.0.2.6"},
"id-4": {},
"id-5": {},
},
},
"client already has same instance id": {
api: stubAPI{instances: testInstances},
instances: ec2.Instances{"id-1": {}, "id-4": {}, "id-5": {}},
securityGroup: "sg-test",
errExpected: false,
wantInstances: ec2.Instances{
"id-1": {PublicIP: "192.0.2.1", PrivateIP: "192.0.2.2"},
"id-2": {PublicIP: "192.0.2.3", PrivateIP: "192.0.2.4"},
"id-3": {PublicIP: "192.0.2.5", PrivateIP: "192.0.2.6"},
"id-4": {},
"id-5": {},
},
},
"client has no security group": {
api: stubAPI{},
errExpected: true,
},
"run API error": {
api: stubAPI{runInstancesErr: someErr},
securityGroup: "sg-test",
errExpected: true,
},
"runDryRun API error": {
api: stubAPI{runInstancesDryRunErr: &someErr},
securityGroup: "sg-test",
errExpected: true,
},
"runDryRun missing expected API error": {
api: stubAPI{runInstancesDryRunErr: &noErr},
securityGroup: "sg-test",
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := &Client{
api: tc.api,
instances: tc.instances,
timeout: time.Millisecond,
securityGroup: tc.securityGroup,
}
if client.instances == nil {
client.instances = make(map[string]ec2.Instance)
}
input := CreateInput{
ImageId: "test-image",
InstanceType: "",
Count: 13,
}
err := client.CreateInstances(context.Background(), input)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.ElementsMatch(tc.wantInstances.IDs(), client.instances.IDs())
assert.ElementsMatch(tc.wantInstances.PublicIPs(), client.instances.PublicIPs())
assert.ElementsMatch(tc.wantInstances.PrivateIPs(), client.instances.PrivateIPs())
}
})
}
}
func TestTerminateInstances(t *testing.T) {
testAWSInstances := []types.Instance{
{InstanceId: aws.String("id-1"), State: &stateTerminated},
{InstanceId: aws.String("id-2"), State: &stateTerminated},
{InstanceId: aws.String("id-3"), State: &stateTerminated},
}
someErr := errors.New("failed")
var noErr error
testCases := map[string]struct {
api stubAPI
instances ec2.Instances
errExpected bool
}{
"client with instances": {
api: stubAPI{instances: testAWSInstances},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: false,
},
"client no instances set": {
api: stubAPI{},
},
"terminate API error": {
api: stubAPI{terminateInstancesErr: someErr},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: true,
},
"terminateDryRun API error": {
api: stubAPI{terminateInstancesDryRunErr: &someErr},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: true,
},
"terminateDryRun miss expected API error": {
api: stubAPI{terminateInstancesDryRunErr: &noErr},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := &Client{
api: tc.api,
instances: tc.instances,
timeout: time.Millisecond,
}
if client.instances == nil {
client.instances = make(map[string]ec2.Instance)
}
err := client.TerminateInstances(context.Background())
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.Empty(client.instances)
}
})
}
}
func TestWaitStateRunning(t *testing.T) {
testCases := map[string]struct {
api api
instances ec2.Instances
errExpected bool
}{
"instances are running": {
api: stubAPI{instances: []types.Instance{
{
InstanceId: aws.String("id-1"),
State: &stateRunning,
},
{
InstanceId: aws.String("id-2"),
State: &stateRunning,
},
{
InstanceId: aws.String("id-3"),
State: &stateRunning,
},
}},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: false,
},
"one instance running, rest nil": {
api: stubAPI{instances: []types.Instance{
{
InstanceId: aws.String("id-1"),
State: &stateRunning,
},
{InstanceId: aws.String("id-2")},
{InstanceId: aws.String("id-3")},
}},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: false,
},
"one instance terminated, rest nil": {
api: stubAPI{instances: []types.Instance{
{
InstanceId: aws.String("id-1"),
State: &stateTerminated,
},
{InstanceId: aws.String("id-2")},
{InstanceId: aws.String("id-3")},
}},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: true,
},
"instances with different state": {
api: stubAPI{instances: []types.Instance{
{
InstanceId: aws.String("id-1"),
State: &stateTerminated,
},
{
InstanceId: aws.String("id-2"),
State: &stateRunning,
},
{InstanceId: aws.String("id-3")},
}},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: true,
},
"all instances have nil state": {
api: stubAPI{instances: []types.Instance{
{InstanceId: aws.String("id-1")},
{InstanceId: aws.String("id-2")},
{InstanceId: aws.String("id-3")},
}},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: true,
},
"client has no instances": {
api: &stubAPI{},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := &Client{
api: tc.api,
instances: tc.instances,
timeout: time.Millisecond,
}
if client.instances == nil {
client.instances = make(map[string]ec2.Instance)
}
err := client.waitStateRunning(context.Background())
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestWaitStateTerminated(t *testing.T) {
testCases := map[string]struct {
api api
instances ec2.Instances
errExpected bool
}{
"instances are terminated": {
api: stubAPI{instances: []types.Instance{
{
InstanceId: aws.String("id-1"),
State: &stateTerminated,
},
{
InstanceId: aws.String("id-2"),
State: &stateTerminated,
},
{
InstanceId: aws.String("id-3"),
State: &stateTerminated,
},
}},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: false,
},
"one instance terminated, rest nil": {
api: stubAPI{instances: []types.Instance{
{
InstanceId: aws.String("id-1"),
State: &stateTerminated,
},
{InstanceId: aws.String("id-2")},
{InstanceId: aws.String("id-3")},
}},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: false,
},
"one instance running, rest nil": {
api: stubAPI{instances: []types.Instance{
{
InstanceId: aws.String("id-1"),
State: &stateRunning,
},
{InstanceId: aws.String("id-2")},
{InstanceId: aws.String("id-3")},
}},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: true,
},
"instances with different state": {
api: stubAPI{instances: []types.Instance{
{
InstanceId: aws.String("id-1"),
State: &stateTerminated,
},
{
InstanceId: aws.String("id-2"),
State: &stateRunning,
},
{InstanceId: aws.String("id-3")},
}},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: true,
},
"all instances have nil state": {
api: stubAPI{instances: []types.Instance{
{InstanceId: aws.String("id-1")},
{InstanceId: aws.String("id-2")},
{InstanceId: aws.String("id-3")},
}},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: true,
},
"client has no instances": {
api: &stubAPI{},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := &Client{
api: tc.api,
instances: tc.instances,
timeout: time.Millisecond,
}
if client.instances == nil {
client.instances = make(map[string]ec2.Instance)
}
err := client.waitStateTerminated(context.Background())
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestTagInstances(t *testing.T) {
testTags := ec2.Tags{
{Key: "Name", Value: "Test"},
{Key: "Foo", Value: "Bar"},
}
testCases := map[string]struct {
api stubAPI
instances ec2.Instances
errExpected bool
}{
"tag": {
api: stubAPI{},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: false,
},
"client without instances": {
api: stubAPI{createTagsErr: errors.New("failed")},
errExpected: true,
},
"tag API error": {
api: stubAPI{createTagsErr: errors.New("failed")},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := &Client{
api: tc.api,
instances: tc.instances,
timeout: time.Millisecond,
}
if client.instances == nil {
client.instances = make(map[string]ec2.Instance)
}
err := client.tagInstances(context.Background(), testTags)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestEc2RunInstanceInput(t *testing.T) {
assert := assert.New(t)
testCases := []struct {
in CreateInput
outExpected awsec2.RunInstancesInput
}{
{
in: CreateInput{
ImageId: "test-image",
InstanceType: "4xlarge",
Count: 13,
securityGroupIds: []string{"test-sec-group"},
},
outExpected: awsec2.RunInstancesInput{
ImageId: aws.String("test-image"),
InstanceType: types.InstanceTypeC5a4xlarge,
MinCount: aws.Int32(int32(13)),
MaxCount: aws.Int32(int32(13)),
EnclaveOptions: &types.EnclaveOptionsRequest{Enabled: aws.Bool(true)},
SecurityGroupIds: []string{"test-sec-group"},
},
},
{
in: CreateInput{
ImageId: "test-image-2",
InstanceType: "12xlarge",
Count: 2,
securityGroupIds: []string{"test-sec-group-2"},
},
outExpected: awsec2.RunInstancesInput{
ImageId: aws.String("test-image-2"),
InstanceType: types.InstanceTypeC5a12xlarge,
MinCount: aws.Int32(int32(2)),
MaxCount: aws.Int32(int32(2)),
EnclaveOptions: &types.EnclaveOptionsRequest{Enabled: aws.Bool(true)},
SecurityGroupIds: []string{"test-sec-group-2"},
},
},
}
for _, tc := range testCases {
out := tc.in.AWS()
assert.Equal(tc.outExpected, *out)
}
}

View File

@ -0,0 +1,136 @@
package client
import (
"context"
"errors"
"github.com/aws/aws-sdk-go-v2/aws"
awsec2 "github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/edgelesssys/constellation/cli/cloud/cloudtypes"
"github.com/google/uuid"
)
// CreateSecurityGroup creates a AWS security group with the handed properties.
func (c *Client) CreateSecurityGroup(ctx context.Context, input SecurityGroupInput) error {
if c.securityGroup != "" {
return errors.New("client already has a security group")
}
id := uuid.New()
createInput := &awsec2.CreateSecurityGroupInput{
Description: aws.String("Security group of Constellation. This group was generated through the Constellation CLI."),
GroupName: aws.String("Constellation-" + id.String()),
DryRun: aws.Bool(true),
}
// DryRun
_, err := c.api.CreateSecurityGroup(ctx, createInput)
if err := checkDryRunError(err); err != nil {
return err
}
createInput.DryRun = aws.Bool(false)
// Create
out, err := c.api.CreateSecurityGroup(ctx, createInput)
if err != nil {
return err
}
if out.GroupId == nil {
return errors.New("security group creation didn't return an id")
}
c.securityGroup = *out.GroupId
// Authorize.
return c.authorizeSecurityGroup(ctx, input)
}
// DeleteSecurityGroup deletes the security group of the client.
// The deletion is idempotent, no error is returned if the client has
// an empty securityGroupID.
func (c *Client) DeleteSecurityGroup(ctx context.Context) error {
if c.securityGroup == "" {
return nil
}
input := &awsec2.DeleteSecurityGroupInput{
GroupId: aws.String(c.securityGroup),
DryRun: aws.Bool(true),
}
// DryRun
_, err := c.api.DeleteSecurityGroup(ctx, input)
if err := checkDryRunError(err); err != nil {
return err
}
input.DryRun = aws.Bool(false)
// Delete
if _, err := c.api.DeleteSecurityGroup(ctx, input); err != nil {
return err
}
c.securityGroup = ""
return nil
}
func (c *Client) authorizeSecurityGroup(ctx context.Context, input SecurityGroupInput) error {
if c.securityGroup == "" {
return errors.New("client hasn't got a security group id")
}
if err := c.authorizeSecurityGroupIngress(ctx, input.Inbound); err != nil {
return err
}
return c.authorizeSecurityGroupEgress(ctx, input.Outbound)
}
func (c *Client) authorizeSecurityGroupIngress(ctx context.Context, perms cloudtypes.Firewall) error {
if len(perms) == 0 {
return nil
}
authInput := &awsec2.AuthorizeSecurityGroupIngressInput{
GroupId: aws.String(c.securityGroup),
IpPermissions: perms.AWS(),
DryRun: aws.Bool(true),
}
// DryRun
_, err := c.api.AuthorizeSecurityGroupIngress(ctx, authInput)
if err := checkDryRunError(err); err != nil {
return err
}
authInput.DryRun = aws.Bool(false)
// Authorize
_, err = c.api.AuthorizeSecurityGroupIngress(ctx, authInput)
return err
}
func (c *Client) authorizeSecurityGroupEgress(ctx context.Context, perms cloudtypes.Firewall) error {
if len(perms) == 0 {
return nil
}
authInput := &awsec2.AuthorizeSecurityGroupEgressInput{
GroupId: aws.String(c.securityGroup),
IpPermissions: perms.AWS(),
DryRun: aws.Bool(true),
}
// DryRun
_, err := c.api.AuthorizeSecurityGroupEgress(ctx, authInput)
if err := checkDryRunError(err); err != nil {
return err
}
authInput.DryRun = aws.Bool(false)
// Authorize
_, err = c.api.AuthorizeSecurityGroupEgress(ctx, authInput)
return err
}
type SecurityGroupInput struct {
Inbound cloudtypes.Firewall
Outbound cloudtypes.Firewall
}

View File

@ -0,0 +1,269 @@
package client
import (
"context"
"errors"
"testing"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/edgelesssys/constellation/cli/cloud/cloudtypes"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCreateSecurityGroup(t *testing.T) {
testInput := SecurityGroupInput{
Inbound: cloudtypes.Firewall{
{
Description: "perm1",
Protocol: "TCP",
IPRange: "192.0.2.0/24",
Port: 22,
},
{
Description: "perm2",
Protocol: "UDP",
IPRange: "192.0.2.0/24",
Port: 4433,
},
},
Outbound: cloudtypes.Firewall{
{
Description: "perm3",
Protocol: "TCP",
IPRange: "192.0.2.0/24",
Port: 4040,
},
},
}
someErr := errors.New("failed")
var noErr error
testCases := map[string]struct {
api stubAPI
securityGroup string
input SecurityGroupInput
errExpected bool
securityGroupExpected string
}{
"create security group": {
api: stubAPI{securityGroup: types.SecurityGroup{GroupId: aws.String("sg-test")}},
input: testInput,
securityGroupExpected: "sg-test",
},
"create security group without permissions": {
api: stubAPI{securityGroup: types.SecurityGroup{GroupId: aws.String("sg-test")}},
input: SecurityGroupInput{},
securityGroupExpected: "sg-test",
},
"client already has security group": {
api: stubAPI{},
securityGroup: "sg-test",
input: testInput,
errExpected: true,
},
"create returns nil security group ID": {
api: stubAPI{securityGroup: types.SecurityGroup{GroupId: nil}},
input: testInput,
errExpected: true,
},
"create API error": {
api: stubAPI{createSecurityGroupErr: someErr},
input: testInput,
errExpected: true,
},
"create DryRun API error": {
api: stubAPI{createSecurityGroupDryRunErr: &someErr},
input: testInput,
errExpected: true,
},
"create DryRun missing expected error": {
api: stubAPI{createSecurityGroupDryRunErr: &noErr},
input: testInput,
errExpected: true,
},
"authorize error": {
api: stubAPI{
securityGroup: types.SecurityGroup{GroupId: aws.String("sg-test")},
authorizeSecurityGroupIngressErr: someErr,
},
input: testInput,
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
client, err := newClient(tc.api)
require.NoError(err)
client.securityGroup = tc.securityGroup
err = client.CreateSecurityGroup(context.Background(), tc.input)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.securityGroupExpected, client.securityGroup)
}
})
}
}
func TestDeleteSecurityGroup(t *testing.T) {
someErr := errors.New("failed")
var noErr error
testCases := map[string]struct {
api stubAPI
securityGroup string
errExpected bool
}{
"delete security group": {
api: stubAPI{},
securityGroup: "sg-test",
},
"client without security group": {
api: stubAPI{},
},
"delete API error": {
api: stubAPI{deleteSecurityGroupErr: someErr},
securityGroup: "sg-test",
errExpected: true,
},
"delete DryRun API error": {
api: stubAPI{deleteSecurityGroupDryRunErr: &someErr},
securityGroup: "sg-test",
errExpected: true,
},
"delete DryRun missing expected error": {
api: stubAPI{deleteSecurityGroupDryRunErr: &noErr},
securityGroup: "sg-test",
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
client, err := newClient(tc.api)
require.NoError(err)
client.securityGroup = tc.securityGroup
err = client.DeleteSecurityGroup(context.Background())
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.Empty(client.securityGroup)
}
})
}
}
func TestAuthorizeSecurityGroup(t *testing.T) {
testInput := SecurityGroupInput{
Inbound: cloudtypes.Firewall{
{
Description: "perm1",
Protocol: "TCP",
IPRange: " 192.0.2.0/24",
Port: 22,
},
{
Description: "perm2",
Protocol: "UDP",
IPRange: "192.0.2.0/24",
Port: 4433,
},
},
Outbound: cloudtypes.Firewall{
{
Description: "perm3",
Protocol: "TCP",
IPRange: "192.0.2.0/24",
Port: 4040,
},
},
}
someErr := errors.New("failed")
var noErr error
testCases := map[string]struct {
api stubAPI
securityGroup string
input SecurityGroupInput
errExpected bool
}{
"authorize": {
api: stubAPI{},
securityGroup: "sg-test",
input: testInput,
errExpected: false,
},
"client without security group": {
api: stubAPI{},
input: testInput,
errExpected: true,
},
"authorizeIngress API error": {
api: stubAPI{authorizeSecurityGroupIngressErr: someErr},
securityGroup: "sg-test",
input: testInput,
errExpected: true,
},
"authorizeIngress DryRun API error": {
api: stubAPI{authorizeSecurityGroupIngressDryRunErr: &someErr},
securityGroup: "sg-test",
input: testInput,
errExpected: true,
},
"authorizeIngress DryRun missing expected error": {
api: stubAPI{authorizeSecurityGroupIngressDryRunErr: &noErr},
securityGroup: "sg-test",
input: testInput,
errExpected: true,
},
"authorizeEgress API error": {
api: stubAPI{authorizeSecurityGroupEgressErr: someErr},
securityGroup: "sg-test",
input: testInput,
errExpected: true,
},
"authorizeEgress DryRun API error": {
api: stubAPI{authorizeSecurityGroupEgressDryRunErr: &someErr},
securityGroup: "sg-test",
input: testInput,
errExpected: true,
},
"authorizeEgress DryRun missing expected error": {
api: stubAPI{authorizeSecurityGroupEgressDryRunErr: &noErr},
securityGroup: "sg-test",
input: testInput,
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
client, err := newClient(tc.api)
require.NoError(err)
client.securityGroup = tc.securityGroup
err = client.authorizeSecurityGroup(context.Background(), tc.input)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}

21
cli/ec2/client/util.go Normal file
View File

@ -0,0 +1,21 @@
package client
import (
"errors"
"github.com/aws/smithy-go"
)
// checkDryRunError error checks if an error is a DryRun error.
// If the error is nil, an error is returned, since a DryRun error
// is the expected result of a DryRun operation.
func checkDryRunError(err error) error {
var apiErr smithy.APIError
if errors.As(err, &apiErr) && apiErr.ErrorCode() == "DryRunOperation" {
return nil
}
if err != nil {
return err
}
return errors.New("expected APIError: DryRunOperation, but got no error at all")
}

View File

@ -0,0 +1,22 @@
package client
import (
"errors"
"testing"
"github.com/aws/smithy-go"
"github.com/stretchr/testify/assert"
)
func TestCheckDryRunError(t *testing.T) {
assert := assert.New(t)
someErr := errors.New("failed")
assert.ErrorIs(checkDryRunError(someErr), someErr)
dryRunErr := smithy.GenericAPIError{Code: "DryRunOperation"}
assert.NoError(checkDryRunError(&dryRunErr))
var nilErr error
assert.Error(checkDryRunError(nilErr))
}

58
cli/ec2/instances.go Normal file
View File

@ -0,0 +1,58 @@
package ec2
import "errors"
// Instance is an ec2 instance.
type Instance struct {
PublicIP string
PrivateIP string
}
// Instances is a map of ec2 Instances. The ID of an instance is used as key.
type Instances map[string]Instance
// IDs returns the IDs of all instances of the Constellation.
func (i Instances) IDs() []string {
var ids []string
for id := range i {
ids = append(ids, id)
}
return ids
}
// PublicIPs returns the public IPs of all the instances of the Constellation.
func (i Instances) PublicIPs() []string {
var ips []string
for _, instance := range i {
ips = append(ips, instance.PublicIP)
}
return ips
}
// PrivateIPs returns the private IPs of all the instances of the Constellation.
func (i Instances) PrivateIPs() []string {
var ips []string
for _, instance := range i {
ips = append(ips, instance.PrivateIP)
}
return ips
}
// GetOne return anyone instance out of the instances and its ID.
func (i Instances) GetOne() (string, Instance, error) {
for id, instance := range i {
return id, instance, nil
}
return "", Instance{}, errors.New("map is empty")
}
// GetOthers returns all instances but the one with the handed ID.
func (i Instances) GetOthers(id string) Instances {
others := make(Instances)
for key, instance := range i {
if key != id {
others[key] = instance
}
}
return others
}

71
cli/ec2/instances_test.go Normal file
View File

@ -0,0 +1,71 @@
package ec2
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestIDs(t *testing.T) {
assert := assert.New(t)
testState := testInstances()
expectedIDs := []string{"id-9", "id-10", "id-11", "id-12"}
assert.ElementsMatch(expectedIDs, testState.IDs())
}
func TestPublicIPs(t *testing.T) {
assert := assert.New(t)
testState := testInstances()
expectedIPs := []string{"192.0.2.1", "192.0.2.3", "192.0.2.5", "192.0.2.7"}
assert.ElementsMatch(expectedIPs, testState.PublicIPs())
}
func TestPrivateIPs(t *testing.T) {
assert := assert.New(t)
testState := testInstances()
expectedIPs := []string{"192.0.2.2", "192.0.2.4", "192.0.2.6", "192.0.2.8"}
assert.ElementsMatch(expectedIPs, testState.PrivateIPs())
}
func TestGetOne(t *testing.T) {
assert := assert.New(t)
testState := testInstances()
id, instance, err := testState.GetOne()
assert.NoError(err)
assert.Contains(testState, id)
assert.Equal(testState[id], instance)
}
func TestGetOthers(t *testing.T) {
assert := assert.New(t)
testCases := testInstances().IDs()
for _, id := range testCases {
others := testInstances().GetOthers(id)
assert.NotContains(others, id)
expectedInstances := testInstances()
delete(expectedInstances, id)
assert.ElementsMatch(others.IDs(), expectedInstances.IDs())
}
}
func testInstances() Instances {
return Instances{
"id-9": {
PublicIP: "192.0.2.1",
PrivateIP: "192.0.2.2",
},
"id-10": {
PublicIP: "192.0.2.3",
PrivateIP: "192.0.2.4",
},
"id-11": {
PublicIP: "192.0.2.5",
PrivateIP: "192.0.2.6",
},
"id-12": {
PublicIP: "192.0.2.7",
PrivateIP: "192.0.2.8",
},
}
}

18
cli/ec2/instancetypes.go Normal file
View File

@ -0,0 +1,18 @@
package ec2
import "github.com/aws/aws-sdk-go-v2/service/ec2/types"
// InstanceTypes defines possible values for the SIZE positional argument.
var InstanceTypes = map[string]types.InstanceType{
"4xlarge": types.InstanceTypeC5a4xlarge,
"8xlarge": types.InstanceTypeC5a8xlarge,
"12xlarge": types.InstanceTypeC5a12xlarge,
"16xlarge": types.InstanceTypeC5a16xlarge,
"24xlarge": types.InstanceTypeC5a24xlarge,
// shorthands
"4xl": types.InstanceTypeC5a4xlarge,
"8xl": types.InstanceTypeC5a8xlarge,
"12xl": types.InstanceTypeC5a12xlarge,
"16xl": types.InstanceTypeC5a16xlarge,
"24xl": types.InstanceTypeC5a24xlarge,
}

Some files were not shown because too many files have changed in this diff Show More