From c603b547dbf752cb21051080006e5ff4c144cb4b Mon Sep 17 00:00:00 2001 From: Otto Bittner Date: Mon, 9 Oct 2023 15:18:12 +0200 Subject: [PATCH] s3proxy: add allow-multipart flag (#2420) This flag allows users to control wether multipart uploads are blocked or allowed. At the moment s3proxy doesn't encrypt multipart uploads, so there is a potential for inadvertent data leakage. With this flag the default behavior is changed to a more secure default one: block multipart uploads. The previous behavior can be enabled by setting allow-multipart. --- s3proxy/cmd/main.go | 31 ++-- s3proxy/deploy/deployment-s3proxy.yaml | 2 +- s3proxy/internal/router/BUILD.bazel | 1 + s3proxy/internal/router/handler.go | 191 +++++++++++++++++++++++++ s3proxy/internal/router/router.go | 179 ++++++----------------- 5 files changed, 253 insertions(+), 151 deletions(-) create mode 100644 s3proxy/internal/router/handler.go diff --git a/s3proxy/cmd/main.go b/s3proxy/cmd/main.go index 5b471d8af..12a72622d 100644 --- a/s3proxy/cmd/main.go +++ b/s3proxy/cmd/main.go @@ -49,6 +49,10 @@ func main() { logger := logger.New(logger.JSONLog, logger.VerbosityFromInt(flags.logLevel)) + if flags.forwardMultipartReqs { + logger.Warnf("configured to forward multipart uploads, this may leak data to AWS") + } + if err := runServer(flags, logger); err != nil { panic(err) } @@ -57,7 +61,7 @@ func main() { func runServer(flags cmdFlags, log *logger.Logger) error { log.With(zap.String("ip", flags.ip), zap.Int("port", defaultPort), zap.String("region", flags.region)).Infof("listening") - router, err := router.New(flags.region, flags.kmsEndpoint, log) + router, err := router.New(flags.region, flags.kmsEndpoint, flags.forwardMultipartReqs, log) if err != nil { return fmt.Errorf("creating router: %w", err) } @@ -96,6 +100,7 @@ func parseFlags() (cmdFlags, error) { region := flag.String("region", defaultRegion, "AWS region in which target bucket is located") certLocation := flag.String("cert", defaultCertLocation, "location of TLS certificate") kmsEndpoint := flag.String("kms", "key-service.kube-system:9000", "endpoint of the KMS service to get key encryption keys from") + forwardMultipartReqs := flag.Bool("allow-multipart", false, "forward multipart requests to the target bucket; beware: this may store unencrypted data on AWS. See the documentation for more information") level := flag.Int("level", defaultLogLevel, "log level") flag.Parse() @@ -112,21 +117,23 @@ func parseFlags() (cmdFlags, error) { // } return cmdFlags{ - noTLS: *noTLS, - ip: netIP.String(), - region: *region, - certLocation: *certLocation, - kmsEndpoint: *kmsEndpoint, - logLevel: *level, + noTLS: *noTLS, + ip: netIP.String(), + region: *region, + certLocation: *certLocation, + kmsEndpoint: *kmsEndpoint, + forwardMultipartReqs: *forwardMultipartReqs, + logLevel: *level, }, nil } type cmdFlags struct { - noTLS bool - ip string - region string - certLocation string - kmsEndpoint string + noTLS bool + ip string + region string + certLocation string + kmsEndpoint string + forwardMultipartReqs bool // TODO(derpsteb): enable once we are on go 1.21. // logLevel slog.Level logLevel int diff --git a/s3proxy/deploy/deployment-s3proxy.yaml b/s3proxy/deploy/deployment-s3proxy.yaml index 441770eff..c3a5f38d9 100644 --- a/s3proxy/deploy/deployment-s3proxy.yaml +++ b/s3proxy/deploy/deployment-s3proxy.yaml @@ -47,7 +47,7 @@ spec: - name: regcred containers: - name: s3proxy - image: ghcr.io/edgelesssys/constellation/s3proxy@sha256:2394a804e8b5ff487a55199dd83138885322a4de8e71ac7ce67b79d4ffc842b2 + image: ghcr.io/edgelesssys/constellation/s3proxy:v2.12.0-pre.0.20231009141917-226cb427d0b1 args: - "--level=-1" ports: diff --git a/s3proxy/internal/router/BUILD.bazel b/s3proxy/internal/router/BUILD.bazel index c60568bce..8768bc493 100644 --- a/s3proxy/internal/router/BUILD.bazel +++ b/s3proxy/internal/router/BUILD.bazel @@ -4,6 +4,7 @@ load("//bazel/go:go_test.bzl", "go_test") go_library( name = "router", srcs = [ + "handler.go", "object.go", "router.go", ], diff --git a/s3proxy/internal/router/handler.go b/s3proxy/internal/router/handler.go new file mode 100644 index 000000000..75cc0fbe2 --- /dev/null +++ b/s3proxy/internal/router/handler.go @@ -0,0 +1,191 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +package router + +import ( + "encoding/xml" + "fmt" + "io" + "net/http" + + "github.com/edgelesssys/constellation/v2/internal/logger" + "github.com/edgelesssys/constellation/v2/s3proxy/internal/s3" + "go.uber.org/zap" +) + +func handleGetObject(client *s3.Client, key string, bucket string, log *logger.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + log.With(zap.String("path", req.URL.Path), zap.String("method", req.Method), zap.String("host", req.Host)).Debugf("intercepting") + if req.Header.Get("Range") != "" { + log.Errorf("GetObject Range header unsupported") + http.Error(w, "s3proxy currently does not support Range headers", http.StatusNotImplemented) + return + } + + obj := object{ + client: client, + key: key, + bucket: bucket, + query: req.URL.Query(), + sseCustomerAlgorithm: req.Header.Get("x-amz-server-side-encryption-customer-algorithm"), + sseCustomerKey: req.Header.Get("x-amz-server-side-encryption-customer-key"), + sseCustomerKeyMD5: req.Header.Get("x-amz-server-side-encryption-customer-key-MD5"), + log: log, + } + get(obj.get)(w, req) + } +} + +func handlePutObject(client *s3.Client, key string, bucket string, log *logger.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + log.With(zap.String("path", req.URL.Path), zap.String("method", req.Method), zap.String("host", req.Host)).Debugf("intercepting") + body, err := io.ReadAll(req.Body) + if err != nil { + log.With(zap.Error(err)).Errorf("PutObject") + http.Error(w, fmt.Sprintf("reading body: %s", err.Error()), http.StatusInternalServerError) + return + } + + clientDigest := req.Header.Get("x-amz-content-sha256") + serverDigest := sha256sum(body) + + // There may be a client that wants to test that incorrect content digests result in API errors. + // For encrypting the body we have to recalculate the content digest. + // If the client intentionally sends a mismatching content digest, we would take the client request, rewrap it, + // calculate the correct digest for the new body and NOT get an error. + // Thus we have to check incoming requets for matching content digests. + // UNSIGNED-PAYLOAD can be used to disabled payload signing. In that case we don't check the content digest. + if clientDigest != "" && clientDigest != "UNSIGNED-PAYLOAD" && clientDigest != serverDigest { + log.Debugf("PutObject", "error", "x-amz-content-sha256 mismatch") + // The S3 API responds with an XML formatted error message. + mismatchErr := NewContentSHA256MismatchError(clientDigest, serverDigest) + marshalled, err := xml.Marshal(mismatchErr) + if err != nil { + log.With(zap.Error(err)).Errorf("PutObject") + http.Error(w, fmt.Sprintf("marshalling error: %s", err.Error()), http.StatusInternalServerError) + return + } + + http.Error(w, string(marshalled), http.StatusBadRequest) + return + } + + metadata := getMetadataHeaders(req.Header) + + raw := req.Header.Get("x-amz-object-lock-retain-until-date") + retentionTime, err := parseRetentionTime(raw) + if err != nil { + log.With(zap.String("data", raw), zap.Error(err)).Errorf("parsing lock retention time") + http.Error(w, fmt.Sprintf("parsing x-amz-object-lock-retain-until-date: %s", err.Error()), http.StatusInternalServerError) + return + } + + err = validateContentMD5(req.Header.Get("content-md5"), body) + if err != nil { + log.With(zap.Error(err)).Errorf("validating content md5") + http.Error(w, fmt.Sprintf("validating content md5: %s", err.Error()), http.StatusBadRequest) + return + } + + obj := object{ + client: client, + key: key, + bucket: bucket, + data: body, + query: req.URL.Query(), + tags: req.Header.Get("x-amz-tagging"), + contentType: req.Header.Get("Content-Type"), + metadata: metadata, + objectLockLegalHoldStatus: req.Header.Get("x-amz-object-lock-legal-hold"), + objectLockMode: req.Header.Get("x-amz-object-lock-mode"), + objectLockRetainUntilDate: retentionTime, + sseCustomerAlgorithm: req.Header.Get("x-amz-server-side-encryption-customer-algorithm"), + sseCustomerKey: req.Header.Get("x-amz-server-side-encryption-customer-key"), + sseCustomerKeyMD5: req.Header.Get("x-amz-server-side-encryption-customer-key-MD5"), + log: log, + } + + put(obj.put)(w, req) + } +} + +func handleForwards(log *logger.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + log.With(zap.String("path", req.URL.Path), zap.String("method", req.Method), zap.String("host", req.Host)).Debugf("forwarding") + + newReq := repackage(req) + + httpClient := http.DefaultClient + resp, err := httpClient.Do(&newReq) + if err != nil { + log.With(zap.Error(err)).Errorf("do request") + http.Error(w, fmt.Sprintf("do request: %s", err.Error()), http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + for key := range resp.Header { + w.Header().Set(key, resp.Header.Get(key)) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + log.With(zap.Error(err)).Errorf("ReadAll") + http.Error(w, fmt.Sprintf("reading body: %s", err.Error()), http.StatusInternalServerError) + return + } + w.WriteHeader(resp.StatusCode) + if body == nil { + return + } + + if _, err := w.Write(body); err != nil { + log.With(zap.Error(err)).Errorf("Write") + http.Error(w, fmt.Sprintf("writing body: %s", err.Error()), http.StatusInternalServerError) + return + } + } +} + +// handleCreateMultipartUpload logs the request and blocks with an error message. +func handleCreateMultipartUpload(log *logger.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + log.With(zap.String("path", req.URL.Path), zap.String("method", req.Method), zap.String("host", req.Host)).Debugf("intercepting CreateMultipartUpload") + + log.Errorf("Blocking CreateMultipartUpload request") + http.Error(w, "s3proxy is configured to block CreateMultipartUpload requests", http.StatusNotImplemented) + } +} + +// handleUploadPart logs the request and blocks with an error message. +func handleUploadPart(log *logger.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + log.With(zap.String("path", req.URL.Path), zap.String("method", req.Method), zap.String("host", req.Host)).Debugf("intercepting UploadPart") + + log.Errorf("Blocking UploadPart request") + http.Error(w, "s3proxy is configured to block UploadPart requests", http.StatusNotImplemented) + } +} + +// handleCompleteMultipartUpload logs the request and blocks with an error message. +func handleCompleteMultipartUpload(log *logger.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + log.With(zap.String("path", req.URL.Path), zap.String("method", req.Method), zap.String("host", req.Host)).Debugf("intercepting CompleteMultipartUpload") + + log.Errorf("Blocking CompleteMultipartUpload request") + http.Error(w, "s3proxy is configured to block CompleteMultipartUpload requests", http.StatusNotImplemented) + } +} + +// handleAbortMultipartUpload logs the request and blocks with an error message. +func handleAbortMultipartUpload(log *logger.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + log.With(zap.String("path", req.URL.Path), zap.String("method", req.Method), zap.String("host", req.Host)).Debugf("intercepting AbortMultipartUpload") + + log.Errorf("Blocking AbortMultipartUpload request") + http.Error(w, "s3proxy is configured to block AbortMultipartUpload requests", http.StatusNotImplemented) + } +} diff --git a/s3proxy/internal/router/router.go b/s3proxy/internal/router/router.go index bd9b84427..1680e44de 100644 --- a/s3proxy/internal/router/router.go +++ b/s3proxy/internal/router/router.go @@ -27,7 +27,6 @@ import ( "encoding/base64" "encoding/xml" "fmt" - "io" "net/http" "net/url" "regexp" @@ -37,7 +36,6 @@ import ( "github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/s3proxy/internal/kms" "github.com/edgelesssys/constellation/v2/s3proxy/internal/s3" - "go.uber.org/zap" ) const ( @@ -55,11 +53,15 @@ var ( type Router struct { region string kek [32]byte - log *logger.Logger + // forwardMultipartReqs controls whether we forward the following requests: CreateMultipartUpload, UploadPart, CompleteMultipartUpload, AbortMultipartUpload. + // s3proxy does not implement those yet. + // Setting forwardMultipartReqs to true will forward those requests to the S3 API, otherwise we block them (secure defaults). + forwardMultipartReqs bool + log *logger.Logger } // New creates a new Router. -func New(region, endpoint string, log *logger.Logger) (Router, error) { +func New(region, endpoint string, forwardMultipartReqs bool, log *logger.Logger) (Router, error) { kms := kms.New(log, endpoint) // Get the key encryption key that encrypts all DEKs. @@ -73,7 +75,7 @@ func New(region, endpoint string, log *logger.Logger) (Router, error) { return Router{}, fmt.Errorf("converting KEK to byte array: %w", err) } - return Router{region: region, kek: kekArray, log: log}, nil + return Router{region: region, kek: kekArray, forwardMultipartReqs: forwardMultipartReqs, log: log}, nil } // Serve implements the routing logic for the s3 proxy. @@ -103,6 +105,7 @@ func (r Router) Serve(w http.ResponseWriter, req *http.Request) { } var h http.Handler + switch { // intercept GetObject. case matchingPath && req.Method == "GET" && !isUnwantedGetEndpoint(req.URL.Query()): @@ -110,6 +113,14 @@ func (r Router) Serve(w http.ResponseWriter, req *http.Request) { // intercept PutObject. case matchingPath && req.Method == "PUT" && !isUnwantedPutEndpoint(req.Header, req.URL.Query()): h = handlePutObject(client, key, bucket, r.log) + case !r.forwardMultipartReqs && matchingPath && isUploadPart(req.Method, req.URL.Query()): + h = handleUploadPart(r.log) + case !r.forwardMultipartReqs && matchingPath && isCreateMultipartUpload(req.Method, req.URL.Query()): + h = handleCreateMultipartUpload(r.log) + case !r.forwardMultipartReqs && matchingPath && isCompleteMultipartUpload(req.Method, req.URL.Query()): + h = handleCompleteMultipartUpload(r.log) + case !r.forwardMultipartReqs && matchingPath && isAbortMultipartUpload(req.Method, req.URL.Query()): + h = handleAbortMultipartUpload(r.log) // Forward all other requests. default: h = handleForwards(r.log) @@ -118,6 +129,31 @@ func (r Router) Serve(w http.ResponseWriter, req *http.Request) { h.ServeHTTP(w, req) } +func isAbortMultipartUpload(method string, query url.Values) bool { + _, uploadID := query["uploadId"] + + return method == "DELETE" && uploadID +} + +func isCompleteMultipartUpload(method string, query url.Values) bool { + _, multipart := query["uploadId"] + + return method == "POST" && multipart +} + +func isCreateMultipartUpload(method string, query url.Values) bool { + _, multipart := query["uploads"] + + return method == "POST" && multipart +} + +func isUploadPart(method string, query url.Values) bool { + _, partNumber := query["partNumber"] + _, uploadID := query["uploadId"] + + return method == "PUT" && partNumber && uploadID +} + // ContentSHA256MismatchError is a helper struct to create an XML formatted error message. // s3 clients might try to parse error messages, so we need to serve correctly formatted messages. type ContentSHA256MismatchError struct { @@ -138,139 +174,6 @@ func NewContentSHA256MismatchError(clientComputedContentSHA256, s3ComputedConten } } -func handleGetObject(client *s3.Client, key string, bucket string, log *logger.Logger) http.HandlerFunc { - return func(w http.ResponseWriter, req *http.Request) { - log.With(zap.String("path", req.URL.Path), zap.String("method", req.Method), zap.String("host", req.Host)).Debugf("intercepting") - if req.Header.Get("Range") != "" { - log.Errorf("GetObject Range header unsupported") - http.Error(w, "s3proxy currently does not support Range headers", http.StatusNotImplemented) - return - } - - obj := object{ - client: client, - key: key, - bucket: bucket, - query: req.URL.Query(), - sseCustomerAlgorithm: req.Header.Get("x-amz-server-side-encryption-customer-algorithm"), - sseCustomerKey: req.Header.Get("x-amz-server-side-encryption-customer-key"), - sseCustomerKeyMD5: req.Header.Get("x-amz-server-side-encryption-customer-key-MD5"), - log: log, - } - get(obj.get)(w, req) - } -} - -func handlePutObject(client *s3.Client, key string, bucket string, log *logger.Logger) http.HandlerFunc { - return func(w http.ResponseWriter, req *http.Request) { - log.With(zap.String("path", req.URL.Path), zap.String("method", req.Method), zap.String("host", req.Host)).Debugf("intercepting") - body, err := io.ReadAll(req.Body) - if err != nil { - log.With(zap.Error(err)).Errorf("PutObject") - http.Error(w, fmt.Sprintf("reading body: %s", err.Error()), http.StatusInternalServerError) - return - } - - clientDigest := req.Header.Get("x-amz-content-sha256") - serverDigest := sha256sum(body) - - // There may be a client that wants to test that incorrect content digests result in API errors. - // For encrypting the body we have to recalculate the content digest. - // If the client intentionally sends a mismatching content digest, we would take the client request, rewrap it, - // calculate the correct digest for the new body and NOT get an error. - // Thus we have to check incoming requets for matching content digests. - // UNSIGNED-PAYLOAD can be used to disabled payload signing. In that case we don't check the content digest. - if clientDigest != "" && clientDigest != "UNSIGNED-PAYLOAD" && clientDigest != serverDigest { - log.Debugf("PutObject", "error", "x-amz-content-sha256 mismatch") - // The S3 API responds with an XML formatted error message. - mismatchErr := NewContentSHA256MismatchError(clientDigest, serverDigest) - marshalled, err := xml.Marshal(mismatchErr) - if err != nil { - log.With(zap.Error(err)).Errorf("PutObject") - http.Error(w, fmt.Sprintf("marshalling error: %s", err.Error()), http.StatusInternalServerError) - return - } - - http.Error(w, string(marshalled), http.StatusBadRequest) - return - } - - metadata := getMetadataHeaders(req.Header) - - raw := req.Header.Get("x-amz-object-lock-retain-until-date") - retentionTime, err := parseRetentionTime(raw) - if err != nil { - log.With(zap.String("data", raw), zap.Error(err)).Errorf("parsing lock retention time") - http.Error(w, fmt.Sprintf("parsing x-amz-object-lock-retain-until-date: %s", err.Error()), http.StatusInternalServerError) - return - } - - err = validateContentMD5(req.Header.Get("content-md5"), body) - if err != nil { - log.With(zap.Error(err)).Errorf("validating content md5") - http.Error(w, fmt.Sprintf("validating content md5: %s", err.Error()), http.StatusBadRequest) - return - } - - obj := object{ - client: client, - key: key, - bucket: bucket, - data: body, - query: req.URL.Query(), - tags: req.Header.Get("x-amz-tagging"), - contentType: req.Header.Get("Content-Type"), - metadata: metadata, - objectLockLegalHoldStatus: req.Header.Get("x-amz-object-lock-legal-hold"), - objectLockMode: req.Header.Get("x-amz-object-lock-mode"), - objectLockRetainUntilDate: retentionTime, - sseCustomerAlgorithm: req.Header.Get("x-amz-server-side-encryption-customer-algorithm"), - sseCustomerKey: req.Header.Get("x-amz-server-side-encryption-customer-key"), - sseCustomerKeyMD5: req.Header.Get("x-amz-server-side-encryption-customer-key-MD5"), - log: log, - } - - put(obj.put)(w, req) - } -} - -func handleForwards(log *logger.Logger) http.HandlerFunc { - return func(w http.ResponseWriter, req *http.Request) { - log.With(zap.String("path", req.URL.Path), zap.String("method", req.Method), zap.String("host", req.Host)).Debugf("forwarding") - - newReq := repackage(req) - - httpClient := http.DefaultClient - resp, err := httpClient.Do(&newReq) - if err != nil { - log.With(zap.Error(err)).Errorf("do request") - http.Error(w, fmt.Sprintf("do request: %s", err.Error()), http.StatusInternalServerError) - return - } - defer resp.Body.Close() - - for key := range resp.Header { - w.Header().Set(key, resp.Header.Get(key)) - } - body, err := io.ReadAll(resp.Body) - if err != nil { - log.With(zap.Error(err)).Errorf("ReadAll") - http.Error(w, fmt.Sprintf("reading body: %s", err.Error()), http.StatusInternalServerError) - return - } - w.WriteHeader(resp.StatusCode) - if body == nil { - return - } - - if _, err := w.Write(body); err != nil { - log.With(zap.Error(err)).Errorf("Write") - http.Error(w, fmt.Sprintf("writing body: %s", err.Error()), http.StatusInternalServerError) - return - } - } -} - // byteSliceToByteArray casts a byte slice to a byte array of length 32. // It does a length check to prevent the cast from panic'ing. func byteSliceToByteArray(input []byte) ([32]byte, error) {