ceph-csi/internal/kms/kmip.go
Praveen M e345b26340 cleanup: refactor functions to accept a context parameter
Signed-off-by: Praveen M <m.praveen@ibm.com>
2024-03-12 13:54:19 +00:00

535 lines
14 KiB
Go

/*
Copyright 2022 The Ceph-CSI Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package kms
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io"
"time"
"github.com/ceph/ceph-csi/internal/util/k8s"
kmip "github.com/gemalto/kmip-go"
"github.com/gemalto/kmip-go/kmip14"
"github.com/gemalto/kmip-go/ttlv"
"github.com/google/uuid"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
const (
kmsTypeKMIP = "kmip"
// kmipDefaulfReadTimeout is the default read network timeout.
kmipDefaulfReadTimeout = 10
// kmipDefaultWriteTimeout is the default write network timeout.
kmipDefaultWriteTimeout = 10
// KMIP version.
protocolMajor = 1
protocolMinor = 4
// nonceSize is required to generate nonce for encrypting DEK.
nonceSize = 16
// kmipDefaultSecretsName is the default name of the Kubernetes Secret
// that contains the credentials to access the KMIP server. The name of
// the Secret can be configured by setting the `KMIP_SECRET_NAME`
// option.
//
// #nosec:G101, value not credential, just references token.
kmipDefaultSecretsName = "ceph-csi-kmip-credentials"
kmipEndpoint = "KMIP_ENDPOINT"
kmipTLSServerName = "TLS_SERVER_NAME"
kmipReadTimeOut = "READ_TIMEOUT"
kmipWriteTimeOut = "WRITE_TIMEOUT"
// The following options are part of the Kubernetes Secrets.
//
// #nosec:G101, value not credential, just configuration keys.
kmipSecretNameKey = "KMIP_SECRET_NAME"
kmipCACert = "CA_CERT"
kmipCLientCert = "CLIENT_CERT"
kmipClientKey = "CLIENT_KEY"
kmipUniqueIdentifier = "UNIQUE_IDENTIFIER"
)
var _ = RegisterProvider(Provider{
UniqueID: kmsTypeKMIP,
Initializer: initKMIPKMS,
})
type kmipKMS struct {
// basic options to get the secret
secretName string
namespace string
// standard KMIP configuration options
endpoint string
tlsConfig *tls.Config
uniqueIdentifier string
readTimeout uint8
writeTimeout uint8
}
func initKMIPKMS(args ProviderInitArgs) (EncryptionKMS, error) {
kms := &kmipKMS{
namespace: args.Namespace,
}
// get secret name if set, else use default.
err := setConfigString(&kms.secretName, args.Config, kmipSecretNameKey)
if errors.Is(err, errConfigOptionInvalid) {
return nil, err
} else if errors.Is(err, errConfigOptionMissing) {
kms.secretName = kmipDefaultSecretsName
}
err = setConfigString(&kms.endpoint, args.Config, kmipEndpoint)
if err != nil {
return nil, err
}
// optional
serverName := ""
err = setConfigString(&serverName, args.Config, kmipTLSServerName)
if errors.Is(err, errConfigOptionInvalid) {
return nil, err
}
// optional
timeout := kmipDefaulfReadTimeout
err = setConfigInt(&timeout, args.Config, kmipReadTimeOut)
if errors.Is(err, errConfigOptionInvalid) {
return nil, err
}
kms.readTimeout = uint8(timeout)
// optional
timeout = kmipDefaultWriteTimeout
err = setConfigInt(&timeout, args.Config, kmipWriteTimeOut)
if errors.Is(err, errConfigOptionInvalid) {
return nil, err
}
kms.writeTimeout = uint8(timeout)
// read the Kubernetes Secret with CA cert, client cert, client key
// & key unique identifier.
secrets, err := kms.getSecrets()
if err != nil {
return nil, fmt.Errorf("failed to get secrets: %w", err)
}
caCert, found := secrets[kmipCACert]
if !found {
return nil, fmt.Errorf("%w: %s", errConfigOptionMissing, kmipCACert)
}
clientCert, found := secrets[kmipCLientCert]
if !found {
return nil, fmt.Errorf("%w: %s", errConfigOptionMissing, kmipCLientCert)
}
clientKey, found := secrets[kmipClientKey]
if !found {
return nil, fmt.Errorf("%w: %s", errConfigOptionMissing, kmipCLientCert)
}
kms.uniqueIdentifier, found = secrets[kmipUniqueIdentifier]
if !found {
return nil, fmt.Errorf("%w: %s", errConfigOptionMissing, kmipUniqueIdentifier)
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM([]byte(caCert))
cert, err := tls.X509KeyPair([]byte(clientCert), []byte(clientKey))
if err != nil {
return nil, fmt.Errorf("invalid X509 key pair: %w", err)
}
kms.tlsConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
ServerName: serverName,
RootCAs: caCertPool,
Certificates: []tls.Certificate{cert},
}
return kms, nil
}
// EncryptDEK uses the KMIP encrypt operation to encrypt the DEK.
func (kms *kmipKMS) EncryptDEK(ctx context.Context, _, plainDEK string) (string, error) {
conn, err := kms.connect()
if err != nil {
return "", err
}
defer conn.Close()
emd := encryptedMetedataDEK{}
emd.Nonce, err = generateNonce(nonceSize)
if err != nil {
return "", fmt.Errorf("failed to generated nonce: %w", err)
}
respMsg, decoder, uniqueBatchItemID, err := kms.send(conn,
kmip14.OperationEncrypt,
EncryptRequestPayload{
UniqueIdentifier: kms.uniqueIdentifier,
Data: []byte(plainDEK),
CryptographicParameters: kmip.CryptographicParameters{
PaddingMethod: kmip14.PaddingMethodPKCS5,
CryptographicAlgorithm: kmip14.CryptographicAlgorithmAES,
BlockCipherMode: kmip14.BlockCipherModeCBC,
},
IVCounterNonce: emd.Nonce,
})
if err != nil {
return "", err
}
batchItem, err := kms.verifyResponse(respMsg, kmip14.OperationEncrypt, uniqueBatchItemID)
if err != nil {
return "", err
}
ttlvPayload, ok := batchItem.ResponsePayload.(ttlv.TTLV)
if !ok {
return "", errors.New("failed to parse responsePayload")
}
var encryptRespPayload EncryptResponsePayload
err = decoder.DecodeValue(&encryptRespPayload, ttlvPayload)
if err != nil {
return "", err
}
emd.DEK = encryptRespPayload.Data
emdData, err := json.Marshal(&emd)
if err != nil {
return "", fmt.Errorf("failed to convert "+
"encryptedMetedataDEK to JSON: %w", err)
}
return string(emdData), nil
}
// DecryptDEK uses the KMIP decrypt operation to decrypt the DEK.
func (kms *kmipKMS) DecryptDEK(ctx context.Context, _, encryptedDEK string) (string, error) {
conn, err := kms.connect()
if err != nil {
return "", err
}
defer conn.Close()
emd := encryptedMetedataDEK{}
err = json.Unmarshal([]byte(encryptedDEK), &emd)
if err != nil {
return "", fmt.Errorf("failed to convert data to "+
"encryptedMetedataDEK: %w", err)
}
respMsg, decoder, uniqueBatchItemID, err := kms.send(conn,
kmip14.OperationDecrypt,
DecryptRequestPayload{
UniqueIdentifier: kms.uniqueIdentifier,
Data: emd.DEK,
IVCounterNonce: emd.Nonce,
CryptographicParameters: kmip.CryptographicParameters{
PaddingMethod: kmip14.PaddingMethodPKCS5,
CryptographicAlgorithm: kmip14.CryptographicAlgorithmAES,
BlockCipherMode: kmip14.BlockCipherModeCBC,
},
})
if err != nil {
return "", err
}
batchItem, err := kms.verifyResponse(respMsg, kmip14.OperationDecrypt, uniqueBatchItemID)
if err != nil {
return "", err
}
ttlvPayload, ok := batchItem.ResponsePayload.(ttlv.TTLV)
if !ok {
return "", errors.New("failed to parse responsePayload")
}
var decryptRespPayload DecryptRequestPayload
err = decoder.DecodeValue(&decryptRespPayload, ttlvPayload)
if err != nil {
return "", err
}
return string(decryptRespPayload.Data), nil
}
func (kms *kmipKMS) Destroy() {
// Nothing to do.
}
func (kms *kmipKMS) RequiresDEKStore() DEKStoreType {
return DEKStoreMetadata
}
// getSecrets returns required options from the Kubernetes Secret.
func (kms *kmipKMS) getSecrets() (map[string]string, error) {
c, err := k8s.NewK8sClient()
if err != nil {
return nil, fmt.Errorf("failed to connect to Kubernetes to "+
"get Secret %s/%s: %w", kms.namespace, kms.secretName, err)
}
secret, err := c.CoreV1().Secrets(kms.namespace).Get(context.TODO(),
kms.secretName, metav1.GetOptions{})
if err != nil {
return nil, fmt.Errorf("failed to get Secret %s/%s: %w",
kms.namespace, kms.secretName, err)
}
config := make(map[string]string)
for k, v := range secret.Data {
switch k {
case kmipClientKey, kmipCLientCert, kmipCACert, kmipUniqueIdentifier:
config[k] = string(v)
default:
return nil, fmt.Errorf("unsupported option for KMS "+
"provider %q: %s", kmsTypeKMIP, k)
}
}
return config, nil
}
// connect to the kmip endpoint, perform TLS and KMIP handshakes.
func (kms *kmipKMS) connect() (*tls.Conn, error) {
conn, err := tls.Dial("tcp", kms.endpoint, kms.tlsConfig)
if err != nil {
return nil, fmt.Errorf("failed to dial kmip connection endpoint: %w", err)
}
if kms.readTimeout != 0 {
err = conn.SetReadDeadline(time.Now().Add(time.Second * time.Duration(kms.readTimeout)))
if err != nil {
return nil, fmt.Errorf("failed to set read deadline: %w", err)
}
}
if kms.writeTimeout != 0 {
err = conn.SetReadDeadline(time.Now().Add(time.Second * time.Duration(kms.writeTimeout)))
if err != nil {
return nil, fmt.Errorf("failed to set write deadline: %w", err)
}
}
defer func() {
if err != nil {
conn.Close()
}
}()
err = conn.Handshake()
if err != nil {
return nil, fmt.Errorf("failed to perform connection handshake: %w", err)
}
err = kms.discover(conn)
if err != nil {
return nil, err
}
return conn, nil
}
// discover performs KMIP discover operation.
// https://docs.oasis-open.org/kmip/spec/v1.4/kmip-spec-v1.4.html
// chapter 4.26.
func (kms *kmipKMS) discover(conn io.ReadWriter) error {
respMsg, decoder, uniqueBatchItemID, err := kms.send(conn,
kmip14.OperationDiscoverVersions,
kmip.DiscoverVersionsRequestPayload{
ProtocolVersion: []kmip.ProtocolVersion{
{
ProtocolVersionMajor: protocolMajor,
ProtocolVersionMinor: protocolMinor,
},
},
})
if err != nil {
return err
}
batchItem, err := kms.verifyResponse(
respMsg,
kmip14.OperationDiscoverVersions,
uniqueBatchItemID)
if err != nil {
return err
}
ttlvPayload, ok := batchItem.ResponsePayload.(ttlv.TTLV)
if !ok {
return errors.New("failed to parse responsePayload")
}
var respDiscoverVersionsPayload kmip.DiscoverVersionsResponsePayload
err = decoder.DecodeValue(&respDiscoverVersionsPayload, ttlvPayload)
if err != nil {
return err
}
if len(respDiscoverVersionsPayload.ProtocolVersion) != 1 {
return fmt.Errorf("invalid len of discovered protocol versions %v expected 1",
len(respDiscoverVersionsPayload.ProtocolVersion))
}
pv := respDiscoverVersionsPayload.ProtocolVersion[0]
if pv.ProtocolVersionMajor != protocolMajor || pv.ProtocolVersionMinor != protocolMinor {
return fmt.Errorf("invalid discovered protocol version %v.%v expected %v.%v",
pv.ProtocolVersionMajor, pv.ProtocolVersionMinor, protocolMajor, protocolMinor)
}
return nil
}
// send sends KMIP operation over tls connection, returns
// kmip response message,
// ttlv Decoder to decode message into desired format,
// batchItem ID,
// and error.
func (kms *kmipKMS) send(
conn io.ReadWriter,
operation kmip14.Operation,
payload interface{},
) (*kmip.ResponseMessage, *ttlv.Decoder, []byte, error) {
biID := uuid.New()
msg := kmip.RequestMessage{
RequestHeader: kmip.RequestHeader{
ProtocolVersion: kmip.ProtocolVersion{
ProtocolVersionMajor: protocolMajor,
ProtocolVersionMinor: protocolMinor,
},
BatchCount: 1,
},
BatchItem: []kmip.RequestBatchItem{
{
UniqueBatchItemID: biID[:],
Operation: operation,
RequestPayload: payload,
},
},
}
req, err := ttlv.Marshal(msg)
if err != nil {
return nil, nil, nil,
fmt.Errorf("failed to ttlv marshal message: %w", err)
}
_, err = conn.Write(req)
if err != nil {
return nil, nil, nil,
fmt.Errorf("failed to write request onto connection: %w", err)
}
decoder := ttlv.NewDecoder(bufio.NewReader(conn))
resp, err := decoder.NextTTLV()
if err != nil {
return nil, nil, nil,
fmt.Errorf("failed to read ttlv KMIP value: %w", err)
}
var respMsg kmip.ResponseMessage
err = decoder.DecodeValue(&respMsg, resp)
if err != nil {
return nil, nil, nil,
fmt.Errorf("failed to decode response value: %w", err)
}
return &respMsg, decoder, biID[:], nil
}
// verifyResponse verifies the response success and return the batch item.
func (kms *kmipKMS) verifyResponse(
respMsg *kmip.ResponseMessage,
operation kmip14.Operation,
uniqueBatchItemID []byte,
) (*kmip.ResponseBatchItem, error) {
if respMsg.ResponseHeader.BatchCount != 1 {
return nil, fmt.Errorf("batch count %q should be \"1\"",
respMsg.ResponseHeader.BatchCount)
}
if len(respMsg.BatchItem) != 1 {
return nil, fmt.Errorf("batch Intems list len %q should be \"1\"",
len(respMsg.BatchItem))
}
batchItem := respMsg.BatchItem[0]
if operation != batchItem.Operation {
return nil, fmt.Errorf("unexpected operation, real %q expected %q",
batchItem.Operation, operation)
}
if !bytes.Equal(uniqueBatchItemID, batchItem.UniqueBatchItemID) {
return nil, fmt.Errorf("unexpected uniqueBatchItemID, real %q expected %q",
batchItem.UniqueBatchItemID, uniqueBatchItemID)
}
if kmip14.ResultStatusSuccess != batchItem.ResultStatus {
return nil, fmt.Errorf("unexpected result status %q expected success %q,"+
"result reason %q, result message %q",
batchItem.ResultStatus, kmip14.ResultStatusSuccess,
batchItem.ResultReason, batchItem.ResultMessage)
}
return &batchItem, nil
}
func (kms *kmipKMS) GetSecret(ctx context.Context, volumeID string) (string, error) {
return "", ErrGetSecretUnsupported
}
// TODO: use the following structs from https://github.com/gemalto/kmip-go
// when https://github.com/ThalesGroup/kmip-go/issues/21 is resolved.
// refer: https://docs.oasis-open.org/kmip/spec/v1.4/kmip-spec-v1.4.html.
type EncryptRequestPayload struct {
UniqueIdentifier string
CryptographicParameters kmip.CryptographicParameters
Data []byte
IVCounterNonce []byte
}
type EncryptResponsePayload struct {
UniqueIdentifier string
Data []byte
IVCounterNonce []byte
}
type DecryptRequestPayload struct {
UniqueIdentifier string
CryptographicParameters kmip.CryptographicParameters
Data []byte
IVCounterNonce []byte
}
type DecryptResponsePayload struct {
UniqueIdentifier string
Data []byte
IVCounterNonce []byte
}