diff --git a/internal/rbd/rbd_journal.go b/internal/rbd/rbd_journal.go index f512cc340..d57203bf5 100644 --- a/internal/rbd/rbd_journal.go +++ b/internal/rbd/rbd_journal.go @@ -87,6 +87,17 @@ func validateRbdVol(rbdVol *rbdVolume) error { return err } +func getEncryptionConfig(rbdVol *rbdVolume) (string, util.EncryptionType) { + switch { + case rbdVol.isBlockEncrypted(): + return rbdVol.blockEncryption.GetID(), util.EncryptionTypeBlock + case rbdVol.isFileEncrypted(): + return rbdVol.fileEncryption.GetID(), util.EncryptionTypeFile + default: + return "", util.EncryptionTypeInvalid + } +} + /* checkSnapCloneExists, and its counterpart checkVolExists, function checks if the passed in rbdSnapshot or rbdVolume exists on the backend. @@ -130,7 +141,7 @@ func checkSnapCloneExists( defer j.Destroy() snapData, err := j.CheckReservation(ctx, rbdSnap.JournalPool, - rbdSnap.RequestName, rbdSnap.NamePrefix, rbdSnap.RbdImageName, "") + rbdSnap.RequestName, rbdSnap.NamePrefix, rbdSnap.RbdImageName, "", util.EncryptionTypeInvalid) if err != nil { return false, err } @@ -245,10 +256,7 @@ func (rv *rbdVolume) Exists(ctx context.Context, parentVol *rbdVolume) (bool, er return false, err } - kmsID := "" - if rv.isBlockEncrypted() { - kmsID = rv.blockEncryption.GetID() - } + kmsID, encryptionType := getEncryptionConfig(rv) j, err := volJournal.Connect(rv.Monitors, rv.RadosNamespace, rv.conn.Creds) if err != nil { @@ -257,7 +265,7 @@ func (rv *rbdVolume) Exists(ctx context.Context, parentVol *rbdVolume) (bool, er defer j.Destroy() imageData, err := j.CheckReservation( - ctx, rv.JournalPool, rv.RequestName, rv.NamePrefix, "", kmsID) + ctx, rv.JournalPool, rv.RequestName, rv.NamePrefix, "", kmsID, encryptionType) if err != nil { return false, err } @@ -386,14 +394,12 @@ func reserveSnap(ctx context.Context, rbdSnap *rbdSnapshot, rbdVol *rbdVolume, c } defer j.Destroy() - kmsID := "" - if rbdVol.isBlockEncrypted() { - kmsID = rbdVol.blockEncryption.GetID() - } + kmsID, encryptionType := getEncryptionConfig(rbdVol) rbdSnap.ReservedID, rbdSnap.RbdSnapName, err = j.ReserveName( ctx, rbdSnap.JournalPool, journalPoolID, rbdSnap.Pool, imagePoolID, - rbdSnap.RequestName, rbdSnap.NamePrefix, rbdVol.RbdImageName, kmsID, rbdSnap.ReservedID, rbdVol.Owner, "") + rbdSnap.RequestName, rbdSnap.NamePrefix, rbdVol.RbdImageName, kmsID, rbdSnap.ReservedID, rbdVol.Owner, + "", encryptionType) if err != nil { return err } @@ -460,10 +466,7 @@ func reserveVol(ctx context.Context, rbdVol *rbdVolume, rbdSnap *rbdSnapshot, cr return err } - kmsID := "" - if rbdVol.isBlockEncrypted() { - kmsID = rbdVol.blockEncryption.GetID() - } + kmsID, encryptionType := getEncryptionConfig(rbdVol) j, err := volJournal.Connect(rbdVol.Monitors, rbdVol.RadosNamespace, cr) if err != nil { @@ -473,7 +476,7 @@ func reserveVol(ctx context.Context, rbdVol *rbdVolume, rbdSnap *rbdSnapshot, cr rbdVol.ReservedID, rbdVol.RbdImageName, err = j.ReserveName( ctx, rbdVol.JournalPool, journalPoolID, rbdVol.Pool, imagePoolID, - rbdVol.RequestName, rbdVol.NamePrefix, "", kmsID, rbdVol.ReservedID, rbdVol.Owner, "") + rbdVol.RequestName, rbdVol.NamePrefix, "", kmsID, rbdVol.ReservedID, rbdVol.Owner, "", encryptionType) if err != nil { return err } @@ -548,11 +551,12 @@ func RegenerateJournal( ) (string, error) { ctx := context.Background() var ( - vi util.CSIIdentifier - rbdVol *rbdVolume - kmsID string - err error - ok bool + vi util.CSIIdentifier + rbdVol *rbdVolume + kmsID string + encryptionType util.EncryptionType + err error + ok bool ) rbdVol = &rbdVolume{} @@ -605,7 +609,7 @@ func RegenerateJournal( rbdVol.NamePrefix = volumeAttributes["volumeNamePrefix"] imageData, err := j.CheckReservation( - ctx, rbdVol.JournalPool, rbdVol.RequestName, rbdVol.NamePrefix, "", kmsID) + ctx, rbdVol.JournalPool, rbdVol.RequestName, rbdVol.NamePrefix, "", kmsID, encryptionType) if err != nil { return "", err } @@ -639,7 +643,7 @@ func RegenerateJournal( rbdVol.ReservedID, rbdVol.RbdImageName, err = j.ReserveName( ctx, rbdVol.JournalPool, journalPoolID, rbdVol.Pool, imagePoolID, - rbdVol.RequestName, rbdVol.NamePrefix, "", kmsID, vi.ObjectUUID, rbdVol.Owner, "") + rbdVol.RequestName, rbdVol.NamePrefix, "", kmsID, vi.ObjectUUID, rbdVol.Owner, "", encryptionType) if err != nil { return "", err }