Merge pull request #290 from red-hat-storage/sync_us--devel

Syncing latest changes from upstream devel for ceph-csi
This commit is contained in:
openshift-merge-bot[bot] 2024-04-11 08:23:51 +00:00 committed by GitHub
commit 8e62c9face
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
102 changed files with 1560 additions and 707 deletions

View File

@ -31,6 +31,7 @@ rules:
- - build - - build
- cephfs - cephfs
- ci - ci
- csi-addons
- cleanup - cleanup
- deploy - deploy
- doc - doc

View File

@ -26,7 +26,7 @@ GO111MODULE=on
COMMITLINT_VERSION=latest COMMITLINT_VERSION=latest
# static checks and linters # static checks and linters
GOLANGCI_VERSION=v1.54.1 GOLANGCI_VERSION=v1.57.2
# external snapshotter version # external snapshotter version
# Refer: https://github.com/kubernetes-csi/external-snapshotter/releases # Refer: https://github.com/kubernetes-csi/external-snapshotter/releases
@ -41,7 +41,7 @@ SNAPSHOT_VERSION=v7.0.1
HELM_SCRIPT=https://raw.githubusercontent.com/helm/helm/main/scripts/get-helm-3 HELM_SCRIPT=https://raw.githubusercontent.com/helm/helm/main/scripts/get-helm-3
# helm chart generation, testing and publishing # helm chart generation, testing and publishing
HELM_VERSION=v3.14.1 HELM_VERSION=v3.14.3
# minikube settings # minikube settings
MINIKUBE_VERSION=v1.32.0 MINIKUBE_VERSION=v1.32.0

View File

@ -23,7 +23,7 @@ import (
"k8s.io/kubernetes/test/e2e/framework" "k8s.io/kubernetes/test/e2e/framework"
) )
// #nosec because of the word `Secret` //nolint:gosec // secret for test
const ( const (
// ceph user names. // ceph user names.
keyringRBDProvisionerUsername = "cephcsi-rbd-provisioner" keyringRBDProvisionerUsername = "cephcsi-rbd-provisioner"
@ -110,7 +110,7 @@ func createCephUser(f *framework.Framework, user string, caps []string) (string,
} }
func deleteCephUser(f *framework.Framework, user string) error { func deleteCephUser(f *framework.Framework, user string) error {
cmd := fmt.Sprintf("ceph auth del client.%s", user) cmd := "ceph auth del client." + user
_, _, err := execCommandInToolBoxPod(f, cmd, rookNamespace) _, _, err := execCommandInToolBoxPod(f, cmd, rookNamespace)
return err return err

View File

@ -23,7 +23,7 @@ import (
"sync" "sync"
snapapi "github.com/kubernetes-csi/external-snapshotter/client/v7/apis/volumesnapshot/v1" snapapi "github.com/kubernetes-csi/external-snapshotter/client/v7/apis/volumesnapshot/v1"
. "github.com/onsi/ginkgo/v2" //nolint:golint // e2e uses By() and other Ginkgo functions . "github.com/onsi/ginkgo/v2"
v1 "k8s.io/api/core/v1" v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource" "k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@ -735,7 +735,7 @@ var _ = Describe(cephfsType, func() {
framework.Failf("failed to create PVC and Deployment: %v", err) framework.Failf("failed to create PVC and Deployment: %v", err)
} }
deplPods, err := listPods(f, depl.Namespace, &metav1.ListOptions{ deplPods, err := listPods(f, depl.Namespace, &metav1.ListOptions{
LabelSelector: fmt.Sprintf("app=%s", depl.Labels["app"]), LabelSelector: "app=" + depl.Labels["app"],
}) })
if err != nil { if err != nil {
framework.Failf("failed to list pods for Deployment: %v", err) framework.Failf("failed to list pods for Deployment: %v", err)
@ -744,7 +744,7 @@ var _ = Describe(cephfsType, func() {
doStat := func(podName string) (string, error) { doStat := func(podName string) (string, error) {
_, stdErr, execErr := execCommandInContainerByPodName( _, stdErr, execErr := execCommandInContainerByPodName(
f, f,
fmt.Sprintf("stat %s", depl.Spec.Template.Spec.Containers[0].VolumeMounts[0].MountPath), "stat "+depl.Spec.Template.Spec.Containers[0].VolumeMounts[0].MountPath,
depl.Namespace, depl.Namespace,
podName, podName,
depl.Spec.Template.Spec.Containers[0].Name, depl.Spec.Template.Spec.Containers[0].Name,
@ -808,7 +808,7 @@ var _ = Describe(cephfsType, func() {
} }
// List Deployment's pods again to get name of the new pod. // List Deployment's pods again to get name of the new pod.
deplPods, err = listPods(f, depl.Namespace, &metav1.ListOptions{ deplPods, err = listPods(f, depl.Namespace, &metav1.ListOptions{
LabelSelector: fmt.Sprintf("app=%s", depl.Labels["app"]), LabelSelector: "app=" + depl.Labels["app"],
}) })
if err != nil { if err != nil {
framework.Failf("failed to list pods for Deployment: %v", err) framework.Failf("failed to list pods for Deployment: %v", err)
@ -1074,13 +1074,13 @@ var _ = Describe(cephfsType, func() {
} }
opt := metav1.ListOptions{ opt := metav1.ListOptions{
LabelSelector: fmt.Sprintf("app=%s", app.Name), LabelSelector: "app=" + app.Name,
} }
filePath := app.Spec.Containers[0].VolumeMounts[0].MountPath + "/test" filePath := app.Spec.Containers[0].VolumeMounts[0].MountPath + "/test"
_, stdErr := execCommandInPodAndAllowFail( _, stdErr := execCommandInPodAndAllowFail(
f, f,
fmt.Sprintf("echo 'Hello World' > %s", filePath), "echo 'Hello World' >"+filePath,
app.Namespace, app.Namespace,
&opt) &opt)
readOnlyErr := fmt.Sprintf("cannot create %s: Read-only file system", filePath) readOnlyErr := fmt.Sprintf("cannot create %s: Read-only file system", filePath)
@ -2407,13 +2407,13 @@ var _ = Describe(cephfsType, func() {
} }
opt := metav1.ListOptions{ opt := metav1.ListOptions{
LabelSelector: fmt.Sprintf("app=%s", app.Name), LabelSelector: "app=" + app.Name,
} }
filePath := app.Spec.Containers[0].VolumeMounts[0].MountPath + "/test" filePath := app.Spec.Containers[0].VolumeMounts[0].MountPath + "/test"
_, stdErr := execCommandInPodAndAllowFail( _, stdErr := execCommandInPodAndAllowFail(
f, f,
fmt.Sprintf("echo 'Hello World' > %s", filePath), "echo 'Hello World' > "+filePath,
app.Namespace, app.Namespace,
&opt) &opt)
readOnlyErr := fmt.Sprintf("cannot create %s: Read-only file system", filePath) readOnlyErr := fmt.Sprintf("cannot create %s: Read-only file system", filePath)

View File

@ -338,7 +338,7 @@ func getSnapName(snapNamespace, snapName string) (string, error) {
} }
snapIDRegex := regexp.MustCompile(`(\w+\-?){5}$`) snapIDRegex := regexp.MustCompile(`(\w+\-?){5}$`)
snapID := snapIDRegex.FindString(*sc.Status.SnapshotHandle) snapID := snapIDRegex.FindString(*sc.Status.SnapshotHandle)
snapshotName := fmt.Sprintf("csi-snap-%s", snapID) snapshotName := "csi-snap-" + snapID
framework.Logf("snapshotName= %s", snapshotName) framework.Logf("snapshotName= %s", snapshotName)
return snapshotName, nil return snapshotName, nil
@ -392,10 +392,10 @@ func validateEncryptedCephfs(f *framework.Framework, pvName, appName string) err
LabelSelector: selector, LabelSelector: selector,
} }
cmd := fmt.Sprintf("getfattr --name=ceph.fscrypt.auth --only-values %s", volumeMountPath) cmd := "getfattr --name=ceph.fscrypt.auth --only-values " + volumeMountPath
_, _, err = execCommandInContainer(f, cmd, cephCSINamespace, "csi-cephfsplugin", &opt) _, _, err = execCommandInContainer(f, cmd, cephCSINamespace, "csi-cephfsplugin", &opt)
if err != nil { if err != nil {
cmd = fmt.Sprintf("getfattr --recursive --dump %s", volumeMountPath) cmd = "getfattr --recursive --dump " + volumeMountPath
stdOut, stdErr, listErr := execCommandInContainer(f, cmd, cephCSINamespace, "csi-cephfsplugin", &opt) stdOut, stdErr, listErr := execCommandInContainer(f, cmd, cephCSINamespace, "csi-cephfsplugin", &opt)
if listErr == nil { if listErr == nil {
return fmt.Errorf("error checking for cephfs fscrypt xattr on %q. listing: %s %s", return fmt.Errorf("error checking for cephfs fscrypt xattr on %q. listing: %s %s",

View File

@ -21,7 +21,7 @@ import (
"fmt" "fmt"
"strings" "strings"
. "github.com/onsi/gomega" //nolint:golint // e2e uses Expect() and other Gomega functions . "github.com/onsi/gomega"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes" "k8s.io/client-go/kubernetes"
"k8s.io/kubernetes/test/e2e/framework" "k8s.io/kubernetes/test/e2e/framework"

View File

@ -95,8 +95,8 @@ func replaceNamespaceInTemplate(filePath string) (string, error) {
} }
// template can contain "default" as namespace, with or without ". // template can contain "default" as namespace, with or without ".
templ := strings.ReplaceAll(string(read), "namespace: default", fmt.Sprintf("namespace: %s", cephCSINamespace)) templ := strings.ReplaceAll(string(read), "namespace: default", "namespace: "+cephCSINamespace)
templ = strings.ReplaceAll(templ, "namespace: \"default\"", fmt.Sprintf("namespace: %s", cephCSINamespace)) templ = strings.ReplaceAll(templ, "namespace: \"default\"", "namespace: "+cephCSINamespace)
return templ, nil return templ, nil
} }

View File

@ -24,7 +24,7 @@ import (
"time" "time"
snapapi "github.com/kubernetes-csi/external-snapshotter/client/v7/apis/volumesnapshot/v1" snapapi "github.com/kubernetes-csi/external-snapshotter/client/v7/apis/volumesnapshot/v1"
. "github.com/onsi/ginkgo/v2" //nolint:golint // e2e uses By() and other Ginkgo functions . "github.com/onsi/ginkgo/v2"
v1 "k8s.io/api/core/v1" v1 "k8s.io/api/core/v1"
apierrs "k8s.io/apimachinery/pkg/api/errors" apierrs "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@ -603,13 +603,13 @@ var _ = Describe("nfs", func() {
} }
opt := metav1.ListOptions{ opt := metav1.ListOptions{
LabelSelector: fmt.Sprintf("app=%s", app.Name), LabelSelector: "app=" + app.Name,
} }
filePath := app.Spec.Containers[0].VolumeMounts[0].MountPath + "/test" filePath := app.Spec.Containers[0].VolumeMounts[0].MountPath + "/test"
_, stdErr := execCommandInPodAndAllowFail( _, stdErr := execCommandInPodAndAllowFail(
f, f,
fmt.Sprintf("echo 'Hello World' > %s", filePath), "echo 'Hello World' > "+filePath,
app.Namespace, app.Namespace,
&opt) &opt)
readOnlyErr := fmt.Sprintf("cannot create %s: Read-only file system", filePath) readOnlyErr := fmt.Sprintf("cannot create %s: Read-only file system", filePath)

View File

@ -382,7 +382,7 @@ func waitForPodInRunningState(name, ns string, c kubernetes.Interface, t int, ex
case v1.PodPending: case v1.PodPending:
if expectedError != "" { if expectedError != "" {
events, err := c.CoreV1().Events(ns).List(ctx, metav1.ListOptions{ events, err := c.CoreV1().Events(ns).List(ctx, metav1.ListOptions{
FieldSelector: fmt.Sprintf("involvedObject.name=%s", name), FieldSelector: "involvedObject.name=" + name,
}) })
if err != nil { if err != nil {
return false, err return false, err
@ -452,7 +452,7 @@ func deletePodWithLabel(label, ns string, skipNotFound bool) error {
// calculateSHA512sum returns the sha512sum of a file inside a pod. // calculateSHA512sum returns the sha512sum of a file inside a pod.
func calculateSHA512sum(f *framework.Framework, app *v1.Pod, filePath string, opt *metav1.ListOptions) (string, error) { func calculateSHA512sum(f *framework.Framework, app *v1.Pod, filePath string, opt *metav1.ListOptions) (string, error) {
cmd := fmt.Sprintf("sha512sum %s", filePath) cmd := "sha512sum " + filePath
sha512sumOut, stdErr, err := execCommandInPod(f, cmd, app.Namespace, opt) sha512sumOut, stdErr, err := execCommandInPod(f, cmd, app.Namespace, opt)
if err != nil { if err != nil {
return "", err return "", err

View File

@ -19,12 +19,13 @@ package e2e
import ( import (
"context" "context"
"fmt" "fmt"
"strconv"
"strings" "strings"
"time" "time"
"github.com/ceph/ceph-csi/internal/util" "github.com/ceph/ceph-csi/internal/util"
. "github.com/onsi/ginkgo/v2" //nolint:golint // e2e uses By() and other Ginkgo functions . "github.com/onsi/ginkgo/v2"
v1 "k8s.io/api/core/v1" v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/types"
@ -536,7 +537,7 @@ var _ = Describe("RBD", func() {
}) })
By("reattach the old PV to a new PVC and check if PVC metadata is updated on RBD image", func() { By("reattach the old PV to a new PVC and check if PVC metadata is updated on RBD image", func() {
reattachPVCNamespace := fmt.Sprintf("%s-2", f.Namespace.Name) reattachPVCNamespace := f.Namespace.Name + "-2"
pvc, err := loadPVC(pvcPath) pvc, err := loadPVC(pvcPath)
if err != nil { if err != nil {
framework.Failf("failed to load PVC: %v", err) framework.Failf("failed to load PVC: %v", err)
@ -1512,7 +1513,7 @@ var _ = Describe("RBD", func() {
cmd := fmt.Sprintf("dd if=/dev/zero of=%s bs=1M count=10", devPath) cmd := fmt.Sprintf("dd if=/dev/zero of=%s bs=1M count=10", devPath)
opt := metav1.ListOptions{ opt := metav1.ListOptions{
LabelSelector: fmt.Sprintf("app=%s", app.Name), LabelSelector: "app=" + app.Name,
} }
podList, err := e2epod.PodClientNS(f, app.Namespace).List(context.TODO(), opt) podList, err := e2epod.PodClientNS(f, app.Namespace).List(context.TODO(), opt)
if err != nil { if err != nil {
@ -1630,10 +1631,10 @@ var _ = Describe("RBD", func() {
validateOmapCount(f, 2, rbdType, defaultRBDPool, volumesType) validateOmapCount(f, 2, rbdType, defaultRBDPool, volumesType)
filePath := appClone.Spec.Template.Spec.Containers[0].VolumeMounts[0].MountPath + "/test" filePath := appClone.Spec.Template.Spec.Containers[0].VolumeMounts[0].MountPath + "/test"
cmd := fmt.Sprintf("echo 'Hello World' > %s", filePath) cmd := "echo 'Hello World' > " + filePath
opt := metav1.ListOptions{ opt := metav1.ListOptions{
LabelSelector: fmt.Sprintf("app=%s", appClone.Name), LabelSelector: "app=" + appClone.Name,
} }
podList, err := e2epod.PodClientNS(f, appClone.Namespace).List(context.TODO(), opt) podList, err := e2epod.PodClientNS(f, appClone.Namespace).List(context.TODO(), opt)
if err != nil { if err != nil {
@ -1766,7 +1767,7 @@ var _ = Describe("RBD", func() {
cmd := fmt.Sprintf("dd if=/dev/zero of=%s bs=1M count=10", devPath) cmd := fmt.Sprintf("dd if=/dev/zero of=%s bs=1M count=10", devPath)
opt := metav1.ListOptions{ opt := metav1.ListOptions{
LabelSelector: fmt.Sprintf("app=%s", appClone.Name), LabelSelector: "app=" + appClone.Name,
} }
podList, err := e2epod.PodClientNS(f, appClone.Namespace).List(context.TODO(), opt) podList, err := e2epod.PodClientNS(f, appClone.Namespace).List(context.TODO(), opt)
if err != nil { if err != nil {
@ -1862,14 +1863,14 @@ var _ = Describe("RBD", func() {
} }
appOpt := metav1.ListOptions{ appOpt := metav1.ListOptions{
LabelSelector: fmt.Sprintf("app=%s", app.Name), LabelSelector: "app=" + app.Name,
} }
// TODO: Remove this once we ensure that rbd-nbd can sync data // TODO: Remove this once we ensure that rbd-nbd can sync data
// from Filesystem layer to backend rbd image as part of its // from Filesystem layer to backend rbd image as part of its
// detach or SIGTERM signal handler // detach or SIGTERM signal handler
_, stdErr, err := execCommandInPod( _, stdErr, err := execCommandInPod(
f, f,
fmt.Sprintf("sync %s", app.Spec.Containers[0].VolumeMounts[0].MountPath), "sync "+app.Spec.Containers[0].VolumeMounts[0].MountPath,
app.Namespace, app.Namespace,
&appOpt) &appOpt)
if err != nil || stdErr != "" { if err != nil || stdErr != "" {
@ -1956,7 +1957,7 @@ var _ = Describe("RBD", func() {
filePath := app.Spec.Containers[0].VolumeMounts[0].MountPath + "/test" filePath := app.Spec.Containers[0].VolumeMounts[0].MountPath + "/test"
_, stdErr, err = execCommandInPod( _, stdErr, err = execCommandInPod(
f, f,
fmt.Sprintf("echo 'Hello World' > %s", filePath), "echo 'Hello World' > "+filePath,
app.Namespace, app.Namespace,
&appOpt) &appOpt)
if err != nil || stdErr != "" { if err != nil || stdErr != "" {
@ -3331,13 +3332,13 @@ var _ = Describe("RBD", func() {
for i := 0; i < totalCount; i++ { for i := 0; i < totalCount; i++ {
name := fmt.Sprintf("%s%d", f.UniqueName, i) name := fmt.Sprintf("%s%d", f.UniqueName, i)
opt := metav1.ListOptions{ opt := metav1.ListOptions{
LabelSelector: fmt.Sprintf("app=%s", name), LabelSelector: "app=" + name,
} }
filePath := appClone.Spec.Containers[0].VolumeMounts[0].MountPath + "/test" filePath := appClone.Spec.Containers[0].VolumeMounts[0].MountPath + "/test"
_, stdErr := execCommandInPodAndAllowFail( _, stdErr := execCommandInPodAndAllowFail(
f, f,
fmt.Sprintf("echo 'Hello World' > %s", filePath), "echo 'Hello World' > "+filePath,
appClone.Namespace, appClone.Namespace,
&opt) &opt)
readOnlyErr := fmt.Sprintf("cannot create %s: Read-only file system", filePath) readOnlyErr := fmt.Sprintf("cannot create %s: Read-only file system", filePath)
@ -3961,13 +3962,13 @@ var _ = Describe("RBD", func() {
validateOmapCount(f, 1, rbdType, defaultRBDPool, volumesType) validateOmapCount(f, 1, rbdType, defaultRBDPool, volumesType)
opt := metav1.ListOptions{ opt := metav1.ListOptions{
LabelSelector: fmt.Sprintf("app=%s", app.Name), LabelSelector: "app=" + app.Name,
} }
filePath := app.Spec.Containers[0].VolumeMounts[0].MountPath + "/test" filePath := app.Spec.Containers[0].VolumeMounts[0].MountPath + "/test"
_, stdErr := execCommandInPodAndAllowFail( _, stdErr := execCommandInPodAndAllowFail(
f, f,
fmt.Sprintf("echo 'Hello World' > %s", filePath), "echo 'Hello World' > "+filePath,
app.Namespace, app.Namespace,
&opt) &opt)
readOnlyErr := fmt.Sprintf("cannot create %s: Read-only file system", filePath) readOnlyErr := fmt.Sprintf("cannot create %s: Read-only file system", filePath)
@ -4350,9 +4351,9 @@ var _ = Describe("RBD", func() {
defaultSCName, defaultSCName,
nil, nil,
map[string]string{ map[string]string{
"stripeUnit": fmt.Sprintf("%d", stripeUnit), "stripeUnit": strconv.Itoa(stripeUnit),
"stripeCount": fmt.Sprintf("%d", stripeCount), "stripeCount": strconv.Itoa(stripeCount),
"objectSize": fmt.Sprintf("%d", objectSize), "objectSize": strconv.Itoa(objectSize),
}, },
deletePolicy) deletePolicy)
if err != nil { if err != nil {

View File

@ -118,7 +118,7 @@ func createRBDStorageClass(
scOptions, parameters map[string]string, scOptions, parameters map[string]string,
policy v1.PersistentVolumeReclaimPolicy, policy v1.PersistentVolumeReclaimPolicy,
) error { ) error {
scPath := fmt.Sprintf("%s/%s", rbdExamplePath, "storageclass.yaml") scPath := rbdExamplePath + "/" + "storageclass.yaml"
sc, err := getStorageClass(scPath) sc, err := getStorageClass(scPath)
if err != nil { if err != nil {
return fmt.Errorf("failed to get sc: %w", err) return fmt.Errorf("failed to get sc: %w", err)
@ -184,7 +184,7 @@ func createRBDStorageClass(
func createRadosNamespace(f *framework.Framework) error { func createRadosNamespace(f *framework.Framework) error {
stdOut, stdErr, err := execCommandInToolBoxPod(f, stdOut, stdErr, err := execCommandInToolBoxPod(f,
fmt.Sprintf("rbd namespace ls --pool=%s", defaultRBDPool), rookNamespace) "rbd namespace ls --pool="+defaultRBDPool, rookNamespace)
if err != nil { if err != nil {
return err return err
} }
@ -193,7 +193,7 @@ func createRadosNamespace(f *framework.Framework) error {
} }
if !strings.Contains(stdOut, radosNamespace) { if !strings.Contains(stdOut, radosNamespace) {
_, stdErr, err = execCommandInToolBoxPod(f, _, stdErr, err = execCommandInToolBoxPod(f,
fmt.Sprintf("rbd namespace create %s", rbdOptions(defaultRBDPool)), rookNamespace) "rbd namespace create "+rbdOptions(defaultRBDPool), rookNamespace)
if err != nil { if err != nil {
return err return err
} }
@ -202,7 +202,7 @@ func createRadosNamespace(f *framework.Framework) error {
} }
} }
stdOut, stdErr, err = execCommandInToolBoxPod(f, stdOut, stdErr, err = execCommandInToolBoxPod(f,
fmt.Sprintf("rbd namespace ls --pool=%s", rbdTopologyPool), rookNamespace) "rbd namespace ls --pool="+rbdTopologyPool, rookNamespace)
if err != nil { if err != nil {
return err return err
} }
@ -212,7 +212,7 @@ func createRadosNamespace(f *framework.Framework) error {
if !strings.Contains(stdOut, radosNamespace) { if !strings.Contains(stdOut, radosNamespace) {
_, stdErr, err = execCommandInToolBoxPod(f, _, stdErr, err = execCommandInToolBoxPod(f,
fmt.Sprintf("rbd namespace create %s", rbdOptions(rbdTopologyPool)), rookNamespace) "rbd namespace create "+rbdOptions(rbdTopologyPool), rookNamespace)
if err != nil { if err != nil {
return err return err
} }
@ -269,7 +269,7 @@ func getImageInfoFromPVC(pvcNamespace, pvcName string, f *framework.Framework) (
imageData = imageInfoFromPVC{ imageData = imageInfoFromPVC{
imageID: imageID, imageID: imageID,
imageName: fmt.Sprintf("csi-vol-%s", imageID), imageName: "csi-vol-" + imageID,
csiVolumeHandle: pv.Spec.CSI.VolumeHandle, csiVolumeHandle: pv.Spec.CSI.VolumeHandle,
pvName: pv.Name, pvName: pv.Name,
} }
@ -671,7 +671,7 @@ func validateEncryptedFilesystem(f *framework.Framework, rbdImageSpec, pvName, a
cmd := fmt.Sprintf("lsattr -la %s | grep -E '%s/.\\s+Encrypted'", volumeMountPath, volumeMountPath) cmd := fmt.Sprintf("lsattr -la %s | grep -E '%s/.\\s+Encrypted'", volumeMountPath, volumeMountPath)
_, _, err = execCommandInContainer(f, cmd, cephCSINamespace, "csi-rbdplugin", &opt) _, _, err = execCommandInContainer(f, cmd, cephCSINamespace, "csi-rbdplugin", &opt)
if err != nil { if err != nil {
cmd = fmt.Sprintf("lsattr -lRa %s", volumeMountPath) cmd = "lsattr -lRa " + volumeMountPath
stdOut, stdErr, listErr := execCommandInContainer(f, cmd, cephCSINamespace, "csi-rbdplugin", &opt) stdOut, stdErr, listErr := execCommandInContainer(f, cmd, cephCSINamespace, "csi-rbdplugin", &opt)
if listErr == nil { if listErr == nil {
return fmt.Errorf("error checking file encrypted attribute of %q. listing filesystem+attrs: %s %s", return fmt.Errorf("error checking file encrypted attribute of %q. listing filesystem+attrs: %s %s",
@ -697,7 +697,7 @@ func listRBDImages(f *framework.Framework, pool string) ([]string, error) {
var imgInfos []string var imgInfos []string
stdout, stdErr, err := execCommandInToolBoxPod(f, stdout, stdErr, err := execCommandInToolBoxPod(f,
fmt.Sprintf("rbd ls --format=json %s", rbdOptions(pool)), rookNamespace) "rbd ls --format=json "+rbdOptions(pool), rookNamespace)
if err != nil { if err != nil {
return imgInfos, err return imgInfos, err
} }
@ -744,7 +744,7 @@ type rbdDuImageList struct {
// getRbdDu runs 'rbd du' on the RBD image and returns a rbdDuImage struct with // getRbdDu runs 'rbd du' on the RBD image and returns a rbdDuImage struct with
// the result. // the result.
// //
//nolint:deadcode,unused // required for reclaimspace e2e. //nolint:unused // Unused code will be used in future.
func getRbdDu(f *framework.Framework, pvc *v1.PersistentVolumeClaim) (*rbdDuImage, error) { func getRbdDu(f *framework.Framework, pvc *v1.PersistentVolumeClaim) (*rbdDuImage, error) {
rdil := rbdDuImageList{} rdil := rbdDuImageList{}
@ -778,7 +778,7 @@ func getRbdDu(f *framework.Framework, pvc *v1.PersistentVolumeClaim) (*rbdDuImag
// take up any space anymore. This can be used to verify that an empty, but // take up any space anymore. This can be used to verify that an empty, but
// allocated (with zerofill) extents have been released. // allocated (with zerofill) extents have been released.
// //
//nolint:deadcode,unused // required for reclaimspace e2e. //nolint:unused // Unused code will be used in future.
func sparsifyBackingRBDImage(f *framework.Framework, pvc *v1.PersistentVolumeClaim) error { func sparsifyBackingRBDImage(f *framework.Framework, pvc *v1.PersistentVolumeClaim) error {
imageData, err := getImageInfoFromPVC(pvc.Namespace, pvc.Name, f) imageData, err := getImageInfoFromPVC(pvc.Namespace, pvc.Name, f)
if err != nil { if err != nil {
@ -802,7 +802,7 @@ func deletePool(name string, cephFS bool, f *framework.Framework) error {
// --yes-i-really-mean-it // --yes-i-really-mean-it
// ceph osd pool delete myfs-replicated myfs-replicated // ceph osd pool delete myfs-replicated myfs-replicated
// --yes-i-really-mean-it // --yes-i-really-mean-it
cmds = append(cmds, fmt.Sprintf("ceph fs fail %s", name), cmds = append(cmds, "ceph fs fail "+name,
fmt.Sprintf("ceph fs rm %s --yes-i-really-mean-it", name), fmt.Sprintf("ceph fs rm %s --yes-i-really-mean-it", name),
fmt.Sprintf("ceph osd pool delete %s-metadata %s-metadata --yes-i-really-really-mean-it", name, name), fmt.Sprintf("ceph osd pool delete %s-metadata %s-metadata --yes-i-really-really-mean-it", name, name),
fmt.Sprintf("ceph osd pool delete %s-replicated %s-replicated --yes-i-really-really-mean-it", name, name)) fmt.Sprintf("ceph osd pool delete %s-replicated %s-replicated --yes-i-really-really-mean-it", name, name))
@ -850,7 +850,7 @@ func getPVCImageInfoInPool(f *framework.Framework, pvc *v1.PersistentVolumeClaim
} }
stdOut, stdErr, err := execCommandInToolBoxPod(f, stdOut, stdErr, err := execCommandInToolBoxPod(f,
fmt.Sprintf("rbd info %s", imageSpec(pool, imageData.imageName)), rookNamespace) "rbd info "+imageSpec(pool, imageData.imageName), rookNamespace)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -1021,7 +1021,7 @@ func listRBDImagesInTrash(f *framework.Framework, poolName string) ([]trashInfo,
var trashInfos []trashInfo var trashInfos []trashInfo
stdout, stdErr, err := execCommandInToolBoxPod(f, stdout, stdErr, err := execCommandInToolBoxPod(f,
fmt.Sprintf("rbd trash ls --format=json %s", rbdOptions(poolName)), rookNamespace) "rbd trash ls --format=json "+rbdOptions(poolName), rookNamespace)
if err != nil { if err != nil {
return trashInfos, err return trashInfos, err
} }

View File

@ -22,7 +22,7 @@ import (
"strings" "strings"
"time" "time"
. "github.com/onsi/gomega" //nolint:golint // e2e uses Expect() and other Gomega functions . "github.com/onsi/gomega"
v1 "k8s.io/api/core/v1" v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource" "k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@ -179,7 +179,7 @@ func getDirSizeCheckCmd(dirPath string) string {
} }
func getDeviceSizeCheckCmd(devPath string) string { func getDeviceSizeCheckCmd(devPath string) string {
return fmt.Sprintf("blockdev --getsize64 %s", devPath) return "blockdev --getsize64 " + devPath
} }
func checkAppMntSize(f *framework.Framework, opt *metav1.ListOptions, size, cmd, ns string, t int) error { func checkAppMntSize(f *framework.Framework, opt *metav1.ListOptions, size, cmd, ns string, t int) error {

View File

@ -23,7 +23,7 @@ import (
snapapi "github.com/kubernetes-csi/external-snapshotter/client/v7/apis/volumesnapshot/v1" snapapi "github.com/kubernetes-csi/external-snapshotter/client/v7/apis/volumesnapshot/v1"
snapclient "github.com/kubernetes-csi/external-snapshotter/client/v7/clientset/versioned/typed/volumesnapshot/v1" snapclient "github.com/kubernetes-csi/external-snapshotter/client/v7/clientset/versioned/typed/volumesnapshot/v1"
. "github.com/onsi/gomega" //nolint:golint // e2e uses Expect() and other Gomega functions . "github.com/onsi/gomega"
v1 "k8s.io/api/core/v1" v1 "k8s.io/api/core/v1"
apierrs "k8s.io/apimachinery/pkg/api/errors" apierrs "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/api/resource" "k8s.io/apimachinery/pkg/api/resource"

View File

@ -23,7 +23,7 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
. "github.com/onsi/ginkgo/v2" //nolint:golint // e2e uses By() and other Ginkgo functions . "github.com/onsi/ginkgo/v2"
v1 "k8s.io/api/core/v1" v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource" "k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@ -234,7 +234,7 @@ var _ = Describe("CephFS Upgrade Testing", func() {
} }
// force an immediate write of all cached data to disk. // force an immediate write of all cached data to disk.
_, stdErr = execCommandInPodAndAllowFail(f, fmt.Sprintf("sync %s", filePath), app.Namespace, &opt) _, stdErr = execCommandInPodAndAllowFail(f, "sync "+filePath, app.Namespace, &opt)
if stdErr != "" { if stdErr != "" {
framework.Failf("failed to sync data to a disk %s", stdErr) framework.Failf("failed to sync data to a disk %s", stdErr)
} }

View File

@ -23,7 +23,7 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
. "github.com/onsi/ginkgo/v2" //nolint:golint // e2e uses By() and other Ginkgo functions . "github.com/onsi/ginkgo/v2"
v1 "k8s.io/api/core/v1" v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource" "k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@ -239,13 +239,13 @@ var _ = Describe("RBD Upgrade Testing", func() {
} }
// force an immediate write of all cached data to disk. // force an immediate write of all cached data to disk.
_, stdErr = execCommandInPodAndAllowFail(f, fmt.Sprintf("sync %s", filePath), app.Namespace, &opt) _, stdErr = execCommandInPodAndAllowFail(f, "sync "+filePath, app.Namespace, &opt)
if stdErr != "" { if stdErr != "" {
framework.Failf("failed to sync data to a disk %s", stdErr) framework.Failf("failed to sync data to a disk %s", stdErr)
} }
opt = metav1.ListOptions{ opt = metav1.ListOptions{
LabelSelector: fmt.Sprintf("app=%s", appLabel), LabelSelector: "app=" + appLabel,
} }
framework.Logf("Calculating checksum of %s", filePath) framework.Logf("Calculating checksum of %s", filePath)
checkSum, err = calculateSHA512sum(f, app, filePath, &opt) checkSum, err = calculateSHA512sum(f, app, filePath, &opt)

View File

@ -183,9 +183,9 @@ func validateOmapCount(f *framework.Framework, count int, driver, pool, mode str
{ {
volumeMode: volumesType, volumeMode: volumesType,
driverType: rbdType, driverType: rbdType,
radosLsCmd: fmt.Sprintf("rados ls %s", rbdOptions(pool)), radosLsCmd: "rados ls " + rbdOptions(pool),
radosLsCmdFilter: fmt.Sprintf("rados ls %s | grep -v default | grep -c ^csi.volume.", rbdOptions(pool)), radosLsCmdFilter: fmt.Sprintf("rados ls %s | grep -v default | grep -c ^csi.volume.", rbdOptions(pool)),
radosLsKeysCmd: fmt.Sprintf("rados listomapkeys csi.volumes.default %s", rbdOptions(pool)), radosLsKeysCmd: "rados listomapkeys csi.volumes.default " + rbdOptions(pool),
radosLsKeysCmdFilter: fmt.Sprintf("rados listomapkeys csi.volumes.default %s | wc -l", rbdOptions(pool)), radosLsKeysCmdFilter: fmt.Sprintf("rados listomapkeys csi.volumes.default %s | wc -l", rbdOptions(pool)),
}, },
{ {
@ -201,9 +201,9 @@ func validateOmapCount(f *framework.Framework, count int, driver, pool, mode str
{ {
volumeMode: snapsType, volumeMode: snapsType,
driverType: rbdType, driverType: rbdType,
radosLsCmd: fmt.Sprintf("rados ls %s", rbdOptions(pool)), radosLsCmd: "rados ls " + rbdOptions(pool),
radosLsCmdFilter: fmt.Sprintf("rados ls %s | grep -v default | grep -c ^csi.snap.", rbdOptions(pool)), radosLsCmdFilter: fmt.Sprintf("rados ls %s | grep -v default | grep -c ^csi.snap.", rbdOptions(pool)),
radosLsKeysCmd: fmt.Sprintf("rados listomapkeys csi.snaps.default %s", rbdOptions(pool)), radosLsKeysCmd: "rados listomapkeys csi.snaps.default " + rbdOptions(pool),
radosLsKeysCmdFilter: fmt.Sprintf("rados listomapkeys csi.snaps.default %s | wc -l", rbdOptions(pool)), radosLsKeysCmdFilter: fmt.Sprintf("rados listomapkeys csi.snaps.default %s | wc -l", rbdOptions(pool)),
}, },
} }
@ -716,7 +716,7 @@ func checkDataPersist(pvcPath, appPath string, f *framework.Framework) error {
if err != nil { if err != nil {
return err return err
} }
persistData, stdErr, err := execCommandInPod(f, fmt.Sprintf("cat %s", filePath), app.Namespace, &opt) persistData, stdErr, err := execCommandInPod(f, "cat "+filePath, app.Namespace, &opt)
if err != nil { if err != nil {
return err return err
} }
@ -793,7 +793,7 @@ func checkMountOptions(pvcPath, appPath string, f *framework.Framework, mountFla
LabelSelector: "app=validate-mount-opt", LabelSelector: "app=validate-mount-opt",
} }
cmd := fmt.Sprintf("mount |grep %s", app.Spec.Containers[0].VolumeMounts[0].MountPath) cmd := "mount |grep " + app.Spec.Containers[0].VolumeMounts[0].MountPath
data, stdErr, err := execCommandInPod(f, cmd, app.Namespace, &opt) data, stdErr, err := execCommandInPod(f, cmd, app.Namespace, &opt)
if err != nil { if err != nil {
return err return err
@ -1545,7 +1545,7 @@ func validateController(
// If fetching the ServerVersion of the Kubernetes cluster fails, the calling // If fetching the ServerVersion of the Kubernetes cluster fails, the calling
// test case is marked as `FAILED` and gets aborted. // test case is marked as `FAILED` and gets aborted.
// //
//nolint:deadcode,unused // Unused code will be used in future. //nolint:unused // Unused code will be used in future.
func k8sVersionGreaterEquals(c kubernetes.Interface, major, minor int) bool { func k8sVersionGreaterEquals(c kubernetes.Interface, major, minor int) bool {
v, err := c.Discovery().ServerVersion() v, err := c.Discovery().ServerVersion()
if err != nil { if err != nil {
@ -1555,8 +1555,8 @@ func k8sVersionGreaterEquals(c kubernetes.Interface, major, minor int) bool {
// return value. // return value.
} }
maj := fmt.Sprintf("%d", major) maj := strconv.Itoa(major)
min := fmt.Sprintf("%d", minor) min := strconv.Itoa(minor)
return (v.Major > maj) || (v.Major == maj && v.Minor >= min) return (v.Major > maj) || (v.Major == maj && v.Minor >= min)
} }

12
go.mod
View File

@ -27,9 +27,9 @@ require (
github.com/pkg/xattr v0.4.9 github.com/pkg/xattr v0.4.9
github.com/prometheus/client_golang v1.18.0 github.com/prometheus/client_golang v1.18.0
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
golang.org/x/crypto v0.21.0 golang.org/x/crypto v0.22.0
golang.org/x/net v0.22.0 golang.org/x/net v0.24.0
golang.org/x/sys v0.18.0 golang.org/x/sys v0.19.0
google.golang.org/grpc v1.62.1 google.golang.org/grpc v1.62.1
google.golang.org/protobuf v1.33.0 google.golang.org/protobuf v1.33.0
// //
@ -44,7 +44,7 @@ require (
k8s.io/mount-utils v0.29.3 k8s.io/mount-utils v0.29.3
k8s.io/pod-security-admission v0.29.3 k8s.io/pod-security-admission v0.29.3
k8s.io/utils v0.0.0-20230726121419-3b25d923346b k8s.io/utils v0.0.0-20230726121419-3b25d923346b
sigs.k8s.io/controller-runtime v0.17.2 sigs.k8s.io/controller-runtime v0.17.3
) )
require ( require (
@ -163,7 +163,7 @@ require (
golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 // indirect golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 // indirect
golang.org/x/oauth2 v0.16.0 // indirect golang.org/x/oauth2 v0.16.0 // indirect
golang.org/x/sync v0.6.0 // indirect golang.org/x/sync v0.6.0 // indirect
golang.org/x/term v0.18.0 // indirect golang.org/x/term v0.19.0 // indirect
golang.org/x/text v0.14.0 // indirect golang.org/x/text v0.14.0 // indirect
golang.org/x/time v0.3.0 // indirect golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.16.1 // indirect golang.org/x/tools v0.16.1 // indirect
@ -176,7 +176,7 @@ require (
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
k8s.io/apiextensions-apiserver v0.29.0 // indirect k8s.io/apiextensions-apiserver v0.29.2 // indirect
k8s.io/apiserver v0.29.3 // indirect k8s.io/apiserver v0.29.3 // indirect
k8s.io/component-base v0.29.3 // indirect k8s.io/component-base v0.29.3 // indirect
k8s.io/component-helpers v0.29.3 // indirect k8s.io/component-helpers v0.29.3 // indirect

20
go.sum
View File

@ -1761,8 +1761,8 @@ golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliY
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -1908,8 +1908,8 @@ golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@ -2078,8 +2078,8 @@ golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@ -2098,8 +2098,8 @@ golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= golang.org/x/term v0.19.0 h1:+ThwsDv+tYfnJFhF4L8jITxu1tdTWRTZpdsWgEgjL6Q=
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -2702,8 +2702,8 @@ rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.28.0 h1:TgtAeesdhpm2SGwkQasmbeqDo8th5wOBA5h/AjTKA4I= sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.28.0 h1:TgtAeesdhpm2SGwkQasmbeqDo8th5wOBA5h/AjTKA4I=
sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.28.0/go.mod h1:VHVDI/KrK4fjnV61bE2g3sA7tiETLn8sooImelsCx3Y= sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.28.0/go.mod h1:VHVDI/KrK4fjnV61bE2g3sA7tiETLn8sooImelsCx3Y=
sigs.k8s.io/controller-runtime v0.2.2/go.mod h1:9dyohw3ZtoXQuV1e766PHUn+cmrRCIcBh6XIMFNMZ+I= sigs.k8s.io/controller-runtime v0.2.2/go.mod h1:9dyohw3ZtoXQuV1e766PHUn+cmrRCIcBh6XIMFNMZ+I=
sigs.k8s.io/controller-runtime v0.17.2 h1:FwHwD1CTUemg0pW2otk7/U5/i5m2ymzvOXdbeGOUvw0= sigs.k8s.io/controller-runtime v0.17.3 h1:65QmN7r3FWgTxDMz9fvGnO1kbf2nu+acg9p2R9oYYYk=
sigs.k8s.io/controller-runtime v0.17.2/go.mod h1:+MngTvIQQQhfXtwfdGw/UOQ/aIaqsYywfCINOtwMO/s= sigs.k8s.io/controller-runtime v0.17.3/go.mod h1:N0jpP5Lo7lMTF9aL56Z/B2oWBJjey6StQM0jRbKQXtY=
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMmr1bNJefnuqLsRAsHZo= sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMmr1bNJefnuqLsRAsHZo=
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0= sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0=
sigs.k8s.io/structured-merge-diff/v4 v4.2.3/go.mod h1:qjx8mGObPmV2aSZepjQjbmb2ihdVs8cGKBraizNC69E= sigs.k8s.io/structured-merge-diff/v4 v4.2.3/go.mod h1:qjx8mGObPmV2aSZepjQjbmb2ihdVs8cGKBraizNC69E=

View File

@ -183,13 +183,13 @@ func (cs *ControllerServer) checkContentSource(
req *csi.CreateVolumeRequest, req *csi.CreateVolumeRequest,
cr *util.Credentials, cr *util.Credentials,
) (*store.VolumeOptions, *store.VolumeIdentifier, *store.SnapshotIdentifier, error) { ) (*store.VolumeOptions, *store.VolumeIdentifier, *store.SnapshotIdentifier, error) {
if req.VolumeContentSource == nil { if req.GetVolumeContentSource() == nil {
return nil, nil, nil, nil return nil, nil, nil, nil
} }
volumeSource := req.VolumeContentSource volumeSource := req.GetVolumeContentSource()
switch volumeSource.Type.(type) { switch volumeSource.GetType().(type) {
case *csi.VolumeContentSource_Snapshot: case *csi.VolumeContentSource_Snapshot:
snapshotID := req.VolumeContentSource.GetSnapshot().GetSnapshotId() snapshotID := req.GetVolumeContentSource().GetSnapshot().GetSnapshotId()
volOpt, _, sid, err := store.NewSnapshotOptionsFromID(ctx, snapshotID, cr, volOpt, _, sid, err := store.NewSnapshotOptionsFromID(ctx, snapshotID, cr,
req.GetSecrets(), cs.ClusterName, cs.SetMetadata) req.GetSecrets(), cs.ClusterName, cs.SetMetadata)
if err != nil { if err != nil {
@ -203,9 +203,9 @@ func (cs *ControllerServer) checkContentSource(
return volOpt, nil, sid, nil return volOpt, nil, sid, nil
case *csi.VolumeContentSource_Volume: case *csi.VolumeContentSource_Volume:
// Find the volume using the provided VolumeID // Find the volume using the provided VolumeID
volID := req.VolumeContentSource.GetVolume().GetVolumeId() volID := req.GetVolumeContentSource().GetVolume().GetVolumeId()
parentVol, pvID, err := store.NewVolumeOptionsFromVolID(ctx, parentVol, pvID, err := store.NewVolumeOptionsFromVolID(ctx,
volID, nil, req.Secrets, cs.ClusterName, cs.SetMetadata) volID, nil, req.GetSecrets(), cs.ClusterName, cs.SetMetadata)
if err != nil { if err != nil {
if !errors.Is(err, cerrors.ErrVolumeNotFound) { if !errors.Is(err, cerrors.ErrVolumeNotFound) {
return nil, nil, nil, status.Error(codes.NotFound, err.Error()) return nil, nil, nil, status.Error(codes.NotFound, err.Error())
@ -342,7 +342,7 @@ func (cs *ControllerServer) CreateVolume(
// As we are trying to create RWX volume from backing snapshot, we need to // As we are trying to create RWX volume from backing snapshot, we need to
// retrieve the snapshot details from the backing snapshot and create a // retrieve the snapshot details from the backing snapshot and create a
// subvolume clone from the snapshot. // subvolume clone from the snapshot.
if parentVol != nil && parentVol.BackingSnapshot && !store.IsVolumeCreateRO(req.VolumeCapabilities) { if parentVol != nil && parentVol.BackingSnapshot && !store.IsVolumeCreateRO(req.GetVolumeCapabilities()) {
// unset pvID as we dont have real subvolume for the parent volumeID as its a backing snapshot // unset pvID as we dont have real subvolume for the parent volumeID as its a backing snapshot
pvID = nil pvID = nil
parentVol, _, sID, err = store.NewSnapshotOptionsFromID(ctx, parentVol.BackingSnapshotID, cr, parentVol, _, sID, err = store.NewSnapshotOptionsFromID(ctx, parentVol.BackingSnapshotID, cr,
@ -674,7 +674,7 @@ func (cs *ControllerServer) ValidateVolumeCapabilities(
req *csi.ValidateVolumeCapabilitiesRequest, req *csi.ValidateVolumeCapabilitiesRequest,
) (*csi.ValidateVolumeCapabilitiesResponse, error) { ) (*csi.ValidateVolumeCapabilitiesResponse, error) {
// Cephfs doesn't support Block volume // Cephfs doesn't support Block volume
for _, capability := range req.VolumeCapabilities { for _, capability := range req.GetVolumeCapabilities() {
if capability.GetBlock() != nil { if capability.GetBlock() != nil {
return &csi.ValidateVolumeCapabilitiesResponse{Message: ""}, nil return &csi.ValidateVolumeCapabilitiesResponse{Message: ""}, nil
} }
@ -682,7 +682,7 @@ func (cs *ControllerServer) ValidateVolumeCapabilities(
return &csi.ValidateVolumeCapabilitiesResponse{ return &csi.ValidateVolumeCapabilitiesResponse{
Confirmed: &csi.ValidateVolumeCapabilitiesResponse_Confirmed{ Confirmed: &csi.ValidateVolumeCapabilitiesResponse_Confirmed{
VolumeCapabilities: req.VolumeCapabilities, VolumeCapabilities: req.GetVolumeCapabilities(),
}, },
}, nil }, nil
} }
@ -970,10 +970,10 @@ func (cs *ControllerServer) validateSnapshotReq(ctx context.Context, req *csi.Cr
} }
// Check sanity of request Snapshot Name, Source Volume Id // Check sanity of request Snapshot Name, Source Volume Id
if req.Name == "" { if req.GetName() == "" {
return status.Error(codes.NotFound, "snapshot Name cannot be empty") return status.Error(codes.NotFound, "snapshot Name cannot be empty")
} }
if req.SourceVolumeId == "" { if req.GetSourceVolumeId() == "" {
return status.Error(codes.NotFound, "source Volume ID cannot be empty") return status.Error(codes.NotFound, "source Volume ID cannot be empty")
} }

View File

@ -17,13 +17,12 @@ limitations under the License.
package core package core
import ( import (
"errors"
"testing" "testing"
cerrors "github.com/ceph/ceph-csi/internal/cephfs/errors" cerrors "github.com/ceph/ceph-csi/internal/cephfs/errors"
fsa "github.com/ceph/go-ceph/cephfs/admin" fsa "github.com/ceph/go-ceph/cephfs/admin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func TestCloneStateToError(t *testing.T) { func TestCloneStateToError(t *testing.T) {
@ -36,6 +35,6 @@ func TestCloneStateToError(t *testing.T) {
errorState[cephFSCloneState{fsa.CloneFailed, "", ""}] = cerrors.ErrCloneFailed errorState[cephFSCloneState{fsa.CloneFailed, "", ""}] = cerrors.ErrCloneFailed
for state, err := range errorState { for state, err := range errorState {
assert.True(t, errors.Is(state.ToError(), err)) require.ErrorIs(t, state.ToError(), err)
} }
} }

View File

@ -29,11 +29,11 @@ import (
// that interacts with CephFS filesystem API's. // that interacts with CephFS filesystem API's.
type FileSystem interface { type FileSystem interface {
// GetFscID returns the ID of the filesystem with the given name. // GetFscID returns the ID of the filesystem with the given name.
GetFscID(context.Context, string) (int64, error) GetFscID(ctx context.Context, fsName string) (int64, error)
// GetMetadataPool returns the metadata pool name of the filesystem with the given name. // GetMetadataPool returns the metadata pool name of the filesystem with the given name.
GetMetadataPool(context.Context, string) (string, error) GetMetadataPool(ctx context.Context, fsName string) (string, error)
// GetFsName returns the name of the filesystem with the given ID. // GetFsName returns the name of the filesystem with the given ID.
GetFsName(context.Context, int64) (string, error) GetFsName(ctx context.Context, fsID int64) (string, error)
} }
// fileSystem is the implementation of FileSystem interface. // fileSystem is the implementation of FileSystem interface.

View File

@ -20,7 +20,6 @@ import (
"os" "os"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/ceph/ceph-csi/internal/util" "github.com/ceph/ceph-csi/internal/util"
@ -44,7 +43,7 @@ func TestSetupCSIAddonsServer(t *testing.T) {
// verify the socket file has been created // verify the socket file has been created
_, err = os.Stat(tmpDir + "/csi-addons.sock") _, err = os.Stat(tmpDir + "/csi-addons.sock")
assert.NoError(t, err) require.NoError(t, err)
// stop the gRPC server // stop the gRPC server
drv.cas.Stop() drv.cas.Stop()

View File

@ -180,7 +180,7 @@ func (cs *ControllerServer) CreateVolumeGroupSnapshot(
for _, r := range *resp { for _, r := range *resp {
r.Snapshot.GroupSnapshotId = vgs.VolumeGroupSnapshotID r.Snapshot.GroupSnapshotId = vgs.VolumeGroupSnapshotID
response.GroupSnapshot.Snapshots = append(response.GroupSnapshot.Snapshots, r.Snapshot) response.GroupSnapshot.Snapshots = append(response.GroupSnapshot.Snapshots, r.GetSnapshot())
} }
return response, nil return response, nil
@ -293,7 +293,7 @@ func (cs *ControllerServer) releaseQuiesceAndGetVolumeGroupSnapshotResponse(
for _, r := range snapshotResponses { for _, r := range snapshotResponses {
r.Snapshot.GroupSnapshotId = vgs.VolumeGroupSnapshotID r.Snapshot.GroupSnapshotId = vgs.VolumeGroupSnapshotID
response.GroupSnapshot.Snapshots = append(response.GroupSnapshot.Snapshots, r.Snapshot) response.GroupSnapshot.Snapshots = append(response.GroupSnapshot.Snapshots, r.GetSnapshot())
} }
return response, nil return response, nil
@ -703,7 +703,7 @@ func (cs *ControllerServer) DeleteVolumeGroupSnapshot(ctx context.Context,
return nil, err return nil, err
} }
groupSnapshotID := req.GroupSnapshotId groupSnapshotID := req.GetGroupSnapshotId()
// Existence and conflict checks // Existence and conflict checks
if acquired := cs.VolumeGroupLocks.TryAcquire(groupSnapshotID); !acquired { if acquired := cs.VolumeGroupLocks.TryAcquire(groupSnapshotID); !acquired {
log.ErrorLog(ctx, util.VolumeOperationAlreadyExistsFmt, groupSnapshotID) log.ErrorLog(ctx, util.VolumeOperationAlreadyExistsFmt, groupSnapshotID)
@ -718,7 +718,7 @@ func (cs *ControllerServer) DeleteVolumeGroupSnapshot(ctx context.Context,
} }
defer cr.DeleteCredentials() defer cr.DeleteCredentials()
vgo, vgsi, err := store.NewVolumeGroupOptionsFromID(ctx, req.GroupSnapshotId, cr) vgo, vgsi, err := store.NewVolumeGroupOptionsFromID(ctx, req.GetGroupSnapshotId(), cr)
if err != nil { if err != nil {
log.ErrorLog(ctx, "failed to get volume group options: %v", err) log.ErrorLog(ctx, "failed to get volume group options: %v", err)
err = extractDeleteVolumeGroupError(err) err = extractDeleteVolumeGroupError(err)

View File

@ -81,7 +81,7 @@ func (m *kernelMounter) mountKernel(
optionsStr := fmt.Sprintf("name=%s,secretfile=%s", cr.ID, cr.KeyFile) optionsStr := fmt.Sprintf("name=%s,secretfile=%s", cr.ID, cr.KeyFile)
mdsNamespace := "" mdsNamespace := ""
if volOptions.FsName != "" { if volOptions.FsName != "" {
mdsNamespace = fmt.Sprintf("mds_namespace=%s", volOptions.FsName) mdsNamespace = "mds_namespace=" + volOptions.FsName
} }
optionsStr = util.MountOptionsAdd(optionsStr, mdsNamespace, volOptions.KernelMountOptions, netDev) optionsStr = util.MountOptionsAdd(optionsStr, mdsNamespace, volOptions.KernelMountOptions, netDev)

View File

@ -19,7 +19,7 @@ package mounter
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func TestFilesystemSupported(t *testing.T) { func TestFilesystemSupported(t *testing.T) {
@ -31,8 +31,8 @@ func TestFilesystemSupported(t *testing.T) {
// "proc" is always a supported filesystem, we detect supported // "proc" is always a supported filesystem, we detect supported
// filesystems by reading from it // filesystems by reading from it
assert.True(t, filesystemSupported("proc")) require.True(t, filesystemSupported("proc"))
// "nonefs" is a made-up name, and does not exist // "nonefs" is a made-up name, and does not exist
assert.False(t, filesystemSupported("nonefs")) require.False(t, filesystemSupported("nonefs"))
} }

View File

@ -110,8 +110,7 @@ func (ns *NodeServer) getVolumeOptions(
func validateSnapshotBackedVolCapability(volCap *csi.VolumeCapability) error { func validateSnapshotBackedVolCapability(volCap *csi.VolumeCapability) error {
// Snapshot-backed volumes may be used with read-only volume access modes only. // Snapshot-backed volumes may be used with read-only volume access modes only.
mode := volCap.AccessMode.Mode mode := volCap.GetAccessMode().GetMode()
if mode != csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY && if mode != csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY &&
mode != csi.VolumeCapability_AccessMode_SINGLE_NODE_READER_ONLY { mode != csi.VolumeCapability_AccessMode_SINGLE_NODE_READER_ONLY {
return status.Error(codes.InvalidArgument, return status.Error(codes.InvalidArgument,
@ -352,7 +351,6 @@ func (ns *NodeServer) mount(
true, true,
[]string{"bind", "_netdev"}, []string{"bind", "_netdev"},
) )
if err != nil { if err != nil {
log.ErrorLog(ctx, log.ErrorLog(ctx,
"failed to bind mount snapshot root %s: %v", absoluteSnapshotRoot, err) "failed to bind mount snapshot root %s: %v", absoluteSnapshotRoot, err)
@ -813,9 +811,9 @@ func (ns *NodeServer) setMountOptions(
} }
const readOnly = "ro" const readOnly = "ro"
mode := volCap.GetAccessMode().GetMode()
if volCap.AccessMode.Mode == csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY || if mode == csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY ||
volCap.AccessMode.Mode == csi.VolumeCapability_AccessMode_SINGLE_NODE_READER_ONLY { mode == csi.VolumeCapability_AccessMode_SINGLE_NODE_READER_ONLY {
switch mnt.(type) { switch mnt.(type) {
case *mounter.FuseMounter: case *mounter.FuseMounter:
if !csicommon.MountOptionContains(strings.Split(volOptions.FuseMountOptions, ","), readOnly) { if !csicommon.MountOptionContains(strings.Split(volOptions.FuseMountOptions, ","), readOnly) {

View File

@ -18,7 +18,6 @@ package store
import ( import (
"context" "context"
"fmt"
fsutil "github.com/ceph/ceph-csi/internal/cephfs/util" fsutil "github.com/ceph/ceph-csi/internal/cephfs/util"
"github.com/ceph/ceph-csi/internal/util/log" "github.com/ceph/ceph-csi/internal/util/log"
@ -28,7 +27,7 @@ import (
) )
func fmtBackingSnapshotReftrackerName(backingSnapID string) string { func fmtBackingSnapshotReftrackerName(backingSnapID string) string {
return fmt.Sprintf("rt-backingsnapshot-%s", backingSnapID) return "rt-backingsnapshot-" + backingSnapID
} }
func AddSnapshotBackedVolumeRef( func AddSnapshotBackedVolumeRef(

View File

@ -168,7 +168,7 @@ func extractMounter(dest *string, options map[string]string) error {
func GetClusterInformation(options map[string]string) (*cephcsi.ClusterInfo, error) { func GetClusterInformation(options map[string]string) (*cephcsi.ClusterInfo, error) {
clusterID, ok := options["clusterID"] clusterID, ok := options["clusterID"]
if !ok { if !ok {
err := fmt.Errorf("clusterID must be set") err := errors.New("clusterID must be set")
return nil, err return nil, err
} }
@ -344,15 +344,15 @@ func NewVolumeOptions(
// IsShallowVolumeSupported returns true only for ReadOnly volume requests // IsShallowVolumeSupported returns true only for ReadOnly volume requests
// with datasource as snapshot. // with datasource as snapshot.
func IsShallowVolumeSupported(req *csi.CreateVolumeRequest) bool { func IsShallowVolumeSupported(req *csi.CreateVolumeRequest) bool {
isRO := IsVolumeCreateRO(req.VolumeCapabilities) isRO := IsVolumeCreateRO(req.GetVolumeCapabilities())
return isRO && (req.GetVolumeContentSource() != nil && req.GetVolumeContentSource().GetSnapshot() != nil) return isRO && (req.GetVolumeContentSource() != nil && req.GetVolumeContentSource().GetSnapshot() != nil)
} }
func IsVolumeCreateRO(caps []*csi.VolumeCapability) bool { func IsVolumeCreateRO(caps []*csi.VolumeCapability) bool {
for _, cap := range caps { for _, cap := range caps {
if cap.AccessMode != nil { if cap.GetAccessMode() != nil {
switch cap.AccessMode.Mode { //nolint:exhaustive // only check what we want switch cap.GetAccessMode().GetMode() { //nolint:exhaustive // only check what we want
case csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY, case csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY,
csi.VolumeCapability_AccessMode_SINGLE_NODE_READER_ONLY: csi.VolumeCapability_AccessMode_SINGLE_NODE_READER_ONLY:
return true return true
@ -612,7 +612,7 @@ func NewVolumeOptionsFromMonitorList(
// check if there are mon values in secret and if so override option retrieved monitors from // check if there are mon values in secret and if so override option retrieved monitors from
// monitors in the secret // monitors in the secret
mon, err := util.GetMonValFromSecret(secrets) mon, err := util.GetMonValFromSecret(secrets)
if err == nil && len(mon) > 0 { if err == nil && mon != "" {
opts.Monitors = mon opts.Monitors = mon
} }

View File

@ -54,11 +54,11 @@ func (cs *ControllerServer) validateCreateVolumeRequest(req *csi.CreateVolumeReq
return err return err
} }
if req.VolumeContentSource != nil { if req.GetVolumeContentSource() != nil {
volumeSource := req.VolumeContentSource volumeSource := req.GetVolumeContentSource()
switch volumeSource.Type.(type) { switch volumeSource.GetType().(type) {
case *csi.VolumeContentSource_Snapshot: case *csi.VolumeContentSource_Snapshot:
snapshot := req.VolumeContentSource.GetSnapshot() snapshot := req.GetVolumeContentSource().GetSnapshot()
// CSI spec requires returning NOT_FOUND when the volumeSource is missing/incorrect. // CSI spec requires returning NOT_FOUND when the volumeSource is missing/incorrect.
if snapshot == nil { if snapshot == nil {
return status.Error(codes.NotFound, "volume Snapshot cannot be empty") return status.Error(codes.NotFound, "volume Snapshot cannot be empty")
@ -68,7 +68,7 @@ func (cs *ControllerServer) validateCreateVolumeRequest(req *csi.CreateVolumeReq
} }
case *csi.VolumeContentSource_Volume: case *csi.VolumeContentSource_Volume:
// CSI spec requires returning NOT_FOUND when the volumeSource is missing/incorrect. // CSI spec requires returning NOT_FOUND when the volumeSource is missing/incorrect.
vol := req.VolumeContentSource.GetVolume() vol := req.GetVolumeContentSource().GetVolume()
if vol == nil { if vol == nil {
return status.Error(codes.NotFound, "volume cannot be empty") return status.Error(codes.NotFound, "volume cannot be empty")
} }

View File

@ -31,7 +31,7 @@ import (
// The New controllers which gets added, as to implement Add function to get // The New controllers which gets added, as to implement Add function to get
// started by the manager. // started by the manager.
type Manager interface { type Manager interface {
Add(manager.Manager, Config) error Add(mgr manager.Manager, cfg Config) error
} }
// Config holds the drivername and namespace name. // Config holds the drivername and namespace name.

View File

@ -66,7 +66,7 @@ func (fcs *FenceControllerServer) FenceClusterNetwork(
ctx context.Context, ctx context.Context,
req *fence.FenceClusterNetworkRequest, req *fence.FenceClusterNetworkRequest,
) (*fence.FenceClusterNetworkResponse, error) { ) (*fence.FenceClusterNetworkResponse, error) {
err := validateNetworkFenceReq(req.GetCidrs(), req.Parameters) err := validateNetworkFenceReq(req.GetCidrs(), req.GetParameters())
if err != nil { if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error()) return nil, status.Error(codes.InvalidArgument, err.Error())
} }
@ -77,7 +77,7 @@ func (fcs *FenceControllerServer) FenceClusterNetwork(
} }
defer cr.DeleteCredentials() defer cr.DeleteCredentials()
nwFence, err := nf.NewNetworkFence(ctx, cr, req.Cidrs, req.GetParameters()) nwFence, err := nf.NewNetworkFence(ctx, cr, req.GetCidrs(), req.GetParameters())
if err != nil { if err != nil {
return nil, status.Error(codes.Internal, err.Error()) return nil, status.Error(codes.Internal, err.Error())
} }
@ -95,7 +95,7 @@ func (fcs *FenceControllerServer) UnfenceClusterNetwork(
ctx context.Context, ctx context.Context,
req *fence.UnfenceClusterNetworkRequest, req *fence.UnfenceClusterNetworkRequest,
) (*fence.UnfenceClusterNetworkResponse, error) { ) (*fence.UnfenceClusterNetworkResponse, error) {
err := validateNetworkFenceReq(req.GetCidrs(), req.Parameters) err := validateNetworkFenceReq(req.GetCidrs(), req.GetParameters())
if err != nil { if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error()) return nil, status.Error(codes.InvalidArgument, err.Error())
} }
@ -106,7 +106,7 @@ func (fcs *FenceControllerServer) UnfenceClusterNetwork(
} }
defer cr.DeleteCredentials() defer cr.DeleteCredentials()
nwFence, err := nf.NewNetworkFence(ctx, cr, req.Cidrs, req.GetParameters()) nwFence, err := nf.NewNetworkFence(ctx, cr, req.GetCidrs(), req.GetParameters())
if err != nil { if err != nil {
return nil, status.Error(codes.Internal, err.Error()) return nil, status.Error(codes.Internal, err.Error())
} }

View File

@ -21,7 +21,7 @@ import (
"testing" "testing"
"github.com/csi-addons/spec/lib/go/fence" "github.com/csi-addons/spec/lib/go/fence"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
// TestFenceClusterNetwork is a minimal test for the FenceClusterNetwork() // TestFenceClusterNetwork is a minimal test for the FenceClusterNetwork()
@ -39,7 +39,7 @@ func TestFenceClusterNetwork(t *testing.T) {
} }
_, err := controller.FenceClusterNetwork(context.TODO(), req) _, err := controller.FenceClusterNetwork(context.TODO(), req)
assert.Error(t, err) require.Error(t, err)
} }
// TestUnfenceClusterNetwork is a minimal test for the UnfenceClusterNetwork() // TestUnfenceClusterNetwork is a minimal test for the UnfenceClusterNetwork()
@ -55,5 +55,5 @@ func TestUnfenceClusterNetwork(t *testing.T) {
Cidrs: nil, Cidrs: nil,
} }
_, err := controller.UnfenceClusterNetwork(context.TODO(), req) _, err := controller.UnfenceClusterNetwork(context.TODO(), req)
assert.Error(t, err) require.Error(t, err)
} }

View File

@ -348,7 +348,7 @@ type Cidrs []*fence.CIDR
func GetCIDR(cidrs Cidrs) ([]string, error) { func GetCIDR(cidrs Cidrs) ([]string, error) {
var cidrList []string var cidrList []string
for _, cidr := range cidrs { for _, cidr := range cidrs {
cidrList = append(cidrList, cidr.Cidr) cidrList = append(cidrList, cidr.GetCidr())
} }
if len(cidrList) < 1 { if len(cidrList) < 1 {
return nil, errors.New("the CIDR cannot be empty") return nil, errors.New("the CIDR cannot be empty")

View File

@ -19,7 +19,7 @@ package networkfence
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func TestGetIPRange(t *testing.T) { func TestGetIPRange(t *testing.T) {
@ -47,10 +47,10 @@ func TestGetIPRange(t *testing.T) {
t.Run(ts.cidr, func(t *testing.T) { t.Run(ts.cidr, func(t *testing.T) {
t.Parallel() t.Parallel()
got, err := getIPRange(ts.cidr) got, err := getIPRange(ts.cidr)
assert.NoError(t, err) require.NoError(t, err)
// validate if number of IPs in the range is same as expected, if not, fail. // validate if number of IPs in the range is same as expected, if not, fail.
assert.ElementsMatch(t, ts.expectedIPs, got) require.ElementsMatch(t, ts.expectedIPs, got)
}) })
} }
} }

View File

@ -62,7 +62,7 @@ func (fcs *FenceControllerServer) FenceClusterNetwork(
ctx context.Context, ctx context.Context,
req *fence.FenceClusterNetworkRequest, req *fence.FenceClusterNetworkRequest,
) (*fence.FenceClusterNetworkResponse, error) { ) (*fence.FenceClusterNetworkResponse, error) {
err := validateNetworkFenceReq(req.GetCidrs(), req.Parameters) err := validateNetworkFenceReq(req.GetCidrs(), req.GetParameters())
if err != nil { if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error()) return nil, status.Error(codes.InvalidArgument, err.Error())
} }
@ -73,7 +73,7 @@ func (fcs *FenceControllerServer) FenceClusterNetwork(
} }
defer cr.DeleteCredentials() defer cr.DeleteCredentials()
nwFence, err := nf.NewNetworkFence(ctx, cr, req.Cidrs, req.GetParameters()) nwFence, err := nf.NewNetworkFence(ctx, cr, req.GetCidrs(), req.GetParameters())
if err != nil { if err != nil {
return nil, status.Error(codes.Internal, err.Error()) return nil, status.Error(codes.Internal, err.Error())
} }
@ -91,7 +91,7 @@ func (fcs *FenceControllerServer) UnfenceClusterNetwork(
ctx context.Context, ctx context.Context,
req *fence.UnfenceClusterNetworkRequest, req *fence.UnfenceClusterNetworkRequest,
) (*fence.UnfenceClusterNetworkResponse, error) { ) (*fence.UnfenceClusterNetworkResponse, error) {
err := validateNetworkFenceReq(req.GetCidrs(), req.Parameters) err := validateNetworkFenceReq(req.GetCidrs(), req.GetParameters())
if err != nil { if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error()) return nil, status.Error(codes.InvalidArgument, err.Error())
} }
@ -102,7 +102,7 @@ func (fcs *FenceControllerServer) UnfenceClusterNetwork(
} }
defer cr.DeleteCredentials() defer cr.DeleteCredentials()
nwFence, err := nf.NewNetworkFence(ctx, cr, req.Cidrs, req.GetParameters()) nwFence, err := nf.NewNetworkFence(ctx, cr, req.GetCidrs(), req.GetParameters())
if err != nil { if err != nil {
return nil, status.Error(codes.Internal, err.Error()) return nil, status.Error(codes.Internal, err.Error())
} }

View File

@ -17,9 +17,8 @@ import (
"context" "context"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/csi-addons/spec/lib/go/fence" "github.com/csi-addons/spec/lib/go/fence"
"github.com/stretchr/testify/require"
) )
// TestFenceClusterNetwork is a minimal test for the FenceClusterNetwork() // TestFenceClusterNetwork is a minimal test for the FenceClusterNetwork()
@ -37,7 +36,7 @@ func TestFenceClusterNetwork(t *testing.T) {
} }
_, err := controller.FenceClusterNetwork(context.TODO(), req) _, err := controller.FenceClusterNetwork(context.TODO(), req)
assert.Error(t, err) require.Error(t, err)
} }
// TestUnfenceClusterNetwork is a minimal test for the UnfenceClusterNetwork() // TestUnfenceClusterNetwork is a minimal test for the UnfenceClusterNetwork()
@ -53,5 +52,5 @@ func TestUnfenceClusterNetwork(t *testing.T) {
Cidrs: nil, Cidrs: nil,
} }
_, err := controller.UnfenceClusterNetwork(context.TODO(), req) _, err := controller.UnfenceClusterNetwork(context.TODO(), req)
assert.Error(t, err) require.Error(t, err)
} }

View File

@ -20,9 +20,8 @@ import (
"context" "context"
"testing" "testing"
"github.com/stretchr/testify/assert"
rs "github.com/csi-addons/spec/lib/go/reclaimspace" rs "github.com/csi-addons/spec/lib/go/reclaimspace"
"github.com/stretchr/testify/require"
) )
// TestControllerReclaimSpace is a minimal test for the // TestControllerReclaimSpace is a minimal test for the
@ -39,7 +38,7 @@ func TestControllerReclaimSpace(t *testing.T) {
} }
_, err := controller.ControllerReclaimSpace(context.TODO(), req) _, err := controller.ControllerReclaimSpace(context.TODO(), req)
assert.Error(t, err) require.Error(t, err)
} }
// TestNodeReclaimSpace is a minimal test for the NodeReclaimSpace() procedure. // TestNodeReclaimSpace is a minimal test for the NodeReclaimSpace() procedure.
@ -58,5 +57,5 @@ func TestNodeReclaimSpace(t *testing.T) {
} }
_, err := node.NodeReclaimSpace(context.TODO(), req) _, err := node.NodeReclaimSpace(context.TODO(), req)
assert.Error(t, err) require.Error(t, err)
} }

View File

@ -709,7 +709,7 @@ func (rs *ReplicationServer) ResyncVolume(ctx context.Context,
return nil, status.Errorf(codes.Internal, "failed to parse image creation time: %s", sErr.Error()) return nil, status.Errorf(codes.Internal, "failed to parse image creation time: %s", sErr.Error())
} }
log.DebugLog(ctx, "image %s, savedImageTime=%v, currentImageTime=%v", rbdVol, st, creationTime.AsTime()) log.DebugLog(ctx, "image %s, savedImageTime=%v, currentImageTime=%v", rbdVol, st, creationTime.AsTime())
if req.Force && st.Equal(creationTime.AsTime()) { if req.GetForce() && st.Equal(creationTime.AsTime()) {
err = rbdVol.ResyncVol(localStatus) err = rbdVol.ResyncVol(localStatus)
if err != nil { if err != nil {
return nil, getGRPCError(err) return nil, getGRPCError(err)
@ -738,7 +738,7 @@ func (rs *ReplicationServer) ResyncVolume(ctx context.Context,
// timestampToString converts the time.Time object to string. // timestampToString converts the time.Time object to string.
func timestampToString(st *timestamppb.Timestamp) string { func timestampToString(st *timestamppb.Timestamp) string {
return fmt.Sprintf("seconds:%d nanos:%d", st.Seconds, st.Nanos) return fmt.Sprintf("seconds:%d nanos:%d", st.GetSeconds(), st.GetNanos())
} }
// timestampFromString parses the timestamp string and returns the time.Time // timestampFromString parses the timestamp string and returns the time.Time
@ -989,7 +989,7 @@ func checkVolumeResyncStatus(ctx context.Context, localStatus librbd.SiteMirrorI
if err != nil { if err != nil {
return fmt.Errorf("failed to get last sync info: %w", err) return fmt.Errorf("failed to get last sync info: %w", err)
} }
if resp.LastSyncTime == nil { if resp.GetLastSyncTime() == nil {
return errors.New("last sync time is nil") return errors.New("last sync time is nil")
} }

View File

@ -30,7 +30,7 @@ import (
librbd "github.com/ceph/go-ceph/rbd" librbd "github.com/ceph/go-ceph/rbd"
"github.com/ceph/go-ceph/rbd/admin" "github.com/ceph/go-ceph/rbd/admin"
"github.com/csi-addons/spec/lib/go/replication" "github.com/csi-addons/spec/lib/go/replication"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/durationpb"
@ -511,19 +511,29 @@ func TestValidateLastSyncInfo(t *testing.T) {
tt.expectedErr, err) tt.expectedErr, err)
} }
if teststruct != nil { if teststruct != nil {
if teststruct.LastSyncTime.GetSeconds() != tt.info.LastSyncTime.GetSeconds() { if teststruct.GetLastSyncTime().GetSeconds() != tt.info.GetLastSyncTime().GetSeconds() {
t.Errorf("name: %v, getLastSyncInfo() %v, expected %v", tt.name, teststruct.LastSyncTime, tt.info.LastSyncTime) t.Errorf("name: %v, getLastSyncInfo() %v, expected %v",
tt.name,
teststruct.GetLastSyncTime(),
tt.info.GetLastSyncTime())
} }
if tt.info.LastSyncDuration == nil && teststruct.LastSyncDuration != nil { if tt.info.GetLastSyncDuration() == nil && teststruct.GetLastSyncDuration() != nil {
t.Errorf("name: %v, getLastSyncInfo() %v, expected %v", tt.name, teststruct.LastSyncDuration, t.Errorf("name: %v, getLastSyncInfo() %v, expected %v",
tt.info.LastSyncDuration) tt.name,
teststruct.GetLastSyncDuration(),
tt.info.GetLastSyncDuration())
} }
if teststruct.LastSyncDuration.GetSeconds() != tt.info.LastSyncDuration.GetSeconds() { if teststruct.GetLastSyncDuration().GetSeconds() != tt.info.GetLastSyncDuration().GetSeconds() {
t.Errorf("name: %v, getLastSyncInfo() %v, expected %v", tt.name, teststruct.LastSyncDuration, t.Errorf("name: %v, getLastSyncInfo() %v, expected %v",
tt.info.LastSyncDuration) tt.name,
teststruct.GetLastSyncDuration(),
tt.info.GetLastSyncDuration())
} }
if teststruct.LastSyncBytes != tt.info.LastSyncBytes { if teststruct.GetLastSyncBytes() != tt.info.GetLastSyncBytes() {
t.Errorf("name: %v, getLastSyncInfo() %v, expected %v", tt.name, teststruct.LastSyncBytes, tt.info.LastSyncBytes) t.Errorf("name: %v, getLastSyncInfo() %v, expected %v",
tt.name,
teststruct.GetLastSyncBytes(),
tt.info.GetLastSyncBytes())
} }
} }
}) })
@ -594,7 +604,7 @@ func TestGetGRPCError(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
result := getGRPCError(tt.err) result := getGRPCError(tt.err)
assert.Equal(t, tt.expectedErr, result) require.Equal(t, tt.expectedErr, result)
}) })
} }
} }

View File

@ -19,7 +19,6 @@ package server
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -39,7 +38,7 @@ func TestNewCSIAddonsServer(t *testing.T) {
cas, err := NewCSIAddonsServer("") cas, err := NewCSIAddonsServer("")
require.Error(t, err) require.Error(t, err)
assert.Nil(t, cas) require.Nil(t, cas)
}) })
t.Run("no UDS endpoint", func(t *testing.T) { t.Run("no UDS endpoint", func(t *testing.T) {
@ -47,6 +46,6 @@ func TestNewCSIAddonsServer(t *testing.T) {
cas, err := NewCSIAddonsServer("endpoint at /tmp/...") cas, err := NewCSIAddonsServer("endpoint at /tmp/...")
require.Error(t, err) require.Error(t, err)
assert.Nil(t, cas) require.Nil(t, cas)
}) })
} }

View File

@ -70,8 +70,6 @@ func NewCSIDriver(name, v, nodeID string) *CSIDriver {
// ValidateControllerServiceRequest validates the controller // ValidateControllerServiceRequest validates the controller
// plugin capabilities. // plugin capabilities.
//
//nolint:interfacer // c can be of type fmt.Stringer, but that does not make the API clearer
func (d *CSIDriver) ValidateControllerServiceRequest(c csi.ControllerServiceCapability_RPC_Type) error { func (d *CSIDriver) ValidateControllerServiceRequest(c csi.ControllerServiceCapability_RPC_Type) error {
if c == csi.ControllerServiceCapability_RPC_UNKNOWN { if c == csi.ControllerServiceCapability_RPC_UNKNOWN {
return nil return nil
@ -133,8 +131,6 @@ func (d *CSIDriver) AddGroupControllerServiceCapabilities(cl []csi.GroupControll
// ValidateGroupControllerServiceRequest validates the group controller // ValidateGroupControllerServiceRequest validates the group controller
// plugin capabilities. // plugin capabilities.
//
//nolint:interfacer // c can be of type fmt.Stringer, but that does not make the API clearer
func (d *CSIDriver) ValidateGroupControllerServiceRequest(c csi.GroupControllerServiceCapability_RPC_Type) error { func (d *CSIDriver) ValidateGroupControllerServiceRequest(c csi.GroupControllerServiceCapability_RPC_Type) error {
if c == csi.GroupControllerServiceCapability_RPC_UNKNOWN { if c == csi.GroupControllerServiceCapability_RPC_UNKNOWN {
return nil return nil

View File

@ -86,7 +86,7 @@ func ConstructMountOptions(mountOptions []string, volCap *csi.VolumeCapability)
return false return false
} }
for _, f := range m.MountFlags { for _, f := range m.GetMountFlags() {
if !hasOption(mountOptions, f) { if !hasOption(mountOptions, f) {
mountOptions = append(mountOptions, f) mountOptions = append(mountOptions, f)
} }

View File

@ -121,52 +121,52 @@ func getReqID(req interface{}) string {
reqID := "" reqID := ""
switch r := req.(type) { switch r := req.(type) {
case *csi.CreateVolumeRequest: case *csi.CreateVolumeRequest:
reqID = r.Name reqID = r.GetName()
case *csi.DeleteVolumeRequest: case *csi.DeleteVolumeRequest:
reqID = r.VolumeId reqID = r.GetVolumeId()
case *csi.CreateSnapshotRequest: case *csi.CreateSnapshotRequest:
reqID = r.Name reqID = r.GetName()
case *csi.DeleteSnapshotRequest: case *csi.DeleteSnapshotRequest:
reqID = r.SnapshotId reqID = r.GetSnapshotId()
case *csi.ControllerExpandVolumeRequest: case *csi.ControllerExpandVolumeRequest:
reqID = r.VolumeId reqID = r.GetVolumeId()
case *csi.NodeStageVolumeRequest: case *csi.NodeStageVolumeRequest:
reqID = r.VolumeId reqID = r.GetVolumeId()
case *csi.NodeUnstageVolumeRequest: case *csi.NodeUnstageVolumeRequest:
reqID = r.VolumeId reqID = r.GetVolumeId()
case *csi.NodePublishVolumeRequest: case *csi.NodePublishVolumeRequest:
reqID = r.VolumeId reqID = r.GetVolumeId()
case *csi.NodeUnpublishVolumeRequest: case *csi.NodeUnpublishVolumeRequest:
reqID = r.VolumeId reqID = r.GetVolumeId()
case *csi.NodeExpandVolumeRequest: case *csi.NodeExpandVolumeRequest:
reqID = r.VolumeId reqID = r.GetVolumeId()
case *csi.CreateVolumeGroupSnapshotRequest: case *csi.CreateVolumeGroupSnapshotRequest:
reqID = r.Name reqID = r.GetName()
case *csi.DeleteVolumeGroupSnapshotRequest: case *csi.DeleteVolumeGroupSnapshotRequest:
reqID = r.GroupSnapshotId reqID = r.GetGroupSnapshotId()
case *csi.GetVolumeGroupSnapshotRequest: case *csi.GetVolumeGroupSnapshotRequest:
reqID = r.GroupSnapshotId reqID = r.GetGroupSnapshotId()
// Replication // Replication
case *replication.EnableVolumeReplicationRequest: case *replication.EnableVolumeReplicationRequest:
reqID = r.VolumeId reqID = r.GetVolumeId()
case *replication.DisableVolumeReplicationRequest: case *replication.DisableVolumeReplicationRequest:
reqID = r.VolumeId reqID = r.GetVolumeId()
case *replication.PromoteVolumeRequest: case *replication.PromoteVolumeRequest:
reqID = r.VolumeId reqID = r.GetVolumeId()
case *replication.DemoteVolumeRequest: case *replication.DemoteVolumeRequest:
reqID = r.VolumeId reqID = r.GetVolumeId()
case *replication.ResyncVolumeRequest: case *replication.ResyncVolumeRequest:
reqID = r.VolumeId reqID = r.GetVolumeId()
case *replication.GetVolumeReplicationInfoRequest: case *replication.GetVolumeReplicationInfoRequest:
reqID = r.VolumeId reqID = r.GetVolumeId()
} }
return reqID return reqID
@ -353,9 +353,9 @@ func IsFileRWO(caps []*csi.VolumeCapability) bool {
// to preserve backward compatibility we allow RWO filemode, ideally SINGLE_NODE_WRITER check is good enough, // to preserve backward compatibility we allow RWO filemode, ideally SINGLE_NODE_WRITER check is good enough,
// however more granular level check could help us in future, so keeping it here as an additional measure. // however more granular level check could help us in future, so keeping it here as an additional measure.
for _, cap := range caps { for _, cap := range caps {
if cap.AccessMode != nil { if cap.GetAccessMode() != nil {
if cap.GetMount() != nil { if cap.GetMount() != nil {
switch cap.AccessMode.Mode { //nolint:exhaustive // only check what we want switch cap.GetAccessMode().GetMode() { //nolint:exhaustive // only check what we want
case csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER, case csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER,
csi.VolumeCapability_AccessMode_SINGLE_NODE_MULTI_WRITER, csi.VolumeCapability_AccessMode_SINGLE_NODE_MULTI_WRITER,
csi.VolumeCapability_AccessMode_SINGLE_NODE_SINGLE_WRITER: csi.VolumeCapability_AccessMode_SINGLE_NODE_SINGLE_WRITER:
@ -372,8 +372,8 @@ func IsFileRWO(caps []*csi.VolumeCapability) bool {
// or block mode. // or block mode.
func IsReaderOnly(caps []*csi.VolumeCapability) bool { func IsReaderOnly(caps []*csi.VolumeCapability) bool {
for _, cap := range caps { for _, cap := range caps {
if cap.AccessMode != nil { if cap.GetAccessMode() != nil {
switch cap.AccessMode.Mode { //nolint:exhaustive // only check what we want switch cap.GetAccessMode().GetMode() { //nolint:exhaustive // only check what we want
case csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY, case csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY,
csi.VolumeCapability_AccessMode_SINGLE_NODE_READER_ONLY: csi.VolumeCapability_AccessMode_SINGLE_NODE_READER_ONLY:
return true return true
@ -397,8 +397,8 @@ func IsBlockMultiWriter(caps []*csi.VolumeCapability) (bool, bool) {
var block bool var block bool
for _, cap := range caps { for _, cap := range caps {
if cap.AccessMode != nil { if cap.GetAccessMode() != nil {
switch cap.AccessMode.Mode { //nolint:exhaustive // only check what we want switch cap.GetAccessMode().GetMode() { //nolint:exhaustive // only check what we want
case csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER, case csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER,
csi.VolumeCapability_AccessMode_SINGLE_NODE_MULTI_WRITER: csi.VolumeCapability_AccessMode_SINGLE_NODE_MULTI_WRITER:
multiWriter = true multiWriter = true

View File

@ -25,7 +25,6 @@ import (
"github.com/container-storage-interface/spec/lib/go/csi" "github.com/container-storage-interface/spec/lib/go/csi"
"github.com/csi-addons/spec/lib/go/replication" "github.com/csi-addons/spec/lib/go/replication"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
mount "k8s.io/mount-utils" mount "k8s.io/mount-utils"
) )
@ -127,12 +126,12 @@ func TestFilesystemNodeGetVolumeStats(t *testing.T) {
} }
require.NoError(t, err) require.NoError(t, err)
assert.NotEqual(t, len(stats.Usage), 0) require.NotEmpty(t, stats.GetUsage())
for _, usage := range stats.Usage { for _, usage := range stats.GetUsage() {
assert.NotEqual(t, usage.Available, -1) require.NotEqual(t, -1, usage.GetAvailable())
assert.NotEqual(t, usage.Total, -1) require.NotEqual(t, -1, usage.GetTotal())
assert.NotEqual(t, usage.Used, -1) require.NotEqual(t, -1, usage.GetUsed())
assert.NotEqual(t, usage.Unit, 0) require.NotEqual(t, 0, usage.GetUnit())
} }
// tests done, no need to retry again // tests done, no need to retry again
@ -143,9 +142,9 @@ func TestFilesystemNodeGetVolumeStats(t *testing.T) {
func TestRequirePositive(t *testing.T) { func TestRequirePositive(t *testing.T) {
t.Parallel() t.Parallel()
assert.Equal(t, requirePositive(0), int64(0)) require.Equal(t, int64(0), requirePositive(0))
assert.Equal(t, requirePositive(-1), int64(0)) require.Equal(t, int64(0), requirePositive(-1))
assert.Equal(t, requirePositive(1), int64(1)) require.Equal(t, int64(1), requirePositive(1))
} }
func TestIsBlockMultiNode(t *testing.T) { func TestIsBlockMultiNode(t *testing.T) {
@ -204,8 +203,8 @@ func TestIsBlockMultiNode(t *testing.T) {
for _, test := range tests { for _, test := range tests {
isBlock, isMultiNode := IsBlockMultiNode(test.caps) isBlock, isMultiNode := IsBlockMultiNode(test.caps)
assert.Equal(t, isBlock, test.isBlock, test.name) require.Equal(t, isBlock, test.isBlock, test.name)
assert.Equal(t, isMultiNode, test.isMultiNode, test.name) require.Equal(t, isMultiNode, test.isMultiNode, test.name)
} }
} }

View File

@ -561,7 +561,7 @@ func (conn *Connection) ReserveName(ctx context.Context,
imagePool string, imagePoolID int64, imagePool string, imagePoolID int64,
reqName, namePrefix, parentName, kmsConf, volUUID, owner, reqName, namePrefix, parentName, kmsConf, volUUID, owner,
backingSnapshotID string, backingSnapshotID string,
encryptionType util.EncryptionType, //nolint:interfacer // prefer util.EncryptionType over fmt.Stringer encryptionType util.EncryptionType,
) (string, string, error) { ) (string, string, error) {
// TODO: Take in-arg as ImageAttributes? // TODO: Take in-arg as ImageAttributes?
var ( var (

View File

@ -19,11 +19,11 @@ package kms
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func TestAWSMetadataKMSRegistered(t *testing.T) { func TestAWSMetadataKMSRegistered(t *testing.T) {
t.Parallel() t.Parallel()
_, ok := kmsManager.providers[kmsTypeAWSMetadata] _, ok := kmsManager.providers[kmsTypeAWSMetadata]
assert.True(t, ok) require.True(t, ok)
} }

View File

@ -19,11 +19,11 @@ package kms
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func TestAWSSTSMetadataKMSRegistered(t *testing.T) { func TestAWSSTSMetadataKMSRegistered(t *testing.T) {
t.Parallel() t.Parallel()
_, ok := kmsManager.providers[kmsTypeAWSSTSMetadata] _, ok := kmsManager.providers[kmsTypeAWSSTSMetadata]
assert.True(t, ok) require.True(t, ok)
} }

View File

@ -19,11 +19,11 @@ package kms
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func TestAzureKMSRegistered(t *testing.T) { func TestAzureKMSRegistered(t *testing.T) {
t.Parallel() t.Parallel()
_, ok := kmsManager.providers[kmsTypeAzure] _, ok := kmsManager.providers[kmsTypeAzure]
assert.True(t, ok) require.True(t, ok)
} }

View File

@ -19,11 +19,11 @@ package kms
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func TestKeyProtectMetadataKMSRegistered(t *testing.T) { func TestKeyProtectMetadataKMSRegistered(t *testing.T) {
t.Parallel() t.Parallel()
_, ok := kmsManager.providers[kmsTypeKeyProtectMetadata] _, ok := kmsManager.providers[kmsTypeKeyProtectMetadata]
assert.True(t, ok) require.True(t, ok)
} }

View File

@ -19,11 +19,11 @@ package kms
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func TestKMIPKMSRegistered(t *testing.T) { func TestKMIPKMSRegistered(t *testing.T) {
t.Parallel() t.Parallel()
_, ok := kmsManager.providers[kmsTypeKMIP] _, ok := kmsManager.providers[kmsTypeKMIP]
assert.True(t, ok) require.True(t, ok)
} }

View File

@ -19,7 +19,7 @@ package kms
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func noinitKMS(args ProviderInitArgs) (EncryptionKMS, error) { func noinitKMS(args ProviderInitArgs) (EncryptionKMS, error) {
@ -47,9 +47,9 @@ func TestRegisterProvider(t *testing.T) {
for _, test := range tests { for _, test := range tests {
provider := test.provider provider := test.provider
if test.panics { if test.panics {
assert.Panics(t, func() { RegisterProvider(provider) }) require.Panics(t, func() { RegisterProvider(provider) })
} else { } else {
assert.True(t, RegisterProvider(provider)) require.True(t, RegisterProvider(provider))
} }
} }
} }

View File

@ -20,7 +20,7 @@ import (
"errors" "errors"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func TestSetConfigInt(t *testing.T) { func TestSetConfigInt(t *testing.T) {
@ -81,7 +81,7 @@ func TestSetConfigInt(t *testing.T) {
t.Errorf("setConfigInt() error = %v, wantErr %v", err, currentTT.err) t.Errorf("setConfigInt() error = %v, wantErr %v", err, currentTT.err)
} }
if err != nil { if err != nil {
assert.NotEqual(t, currentTT.value, currentTT.args.option) require.NotEqual(t, currentTT.value, currentTT.args.option)
} }
}) })
} }

View File

@ -20,7 +20,6 @@ import (
"context" "context"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -32,24 +31,24 @@ func TestNewSecretsKMS(t *testing.T) {
kms, err := newSecretsKMS(ProviderInitArgs{ kms, err := newSecretsKMS(ProviderInitArgs{
Secrets: secrets, Secrets: secrets,
}) })
assert.Error(t, err) require.Error(t, err)
assert.Nil(t, kms) require.Nil(t, kms)
// set a passphrase and it should pass // set a passphrase and it should pass
secrets[encryptionPassphraseKey] = "plaintext encryption key" secrets[encryptionPassphraseKey] = "plaintext encryption key"
kms, err = newSecretsKMS(ProviderInitArgs{ kms, err = newSecretsKMS(ProviderInitArgs{
Secrets: secrets, Secrets: secrets,
}) })
assert.NotNil(t, kms) require.NotNil(t, kms)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestGenerateNonce(t *testing.T) { func TestGenerateNonce(t *testing.T) {
t.Parallel() t.Parallel()
size := 64 size := 64
nonce, err := generateNonce(size) nonce, err := generateNonce(size)
assert.Equal(t, size, len(nonce)) require.Len(t, nonce, size)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestGenerateCipher(t *testing.T) { func TestGenerateCipher(t *testing.T) {
@ -59,8 +58,8 @@ func TestGenerateCipher(t *testing.T) {
salt := "unique-id-for-the-volume" salt := "unique-id-for-the-volume"
aead, err := generateCipher(passphrase, salt) aead, err := generateCipher(passphrase, salt)
assert.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, aead) require.NotNil(t, aead)
} }
func TestInitSecretsMetadataKMS(t *testing.T) { func TestInitSecretsMetadataKMS(t *testing.T) {
@ -73,16 +72,16 @@ func TestInitSecretsMetadataKMS(t *testing.T) {
// passphrase it not set, init should fail // passphrase it not set, init should fail
kms, err := initSecretsMetadataKMS(args) kms, err := initSecretsMetadataKMS(args)
assert.Error(t, err) require.Error(t, err)
assert.Nil(t, kms) require.Nil(t, kms)
// set a passphrase to get a working KMS // set a passphrase to get a working KMS
args.Secrets[encryptionPassphraseKey] = "my-passphrase-from-kubernetes" args.Secrets[encryptionPassphraseKey] = "my-passphrase-from-kubernetes"
kms, err = initSecretsMetadataKMS(args) kms, err = initSecretsMetadataKMS(args)
assert.NoError(t, err) require.NoError(t, err)
require.NotNil(t, kms) require.NotNil(t, kms)
assert.Equal(t, DEKStoreMetadata, kms.RequiresDEKStore()) require.Equal(t, DEKStoreMetadata, kms.RequiresDEKStore())
} }
func TestWorkflowSecretsMetadataKMS(t *testing.T) { func TestWorkflowSecretsMetadataKMS(t *testing.T) {
@ -98,7 +97,7 @@ func TestWorkflowSecretsMetadataKMS(t *testing.T) {
volumeID := "csi-vol-1b00f5f8-b1c1-11e9-8421-9243c1f659f0" volumeID := "csi-vol-1b00f5f8-b1c1-11e9-8421-9243c1f659f0"
kms, err := initSecretsMetadataKMS(args) kms, err := initSecretsMetadataKMS(args)
assert.NoError(t, err) require.NoError(t, err)
require.NotNil(t, kms) require.NotNil(t, kms)
// plainDEK is the (LUKS) passphrase for the volume // plainDEK is the (LUKS) passphrase for the volume
@ -107,25 +106,25 @@ func TestWorkflowSecretsMetadataKMS(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
encryptedDEK, err := kms.EncryptDEK(ctx, volumeID, plainDEK) encryptedDEK, err := kms.EncryptDEK(ctx, volumeID, plainDEK)
assert.NoError(t, err) require.NoError(t, err)
assert.NotEqual(t, "", encryptedDEK) require.NotEqual(t, "", encryptedDEK)
assert.NotEqual(t, plainDEK, encryptedDEK) require.NotEqual(t, plainDEK, encryptedDEK)
// with an incorrect volumeID, decrypting should fail // with an incorrect volumeID, decrypting should fail
decryptedDEK, err := kms.DecryptDEK(ctx, "incorrect-volumeID", encryptedDEK) decryptedDEK, err := kms.DecryptDEK(ctx, "incorrect-volumeID", encryptedDEK)
assert.Error(t, err) require.Error(t, err)
assert.Equal(t, "", decryptedDEK) require.Equal(t, "", decryptedDEK)
assert.NotEqual(t, plainDEK, decryptedDEK) require.NotEqual(t, plainDEK, decryptedDEK)
// with the right volumeID, decrypting should return the plainDEK // with the right volumeID, decrypting should return the plainDEK
decryptedDEK, err = kms.DecryptDEK(ctx, volumeID, encryptedDEK) decryptedDEK, err = kms.DecryptDEK(ctx, volumeID, encryptedDEK)
assert.NoError(t, err) require.NoError(t, err)
assert.NotEqual(t, "", decryptedDEK) require.NotEqual(t, "", decryptedDEK)
assert.Equal(t, plainDEK, decryptedDEK) require.Equal(t, plainDEK, decryptedDEK)
} }
func TestSecretsMetadataKMSRegistered(t *testing.T) { func TestSecretsMetadataKMSRegistered(t *testing.T) {
t.Parallel() t.Parallel()
_, ok := kmsManager.providers[kmsTypeSecretsMetadata] _, ok := kmsManager.providers[kmsTypeSecretsMetadata]
assert.True(t, ok) require.True(t, ok)
} }

View File

@ -20,13 +20,13 @@ import (
"errors" "errors"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func TestVaultTenantSAKMSRegistered(t *testing.T) { func TestVaultTenantSAKMSRegistered(t *testing.T) {
t.Parallel() t.Parallel()
_, ok := kmsManager.providers[kmsTypeVaultTenantSA] _, ok := kmsManager.providers[kmsTypeVaultTenantSA]
assert.True(t, ok) require.True(t, ok)
} }
func TestTenantSAParseConfig(t *testing.T) { func TestTenantSAParseConfig(t *testing.T) {

View File

@ -22,7 +22,6 @@ import (
"testing" "testing"
loss "github.com/libopenstorage/secrets" loss "github.com/libopenstorage/secrets"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -113,8 +112,8 @@ func TestDefaultVaultDestroyKeys(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
keyContext := vc.getDeleteKeyContext() keyContext := vc.getDeleteKeyContext()
destroySecret, ok := keyContext[loss.DestroySecret] destroySecret, ok := keyContext[loss.DestroySecret]
assert.NotEqual(t, destroySecret, "") require.NotEqual(t, "", destroySecret)
assert.True(t, ok) require.True(t, ok)
// setting vaultDestroyKeys to !true should remove the loss.DestroySecret entry // setting vaultDestroyKeys to !true should remove the loss.DestroySecret entry
config["vaultDestroyKeys"] = "false" config["vaultDestroyKeys"] = "false"
@ -122,11 +121,11 @@ func TestDefaultVaultDestroyKeys(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
keyContext = vc.getDeleteKeyContext() keyContext = vc.getDeleteKeyContext()
_, ok = keyContext[loss.DestroySecret] _, ok = keyContext[loss.DestroySecret]
assert.False(t, ok) require.False(t, ok)
} }
func TestVaultKMSRegistered(t *testing.T) { func TestVaultKMSRegistered(t *testing.T) {
t.Parallel() t.Parallel()
_, ok := kmsManager.providers[kmsTypeVault] _, ok := kmsManager.providers[kmsTypeVault]
assert.True(t, ok) require.True(t, ok)
} }

View File

@ -25,7 +25,6 @@ import (
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
loss "github.com/libopenstorage/secrets" loss "github.com/libopenstorage/secrets"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -205,18 +204,18 @@ func TestTransformConfig(t *testing.T) {
config, err := transformConfig(cm) config, err := transformConfig(cm)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, config["encryptionKMSType"], cm["KMS_PROVIDER"]) require.Equal(t, cm["KMS_PROVIDER"], config["encryptionKMSType"])
assert.Equal(t, config["vaultAddress"], cm["VAULT_ADDR"]) require.Equal(t, cm["VAULT_ADDR"], config["vaultAddress"])
assert.Equal(t, config["vaultBackend"], cm["VAULT_BACKEND"]) require.Equal(t, cm["VAULT_BACKEND"], config["vaultBackend"])
assert.Equal(t, config["vaultBackendPath"], cm["VAULT_BACKEND_PATH"]) require.Equal(t, cm["VAULT_BACKEND_PATH"], config["vaultBackendPath"])
assert.Equal(t, config["vaultDestroyKeys"], cm["VAULT_DESTROY_KEYS"]) require.Equal(t, cm["VAULT_DESTROY_KEYS"], config["vaultDestroyKeys"])
assert.Equal(t, config["vaultCAFromSecret"], cm["VAULT_CACERT"]) require.Equal(t, cm["VAULT_CACERT"], config["vaultCAFromSecret"])
assert.Equal(t, config["vaultTLSServerName"], cm["VAULT_TLS_SERVER_NAME"]) require.Equal(t, cm["VAULT_TLS_SERVER_NAME"], config["vaultTLSServerName"])
assert.Equal(t, config["vaultClientCertFromSecret"], cm["VAULT_CLIENT_CERT"]) require.Equal(t, cm["VAULT_CLIENT_CERT"], config["vaultClientCertFromSecret"])
assert.Equal(t, config["vaultClientCertKeyFromSecret"], cm["VAULT_CLIENT_KEY"]) require.Equal(t, cm["VAULT_CLIENT_KEY"], config["vaultClientCertKeyFromSecret"])
assert.Equal(t, config["vaultAuthNamespace"], cm["VAULT_AUTH_NAMESPACE"]) require.Equal(t, cm["VAULT_AUTH_NAMESPACE"], config["vaultAuthNamespace"])
assert.Equal(t, config["vaultNamespace"], cm["VAULT_NAMESPACE"]) require.Equal(t, cm["VAULT_NAMESPACE"], config["vaultNamespace"])
assert.Equal(t, config["vaultCAVerify"], "false") require.Equal(t, "false", config["vaultCAVerify"])
} }
func TestTransformConfigDefaults(t *testing.T) { func TestTransformConfigDefaults(t *testing.T) {
@ -226,15 +225,15 @@ func TestTransformConfigDefaults(t *testing.T) {
config, err := transformConfig(cm) config, err := transformConfig(cm)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, config["encryptionKMSType"], cm["KMS_PROVIDER"]) require.Equal(t, cm["KMS_PROVIDER"], config["encryptionKMSType"])
assert.Equal(t, config["vaultDestroyKeys"], vaultDefaultDestroyKeys) require.Equal(t, vaultDefaultDestroyKeys, config["vaultDestroyKeys"])
assert.Equal(t, config["vaultCAVerify"], strconv.FormatBool(vaultDefaultCAVerify)) require.Equal(t, strconv.FormatBool(vaultDefaultCAVerify), config["vaultCAVerify"])
} }
func TestVaultTokensKMSRegistered(t *testing.T) { func TestVaultTokensKMSRegistered(t *testing.T) {
t.Parallel() t.Parallel()
_, ok := kmsManager.providers[kmsTypeVaultTokens] _, ok := kmsManager.providers[kmsTypeVaultTokens]
assert.True(t, ok) require.True(t, ok)
} }
func TestSetTenantAuthNamespace(t *testing.T) { func TestSetTenantAuthNamespace(t *testing.T) {
@ -259,7 +258,7 @@ func TestSetTenantAuthNamespace(t *testing.T) {
kms.setTenantAuthNamespace(config) kms.setTenantAuthNamespace(config)
assert.Equal(tt, vaultNamespace, config["vaultAuthNamespace"]) require.Equal(tt, vaultNamespace, config["vaultAuthNamespace"])
}) })
t.Run("inherit vaultAuthNamespace", func(tt *testing.T) { t.Run("inherit vaultAuthNamespace", func(tt *testing.T) {
@ -283,7 +282,7 @@ func TestSetTenantAuthNamespace(t *testing.T) {
// when inheriting from the global config, the config of the // when inheriting from the global config, the config of the
// tenant should not have vaultAuthNamespace configured // tenant should not have vaultAuthNamespace configured
assert.Equal(tt, nil, config["vaultAuthNamespace"]) require.Nil(tt, config["vaultAuthNamespace"])
}) })
t.Run("unset vaultAuthNamespace", func(tt *testing.T) { t.Run("unset vaultAuthNamespace", func(tt *testing.T) {
@ -306,7 +305,7 @@ func TestSetTenantAuthNamespace(t *testing.T) {
// global vaultAuthNamespace is not set, tenant // global vaultAuthNamespace is not set, tenant
// vaultAuthNamespace will be configured as vaultNamespace by // vaultAuthNamespace will be configured as vaultNamespace by
// default // default
assert.Equal(tt, nil, config["vaultAuthNamespace"]) require.Nil(tt, config["vaultAuthNamespace"])
}) })
t.Run("no vaultNamespace", func(tt *testing.T) { t.Run("no vaultNamespace", func(tt *testing.T) {
@ -326,6 +325,6 @@ func TestSetTenantAuthNamespace(t *testing.T) {
kms.setTenantAuthNamespace(config) kms.setTenantAuthNamespace(config)
assert.Equal(tt, nil, config["vaultAuthNamespace"]) require.Nil(tt, config["vaultAuthNamespace"])
}) })
} }

View File

@ -84,9 +84,9 @@ func (cs *Server) CreateVolume(
return nil, err return nil, err
} }
backend := res.Volume backend := res.GetVolume()
log.DebugLog(ctx, "CephFS volume created: %s", backend.VolumeId) log.DebugLog(ctx, "CephFS volume created: %s", backend.GetVolumeId())
secret := req.GetSecrets() secret := req.GetSecrets()
cr, err := util.NewAdminCredentials(secret) cr, err := util.NewAdminCredentials(secret)
@ -97,7 +97,7 @@ func (cs *Server) CreateVolume(
} }
defer cr.DeleteCredentials() defer cr.DeleteCredentials()
nfsVolume, err := NewNFSVolume(ctx, backend.VolumeId) nfsVolume, err := NewNFSVolume(ctx, backend.GetVolumeId())
if err != nil { if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error()) return nil, status.Error(codes.InvalidArgument, err.Error())
} }

View File

@ -127,12 +127,12 @@ func (nv *NFSVolume) CreateExport(backend *csi.Volume) error {
if !nv.connected { if !nv.connected {
return fmt.Errorf("can not created export for %q: %w", nv, ErrNotConnected) return fmt.Errorf("can not created export for %q: %w", nv, ErrNotConnected)
} }
vctx := backend.GetVolumeContext()
fs := backend.VolumeContext["fsName"] fs := vctx["fsName"]
nfsCluster := backend.VolumeContext["nfsCluster"] nfsCluster := vctx["nfsCluster"]
path := backend.VolumeContext["subvolumePath"] path := vctx["subvolumePath"]
secTypes := backend.VolumeContext["secTypes"] secTypes := vctx["secTypes"]
clients := backend.VolumeContext["clients"] clients := vctx["clients"]
err := nv.setNFSCluster(nfsCluster) err := nv.setNFSCluster(nfsCluster)
if err != nil { if err != nil {

View File

@ -68,10 +68,10 @@ func (cs *ControllerServer) validateVolumeReq(ctx context.Context, req *csi.Crea
return err return err
} }
// Check sanity of request Name, Volume Capabilities // Check sanity of request Name, Volume Capabilities
if req.Name == "" { if req.GetName() == "" {
return status.Error(codes.InvalidArgument, "volume Name cannot be empty") return status.Error(codes.InvalidArgument, "volume Name cannot be empty")
} }
if req.VolumeCapabilities == nil { if req.GetVolumeCapabilities() == nil {
return status.Error(codes.InvalidArgument, "volume Capabilities cannot be empty") return status.Error(codes.InvalidArgument, "volume Capabilities cannot be empty")
} }
options := req.GetParameters() options := req.GetParameters()
@ -105,7 +105,7 @@ func (cs *ControllerServer) validateVolumeReq(ctx context.Context, req *csi.Crea
return err return err
} }
err = validateStriping(req.Parameters) err = validateStriping(req.GetParameters())
if err != nil { if err != nil {
return status.Error(codes.InvalidArgument, err.Error()) return status.Error(codes.InvalidArgument, err.Error())
} }
@ -156,13 +156,13 @@ func (cs *ControllerServer) parseVolCreateRequest(
// below capability check indicates that we support both {SINGLE_NODE or MULTI_NODE} WRITERs and the `isMultiWriter` // below capability check indicates that we support both {SINGLE_NODE or MULTI_NODE} WRITERs and the `isMultiWriter`
// flag has been set accordingly. // flag has been set accordingly.
isMultiWriter, isBlock := csicommon.IsBlockMultiWriter(req.VolumeCapabilities) isMultiWriter, isBlock := csicommon.IsBlockMultiWriter(req.GetVolumeCapabilities())
// below return value has set, if it is RWO mode File PVC. // below return value has set, if it is RWO mode File PVC.
isRWOFile := csicommon.IsFileRWO(req.VolumeCapabilities) isRWOFile := csicommon.IsFileRWO(req.GetVolumeCapabilities())
// below return value has set, if it is ReadOnly capability. // below return value has set, if it is ReadOnly capability.
isROOnly := csicommon.IsReaderOnly(req.VolumeCapabilities) isROOnly := csicommon.IsReaderOnly(req.GetVolumeCapabilities())
// We want to fail early if the user is trying to create a RWX on a non-block type device // We want to fail early if the user is trying to create a RWX on a non-block type device
if !isRWOFile && !isBlock && !isROOnly { if !isRWOFile && !isBlock && !isROOnly {
return nil, status.Error( return nil, status.Error(
@ -782,13 +782,13 @@ func checkContentSource(
req *csi.CreateVolumeRequest, req *csi.CreateVolumeRequest,
cr *util.Credentials, cr *util.Credentials,
) (*rbdVolume, *rbdSnapshot, error) { ) (*rbdVolume, *rbdSnapshot, error) {
if req.VolumeContentSource == nil { if req.GetVolumeContentSource() == nil {
return nil, nil, nil return nil, nil, nil
} }
volumeSource := req.VolumeContentSource volumeSource := req.GetVolumeContentSource()
switch volumeSource.Type.(type) { switch volumeSource.GetType().(type) {
case *csi.VolumeContentSource_Snapshot: case *csi.VolumeContentSource_Snapshot:
snapshot := req.VolumeContentSource.GetSnapshot() snapshot := req.GetVolumeContentSource().GetSnapshot()
if snapshot == nil { if snapshot == nil {
return nil, nil, status.Error(codes.NotFound, "volume Snapshot cannot be empty") return nil, nil, status.Error(codes.NotFound, "volume Snapshot cannot be empty")
} }
@ -808,7 +808,7 @@ func checkContentSource(
return nil, rbdSnap, nil return nil, rbdSnap, nil
case *csi.VolumeContentSource_Volume: case *csi.VolumeContentSource_Volume:
vol := req.VolumeContentSource.GetVolume() vol := req.GetVolumeContentSource().GetVolume()
if vol == nil { if vol == nil {
return nil, nil, status.Error(codes.NotFound, "volume cannot be empty") return nil, nil, status.Error(codes.NotFound, "volume cannot be empty")
} }
@ -1066,11 +1066,11 @@ func (cs *ControllerServer) ValidateVolumeCapabilities(
return nil, status.Error(codes.InvalidArgument, "empty volume ID in request") return nil, status.Error(codes.InvalidArgument, "empty volume ID in request")
} }
if len(req.VolumeCapabilities) == 0 { if len(req.GetVolumeCapabilities()) == 0 {
return nil, status.Error(codes.InvalidArgument, "empty volume capabilities in request") return nil, status.Error(codes.InvalidArgument, "empty volume capabilities in request")
} }
for _, capability := range req.VolumeCapabilities { for _, capability := range req.GetVolumeCapabilities() {
if capability.GetAccessMode().GetMode() != csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER { if capability.GetAccessMode().GetMode() != csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER {
return &csi.ValidateVolumeCapabilitiesResponse{Message: ""}, nil return &csi.ValidateVolumeCapabilitiesResponse{Message: ""}, nil
} }
@ -1078,7 +1078,7 @@ func (cs *ControllerServer) ValidateVolumeCapabilities(
return &csi.ValidateVolumeCapabilitiesResponse{ return &csi.ValidateVolumeCapabilitiesResponse{
Confirmed: &csi.ValidateVolumeCapabilitiesResponse_Confirmed{ Confirmed: &csi.ValidateVolumeCapabilitiesResponse_Confirmed{
VolumeCapabilities: req.VolumeCapabilities, VolumeCapabilities: req.GetVolumeCapabilities(),
}, },
}, nil }, nil
} }
@ -1297,10 +1297,10 @@ func (cs *ControllerServer) validateSnapshotReq(ctx context.Context, req *csi.Cr
} }
// Check sanity of request Snapshot Name, Source Volume Id // Check sanity of request Snapshot Name, Source Volume Id
if req.Name == "" { if req.GetName() == "" {
return status.Error(codes.InvalidArgument, "snapshot Name cannot be empty") return status.Error(codes.InvalidArgument, "snapshot Name cannot be empty")
} }
if req.SourceVolumeId == "" { if req.GetSourceVolumeId() == "" {
return status.Error(codes.InvalidArgument, "source Volume ID cannot be empty") return status.Error(codes.InvalidArgument, "source Volume ID cannot be empty")
} }

View File

@ -20,7 +20,6 @@ import (
"os" "os"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/ceph/ceph-csi/internal/util" "github.com/ceph/ceph-csi/internal/util"
@ -44,7 +43,7 @@ func TestSetupCSIAddonsServer(t *testing.T) {
// verify the socket file has been created // verify the socket file has been created
_, err = os.Stat(tmpDir + "/csi-addons.sock") _, err = os.Stat(tmpDir + "/csi-addons.sock")
assert.NoError(t, err) require.NoError(t, err)
// stop the gRPC server // stop the gRPC server
drv.cas.Stop() drv.cas.Stop()

View File

@ -312,7 +312,7 @@ func (ri *rbdImage) initKMS(ctx context.Context, volOptions, credentials map[str
case util.EncryptionTypeFile: case util.EncryptionTypeFile:
err = ri.configureFileEncryption(ctx, kmsID, credentials) err = ri.configureFileEncryption(ctx, kmsID, credentials)
case util.EncryptionTypeInvalid: case util.EncryptionTypeInvalid:
return fmt.Errorf("invalid encryption type") return errors.New("invalid encryption type")
case util.EncryptionTypeNone: case util.EncryptionTypeNone:
return nil return nil
} }

View File

@ -163,7 +163,7 @@ func (ns *NodeServer) populateRbdVol(
isBlock := req.GetVolumeCapability().GetBlock() != nil isBlock := req.GetVolumeCapability().GetBlock() != nil
disableInUseChecks := false disableInUseChecks := false
// MULTI_NODE_MULTI_WRITER is supported by default for Block access type volumes // MULTI_NODE_MULTI_WRITER is supported by default for Block access type volumes
if req.VolumeCapability.AccessMode.Mode == csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER { if req.GetVolumeCapability().GetAccessMode().GetMode() == csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER {
if !isBlock { if !isBlock {
log.WarningLog( log.WarningLog(
ctx, ctx,
@ -400,7 +400,7 @@ func (ns *NodeServer) stageTransaction(
var err error var err error
// Allow image to be mounted on multiple nodes if it is ROX // Allow image to be mounted on multiple nodes if it is ROX
if req.VolumeCapability.AccessMode.Mode == csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY { if req.GetVolumeCapability().GetAccessMode().GetMode() == csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY {
log.ExtendedLog(ctx, "setting disableInUseChecks on rbd volume to: %v", req.GetVolumeId) log.ExtendedLog(ctx, "setting disableInUseChecks on rbd volume to: %v", req.GetVolumeId)
volOptions.DisableInUseChecks = true volOptions.DisableInUseChecks = true
volOptions.readOnly = true volOptions.readOnly = true
@ -777,8 +777,9 @@ func (ns *NodeServer) mountVolumeToStagePath(
isBlock := req.GetVolumeCapability().GetBlock() != nil isBlock := req.GetVolumeCapability().GetBlock() != nil
rOnly := "ro" rOnly := "ro"
if req.VolumeCapability.AccessMode.Mode == csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY || mode := req.GetVolumeCapability().GetAccessMode().GetMode()
req.VolumeCapability.AccessMode.Mode == csi.VolumeCapability_AccessMode_SINGLE_NODE_READER_ONLY { if mode == csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY ||
mode == csi.VolumeCapability_AccessMode_SINGLE_NODE_READER_ONLY {
if !csicommon.MountOptionContains(opt, rOnly) { if !csicommon.MountOptionContains(opt, rOnly) {
opt = append(opt, rOnly) opt = append(opt, rOnly)
} }

View File

@ -27,7 +27,7 @@ import (
"github.com/ceph/ceph-csi/internal/util" "github.com/ceph/ceph-csi/internal/util"
"github.com/container-storage-interface/spec/lib/go/csi" "github.com/container-storage-interface/spec/lib/go/csi"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func TestGetStagingPath(t *testing.T) { func TestGetStagingPath(t *testing.T) {
@ -196,7 +196,7 @@ func TestNodeServer_appendReadAffinityMapOptions(t *testing.T) {
Mounter: currentTT.args.mounter, Mounter: currentTT.args.mounter,
} }
rv.appendReadAffinityMapOptions(currentTT.args.readAffinityMapOptions) rv.appendReadAffinityMapOptions(currentTT.args.readAffinityMapOptions)
assert.Equal(t, currentTT.want, rv.MapOptions) require.Equal(t, currentTT.want, rv.MapOptions)
}) })
} }
} }
@ -310,10 +310,10 @@ func TestReadAffinity_GetReadAffinityMapOptions(t *testing.T) {
tmpConfPath, tc.clusterID, ns.CLIReadAffinityOptions, nodeLabels, tmpConfPath, tc.clusterID, ns.CLIReadAffinityOptions, nodeLabels,
) )
if err != nil { if err != nil {
assert.Fail(t, err.Error()) require.Fail(t, err.Error())
} }
assert.Equal(t, tc.want, readAffinityMapOptions) require.Equal(t, tc.want, readAffinityMapOptions)
}) })
} }
} }

View File

@ -229,7 +229,7 @@ func waitForPath(ctx context.Context, pool, namespace, image string, maxRetries
func SetRbdNbdToolFeatures() { func SetRbdNbdToolFeatures() {
var stderr string var stderr string
// check if the module is loaded or compiled in // check if the module is loaded or compiled in
_, err := os.Stat(fmt.Sprintf("/sys/module/%s", moduleNbd)) _, err := os.Stat("/sys/module/" + moduleNbd)
if os.IsNotExist(err) { if os.IsNotExist(err) {
// try to load the module // try to load the module
_, stderr, err = util.ExecCommand(context.TODO(), "modprobe", moduleNbd) _, stderr, err = util.ExecCommand(context.TODO(), "modprobe", moduleNbd)
@ -377,7 +377,7 @@ func appendNbdDeviceTypeAndOptions(cmdArgs []string, userOptions, cookie string)
} }
if hasNBDCookieSupport { if hasNBDCookieSupport {
cmdArgs = append(cmdArgs, fmt.Sprintf("--cookie=%s", cookie)) cmdArgs = append(cmdArgs, "--cookie="+cookie)
} }
} }
@ -409,7 +409,7 @@ func appendKRbdDeviceTypeAndOptions(cmdArgs []string, userOptions string) []stri
// provided for rbd integrated cli to rbd-nbd cli format specific. // provided for rbd integrated cli to rbd-nbd cli format specific.
func appendRbdNbdCliOptions(cmdArgs []string, userOptions, cookie string) []string { func appendRbdNbdCliOptions(cmdArgs []string, userOptions, cookie string) []string {
if !strings.Contains(userOptions, useNbdNetlink) { if !strings.Contains(userOptions, useNbdNetlink) {
cmdArgs = append(cmdArgs, fmt.Sprintf("--%s", useNbdNetlink)) cmdArgs = append(cmdArgs, "--"+useNbdNetlink)
} }
if !strings.Contains(userOptions, setNbdReattach) { if !strings.Contains(userOptions, setNbdReattach) {
cmdArgs = append(cmdArgs, fmt.Sprintf("--%s=%d", setNbdReattach, defaultNbdReAttachTimeout)) cmdArgs = append(cmdArgs, fmt.Sprintf("--%s=%d", setNbdReattach, defaultNbdReAttachTimeout))
@ -418,12 +418,12 @@ func appendRbdNbdCliOptions(cmdArgs []string, userOptions, cookie string) []stri
cmdArgs = append(cmdArgs, fmt.Sprintf("--%s=%d", setNbdIOTimeout, defaultNbdIOTimeout)) cmdArgs = append(cmdArgs, fmt.Sprintf("--%s=%d", setNbdIOTimeout, defaultNbdIOTimeout))
} }
if hasNBDCookieSupport { if hasNBDCookieSupport {
cmdArgs = append(cmdArgs, fmt.Sprintf("--cookie=%s", cookie)) cmdArgs = append(cmdArgs, "--cookie="+cookie)
} }
if userOptions != "" { if userOptions != "" {
options := strings.Split(userOptions, ",") options := strings.Split(userOptions, ",")
for _, opt := range options { for _, opt := range options {
cmdArgs = append(cmdArgs, fmt.Sprintf("--%s", opt)) cmdArgs = append(cmdArgs, "--"+opt)
} }
} }
@ -566,7 +566,7 @@ func detachRBDImageOrDeviceSpec(
return err return err
} }
if len(mapper) > 0 { if mapper != "" {
// mapper found, so it is open Luks device // mapper found, so it is open Luks device
err = util.CloseEncryptedVolume(ctx, mapperFile) err = util.CloseEncryptedVolume(ctx, mapperFile)
if err != nil { if err != nil {

View File

@ -19,6 +19,7 @@ package rbd
import ( import (
"context" "context"
"crypto/sha256" "crypto/sha256"
"encoding/hex"
"fmt" "fmt"
"path/filepath" "path/filepath"
"sync" "sync"
@ -79,7 +80,7 @@ func getSecret(c *k8s.Clientset, ns, name string) (map[string]string, error) {
func formatStagingTargetPath(c *k8s.Clientset, pv *v1.PersistentVolume, stagingPath string) (string, error) { func formatStagingTargetPath(c *k8s.Clientset, pv *v1.PersistentVolume, stagingPath string) (string, error) {
// Kubernetes 1.24+ uses a hash of the volume-id in the path name // Kubernetes 1.24+ uses a hash of the volume-id in the path name
unique := sha256.Sum256([]byte(pv.Spec.CSI.VolumeHandle)) unique := sha256.Sum256([]byte(pv.Spec.CSI.VolumeHandle))
targetPath := filepath.Join(stagingPath, pv.Spec.CSI.Driver, fmt.Sprintf("%x", unique), "globalmount") targetPath := filepath.Join(stagingPath, pv.Spec.CSI.Driver, hex.EncodeToString(unique[:]), "globalmount")
major, minor, err := kubeclient.GetServerVersion(c) major, minor, err := kubeclient.GetServerVersion(c)
if err != nil { if err != nil {

View File

@ -294,7 +294,7 @@ func (rv *rbdVolume) Exists(ctx context.Context, parentVol *rbdVolume) (bool, er
// NOTE: Return volsize should be on-disk volsize, not request vol size, so // NOTE: Return volsize should be on-disk volsize, not request vol size, so
// save it for size checks before fetching image data // save it for size checks before fetching image data
requestSize := rv.VolSize //nolint:ifshort // FIXME: rename and split function into helpers requestSize := rv.VolSize
// Fetch on-disk image attributes and compare against request // Fetch on-disk image attributes and compare against request
err = rv.getImageInfo() err = rv.getImageInfo()
if err != nil { if err != nil {

View File

@ -1163,8 +1163,6 @@ func generateVolumeFromVolumeID(
// GenVolFromVolID generates a rbdVolume structure from the provided identifier, updating // GenVolFromVolID generates a rbdVolume structure from the provided identifier, updating
// the structure with elements from on-disk image metadata as well. // the structure with elements from on-disk image metadata as well.
//
//nolint:golint // TODO: returning unexported rbdVolume type, use an interface instead.
func GenVolFromVolID( func GenVolFromVolID(
ctx context.Context, ctx context.Context,
volumeID string, volumeID string,
@ -1231,7 +1229,7 @@ func generateVolumeFromMapping(
vi.ClusterID) vi.ClusterID)
// Add mapping clusterID to Identifier // Add mapping clusterID to Identifier
nvi.ClusterID = mappedClusterID nvi.ClusterID = mappedClusterID
poolID := fmt.Sprintf("%d", (vi.LocationID)) poolID := strconv.FormatInt(vi.LocationID, 10)
for _, pools := range cm.RBDpoolIDMappingInfo { for _, pools := range cm.RBDpoolIDMappingInfo {
for key, val := range pools { for key, val := range pools {
mappedPoolID := util.GetMappedID(key, val, poolID) mappedPoolID := util.GetMappedID(key, val, poolID)
@ -1525,7 +1523,7 @@ func (rv *rbdVolume) setImageOptions(ctx context.Context, options *librbd.ImageO
logMsg := fmt.Sprintf("setting image options on %s", rv) logMsg := fmt.Sprintf("setting image options on %s", rv)
if rv.DataPool != "" { if rv.DataPool != "" {
logMsg += fmt.Sprintf(", data pool %s", rv.DataPool) logMsg += ", data pool %s" + rv.DataPool
err = options.SetString(librbd.RbdImageOptionDataPool, rv.DataPool) err = options.SetString(librbd.RbdImageOptionDataPool, rv.DataPool)
if err != nil { if err != nil {
return fmt.Errorf("failed to set data pool: %w", err) return fmt.Errorf("failed to set data pool: %w", err)

View File

@ -24,7 +24,7 @@ import (
"testing" "testing"
librbd "github.com/ceph/go-ceph/rbd" librbd "github.com/ceph/go-ceph/rbd"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func TestHasSnapshotFeature(t *testing.T) { func TestHasSnapshotFeature(t *testing.T) {
@ -165,11 +165,11 @@ func TestValidateImageFeatures(t *testing.T) {
for _, test := range tests { for _, test := range tests {
err := test.rbdVol.validateImageFeatures(test.imageFeatures) err := test.rbdVol.validateImageFeatures(test.imageFeatures)
if test.isErr { if test.isErr {
assert.EqualError(t, err, test.errMsg) require.EqualError(t, err, test.errMsg)
continue continue
} }
assert.Nil(t, err) require.NoError(t, err)
} }
} }

View File

@ -48,7 +48,7 @@ func ExecuteCommandWithNSEnter(ctx context.Context, netPath, program string, arg
return "", "", fmt.Errorf("failed to get stat for %s %w", netPath, err) return "", "", fmt.Errorf("failed to get stat for %s %w", netPath, err)
} }
// nsenter --net=%s -- <program> <args> // nsenter --net=%s -- <program> <args>
args = append([]string{fmt.Sprintf("--net=%s", netPath), "--", program}, args...) args = append([]string{"--net=" + netPath, "--", program}, args...)
sanitizedArgs := StripSecretInArgs(args) sanitizedArgs := StripSecretInArgs(args)
cmd := exec.Command(nsenter, args...) // #nosec:G204, commands executing not vulnerable. cmd := exec.Command(nsenter, args...) // #nosec:G204, commands executing not vulnerable.
cmd.Stdout = &stdoutBuf cmd.Stdout = &stdoutBuf

View File

@ -159,7 +159,7 @@ func TestGetClusterMappingInfo(t *testing.T) {
}) })
} }
clusterMappingConfigFile = fmt.Sprintf("%s/mapping.json", mappingBasePath) clusterMappingConfigFile = mappingBasePath + "/mapping.json"
err = os.WriteFile(clusterMappingConfigFile, mappingFileContent, 0o600) err = os.WriteFile(clusterMappingConfigFile, mappingFileContent, 0o600)
if err != nil { if err != nil {
t.Errorf("failed to write mapping content error = %v", err) t.Errorf("failed to write mapping content error = %v", err)

View File

@ -19,7 +19,7 @@ package util
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func Test_getCrushLocationMap(t *testing.T) { func Test_getCrushLocationMap(t *testing.T) {
@ -105,7 +105,7 @@ func Test_getCrushLocationMap(t *testing.T) {
currentTT := tt currentTT := tt
t.Run(currentTT.name, func(t *testing.T) { t.Run(currentTT.name, func(t *testing.T) {
t.Parallel() t.Parallel()
assert.Equal(t, require.Equal(t,
currentTT.want, currentTT.want,
getCrushLocationMap(currentTT.args.crushLocationLabels, currentTT.args.nodeLabels)) getCrushLocationMap(currentTT.args.crushLocationLabels, currentTT.args.nodeLabels))
}) })

View File

@ -23,7 +23,6 @@ import (
"github.com/ceph/ceph-csi/internal/kms" "github.com/ceph/ceph-csi/internal/kms"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -35,8 +34,8 @@ func TestGenerateNewEncryptionPassphrase(t *testing.T) {
// b64Passphrase is URL-encoded, decode to verify the length of the // b64Passphrase is URL-encoded, decode to verify the length of the
// passphrase // passphrase
passphrase, err := base64.URLEncoding.DecodeString(b64Passphrase) passphrase, err := base64.URLEncoding.DecodeString(b64Passphrase)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, defaultEncryptionPassphraseSize, len(passphrase)) require.Len(t, passphrase, defaultEncryptionPassphraseSize)
} }
func TestKMSWorkflow(t *testing.T) { func TestKMSWorkflow(t *testing.T) {
@ -47,52 +46,52 @@ func TestKMSWorkflow(t *testing.T) {
} }
kmsProvider, err := kms.GetDefaultKMS(secrets) kmsProvider, err := kms.GetDefaultKMS(secrets)
assert.NoError(t, err) require.NoError(t, err)
require.NotNil(t, kmsProvider) require.NotNil(t, kmsProvider)
ve, err := NewVolumeEncryption("", kmsProvider) ve, err := NewVolumeEncryption("", kmsProvider)
assert.NoError(t, err) require.NoError(t, err)
require.NotNil(t, ve) require.NotNil(t, ve)
assert.Equal(t, kms.DefaultKMSType, ve.GetID()) require.Equal(t, kms.DefaultKMSType, ve.GetID())
volumeID := "volume-id" volumeID := "volume-id"
ctx := context.TODO() ctx := context.TODO()
err = ve.StoreNewCryptoPassphrase(ctx, volumeID, defaultEncryptionPassphraseSize) err = ve.StoreNewCryptoPassphrase(ctx, volumeID, defaultEncryptionPassphraseSize)
assert.NoError(t, err) require.NoError(t, err)
passphrase, err := ve.GetCryptoPassphrase(ctx, volumeID) passphrase, err := ve.GetCryptoPassphrase(ctx, volumeID)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, secrets["encryptionPassphrase"], passphrase) require.Equal(t, secrets["encryptionPassphrase"], passphrase)
} }
func TestEncryptionType(t *testing.T) { func TestEncryptionType(t *testing.T) {
t.Parallel() t.Parallel()
assert.EqualValues(t, EncryptionTypeInvalid, ParseEncryptionType("wat?")) require.EqualValues(t, EncryptionTypeInvalid, ParseEncryptionType("wat?"))
assert.EqualValues(t, EncryptionTypeInvalid, ParseEncryptionType("both")) require.EqualValues(t, EncryptionTypeInvalid, ParseEncryptionType("both"))
assert.EqualValues(t, EncryptionTypeInvalid, ParseEncryptionType("file,block")) require.EqualValues(t, EncryptionTypeInvalid, ParseEncryptionType("file,block"))
assert.EqualValues(t, EncryptionTypeInvalid, ParseEncryptionType("block,file")) require.EqualValues(t, EncryptionTypeInvalid, ParseEncryptionType("block,file"))
assert.EqualValues(t, EncryptionTypeBlock, ParseEncryptionType("block")) require.EqualValues(t, EncryptionTypeBlock, ParseEncryptionType("block"))
assert.EqualValues(t, EncryptionTypeFile, ParseEncryptionType("file")) require.EqualValues(t, EncryptionTypeFile, ParseEncryptionType("file"))
assert.EqualValues(t, EncryptionTypeNone, ParseEncryptionType("")) require.EqualValues(t, EncryptionTypeNone, ParseEncryptionType(""))
for _, s := range []string{"file", "block", ""} { for _, s := range []string{"file", "block", ""} {
assert.EqualValues(t, s, ParseEncryptionType(s).String()) require.EqualValues(t, s, ParseEncryptionType(s).String())
} }
} }
func TestFetchEncryptionType(t *testing.T) { func TestFetchEncryptionType(t *testing.T) {
t.Parallel() t.Parallel()
volOpts := map[string]string{} volOpts := map[string]string{}
assert.EqualValues(t, EncryptionTypeBlock, FetchEncryptionType(volOpts, EncryptionTypeBlock)) require.EqualValues(t, EncryptionTypeBlock, FetchEncryptionType(volOpts, EncryptionTypeBlock))
assert.EqualValues(t, EncryptionTypeFile, FetchEncryptionType(volOpts, EncryptionTypeFile)) require.EqualValues(t, EncryptionTypeFile, FetchEncryptionType(volOpts, EncryptionTypeFile))
assert.EqualValues(t, EncryptionTypeNone, FetchEncryptionType(volOpts, EncryptionTypeNone)) require.EqualValues(t, EncryptionTypeNone, FetchEncryptionType(volOpts, EncryptionTypeNone))
volOpts["encryptionType"] = "" volOpts["encryptionType"] = ""
assert.EqualValues(t, EncryptionTypeInvalid, FetchEncryptionType(volOpts, EncryptionTypeNone)) require.EqualValues(t, EncryptionTypeInvalid, FetchEncryptionType(volOpts, EncryptionTypeNone))
volOpts["encryptionType"] = "block" volOpts["encryptionType"] = "block"
assert.EqualValues(t, EncryptionTypeBlock, FetchEncryptionType(volOpts, EncryptionTypeNone)) require.EqualValues(t, EncryptionTypeBlock, FetchEncryptionType(volOpts, EncryptionTypeNone))
volOpts["encryptionType"] = "file" volOpts["encryptionType"] = "file"
assert.EqualValues(t, EncryptionTypeFile, FetchEncryptionType(volOpts, EncryptionTypeNone)) require.EqualValues(t, EncryptionTypeFile, FetchEncryptionType(volOpts, EncryptionTypeNone))
volOpts["encryptionType"] = "INVALID" volOpts["encryptionType"] = "INVALID"
assert.EqualValues(t, EncryptionTypeInvalid, FetchEncryptionType(volOpts, EncryptionTypeNone)) require.EqualValues(t, EncryptionTypeInvalid, FetchEncryptionType(volOpts, EncryptionTypeNone))
} }

View File

@ -463,5 +463,5 @@ func Unlock(
return initializeAndUnlock(ctx, fscryptContext, encryptedPath, protectorName, keyFn) return initializeAndUnlock(ctx, fscryptContext, encryptedPath, protectorName, keyFn)
} }
return fmt.Errorf("unsupported") return errors.New("unsupported")
} }

View File

@ -20,7 +20,7 @@ import (
kmsapi "github.com/ceph/ceph-csi/internal/kms" kmsapi "github.com/ceph/ceph-csi/internal/kms"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func TestGetPassphraseFromKMS(t *testing.T) { func TestGetPassphraseFromKMS(t *testing.T) {
@ -31,7 +31,7 @@ func TestGetPassphraseFromKMS(t *testing.T) {
continue continue
} }
kms := kmsapi.GetKMSTestDummy(provider.UniqueID) kms := kmsapi.GetKMSTestDummy(provider.UniqueID)
assert.NotNil(t, kms) require.NotNil(t, kms)
volEnc, err := NewVolumeEncryption(provider.UniqueID, kms) volEnc, err := NewVolumeEncryption(provider.UniqueID, kms)
if errors.Is(err, ErrDEKStoreNeeded) { if errors.Is(err, ErrDEKStoreNeeded) {
@ -40,14 +40,14 @@ func TestGetPassphraseFromKMS(t *testing.T) {
continue // currently unsupported by fscrypt integration continue // currently unsupported by fscrypt integration
} }
} }
assert.NotNil(t, volEnc) require.NotNil(t, volEnc)
if kms.RequiresDEKStore() == kmsapi.DEKStoreIntegrated { if kms.RequiresDEKStore() == kmsapi.DEKStoreIntegrated {
continue continue
} }
secret, err := kms.GetSecret(context.TODO(), "") secret, err := kms.GetSecret(context.TODO(), "")
assert.NoError(t, err, provider.UniqueID) require.NoError(t, err, provider.UniqueID)
assert.NotEmpty(t, secret, provider.UniqueID) require.NotEmpty(t, secret, provider.UniqueID)
} }
} }

View File

@ -70,7 +70,7 @@ func getCgroupPidsFile() (string, error) {
} }
} }
if slice == "" { if slice == "" {
return "", fmt.Errorf("could not find a cgroup for 'pids'") return "", errors.New("could not find a cgroup for 'pids'")
} }
return pidsMax, nil return pidsMax, nil
@ -112,7 +112,7 @@ func GetPIDLimit() (int, error) {
func SetPIDLimit(limit int) error { func SetPIDLimit(limit int) error {
limitStr := "max" limitStr := "max"
if limit != -1 { if limit != -1 {
limitStr = fmt.Sprintf("%d", limit) limitStr = strconv.Itoa(limit)
} }
pidsMax, err := getCgroupPidsFile() pidsMax, err := getCgroupPidsFile()

View File

@ -19,7 +19,7 @@ package util
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func TestReadAffinity_ConstructReadAffinityMapOption(t *testing.T) { func TestReadAffinity_ConstructReadAffinityMapOption(t *testing.T) {
@ -62,7 +62,7 @@ func TestReadAffinity_ConstructReadAffinityMapOption(t *testing.T) {
currentTT := tt currentTT := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
assert.Contains(t, currentTT.wantAny, ConstructReadAffinityMapOption(currentTT.crushLocationmap)) require.Contains(t, currentTT.wantAny, ConstructReadAffinityMapOption(currentTT.crushLocationmap))
}) })
} }
} }

View File

@ -70,7 +70,6 @@ func TryRADOSAborted(opErr error) error {
return opErr return opErr
} }
//nolint:errorlint // Can't use errors.As() because rados.radosError is private.
errnoErr, ok := radosOpErr.OpError.(interface{ ErrorCode() int }) errnoErr, ok := radosOpErr.OpError.(interface{ ErrorCode() int })
if !ok { if !ok {
return opErr return opErr

View File

@ -22,7 +22,7 @@ import (
"github.com/ceph/ceph-csi/internal/util/reftracker/radoswrapper" "github.com/ceph/ceph-csi/internal/util/reftracker/radoswrapper"
"github.com/ceph/ceph-csi/internal/util/reftracker/reftype" "github.com/ceph/ceph-csi/internal/util/reftracker/reftype"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
const rtName = "hello-rt" const rtName = "hello-rt"
@ -36,8 +36,8 @@ func TestRTAdd(t *testing.T) {
ioctx := radoswrapper.NewFakeIOContext(radoswrapper.NewFakeRados()) ioctx := radoswrapper.NewFakeIOContext(radoswrapper.NewFakeRados())
created, err := Add(ioctx, "", nil) created, err := Add(ioctx, "", nil)
assert.Error(ts, err) require.Error(ts, err)
assert.False(ts, created) require.False(ts, created)
}) })
// Verify input validation for nil and empty refs. // Verify input validation for nil and empty refs.
@ -51,8 +51,8 @@ func TestRTAdd(t *testing.T) {
} }
for _, ref := range refs { for _, ref := range refs {
created, err := Add(ioctx, rtName, ref) created, err := Add(ioctx, rtName, ref)
assert.Error(ts, err) require.Error(ts, err)
assert.False(ts, created) require.False(ts, created)
} }
}) })
@ -66,8 +66,8 @@ func TestRTAdd(t *testing.T) {
"ref2": {}, "ref2": {},
"ref3": {}, "ref3": {},
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.True(ts, created) require.True(ts, created)
}) })
// Add refs where each Add() has some of the refs overlapping // Add refs where each Add() has some of the refs overlapping
@ -80,8 +80,8 @@ func TestRTAdd(t *testing.T) {
"ref1": {}, "ref1": {},
"ref2": {}, "ref2": {},
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.True(ts, created) require.True(ts, created)
refsTable := []map[string]struct{}{ refsTable := []map[string]struct{}{
{"ref2": {}, "ref3": {}}, {"ref2": {}, "ref3": {}},
@ -90,8 +90,8 @@ func TestRTAdd(t *testing.T) {
} }
for _, refs := range refsTable { for _, refs := range refsTable {
created, err = Add(ioctx, rtName, refs) created, err = Add(ioctx, rtName, refs)
assert.NoError(ts, err) require.NoError(ts, err)
assert.False(ts, created) require.False(ts, created)
} }
}) })
} }
@ -110,8 +110,8 @@ func TestRTRemove(t *testing.T) {
} }
for _, ref := range refs { for _, ref := range refs {
created, err := Remove(ioctx, rtName, ref) created, err := Remove(ioctx, rtName, ref)
assert.Error(ts, err) require.Error(ts, err)
assert.False(ts, created) require.False(ts, created)
} }
}) })
@ -124,8 +124,8 @@ func TestRTRemove(t *testing.T) {
deleted, err := Remove(ioctx, "xxx", map[string]reftype.RefType{ deleted, err := Remove(ioctx, "xxx", map[string]reftype.RefType{
"ref1": reftype.Normal, "ref1": reftype.Normal,
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.True(ts, deleted) require.True(ts, deleted)
}) })
// Removing only non-existent refs should not result in reftracker object // Removing only non-existent refs should not result in reftracker object
@ -140,16 +140,16 @@ func TestRTRemove(t *testing.T) {
"ref2": {}, "ref2": {},
"ref3": {}, "ref3": {},
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.True(ts, created) require.True(ts, created)
deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{ deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{
"refX": reftype.Normal, "refX": reftype.Normal,
"refY": reftype.Normal, "refY": reftype.Normal,
"refZ": reftype.Normal, "refZ": reftype.Normal,
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.False(ts, deleted) require.False(ts, deleted)
}) })
// Removing all refs plus some surplus should result in reftracker object // Removing all refs plus some surplus should result in reftracker object
@ -162,8 +162,8 @@ func TestRTRemove(t *testing.T) {
created, err := Add(ioctx, rtName, map[string]struct{}{ created, err := Add(ioctx, rtName, map[string]struct{}{
"ref": {}, "ref": {},
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.True(ts, created) require.True(ts, created)
deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{ deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{
"refX": reftype.Normal, "refX": reftype.Normal,
@ -171,8 +171,8 @@ func TestRTRemove(t *testing.T) {
"ref": reftype.Normal, "ref": reftype.Normal,
"refZ": reftype.Normal, "refZ": reftype.Normal,
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.True(ts, deleted) require.True(ts, deleted)
}) })
// Bulk removal of all refs should result in reftracker object deletion. // Bulk removal of all refs should result in reftracker object deletion.
@ -189,12 +189,12 @@ func TestRTRemove(t *testing.T) {
} }
created, err := Add(ioctx, rtName, refsToAdd) created, err := Add(ioctx, rtName, refsToAdd)
assert.NoError(ts, err) require.NoError(ts, err)
assert.True(ts, created) require.True(ts, created)
deleted, err := Remove(ioctx, rtName, refsToRemove) deleted, err := Remove(ioctx, rtName, refsToRemove)
assert.NoError(ts, err) require.NoError(ts, err)
assert.True(ts, deleted) require.True(ts, deleted)
}) })
// Removal of all refs one-by-one should result in reftracker object deletion // Removal of all refs one-by-one should result in reftracker object deletion
@ -209,23 +209,23 @@ func TestRTRemove(t *testing.T) {
"ref2": {}, "ref2": {},
"ref3": {}, "ref3": {},
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.True(ts, created) require.True(ts, created)
for _, k := range []string{"ref3", "ref2"} { for _, k := range []string{"ref3", "ref2"} {
deleted, errRemove := Remove(ioctx, rtName, map[string]reftype.RefType{ deleted, errRemove := Remove(ioctx, rtName, map[string]reftype.RefType{
k: reftype.Normal, k: reftype.Normal,
}) })
assert.NoError(ts, errRemove) require.NoError(ts, errRemove)
assert.False(ts, deleted) require.False(ts, deleted)
} }
// Remove the last reference. It should remove the whole reftracker object too. // Remove the last reference. It should remove the whole reftracker object too.
deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{ deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{
"ref1": reftype.Normal, "ref1": reftype.Normal,
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.True(ts, deleted) require.True(ts, deleted)
}) })
// Cycle through reftracker object twice. // Cycle through reftracker object twice.
@ -246,12 +246,12 @@ func TestRTRemove(t *testing.T) {
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
created, err := Add(ioctx, rtName, refsToAdd) created, err := Add(ioctx, rtName, refsToAdd)
assert.NoError(ts, err) require.NoError(ts, err)
assert.True(ts, created) require.True(ts, created)
deleted, err := Remove(ioctx, rtName, refsToRemove) deleted, err := Remove(ioctx, rtName, refsToRemove)
assert.NoError(ts, err) require.NoError(ts, err)
assert.True(ts, deleted) require.True(ts, deleted)
} }
}) })
@ -265,8 +265,8 @@ func TestRTRemove(t *testing.T) {
"ref1": {}, "ref1": {},
"ref2": {}, "ref2": {},
}) })
assert.True(ts, created) require.True(ts, created)
assert.NoError(ts, err) require.NoError(ts, err)
refsTable := []map[string]struct{}{ refsTable := []map[string]struct{}{
{"ref2": {}, "ref3": {}}, {"ref2": {}, "ref3": {}},
{"ref3": {}, "ref4": {}}, {"ref3": {}, "ref4": {}},
@ -274,8 +274,8 @@ func TestRTRemove(t *testing.T) {
} }
for _, refs := range refsTable { for _, refs := range refsTable {
created, err = Add(ioctx, rtName, refs) created, err = Add(ioctx, rtName, refs)
assert.False(ts, created) require.False(ts, created)
assert.NoError(ts, err) require.NoError(ts, err)
} }
deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{ deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{
@ -285,8 +285,8 @@ func TestRTRemove(t *testing.T) {
"ref4": reftype.Normal, "ref4": reftype.Normal,
"ref5": reftype.Normal, "ref5": reftype.Normal,
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.True(ts, deleted) require.True(ts, deleted)
}) })
} }
@ -307,12 +307,12 @@ func TestRTMask(t *testing.T) {
} }
created, err := Add(ioctx, rtName, refsToAdd) created, err := Add(ioctx, rtName, refsToAdd)
assert.NoError(ts, err) require.NoError(ts, err)
assert.True(ts, created) require.True(ts, created)
deleted, err := Remove(ioctx, rtName, refsToRemove) deleted, err := Remove(ioctx, rtName, refsToRemove)
assert.NoError(ts, err) require.NoError(ts, err)
assert.True(ts, deleted) require.True(ts, deleted)
}) })
// Masking all refs one-by-one should result in reftracker object deletion in // Masking all refs one-by-one should result in reftracker object deletion in
@ -327,15 +327,15 @@ func TestRTMask(t *testing.T) {
"ref2": {}, "ref2": {},
"ref3": {}, "ref3": {},
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.True(ts, created) require.True(ts, created)
for _, k := range []string{"ref3", "ref2"} { for _, k := range []string{"ref3", "ref2"} {
deleted, errRemove := Remove(ioctx, rtName, map[string]reftype.RefType{ deleted, errRemove := Remove(ioctx, rtName, map[string]reftype.RefType{
k: reftype.Mask, k: reftype.Mask,
}) })
assert.NoError(ts, errRemove) require.NoError(ts, errRemove)
assert.False(ts, deleted) require.False(ts, deleted)
} }
// Remove the last reference. It should delete the whole reftracker object // Remove the last reference. It should delete the whole reftracker object
@ -343,8 +343,8 @@ func TestRTMask(t *testing.T) {
deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{ deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{
"ref1": reftype.Mask, "ref1": reftype.Mask,
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.True(ts, deleted) require.True(ts, deleted)
}) })
// Bulk removing two (out of 3) refs and then masking the ref that's left // Bulk removing two (out of 3) refs and then masking the ref that's left
@ -359,21 +359,21 @@ func TestRTMask(t *testing.T) {
"ref2": {}, "ref2": {},
"ref3": {}, "ref3": {},
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.True(ts, created) require.True(ts, created)
deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{ deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{
"ref1": reftype.Normal, "ref1": reftype.Normal,
"ref2": reftype.Normal, "ref2": reftype.Normal,
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.False(ts, deleted) require.False(ts, deleted)
deleted, err = Remove(ioctx, rtName, map[string]reftype.RefType{ deleted, err = Remove(ioctx, rtName, map[string]reftype.RefType{
"ref3": reftype.Mask, "ref3": reftype.Mask,
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.True(ts, deleted) require.True(ts, deleted)
}) })
// Bulk masking two (out of 3) refs and then removing the ref that's left // Bulk masking two (out of 3) refs and then removing the ref that's left
@ -388,21 +388,21 @@ func TestRTMask(t *testing.T) {
"ref2": {}, "ref2": {},
"ref3": {}, "ref3": {},
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.True(ts, created) require.True(ts, created)
deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{ deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{
"ref1": reftype.Mask, "ref1": reftype.Mask,
"ref2": reftype.Mask, "ref2": reftype.Mask,
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.False(ts, deleted) require.False(ts, deleted)
deleted, err = Remove(ioctx, rtName, map[string]reftype.RefType{ deleted, err = Remove(ioctx, rtName, map[string]reftype.RefType{
"ref3": reftype.Normal, "ref3": reftype.Normal,
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.True(ts, deleted) require.True(ts, deleted)
}) })
// Verify that masking refs hides them from future Add()s. // Verify that masking refs hides them from future Add()s.
@ -416,28 +416,28 @@ func TestRTMask(t *testing.T) {
"ref2": {}, "ref2": {},
"ref3": {}, "ref3": {},
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.True(ts, created) require.True(ts, created)
deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{ deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{
"ref1": reftype.Mask, "ref1": reftype.Mask,
"ref2": reftype.Mask, "ref2": reftype.Mask,
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.False(ts, deleted) require.False(ts, deleted)
created, err = Add(ioctx, rtName, map[string]struct{}{ created, err = Add(ioctx, rtName, map[string]struct{}{
"ref1": {}, "ref1": {},
"ref2": {}, "ref2": {},
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.False(ts, created) require.False(ts, created)
deleted, err = Remove(ioctx, rtName, map[string]reftype.RefType{ deleted, err = Remove(ioctx, rtName, map[string]reftype.RefType{
"ref3": reftype.Normal, "ref3": reftype.Normal,
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.True(ts, deleted) require.True(ts, deleted)
}) })
// Verify that masked refs may be removed with reftype.Normal and re-added. // Verify that masked refs may be removed with reftype.Normal and re-added.
@ -451,41 +451,41 @@ func TestRTMask(t *testing.T) {
"ref2": {}, "ref2": {},
"ref3": {}, "ref3": {},
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.True(ts, created) require.True(ts, created)
deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{ deleted, err := Remove(ioctx, rtName, map[string]reftype.RefType{
"ref1": reftype.Mask, "ref1": reftype.Mask,
"ref2": reftype.Mask, "ref2": reftype.Mask,
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.False(ts, deleted) require.False(ts, deleted)
deleted, err = Remove(ioctx, rtName, map[string]reftype.RefType{ deleted, err = Remove(ioctx, rtName, map[string]reftype.RefType{
"ref1": reftype.Normal, "ref1": reftype.Normal,
"ref2": reftype.Normal, "ref2": reftype.Normal,
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.False(ts, deleted) require.False(ts, deleted)
created, err = Add(ioctx, rtName, map[string]struct{}{ created, err = Add(ioctx, rtName, map[string]struct{}{
"ref1": {}, "ref1": {},
"ref2": {}, "ref2": {},
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.False(ts, created) require.False(ts, created)
deleted, err = Remove(ioctx, rtName, map[string]reftype.RefType{ deleted, err = Remove(ioctx, rtName, map[string]reftype.RefType{
"ref3": reftype.Normal, "ref3": reftype.Normal,
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.False(ts, deleted) require.False(ts, deleted)
deleted, err = Remove(ioctx, rtName, map[string]reftype.RefType{ deleted, err = Remove(ioctx, rtName, map[string]reftype.RefType{
"ref1": reftype.Normal, "ref1": reftype.Normal,
"ref2": reftype.Normal, "ref2": reftype.Normal,
}) })
assert.NoError(ts, err) require.NoError(ts, err)
assert.True(ts, deleted) require.True(ts, deleted)
}) })
} }

View File

@ -19,7 +19,7 @@ package reftype
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func TestRefTypeBytes(t *testing.T) { func TestRefTypeBytes(t *testing.T) {
@ -41,7 +41,7 @@ func TestRefTypeBytes(t *testing.T) {
for i := range expectedBytes { for i := range expectedBytes {
bs := ToBytes(refTypes[i]) bs := ToBytes(refTypes[i])
assert.Equal(ts, expectedBytes[i], bs) require.Equal(ts, expectedBytes[i], bs)
} }
}) })
@ -50,14 +50,14 @@ func TestRefTypeBytes(t *testing.T) {
for i := range refTypes { for i := range refTypes {
refType, err := FromBytes(expectedBytes[i]) refType, err := FromBytes(expectedBytes[i])
assert.NoError(ts, err) require.NoError(ts, err)
assert.Equal(ts, refTypes[i], refType) require.Equal(ts, refTypes[i], refType)
} }
_, err := FromBytes(refTypeInvalidBytes) _, err := FromBytes(refTypeInvalidBytes)
assert.Error(ts, err) require.Error(ts, err)
_, err = FromBytes(refTypeWrongSizeBytes) _, err = FromBytes(refTypeWrongSizeBytes)
assert.Error(ts, err) require.Error(ts, err)
}) })
} }

View File

@ -19,7 +19,7 @@ package v1
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func TestV1RefCountBytes(t *testing.T) { func TestV1RefCountBytes(t *testing.T) {
@ -35,17 +35,17 @@ func TestV1RefCountBytes(t *testing.T) {
ts.Parallel() ts.Parallel()
bs := refCountValue.toBytes() bs := refCountValue.toBytes()
assert.Equal(ts, refCountBytes, bs) require.Equal(ts, refCountBytes, bs)
}) })
t.Run("FromBytes", func(ts *testing.T) { t.Run("FromBytes", func(ts *testing.T) {
ts.Parallel() ts.Parallel()
rc, err := refCountFromBytes(refCountBytes) rc, err := refCountFromBytes(refCountBytes)
assert.NoError(ts, err) require.NoError(ts, err)
assert.Equal(ts, refCountValue, rc) require.Equal(ts, refCountValue, rc)
_, err = refCountFromBytes(wrongSizeRefCountBytes) _, err = refCountFromBytes(wrongSizeRefCountBytes)
assert.Error(ts, err) require.Error(ts, err)
}) })
} }

View File

@ -205,7 +205,7 @@ func Remove(
if rcToSubtract > readRes.total { if rcToSubtract > readRes.total {
// BUG: this should never happen! // BUG: this should never happen!
return false, fmt.Errorf("refcount underflow, reftracker object corrupted") return false, goerrors.New("refcount underflow, reftracker object corrupted")
} }
newRC := readRes.total - rcToSubtract newRC := readRes.total - rcToSubtract

View File

@ -17,14 +17,13 @@ limitations under the License.
package v1 package v1
import ( import (
goerrors "errors"
"testing" "testing"
"github.com/ceph/ceph-csi/internal/util/reftracker/errors" "github.com/ceph/ceph-csi/internal/util/reftracker/errors"
"github.com/ceph/ceph-csi/internal/util/reftracker/radoswrapper" "github.com/ceph/ceph-csi/internal/util/reftracker/radoswrapper"
"github.com/ceph/ceph-csi/internal/util/reftracker/reftype" "github.com/ceph/ceph-csi/internal/util/reftracker/reftype"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func TestV1Read(t *testing.T) { func TestV1Read(t *testing.T) {
@ -73,17 +72,17 @@ func TestV1Read(t *testing.T) {
) )
err := Add(validObj, rtName, gen, refsToAdd) err := Add(validObj, rtName, gen, refsToAdd)
assert.NoError(t, err) require.NoError(t, err)
for i := range invalidObjs { for i := range invalidObjs {
err = Add(invalidObjs[i], rtName, gen, refsToAdd) err = Add(invalidObjs[i], rtName, gen, refsToAdd)
assert.Error(t, err) require.Error(t, err)
} }
// Check for correct error type for wrong gen num. // Check for correct error type for wrong gen num.
err = Add(invalidObjs[1], rtName, gen, refsToAdd) err = Add(invalidObjs[1], rtName, gen, refsToAdd)
assert.Error(t, err) require.Error(t, err)
assert.True(t, goerrors.Is(err, errors.ErrObjectOutOfDate)) require.ErrorIs(t, err, errors.ErrObjectOutOfDate)
} }
func TestV1Init(t *testing.T) { func TestV1Init(t *testing.T) {
@ -106,10 +105,10 @@ func TestV1Init(t *testing.T) {
) )
err := Init(emptyRados, rtName, refsToInit) err := Init(emptyRados, rtName, refsToInit)
assert.NoError(t, err) require.NoError(t, err)
err = Init(alreadyExists, rtName, refsToInit) err = Init(alreadyExists, rtName, refsToInit)
assert.Error(t, err) require.Error(t, err)
} }
func TestV1Add(t *testing.T) { func TestV1Add(t *testing.T) {
@ -224,19 +223,19 @@ func TestV1Add(t *testing.T) {
ioctx.Rados.Objs[rtName] = shouldSucceed[i].before ioctx.Rados.Objs[rtName] = shouldSucceed[i].before
err := Add(ioctx, rtName, 0, shouldSucceed[i].refsToAdd) err := Add(ioctx, rtName, 0, shouldSucceed[i].refsToAdd)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, shouldSucceed[i].after, ioctx.Rados.Objs[rtName]) require.Equal(t, shouldSucceed[i].after, ioctx.Rados.Objs[rtName])
} }
for i := range shouldFail { for i := range shouldFail {
err := Add(shouldFail[i], rtName, 0, map[string]struct{}{"ref1": {}}) err := Add(shouldFail[i], rtName, 0, map[string]struct{}{"ref1": {}})
assert.Error(t, err) require.Error(t, err)
} }
// Check for correct error type for wrong gen num. // Check for correct error type for wrong gen num.
err := Add(shouldFail[1], rtName, 0, map[string]struct{}{"ref1": {}}) err := Add(shouldFail[1], rtName, 0, map[string]struct{}{"ref1": {}})
assert.Error(t, err) require.Error(t, err)
assert.True(t, goerrors.Is(err, errors.ErrObjectOutOfDate)) require.ErrorIs(t, err, errors.ErrObjectOutOfDate)
} }
func TestV1Remove(t *testing.T) { func TestV1Remove(t *testing.T) {
@ -412,12 +411,12 @@ func TestV1Remove(t *testing.T) {
ioctx.Rados.Objs[rtName] = shouldSucceed[i].before ioctx.Rados.Objs[rtName] = shouldSucceed[i].before
deleted, err := Remove(ioctx, rtName, 0, shouldSucceed[i].refsToRemove) deleted, err := Remove(ioctx, rtName, 0, shouldSucceed[i].refsToRemove)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, shouldSucceed[i].deleted, deleted) require.Equal(t, shouldSucceed[i].deleted, deleted)
assert.Equal(t, shouldSucceed[i].after, ioctx.Rados.Objs[rtName]) require.Equal(t, shouldSucceed[i].after, ioctx.Rados.Objs[rtName])
} }
_, err := Remove(badGen, rtName, 0, map[string]reftype.RefType{"ref": reftype.Normal}) _, err := Remove(badGen, rtName, 0, map[string]reftype.RefType{"ref": reftype.Normal})
assert.Error(t, err) require.Error(t, err)
assert.True(t, goerrors.Is(err, errors.ErrObjectOutOfDate)) require.ErrorIs(t, err, errors.ErrObjectOutOfDate)
} }

View File

@ -21,7 +21,7 @@ import (
"github.com/ceph/ceph-csi/internal/util/reftracker/radoswrapper" "github.com/ceph/ceph-csi/internal/util/reftracker/radoswrapper"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
var ( var (
@ -38,18 +38,18 @@ func TestVersionBytes(t *testing.T) {
ts.Parallel() ts.Parallel()
bs := ToBytes(v1Value) bs := ToBytes(v1Value)
assert.Equal(ts, v1Bytes, bs) require.Equal(ts, v1Bytes, bs)
}) })
t.Run("FromBytes", func(ts *testing.T) { t.Run("FromBytes", func(ts *testing.T) {
ts.Parallel() ts.Parallel()
ver, err := FromBytes(v1Bytes) ver, err := FromBytes(v1Bytes)
assert.NoError(ts, err) require.NoError(ts, err)
assert.Equal(ts, v1Value, ver) require.Equal(ts, v1Value, ver)
_, err = FromBytes(wrongSizeVersionBytes) _, err = FromBytes(wrongSizeVersionBytes)
assert.Error(ts, err) require.Error(ts, err)
}) })
} }
@ -101,11 +101,11 @@ func TestVersionRead(t *testing.T) {
) )
ver, err := Read(validObj, rtName) ver, err := Read(validObj, rtName)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, v1Value, ver) require.Equal(t, v1Value, ver)
for i := range invalidObjs { for i := range invalidObjs {
_, err = Read(invalidObjs[i], rtName) _, err = Read(invalidObjs[i], rtName)
assert.Error(t, err) require.Error(t, err)
} }
} }

View File

@ -212,7 +212,7 @@ func parseKernelRelease(release string) (int, int, int, int, error) {
extraversion := 0 extraversion := 0
if n > minVersions { if n > minVersions {
n, err = fmt.Sscanf(extra, ".%d%s", &sublevel, &extra) n, err = fmt.Sscanf(extra, ".%d%s", &sublevel, &extra)
if err != nil && n == 0 && len(extra) > 0 && extra[0] != '-' && extra[0] == '.' { if err != nil && n == 0 && extra != "" && extra[0] != '-' && extra[0] == '.' {
return 0, 0, 0, 0, fmt.Errorf("failed to parse subversion from %s: %w", release, err) return 0, 0, 0, 0, fmt.Errorf("failed to parse subversion from %s: %w", release, err)
} }

View File

@ -87,7 +87,7 @@ func ValidateNodeUnpublishVolumeRequest(req *csi.NodeUnpublishVolumeRequest) err
// volume is from source as empty ReadOnlyMany is not supported. // volume is from source as empty ReadOnlyMany is not supported.
func CheckReadOnlyManyIsSupported(req *csi.CreateVolumeRequest) error { func CheckReadOnlyManyIsSupported(req *csi.CreateVolumeRequest) error {
for _, capability := range req.GetVolumeCapabilities() { for _, capability := range req.GetVolumeCapabilities() {
if m := capability.GetAccessMode().Mode; m == csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY || if m := capability.GetAccessMode().GetMode(); m == csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY ||
m == csi.VolumeCapability_AccessMode_SINGLE_NODE_READER_ONLY { m == csi.VolumeCapability_AccessMode_SINGLE_NODE_READER_ONLY {
if req.GetVolumeContentSource() == nil { if req.GetVolumeContentSource() == nil {
return status.Error(codes.InvalidArgument, "readOnly accessMode is supported only with content source") return status.Error(codes.InvalidArgument, "readOnly accessMode is supported only with content source")

View File

@ -1,5 +1,5 @@
--- ---
# https://github.com/golangci/golangci-lint/blob/master/.golangci.example.yml # https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
# This file contains all available configuration options # This file contains all available configuration options
# with their default values. # with their default values.
@ -12,7 +12,7 @@ run:
concurrency: 4 concurrency: 4
# timeout for analysis, e.g. 30s, 5m, default is 1m # timeout for analysis, e.g. 30s, 5m, default is 1m
deadline: 10m timeout: 20m
# exit code when at least one issue was found, default is 1 # exit code when at least one issue was found, default is 1
issues-exit-code: 1 issues-exit-code: 1

View File

@ -33,6 +33,9 @@
#define CONSTBASE R16 #define CONSTBASE R16
#define BLOCKS R17 #define BLOCKS R17
// for VPERMXOR
#define MASK R18
DATA consts<>+0x00(SB)/8, $0x3320646e61707865 DATA consts<>+0x00(SB)/8, $0x3320646e61707865
DATA consts<>+0x08(SB)/8, $0x6b20657479622d32 DATA consts<>+0x08(SB)/8, $0x6b20657479622d32
DATA consts<>+0x10(SB)/8, $0x0000000000000001 DATA consts<>+0x10(SB)/8, $0x0000000000000001
@ -53,7 +56,11 @@ DATA consts<>+0x80(SB)/8, $0x6b2065746b206574
DATA consts<>+0x88(SB)/8, $0x6b2065746b206574 DATA consts<>+0x88(SB)/8, $0x6b2065746b206574
DATA consts<>+0x90(SB)/8, $0x0000000100000000 DATA consts<>+0x90(SB)/8, $0x0000000100000000
DATA consts<>+0x98(SB)/8, $0x0000000300000002 DATA consts<>+0x98(SB)/8, $0x0000000300000002
GLOBL consts<>(SB), RODATA, $0xa0 DATA consts<>+0xa0(SB)/8, $0x5566774411223300
DATA consts<>+0xa8(SB)/8, $0xddeeffcc99aabb88
DATA consts<>+0xb0(SB)/8, $0x6677445522330011
DATA consts<>+0xb8(SB)/8, $0xeeffccddaabb8899
GLOBL consts<>(SB), RODATA, $0xc0
//func chaCha20_ctr32_vsx(out, inp *byte, len int, key *[8]uint32, counter *uint32) //func chaCha20_ctr32_vsx(out, inp *byte, len int, key *[8]uint32, counter *uint32)
TEXT ·chaCha20_ctr32_vsx(SB),NOSPLIT,$64-40 TEXT ·chaCha20_ctr32_vsx(SB),NOSPLIT,$64-40
@ -70,6 +77,9 @@ TEXT ·chaCha20_ctr32_vsx(SB),NOSPLIT,$64-40
MOVD $48, R10 MOVD $48, R10
MOVD $64, R11 MOVD $64, R11
SRD $6, LEN, BLOCKS SRD $6, LEN, BLOCKS
// for VPERMXOR
MOVD $consts<>+0xa0(SB), MASK
MOVD $16, R20
// V16 // V16
LXVW4X (CONSTBASE)(R0), VS48 LXVW4X (CONSTBASE)(R0), VS48
ADD $80,CONSTBASE ADD $80,CONSTBASE
@ -87,6 +97,10 @@ TEXT ·chaCha20_ctr32_vsx(SB),NOSPLIT,$64-40
// V28 // V28
LXVW4X (CONSTBASE)(R11), VS60 LXVW4X (CONSTBASE)(R11), VS60
// Load mask constants for VPERMXOR
LXVW4X (MASK)(R0), V20
LXVW4X (MASK)(R20), V21
// splat slot from V19 -> V26 // splat slot from V19 -> V26
VSPLTW $0, V19, V26 VSPLTW $0, V19, V26
@ -97,7 +111,7 @@ TEXT ·chaCha20_ctr32_vsx(SB),NOSPLIT,$64-40
MOVD $10, R14 MOVD $10, R14
MOVD R14, CTR MOVD R14, CTR
PCALIGN $16
loop_outer_vsx: loop_outer_vsx:
// V0, V1, V2, V3 // V0, V1, V2, V3
LXVW4X (R0)(CONSTBASE), VS32 LXVW4X (R0)(CONSTBASE), VS32
@ -128,22 +142,17 @@ loop_outer_vsx:
VSPLTISW $12, V28 VSPLTISW $12, V28
VSPLTISW $8, V29 VSPLTISW $8, V29
VSPLTISW $7, V30 VSPLTISW $7, V30
PCALIGN $16
loop_vsx: loop_vsx:
VADDUWM V0, V4, V0 VADDUWM V0, V4, V0
VADDUWM V1, V5, V1 VADDUWM V1, V5, V1
VADDUWM V2, V6, V2 VADDUWM V2, V6, V2
VADDUWM V3, V7, V3 VADDUWM V3, V7, V3
VXOR V12, V0, V12 VPERMXOR V12, V0, V21, V12
VXOR V13, V1, V13 VPERMXOR V13, V1, V21, V13
VXOR V14, V2, V14 VPERMXOR V14, V2, V21, V14
VXOR V15, V3, V15 VPERMXOR V15, V3, V21, V15
VRLW V12, V27, V12
VRLW V13, V27, V13
VRLW V14, V27, V14
VRLW V15, V27, V15
VADDUWM V8, V12, V8 VADDUWM V8, V12, V8
VADDUWM V9, V13, V9 VADDUWM V9, V13, V9
@ -165,15 +174,10 @@ loop_vsx:
VADDUWM V2, V6, V2 VADDUWM V2, V6, V2
VADDUWM V3, V7, V3 VADDUWM V3, V7, V3
VXOR V12, V0, V12 VPERMXOR V12, V0, V20, V12
VXOR V13, V1, V13 VPERMXOR V13, V1, V20, V13
VXOR V14, V2, V14 VPERMXOR V14, V2, V20, V14
VXOR V15, V3, V15 VPERMXOR V15, V3, V20, V15
VRLW V12, V29, V12
VRLW V13, V29, V13
VRLW V14, V29, V14
VRLW V15, V29, V15
VADDUWM V8, V12, V8 VADDUWM V8, V12, V8
VADDUWM V9, V13, V9 VADDUWM V9, V13, V9
@ -195,15 +199,10 @@ loop_vsx:
VADDUWM V2, V7, V2 VADDUWM V2, V7, V2
VADDUWM V3, V4, V3 VADDUWM V3, V4, V3
VXOR V15, V0, V15 VPERMXOR V15, V0, V21, V15
VXOR V12, V1, V12 VPERMXOR V12, V1, V21, V12
VXOR V13, V2, V13 VPERMXOR V13, V2, V21, V13
VXOR V14, V3, V14 VPERMXOR V14, V3, V21, V14
VRLW V15, V27, V15
VRLW V12, V27, V12
VRLW V13, V27, V13
VRLW V14, V27, V14
VADDUWM V10, V15, V10 VADDUWM V10, V15, V10
VADDUWM V11, V12, V11 VADDUWM V11, V12, V11
@ -225,15 +224,10 @@ loop_vsx:
VADDUWM V2, V7, V2 VADDUWM V2, V7, V2
VADDUWM V3, V4, V3 VADDUWM V3, V4, V3
VXOR V15, V0, V15 VPERMXOR V15, V0, V20, V15
VXOR V12, V1, V12 VPERMXOR V12, V1, V20, V12
VXOR V13, V2, V13 VPERMXOR V13, V2, V20, V13
VXOR V14, V3, V14 VPERMXOR V14, V3, V20, V14
VRLW V15, V29, V15
VRLW V12, V29, V12
VRLW V13, V29, V13
VRLW V14, V29, V14
VADDUWM V10, V15, V10 VADDUWM V10, V15, V10
VADDUWM V11, V12, V11 VADDUWM V11, V12, V11
@ -249,48 +243,48 @@ loop_vsx:
VRLW V6, V30, V6 VRLW V6, V30, V6
VRLW V7, V30, V7 VRLW V7, V30, V7
VRLW V4, V30, V4 VRLW V4, V30, V4
BC 16, LT, loop_vsx BDNZ loop_vsx
VADDUWM V12, V26, V12 VADDUWM V12, V26, V12
WORD $0x13600F8C // VMRGEW V0, V1, V27 VMRGEW V0, V1, V27
WORD $0x13821F8C // VMRGEW V2, V3, V28 VMRGEW V2, V3, V28
WORD $0x10000E8C // VMRGOW V0, V1, V0 VMRGOW V0, V1, V0
WORD $0x10421E8C // VMRGOW V2, V3, V2 VMRGOW V2, V3, V2
WORD $0x13A42F8C // VMRGEW V4, V5, V29 VMRGEW V4, V5, V29
WORD $0x13C63F8C // VMRGEW V6, V7, V30 VMRGEW V6, V7, V30
XXPERMDI VS32, VS34, $0, VS33 XXPERMDI VS32, VS34, $0, VS33
XXPERMDI VS32, VS34, $3, VS35 XXPERMDI VS32, VS34, $3, VS35
XXPERMDI VS59, VS60, $0, VS32 XXPERMDI VS59, VS60, $0, VS32
XXPERMDI VS59, VS60, $3, VS34 XXPERMDI VS59, VS60, $3, VS34
WORD $0x10842E8C // VMRGOW V4, V5, V4 VMRGOW V4, V5, V4
WORD $0x10C63E8C // VMRGOW V6, V7, V6 VMRGOW V6, V7, V6
WORD $0x13684F8C // VMRGEW V8, V9, V27 VMRGEW V8, V9, V27
WORD $0x138A5F8C // VMRGEW V10, V11, V28 VMRGEW V10, V11, V28
XXPERMDI VS36, VS38, $0, VS37 XXPERMDI VS36, VS38, $0, VS37
XXPERMDI VS36, VS38, $3, VS39 XXPERMDI VS36, VS38, $3, VS39
XXPERMDI VS61, VS62, $0, VS36 XXPERMDI VS61, VS62, $0, VS36
XXPERMDI VS61, VS62, $3, VS38 XXPERMDI VS61, VS62, $3, VS38
WORD $0x11084E8C // VMRGOW V8, V9, V8 VMRGOW V8, V9, V8
WORD $0x114A5E8C // VMRGOW V10, V11, V10 VMRGOW V10, V11, V10
WORD $0x13AC6F8C // VMRGEW V12, V13, V29 VMRGEW V12, V13, V29
WORD $0x13CE7F8C // VMRGEW V14, V15, V30 VMRGEW V14, V15, V30
XXPERMDI VS40, VS42, $0, VS41 XXPERMDI VS40, VS42, $0, VS41
XXPERMDI VS40, VS42, $3, VS43 XXPERMDI VS40, VS42, $3, VS43
XXPERMDI VS59, VS60, $0, VS40 XXPERMDI VS59, VS60, $0, VS40
XXPERMDI VS59, VS60, $3, VS42 XXPERMDI VS59, VS60, $3, VS42
WORD $0x118C6E8C // VMRGOW V12, V13, V12 VMRGOW V12, V13, V12
WORD $0x11CE7E8C // VMRGOW V14, V15, V14 VMRGOW V14, V15, V14
VSPLTISW $4, V27 VSPLTISW $4, V27
VADDUWM V26, V27, V26 VADDUWM V26, V27, V26
@ -431,7 +425,7 @@ tail_vsx:
ADD $-1, R11, R12 ADD $-1, R11, R12
ADD $-1, INP ADD $-1, INP
ADD $-1, OUT ADD $-1, OUT
PCALIGN $16
looptail_vsx: looptail_vsx:
// Copying the result to OUT // Copying the result to OUT
// in bytes. // in bytes.
@ -439,7 +433,7 @@ looptail_vsx:
MOVBZU 1(INP), TMP MOVBZU 1(INP), TMP
XOR KEY, TMP, KEY XOR KEY, TMP, KEY
MOVBU KEY, 1(OUT) MOVBU KEY, 1(OUT)
BC 16, LT, looptail_vsx BDNZ looptail_vsx
// Clear the stack values // Clear the stack values
STXVW4X VS48, (R11)(R0) STXVW4X VS48, (R11)(R0)

View File

@ -426,6 +426,35 @@ func (l ServerAuthError) Error() string {
return "[" + strings.Join(errs, ", ") + "]" return "[" + strings.Join(errs, ", ") + "]"
} }
// ServerAuthCallbacks defines server-side authentication callbacks.
type ServerAuthCallbacks struct {
// PasswordCallback behaves like [ServerConfig.PasswordCallback].
PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error)
// PublicKeyCallback behaves like [ServerConfig.PublicKeyCallback].
PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
// KeyboardInteractiveCallback behaves like [ServerConfig.KeyboardInteractiveCallback].
KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error)
// GSSAPIWithMICConfig behaves like [ServerConfig.GSSAPIWithMICConfig].
GSSAPIWithMICConfig *GSSAPIWithMICConfig
}
// PartialSuccessError can be returned by any of the [ServerConfig]
// authentication callbacks to indicate to the client that authentication has
// partially succeeded, but further steps are required.
type PartialSuccessError struct {
// Next defines the authentication callbacks to apply to further steps. The
// available methods communicated to the client are based on the non-nil
// ServerAuthCallbacks fields.
Next ServerAuthCallbacks
}
func (p *PartialSuccessError) Error() string {
return "ssh: authenticated with partial success"
}
// ErrNoAuth is the error value returned if no // ErrNoAuth is the error value returned if no
// authentication method has been passed yet. This happens as a normal // authentication method has been passed yet. This happens as a normal
// part of the authentication loop, since the client first tries // part of the authentication loop, since the client first tries
@ -439,8 +468,18 @@ func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, err
var perms *Permissions var perms *Permissions
authFailures := 0 authFailures := 0
noneAuthCount := 0
var authErrs []error var authErrs []error
var displayedBanner bool var displayedBanner bool
partialSuccessReturned := false
// Set the initial authentication callbacks from the config. They can be
// changed if a PartialSuccessError is returned.
authConfig := ServerAuthCallbacks{
PasswordCallback: config.PasswordCallback,
PublicKeyCallback: config.PublicKeyCallback,
KeyboardInteractiveCallback: config.KeyboardInteractiveCallback,
GSSAPIWithMICConfig: config.GSSAPIWithMICConfig,
}
userAuthLoop: userAuthLoop:
for { for {
@ -471,6 +510,11 @@ userAuthLoop:
return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service) return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service)
} }
if s.user != userAuthReq.User && partialSuccessReturned {
return nil, fmt.Errorf("ssh: client changed the user after a partial success authentication, previous user %q, current user %q",
s.user, userAuthReq.User)
}
s.user = userAuthReq.User s.user = userAuthReq.User
if !displayedBanner && config.BannerCallback != nil { if !displayedBanner && config.BannerCallback != nil {
@ -491,20 +535,18 @@ userAuthLoop:
switch userAuthReq.Method { switch userAuthReq.Method {
case "none": case "none":
if config.NoClientAuth { noneAuthCount++
// We don't allow none authentication after a partial success
// response.
if config.NoClientAuth && !partialSuccessReturned {
if config.NoClientAuthCallback != nil { if config.NoClientAuthCallback != nil {
perms, authErr = config.NoClientAuthCallback(s) perms, authErr = config.NoClientAuthCallback(s)
} else { } else {
authErr = nil authErr = nil
} }
} }
// allow initial attempt of 'none' without penalty
if authFailures == 0 {
authFailures--
}
case "password": case "password":
if config.PasswordCallback == nil { if authConfig.PasswordCallback == nil {
authErr = errors.New("ssh: password auth not configured") authErr = errors.New("ssh: password auth not configured")
break break
} }
@ -518,17 +560,17 @@ userAuthLoop:
return nil, parseError(msgUserAuthRequest) return nil, parseError(msgUserAuthRequest)
} }
perms, authErr = config.PasswordCallback(s, password) perms, authErr = authConfig.PasswordCallback(s, password)
case "keyboard-interactive": case "keyboard-interactive":
if config.KeyboardInteractiveCallback == nil { if authConfig.KeyboardInteractiveCallback == nil {
authErr = errors.New("ssh: keyboard-interactive auth not configured") authErr = errors.New("ssh: keyboard-interactive auth not configured")
break break
} }
prompter := &sshClientKeyboardInteractive{s} prompter := &sshClientKeyboardInteractive{s}
perms, authErr = config.KeyboardInteractiveCallback(s, prompter.Challenge) perms, authErr = authConfig.KeyboardInteractiveCallback(s, prompter.Challenge)
case "publickey": case "publickey":
if config.PublicKeyCallback == nil { if authConfig.PublicKeyCallback == nil {
authErr = errors.New("ssh: publickey auth not configured") authErr = errors.New("ssh: publickey auth not configured")
break break
} }
@ -562,11 +604,18 @@ userAuthLoop:
if !ok { if !ok {
candidate.user = s.user candidate.user = s.user
candidate.pubKeyData = pubKeyData candidate.pubKeyData = pubKeyData
candidate.perms, candidate.result = config.PublicKeyCallback(s, pubKey) candidate.perms, candidate.result = authConfig.PublicKeyCallback(s, pubKey)
if candidate.result == nil && candidate.perms != nil && candidate.perms.CriticalOptions != nil && candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" { _, isPartialSuccessError := candidate.result.(*PartialSuccessError)
candidate.result = checkSourceAddress(
if (candidate.result == nil || isPartialSuccessError) &&
candidate.perms != nil &&
candidate.perms.CriticalOptions != nil &&
candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" {
if err := checkSourceAddress(
s.RemoteAddr(), s.RemoteAddr(),
candidate.perms.CriticalOptions[sourceAddressCriticalOption]) candidate.perms.CriticalOptions[sourceAddressCriticalOption]); err != nil {
candidate.result = err
}
} }
cache.add(candidate) cache.add(candidate)
} }
@ -578,8 +627,8 @@ userAuthLoop:
if len(payload) > 0 { if len(payload) > 0 {
return nil, parseError(msgUserAuthRequest) return nil, parseError(msgUserAuthRequest)
} }
_, isPartialSuccessError := candidate.result.(*PartialSuccessError)
if candidate.result == nil { if candidate.result == nil || isPartialSuccessError {
okMsg := userAuthPubKeyOkMsg{ okMsg := userAuthPubKeyOkMsg{
Algo: algo, Algo: algo,
PubKey: pubKeyData, PubKey: pubKeyData,
@ -629,11 +678,11 @@ userAuthLoop:
perms = candidate.perms perms = candidate.perms
} }
case "gssapi-with-mic": case "gssapi-with-mic":
if config.GSSAPIWithMICConfig == nil { if authConfig.GSSAPIWithMICConfig == nil {
authErr = errors.New("ssh: gssapi-with-mic auth not configured") authErr = errors.New("ssh: gssapi-with-mic auth not configured")
break break
} }
gssapiConfig := config.GSSAPIWithMICConfig gssapiConfig := authConfig.GSSAPIWithMICConfig
userAuthRequestGSSAPI, err := parseGSSAPIPayload(userAuthReq.Payload) userAuthRequestGSSAPI, err := parseGSSAPIPayload(userAuthReq.Payload)
if err != nil { if err != nil {
return nil, parseError(msgUserAuthRequest) return nil, parseError(msgUserAuthRequest)
@ -689,7 +738,28 @@ userAuthLoop:
break userAuthLoop break userAuthLoop
} }
var failureMsg userAuthFailureMsg
if partialSuccess, ok := authErr.(*PartialSuccessError); ok {
// After a partial success error we don't allow changing the user
// name and execute the NoClientAuthCallback.
partialSuccessReturned = true
// In case a partial success is returned, the server may send
// a new set of authentication methods.
authConfig = partialSuccess.Next
// Reset pubkey cache, as the new PublicKeyCallback might
// accept a different set of public keys.
cache = pubKeyCache{}
// Send back a partial success message to the user.
failureMsg.PartialSuccess = true
} else {
// Allow initial attempt of 'none' without penalty.
if authFailures > 0 || userAuthReq.Method != "none" || noneAuthCount != 1 {
authFailures++ authFailures++
}
if config.MaxAuthTries > 0 && authFailures >= config.MaxAuthTries { if config.MaxAuthTries > 0 && authFailures >= config.MaxAuthTries {
// If we have hit the max attempts, don't bother sending the // If we have hit the max attempts, don't bother sending the
// final SSH_MSG_USERAUTH_FAILURE message, since there are // final SSH_MSG_USERAUTH_FAILURE message, since there are
@ -709,29 +779,29 @@ userAuthLoop:
// disconnect, should we only send that message.) // disconnect, should we only send that message.)
// //
// Either way, OpenSSH disconnects immediately after the last // Either way, OpenSSH disconnects immediately after the last
// failed authnetication attempt, and given they are typically // failed authentication attempt, and given they are typically
// considered the golden implementation it seems reasonable // considered the golden implementation it seems reasonable
// to match that behavior. // to match that behavior.
continue continue
} }
}
var failureMsg userAuthFailureMsg if authConfig.PasswordCallback != nil {
if config.PasswordCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "password") failureMsg.Methods = append(failureMsg.Methods, "password")
} }
if config.PublicKeyCallback != nil { if authConfig.PublicKeyCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "publickey") failureMsg.Methods = append(failureMsg.Methods, "publickey")
} }
if config.KeyboardInteractiveCallback != nil { if authConfig.KeyboardInteractiveCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive") failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive")
} }
if config.GSSAPIWithMICConfig != nil && config.GSSAPIWithMICConfig.Server != nil && if authConfig.GSSAPIWithMICConfig != nil && authConfig.GSSAPIWithMICConfig.Server != nil &&
config.GSSAPIWithMICConfig.AllowLogin != nil { authConfig.GSSAPIWithMICConfig.AllowLogin != nil {
failureMsg.Methods = append(failureMsg.Methods, "gssapi-with-mic") failureMsg.Methods = append(failureMsg.Methods, "gssapi-with-mic")
} }
if len(failureMsg.Methods) == 0 { if len(failureMsg.Methods) == 0 {
return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false") return nil, errors.New("ssh: no authentication methods available")
} }
if err := s.transport.writePacket(Marshal(&failureMsg)); err != nil { if err := s.transport.writePacket(Marshal(&failureMsg)); err != nil {

View File

@ -1564,6 +1564,7 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) {
if size > remainSize { if size > remainSize {
hdec.SetEmitEnabled(false) hdec.SetEmitEnabled(false)
mh.Truncated = true mh.Truncated = true
remainSize = 0
return return
} }
remainSize -= size remainSize -= size
@ -1576,6 +1577,36 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) {
var hc headersOrContinuation = hf var hc headersOrContinuation = hf
for { for {
frag := hc.HeaderBlockFragment() frag := hc.HeaderBlockFragment()
// Avoid parsing large amounts of headers that we will then discard.
// If the sender exceeds the max header list size by too much,
// skip parsing the fragment and close the connection.
//
// "Too much" is either any CONTINUATION frame after we've already
// exceeded the max header list size (in which case remainSize is 0),
// or a frame whose encoded size is more than twice the remaining
// header list bytes we're willing to accept.
if int64(len(frag)) > int64(2*remainSize) {
if VerboseLogs {
log.Printf("http2: header list too large")
}
// It would be nice to send a RST_STREAM before sending the GOAWAY,
// but the structure of the server's frame writer makes this difficult.
return nil, ConnectionError(ErrCodeProtocol)
}
// Also close the connection after any CONTINUATION frame following an
// invalid header, since we stop tracking the size of the headers after
// an invalid one.
if invalid != nil {
if VerboseLogs {
log.Printf("http2: invalid header: %v", invalid)
}
// It would be nice to send a RST_STREAM before sending the GOAWAY,
// but the structure of the server's frame writer makes this difficult.
return nil, ConnectionError(ErrCodeProtocol)
}
if _, err := hdec.Write(frag); err != nil { if _, err := hdec.Write(frag); err != nil {
return nil, ConnectionError(ErrCodeCompression) return nil, ConnectionError(ErrCodeCompression)
} }

View File

@ -77,7 +77,10 @@ func (p *pipe) Read(d []byte) (n int, err error) {
} }
} }
var errClosedPipeWrite = errors.New("write on closed buffer") var (
errClosedPipeWrite = errors.New("write on closed buffer")
errUninitializedPipeWrite = errors.New("write on uninitialized buffer")
)
// Write copies bytes from p into the buffer and wakes a reader. // Write copies bytes from p into the buffer and wakes a reader.
// It is an error to write more data than the buffer can hold. // It is an error to write more data than the buffer can hold.
@ -91,6 +94,12 @@ func (p *pipe) Write(d []byte) (n int, err error) {
if p.err != nil || p.breakErr != nil { if p.err != nil || p.breakErr != nil {
return 0, errClosedPipeWrite return 0, errClosedPipeWrite
} }
// pipe.setBuffer is never invoked, leaving the buffer uninitialized.
// We shouldn't try to write to an uninitialized pipe,
// but returning an error is better than panicking.
if p.b == nil {
return 0, errUninitializedPipeWrite
}
return p.b.Write(d) return p.b.Write(d)
} }

View File

@ -124,6 +124,7 @@ type Server struct {
// IdleTimeout specifies how long until idle clients should be // IdleTimeout specifies how long until idle clients should be
// closed with a GOAWAY frame. PING frames are not considered // closed with a GOAWAY frame. PING frames are not considered
// activity for the purposes of IdleTimeout. // activity for the purposes of IdleTimeout.
// If zero or negative, there is no timeout.
IdleTimeout time.Duration IdleTimeout time.Duration
// MaxUploadBufferPerConnection is the size of the initial flow // MaxUploadBufferPerConnection is the size of the initial flow
@ -434,7 +435,7 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
// passes the connection off to us with the deadline already set. // passes the connection off to us with the deadline already set.
// Write deadlines are set per stream in serverConn.newStream. // Write deadlines are set per stream in serverConn.newStream.
// Disarm the net.Conn write deadline here. // Disarm the net.Conn write deadline here.
if sc.hs.WriteTimeout != 0 { if sc.hs.WriteTimeout > 0 {
sc.conn.SetWriteDeadline(time.Time{}) sc.conn.SetWriteDeadline(time.Time{})
} }
@ -924,7 +925,7 @@ func (sc *serverConn) serve() {
sc.setConnState(http.StateActive) sc.setConnState(http.StateActive)
sc.setConnState(http.StateIdle) sc.setConnState(http.StateIdle)
if sc.srv.IdleTimeout != 0 { if sc.srv.IdleTimeout > 0 {
sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer) sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer)
defer sc.idleTimer.Stop() defer sc.idleTimer.Stop()
} }
@ -1637,7 +1638,7 @@ func (sc *serverConn) closeStream(st *stream, err error) {
delete(sc.streams, st.id) delete(sc.streams, st.id)
if len(sc.streams) == 0 { if len(sc.streams) == 0 {
sc.setConnState(http.StateIdle) sc.setConnState(http.StateIdle)
if sc.srv.IdleTimeout != 0 { if sc.srv.IdleTimeout > 0 {
sc.idleTimer.Reset(sc.srv.IdleTimeout) sc.idleTimer.Reset(sc.srv.IdleTimeout)
} }
if h1ServerKeepAlivesDisabled(sc.hs) { if h1ServerKeepAlivesDisabled(sc.hs) {
@ -2017,7 +2018,7 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
// similar to how the http1 server works. Here it's // similar to how the http1 server works. Here it's
// technically more like the http1 Server's ReadHeaderTimeout // technically more like the http1 Server's ReadHeaderTimeout
// (in Go 1.8), though. That's a more sane option anyway. // (in Go 1.8), though. That's a more sane option anyway.
if sc.hs.ReadTimeout != 0 { if sc.hs.ReadTimeout > 0 {
sc.conn.SetReadDeadline(time.Time{}) sc.conn.SetReadDeadline(time.Time{})
st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout) st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
} }
@ -2038,7 +2039,7 @@ func (sc *serverConn) upgradeRequest(req *http.Request) {
// Disable any read deadline set by the net/http package // Disable any read deadline set by the net/http package
// prior to the upgrade. // prior to the upgrade.
if sc.hs.ReadTimeout != 0 { if sc.hs.ReadTimeout > 0 {
sc.conn.SetReadDeadline(time.Time{}) sc.conn.SetReadDeadline(time.Time{})
} }
@ -2116,7 +2117,7 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream
st.flow.conn = &sc.flow // link to conn-level counter st.flow.conn = &sc.flow // link to conn-level counter
st.flow.add(sc.initialStreamSendWindowSize) st.flow.add(sc.initialStreamSendWindowSize)
st.inflow.init(sc.srv.initialStreamRecvWindowSize()) st.inflow.init(sc.srv.initialStreamRecvWindowSize())
if sc.hs.WriteTimeout != 0 { if sc.hs.WriteTimeout > 0 {
st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout)
} }

331
vendor/golang.org/x/net/http2/testsync.go generated vendored Normal file
View File

@ -0,0 +1,331 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
import (
"context"
"sync"
"time"
)
// testSyncHooks coordinates goroutines in tests.
//
// For example, a call to ClientConn.RoundTrip involves several goroutines, including:
// - the goroutine running RoundTrip;
// - the clientStream.doRequest goroutine, which writes the request; and
// - the clientStream.readLoop goroutine, which reads the response.
//
// Using testSyncHooks, a test can start a RoundTrip and identify when all these goroutines
// are blocked waiting for some condition such as reading the Request.Body or waiting for
// flow control to become available.
//
// The testSyncHooks also manage timers and synthetic time in tests.
// This permits us to, for example, start a request and cause it to time out waiting for
// response headers without resorting to time.Sleep calls.
type testSyncHooks struct {
// active/inactive act as a mutex and condition variable.
//
// - neither chan contains a value: testSyncHooks is locked.
// - active contains a value: unlocked, and at least one goroutine is not blocked
// - inactive contains a value: unlocked, and all goroutines are blocked
active chan struct{}
inactive chan struct{}
// goroutine counts
total int // total goroutines
condwait map[*sync.Cond]int // blocked in sync.Cond.Wait
blocked []*testBlockedGoroutine // otherwise blocked
// fake time
now time.Time
timers []*fakeTimer
// Transport testing: Report various events.
newclientconn func(*ClientConn)
newstream func(*clientStream)
}
// testBlockedGoroutine is a blocked goroutine.
type testBlockedGoroutine struct {
f func() bool // blocked until f returns true
ch chan struct{} // closed when unblocked
}
func newTestSyncHooks() *testSyncHooks {
h := &testSyncHooks{
active: make(chan struct{}, 1),
inactive: make(chan struct{}, 1),
condwait: map[*sync.Cond]int{},
}
h.inactive <- struct{}{}
h.now = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)
return h
}
// lock acquires the testSyncHooks mutex.
func (h *testSyncHooks) lock() {
select {
case <-h.active:
case <-h.inactive:
}
}
// waitInactive waits for all goroutines to become inactive.
func (h *testSyncHooks) waitInactive() {
for {
<-h.inactive
if !h.unlock() {
break
}
}
}
// unlock releases the testSyncHooks mutex.
// It reports whether any goroutines are active.
func (h *testSyncHooks) unlock() (active bool) {
// Look for a blocked goroutine which can be unblocked.
blocked := h.blocked[:0]
unblocked := false
for _, b := range h.blocked {
if !unblocked && b.f() {
unblocked = true
close(b.ch)
} else {
blocked = append(blocked, b)
}
}
h.blocked = blocked
// Count goroutines blocked on condition variables.
condwait := 0
for _, count := range h.condwait {
condwait += count
}
if h.total > condwait+len(blocked) {
h.active <- struct{}{}
return true
} else {
h.inactive <- struct{}{}
return false
}
}
// goRun starts a new goroutine.
func (h *testSyncHooks) goRun(f func()) {
h.lock()
h.total++
h.unlock()
go func() {
defer func() {
h.lock()
h.total--
h.unlock()
}()
f()
}()
}
// blockUntil indicates that a goroutine is blocked waiting for some condition to become true.
// It waits until f returns true before proceeding.
//
// Example usage:
//
// h.blockUntil(func() bool {
// // Is the context done yet?
// select {
// case <-ctx.Done():
// default:
// return false
// }
// return true
// })
// // Wait for the context to become done.
// <-ctx.Done()
//
// The function f passed to blockUntil must be non-blocking and idempotent.
func (h *testSyncHooks) blockUntil(f func() bool) {
if f() {
return
}
ch := make(chan struct{})
h.lock()
h.blocked = append(h.blocked, &testBlockedGoroutine{
f: f,
ch: ch,
})
h.unlock()
<-ch
}
// broadcast is sync.Cond.Broadcast.
func (h *testSyncHooks) condBroadcast(cond *sync.Cond) {
h.lock()
delete(h.condwait, cond)
h.unlock()
cond.Broadcast()
}
// broadcast is sync.Cond.Wait.
func (h *testSyncHooks) condWait(cond *sync.Cond) {
h.lock()
h.condwait[cond]++
h.unlock()
}
// newTimer creates a new fake timer.
func (h *testSyncHooks) newTimer(d time.Duration) timer {
h.lock()
defer h.unlock()
t := &fakeTimer{
hooks: h,
when: h.now.Add(d),
c: make(chan time.Time),
}
h.timers = append(h.timers, t)
return t
}
// afterFunc creates a new fake AfterFunc timer.
func (h *testSyncHooks) afterFunc(d time.Duration, f func()) timer {
h.lock()
defer h.unlock()
t := &fakeTimer{
hooks: h,
when: h.now.Add(d),
f: f,
}
h.timers = append(h.timers, t)
return t
}
func (h *testSyncHooks) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(ctx)
t := h.afterFunc(d, cancel)
return ctx, func() {
t.Stop()
cancel()
}
}
func (h *testSyncHooks) timeUntilEvent() time.Duration {
h.lock()
defer h.unlock()
var next time.Time
for _, t := range h.timers {
if next.IsZero() || t.when.Before(next) {
next = t.when
}
}
if d := next.Sub(h.now); d > 0 {
return d
}
return 0
}
// advance advances time and causes synthetic timers to fire.
func (h *testSyncHooks) advance(d time.Duration) {
h.lock()
defer h.unlock()
h.now = h.now.Add(d)
timers := h.timers[:0]
for _, t := range h.timers {
t := t // remove after go.mod depends on go1.22
t.mu.Lock()
switch {
case t.when.After(h.now):
timers = append(timers, t)
case t.when.IsZero():
// stopped timer
default:
t.when = time.Time{}
if t.c != nil {
close(t.c)
}
if t.f != nil {
h.total++
go func() {
defer func() {
h.lock()
h.total--
h.unlock()
}()
t.f()
}()
}
}
t.mu.Unlock()
}
h.timers = timers
}
// A timer wraps a time.Timer, or a synthetic equivalent in tests.
// Unlike time.Timer, timer is single-use: The timer channel is closed when the timer expires.
type timer interface {
C() <-chan time.Time
Stop() bool
Reset(d time.Duration) bool
}
// timeTimer implements timer using real time.
type timeTimer struct {
t *time.Timer
c chan time.Time
}
// newTimeTimer creates a new timer using real time.
func newTimeTimer(d time.Duration) timer {
ch := make(chan time.Time)
t := time.AfterFunc(d, func() {
close(ch)
})
return &timeTimer{t, ch}
}
// newTimeAfterFunc creates an AfterFunc timer using real time.
func newTimeAfterFunc(d time.Duration, f func()) timer {
return &timeTimer{
t: time.AfterFunc(d, f),
}
}
func (t timeTimer) C() <-chan time.Time { return t.c }
func (t timeTimer) Stop() bool { return t.t.Stop() }
func (t timeTimer) Reset(d time.Duration) bool { return t.t.Reset(d) }
// fakeTimer implements timer using fake time.
type fakeTimer struct {
hooks *testSyncHooks
mu sync.Mutex
when time.Time // when the timer will fire
c chan time.Time // closed when the timer fires; mutually exclusive with f
f func() // called when the timer fires; mutually exclusive with c
}
func (t *fakeTimer) C() <-chan time.Time { return t.c }
func (t *fakeTimer) Stop() bool {
t.mu.Lock()
defer t.mu.Unlock()
stopped := t.when.IsZero()
t.when = time.Time{}
return stopped
}
func (t *fakeTimer) Reset(d time.Duration) bool {
if t.c != nil || t.f == nil {
panic("fakeTimer only supports Reset on AfterFunc timers")
}
t.mu.Lock()
defer t.mu.Unlock()
t.hooks.lock()
defer t.hooks.unlock()
active := !t.when.IsZero()
t.when = t.hooks.now.Add(d)
if !active {
t.hooks.timers = append(t.hooks.timers, t)
}
return active
}

View File

@ -147,6 +147,12 @@ type Transport struct {
// waiting for their turn. // waiting for their turn.
StrictMaxConcurrentStreams bool StrictMaxConcurrentStreams bool
// IdleConnTimeout is the maximum amount of time an idle
// (keep-alive) connection will remain idle before closing
// itself.
// Zero means no limit.
IdleConnTimeout time.Duration
// ReadIdleTimeout is the timeout after which a health check using ping // ReadIdleTimeout is the timeout after which a health check using ping
// frame will be carried out if no frame is received on the connection. // frame will be carried out if no frame is received on the connection.
// Note that a ping response will is considered a received frame, so if // Note that a ping response will is considered a received frame, so if
@ -178,6 +184,8 @@ type Transport struct {
connPoolOnce sync.Once connPoolOnce sync.Once
connPoolOrDef ClientConnPool // non-nil version of ConnPool connPoolOrDef ClientConnPool // non-nil version of ConnPool
syncHooks *testSyncHooks
} }
func (t *Transport) maxHeaderListSize() uint32 { func (t *Transport) maxHeaderListSize() uint32 {
@ -302,7 +310,7 @@ type ClientConn struct {
readerErr error // set before readerDone is closed readerErr error // set before readerDone is closed
idleTimeout time.Duration // or 0 for never idleTimeout time.Duration // or 0 for never
idleTimer *time.Timer idleTimer timer
mu sync.Mutex // guards following mu sync.Mutex // guards following
cond *sync.Cond // hold mu; broadcast on flow/closed changes cond *sync.Cond // hold mu; broadcast on flow/closed changes
@ -344,6 +352,60 @@ type ClientConn struct {
werr error // first write error that has occurred werr error // first write error that has occurred
hbuf bytes.Buffer // HPACK encoder writes into this hbuf bytes.Buffer // HPACK encoder writes into this
henc *hpack.Encoder henc *hpack.Encoder
syncHooks *testSyncHooks // can be nil
}
// Hook points used for testing.
// Outside of tests, cc.syncHooks is nil and these all have minimal implementations.
// Inside tests, see the testSyncHooks function docs.
// goRun starts a new goroutine.
func (cc *ClientConn) goRun(f func()) {
if cc.syncHooks != nil {
cc.syncHooks.goRun(f)
return
}
go f()
}
// condBroadcast is cc.cond.Broadcast.
func (cc *ClientConn) condBroadcast() {
if cc.syncHooks != nil {
cc.syncHooks.condBroadcast(cc.cond)
}
cc.cond.Broadcast()
}
// condWait is cc.cond.Wait.
func (cc *ClientConn) condWait() {
if cc.syncHooks != nil {
cc.syncHooks.condWait(cc.cond)
}
cc.cond.Wait()
}
// newTimer creates a new time.Timer, or a synthetic timer in tests.
func (cc *ClientConn) newTimer(d time.Duration) timer {
if cc.syncHooks != nil {
return cc.syncHooks.newTimer(d)
}
return newTimeTimer(d)
}
// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests.
func (cc *ClientConn) afterFunc(d time.Duration, f func()) timer {
if cc.syncHooks != nil {
return cc.syncHooks.afterFunc(d, f)
}
return newTimeAfterFunc(d, f)
}
func (cc *ClientConn) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
if cc.syncHooks != nil {
return cc.syncHooks.contextWithTimeout(ctx, d)
}
return context.WithTimeout(ctx, d)
} }
// clientStream is the state for a single HTTP/2 stream. One of these // clientStream is the state for a single HTTP/2 stream. One of these
@ -425,7 +487,7 @@ func (cs *clientStream) abortStreamLocked(err error) {
// TODO(dneil): Clean up tests where cs.cc.cond is nil. // TODO(dneil): Clean up tests where cs.cc.cond is nil.
if cs.cc.cond != nil { if cs.cc.cond != nil {
// Wake up writeRequestBody if it is waiting on flow control. // Wake up writeRequestBody if it is waiting on flow control.
cs.cc.cond.Broadcast() cs.cc.condBroadcast()
} }
} }
@ -435,7 +497,7 @@ func (cs *clientStream) abortRequestBodyWrite() {
defer cc.mu.Unlock() defer cc.mu.Unlock()
if cs.reqBody != nil && cs.reqBodyClosed == nil { if cs.reqBody != nil && cs.reqBodyClosed == nil {
cs.closeReqBodyLocked() cs.closeReqBodyLocked()
cc.cond.Broadcast() cc.condBroadcast()
} }
} }
@ -445,10 +507,10 @@ func (cs *clientStream) closeReqBodyLocked() {
} }
cs.reqBodyClosed = make(chan struct{}) cs.reqBodyClosed = make(chan struct{})
reqBodyClosed := cs.reqBodyClosed reqBodyClosed := cs.reqBodyClosed
go func() { cs.cc.goRun(func() {
cs.reqBody.Close() cs.reqBody.Close()
close(reqBodyClosed) close(reqBodyClosed)
}() })
} }
type stickyErrWriter struct { type stickyErrWriter struct {
@ -537,15 +599,6 @@ func authorityAddr(scheme string, authority string) (addr string) {
return net.JoinHostPort(host, port) return net.JoinHostPort(host, port)
} }
var retryBackoffHook func(time.Duration) *time.Timer
func backoffNewTimer(d time.Duration) *time.Timer {
if retryBackoffHook != nil {
return retryBackoffHook(d)
}
return time.NewTimer(d)
}
// RoundTripOpt is like RoundTrip, but takes options. // RoundTripOpt is like RoundTrip, but takes options.
func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) { if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) {
@ -573,13 +626,27 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
backoff := float64(uint(1) << (uint(retry) - 1)) backoff := float64(uint(1) << (uint(retry) - 1))
backoff += backoff * (0.1 * mathrand.Float64()) backoff += backoff * (0.1 * mathrand.Float64())
d := time.Second * time.Duration(backoff) d := time.Second * time.Duration(backoff)
timer := backoffNewTimer(d) var tm timer
if t.syncHooks != nil {
tm = t.syncHooks.newTimer(d)
t.syncHooks.blockUntil(func() bool {
select { select {
case <-timer.C: case <-tm.C():
case <-req.Context().Done():
default:
return false
}
return true
})
} else {
tm = newTimeTimer(d)
}
select {
case <-tm.C():
t.vlogf("RoundTrip retrying after failure: %v", roundTripErr) t.vlogf("RoundTrip retrying after failure: %v", roundTripErr)
continue continue
case <-req.Context().Done(): case <-req.Context().Done():
timer.Stop() tm.Stop()
err = req.Context().Err() err = req.Context().Err()
} }
} }
@ -658,6 +725,9 @@ func canRetryError(err error) bool {
} }
func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) { func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) {
if t.syncHooks != nil {
return t.newClientConn(nil, singleUse, t.syncHooks)
}
host, _, err := net.SplitHostPort(addr) host, _, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -666,7 +736,7 @@ func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse b
if err != nil { if err != nil {
return nil, err return nil, err
} }
return t.newClientConn(tconn, singleUse) return t.newClientConn(tconn, singleUse, nil)
} }
func (t *Transport) newTLSConfig(host string) *tls.Config { func (t *Transport) newTLSConfig(host string) *tls.Config {
@ -732,10 +802,10 @@ func (t *Transport) maxEncoderHeaderTableSize() uint32 {
} }
func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) { func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) {
return t.newClientConn(c, t.disableKeepAlives()) return t.newClientConn(c, t.disableKeepAlives(), nil)
} }
func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, error) { func (t *Transport) newClientConn(c net.Conn, singleUse bool, hooks *testSyncHooks) (*ClientConn, error) {
cc := &ClientConn{ cc := &ClientConn{
t: t, t: t,
tconn: c, tconn: c,
@ -750,10 +820,15 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
wantSettingsAck: true, wantSettingsAck: true,
pings: make(map[[8]byte]chan struct{}), pings: make(map[[8]byte]chan struct{}),
reqHeaderMu: make(chan struct{}, 1), reqHeaderMu: make(chan struct{}, 1),
syncHooks: hooks,
}
if hooks != nil {
hooks.newclientconn(cc)
c = cc.tconn
} }
if d := t.idleConnTimeout(); d != 0 { if d := t.idleConnTimeout(); d != 0 {
cc.idleTimeout = d cc.idleTimeout = d
cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout) cc.idleTimer = cc.afterFunc(d, cc.onIdleTimeout)
} }
if VerboseLogs { if VerboseLogs {
t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr()) t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr())
@ -818,7 +893,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
return nil, cc.werr return nil, cc.werr
} }
go cc.readLoop() cc.goRun(cc.readLoop)
return cc, nil return cc, nil
} }
@ -826,7 +901,7 @@ func (cc *ClientConn) healthCheck() {
pingTimeout := cc.t.pingTimeout() pingTimeout := cc.t.pingTimeout()
// We don't need to periodically ping in the health check, because the readLoop of ClientConn will // We don't need to periodically ping in the health check, because the readLoop of ClientConn will
// trigger the healthCheck again if there is no frame received. // trigger the healthCheck again if there is no frame received.
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) ctx, cancel := cc.contextWithTimeout(context.Background(), pingTimeout)
defer cancel() defer cancel()
cc.vlogf("http2: Transport sending health check") cc.vlogf("http2: Transport sending health check")
err := cc.Ping(ctx) err := cc.Ping(ctx)
@ -1056,7 +1131,7 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error {
// Wait for all in-flight streams to complete or connection to close // Wait for all in-flight streams to complete or connection to close
done := make(chan struct{}) done := make(chan struct{})
cancelled := false // guarded by cc.mu cancelled := false // guarded by cc.mu
go func() { cc.goRun(func() {
cc.mu.Lock() cc.mu.Lock()
defer cc.mu.Unlock() defer cc.mu.Unlock()
for { for {
@ -1068,9 +1143,9 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error {
if cancelled { if cancelled {
break break
} }
cc.cond.Wait() cc.condWait()
} }
}() })
shutdownEnterWaitStateHook() shutdownEnterWaitStateHook()
select { select {
case <-done: case <-done:
@ -1080,7 +1155,7 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error {
cc.mu.Lock() cc.mu.Lock()
// Free the goroutine above // Free the goroutine above
cancelled = true cancelled = true
cc.cond.Broadcast() cc.condBroadcast()
cc.mu.Unlock() cc.mu.Unlock()
return ctx.Err() return ctx.Err()
} }
@ -1118,7 +1193,7 @@ func (cc *ClientConn) closeForError(err error) {
for _, cs := range cc.streams { for _, cs := range cc.streams {
cs.abortStreamLocked(err) cs.abortStreamLocked(err)
} }
cc.cond.Broadcast() cc.condBroadcast()
cc.mu.Unlock() cc.mu.Unlock()
cc.closeConn() cc.closeConn()
} }
@ -1215,6 +1290,10 @@ func (cc *ClientConn) decrStreamReservationsLocked() {
} }
func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
return cc.roundTrip(req, nil)
}
func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream)) (*http.Response, error) {
ctx := req.Context() ctx := req.Context()
cs := &clientStream{ cs := &clientStream{
cc: cc, cc: cc,
@ -1229,9 +1308,23 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
respHeaderRecv: make(chan struct{}), respHeaderRecv: make(chan struct{}),
donec: make(chan struct{}), donec: make(chan struct{}),
} }
go cs.doRequest(req) cc.goRun(func() {
cs.doRequest(req)
})
waitDone := func() error { waitDone := func() error {
if cc.syncHooks != nil {
cc.syncHooks.blockUntil(func() bool {
select {
case <-cs.donec:
case <-ctx.Done():
case <-cs.reqCancel:
default:
return false
}
return true
})
}
select { select {
case <-cs.donec: case <-cs.donec:
return nil return nil
@ -1292,7 +1385,24 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
return err return err
} }
if streamf != nil {
streamf(cs)
}
for { for {
if cc.syncHooks != nil {
cc.syncHooks.blockUntil(func() bool {
select {
case <-cs.respHeaderRecv:
case <-cs.abort:
case <-ctx.Done():
case <-cs.reqCancel:
default:
return false
}
return true
})
}
select { select {
case <-cs.respHeaderRecv: case <-cs.respHeaderRecv:
return handleResponseHeaders() return handleResponseHeaders()
@ -1348,6 +1458,21 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
if cc.reqHeaderMu == nil { if cc.reqHeaderMu == nil {
panic("RoundTrip on uninitialized ClientConn") // for tests panic("RoundTrip on uninitialized ClientConn") // for tests
} }
var newStreamHook func(*clientStream)
if cc.syncHooks != nil {
newStreamHook = cc.syncHooks.newstream
cc.syncHooks.blockUntil(func() bool {
select {
case cc.reqHeaderMu <- struct{}{}:
<-cc.reqHeaderMu
case <-cs.reqCancel:
case <-ctx.Done():
default:
return false
}
return true
})
}
select { select {
case cc.reqHeaderMu <- struct{}{}: case cc.reqHeaderMu <- struct{}{}:
case <-cs.reqCancel: case <-cs.reqCancel:
@ -1372,6 +1497,10 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
} }
cc.mu.Unlock() cc.mu.Unlock()
if newStreamHook != nil {
newStreamHook(cs)
}
// TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
if !cc.t.disableCompression() && if !cc.t.disableCompression() &&
req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Accept-Encoding") == "" &&
@ -1452,15 +1581,30 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
var respHeaderTimer <-chan time.Time var respHeaderTimer <-chan time.Time
var respHeaderRecv chan struct{} var respHeaderRecv chan struct{}
if d := cc.responseHeaderTimeout(); d != 0 { if d := cc.responseHeaderTimeout(); d != 0 {
timer := time.NewTimer(d) timer := cc.newTimer(d)
defer timer.Stop() defer timer.Stop()
respHeaderTimer = timer.C respHeaderTimer = timer.C()
respHeaderRecv = cs.respHeaderRecv respHeaderRecv = cs.respHeaderRecv
} }
// Wait until the peer half-closes its end of the stream, // Wait until the peer half-closes its end of the stream,
// or until the request is aborted (via context, error, or otherwise), // or until the request is aborted (via context, error, or otherwise),
// whichever comes first. // whichever comes first.
for { for {
if cc.syncHooks != nil {
cc.syncHooks.blockUntil(func() bool {
select {
case <-cs.peerClosed:
case <-respHeaderTimer:
case <-respHeaderRecv:
case <-cs.abort:
case <-ctx.Done():
case <-cs.reqCancel:
default:
return false
}
return true
})
}
select { select {
case <-cs.peerClosed: case <-cs.peerClosed:
return nil return nil
@ -1609,7 +1753,7 @@ func (cc *ClientConn) awaitOpenSlotForStreamLocked(cs *clientStream) error {
return nil return nil
} }
cc.pendingRequests++ cc.pendingRequests++
cc.cond.Wait() cc.condWait()
cc.pendingRequests-- cc.pendingRequests--
select { select {
case <-cs.abort: case <-cs.abort:
@ -1871,10 +2015,26 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error)
cs.flow.take(take) cs.flow.take(take)
return take, nil return take, nil
} }
cc.cond.Wait() cc.condWait()
} }
} }
func validateHeaders(hdrs http.Header) string {
for k, vv := range hdrs {
if !httpguts.ValidHeaderFieldName(k) {
return fmt.Sprintf("name %q", k)
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
// Don't include the value in the error,
// because it may be sensitive.
return fmt.Sprintf("value for header %q", k)
}
}
}
return ""
}
var errNilRequestURL = errors.New("http2: Request.URI is nil") var errNilRequestURL = errors.New("http2: Request.URI is nil")
// requires cc.wmu be held. // requires cc.wmu be held.
@ -1912,19 +2072,14 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail
} }
} }
// Check for any invalid headers and return an error before we // Check for any invalid headers+trailers and return an error before we
// potentially pollute our hpack state. (We want to be able to // potentially pollute our hpack state. (We want to be able to
// continue to reuse the hpack encoder for future requests) // continue to reuse the hpack encoder for future requests)
for k, vv := range req.Header { if err := validateHeaders(req.Header); err != "" {
if !httpguts.ValidHeaderFieldName(k) { return nil, fmt.Errorf("invalid HTTP header %s", err)
return nil, fmt.Errorf("invalid HTTP header name %q", k)
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
// Don't include the value in the error, because it may be sensitive.
return nil, fmt.Errorf("invalid HTTP header value for header %q", k)
}
} }
if err := validateHeaders(req.Trailer); err != "" {
return nil, fmt.Errorf("invalid HTTP trailer %s", err)
} }
enumerateHeaders := func(f func(name, value string)) { enumerateHeaders := func(f func(name, value string)) {
@ -2143,7 +2298,7 @@ func (cc *ClientConn) forgetStreamID(id uint32) {
} }
// Wake up writeRequestBody via clientStream.awaitFlowControl and // Wake up writeRequestBody via clientStream.awaitFlowControl and
// wake up RoundTrip if there is a pending request. // wake up RoundTrip if there is a pending request.
cc.cond.Broadcast() cc.condBroadcast()
closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() || cc.goAway != nil closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() || cc.goAway != nil
if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 { if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 {
@ -2231,7 +2386,7 @@ func (rl *clientConnReadLoop) cleanup() {
cs.abortStreamLocked(err) cs.abortStreamLocked(err)
} }
} }
cc.cond.Broadcast() cc.condBroadcast()
cc.mu.Unlock() cc.mu.Unlock()
} }
@ -2266,10 +2421,9 @@ func (rl *clientConnReadLoop) run() error {
cc := rl.cc cc := rl.cc
gotSettings := false gotSettings := false
readIdleTimeout := cc.t.ReadIdleTimeout readIdleTimeout := cc.t.ReadIdleTimeout
var t *time.Timer var t timer
if readIdleTimeout != 0 { if readIdleTimeout != 0 {
t = time.AfterFunc(readIdleTimeout, cc.healthCheck) t = cc.afterFunc(readIdleTimeout, cc.healthCheck)
defer t.Stop()
} }
for { for {
f, err := cc.fr.ReadFrame() f, err := cc.fr.ReadFrame()
@ -2684,7 +2838,7 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error {
}) })
return nil return nil
} }
if !cs.firstByte { if !cs.pastHeaders {
cc.logf("protocol error: received DATA before a HEADERS frame") cc.logf("protocol error: received DATA before a HEADERS frame")
rl.endStreamError(cs, StreamError{ rl.endStreamError(cs, StreamError{
StreamID: f.StreamID, StreamID: f.StreamID,
@ -2867,7 +3021,7 @@ func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error {
for _, cs := range cc.streams { for _, cs := range cc.streams {
cs.flow.add(delta) cs.flow.add(delta)
} }
cc.cond.Broadcast() cc.condBroadcast()
cc.initialWindowSize = s.Val cc.initialWindowSize = s.Val
case SettingHeaderTableSize: case SettingHeaderTableSize:
@ -2922,7 +3076,7 @@ func (rl *clientConnReadLoop) processWindowUpdate(f *WindowUpdateFrame) error {
return ConnectionError(ErrCodeFlowControl) return ConnectionError(ErrCodeFlowControl)
} }
cc.cond.Broadcast() cc.condBroadcast()
return nil return nil
} }
@ -2964,24 +3118,38 @@ func (cc *ClientConn) Ping(ctx context.Context) error {
} }
cc.mu.Unlock() cc.mu.Unlock()
} }
errc := make(chan error, 1) var pingError error
go func() { errc := make(chan struct{})
cc.goRun(func() {
cc.wmu.Lock() cc.wmu.Lock()
defer cc.wmu.Unlock() defer cc.wmu.Unlock()
if err := cc.fr.WritePing(false, p); err != nil { if pingError = cc.fr.WritePing(false, p); pingError != nil {
errc <- err close(errc)
return return
} }
if err := cc.bw.Flush(); err != nil { if pingError = cc.bw.Flush(); pingError != nil {
errc <- err close(errc)
return return
} }
}() })
if cc.syncHooks != nil {
cc.syncHooks.blockUntil(func() bool {
select {
case <-c:
case <-errc:
case <-ctx.Done():
case <-cc.readerDone:
default:
return false
}
return true
})
}
select { select {
case <-c: case <-c:
return nil return nil
case err := <-errc: case <-errc:
return err return pingError
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case <-cc.readerDone: case <-cc.readerDone:
@ -3150,9 +3318,17 @@ func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, err
} }
func (t *Transport) idleConnTimeout() time.Duration { func (t *Transport) idleConnTimeout() time.Duration {
// to keep things backwards compatible, we use non-zero values of
// IdleConnTimeout, followed by using the IdleConnTimeout on the underlying
// http1 transport, followed by 0
if t.IdleConnTimeout != 0 {
return t.IdleConnTimeout
}
if t.t1 != nil { if t.t1 != nil {
return t.t1.IdleConnTimeout return t.t1.IdleConnTimeout
} }
return 0 return 0
} }

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || openbsd || solaris //go:build aix || darwin || dragonfly || freebsd || openbsd || solaris || zos
package unix package unix

View File

@ -1520,6 +1520,14 @@ func (m *mmapper) Munmap(data []byte) (err error) {
return nil return nil
} }
func Mmap(fd int, offset int64, length int, prot int, flags int) (data []byte, err error) {
return mapper.Mmap(fd, offset, length, prot, flags)
}
func Munmap(b []byte) (err error) {
return mapper.Munmap(b)
}
func Read(fd int, p []byte) (n int, err error) { func Read(fd int, p []byte) (n int, err error) {
n, err = read(fd, p) n, err = read(fd, p)
if raceenabled { if raceenabled {

View File

@ -165,6 +165,7 @@ func NewCallbackCDecl(fn interface{}) uintptr {
//sys CreateFile(name *uint16, access uint32, mode uint32, sa *SecurityAttributes, createmode uint32, attrs uint32, templatefile Handle) (handle Handle, err error) [failretval==InvalidHandle] = CreateFileW //sys CreateFile(name *uint16, access uint32, mode uint32, sa *SecurityAttributes, createmode uint32, attrs uint32, templatefile Handle) (handle Handle, err error) [failretval==InvalidHandle] = CreateFileW
//sys CreateNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *SecurityAttributes) (handle Handle, err error) [failretval==InvalidHandle] = CreateNamedPipeW //sys CreateNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *SecurityAttributes) (handle Handle, err error) [failretval==InvalidHandle] = CreateNamedPipeW
//sys ConnectNamedPipe(pipe Handle, overlapped *Overlapped) (err error) //sys ConnectNamedPipe(pipe Handle, overlapped *Overlapped) (err error)
//sys DisconnectNamedPipe(pipe Handle) (err error)
//sys GetNamedPipeInfo(pipe Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) //sys GetNamedPipeInfo(pipe Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error)
//sys GetNamedPipeHandleState(pipe Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW //sys GetNamedPipeHandleState(pipe Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW
//sys SetNamedPipeHandleState(pipe Handle, state *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32) (err error) = SetNamedPipeHandleState //sys SetNamedPipeHandleState(pipe Handle, state *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32) (err error) = SetNamedPipeHandleState
@ -348,8 +349,19 @@ func NewCallbackCDecl(fn interface{}) uintptr {
//sys SetProcessPriorityBoost(process Handle, disable bool) (err error) = kernel32.SetProcessPriorityBoost //sys SetProcessPriorityBoost(process Handle, disable bool) (err error) = kernel32.SetProcessPriorityBoost
//sys GetProcessWorkingSetSizeEx(hProcess Handle, lpMinimumWorkingSetSize *uintptr, lpMaximumWorkingSetSize *uintptr, flags *uint32) //sys GetProcessWorkingSetSizeEx(hProcess Handle, lpMinimumWorkingSetSize *uintptr, lpMaximumWorkingSetSize *uintptr, flags *uint32)
//sys SetProcessWorkingSetSizeEx(hProcess Handle, dwMinimumWorkingSetSize uintptr, dwMaximumWorkingSetSize uintptr, flags uint32) (err error) //sys SetProcessWorkingSetSizeEx(hProcess Handle, dwMinimumWorkingSetSize uintptr, dwMaximumWorkingSetSize uintptr, flags uint32) (err error)
//sys ClearCommBreak(handle Handle) (err error)
//sys ClearCommError(handle Handle, lpErrors *uint32, lpStat *ComStat) (err error)
//sys EscapeCommFunction(handle Handle, dwFunc uint32) (err error)
//sys GetCommState(handle Handle, lpDCB *DCB) (err error)
//sys GetCommModemStatus(handle Handle, lpModemStat *uint32) (err error)
//sys GetCommTimeouts(handle Handle, timeouts *CommTimeouts) (err error) //sys GetCommTimeouts(handle Handle, timeouts *CommTimeouts) (err error)
//sys PurgeComm(handle Handle, dwFlags uint32) (err error)
//sys SetCommBreak(handle Handle) (err error)
//sys SetCommMask(handle Handle, dwEvtMask uint32) (err error)
//sys SetCommState(handle Handle, lpDCB *DCB) (err error)
//sys SetCommTimeouts(handle Handle, timeouts *CommTimeouts) (err error) //sys SetCommTimeouts(handle Handle, timeouts *CommTimeouts) (err error)
//sys SetupComm(handle Handle, dwInQueue uint32, dwOutQueue uint32) (err error)
//sys WaitCommEvent(handle Handle, lpEvtMask *uint32, lpOverlapped *Overlapped) (err error)
//sys GetActiveProcessorCount(groupNumber uint16) (ret uint32) //sys GetActiveProcessorCount(groupNumber uint16) (ret uint32)
//sys GetMaximumProcessorCount(groupNumber uint16) (ret uint32) //sys GetMaximumProcessorCount(groupNumber uint16) (ret uint32)
//sys EnumWindows(enumFunc uintptr, param unsafe.Pointer) (err error) = user32.EnumWindows //sys EnumWindows(enumFunc uintptr, param unsafe.Pointer) (err error) = user32.EnumWindows
@ -1834,3 +1846,73 @@ func ResizePseudoConsole(pconsole Handle, size Coord) error {
// accept arguments that can be casted to uintptr, and Coord can't. // accept arguments that can be casted to uintptr, and Coord can't.
return resizePseudoConsole(pconsole, *((*uint32)(unsafe.Pointer(&size)))) return resizePseudoConsole(pconsole, *((*uint32)(unsafe.Pointer(&size))))
} }
// DCB constants. See https://learn.microsoft.com/en-us/windows/win32/api/winbase/ns-winbase-dcb.
const (
CBR_110 = 110
CBR_300 = 300
CBR_600 = 600
CBR_1200 = 1200
CBR_2400 = 2400
CBR_4800 = 4800
CBR_9600 = 9600
CBR_14400 = 14400
CBR_19200 = 19200
CBR_38400 = 38400
CBR_57600 = 57600
CBR_115200 = 115200
CBR_128000 = 128000
CBR_256000 = 256000
DTR_CONTROL_DISABLE = 0x00000000
DTR_CONTROL_ENABLE = 0x00000010
DTR_CONTROL_HANDSHAKE = 0x00000020
RTS_CONTROL_DISABLE = 0x00000000
RTS_CONTROL_ENABLE = 0x00001000
RTS_CONTROL_HANDSHAKE = 0x00002000
RTS_CONTROL_TOGGLE = 0x00003000
NOPARITY = 0
ODDPARITY = 1
EVENPARITY = 2
MARKPARITY = 3
SPACEPARITY = 4
ONESTOPBIT = 0
ONE5STOPBITS = 1
TWOSTOPBITS = 2
)
// EscapeCommFunction constants. See https://learn.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-escapecommfunction.
const (
SETXOFF = 1
SETXON = 2
SETRTS = 3
CLRRTS = 4
SETDTR = 5
CLRDTR = 6
SETBREAK = 8
CLRBREAK = 9
)
// PurgeComm constants. See https://learn.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-purgecomm.
const (
PURGE_TXABORT = 0x0001
PURGE_RXABORT = 0x0002
PURGE_TXCLEAR = 0x0004
PURGE_RXCLEAR = 0x0008
)
// SetCommMask constants. See https://learn.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-setcommmask.
const (
EV_RXCHAR = 0x0001
EV_RXFLAG = 0x0002
EV_TXEMPTY = 0x0004
EV_CTS = 0x0008
EV_DSR = 0x0010
EV_RLSD = 0x0020
EV_BREAK = 0x0040
EV_ERR = 0x0080
EV_RING = 0x0100
)

View File

@ -3380,3 +3380,27 @@ type BLOB struct {
Size uint32 Size uint32
BlobData *byte BlobData *byte
} }
type ComStat struct {
Flags uint32
CBInQue uint32
CBOutQue uint32
}
type DCB struct {
DCBlength uint32
BaudRate uint32
Flags uint32
wReserved uint16
XonLim uint16
XoffLim uint16
ByteSize uint8
Parity uint8
StopBits uint8
XonChar byte
XoffChar byte
ErrorChar byte
EofChar byte
EvtChar byte
wReserved1 uint16
}

View File

@ -188,6 +188,8 @@ var (
procAssignProcessToJobObject = modkernel32.NewProc("AssignProcessToJobObject") procAssignProcessToJobObject = modkernel32.NewProc("AssignProcessToJobObject")
procCancelIo = modkernel32.NewProc("CancelIo") procCancelIo = modkernel32.NewProc("CancelIo")
procCancelIoEx = modkernel32.NewProc("CancelIoEx") procCancelIoEx = modkernel32.NewProc("CancelIoEx")
procClearCommBreak = modkernel32.NewProc("ClearCommBreak")
procClearCommError = modkernel32.NewProc("ClearCommError")
procCloseHandle = modkernel32.NewProc("CloseHandle") procCloseHandle = modkernel32.NewProc("CloseHandle")
procClosePseudoConsole = modkernel32.NewProc("ClosePseudoConsole") procClosePseudoConsole = modkernel32.NewProc("ClosePseudoConsole")
procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe") procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe")
@ -212,7 +214,9 @@ var (
procDeleteProcThreadAttributeList = modkernel32.NewProc("DeleteProcThreadAttributeList") procDeleteProcThreadAttributeList = modkernel32.NewProc("DeleteProcThreadAttributeList")
procDeleteVolumeMountPointW = modkernel32.NewProc("DeleteVolumeMountPointW") procDeleteVolumeMountPointW = modkernel32.NewProc("DeleteVolumeMountPointW")
procDeviceIoControl = modkernel32.NewProc("DeviceIoControl") procDeviceIoControl = modkernel32.NewProc("DeviceIoControl")
procDisconnectNamedPipe = modkernel32.NewProc("DisconnectNamedPipe")
procDuplicateHandle = modkernel32.NewProc("DuplicateHandle") procDuplicateHandle = modkernel32.NewProc("DuplicateHandle")
procEscapeCommFunction = modkernel32.NewProc("EscapeCommFunction")
procExitProcess = modkernel32.NewProc("ExitProcess") procExitProcess = modkernel32.NewProc("ExitProcess")
procExpandEnvironmentStringsW = modkernel32.NewProc("ExpandEnvironmentStringsW") procExpandEnvironmentStringsW = modkernel32.NewProc("ExpandEnvironmentStringsW")
procFindClose = modkernel32.NewProc("FindClose") procFindClose = modkernel32.NewProc("FindClose")
@ -236,6 +240,8 @@ var (
procGenerateConsoleCtrlEvent = modkernel32.NewProc("GenerateConsoleCtrlEvent") procGenerateConsoleCtrlEvent = modkernel32.NewProc("GenerateConsoleCtrlEvent")
procGetACP = modkernel32.NewProc("GetACP") procGetACP = modkernel32.NewProc("GetACP")
procGetActiveProcessorCount = modkernel32.NewProc("GetActiveProcessorCount") procGetActiveProcessorCount = modkernel32.NewProc("GetActiveProcessorCount")
procGetCommModemStatus = modkernel32.NewProc("GetCommModemStatus")
procGetCommState = modkernel32.NewProc("GetCommState")
procGetCommTimeouts = modkernel32.NewProc("GetCommTimeouts") procGetCommTimeouts = modkernel32.NewProc("GetCommTimeouts")
procGetCommandLineW = modkernel32.NewProc("GetCommandLineW") procGetCommandLineW = modkernel32.NewProc("GetCommandLineW")
procGetComputerNameExW = modkernel32.NewProc("GetComputerNameExW") procGetComputerNameExW = modkernel32.NewProc("GetComputerNameExW")
@ -322,6 +328,7 @@ var (
procProcess32NextW = modkernel32.NewProc("Process32NextW") procProcess32NextW = modkernel32.NewProc("Process32NextW")
procProcessIdToSessionId = modkernel32.NewProc("ProcessIdToSessionId") procProcessIdToSessionId = modkernel32.NewProc("ProcessIdToSessionId")
procPulseEvent = modkernel32.NewProc("PulseEvent") procPulseEvent = modkernel32.NewProc("PulseEvent")
procPurgeComm = modkernel32.NewProc("PurgeComm")
procQueryDosDeviceW = modkernel32.NewProc("QueryDosDeviceW") procQueryDosDeviceW = modkernel32.NewProc("QueryDosDeviceW")
procQueryFullProcessImageNameW = modkernel32.NewProc("QueryFullProcessImageNameW") procQueryFullProcessImageNameW = modkernel32.NewProc("QueryFullProcessImageNameW")
procQueryInformationJobObject = modkernel32.NewProc("QueryInformationJobObject") procQueryInformationJobObject = modkernel32.NewProc("QueryInformationJobObject")
@ -335,6 +342,9 @@ var (
procResetEvent = modkernel32.NewProc("ResetEvent") procResetEvent = modkernel32.NewProc("ResetEvent")
procResizePseudoConsole = modkernel32.NewProc("ResizePseudoConsole") procResizePseudoConsole = modkernel32.NewProc("ResizePseudoConsole")
procResumeThread = modkernel32.NewProc("ResumeThread") procResumeThread = modkernel32.NewProc("ResumeThread")
procSetCommBreak = modkernel32.NewProc("SetCommBreak")
procSetCommMask = modkernel32.NewProc("SetCommMask")
procSetCommState = modkernel32.NewProc("SetCommState")
procSetCommTimeouts = modkernel32.NewProc("SetCommTimeouts") procSetCommTimeouts = modkernel32.NewProc("SetCommTimeouts")
procSetConsoleCursorPosition = modkernel32.NewProc("SetConsoleCursorPosition") procSetConsoleCursorPosition = modkernel32.NewProc("SetConsoleCursorPosition")
procSetConsoleMode = modkernel32.NewProc("SetConsoleMode") procSetConsoleMode = modkernel32.NewProc("SetConsoleMode")
@ -342,7 +352,6 @@ var (
procSetDefaultDllDirectories = modkernel32.NewProc("SetDefaultDllDirectories") procSetDefaultDllDirectories = modkernel32.NewProc("SetDefaultDllDirectories")
procSetDllDirectoryW = modkernel32.NewProc("SetDllDirectoryW") procSetDllDirectoryW = modkernel32.NewProc("SetDllDirectoryW")
procSetEndOfFile = modkernel32.NewProc("SetEndOfFile") procSetEndOfFile = modkernel32.NewProc("SetEndOfFile")
procSetFileValidData = modkernel32.NewProc("SetFileValidData")
procSetEnvironmentVariableW = modkernel32.NewProc("SetEnvironmentVariableW") procSetEnvironmentVariableW = modkernel32.NewProc("SetEnvironmentVariableW")
procSetErrorMode = modkernel32.NewProc("SetErrorMode") procSetErrorMode = modkernel32.NewProc("SetErrorMode")
procSetEvent = modkernel32.NewProc("SetEvent") procSetEvent = modkernel32.NewProc("SetEvent")
@ -351,6 +360,7 @@ var (
procSetFileInformationByHandle = modkernel32.NewProc("SetFileInformationByHandle") procSetFileInformationByHandle = modkernel32.NewProc("SetFileInformationByHandle")
procSetFilePointer = modkernel32.NewProc("SetFilePointer") procSetFilePointer = modkernel32.NewProc("SetFilePointer")
procSetFileTime = modkernel32.NewProc("SetFileTime") procSetFileTime = modkernel32.NewProc("SetFileTime")
procSetFileValidData = modkernel32.NewProc("SetFileValidData")
procSetHandleInformation = modkernel32.NewProc("SetHandleInformation") procSetHandleInformation = modkernel32.NewProc("SetHandleInformation")
procSetInformationJobObject = modkernel32.NewProc("SetInformationJobObject") procSetInformationJobObject = modkernel32.NewProc("SetInformationJobObject")
procSetNamedPipeHandleState = modkernel32.NewProc("SetNamedPipeHandleState") procSetNamedPipeHandleState = modkernel32.NewProc("SetNamedPipeHandleState")
@ -361,6 +371,7 @@ var (
procSetStdHandle = modkernel32.NewProc("SetStdHandle") procSetStdHandle = modkernel32.NewProc("SetStdHandle")
procSetVolumeLabelW = modkernel32.NewProc("SetVolumeLabelW") procSetVolumeLabelW = modkernel32.NewProc("SetVolumeLabelW")
procSetVolumeMountPointW = modkernel32.NewProc("SetVolumeMountPointW") procSetVolumeMountPointW = modkernel32.NewProc("SetVolumeMountPointW")
procSetupComm = modkernel32.NewProc("SetupComm")
procSizeofResource = modkernel32.NewProc("SizeofResource") procSizeofResource = modkernel32.NewProc("SizeofResource")
procSleepEx = modkernel32.NewProc("SleepEx") procSleepEx = modkernel32.NewProc("SleepEx")
procTerminateJobObject = modkernel32.NewProc("TerminateJobObject") procTerminateJobObject = modkernel32.NewProc("TerminateJobObject")
@ -379,6 +390,7 @@ var (
procVirtualQueryEx = modkernel32.NewProc("VirtualQueryEx") procVirtualQueryEx = modkernel32.NewProc("VirtualQueryEx")
procVirtualUnlock = modkernel32.NewProc("VirtualUnlock") procVirtualUnlock = modkernel32.NewProc("VirtualUnlock")
procWTSGetActiveConsoleSessionId = modkernel32.NewProc("WTSGetActiveConsoleSessionId") procWTSGetActiveConsoleSessionId = modkernel32.NewProc("WTSGetActiveConsoleSessionId")
procWaitCommEvent = modkernel32.NewProc("WaitCommEvent")
procWaitForMultipleObjects = modkernel32.NewProc("WaitForMultipleObjects") procWaitForMultipleObjects = modkernel32.NewProc("WaitForMultipleObjects")
procWaitForSingleObject = modkernel32.NewProc("WaitForSingleObject") procWaitForSingleObject = modkernel32.NewProc("WaitForSingleObject")
procWriteConsoleW = modkernel32.NewProc("WriteConsoleW") procWriteConsoleW = modkernel32.NewProc("WriteConsoleW")
@ -1641,6 +1653,22 @@ func CancelIoEx(s Handle, o *Overlapped) (err error) {
return return
} }
func ClearCommBreak(handle Handle) (err error) {
r1, _, e1 := syscall.Syscall(procClearCommBreak.Addr(), 1, uintptr(handle), 0, 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func ClearCommError(handle Handle, lpErrors *uint32, lpStat *ComStat) (err error) {
r1, _, e1 := syscall.Syscall(procClearCommError.Addr(), 3, uintptr(handle), uintptr(unsafe.Pointer(lpErrors)), uintptr(unsafe.Pointer(lpStat)))
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func CloseHandle(handle Handle) (err error) { func CloseHandle(handle Handle) (err error) {
r1, _, e1 := syscall.Syscall(procCloseHandle.Addr(), 1, uintptr(handle), 0, 0) r1, _, e1 := syscall.Syscall(procCloseHandle.Addr(), 1, uintptr(handle), 0, 0)
if r1 == 0 { if r1 == 0 {
@ -1845,6 +1873,14 @@ func DeviceIoControl(handle Handle, ioControlCode uint32, inBuffer *byte, inBuff
return return
} }
func DisconnectNamedPipe(pipe Handle) (err error) {
r1, _, e1 := syscall.Syscall(procDisconnectNamedPipe.Addr(), 1, uintptr(pipe), 0, 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func DuplicateHandle(hSourceProcessHandle Handle, hSourceHandle Handle, hTargetProcessHandle Handle, lpTargetHandle *Handle, dwDesiredAccess uint32, bInheritHandle bool, dwOptions uint32) (err error) { func DuplicateHandle(hSourceProcessHandle Handle, hSourceHandle Handle, hTargetProcessHandle Handle, lpTargetHandle *Handle, dwDesiredAccess uint32, bInheritHandle bool, dwOptions uint32) (err error) {
var _p0 uint32 var _p0 uint32
if bInheritHandle { if bInheritHandle {
@ -1857,6 +1893,14 @@ func DuplicateHandle(hSourceProcessHandle Handle, hSourceHandle Handle, hTargetP
return return
} }
func EscapeCommFunction(handle Handle, dwFunc uint32) (err error) {
r1, _, e1 := syscall.Syscall(procEscapeCommFunction.Addr(), 2, uintptr(handle), uintptr(dwFunc), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func ExitProcess(exitcode uint32) { func ExitProcess(exitcode uint32) {
syscall.Syscall(procExitProcess.Addr(), 1, uintptr(exitcode), 0, 0) syscall.Syscall(procExitProcess.Addr(), 1, uintptr(exitcode), 0, 0)
return return
@ -2058,6 +2102,22 @@ func GetActiveProcessorCount(groupNumber uint16) (ret uint32) {
return return
} }
func GetCommModemStatus(handle Handle, lpModemStat *uint32) (err error) {
r1, _, e1 := syscall.Syscall(procGetCommModemStatus.Addr(), 2, uintptr(handle), uintptr(unsafe.Pointer(lpModemStat)), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func GetCommState(handle Handle, lpDCB *DCB) (err error) {
r1, _, e1 := syscall.Syscall(procGetCommState.Addr(), 2, uintptr(handle), uintptr(unsafe.Pointer(lpDCB)), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func GetCommTimeouts(handle Handle, timeouts *CommTimeouts) (err error) { func GetCommTimeouts(handle Handle, timeouts *CommTimeouts) (err error) {
r1, _, e1 := syscall.Syscall(procGetCommTimeouts.Addr(), 2, uintptr(handle), uintptr(unsafe.Pointer(timeouts)), 0) r1, _, e1 := syscall.Syscall(procGetCommTimeouts.Addr(), 2, uintptr(handle), uintptr(unsafe.Pointer(timeouts)), 0)
if r1 == 0 { if r1 == 0 {
@ -2810,6 +2870,14 @@ func PulseEvent(event Handle) (err error) {
return return
} }
func PurgeComm(handle Handle, dwFlags uint32) (err error) {
r1, _, e1 := syscall.Syscall(procPurgeComm.Addr(), 2, uintptr(handle), uintptr(dwFlags), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func QueryDosDevice(deviceName *uint16, targetPath *uint16, max uint32) (n uint32, err error) { func QueryDosDevice(deviceName *uint16, targetPath *uint16, max uint32) (n uint32, err error) {
r0, _, e1 := syscall.Syscall(procQueryDosDeviceW.Addr(), 3, uintptr(unsafe.Pointer(deviceName)), uintptr(unsafe.Pointer(targetPath)), uintptr(max)) r0, _, e1 := syscall.Syscall(procQueryDosDeviceW.Addr(), 3, uintptr(unsafe.Pointer(deviceName)), uintptr(unsafe.Pointer(targetPath)), uintptr(max))
n = uint32(r0) n = uint32(r0)
@ -2924,6 +2992,30 @@ func ResumeThread(thread Handle) (ret uint32, err error) {
return return
} }
func SetCommBreak(handle Handle) (err error) {
r1, _, e1 := syscall.Syscall(procSetCommBreak.Addr(), 1, uintptr(handle), 0, 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func SetCommMask(handle Handle, dwEvtMask uint32) (err error) {
r1, _, e1 := syscall.Syscall(procSetCommMask.Addr(), 2, uintptr(handle), uintptr(dwEvtMask), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func SetCommState(handle Handle, lpDCB *DCB) (err error) {
r1, _, e1 := syscall.Syscall(procSetCommState.Addr(), 2, uintptr(handle), uintptr(unsafe.Pointer(lpDCB)), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func SetCommTimeouts(handle Handle, timeouts *CommTimeouts) (err error) { func SetCommTimeouts(handle Handle, timeouts *CommTimeouts) (err error) {
r1, _, e1 := syscall.Syscall(procSetCommTimeouts.Addr(), 2, uintptr(handle), uintptr(unsafe.Pointer(timeouts)), 0) r1, _, e1 := syscall.Syscall(procSetCommTimeouts.Addr(), 2, uintptr(handle), uintptr(unsafe.Pointer(timeouts)), 0)
if r1 == 0 { if r1 == 0 {
@ -2989,14 +3081,6 @@ func SetEndOfFile(handle Handle) (err error) {
return return
} }
func SetFileValidData(handle Handle, validDataLength int64) (err error) {
r1, _, e1 := syscall.Syscall(procSetFileValidData.Addr(), 2, uintptr(handle), uintptr(validDataLength), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func SetEnvironmentVariable(name *uint16, value *uint16) (err error) { func SetEnvironmentVariable(name *uint16, value *uint16) (err error) {
r1, _, e1 := syscall.Syscall(procSetEnvironmentVariableW.Addr(), 2, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(value)), 0) r1, _, e1 := syscall.Syscall(procSetEnvironmentVariableW.Addr(), 2, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(value)), 0)
if r1 == 0 { if r1 == 0 {
@ -3060,6 +3144,14 @@ func SetFileTime(handle Handle, ctime *Filetime, atime *Filetime, wtime *Filetim
return return
} }
func SetFileValidData(handle Handle, validDataLength int64) (err error) {
r1, _, e1 := syscall.Syscall(procSetFileValidData.Addr(), 2, uintptr(handle), uintptr(validDataLength), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func SetHandleInformation(handle Handle, mask uint32, flags uint32) (err error) { func SetHandleInformation(handle Handle, mask uint32, flags uint32) (err error) {
r1, _, e1 := syscall.Syscall(procSetHandleInformation.Addr(), 3, uintptr(handle), uintptr(mask), uintptr(flags)) r1, _, e1 := syscall.Syscall(procSetHandleInformation.Addr(), 3, uintptr(handle), uintptr(mask), uintptr(flags))
if r1 == 0 { if r1 == 0 {
@ -3145,6 +3237,14 @@ func SetVolumeMountPoint(volumeMountPoint *uint16, volumeName *uint16) (err erro
return return
} }
func SetupComm(handle Handle, dwInQueue uint32, dwOutQueue uint32) (err error) {
r1, _, e1 := syscall.Syscall(procSetupComm.Addr(), 3, uintptr(handle), uintptr(dwInQueue), uintptr(dwOutQueue))
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func SizeofResource(module Handle, resInfo Handle) (size uint32, err error) { func SizeofResource(module Handle, resInfo Handle) (size uint32, err error) {
r0, _, e1 := syscall.Syscall(procSizeofResource.Addr(), 2, uintptr(module), uintptr(resInfo), 0) r0, _, e1 := syscall.Syscall(procSizeofResource.Addr(), 2, uintptr(module), uintptr(resInfo), 0)
size = uint32(r0) size = uint32(r0)
@ -3291,6 +3391,14 @@ func WTSGetActiveConsoleSessionId() (sessionID uint32) {
return return
} }
func WaitCommEvent(handle Handle, lpEvtMask *uint32, lpOverlapped *Overlapped) (err error) {
r1, _, e1 := syscall.Syscall(procWaitCommEvent.Addr(), 3, uintptr(handle), uintptr(unsafe.Pointer(lpEvtMask)), uintptr(unsafe.Pointer(lpOverlapped)))
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func waitForMultipleObjects(count uint32, handles uintptr, waitAll bool, waitMilliseconds uint32) (event uint32, err error) { func waitForMultipleObjects(count uint32, handles uintptr, waitAll bool, waitMilliseconds uint32) (event uint32, err error) {
var _p0 uint32 var _p0 uint32
if waitAll { if waitAll {

12
vendor/modules.txt vendored
View File

@ -711,7 +711,7 @@ go.uber.org/zap/internal/pool
go.uber.org/zap/internal/stacktrace go.uber.org/zap/internal/stacktrace
go.uber.org/zap/zapcore go.uber.org/zap/zapcore
go.uber.org/zap/zapgrpc go.uber.org/zap/zapgrpc
# golang.org/x/crypto v0.21.0 # golang.org/x/crypto v0.22.0
## explicit; go 1.18 ## explicit; go 1.18
golang.org/x/crypto/argon2 golang.org/x/crypto/argon2
golang.org/x/crypto/blake2b golang.org/x/crypto/blake2b
@ -737,7 +737,7 @@ golang.org/x/crypto/ssh/internal/bcrypt_pbkdf
golang.org/x/exp/constraints golang.org/x/exp/constraints
golang.org/x/exp/maps golang.org/x/exp/maps
golang.org/x/exp/slices golang.org/x/exp/slices
# golang.org/x/net v0.22.0 # golang.org/x/net v0.24.0
## explicit; go 1.18 ## explicit; go 1.18
golang.org/x/net/context golang.org/x/net/context
golang.org/x/net/html golang.org/x/net/html
@ -759,14 +759,14 @@ golang.org/x/oauth2/internal
# golang.org/x/sync v0.6.0 # golang.org/x/sync v0.6.0
## explicit; go 1.18 ## explicit; go 1.18
golang.org/x/sync/singleflight golang.org/x/sync/singleflight
# golang.org/x/sys v0.18.0 # golang.org/x/sys v0.19.0
## explicit; go 1.18 ## explicit; go 1.18
golang.org/x/sys/cpu golang.org/x/sys/cpu
golang.org/x/sys/plan9 golang.org/x/sys/plan9
golang.org/x/sys/unix golang.org/x/sys/unix
golang.org/x/sys/windows golang.org/x/sys/windows
golang.org/x/sys/windows/registry golang.org/x/sys/windows/registry
# golang.org/x/term v0.18.0 # golang.org/x/term v0.19.0
## explicit; go 1.18 ## explicit; go 1.18
golang.org/x/term golang.org/x/term
# golang.org/x/text v0.14.0 # golang.org/x/text v0.14.0
@ -998,7 +998,7 @@ k8s.io/api/scheduling/v1beta1
k8s.io/api/storage/v1 k8s.io/api/storage/v1
k8s.io/api/storage/v1alpha1 k8s.io/api/storage/v1alpha1
k8s.io/api/storage/v1beta1 k8s.io/api/storage/v1beta1
# k8s.io/apiextensions-apiserver v0.29.0 => k8s.io/apiextensions-apiserver v0.29.3 # k8s.io/apiextensions-apiserver v0.29.2 => k8s.io/apiextensions-apiserver v0.29.3
## explicit; go 1.21 ## explicit; go 1.21
k8s.io/apiextensions-apiserver/pkg/apis/apiextensions k8s.io/apiextensions-apiserver/pkg/apis/apiextensions
k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1 k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1
@ -1682,7 +1682,7 @@ sigs.k8s.io/apiserver-network-proxy/konnectivity-client/pkg/client
sigs.k8s.io/apiserver-network-proxy/konnectivity-client/pkg/client/metrics sigs.k8s.io/apiserver-network-proxy/konnectivity-client/pkg/client/metrics
sigs.k8s.io/apiserver-network-proxy/konnectivity-client/pkg/common/metrics sigs.k8s.io/apiserver-network-proxy/konnectivity-client/pkg/common/metrics
sigs.k8s.io/apiserver-network-proxy/konnectivity-client/proto/client sigs.k8s.io/apiserver-network-proxy/konnectivity-client/proto/client
# sigs.k8s.io/controller-runtime v0.17.2 # sigs.k8s.io/controller-runtime v0.17.3
## explicit; go 1.21 ## explicit; go 1.21
sigs.k8s.io/controller-runtime/pkg/cache sigs.k8s.io/controller-runtime/pkg/cache
sigs.k8s.io/controller-runtime/pkg/cache/internal sigs.k8s.io/controller-runtime/pkg/cache/internal

View File

@ -20,6 +20,7 @@ import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
"sort"
"time" "time"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
@ -421,7 +422,12 @@ func defaultOpts(config *rest.Config, opts Options) (Options, error) {
for namespace, cfg := range opts.DefaultNamespaces { for namespace, cfg := range opts.DefaultNamespaces {
cfg = defaultConfig(cfg, optionDefaultsToConfig(&opts)) cfg = defaultConfig(cfg, optionDefaultsToConfig(&opts))
if namespace == metav1.NamespaceAll { if namespace == metav1.NamespaceAll {
cfg.FieldSelector = fields.AndSelectors(appendIfNotNil(namespaceAllSelector(maps.Keys(opts.DefaultNamespaces)), cfg.FieldSelector)...) cfg.FieldSelector = fields.AndSelectors(
appendIfNotNil(
namespaceAllSelector(maps.Keys(opts.DefaultNamespaces)),
cfg.FieldSelector,
)...,
)
} }
opts.DefaultNamespaces[namespace] = cfg opts.DefaultNamespaces[namespace] = cfg
} }
@ -435,7 +441,12 @@ func defaultOpts(config *rest.Config, opts Options) (Options, error) {
return opts, fmt.Errorf("type %T is not namespaced, but its ByObject.Namespaces setting is not nil", obj) return opts, fmt.Errorf("type %T is not namespaced, but its ByObject.Namespaces setting is not nil", obj)
} }
// Default the namespace-level configs first, because they need to use the undefaulted type-level config. if isNamespaced && byObject.Namespaces == nil {
byObject.Namespaces = maps.Clone(opts.DefaultNamespaces)
}
// Default the namespace-level configs first, because they need to use the undefaulted type-level config
// to be able to potentially fall through to settings from DefaultNamespaces.
for namespace, config := range byObject.Namespaces { for namespace, config := range byObject.Namespaces {
// 1. Default from the undefaulted type-level config // 1. Default from the undefaulted type-level config
config = defaultConfig(config, byObjectToConfig(byObject)) config = defaultConfig(config, byObjectToConfig(byObject))
@ -461,14 +472,14 @@ func defaultOpts(config *rest.Config, opts Options) (Options, error) {
byObject.Namespaces[namespace] = config byObject.Namespaces[namespace] = config
} }
// Only default ByObject iself if it isn't namespaced or has no namespaces configured, as only
// then any of this will be honored.
if !isNamespaced || len(byObject.Namespaces) == 0 {
defaultedConfig := defaultConfig(byObjectToConfig(byObject), optionDefaultsToConfig(&opts)) defaultedConfig := defaultConfig(byObjectToConfig(byObject), optionDefaultsToConfig(&opts))
byObject.Label = defaultedConfig.LabelSelector byObject.Label = defaultedConfig.LabelSelector
byObject.Field = defaultedConfig.FieldSelector byObject.Field = defaultedConfig.FieldSelector
byObject.Transform = defaultedConfig.Transform byObject.Transform = defaultedConfig.Transform
byObject.UnsafeDisableDeepCopy = defaultedConfig.UnsafeDisableDeepCopy byObject.UnsafeDisableDeepCopy = defaultedConfig.UnsafeDisableDeepCopy
if isNamespaced && byObject.Namespaces == nil {
byObject.Namespaces = opts.DefaultNamespaces
} }
opts.ByObject[obj] = byObject opts.ByObject[obj] = byObject
@ -498,20 +509,21 @@ func defaultConfig(toDefault, defaultFrom Config) Config {
return toDefault return toDefault
} }
func namespaceAllSelector(namespaces []string) fields.Selector { func namespaceAllSelector(namespaces []string) []fields.Selector {
selectors := make([]fields.Selector, 0, len(namespaces)-1) selectors := make([]fields.Selector, 0, len(namespaces)-1)
sort.Strings(namespaces)
for _, namespace := range namespaces { for _, namespace := range namespaces {
if namespace != metav1.NamespaceAll { if namespace != metav1.NamespaceAll {
selectors = append(selectors, fields.OneTermNotEqualSelector("metadata.namespace", namespace)) selectors = append(selectors, fields.OneTermNotEqualSelector("metadata.namespace", namespace))
} }
} }
return fields.AndSelectors(selectors...) return selectors
} }
func appendIfNotNil[T comparable](a, b T) []T { func appendIfNotNil[T comparable](a []T, b T) []T {
if b != *new(T) { if b != *new(T) {
return []T{a, b} return append(a, b)
} }
return []T{a} return a
} }

Some files were not shown because too many files have changed in this diff Show More